Semantic Segmentation with tf.Keras

In this article I will talk about semantic segmentation with tf.keras and a possible implementation for recreating a simplified version of the U-Net Neural Network Architecture. This implementation is only an experiment and is not a made for any competition. I just wanted to explore semantic segmentation, while learning more about Machine Learning, Keras and more on Computer Vision.

What is image segmentation?

Image segmentation is the computer vision technique for understanding what is displayed on a pixel level. It is similar to image recognition, where objects are recognized (and probably localized). With segmentation the “recognition” happens on a pixel level. Therefore image segmentation gives information about the image contents more granularly.

The main applications are:

  • Photo/Video editing and creativity tools
  • Traffic control systems
  • Autonomous vehicles
  • Robotics

What type of segmentation is there?

Semantic segmentation

Is the process of finding a class label for each pixel. The classes can be different objects e.g. buildings vs cars. There can be sub classes of a class e.g. vehicle -> car, truck, van etc. Nevertheless all found car pixels are assigned the same label.

Instance segmentation

Instance segmentation is even more advanced. It allows for separation of distinct objects within a single class. So instead of assigning the class car to two cars in an image, it will label the two cars with car1 and car2.

Deep Learning ImageSegmentation

arch
Neural Network architecture for semantic segmentation

Basic structure

The Encoder

A set of layers that extract features of an image through a sequence of progressively narrower and deeper filters. Removing the spatial knowledge, while focusing on the more salient features during the contraction.

Decoder

A set of layers that progressively grows the output of the encoder into a segmentation mask resembling the pixel resolution of the input image.

Skip connections

Long range connections in the neural network that allow the model to draw on features at varying spatial scales to improve model accuracy.

U-Net

The U-Net model architecture derives its name from its U shape. The encoder and decoder work together to extract salient features of an image in the contraction leaf and then use those features in the expansion path to determine which label should be given to each pixel. The encoder is made up of blocks that downscale an image into narrower feature layers using convolutional layers with a non linear activation function and a max-pooling layer. While the decoder mirrors those blocks in the opposite direction, upscaling its output to the original image size and ultimately predicting a label for each pixel. Skip connections cut across the U to improve performance.

The Dataset: KITTI segmentation

Im using the KITTI semantic segmentation Dataset. It consists of 200 semantically annotated train as well as 200 test images. Providing Groundtruth data for 34 different labels, like road, sidewalk etc. Since not all labels are evenly distributed across the training images, which will make it harder for the model to learn, I will skip some of the provided labels and only focus on 20 meaningful classes. 200 images is quite a low number for training a CNN. Therefore the data will be augmented with rotation, zooming and other image processing operations.

Furthermore I am not interested in taking part in any competition, therefore I am not using the provided dev-kit for scientific comparison. Instead I tried to use tf.Keras as much as possible.

Metric

For defining the success of our model we need to define the metric used. For this we will use the Intersection over Union (IoU) metric. Since we have a multi class issue, we will use the mean IoU over all classes. Lucky for us tf.keras already provides a tf.keras.metrics.MeanIoU implementation.

The model with tf.keras

Now we need to build the model for semantic segmentation with tf.keras’s Sequential API. It is basically just a concatenation of convolution layers with MaxPooling2d for the contraction leaf and UpSampling2d Layers for the expansion leaf. I’ll be using 3 Down and 3 Up layers, also resulting in 3 skip connections. More Layers will probably increase the overall performance for sure, but it will also increase the training time. The model as is already contains 7.7 Million trainable parameters.

Loss Function

Each pixel of the output of the network is compared with the corresponding pixel in the ground truth label image. We apply sparse_categorical_crossentropy loss on each pixel. Sparse because we are not hot encoding our categories but use the int ids directly.

def get_conv_layer(parent, filters):
    conv = Conv2D(filters, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(parent)
    conv = Conv2D(filters, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv)
    return conv 
def get_up_conv_layer(parent, connection, filters):
    up = Conv2D(filters, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(parent))    
    merge = concatenate([connection,up], axis = 3)
    conv = Conv2D(filters, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge)
    conv = Conv2D(filters, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv)
    return conv
def unet(number_output_channels, pretrained_weights = None, input_size = (256,256,3)):
    inputs = Input(input_size)
    conv1 = get_conv_layer(inputs, 64)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
    conv2 = get_conv_layer(pool1, 128)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
    conv3 = get_conv_layer(pool2, 256)    
    drop3 = Dropout(0.5)(conv3)
    pool3 = MaxPooling2D(pool_size=(2, 2))(drop3)
    conv4 = get_conv_layer(pool3, 512)    
    drop4 = Dropout(0.5)(conv4)    
    up3 = get_up_conv_layer(parent=drop4, connection=drop3, filters=256)
    up2 = get_up_conv_layer(parent=up3, connection=conv2, filters=128)
    up1 = get_up_conv_layer(parent=up2, connection=conv1, filters=64)
    conv_last = Conv2D(number_output_channels, 1, activation = 'softmax')(up1)
    model = Model(inputs = inputs, outputs = conv_last)
    model.compile(optimizer = Adam(learning_rate=1e-3), loss = 'sparse_categorical_crossentropy', metrics = ['accuracy', MaskMeanIoU(name='iou', num_classes=number_output_channels)])
    
    if(pretrained_weights):
      model.load_weights(pretrained_weights)
    return model

Results

After training the model on Google Colab with a GPU Instance for 10 Epochs. Taking about 4 hours. It was able to achieve an overall mean IoU of ~65% and a Validation IOU of ~25%. This is of course open for improvement, but I was aware that reducing the complexity of the model will result in suboptimal results.

Now lets see visual results, which look quite good imho.

Note: this is a test image that was never used during training/validation.

Look how good it segmented the road and sidewalk, also all cars and vegetation look proper.

Now I’ll take a look at how our model’s capabilities are generalized, by using a random image from the internet. Considering the amount of training and the simplicity of the model, i think it’s still really awesome!

Segmentation on Random Image not affiliated with Dataset images

Conclusion

In this article, I’ve familiarized myself with semantic segmentation with tf.keras. The trained neural network does what I wanted it to do, but there are many issues still. Most notably the bad Mean IOU values, which will need some more investigation and tweaking. For optimization, we could increase the number of contraction/expansion layers. What would also be interesting and would reduce training time is to use a pretrained Image Recognition Model like VGG16 or MobileNet for the contraction side and let the model “only” learn the classification task.

Also investigating other models for semantic segmentation with tf.keras like Pyramid Structures, Mask R-CNN and DeepLab will be very interesting. You can checkout the full python notebook on my github.