Soumajyoti Sarkar

Target Agnostic Optimization of Data Distribution Mixtures in LLM Continual Learning

This blog post is a continuation of my previous work on Efficient continual pre-training of LLMs. While the earlier post focused on the foundational continual learning problem and data selection strategies, this post explores target-agnostic optimization methods for combining multiple data distributions during LLM training.

Motivation

Scale and diverse datasets during pretraining have known to be useful ingredients in the emergent abilities that come with model scaling. There are some interesting observations from this informal analysis that says that the ability to perform complex reasoning as an ability of LLMs is likely to arise from training on code data in addition to text. It points to the idea that different data distributions that are complementary can help derive newer abilities. Following that, there have also been attempts to structure inputs of Visual QA with code synthesis without any additional training. This also points to the advantage of test time compositional generalization when models come trained with different mix of data like code, latex, math, SQL.

One of the major requirements of any pretraining and prefinetuning stage is to continue adding new datasets to pretrained or prefinetuned/instruction finetuned checkpoints and continue training the models with a mix of replay data and newer data. This leads us to having multiple data distributions collected over time like multiple data collections of natural language text belonging to different domains, math, code, tabular, SQL data among others. However, the model being trained is then evaluated on multiple downstream tasks related to the input distributions of the data that was used to train the models but where there could be a shift in the input and output mapping space of the tasks.

For example, we would train the models on tables, text, SQL, code in the pretraining and prefinetuning stages, and expect the model to perform well on tabular math problems like numerical QA on tables in downstream task evaluation. Note that the exact input distribution of this numerical QA on tables task might not have been seen in the pretraining and prefinetuning stage. The pretraining/prefinetuning stage comes trained with math, text, tables data but not the exact input distribution for this task. So this is one of the requirements where we want a more robust way to combine data in these training stages but being target agnostic of the downstream tasks.

Additionally, when we start with pretrained and prefinetuned MTL (multi-task learned) checkpoints and want to continue training with newer task or domain data, there are two challenges that come along:

  1. It is not clear how to mix the previous data with the newer task data and how to adjust the mix over the steps/epochs based on some metric and how to handle the imbalances in some of the tasks
  2. What reward or metric to use from the distributions of the data to optimize this mix over the steps/epochs?

We aim to devise a method to solve the above three problems:

  1. Target agnostic robust way of combining datasets through an optimization
  2. Perform automatic adjustment of the mix over the steps or in simpler words, find a better curriculum than uniform mixing
  3. Optimize with regards to either the worst case task/domain group metrics without impacting generalization

The focus of the project is not to evaluate continual learning methods since our aim is not to directly develop continual learning baselines of class incremental or task incremental learning. As mentioned, the downstream tasks are unknown during the pretraining/prefinetuning stage.

Problem Statement

We consider a learning scenario from Agnostic Federated Learning where the learner receives pdata sources or groups S1,S2,,Spwith each:

Sk=((xk,1,yk,1),,(xk,mk,yk,mk))(X×Y)mk

of size mkdrawn i.i.d. from a different domain or distribution Dk. The learner's objective is to determine a hypothesis that performs well on some target distribution.

Notation: Let D^kdenote the empirical distribution associated to sample Skof size mdrawn from Dm.

Instead of the federated learning scenario where training is done with uniform distribution over the union of all samples Skwith the underlying target distribution:

U^=kpmkkmkDk

we consider the target distribution to be an unknown mixture of the distributions Dksuch that:

Dλ=k[p]λkDk

for some λΔpwhich is the simplex over p. Since this mixture weight is unknown, the learning procedure must come up with a way that is favorable to any λin the subset ΛΔp. We can define the agnostic risk LDΛ(h)associated to a predictor hbelonging to the hypothesis class as:

(1)hL=argminhHmaxλΛLDλ(h)

In language modeling setup, our goal is to use the above predictor based on sentences sampled from a mixture of distributions such that the predictor is able to find the correct data mixture weights and we then continue training the model with these mixture weights with ERM. In the literature, this problem is often known as multiple distribution learning (On-Demand Sampling). Note that, we are not solving an online version of this problem yet but we want to run this optimization framework at intervals of the ERM based usual training of the models like the seq2seq or autoregressive loss.

Proposed Approaches

Distributionally Robust Optimization (DRO)

Distributionally Robust Optimization in recent years has been one of the tools to address the problem of building a robust model. The general idea behind DRO is to find a model which minimizes the loss function lθ(x,y)with uncertainty set with the following formulation:

minθmaxqQEq[lθ(x,y)]

such that Qcovers test distributions of interest. There are many non-parametric formulations of DRO such as the f-divergences, Wasserstein, CVaR which is also used in DRO-LM. Group DRO and parametric DRO formulations are some other forms which solve the mini-max optimization problem by reformulating the optimization over Q. For example, On-Demand Sampling formulates DRO as a likelihood ratio optimization problem:

