In recent years, we have become increasingly good at training deep neural networks to learn a very accurate mapping from inputs to outputs, whether they are images, sentences, label predictions, etc. from large amounts of labeled data.
What our models still frightfully lack is the ability to generalize to conditions that are different from the ones encountered during training. When is this necessary? Every time you apply your model not to a carefully constructed dataset but to the real world. The real world is messy and contains an infinite number of novel scenarios, many of which your model has not encountered during training and for which it is in turn ill-prepared to make predictions. The ability to transfer knowledge to new conditions is generally known as transfer learning and is what we will discuss in the rest of this post.
Over the course of this blog post, I will first contrast transfer learning with machine learning's most pervasive and successful paradigm, supervised learning. I will then outline reasons why transfer learning warrants our attention. Subsequently, I will give a more technical definition and detail different transfer learning scenarios. I will then provide examples of applications of transfer learning before delving into practical methods that can be used to transfer knowledge. Finally, I will give an overview of related directions and provide an outlook into the future.
What is Transfer Learning?
In the classic supervised learning scenario of machine learning, if we intend to train a model for some task and domain AA, we assume that we are provided with labeled data for the same task and domain. We can see this clearly in Figure 1, where the task and domain of the training and test data of our model AA is the same. We will later define in more detail what exactly a task and a domain are). For the moment, let us assume that a task is the objective our model aims to perform, e.g. recognize objects in images, and a domain is where our data is coming from, e.g. images taken in San Francisco coffee shops.
We can now train a model AA on this dataset and expect it to perform well on unseen data of the same task and domain. On another occasion, when given data for some other task or domain BB, we require again labeled data of the same task or domain that we can use to train a new model BB so that we can expect it to perform well on this data.
The traditional supervised learning paradigm breaks down when we do not have sufficient labeled data for the task or domain we care about to train a reliable model.
If we want to train a model to detect pedestrians on night-time images, we could apply a model that has been trained on a similar domain, e.g. on day-time images. In practice, however, we often experience a deterioration or collapse in performance as the model has inherited the bias of its training data and does not know how to generalize to the new domain.
If we want to train a model to perform a new task, such as detecting bicyclists, we cannot even reuse an existing model, as the labels between the tasks differ.
Transfer learning allows us to deal with these scenarios by leveraging the already existing labeled data of some related task or domain. We try to store this knowledge gained in solving the source task in the source domain and apply it to our problem of interest as can be seen in Figure 2.
In practice, we seek to transfer as much knowledge as we can from the source setting to our target task or domain. This knowledge can take on various forms depending on the data: it can pertain to how objects are composed to allow us to more easily identify novel objects; it can be with regard to the general words people use to express their opinions, etc.
Why Transfer Learning Now?
Andrew Ng, chief scientist at Baidu and professor at Stanford, said during his widely popular NIPS 2016 tutorial that transfer learning will be -- after supervised learning -- the next driver of ML commercial success.
In particular, he sketched out a chart on a whiteboard that I've sought to replicate as faithfully as possible in Figure 4 below (sorry about the unlabelled axes). According to Andrew Ng, transfer learning will become a key driver of Machine Learning success in industry.
It is indisputable that ML use and success in industry has so far been mostly driven by supervised learning. Fuelled by advances in Deep Learning, more capable computing utilities, and large labeled datasets, supervised learning has been largely responsible for the wave of renewed interest in AI, funding rounds and acquisitions, and in particular the applications of machine learning that we have seen in recent years and that have become part of our daily lives. If we disregard naysayers and heralds of another AI winter and instead trust the prescience of Andrew Ng, this success will likely continue.
It is less clear, however, why transfer learning which has been around for decades and is currently little utilized in industry, will see the explosive growth predicted by Ng. Even more so as transfer learning currently receives relatively little visibility compared to other areas of machine learning such as unsupervised learning and reinforcement learning, which have come to enjoy increasing popularity: Unsupervised learning -- the key ingredient on the quest to General AI according to Yann LeCun as can be seen in Figure 5 -- has seen a resurgence of interest, driven in particular by Generative Adversarial Networks. Reinforcement learning, in turn, spear-headed by Google DeepMind has led to advances in game-playing AI exemplified by the success of AlphaGo and has already seen success in the real world, e.g. by reducing Google's data center cooling bill by 40%. Both of these areas, while promising, will likely only have a comparatively small commercial impact in the foreseeable future and mostly remain within the confines of cutting-edge research papers as they still face many challenges.
What makes transfer learning different? In the following, we will look at the factors that -- in our opinion -- motivate Ng's prognosis and outline the reasons why just now is the time to pay attention to transfer learning.
The current use of machine learning in industry is characterised by a dichotomy:
On the one hand, over the course of the last years, we have obtained the ability to train more and more accurate models. We are now at the stage that for many tasks, state-of-the-art models have reached a level where their performance is so good that it is no longer a hindrance for users. How good? The newest residual networks [1] on ImageNet achieve superhuman performance at recognising objects; Google's Smart Reply [2] automatically handles 10% of all mobile responses; speech recognition error has consistently dropped and is more accurate than typing [3]; we can automatically identify skin cancer as well as dermatologists; Google's NMT system [4] is used in production for more than 10 language pairs; Baidu can generate realistic sounding speech in real-time; the list goes on and on. This level of maturity has allowed the large-scale deployment of these models to millions of users and has enabled widespread adoption.
On the other hand, these successful models are immensely data-hungry and rely on huge amounts of labeled data to achieve their performance. For some tasks and domains, this data is available as it has been painstakingly gathered over many years. In a few cases, it is public, e.g. ImageNet [5], but large amounts of labeled data are usually proprietary or expensive to obtain, as in the case of many speech or MT datasets, as they provide an edge over the competition.
At the same time, when applying a machine learning model in the wild, it is faced with a myriad of conditions which the model has never seen before and does not know how to deal with; each client and every user has their own preferences, possesses or generates data that is different than the data used for training; a model is asked to perform many tasks that are related to but not the same as the task it was trained for. In all of these situations, our current state-of-the-art models, despite exhibiting human-level or even super-human performance on the task and domain they were trained on, suffer a significant loss in performance or even break down completely.
Transfer learning can help us deal with these novel scenarios and is necessary for production-scale use of machine learning that goes beyond tasks and domains were labeled data is plentiful. So far, we have applied our models to the tasks and domains that -- while impactful -- are the low-hanging fruits in terms of data availability. To also serve the long tail of the distribution, we must learn to transfer the knowledge we have acquired to new tasks and domains.
To be able to do this, we need to understand the concepts that transfer learning involves. For this reason, we will give a more technical definition in the following section.