Skip to content

    A Hands-on Tutorial for Transfer Learning in Python

    on April 14, 2022

    Fitting complex neural network models is a computationally heavy process, which requires access to large amounts of data. In this article we introduce transfer learning - a method for leveraging pre-trained model, which speeds up the fitting and removes the need of processing large amounts of data. We provide guidance on transfer learning for NLP and computer vision use-cases and show simple implementations of transfer learning via Hugging Face.


    The continuous development and improvements of machine learning approaches are driving a new industrial revolution. This has had a substantial impact on all aspects of life and may produce similar far-reaching effects akin to those observed during the industrial revolution. Nevertheless the current model fitting approaches only serve to enhance human learning capability rather than mimic human intelligence [1].

    To be able to perform inference on unseen data, machine learning algorithms require known data to learn from and heavily rely on it. For example, in a classification task, the annotated data needs to be divided into training and testing subsets. The feature vector of the former is used for training the model and the latter to test the performance of the trained model on unseen data. The dataset can consist of one or many categories and the inference must be performed on unseen data belonging to a category present in the training subset.

    Let us provide an example to clarify the point above. We have a dataset of fruit images with 3 categories including images of apples, bananas, and grapefruits. To train a model capable of accurate classification, this dataset will be divided into training and testing subsets with example ratio of 70 and 30%, respectively. The selected machine learning algorithm is trained, and performance is measured to ensure high accuracy and low error. The trained model can be improved using optimal hyperparameters. After completion of training, the model has learned how to successfully classify images of the three fruit classes. The next step is now to explore model performance on testing images that the model has not seen before this point. Finally, this model can be deployed in production.

    The advances in deep learning neural networks have allowed significant improvements in learning performance, resulting in their dominance in some areas including computer vision. Neural networks [2, 3] make use of data in the form of feature vectors and their operation loosely mimics a biological neuron that receives multiple inputs from other neurons. If the sum of the inputs surpasses a defined threshold, the neuron will transmit a signal to the next neuron. In a basic artificial neural network, the neurons (nodes) which are arranged in layers are connected to the neurons in the next layer. For each input, an initial random weight is assigned and subsequently modified via error back-propagation. This is continued until the model becomes proficient in predicting the class of an unknown sample [4].

    Despite these advances, there remain many main obstacles to the wide adoption of machine learning approaches. Machine learning models are typically limited to one task and one domain and may suffer from catastrophic forgetting [5]. Training a model in many instances requires large, annotated datasets, which may not always be available in some domains (for example the medical field). In addition, in multiclass classifications, class imbalance poses an additional challenge. Moreover, a model trained on one dataset may not perform similarly well on a separate dataset due to differences in input features and feature distribution.

    Irrespective of the size of the dataset, it may also be very difficult to include all categories within a dataset. In our fruit example described above, the model was trained to perform classification for 3 types of fruits whilst there are more than 2000 types of fruits around the world. This becomes a bigger issue in other types of datasets, for example, the number of described insect species is approximately 91,000. Creating annotated datasets to cover all the insect species would be very complex and exhaustive work.

    In summary, whilst the idea of intelligent machines is not new, and the assumption that machine intelligence can become identical to human intelligence is central to AI, the present-day AI does not reproduce the intelligence and learning that biological brains are capable of. For example, teaching a child to identify apples from oranges is a trivial task and can be taught even without a visual presentation of these objects. For instance, it is possible to describe the shape, colour, and other properties to enable a child to use these semantic descriptions to identify an orange from a banana. This can be achieved through other methods including reinforcement learning, not just in humans but in other species with pigeons capable of discriminating between cancerous and healthy tissues on a small sample size of 200 [6].

    The long-term solution is for AI models to learn intelligently similar to humans. In the meantime, new machine learning approaches have been developed to reduce reliance on data for model training including transfer learning and few-shot learning.

    Transfer learning

    Training complex neural networks from scratch with large datasets requires access to high-end computing resources, clusters of CPUs and GPUs as the training process is computationally expensive [7, 8]. In addition, even if access to computing resources is not an obstacle, the availability of large datasets for model training is not always an option.

    To circumvent these issues, and leverage the power of neural networks, one simple approach is to utilise an increasingly popular transfer function approach. This function allows building new models using weights from pre-trained models, which have already been trained on one task (e.g., classification) and one dataset (e.g., ImageNet as one of the largest open-source datasets), and use the new models on a similar task and unseen dataset.

    The intuition behind transfer learning is that if datasets in two problems contain similar data points, their feature representations are also similar and thus the weights obtained from one training can be used in solving subsequent similar problems rather than using random weights and model training from scratch. The pre-trained models can be fitted quicker on the second task as they already contain training weights from the first task. The pre-trained models can be directly used or utilised to extract features in addition to those that the model has already learnt during previous training. It is also possible to modify the pre-trained model to incorporate new layers aiming to obtain higher performance or to mould the model around the dataset.

    Many different pre-trained architectures (e.g., convolutional neural network; CNN) are available to use via transfer learning. One of the best performing architectures is the ResNet; developed and trained on the ImageNet and other datasets including the CIFAR-10 dataset. This model utilises residual connections which allow an identity transform layer to skip learnt transform layers before the features are again added together. These connections together with batch normalisation mitigate the problems of vanishing and exploding gradients [9].

    One drawback of transfer learning is that the features learned during the first task may not be sufficiently specific to the second task and as such the accompanied weights will not result in high performance for the second task. For example, models that are trained to recognise certain types of images can be reused to classify other images successfully, only if the image features of datasets for both tasks are similar [10, 11].
    In the next two sections, we will provide simple examples for transfer learning in natural language processing and computer vision.

    Transfer learning for natural language processing

    Many ventures provide pre-trained models, one of which is Hugging Face launched in 2017. Their focus is on NLP tasks and their library is built on PyTorch and TensorFlow and provides various resources including access to datasets and transformers. The latter is an easy way to utilise pre-trained NLP models. One of the popular NLP models is the BERT (Bidirectional Encoder Representations from Transformers), with a bidirectional approach that allows the model to consider the entire sequence of words at once [12].

    We will use Hugging Face for sentiment analysis. First, we will install transformers, if not already available in the environment:

    !pip install transformers

    In a typical NLP pipeline, several steps are taken to pre-process and prepare data to train and fine-tune the model. Hugging Face provides pipelines that perform all these steps with a few lines of code. Essentially, pipelines are composed of a tokeniser and a model to perform the task on the input text.

    Below, we will utilise a pipeline for sentiment analysis using BERT on this text: “Transfer learning can help overcome issues related to over-reliance on data in machine learning”.

    #importing pipeline 
    from transformers import pipeline

    The pipeline will download and cache a default pre-trained model (distilbert-base-uncased-finetuned-sst-2-english) and a tokeniser for sentiment analysis. We can then use the classifier on the input text:

    # Initialising the pipeline with default model
    classifier = pipeline('sentiment-analysis')

    # Performing sentiment analysis on below text
    classifier('Transfer learning can help overcome issues related to over-reliance on data in machine learning')

    The above pipeline produces a label “Positive” and a “score” of “0.908”. We can define a model in the classifier and perform the analysis. Below we will use “sentiment-roberta-large-english” [13].

    #Initialising the pipeline with the selected model
    classifier = pipeline('sentiment-analysis', mode='siebert/sentiment-roberta-large-english')
    # Performing sentiment analysis on below text
    classifier('Transfer learning can help overcome issues related to over-reliance on data in machine learning')

    The above pipeline produces a label “Positive” and a “score” of “0.998”. This is a slight improvement compared to the default model.

    Transfer learning for computer vision

    A pre-trained model can be used to make a prediction, extract features and fine-tune. Here, we will use ResNet50 pre-trained model for the classification of an example image. This is a 50 layers deep convolutional neural network trained on more than a million images from the ImageNet database. It can classify images into 1000 classes including many objects and animals [9].

    We will use the Keras library to implement transfer learning of ResNet50:

    #Importing required libraries
    from tensorflow.keras.applications.resnet50 import ResNet50
    from tensorflow.keras.preprocessing import image
    from tensorflow.keras.applications.resnet50 import preprocess_input, decode_predictions

    Next, we will define a base model and define the pre-trained weights:

    model = ResNet50(weights='imagenet')

    The classification can subsequently be performed on an example image and the result is printed and displayed:

    img_path = 'Resources/tiger.jpg'
    img = image.load_img(img_path, target_size=(224, 224))
    x = image.img_to_array(img)
    x = np.expand_dims(x, axis=0)
    x = preprocess_input(x)
    preds = model.predict(x)

    # Showing the image with prediction
    img_bgr = cv2.imread(img_path, 1)
    img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
    plt.title('Prediction using ResNet50 pre-trained model: {}'.format(decode_predictions(preds, top=1)[0]));

    Prediction using ResNet50 pre-trained model. The result correctly identifies a tiger in the image.

    Figure 1. Prediction using ResNet50 pre-trained model. The result correctly identifies a tiger in the image.


    In this article we looked at transfer learning - a machine learning technique that reuses a completed model that was developed for one task as the starting point for a new model to accomplish a new task. The knowledge used by the first model is thus transferred to the second model. 

    Transfer learning can speed up progress and improve performance when training a new model and it is therefore very appealing when access to compute infrastructure is limited and the time for fitting the model is of the essence.

    In a follow up post we'll look at N-shot and Zero-shot learning, variants of Transfer learning that were specifically designed to handle cases where examples are scarce or no labelled data is available at all.


    [1] S. Makridakis, "The forthcoming Artificial Intelligence (AI) revolution: Its impact on society and firms," Futures, vol. 90, pp. 46-60, 2017.

    [2] F. Rosenblatt, "The perceptron: a probabilistic model for information storage and organization in the brain," Psychological review, vol. 65, no. 6, p. 386, 1958.

    [3] W. S. McCulloch and W. Pitts, "A logical calculus of the ideas immanent in nervous activity," The bulletin of mathematical biophysics, vol. 5, no. 4, pp. 115-133, 1943.

    [4]        D. E. Rumelhart, G. E. Hinton, and R. J. Williams, "Learning representations by back-propagating errors," nature, vol. 323, no. 6088, pp. 533-536, 1986.

    [5]        M. McCloskey and N. J. Cohen, "Catastrophic interference in connectionist networks: The sequential learning problem," in Psychology of learning and motivation, vol. 24: Elsevier, 1989, pp. 109-165.

    [6]        R. M. Levenson, E. A. Krupinski, V. M. Navarro, and E. A. Wasserman, "Pigeons (Columba livia) as trainable observers of pathology and radiology breast cancer images," PloS one, vol. 10, no. 11, p. e0141357, 2015.

    [7]        K. Simonyan and A. Zisserman, "Very deep convolutional networks for large-scale image recognition," arXiv preprint arXiv:1409.1556, 2014.

    [8]        S. J. Pan and Q. Yang, "A survey on transfer learning," IEEE Transactions on knowledge and data engineering, vol. 22, no. 10, pp. 1345-1359, 2009.

    [9]        K. He, X. Zhang, S. Ren, and J. Sun, "Deep residual learning for image recognition," in Proceedings of the IEEE conference on computer vision and pattern recognition, 2016, pp. 770-778.

    [10]      J. Donahue et al., "Decaf: A deep convolutional activation feature for generic visual recognition," in International conference on machine learning, 2014: PMLR, pp. 647-655.

    [11]      T. Pfister, K. Simonyan, J. Charles, and A. Zisserman, "Deep convolutional neural networks for efficient pose estimation in gesture videos," in Asian Conference on Computer Vision, 2014: Springer, pp. 538-552.

    [12]      J. Devlin, M.-W. Chang, K. Lee, and K. Toutanova, "Bert: Pre-training of deep bidirectional transformers for language understanding," arXiv preprint arXiv:1810.04805, 2018.

    [13]      Y. Liu et al., "Roberta: A robustly optimized bert pretraining approach," arXiv preprint arXiv:1907.11692, 2019.

    Other posts you might be interested in

    Subscribe to the Data Science Blog

    Receive data science tips and tutorials from leading Data Scientists right to your inbox.