Generalizable and Stable Finetuning
of Pretrained Language Models on Low-Resource Texts
Annual Conference of the North American Chapter of
the Association for Computational Linguistics (NAACL), 2024
- Sai Ashish Somayajula
- Youwei Liang
- Li Zhang
- Abhishek Singh
- Pengtao Xie University of California, San Diego
Abstract
Pretrained Language Models (PLMs) have advanced Natural Language Processing (NLP) tasks significantly, but finetuning PLMs on low-resource datasets poses significant challenges such as instability and overfitting. Previous methods tackle these issues by finetuning a strategically chosen subnetwork on a downstream task, while keeping the remaining weights fixed to the pretrained weights. However, they rely on a suboptimal criteria for sub-network selection, leading to suboptimal solutions. To address these limitations, we propose a regularization method based on attention-guided weight mixup for finetuning PLMs. Our approach represents each network weight as a mixup of task-specific weight and pretrained weight, controlled by a learnable attention parameter, providing finer control over sub-network selection. Furthermore, we employ a bi-level optimization (BLO) based framework on two separate splits of the training dataset, improving generalization and combating overfitting. We validate the efficacy of our proposed method through extensive experiments, demonstrating its superiority over previous methods, particularly in the context of finetuning PLMs on low-resource datasets.
Video
Introduction
Pretraining large language models (PLMs) on vast amounts of unlabeled data, followed by fine-tuning on specific tasks, has greatly advanced natural language processing. However, conventional fine-tuning of PLMs faces challenges. Firstly, it's prone to instability, showing varying performance even with the same settings, especially on small datasets. Additionally, the large capacity of PLMs can lead to overfitting on small datasets, resulting in poor generalization. Thus, adapting PLMs to low-resource tasks while maintaining stability and maximizing generalization remains a significant challenge in NLP.
Prior approaches have effectively addressed challenges in fine-tuning pretrained language models (PLMs) by selectively updating a sub-network while keeping other weights fixed to the pre-trained weights. Strategies such as CHILD-TUNINGD and DPS dense hold promise, yet they rely on the Fisher Information Matrix (FIM) for sub-network selection, which may not be ideal, particularly in low-resource scenarios where data scarcity can skew FIM calculations. This reliance on FIM could lead to suboptimal performance. Therefore, advocating a departure from FIM-based discrete selection strategies in favor of those that select a child network based on the model's downstream task performance.
Introducing the "Attention-guided weight mixup" mechanism, we automatically learn a sub-network and train the downstream model weights. This approach transforms FIM-based discrete child network selection methods into a continuous relaxation, bypassing sub-optimal heuristic-based subnetwork selection. By leveraging bi-level optimization on two training dataset splits, we optimize task weights and attention parameters using gradient descent to enhance performance on downstream tasks.
Attention Guided Weights Mixup
Our method employs an attention-guided weight mixup mechanism, where each weight is a linear interpolation of task weights and pretrained weights, controlled by an attention parameter. Thereby, discrete sub-network selection morphs into determining optimal attention parameter that controls the blend of pre-trained and task weights. Specifically, for each weight in a PLM, denoted by W and W0 for task and pretrained weights respectively, an attention parameter α controls the blend. The resultant weight W
Here, o represents element-wise multiplication, and 1 is a matrix with all entries as 1's. α ranging from 0 to 1 allows a flexible transition from discrete to continuous selection, regulating the influence of task and pretrained weights in W
With this formulation, the task weights become dependent on these attention parameters, i.e. the chosen child network. However, in a reciprocal relationship, task weights should be considered while learning the attention parameters. This is because the attention parameters aim to ascertain an optimal blend of pretrained and task weights in the resultant weight computation, engendering a mutual dependency. Thus, to navigate this intricate interdependency, we employ a bi-level optimization (BLO) framework with two stages: 1) task weight finetuning driven by training loss minimization; 2) attention parameter optimization to minimize validation loss, ensuring an optimal balance between task and pretrained weights for improved performance.
Results
The major takeaways of the paper are summarized below. For a more comprehensive explanation, please refer to the paper.
Comparison with FIM-based sub-network selection methods on low-resource scenarios
We compare our method with Vanilla, CHILD-TUNINGD , and DPS dense method using BERTLARGE across 300, 500, and 1000 training data splits. Reported results are the averaged evaluation metrics over all eight GLUE datasets for each training data split. The highest performance in each row is indicated in bold. Our outperforms vanilla and prior FIM-based subnetwork selection based methods by a significant margin.
Comparison with parameter efficient finetuning (PEFT) methods on low-resource scenarios
Averaged performance across the CoLA, RTE, STSB, and MRPC datasets for Vanilla, Prompt Tuning, Prefix-Tuning, LoRA, and our method in low-resource scenarios with 500 and 1000 training instances. We observe that PEFT methods might not always outperform vanilla finetuning under low-resource scenarios. The primary goal of PEFT methods is to produce comparable performance to vanilla finetuning without incurring huge computational demands. However, our method is guaranteed to produce performance improvements under low-resource scenarios.
Evaluation across various PLMs shows consistently lower performance variability compared to vanilla models
Comparison of our method and vanilla finetuning on five popular PLMs. We evaluated the models using ten runs with different random seeds and reported the results in terms of mean and standard deviation. Average score represents the average performance across four datasets, and the best scores are highlighted in bold. The underlined values indicate occurrences of degenerate seeds. We observe a notable gain over vanilla, along with a substantial decrease in the standard deviation.
Comparison with prior regularization based methods
Comparison of our method with prior regularization-based methods on four small datasets (CoLA, RTE, MRPC, STSB), known for causing instability in BERTLARGE. The mean and standard deviation (std) of ten random seeds are reported for each method. Bold indicates the best performance. Our method surpasses all other baselines in terms of average scores, with a particularly notable improvement on the CoLA dataset over baselines, illustrating the effectiveness of our approach. Double-sided t-tests were performed between our method and the vanilla method. The p-values are less than 0.05, indicating statistically significant performance improvement over vanilla.
Citation
Acknowledgements
We thank Professor Taylor Berg for the insightful discussions on the work.
The website template was borrowed from Michaël Gharbi.