From Transfer Learning in Action by Dipanjan Sarkar and Raghav Bali

This article discusses classifying images using transfer learning.

Take 40% off Transfer Learning in Action by entering fccsarkar into the discount code box at checkout at

Let’s dive into a hands-on example showcasing transfer learning in the context of image classification or categorization. The objective here will be to take a few sample images of animals and see how some canned, pre-trained models fare in classifying these images. We will be picking up a couple of pre-trained state-of-the-art models based on complexity, to compare and contrast how they interpret the true category of the input images.


The key objective here is to take a pre-trained model off-the-shelf and use it directly to predict the class of an input image. We focus on inference here to keep things simple without diving into how to train or fine-tune these models. Our methodology to tackle our key objective of image classification focuses on taking an input image, loading up a pre-trained model from TensorFlow Hub in Python and classifying the Top 5 probable classes of the input image. This workflow is depicted in figure 1.

Figure 1. Image Classification with Pre-trained CNNs. The figure depicts the top-5 class probabilities for the given input image using a pre-trained CNN model.

Pre-trained Model Architectures

For our experiment, we will be leveraging two state-of-the-art pre-trained convolution neural network (CNN) models, namely:

  • ResNet-50: This is a residual deep convolutional neural network (CNN) with a total of 50 layers focusing on the standard convolution and pooling layers a typical CNN has along with batch normalization layers for regularization. The novelty of these models include residual or skip connections. This model was trained on the standard ImageNet-1K dataset having a total of 1000 distinct classes.
  • BiT MultiClass ResNet-152 4x: This is Google’s latest state-of-the-art (SOTA) invention in the world of computer vision called Big Transfer, published on May, 2020. Here, they have built their flagship model architecture, a pre-trained ResNet-152 model (152 layers) but four times wider than the original model. This model uses group normalization layers instead of batch normalization for regularization. The model was trained on the ImageNet-21K dataset[1] having a total of 21843 classes.

The foundational architecture behind both models is a convolutional neural network (CNN) which works on the principle of leveraging a multi-layered hierarchical architecture of several convolution and pooling layers with non-linear activation functions.

Convolutional neural networks

Let’s look at the essential components of CNN models. Typically, a convolutional neural network, more popularly known as CNN model consists of a layered architecture of several layers which include convolution, pooling and dense layers besides the input and output layers. A typical architecture is depicted in figure 2.

Figure 2. Architecture of a typical Convolutional Neural Network. This usually includes a stacked hierarchy of convolution and pooling layers.

The CNN model leverages convolution and pooling layers to automatically extract different hierarchies of features, from very generic features like edges and corners to very specific features like the facial structure, whiskers and ears of the tiger depicted as an input image in figure 3. The feature maps are usually flattened using a flatten or global pooling operator to obtain a 1-dimensional feature vector. This vector is then sent as an input through a few fully-connected dense layers, and the output class is finally predicted using a softmax output layer.

The key objective of this multi-stage hierarchical architecture is to learn spatial hierarchies of patterns which are also translation invariant. This is possible through two main layers in the CNN architecture: the convolution and pooling layers.

Convolution Layers: The secret sauce of CNNs are its convolution layers! These layers are created by convolving multiple filters or kernels with patches of the input image, which help in extracting specific features from the input image automatically. Using a layered architecture of stacked convolution layers helps in learning spatial features with a certain hierarchy as depicted in figure 3.

Figure 3. Hierarchical feature maps extracted from convolutional layers. Each layer extracts relevant features for the input image. Shallower layers extract more generic features and deeper layers extract specific features pertaining to the given input image.

While figure 3 provides a simplistic view of a CNN, the core methodology is true in the sense that coarse and generic features like edges and corners are extracted in initial convolution layers (to give feature maps).  A combination of these features maps in deeper convolutional layers helps the CNN to learn more complex visual features like the mane, eyes, cheeks and nose. Finally the overall visual representation and concept of what a tiger looks like is built using a combination of these features.

Pooling Layers: We typically downsample the feature maps from the convolutional layers in the pooling layers using an aggregation operation like max, min or mean. Usually max-pooling is preferred, which means we take in patches of image pixels (e.g. a 2×2 patch) and reduce it to its maximum value (giving one pixel with the max value). Max-pooling is preferred because of its lower computation time as well as its ability to encode the enhanced aspects of the feature maps (by taking the maximal pixel values of image patches rather than the average). Pooling also helps in reducing overfitting, decreasing computation time and enables the CNN to learn translation-invariant features.

The ResNet architecture

Both of the pre-trained models we mentioned earlier are different variants of the ResNet CNN architecture. ResNet stands for Residual Networks, which introduced a novel concept of using residual or skip connections to build deeper neural network models without facing problems of vanishing gradients and model generalization ability. The typical architecture of a ResNet-50 has been simplified and depicted in figure 4.

