Target Agnostic Optimization of Data Distribution Mixtures in LLM Continual Learning
This blog post is a continuation of my previous work on Efficient continual pre-training of LLMs. While the earlier post focused on the foundational continual learning problem and data selection strategies, this post explores target-agnostic optimization methods for combining multiple data distributions during LLM training.
Motivation
Scale and diverse datasets during pretraining have known to be useful ingredients in the emergent abilities that come with model scaling. There are some interesting observations from this informal analysis that says that the ability to perform complex reasoning as an ability of LLMs is likely to arise from training on code data in addition to text. It points to the idea that different data distributions that are complementary can help derive newer abilities. Following that, there have also been attempts to structure inputs of Visual QA with code synthesis without any additional training. This also points to the advantage of test time compositional generalization when models come trained with different mix of data like code, latex, math, SQL.
One of the major requirements of any pretraining and prefinetuning stage is to continue adding new datasets to pretrained or prefinetuned/instruction finetuned checkpoints and continue training the models with a mix of replay data and newer data. This leads us to having multiple data distributions collected over time like multiple data collections of natural language text belonging to different domains, math, code, tabular, SQL data among others. However, the model being trained is then evaluated on multiple downstream tasks related to the input distributions of the data that was used to train the models but where there could be a shift in the input and output mapping space of the tasks.
For example, we would train the models on tables, text, SQL, code in the pretraining and prefinetuning stages, and expect the model to perform well on tabular math problems like numerical QA on tables in downstream task evaluation. Note that the exact input distribution of this numerical QA on tables task might not have been seen in the pretraining and prefinetuning stage. The pretraining/prefinetuning stage comes trained with math, text, tables data but not the exact input distribution for this task. So this is one of the requirements where we want a more robust way to combine data in these training stages but being target agnostic of the downstream tasks.
Additionally, when we start with pretrained and prefinetuned MTL (multi-task learned) checkpoints and want to continue training with newer task or domain data, there are two challenges that come along:
- It is not clear how to mix the previous data with the newer task data and how to adjust the mix over the steps/epochs based on some metric and how to handle the imbalances in some of the tasks
- What reward or metric to use from the distributions of the data to optimize this mix over the steps/epochs?
We aim to devise a method to solve the above three problems:
- Target agnostic robust way of combining datasets through an optimization
- Perform automatic adjustment of the mix over the steps or in simpler words, find a better curriculum than uniform mixing
- Optimize with regards to either the worst case task/domain group metrics without impacting generalization
The focus of the project is not to evaluate continual learning methods since our aim is not to directly develop continual learning baselines of class incremental or task incremental learning. As mentioned, the downstream tasks are unknown during the pretraining/prefinetuning stage.
Problem Statement
We consider a learning scenario from Agnostic Federated Learning where the learner receives
of size
Notation: Let
Instead of the federated learning scenario where training is done with uniform distribution over the union of all samples
we consider the target distribution to be an unknown mixture of the distributions
for some
In language modeling setup, our goal is to use the above predictor based on sentences sampled from a mixture of distributions such that the predictor is able to find the correct data mixture weights and we then continue training the model with these mixture weights with ERM. In the literature, this problem is often known as multiple distribution learning (On-Demand Sampling). Note that, we are not solving an online version of this problem yet but we want to run this optimization framework at intervals of the ERM based usual training of the models like the seq2seq or autoregressive loss.
Proposed Approaches
Distributionally Robust Optimization (DRO)
Distributionally Robust Optimization in recent years has been one of the tools to address the problem of building a robust model. The general idea behind DRO is to find a model which minimizes the loss function
such that
where
DoReMi Approach
The recent work of DoReMi: Optimizing Data Mixtures Speeds Up Language Model Pretraining used this framework of DRO for finding mixture weights of domains of data. The goal of DoReMi is not to produce a robust model but to use the DRO formulation to find the domain weights. As they mention in the paper, during optimization, DRO-LM takes a worst-case subset of each minibatch to update the model on, while the Group DRO optimizer which DoReMi uses doesn't require subselection. Subselection is expensive since it requires evaluating the model on a large minibatch, while only selecting a small fraction to update the model with. Overall, in contrast to these DRO methods which aim to produce robust models, we use DRO to optimize the data for training larger models more efficiently.
DoReMi runs at an early step and extrapolates the domain weights for the desired number of steps, since most of the variation in the domain weights during a DoReMi run seems to occur in the beginning of training. The optimization for the robust model at every step is based on the following:
where
The unsupervised datasets can be highly imbalanced, and we need sampling for the data distributions which DoReMi does not handle. We can start with random sampling and additionally the major issue of DoReMi is that it maintains two additional models along with the LLM. We want to avoid training additional models if necessary to solve the mini-max optimization problem.
In some of the baselines as a starting point, we just evaluate existing methods but none of them handle on-demand sampling from distributions over the optimization. So that is one feature we intend to add so as to speed up the training as well handle imbalances in data distributions.
Baselines
Since our pretraining objective will still be either the seq2seq loss for encoder-decoder models or the causal LM task for decoder only models, the data mixture optimization step would be done at regular intervals of
Method 1: Simple Validation Feedback Based Optimization
We can observe the validation losses on tasks relating to the training data sources
Just Train Twice (JTT) is another baseline that proceeds in two steps. First, JTT trains an ERM model for a small amount of epochs
Method 2: Group DRO with Early Stopping
The problem with applying DRO to overparameterized models is that if a model achieves zero training loss, then it is optimal on both the worst-case (DRO) and the average training objectives. We adopt the group DRO settings where the training distribution
The group DRO model minimizes the empirical worst-group risk
where:
There are already existing algorithms to solve these group DRO algorithms given in Distributionally Robust Neural Networks, and there are a few steps needed to explore for over-parametrized neural networks to make it work like regularization and early stopping.
Method 3: Modified DoReMi
We can modify DoReMi to use the pretrained model
Method 4: Lookahead DRO
The basic idea is to balance the average and worst case losses among data distributions. Rather than simple averaging as done in ERM, prior work in multitask learning often uses arbitrary weightings
Note that all these cases are explored over sequential stages of the learning procedure. We plan to use the lookahead DRO as proposed in Balancing Average and Worst-case Accuracy in Multitask Learning.