This notebook presents a straightforward PyTorch implementation of a Fully Convolutional Network for semantic segmentation of aerial images. More specifically, we aim to automatically perform scene interpretation of images taken from a plane or a satellite by classifying every pixel into several land cover classes. Here, we adapt the baseline implementation of the SegNet model presented in "Beyond RGB: Very High Resolution Urban Remote Sensing With Multimodal Deep Networks ", Nicolas Audebert, Bertrand Le Saux and Sébastien Lefèvre, ISPRS Journal, 2018.
The Vaihingen <https://www2.isprs.org/commissions/comm2/wg4/benchmark/2d-sem-label-vaihingen/>
is a well-known dataset used of for urban semantic segmentation staring from the ISPRS 2D Semantic Labeling Contest - Vaihingen. The datasets is available at the ISPRS Challenge website.
Dataset format:
* images are 3-channel RGB geotiffs
* masks are 3-channel geotiffs with unique RGB values representing the class
Dataset classes:
0. Clutter/background
1. Impervious surfaces
2. Building
3. Low Vegetation
4. Tree
5. Car
# imports and stuff
import numpy as np
from skimage import io
from glob import glob
from tqdm import tqdm_notebook as tqdm
from sklearn.metrics import confusion_matrix
import random
import itertools
# Matplotlib
import matplotlib.pyplot as plt
%matplotlib inline
# Torch imports
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim
import torch.optim.lr_scheduler
import torch.nn.init
from torch.autograd import Variable
import sys
FOLDER = './data/'
sys.path.append(FOLDER)
# Parameters++++
WINDOW_SIZE = (256, 256) # Patch size
STRIDE = 32 # Stride for testing
IN_CHANNELS = 3 # Number of input channels (e.g. RGB)
##FOLDER = "./ISPRS_dataset/" # Replace with your "/path/to/the/ISPRS/dataset/folder/"
BATCH_SIZE = 5 # Number of samples in a mini-batch
LABELS = ["roads", "buildings", "low veg.", "trees", "cars", "clutter"] # Label names
N_CLASSES = len(LABELS) # Number of classes
WEIGHTS = torch.ones(N_CLASSES) # Weights for class balancing
CACHE = True # Store the dataset in-memory
DATASET = 'Vaihingen'
if DATASET == 'Potsdam':
MAIN_FOLDER = FOLDER + 'Potsdam/'
# Uncomment the next line for IRRG data
# DATA_FOLDER = MAIN_FOLDER + '3_Ortho_IRRG/top_potsdam_{}_IRRG.tif'
# For RGB data
DATA_FOLDER = MAIN_FOLDER + '2_Ortho_RGB/top_potsdam_{}_RGB.tif'
LABEL_FOLDER = MAIN_FOLDER + '5_Labels_for_participants/top_potsdam_{}_label.tif'
ERODED_FOLDER = MAIN_FOLDER + '5_Labels_for_participants_no_Boundary/top_potsdam_{}_label_noBoundary.tif'
elif DATASET == 'Vaihingen':
MAIN_FOLDER = FOLDER + 'Vaihingen/'
DATA_FOLDER = MAIN_FOLDER + 'top/top_mosaic_09cm_area{}.tif'
LABEL_FOLDER = MAIN_FOLDER + 'gts_for_participants/top_mosaic_09cm_area{}.tif'
ERODED_FOLDER = MAIN_FOLDER + 'gts_eroded_for_participants/top_mosaic_09cm_area{}_noBoundary.tif'
# ISPRS color palette
# Let's define the standard ISPRS color palette
palette = {0 : (255, 255, 255), # Impervious surfaces (white)
1 : (0, 0, 255), # Buildings (blue)
2 : (0, 255, 255), # Low vegetation (cyan)
3 : (0, 255, 0), # Trees (green)
4 : (255, 255, 0), # Cars (yellow)
5 : (255, 0, 0), # Clutter (red)
6 : (0, 0, 0)} # Undefined (black)
invert_palette = {v: k for k, v in palette.items()}
def convert_to_color(arr_2d, palette=palette):
""" Numeric labels to RGB-color encoding """
arr_3d = np.zeros((arr_2d.shape[0], arr_2d.shape[1], 3), dtype=np.uint8)
for c, i in palette.items():
m = arr_2d == c
arr_3d[m] = i
return arr_3d
def convert_from_color(arr_3d, palette=invert_palette):
""" RGB-color encoding to grayscale labels """
arr_2d = np.zeros((arr_3d.shape[0], arr_3d.shape[1]), dtype=np.uint8)
for c, i in palette.items():
m = np.all(arr_3d == np.array(c).reshape(1, 1, 3), axis=2)
arr_2d[m] = i
return arr_2d
# We load one tile from the dataset and we display it
print(MAIN_FOLDER+'top/top_mosaic_09cm_area1.tif')
img = io.imread(MAIN_FOLDER+'top/top_mosaic_09cm_area1.tif')
fig = plt.figure()
fig.add_subplot(121)
print(img.shape)
plt.imshow(img[:,:,2::1]);##img)
# We load the ground truth
gt = io.imread(MAIN_FOLDER+'gts_for_participants/top_mosaic_09cm_area1.tif')
fig.add_subplot(122)
plt.imshow(gt)
plt.show()
# We also check that we can convert the ground truth into an array format
array_gt = convert_from_color(gt)
print("Ground truth in numerical format has shape ({},{}) : \n".format(*array_gt.shape[:2]), array_gt)
./data/Vaihingen/top/top_mosaic_09cm_area1.tif (2569, 1919, 3)
Ground truth in numerical format has shape (2569,1919) : [[0 0 0 ... 1 1 1] [0 0 0 ... 1 1 1] [0 0 0 ... 1 1 1] ... [1 1 1 ... 4 4 4] [1 1 1 ... 4 4 4] [1 1 1 ... 0 0 0]]
# Utils
def get_random_pos(img, window_shape):
""" Extract of 2D random patch of shape window_shape in the image """
w, h = window_shape
W, H = img.shape[-2:]
x1 = random.randint(0, W - w - 1)
x2 = x1 + w
y1 = random.randint(0, H - h - 1)
y2 = y1 + h
return x1, x2, y1, y2
def CrossEntropy2d(input, target, weight=None, size_average=True):
""" 2D version of the cross entropy loss """
dim = input.dim()
if dim == 2:
return F.cross_entropy(input, target, weight, size_average)
## return nn.CrossEntropyLoss(output, target,weight)
elif dim == 4:
output = input.view(input.size(0),input.size(1), -1)
output = torch.transpose(output,1,2).contiguous()
output = output.view(-1,output.size(2))
target = target.view(-1)
return F.cross_entropy(output, target,weight, size_average)
## return nn.CrossEntropyLoss(output, target,weight)
else:
raise ValueError('Expected 2 or 4 dimensions (got {})'.format(dim))
def accuracy(input, target):
return 100 * float(np.count_nonzero(input == target)) / target.size
def sliding_window(top, step=10, window_size=(20,20)):
""" Slide a window_shape window across the image with a stride of step """
for x in range(0, top.shape[0], step):
if x + window_size[0] > top.shape[0]:
x = top.shape[0] - window_size[0]
for y in range(0, top.shape[1], step):
if y + window_size[1] > top.shape[1]:
y = top.shape[1] - window_size[1]
yield x, y, window_size[0], window_size[1]
def count_sliding_window(top, step=10, window_size=(20,20)):
""" Count the number of windows in an image """
c = 0
for x in range(0, top.shape[0], step):
if x + window_size[0] > top.shape[0]:
x = top.shape[0] - window_size[0]
for y in range(0, top.shape[1], step):
if y + window_size[1] > top.shape[1]:
y = top.shape[1] - window_size[1]
c += 1
return c
def grouper(n, iterable):
""" Browse an iterator by chunk of n elements """
it = iter(iterable)
while True:
chunk = tuple(itertools.islice(it, n))
if not chunk:
return
yield chunk
def metrics(predictions, gts, label_values=LABELS):
cm = confusion_matrix(
gts,
predictions,
range(len(label_values)))
print("Confusion matrix :")
print(cm)
print("---")
# Compute global accuracy
total = sum(sum(cm))
accuracy = sum([cm[x][x] for x in range(len(cm))])
accuracy *= 100 / float(total)
print("{} pixels processed".format(total))
print("Total accuracy : {}%".format(accuracy))
print("---")
# Compute F1 score
F1Score = np.zeros(len(label_values))
for i in range(len(label_values)):
try:
F1Score[i] = 2. * cm[i,i] / (np.sum(cm[i,:]) + np.sum(cm[:,i]))
except:
# Ignore exception if there is no element in class i for test set
pass
print("F1Score :")
for l_id, score in enumerate(F1Score):
print("{}: {}".format(label_values[l_id], score))
print("---")
# Compute kappa coefficient
total = np.sum(cm)
pa = np.trace(cm) / float(total)
pe = np.sum(np.sum(cm, axis=0) * np.sum(cm, axis=1)) / float(total*total)
kappa = (pa - pe) / (1 - pe);
print("Kappa: " + str(kappa))
return cm, accuracy
# Dataset class
class ISPRS_dataset(torch.utils.data.Dataset):
def __init__(self, ids, data_files=DATA_FOLDER, label_files=LABEL_FOLDER,
cache=False, augmentation=True):
super(ISPRS_dataset, self).__init__()
self.augmentation = augmentation
self.cache = cache
# List of files
self.data_files = [DATA_FOLDER.format(id) for id in ids]
self.label_files = [LABEL_FOLDER.format(id) for id in ids]
# Sanity check : raise an error if some files do not exist
for f in self.data_files + self.label_files:
if not os.path.isfile(f):
raise KeyError('{} is not a file !'.format(f))
# Initialize cache dicts
self.data_cache_ = {}
self.label_cache_ = {}
def __len__(self):
# Default epoch size is 10 000 samples
return 10000
@classmethod
def data_augmentation(cls, *arrays, flip=True, mirror=True):
will_flip, will_mirror = False, False
if flip and random.random() < 0.5:
will_flip = True
if mirror and random.random() < 0.5:
will_mirror = True
results = []
for array in arrays:
if will_flip:
if len(array.shape) == 2:
array = array[::-1, :]
else:
array = array[:, ::-1, :]
if will_mirror:
if len(array.shape) == 2:
array = array[:, ::-1]
else:
array = array[:, :, ::-1]
results.append(np.copy(array))
return tuple(results)
def __getitem__(self, i):
# Pick a random image
random_idx = random.randint(0, len(self.data_files) - 1)
# If the tile hasn't been loaded yet, put in cache
if random_idx in self.data_cache_.keys():
data = self.data_cache_[random_idx]
else:
# Data is normalized in [0, 1]
data = 1/255 * np.asarray(io.imread(self.data_files[random_idx]).transpose((2,0,1)), dtype='float32')
if self.cache:
self.data_cache_[random_idx] = data
if random_idx in self.label_cache_.keys():
label = self.label_cache_[random_idx]
else:
# Labels are converted from RGB to their numeric values
label = np.asarray(convert_from_color(io.imread(self.label_files[random_idx])), dtype='int64')
if self.cache:
self.label_cache_[random_idx] = label
# Get a random patch
x1, x2, y1, y2 = get_random_pos(data, WINDOW_SIZE)
data_p = data[:, x1:x2,y1:y2]
label_p = label[x1:x2,y1:y2]
# Data augmentation
data_p, label_p = self.data_augmentation(data_p, label_p)
# Return the torch.Tensor values
return (torch.from_numpy(data_p),
torch.from_numpy(label_p))
class SegNet(nn.Module):
# SegNet network
@staticmethod
def weight_init(m):
if isinstance(m, nn.Linear):
torch.nn.init.kaiming_normal(m.weight.data)
def __init__(self, in_channels=IN_CHANNELS, out_channels=N_CLASSES):
super(SegNet, self).__init__()
self.pool = nn.MaxPool2d(2, return_indices=True)
self.unpool = nn.MaxUnpool2d(2)
self.conv1_1 = nn.Conv2d(in_channels, 64, 3, padding=1)
self.conv1_1_bn = nn.BatchNorm2d(64)
self.conv1_2 = nn.Conv2d(64, 64, 3, padding=1)
self.conv1_2_bn = nn.BatchNorm2d(64)
self.conv2_1 = nn.Conv2d(64, 128, 3, padding=1)
self.conv2_1_bn = nn.BatchNorm2d(128)
self.conv2_2 = nn.Conv2d(128, 128, 3, padding=1)
self.conv2_2_bn = nn.BatchNorm2d(128)
self.conv3_1 = nn.Conv2d(128, 256, 3, padding=1)
self.conv3_1_bn = nn.BatchNorm2d(256)
self.conv3_2 = nn.Conv2d(256, 256, 3, padding=1)
self.conv3_2_bn = nn.BatchNorm2d(256)
self.conv3_3 = nn.Conv2d(256, 256, 3, padding=1)
self.conv3_3_bn = nn.BatchNorm2d(256)
self.conv4_1 = nn.Conv2d(256, 512, 3, padding=1)
self.conv4_1_bn = nn.BatchNorm2d(512)
self.conv4_2 = nn.Conv2d(512, 512, 3, padding=1)
self.conv4_2_bn = nn.BatchNorm2d(512)
self.conv4_3 = nn.Conv2d(512, 512, 3, padding=1)
self.conv4_3_bn = nn.BatchNorm2d(512)
self.conv5_1 = nn.Conv2d(512, 512, 3, padding=1)
self.conv5_1_bn = nn.BatchNorm2d(512)
self.conv5_2 = nn.Conv2d(512, 512, 3, padding=1)
self.conv5_2_bn = nn.BatchNorm2d(512)
self.conv5_3 = nn.Conv2d(512, 512, 3, padding=1)
self.conv5_3_bn = nn.BatchNorm2d(512)
self.conv5_3_D = nn.Conv2d(512, 512, 3, padding=1)
self.conv5_3_D_bn = nn.BatchNorm2d(512)
self.conv5_2_D = nn.Conv2d(512, 512, 3, padding=1)
self.conv5_2_D_bn = nn.BatchNorm2d(512)
self.conv5_1_D = nn.Conv2d(512, 512, 3, padding=1)
self.conv5_1_D_bn = nn.BatchNorm2d(512)
self.conv4_3_D = nn.Conv2d(512, 512, 3, padding=1)
self.conv4_3_D_bn = nn.BatchNorm2d(512)
self.conv4_2_D = nn.Conv2d(512, 512, 3, padding=1)
self.conv4_2_D_bn = nn.BatchNorm2d(512)
self.conv4_1_D = nn.Conv2d(512, 256, 3, padding=1)
self.conv4_1_D_bn = nn.BatchNorm2d(256)
self.conv3_3_D = nn.Conv2d(256, 256, 3, padding=1)
self.conv3_3_D_bn = nn.BatchNorm2d(256)
self.conv3_2_D = nn.Conv2d(256, 256, 3, padding=1)
self.conv3_2_D_bn = nn.BatchNorm2d(256)
self.conv3_1_D = nn.Conv2d(256, 128, 3, padding=1)
self.conv3_1_D_bn = nn.BatchNorm2d(128)
self.conv2_2_D = nn.Conv2d(128, 128, 3, padding=1)
self.conv2_2_D_bn = nn.BatchNorm2d(128)
self.conv2_1_D = nn.Conv2d(128, 64, 3, padding=1)
self.conv2_1_D_bn = nn.BatchNorm2d(64)
self.conv1_2_D = nn.Conv2d(64, 64, 3, padding=1)
self.conv1_2_D_bn = nn.BatchNorm2d(64)
self.conv1_1_D = nn.Conv2d(64, out_channels, 3, padding=1)
self.apply(self.weight_init)
def forward(self, x):
# Encoder block 1
x = self.conv1_1_bn(F.relu(self.conv1_1(x)))
x = self.conv1_2_bn(F.relu(self.conv1_2(x)))
x, mask1 = self.pool(x)
# Encoder block 2
x = self.conv2_1_bn(F.relu(self.conv2_1(x)))
x = self.conv2_2_bn(F.relu(self.conv2_2(x)))
x, mask2 = self.pool(x)
# Encoder block 3
x = self.conv3_1_bn(F.relu(self.conv3_1(x)))
x = self.conv3_2_bn(F.relu(self.conv3_2(x)))
x = self.conv3_3_bn(F.relu(self.conv3_3(x)))
x, mask3 = self.pool(x)
# Encoder block 4
x = self.conv4_1_bn(F.relu(self.conv4_1(x)))
x = self.conv4_2_bn(F.relu(self.conv4_2(x)))
x = self.conv4_3_bn(F.relu(self.conv4_3(x)))
x, mask4 = self.pool(x)
# Encoder block 5
x = self.conv5_1_bn(F.relu(self.conv5_1(x)))
x = self.conv5_2_bn(F.relu(self.conv5_2(x)))
x = self.conv5_3_bn(F.relu(self.conv5_3(x)))
x, mask5 = self.pool(x)
# Decoder block 5
x = self.unpool(x, mask5)
x = self.conv5_3_D_bn(F.relu(self.conv5_3_D(x)))
x = self.conv5_2_D_bn(F.relu(self.conv5_2_D(x)))
x = self.conv5_1_D_bn(F.relu(self.conv5_1_D(x)))
# Decoder block 4
x = self.unpool(x, mask4)
x = self.conv4_3_D_bn(F.relu(self.conv4_3_D(x)))
x = self.conv4_2_D_bn(F.relu(self.conv4_2_D(x)))
x = self.conv4_1_D_bn(F.relu(self.conv4_1_D(x)))
# Decoder block 3
x = self.unpool(x, mask3)
x = self.conv3_3_D_bn(F.relu(self.conv3_3_D(x)))
x = self.conv3_2_D_bn(F.relu(self.conv3_2_D(x)))
x = self.conv3_1_D_bn(F.relu(self.conv3_1_D(x)))
# Decoder block 2
x = self.unpool(x, mask2)
x = self.conv2_2_D_bn(F.relu(self.conv2_2_D(x)))
x = self.conv2_1_D_bn(F.relu(self.conv2_1_D(x)))
# Decoder block 1
x = self.unpool(x, mask1)
x = self.conv1_2_D_bn(F.relu(self.conv1_2_D(x)))
x = F.log_softmax(self.conv1_1_D(x))
return x
# instantiate the network
net = SegNet()
import os
try:
from urllib.request import URLopener
except ImportError:
from urllib import URLopener
# Download VGG-16 weights from PyTorch
vgg_url = 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth'
if not os.path.isfile('./vgg16_bn-6c64b313.pth'):
weights = URLopener().retrieve(vgg_url, './vgg16_bn-6c64b313.pth')
vgg16_weights = torch.load('./vgg16_bn-6c64b313.pth')
mapped_weights = {}
for k_vgg, k_segnet in zip(vgg16_weights.keys(), net.state_dict().keys()):
if "features" in k_vgg:
mapped_weights[k_segnet] = vgg16_weights[k_vgg]
print("Mapping {} to {}".format(k_vgg, k_segnet))
try:
net.load_state_dict(mapped_weights)
print("Loaded VGG-16 weights in SegNet !")
except:
# Ignore missing keys
pass
Mapping features.0.weight to conv1_1.weight Mapping features.0.bias to conv1_1.bias Mapping features.1.weight to conv1_1_bn.weight Mapping features.1.bias to conv1_1_bn.bias Mapping features.1.running_mean to conv1_1_bn.running_mean Mapping features.1.running_var to conv1_1_bn.running_var Mapping features.3.weight to conv1_1_bn.num_batches_tracked Mapping features.3.bias to conv1_2.weight Mapping features.4.weight to conv1_2.bias Mapping features.4.bias to conv1_2_bn.weight Mapping features.4.running_mean to conv1_2_bn.bias Mapping features.4.running_var to conv1_2_bn.running_mean Mapping features.7.weight to conv1_2_bn.running_var Mapping features.7.bias to conv1_2_bn.num_batches_tracked Mapping features.8.weight to conv2_1.weight Mapping features.8.bias to conv2_1.bias Mapping features.8.running_mean to conv2_1_bn.weight Mapping features.8.running_var to conv2_1_bn.bias Mapping features.10.weight to conv2_1_bn.running_mean Mapping features.10.bias to conv2_1_bn.running_var Mapping features.11.weight to conv2_1_bn.num_batches_tracked Mapping features.11.bias to conv2_2.weight Mapping features.11.running_mean to conv2_2.bias Mapping features.11.running_var to conv2_2_bn.weight Mapping features.14.weight to conv2_2_bn.bias Mapping features.14.bias to conv2_2_bn.running_mean Mapping features.15.weight to conv2_2_bn.running_var Mapping features.15.bias to conv2_2_bn.num_batches_tracked Mapping features.15.running_mean to conv3_1.weight Mapping features.15.running_var to conv3_1.bias Mapping features.17.weight to conv3_1_bn.weight Mapping features.17.bias to conv3_1_bn.bias Mapping features.18.weight to conv3_1_bn.running_mean Mapping features.18.bias to conv3_1_bn.running_var Mapping features.18.running_mean to conv3_1_bn.num_batches_tracked Mapping features.18.running_var to conv3_2.weight Mapping features.20.weight to conv3_2.bias Mapping features.20.bias to conv3_2_bn.weight Mapping features.21.weight to conv3_2_bn.bias Mapping features.21.bias to conv3_2_bn.running_mean Mapping features.21.running_mean to conv3_2_bn.running_var Mapping features.21.running_var to conv3_2_bn.num_batches_tracked Mapping features.24.weight to conv3_3.weight Mapping features.24.bias to conv3_3.bias Mapping features.25.weight to conv3_3_bn.weight Mapping features.25.bias to conv3_3_bn.bias Mapping features.25.running_mean to conv3_3_bn.running_mean Mapping features.25.running_var to conv3_3_bn.running_var Mapping features.27.weight to conv3_3_bn.num_batches_tracked Mapping features.27.bias to conv4_1.weight Mapping features.28.weight to conv4_1.bias Mapping features.28.bias to conv4_1_bn.weight Mapping features.28.running_mean to conv4_1_bn.bias Mapping features.28.running_var to conv4_1_bn.running_mean Mapping features.30.weight to conv4_1_bn.running_var Mapping features.30.bias to conv4_1_bn.num_batches_tracked Mapping features.31.weight to conv4_2.weight Mapping features.31.bias to conv4_2.bias Mapping features.31.running_mean to conv4_2_bn.weight Mapping features.31.running_var to conv4_2_bn.bias Mapping features.34.weight to conv4_2_bn.running_mean Mapping features.34.bias to conv4_2_bn.running_var Mapping features.35.weight to conv4_2_bn.num_batches_tracked Mapping features.35.bias to conv4_3.weight Mapping features.35.running_mean to conv4_3.bias Mapping features.35.running_var to conv4_3_bn.weight Mapping features.37.weight to conv4_3_bn.bias Mapping features.37.bias to conv4_3_bn.running_mean Mapping features.38.weight to conv4_3_bn.running_var Mapping features.38.bias to conv4_3_bn.num_batches_tracked Mapping features.38.running_mean to conv5_1.weight Mapping features.38.running_var to conv5_1.bias Mapping features.40.weight to conv5_1_bn.weight Mapping features.40.bias to conv5_1_bn.bias Mapping features.41.weight to conv5_1_bn.running_mean Mapping features.41.bias to conv5_1_bn.running_var Mapping features.41.running_mean to conv5_1_bn.num_batches_tracked Mapping features.41.running_var to conv5_2.weight
net.cuda()
SegNet( (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (unpool): MaxUnpool2d(kernel_size=(2, 2), stride=(2, 2), padding=(0, 0)) (conv1_1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (conv1_1_bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (conv1_2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (conv1_2_bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (conv2_1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (conv2_1_bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (conv2_2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (conv2_2_bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (conv3_1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (conv3_1_bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (conv3_2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (conv3_2_bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (conv3_3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (conv3_3_bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (conv4_1): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (conv4_1_bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (conv4_2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (conv4_2_bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (conv4_3): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (conv4_3_bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (conv5_1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (conv5_1_bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (conv5_2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (conv5_2_bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (conv5_3): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (conv5_3_bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (conv5_3_D): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (conv5_3_D_bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (conv5_2_D): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (conv5_2_D_bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (conv5_1_D): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (conv5_1_D_bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (conv4_3_D): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (conv4_3_D_bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (conv4_2_D): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (conv4_2_D_bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (conv4_1_D): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (conv4_1_D_bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (conv3_3_D): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (conv3_3_D_bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (conv3_2_D): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (conv3_2_D_bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (conv3_1_D): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (conv3_1_D_bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (conv2_2_D): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (conv2_2_D_bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (conv2_1_D): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (conv2_1_D_bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (conv1_2_D): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (conv1_2_D_bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (conv1_1_D): Conv2d(64, 6, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) )
import os
# Load the datasets
if DATASET == 'Potsdam':
all_files = sorted(glob(LABEL_FOLDER.replace('{}', '*')))
all_ids = ["".join(f.split('')[5:7]) for f in all_files]
elif DATASET == 'Vaihingen':
#all_ids =
all_files = sorted(glob(LABEL_FOLDER.replace('{}', '*')))
all_ids = [f.split('area')[-1].split('.')[0] for f in all_files]
# Random tile numbers for train/test split
train_ids = random.sample(all_ids, 2 * len(all_ids) // 3 + 1)
test_ids = list(set(all_ids) - set(train_ids))
# Exemple of a train/val split on Vaihingen :
##train_ids = ['1', '3', '23', '26', '7', '11', '13', '28', '17', '32', '34', '37']
##test_ids = ['5', '21', '15', '30']
train_ids = ['1', '23']
test_ids = ['5']
print("Tiles for training : ", train_ids)
print("Tiles for testing : ", test_ids)
train_set = ISPRS_dataset(train_ids, cache=CACHE)
train_loader = torch.utils.data.DataLoader(train_set,batch_size=BATCH_SIZE)
Tiles for training : ['1', '23'] Tiles for testing : ['5']
base_lr = 0.01
params_dict = dict(net.named_parameters())
params = []
for key, value in params_dict.items():
if '_D' in key:
# Decoder weights are trained at the nominal learning rate
params += [{'params':[value],'lr': base_lr}]
else:
# Encoder weights are trained at lr / 2 (we have VGG-16 weights as initialization)
params += [{'params':[value],'lr': base_lr / 2}]
optimizer = optim.SGD(net.parameters(), lr=base_lr, momentum=0.9, weight_decay=0.0005)
# We define the scheduler
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, [25, 35, 45], gamma=0.1)
def test(net, test_ids, all=False, stride=WINDOW_SIZE[0], batch_size=BATCH_SIZE, window_size=WINDOW_SIZE):
# Use the network on the test set
test_images = (1 / 255 * np.asarray(io.imread(DATA_FOLDER.format(id)), dtype='float32') for id in test_ids)
test_labels = (np.asarray(io.imread(LABEL_FOLDER.format(id)), dtype='uint8') for id in test_ids)
eroded_labels = (convert_from_color(io.imread(ERODED_FOLDER.format(id))) for id in test_ids)
all_preds = []
all_gts = []
# Switch the network to inference mode
net.eval()
for img, gt, gt_e in tqdm(zip(test_images, test_labels, eroded_labels), total=len(test_ids), leave=False):
pred = np.zeros(img.shape[:2] + (N_CLASSES,))
total = count_sliding_window(img, step=stride, window_size=window_size) // batch_size
for i, coords in enumerate(tqdm(grouper(batch_size, sliding_window(img, step=stride, window_size=window_size)), total=total, leave=False)):
# Display in progress results
if i > 0 and total > 10 and i % int(10 * total / 100) == 0:
_pred = np.argmax(pred, axis=-1)
fig = plt.figure()
fig.add_subplot(1,3,1)
plt.imshow(np.asarray(255 * img, dtype='uint8'))
fig.add_subplot(1,3,2)
plt.imshow(convert_to_color(_pred))
fig.add_subplot(1,3,3)
plt.imshow(gt)
clear_output()
plt.show()
# Build the tensor
image_patches = [np.copy(img[x:x+w, y:y+h]).transpose((2,0,1)) for x,y,w,h in coords]
image_patches = np.asarray(image_patches)
image_patches = Variable(torch.from_numpy(image_patches).cuda(), volatile=True)
# Do the inference
outs = net(image_patches)
outs = outs.data.cpu().numpy()
# Fill in the results array
for out, (x, y, w, h) in zip(outs, coords):
out = out.transpose((1,2,0))
pred[x:x+w, y:y+h] += out
del(outs)
pred = np.argmax(pred, axis=-1)
# Display the result
clear_output()
fig = plt.figure()
fig.add_subplot(1,3,1)
plt.imshow(np.asarray(255 * img, dtype='uint8'))
fig.add_subplot(1,3,2)
plt.imshow(convert_to_color(pred))
fig.add_subplot(1,3,3)
plt.imshow(gt)
plt.show()
all_preds.append(pred)
all_gts.append(gt_e)
clear_output()
# Compute some metrics
metrics(pred.ravel(), gt_e.ravel())
cm, accuracy = metrics(np.concatenate([p.ravel() for p in all_preds]), np.concatenate([p.ravel() for p in all_gts]).ravel())
if all:
return cm, accuracy, all_preds, all_gts
else:
return accuracy
from IPython.display import clear_output
def train(net, optimizer, epochs, scheduler=None, weights=WEIGHTS, save_epoch = 5):
losses = np.zeros(1000000)
mean_losses = np.zeros(100000000)
weights = weights.cuda()
criterion = nn.NLLLoss2d(weight=weights)
iter_ = 0
for e in range(1, epochs + 1):
if scheduler is not None:
scheduler.step()
net.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = Variable(data.cuda()), Variable(target.cuda())
optimizer.zero_grad()
output = net(data)
loss = CrossEntropy2d(output, target, weight=weights)
## loss = F.cross_entropy(output, target, weight=weights)
loss.backward()
optimizer.step()
losses[iter_] = loss.item() ##loss.data[0]
mean_losses[iter_] = np.mean(losses[max(0,iter_-100):iter_])
if iter_ % 100 == 0:
clear_output()
rgb = np.asarray(255 * np.transpose(data.data.cpu().numpy()[0],(1,2,0)), dtype='uint8')
pred = np.argmax(output.data.cpu().numpy()[0], axis=0)
gt = target.data.cpu().numpy()[0]
print('Train (epoch {}/{}) [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAccuracy: {}'.format(
e, epochs, batch_idx, len(train_loader),
100. * batch_idx / len(train_loader), loss.item(), accuracy(pred, gt))) ##loss.data[0]
plt.plot(mean_losses[:iter_]) and plt.show()
fig = plt.figure()
fig.add_subplot(131)
plt.imshow(rgb)
plt.title('RGB')
fig.add_subplot(132)
plt.imshow(convert_to_color(gt))
plt.title('Ground truth')
fig.add_subplot(133)
plt.title('Prediction')
plt.imshow(convert_to_color(pred))
plt.show()
iter_ += 1
del(data, target, loss)
#if e % save_epoch == 0:
# We validate with the largest possible stride for faster computing
acc = test(net, test_ids, all=False, stride=min(WINDOW_SIZE))
#torch.save(net.state_dict(), './segnet256_epoch{}_{}'.format(e, acc))
torch.save(net.state_dict(), './segnet_final')
net.load_state_dict(torch.load('./segnet_final'))
<All keys matched successfully>
_, cm, all_preds, all_gts = test(net, test_ids, all=True, stride=32)
C:\Users\rufai\anaconda3\lib\site-packages\sklearn\utils\validation.py:70: FutureWarning: Pass labels=range(0, 6) as keyword args. From version 1.0 (renaming of 0.25) passing these as positional arguments will result in an error warnings.warn(f"Pass {args_msg} as keyword args. From version "
Confusion matrix : [[1264252 193273 47630 10246 392 4] [ 401724 1871696 19715 5392 4023 0] [ 19725 3023 155638 85423 0 0] [ 4455 1955 33644 252572 0 10] [ 34059 7231 1919 783 925 29] [ 0 0 0 0 0 0]] --- 4419738 pixels processed Total accuracy : 80.21025228192259% --- F1Score : roads: 0.7803995787669922 buildings: 0.8547087855684189 low veg.: 0.5959089125211782 trees: 0.7806853235906851 cars: 0.03678956369566082 clutter: 0.0 --- Kappa: 0.6769271838526592
C:\Users\rufai\anaconda3\lib\site-packages\sklearn\utils\validation.py:70: FutureWarning: Pass labels=range(0, 6) as keyword args. From version 1.0 (renaming of 0.25) passing these as positional arguments will result in an error warnings.warn(f"Pass {args_msg} as keyword args. From version "
Confusion matrix : [[1264252 193273 47630 10246 392 4] [ 401724 1871696 19715 5392 4023 0] [ 19725 3023 155638 85423 0 0] [ 4455 1955 33644 252572 0 10] [ 34059 7231 1919 783 925 29] [ 0 0 0 0 0 0]] --- 4419738 pixels processed Total accuracy : 80.21025228192259% --- F1Score : roads: 0.7803995787669922 buildings: 0.8547087855684189 low veg.: 0.5959089125211782 trees: 0.7806853235906851 cars: 0.03678956369566082 clutter: 0.0 --- Kappa: 0.6769271838526592
# from sklearn.metrics import multilabel_confusion_matrix, accuracy_score
# import seaborn as sns
# import pandas as pd
# def plot_cm(cm, LABELS):
# df_cm = pd.DataFrame(cm, index = LABELS,
# columns = LABELS)
# plt.figure(figsize = (15, 12))
# sns.heatmap(df_cm, annot=True)
# plot_cm(cm, LABELS)
for p, id_ in zip(all_preds, test_ids):
img = convert_to_color(p)
plt.imshow(img) and plt.show()
io.imsave('./inference_tile_{}.png'.format(id_), img)
UNET
import torch
import torch.nn as nn
from torchvision import models
def convrelu(in_channels, out_channels, kernel, padding):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel, padding=padding),
nn.ReLU(inplace=True),
)
class UNet(nn.Module):
"""
UNET with ResNet18 as backbone.
"""
def __init__(self, n_class):
super().__init__()
self.base_model = models.resnet18(pretrained=True)
self.base_layers = list(self.base_model.children())
self.layer0 = nn.Sequential(*self.base_layers[:3]) # size=(N, 64, x.H/2, x.W/2)
self.layer0_1x1 = convrelu(64, 64, 1, 0)
self.layer1 = nn.Sequential(*self.base_layers[3:5]) # size=(N, 64, x.H/4, x.W/4)
self.layer1_1x1 = convrelu(64, 64, 1, 0)
self.layer2 = self.base_layers[5] # size=(N, 128, x.H/8, x.W/8)
self.layer2_1x1 = convrelu(128, 128, 1, 0)
self.layer3 = self.base_layers[6] # size=(N, 256, x.H/16, x.W/16)
self.layer3_1x1 = convrelu(256, 256, 1, 0)
self.layer4 = self.base_layers[7] # size=(N, 512, x.H/32, x.W/32)
self.layer4_1x1 = convrelu(512, 512, 1, 0)
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.conv_up3 = convrelu(256 + 512, 512, 3, 1)
self.conv_up2 = convrelu(128 + 512, 256, 3, 1)
self.conv_up1 = convrelu(64 + 256, 256, 3, 1)
self.conv_up0 = convrelu(64 + 256, 128, 3, 1)
self.conv_original_size0 = convrelu(3, 64, 3, 1)
self.conv_original_size1 = convrelu(64, 64, 3, 1)
self.conv_original_size2 = convrelu(64 + 128, 64, 3, 1)
self.conv_last = nn.Conv2d(64, n_class, 1)
def forward(self, input):
x_original = self.conv_original_size0(input)
x_original = self.conv_original_size1(x_original)
layer0 = self.layer0(input)
layer1 = self.layer1(layer0)
layer2 = self.layer2(layer1)
layer3 = self.layer3(layer2)
layer4 = self.layer4(layer3)
layer4 = self.layer4_1x1(layer4)
x = self.upsample(layer4)
layer3 = self.layer3_1x1(layer3)
x = torch.cat([x, layer3], dim=1)
x = self.conv_up3(x)
x = self.upsample(x)
layer2 = self.layer2_1x1(layer2)
x = torch.cat([x, layer2], dim=1)
x = self.conv_up2(x)
x = self.upsample(x)
layer1 = self.layer1_1x1(layer1)
x = torch.cat([x, layer1], dim=1)
x = self.conv_up1(x)
x = self.upsample(x)
layer0 = self.layer0_1x1(layer0)
x = torch.cat([x, layer0], dim=1)
x = self.conv_up0(x)
x = self.upsample(x)
x = torch.cat([x, x_original], dim=1)
x = self.conv_original_size2(x)
out = self.conv_last(x)
return out
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = UNet(n_class=6)
model = model.to(device)
# check keras-like model summary using torchsummary
from torchsummary import summary
summary(model, input_size=(3, 256, 256) )
---------------------------------------------------------------- Layer (type) Output Shape Param # ================================================================ Conv2d-1 [-1, 64, 256, 256] 1,792 ReLU-2 [-1, 64, 256, 256] 0 Conv2d-3 [-1, 64, 256, 256] 36,928 ReLU-4 [-1, 64, 256, 256] 0 Conv2d-5 [-1, 64, 128, 128] 9,408 Conv2d-6 [-1, 64, 128, 128] 9,408 BatchNorm2d-7 [-1, 64, 128, 128] 128 BatchNorm2d-8 [-1, 64, 128, 128] 128 ReLU-9 [-1, 64, 128, 128] 0 ReLU-10 [-1, 64, 128, 128] 0 MaxPool2d-11 [-1, 64, 64, 64] 0 MaxPool2d-12 [-1, 64, 64, 64] 0 Conv2d-13 [-1, 64, 64, 64] 36,864 Conv2d-14 [-1, 64, 64, 64] 36,864 BatchNorm2d-15 [-1, 64, 64, 64] 128 BatchNorm2d-16 [-1, 64, 64, 64] 128 ReLU-17 [-1, 64, 64, 64] 0 ReLU-18 [-1, 64, 64, 64] 0 Conv2d-19 [-1, 64, 64, 64] 36,864 Conv2d-20 [-1, 64, 64, 64] 36,864 BatchNorm2d-21 [-1, 64, 64, 64] 128 BatchNorm2d-22 [-1, 64, 64, 64] 128 ReLU-23 [-1, 64, 64, 64] 0 ReLU-24 [-1, 64, 64, 64] 0 BasicBlock-25 [-1, 64, 64, 64] 0 BasicBlock-26 [-1, 64, 64, 64] 0 Conv2d-27 [-1, 64, 64, 64] 36,864 Conv2d-28 [-1, 64, 64, 64] 36,864 BatchNorm2d-29 [-1, 64, 64, 64] 128 BatchNorm2d-30 [-1, 64, 64, 64] 128 ReLU-31 [-1, 64, 64, 64] 0 ReLU-32 [-1, 64, 64, 64] 0 Conv2d-33 [-1, 64, 64, 64] 36,864 Conv2d-34 [-1, 64, 64, 64] 36,864 BatchNorm2d-35 [-1, 64, 64, 64] 128 BatchNorm2d-36 [-1, 64, 64, 64] 128 ReLU-37 [-1, 64, 64, 64] 0 ReLU-38 [-1, 64, 64, 64] 0 BasicBlock-39 [-1, 64, 64, 64] 0 BasicBlock-40 [-1, 64, 64, 64] 0 Conv2d-41 [-1, 128, 32, 32] 73,728 Conv2d-42 [-1, 128, 32, 32] 73,728 BatchNorm2d-43 [-1, 128, 32, 32] 256 BatchNorm2d-44 [-1, 128, 32, 32] 256 ReLU-45 [-1, 128, 32, 32] 0 ReLU-46 [-1, 128, 32, 32] 0 Conv2d-47 [-1, 128, 32, 32] 147,456 Conv2d-48 [-1, 128, 32, 32] 147,456 BatchNorm2d-49 [-1, 128, 32, 32] 256 BatchNorm2d-50 [-1, 128, 32, 32] 256 Conv2d-51 [-1, 128, 32, 32] 8,192 Conv2d-52 [-1, 128, 32, 32] 8,192 BatchNorm2d-53 [-1, 128, 32, 32] 256 BatchNorm2d-54 [-1, 128, 32, 32] 256 ReLU-55 [-1, 128, 32, 32] 0 ReLU-56 [-1, 128, 32, 32] 0 BasicBlock-57 [-1, 128, 32, 32] 0 BasicBlock-58 [-1, 128, 32, 32] 0 Conv2d-59 [-1, 128, 32, 32] 147,456 Conv2d-60 [-1, 128, 32, 32] 147,456 BatchNorm2d-61 [-1, 128, 32, 32] 256 BatchNorm2d-62 [-1, 128, 32, 32] 256 ReLU-63 [-1, 128, 32, 32] 0 ReLU-64 [-1, 128, 32, 32] 0 Conv2d-65 [-1, 128, 32, 32] 147,456 Conv2d-66 [-1, 128, 32, 32] 147,456 BatchNorm2d-67 [-1, 128, 32, 32] 256 BatchNorm2d-68 [-1, 128, 32, 32] 256 ReLU-69 [-1, 128, 32, 32] 0 ReLU-70 [-1, 128, 32, 32] 0 BasicBlock-71 [-1, 128, 32, 32] 0 BasicBlock-72 [-1, 128, 32, 32] 0 Conv2d-73 [-1, 256, 16, 16] 294,912 Conv2d-74 [-1, 256, 16, 16] 294,912 BatchNorm2d-75 [-1, 256, 16, 16] 512 BatchNorm2d-76 [-1, 256, 16, 16] 512 ReLU-77 [-1, 256, 16, 16] 0 ReLU-78 [-1, 256, 16, 16] 0 Conv2d-79 [-1, 256, 16, 16] 589,824 Conv2d-80 [-1, 256, 16, 16] 589,824 BatchNorm2d-81 [-1, 256, 16, 16] 512 BatchNorm2d-82 [-1, 256, 16, 16] 512 Conv2d-83 [-1, 256, 16, 16] 32,768 Conv2d-84 [-1, 256, 16, 16] 32,768 BatchNorm2d-85 [-1, 256, 16, 16] 512 BatchNorm2d-86 [-1, 256, 16, 16] 512 ReLU-87 [-1, 256, 16, 16] 0 ReLU-88 [-1, 256, 16, 16] 0 BasicBlock-89 [-1, 256, 16, 16] 0 BasicBlock-90 [-1, 256, 16, 16] 0 Conv2d-91 [-1, 256, 16, 16] 589,824 Conv2d-92 [-1, 256, 16, 16] 589,824 BatchNorm2d-93 [-1, 256, 16, 16] 512 BatchNorm2d-94 [-1, 256, 16, 16] 512 ReLU-95 [-1, 256, 16, 16] 0 ReLU-96 [-1, 256, 16, 16] 0 Conv2d-97 [-1, 256, 16, 16] 589,824 Conv2d-98 [-1, 256, 16, 16] 589,824 BatchNorm2d-99 [-1, 256, 16, 16] 512 BatchNorm2d-100 [-1, 256, 16, 16] 512 ReLU-101 [-1, 256, 16, 16] 0 ReLU-102 [-1, 256, 16, 16] 0 BasicBlock-103 [-1, 256, 16, 16] 0 BasicBlock-104 [-1, 256, 16, 16] 0 Conv2d-105 [-1, 512, 8, 8] 1,179,648 Conv2d-106 [-1, 512, 8, 8] 1,179,648 BatchNorm2d-107 [-1, 512, 8, 8] 1,024 BatchNorm2d-108 [-1, 512, 8, 8] 1,024 ReLU-109 [-1, 512, 8, 8] 0 ReLU-110 [-1, 512, 8, 8] 0 Conv2d-111 [-1, 512, 8, 8] 2,359,296 Conv2d-112 [-1, 512, 8, 8] 2,359,296 BatchNorm2d-113 [-1, 512, 8, 8] 1,024 BatchNorm2d-114 [-1, 512, 8, 8] 1,024 Conv2d-115 [-1, 512, 8, 8] 131,072 Conv2d-116 [-1, 512, 8, 8] 131,072 BatchNorm2d-117 [-1, 512, 8, 8] 1,024 BatchNorm2d-118 [-1, 512, 8, 8] 1,024 ReLU-119 [-1, 512, 8, 8] 0 ReLU-120 [-1, 512, 8, 8] 0 BasicBlock-121 [-1, 512, 8, 8] 0 BasicBlock-122 [-1, 512, 8, 8] 0 Conv2d-123 [-1, 512, 8, 8] 2,359,296 Conv2d-124 [-1, 512, 8, 8] 2,359,296 BatchNorm2d-125 [-1, 512, 8, 8] 1,024 BatchNorm2d-126 [-1, 512, 8, 8] 1,024 ReLU-127 [-1, 512, 8, 8] 0 ReLU-128 [-1, 512, 8, 8] 0 Conv2d-129 [-1, 512, 8, 8] 2,359,296 Conv2d-130 [-1, 512, 8, 8] 2,359,296 BatchNorm2d-131 [-1, 512, 8, 8] 1,024 BatchNorm2d-132 [-1, 512, 8, 8] 1,024 ReLU-133 [-1, 512, 8, 8] 0 ReLU-134 [-1, 512, 8, 8] 0 BasicBlock-135 [-1, 512, 8, 8] 0 BasicBlock-136 [-1, 512, 8, 8] 0 Conv2d-137 [-1, 512, 8, 8] 262,656 ReLU-138 [-1, 512, 8, 8] 0 Upsample-139 [-1, 512, 16, 16] 0 Conv2d-140 [-1, 256, 16, 16] 65,792 ReLU-141 [-1, 256, 16, 16] 0 Conv2d-142 [-1, 512, 16, 16] 3,539,456 ReLU-143 [-1, 512, 16, 16] 0 Upsample-144 [-1, 512, 32, 32] 0 Conv2d-145 [-1, 128, 32, 32] 16,512 ReLU-146 [-1, 128, 32, 32] 0 Conv2d-147 [-1, 256, 32, 32] 1,474,816 ReLU-148 [-1, 256, 32, 32] 0 Upsample-149 [-1, 256, 64, 64] 0 Conv2d-150 [-1, 64, 64, 64] 4,160 ReLU-151 [-1, 64, 64, 64] 0 Conv2d-152 [-1, 256, 64, 64] 737,536 ReLU-153 [-1, 256, 64, 64] 0 Upsample-154 [-1, 256, 128, 128] 0 Conv2d-155 [-1, 64, 128, 128] 4,160 ReLU-156 [-1, 64, 128, 128] 0 Conv2d-157 [-1, 128, 128, 128] 368,768 ReLU-158 [-1, 128, 128, 128] 0 Upsample-159 [-1, 128, 256, 256] 0 Conv2d-160 [-1, 64, 256, 256] 110,656 ReLU-161 [-1, 64, 256, 256] 0 Conv2d-162 [-1, 6, 256, 256] 390 ================================================================ Total params: 28,976,646 Trainable params: 28,976,646 Non-trainable params: 0 ---------------------------------------------------------------- Input size (MB): 0.75 Forward/backward pass size (MB): 545.50 Params size (MB): 110.54 Estimated Total Size (MB): 656.79 ----------------------------------------------------------------
base_lr = 0.01
params_dict = dict(model.named_parameters())
params = []
for key, value in params_dict.items():
if '_D' in key:
# Decoder weights are trained at the nominal learning rate
params += [{'params':[value],'lr': base_lr}]
else:
# Encoder weights are trained at lr / 2 (we have VGG-16 weights as initialization)
params += [{'params':[value],'lr': base_lr / 2}]
optimizer = optim.SGD(model.parameters(), lr=base_lr, momentum=0.9, weight_decay=0.0005)
# We define the scheduler
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, [25, 35, 45], gamma=0.1)
from IPython.display import clear_output
def train_unet(net, optimizer, epochs, scheduler=None, weights=WEIGHTS, save_epoch = 5):
losses = np.zeros(1000000)
mean_losses = np.zeros(100000000)
weights = weights.cuda()
criterion = nn.NLLLoss2d(weight=weights)
iter_ = 0
for e in range(1, epochs + 1):
if scheduler is not None:
scheduler.step()
net.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = Variable(data.cuda()), Variable(target.cuda())
optimizer.zero_grad()
output = net(data)
loss = CrossEntropy2d(output, target, weight=weights)
## loss = F.cross_entropy(output, target, weight=weights)
loss.backward()
optimizer.step()
losses[iter_] = loss.item() ##loss.data[0]
mean_losses[iter_] = np.mean(losses[max(0,iter_-100):iter_])
if iter_ % 100 == 0:
clear_output()
rgb = np.asarray(255 * np.transpose(data.data.cpu().numpy()[0],(1,2,0)), dtype='uint8')
pred = np.argmax(output.data.cpu().numpy()[0], axis=0)
gt = target.data.cpu().numpy()[0]
print('Train (epoch {}/{}) [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAccuracy: {}'.format(
e, epochs, batch_idx, len(train_loader),
100. * batch_idx / len(train_loader), loss.item(), accuracy(pred, gt))) ##loss.data[0]
plt.plot(mean_losses[:iter_]) and plt.show()
fig = plt.figure()
fig.add_subplot(131)
plt.imshow(rgb)
plt.title('RGB')
fig.add_subplot(132)
plt.imshow(convert_to_color(gt))
plt.title('Ground truth')
fig.add_subplot(133)
plt.title('Prediction')
plt.imshow(convert_to_color(pred))
plt.show()
iter_ += 1
del(data, target, loss)
#if e % save_epoch == 0:
# We validate with the largest possible stride for faster computing
acc = test(net, test_ids, all=False, stride=min(WINDOW_SIZE))
#torch.save(net.state_dict(), './segnet256_epoch{}_{}'.format(e, acc))
torch.save(net.state_dict(), './unet_final')
train_unet(model, optimizer, 1, scheduler)
C:\Users\rufai\anaconda3\lib\site-packages\sklearn\utils\validation.py:70: FutureWarning: Pass labels=range(0, 6) as keyword args. From version 1.0 (renaming of 0.25) passing these as positional arguments will result in an error warnings.warn(f"Pass {args_msg} as keyword args. From version "
Confusion matrix : [[1448895 52637 9375 2075 2815 0] [ 296506 2000951 4215 295 583 0] [ 38080 9260 131561 84772 136 0] [ 10849 231 9793 271727 36 0] [ 22192 1014 264 64 21412 0] [ 0 0 0 0 0 0]] --- 4419738 pixels processed Total accuracy : 87.66460817360667% --- F1Score : roads: 0.8696016197728968 buildings: 0.9164710740035309 low veg.: 0.6279506559399738 trees: 0.8340697608388367 cars: 0.6124013270792815 clutter: nan --- Kappa: 0.7971852830827353
<ipython-input-8-ed92d6ac3628>:87: RuntimeWarning: invalid value encountered in double_scalars F1Score[i] = 2. * cm[i,i] / (np.sum(cm[i,:]) + np.sum(cm[:,i])) C:\Users\rufai\anaconda3\lib\site-packages\sklearn\utils\validation.py:70: FutureWarning: Pass labels=range(0, 6) as keyword args. From version 1.0 (renaming of 0.25) passing these as positional arguments will result in an error warnings.warn(f"Pass {args_msg} as keyword args. From version "
Confusion matrix : [[1448895 52637 9375 2075 2815 0] [ 296506 2000951 4215 295 583 0] [ 38080 9260 131561 84772 136 0] [ 10849 231 9793 271727 36 0] [ 22192 1014 264 64 21412 0] [ 0 0 0 0 0 0]] --- 4419738 pixels processed Total accuracy : 87.66460817360667% --- F1Score : roads: 0.8696016197728968 buildings: 0.9164710740035309 low veg.: 0.6279506559399738 trees: 0.8340697608388367 cars: 0.6124013270792815 clutter: nan --- Kappa: 0.7971852830827353
<ipython-input-8-ed92d6ac3628>:87: RuntimeWarning: invalid value encountered in double_scalars F1Score[i] = 2. * cm[i,i] / (np.sum(cm[i,:]) + np.sum(cm[:,i]))
model.load_state_dict(torch.load('./unet_final'))
<All keys matched successfully>
_, all_preds, all_gts = test(model, test_ids, all=True, stride=32)
C:\Users\rufai\anaconda3\lib\site-packages\sklearn\utils\validation.py:70: FutureWarning: Pass labels=range(0, 6) as keyword args. From version 1.0 (renaming of 0.25) passing these as positional arguments will result in an error warnings.warn(f"Pass {args_msg} as keyword args. From version "
Confusion matrix : [[1461538 42973 6684 2123 2479 0] [ 301860 1995962 4206 361 161 0] [ 42016 6378 135425 79950 40 0] [ 10612 84 7808 274114 18 0] [ 23059 1187 155 46 20499 0] [ 0 0 0 0 0 0]] --- 4419738 pixels processed Total accuracy : 87.9585622496175% --- F1Score : roads: 0.8712902570045683 buildings: 0.9178664074273177 low veg.: 0.6478316713985367 trees: 0.844428014725136 cars: 0.6016465374286427 clutter: nan --- Kappa: 0.8021370362641846
<ipython-input-8-ed92d6ac3628>:87: RuntimeWarning: invalid value encountered in double_scalars F1Score[i] = 2. * cm[i,i] / (np.sum(cm[i,:]) + np.sum(cm[:,i])) C:\Users\rufai\anaconda3\lib\site-packages\sklearn\utils\validation.py:70: FutureWarning: Pass labels=range(0, 6) as keyword args. From version 1.0 (renaming of 0.25) passing these as positional arguments will result in an error warnings.warn(f"Pass {args_msg} as keyword args. From version "
Confusion matrix : [[1461538 42973 6684 2123 2479 0] [ 301860 1995962 4206 361 161 0] [ 42016 6378 135425 79950 40 0] [ 10612 84 7808 274114 18 0] [ 23059 1187 155 46 20499 0] [ 0 0 0 0 0 0]] --- 4419738 pixels processed Total accuracy : 87.9585622496175% --- F1Score : roads: 0.8712902570045683 buildings: 0.9178664074273177 low veg.: 0.6478316713985367 trees: 0.844428014725136 cars: 0.6016465374286427 clutter: nan --- Kappa: 0.8021370362641846
<ipython-input-8-ed92d6ac3628>:87: RuntimeWarning: invalid value encountered in double_scalars F1Score[i] = 2. * cm[i,i] / (np.sum(cm[i,:]) + np.sum(cm[:,i]))
for p, id_ in zip(all_preds, test_ids):
img = convert_to_color(p)
plt.imshow(img) and plt.show()
io.imsave('./inference_tile_{}.png'.format(id_), img)