Figure 4. ResNet-50 CNN architecture and its components. The key components include the convolution and identity block with residual (skip) connections.

It is pretty clear that the ResNet-50 architecture consists of several stacked convolutional and pooling layers followed by a final global average pooling and a fully connected layer with 1000 units to make the final class prediction. This model also introduces the concept of batch-normalization layers interspersed between layers to help with regularization. The stacked conv and identity blocks are novel concepts introduced in the ResNet architecture which make use of residual or skip connections as seen in the detailed block diagrams in figure 4.

The whole idea of a skip connection (also known as residual or shortcut connections) is to not just stack layers but also directly connect the original input to the output of a few stacked layers as seen in figure 5 where the original input is added to the output from the conv or identity block. The purpose of using skip connections is to enable the capability to build deeper networks without facing problems like vanishing gradients and saturation of performance by allowing alternate paths for gradients to flow through the network. We see different variants of the ResNet architecture in figure 5.

Figure 5. Various ResNet Architectures. The figure indicates the various ResNet models based on the total layers present in the model.

For our first pre-trained model we will use a ResNet-50 model which has been trained on the ImageNet-1k dataset with a multi-class classification task. Our second pre-trained model uses Google’s pre-trained Big Transfer Model for multi-label classification (BitM) which has variants based on ResNet 50, 101 and 152. The model we use is based on a variant of the ResNet-152 architecture which is 4 times wider.

Big Transfer (BiT) Pre-Trained Models

The Big Transfer Models (BiT) were trained and published by Google on May, 2020 as a part of their seminal research paper[2]. These pre-trained models are built on top of the basic ResNet architecture we discussed in the previous section with a few tricks and enhancements. The key focus of BigTransfer models including the following:

  • Upstream Training: Here we train large model architectures (e.g. ResNet) on large datasets (e.g. ImageNet-21k) with a long pre-training time and using concepts like Group Normalization with Weight Standardization, instead of Batch Normalization. The general observation has been that GroupNorm with Weight Standardization scales well to larger batch sizes as compared to BatchNorm.
  • Downstream Fine-tuning: Once the model is pre-trained, it can be fine-tuned and ‘adapted’ to any new dataset with relatively less number of samples. Google uses a hyperparameter heuristic called BiT-HyperRule where stochastic gradient descent (SGD) is used with an initial learning rate of 0.003 with a decay factor of 10 at 30%, 60% and 90% of the training steps.

In our following experiments, we will be using the BiTM-R152x4 model which is a pre-trained Big Transfer model based on Google’s flagship CNN model architecture of a ResNet-152 which is four times wider and trained to perform multi-label classification on the ImageNet-21k dataset.


Let’s now use these pre-trained models to solve our objective of predicting the Top-5 classes of input images.

TIP:  The supporting code notebooks are available in the book’s accompaniment GitHub repository at and we encourage you to fire up the Ch01_Image_Classification_Inference_Big_Transfer notebookfor Chapter 1 and try it out yourself!

We start by loading up the specific dependencies for image processing, modeling and inference.

 import tensorflow as tf
 import tensorflow_hub as tf_hub
 from PIL import Image
 import matplotlib.pyplot as plt
 import numpy as np
 print('TF Version:', tf.__version__)
 print('TF Hub Version:', tf_hub.__version__)
 TF Version: 2.3.0
 TF Hub Version: 0.8.0

Do note that we use TensorFlow 2.x here which is the latest version at the time of writing this book. Since we will be directly using the pre-trained models for inference, we will need to know the class labels of the original ImageNet-1K and the ImageNet-21K datasets for the ResNet-50 and BiTM-R152x4 models respectively as depicted in listing 1.

