Description: https://images.manning.com/360/480/resize/book/b/19f34e0-6ec1-47a9-9e1a-98e7cd68d03e/Sarkar-MEAP-HI.png

From Transfer Learning in Action Dipanjan Sarkar and Raghav Bali

This article delves into tuning up a pre-trained ResNet-50 with one-cycle learning rate.


Take 40% off Transfer Learning in Action by entering fccsarkar into the discount code box at checkout at manning.com.


Fine-tuning pre-trained ResNet-50 with one-cycle learning rate

You may have seen that it is sometimes easy to get an initial burst in accuracy but once you reach 90%, you end up having to push really hard to even get a 1-2% improvement in performance. In this section, we will look at a way to dynamically change the learning rate over epochs using a policy called the one-cycle learning rate policy.

Originally mentioned in a paper by Leslie Smith, the one-cycle learning rate schedule[1] focuses on a cycle with two steps during the training process. We start by ramping up the learning rate initially from a lower to a higher value in a linear incremental fashion for a few epochs (Step 1) and then go back to a lower learning rate decaying over time (Step 2) across multiple epochs. There exists a ready to use implementation for the 1-cycle LR schedule thanks to Martin Gorner’s 2019 talk at Tensorflow World [2] which we will be using to build our 1-cycle LR schedule as depicted in listing 1.

Listing 1. Implementing the one-cycle learning rate schedule

 
 def lr_function(epoch):
     start_lr = 1e-5; min_lr = 1e-5; max_lr = 1e-4    #A
     rampup_epochs = 5; sustain_epochs = 0; exp_decay = .8    #B
  
     def lr(epoch, start_lr, min_lr, max_lr, rampup_epochs,
            sustain_epochs, exp_decay):
         if epoch < rampup_epochs:    #C
             lr = ((max_lr - start_lr) / rampup_epochs
                         * epoch + start_lr)
         elif epoch < rampup_epochs + sustain_epochs:    #D
             lr = max_lr
         else:    #E
             lr = ((max_lr - min_lr) *
                       exp_decay**(epoch - rampup_epochs -
                                     sustain_epochs) + min_lr)
         return lr
  
     return lr(epoch, start_lr, min_lr, max_lr,
               rampup_epochs, sustain_epochs, exp_decay)
  
 rng = [i for i in range(50)]
 y = [lr_function(x) for x in rng]
 plt.plot(rng, [lr_function(x) for x in rng])
 print('start lr:', y[0], '\nend lr:', y[-1])
  

#A Set values for starting, maximum and minimum learning rate possible

#B Set number of epochs to rampup and sustain learning rates along with learning rate decay factor

#C Rampup phase where the learning rate is increased for rampup_epochs till we reach max_lr

#D Sustain phase where we keep learning rate as max_lr for sustain_epochs

#E Decay phase where we decay the learning rate by a factor of exp_decay till we reach min_lr

As depicted in listing 1, we also run the defined function for a fixed number of epochs in the last few lines after the function definition to see how the learning rate changes along this cycle with the two steps we discussed earlier where we start with an initial learning rate of 1e-5 and ramp it up to 1e-4 in the first few epochs and then reduce it again back to 1e-5 over the course of the remaining epochs. This dynamic learning rate curve is depicted with a sample run of 50 epochs in figure 1.


Figure 1. One-cycle learning rate policy over 50 epochs. Learning rate is ramped up initially, followed by a slow decay over epochs.


We will now put our one-cycle learning rate scheduler to the test by applying it when training and fine-tuning our ResNet-50 model. We will use the architecture depicted in listing 2.

Listing 2. Fine-tuning pre-trained ResNet-50 with 1-Cycle LR callback

 
 epochs = 100
 callbacks = [
     tf.keras.callbacks.EarlyStopping(monitor='val_loss',    #A
                                      patience=5,
                                      restore_best_weights=True),
     tf.keras.callbacks.LearningRateScheduler(lambda epoch: \
                                              lr_function(epoch),    #B
                                              verbose=True)
    
 ]
  
 lr_finetuned_resnet50.compile(
     optimizer=tf.keras.optimizers.Adam(1e-5),
     loss="binary_crossentropy", metrics=["accuracy"],
 )
  
 history = lr_finetuned_resnet50.fit(    #C
     train_ds, epochs=epochs, callbacks=callbacks,
     validation_data=val_ds,
 )
  
 Epoch 00001: LearningRateScheduler reducing learning rate to 1e-05.
 Epoch 1/100
 loss: 0.5532 - accuracy: 0.7121 - val_loss: 0.5110 - val_accuracy: 0.7232
 ...
 ...
 Epoch 00020: LearningRateScheduler reducing learning rate to 1.3958e-05.
 Epoch 20/100
 loss: 0.0110 - accuracy: 0.9972 - val_loss: 0.1012 - val_accuracy: 0.9705
  
 Epoch 00021: LearningRateScheduler reducing learning rate to 1.3166e-05.
 Epoch 21/100
 loss: 0.0255 - accuracy: 0.9917 - val_loss: 0.1713 - val_accuracy: 0.9446
  

#A Setup early stopping callback to monitor the validation loss and stop training within 5 epochs of no improvement

#B Setup one-cycle learning rate callback to dynamically change the learning rate per epoch

#C Fine-tune our CNN network based on one-cycle learning rate policy

It is quite evident that we reach a pretty high validation accuracy within 20 epochs based on the sample training logs you can see in listing 2 and this is also visualized in figure 2 where we look at the learning curves for this model.


Figure 2. Learning curves for fine-tuned ResNet-50 with 1-cycle learning rate


We reach consistent values of validation accuracies ranging between 94-97% within 20 epochs which is so far the best we have seen across all our models. On evaluation the model performance on the test dataset we see a pretty good performance based on the metrics depicted in figure 3.


Figure 3. 1-Cycle LR fine-tuned ResNet-50 performance on test data


The metrics clearly depict a performance gain of 2% from our base fine-tuned ResNet model and we achieve a test F1-score of 97% which gives us a gain of 13% from our initial simple CNN model!

If you want to learn more about the book, you can check it out on Manning’s liveBook platform here.