(2)minθmaxrRE(x,y)p[r(x,y)lθ(x,y)]

where pis the true data distribution and r:X×YR+is a function in the uncertainty set R{r|prQ}.

DoReMi Approach

The recent work of DoReMi: Optimizing Data Mixtures Speeds Up Language Model Pretraining used this framework of DRO for finding mixture weights of domains of data. The goal of DoReMi is not to produce a robust model but to use the DRO formulation to find the domain weights. As they mention in the paper, during optimization, DRO-LM takes a worst-case subset of each minibatch to update the model on, while the Group DRO optimizer which DoReMi uses doesn't require subselection. Subselection is expensive since it requires evaluating the model on a large minibatch, while only selecting a small fraction to update the model with. Overall, in contrast to these DRO methods which aim to produce robust models, we use DRO to optimize the data for training larger models more efficiently.

DoReMi runs at an early step and extrapolates the domain weights for the desired number of steps, since most of the variation in the domain weights during a DoReMi run seems to occur in the beginning of training. The optimization for the robust model at every step is based on the following:

(3)minθmaxαΔKL(θ,α)=minθmaxαΔki=1kαi[xDi1|x|xDilθ(x)lref(x)]

where lrefdenotes the loss function of a model of the same size as the robust model lθbeing trained and updated at every step, αidenotes the weights given to the domains. As is evident, the optimization focuses on the minimax excess risk and the maximization happens over the simplex of the domain weights.

The unsupervised datasets can be highly imbalanced, and we need sampling for the data distributions which DoReMi does not handle. We can start with random sampling and additionally the major issue of DoReMi is that it maintains two additional models along with the LLM. We want to avoid training additional models if necessary to solve the mini-max optimization problem.

In some of the baselines as a starting point, we just evaluate existing methods but none of them handle on-demand sampling from distributions over the optimization. So that is one feature we intend to add so as to speed up the training as well handle imbalances in data distributions.

Baselines

Since our pretraining objective will still be either the seq2seq loss for encoder-decoder models or the causal LM task for decoder only models, the data mixture optimization step would be done at regular intervals of Ktraining steps, so basically alternating between Ktraining steps (where Kis a hyper-parameter) and one mixture optimization step. Since the mixture optimization steps are costly, this ensures the impact to training speed is less. The following methods discuss the data mixture optimization step which is applied in those intervals (note that we are not applying this optimization step along with the ERM loss for the Ksteps of training).

Method 1: Simple Validation Feedback Based Optimization

We can observe the validation losses on tasks relating to the training data sources S1,S2,,Skand based on the loss, we can upweight the worst performing groups and continue the training. The advantage of this approach is that this is fairly straightforward to implement but the major disadvantage is that this approach would not work when we have a mix of unsupervised and supervised data and objectives since there would not be directly validation sets of the unsupervised data sources. The other hurdle is to correctly find the upweighting factor for the samples. Similarly, we can use meta-learning algorithm that learns to assign weights to training examples based on their gradient directions minimising a validation loss.

Just Train Twice (JTT) is another baseline that proceeds in two steps. First, JTT trains an ERM model for a small amount of epochs T. Following this, JTT trains a final ERM model on a dataset where the mistakes from the previous ERM model appear λuptimes.

Method 2: Group DRO with Early Stopping

The problem with applying DRO to overparameterized models is that if a model achieves zero training loss, then it is optimal on both the worst-case (DRO) and the average training objectives. We adopt the group DRO settings where the training distribution Pis assumed to be a mixture of mgroups Pgindexed by G={1,2,,m}. In our case, we can self-annotate the groups. The uncertainty set Qcan be defined as a mixture of these groups. Alternatively, we can assume that we know the groups of the training data sources, so there is some manual effort involved in the annotation.

The group DRO model minimizes the empirical worst-group risk R^(θ):

(4)θ^DRO:=argminθΘR^(θ)

where:

R^(θ):=maxgGE(x,y)P^g[l(θ;(x,y))]

There are already existing algorithms to solve these group DRO algorithms given in Distributionally Robust Neural Networks, and there are a few steps needed to explore for over-parametrized neural networks to make it work like regularization and early stopping.

Method 3: Modified DoReMi

We can modify DoReMi to use the pretrained model Mat each stage and use PEFT methods to finetune with the robust optimization loss and verify if that helps in the data mixture optimization.

Method 4: Lookahead DRO

The basic idea is to balance the average and worst case losses among data distributions. Rather than simple averaging as done in ERM, prior work in multitask learning often uses arbitrary weightings w={w1,,wn}in the probability simplex ΔN={w|wi>0,iwi=1}to modulate the losses of individual tasks:

minθi=1Nwili(θ)

Note that all these cases are explored over sequential stages of the learning procedure. We plan to use the lookahead DRO as proposed in Balancing Average and Worst-case Accuracy in Multitask Learning.