Listing 1. Loading ImageNet Class Labels

 data1k = []
 with open('ImageNetLabels.txt', 'r') as f:
     data1k = f.readlines()
 data21k = []
 with open('imagenet21k_wordnet_lemmas.txt', 'r') as f:
     data21k = f.readlines()
 imagenet1k_mapping = {i: value.strip('\n') 
                           for i, value in enumerate(data1k)}
 imagenet21k_mapping = {i: value.strip('\n') 
                           for i, value in enumerate(data21k)}
 print('ImageNet 1K (ResNet-50) Total Classes:', 
 print('Sample:', list(imagenet1k_mapping.items())[:5])
 print('\nImageNet 21K (BiT ResNet-152 4x)Total Classes:', 
 print('Sample:', list(imagenet21k_mapping.items())[:5])
 ImageNet 1K (ResNet-50) Total Classes: 1001
 Sample: [(0, 'background'), (1, 'tench'), (2, 'goldfish'),
          (3, 'great white shark'), (4, 'tiger shark')]
 ImageNet 21K (BiT ResNet-152 4x)Total Classes: 21843
 Sample: [(0, 'organism, being'), (1, 'benthos'),
          (2, 'heterotroph'), (3, 'cell'),
          (4, 'person, individual, someone, somebody, mortal, soul')]

#A Load ImageNet class label mappings

#B View sample class labels

The next step would be to load up the two pre-trained models we discussed earlier from TensorFlow Hub.

 resnet_model_url = ""
 resnet_50 = tf_hub.KerasLayer(resnet_model_url)
 bit_model_url = ""
 bit_r152x4 = tf_hub.KerasLayer(bit_model_url)

Once we have our pre-trained models ready, the next step would be to focus on building some specific utility functions which you can access from the notebook for this Chapter in our GitHub repository mentioned earlier. Just to get some perspective,

  • The preprocess_image(…) function helps us in pre-processing, shaping and scaling the input image pixel values between the range of 0-1.
  • The visualize_predictions(…) function takes in the pre-trained model, the class label mappings, the model type and the input image as inputs to visualize the top-5 predictions as a bar chart.

The ResNet-50 model directly gives the class probabilities as inputs but the BiTM-R152x4 model gives class logits as outputs which need to be converted to class probabilities. We can look at listing 2 which shows a section of the visualize_predictions(…) function which helps us achieve this.

Listing 2. Getting the class probabilities from the model predictions

 def visualize_predictions(model, image, imagenet_mapping_dict, 
     if model_type =='resnet':    #A
         probs = model(image)
         probs = tf.reshape(probs, [-1])
     else:    #B
         logits = model(image)
         logits = tf.reshape(logits, [-1])
         probs = tf.nn.softmax(logits)
     top5_imagenet_idxs = np.argsort(probs)[:-6:-1]    #C
     top5_probs = np.sort(probs)[:-6:-1]    #C
     pred_labels = [imagenet_mapping_dict[i]    #C 
                        for i in top5_imagenet_idxs]

#A Get class probabilities directly for pre-trained ResNet-50

#B Get logits and then derive class probabilities for pre-trained BiTM-R152x4

#C Get the top 5 class predictions and probabilities

Remember that logits are basically the log-odds or unnormalized class probabilities and hence you need to compute the softmax of these logits to get to the normalized class probabilities which sum up to 1 as depicted in figure 6 which shows a sample neural network architecture with the logits and the class probabilities for a hypothetical 3-class classification problem.

Figure 6. Logits and Softmax values in a Neural Network

The softmax function basically squashes the logits using the transform depicted in figure 6 to give us the normalized class probabilities. Let’s now put our code to action! You can leverage these functions on any downloaded image using the sequence of steps depicted in listing 3 to visualize the Top-5 predictions of our two pre-trained models.

Listing 3. Visualizing Top-5 predictions of our pre-trained models on a sample image

 img ='snow_leo.png').convert("RGB")
 pre_img = preprocess_image(img)
 plt.figure(figsize=(12, 3))
 plt.subplot(1,3,1)    #A
 visualize_predictions(model=bit_r152x4, image=pre_img,
 plt.subplot(1,3,2)    #B
 resnet_img = tf.image.resize(pre_img, (224, 224))
 visualize_predictions(model=resnet_50, image=resnet_img,

#A Visualizing top-5 predictions for BiTM-R152x4 model

#B Visualizing top-5 predictions for ResNet-50 model

Voila! We have the Top-5 predictions from our two pre-trained models depicted in a nice visualization in figure 7.

Figure 7. Prediction results on a Snow Leopard Image

It looks like both our models performed well, and as expected the BitM model is very specific and more accurate given it has been trained on over 21K classes with very specific animal species and breeds.

The ResNet-50 model has more inconsistencies as compared to the BiTM model with regard to predicting on animals of similar genus but slightly different species like tigers and lions as depicted in figure 8.

Figure 8. Correct vs. Incorrect predictions of the BitM and ResNet-50 models

Another aspect to keep in mind is that these models are not exhaustive. They don’t cover each and every entity on this planet. This would be impossible to do considering data collection for this task itself would take centuries, if not forever! An example is showcased in figure 9 where our models try to predict a very specific dog breed, the Afghan Hound, from a given image.

Figure 9. Both our models struggle to predict an Afghan Hound

Based on the Top-5 predictions in figure 9 you can see that while our BiTM model actually get the right prediction, the prediction probability is very low indicating our model is not too confident (given that it probably hasn’t seen too many examples of this dog breed in its training data during the pre-training phase). This is where we can fine-tune and adapt our models to make them more tuned to our specific datasets and output labels and outcomes. This forms the basis of the remaining chapters in the book where we will be trying to leverage pre-trained models and adapt these models to very different and novel problems.

If you want to learn more about the book, you can check it out on Manning’s liveBook platform here.


[2] Big Transfer (BiT): General Visual Representation Learning, Kolesnikov et al.