The costs of BERT Fine-Tuning on small datasets
Fine-tuning BERT or its variants has become one of the most popular and effective methods to tackle natural language processing tasks, especially those with limited data. BERT models have been downloaded more than 5.6 millions of times from Huggingface’s public server.
However, fine-tuning remains unstable, especially when
using the large variant of BERT (BERTLarge) on small datasets, arguably the most impactful use of BERT-style models. Identical learning processes with different random seeds often result in significantly different and sometimes degenerate models following fine-tuning, even though only a few, seemingly insignificant aspects of the learning process are impacted by the random seed (Phang et al., 2018; Lee et al., 2020; Dodge et al., 2020). In layman’s terms: every time you train BERT for your task, you get different results. This means you need to train again and again to get a good system. This makes scientific comparison challenging (Dodge et al., 2020) and creates huge costs, which are potentially unnecessary.
While the variance comes from randomness, we hypothesize that the major cause of this instability lies in the optimization process.
Revisiting Few-sample BERT Fine-tuning
We conducted an extensive empirical analysis of BERT fine-tuning optimization behaviors on three aspects to identify the root cause of instability:
- The Optimization Algorithm — we found that omitting debiasing in the BERTAdam algorithm (Devlin et al., 2019) is the main cause of degenerate models
- The Initialization — we found that re-initializing the top few layers of BERT stabilizes the fine-tuning procedure
- The Number of Training Iterations — we found that the model still requires hundreds of updates to converge.
1. Optimization Algorithm
We observed that omitting debiasing in the BERTAdam algorithm (Devlin et al., 2019) is the lead cause of degenerate fine-tuning runs. The following is a pseudo-code of the Adam algorithm (Kingma & Ba, 2014). BERTAdam omits lines 9 and 10 which are used to correct the biases in the first and second moment estimates.
Fine-tuning BERT with the original Adam (with bias correction) eradicates almost all degenerate model training outcomes and reduces the variance across multiple randomized trials. Here, we show the test performance distribution of 20 random trials with or without bias correction on four small datasets.
Since the variance is significantly reduced, practitioners can easily get a decent model within only one to five trials instead of fine-tuning up to 20 models and picking the best one.
We hypothesized that the top pre-trained layers of BERT are specific to the pre-training task and may not transfer to a dissimilar downstream task. We propose to re-initialize the top few layers of BERT to ease the fine-tuning procedure. We plot the training curves with and without re-initialization below, showing consistent improvement for models with re-initialized output layers.
The following figure shows the validation performance with different numbers of re-initialized layers. As we can see, re-initializing a single is already beneficial, while the best number of layers to re-initialize depends on the downstream tasks.
3. Number of Training Iterations
ASAPP – We also studied the conventional 3-epoch fine-tuning setup of BERT. Through extensive experiments on various datasets, we observe that the widely adopted 3-epoch setup is insufficient for few-sample datasets. Even with few training examples, the model still requires hundreds of updates to converge.
Revisiting Existing Methods for Few-sample BERT Fine-tuning
Instability in BERT fine-tuning, especially in few-sample settings, has been receiving significant attention recently. We revisited these methods given our analysis of the fine-tuning process, focusing on the impact of using the debiased Adam instead of BERTAdam.
To illustrate, the following figure shows the mean test performance and standard deviation on four datasets. “Int. Task” stands for transferring via an intermediate task (MNLI), “LLRD” stands for layerwise learning rate decay, “WD’’ stands for weight decay. Numbers that are statistically significantly better than the standard setting (left column) are in blue and underlined.
We found that the standard fine-tuning procedure using bias-corrected Adam already has a fairly small variance, making these more complex techniques largely unnecessary. Moreover, re-initialization and training longer can serve as simple yet hard to beat baselines that outperforms previous methods except “Int. Task’’ on RTE. The reason is that RTE is very similar to MNLI (the intermediate task).
Why this work matters
This work carefully investigates the current, broadly adopted optimization practices in BERT fine-tuning. Our findings significantly stabilize BERT fine-tuning on small datasets. Stable training has multiple benefits. It reduces deployment costs and time, potentially making natural language processing applications more feasible and affordable for companies and individuals with limited computational resources.
Our findings are focused on few-sample training scenarios, which opens, or at least eases the way for new applications at reduced data costs. The reduction in cost broadens the accessibility and reduces the energy footprint of BERT-based models. Applications that require frequent re-training are now easier and cheaper to deploy given the reduced training costs. This work also simplifies the scientific comparison between future fine-tuning methods by making training more stable, and therefore easier to reproduce.
Stable training has multiple benefits. It reduces deployment costs and time, potentially making natural language processing applications more feasible and affordable for companies and individuals with limited computational resources.
Read The Complete Paper:
This work has been accepted and will be published in ICLR 2021. Visit our poster during the virtual conference—Poster Session 2: May 3, 2021, 9 a.m. PDT & May 3, 2021, 11 a.m. PDT—to have some conversations with the authors.
Felix Wu, PhD is a Research Scientist at ASAPP. He received his Ph.D. in Computer Science from Cornell University under the supervision of Prof. Kilian Q. Weinberger and his B.S. in Computer Science and Information Engineering from National Taiwan University. His research interest includes Machine Learning and its applications such as Natural Language Processing and Computer Vision. Recently, he is focusing on designing efficient neural models.