Soumajyoti Sarkar

Efficient continual pre-training of LLMs

I have been working on this topic on and off for some time now. While continual learning has been a subject of research for many decades, I have been frustrated to see that most of these techniques that seem to rely on some form of replay buffers cannot be adapted easily when pretraining large language models on billions of tokens due to costly optimization procedures for selecting replay buffers. So far, I haven't been able to materialize my research into any fruitful publications. Therefore, I decided to give a shot at trying to formulate the continual learning problem and propose few directions of research. This blog post will continue evolving over the next few weeks.

The objective of continual learning (CL) is to accumulate knowledge on non-stationary data or from a stream of non i.i.d. samples. Here I would discuss some methods that can be systematically applied for continual pretraining of LLMs. That being said, CL for Large Language Models (LLM) is known to be limiting in its implementation in two scenarios: (1) the data for previous domains/tasks are assumed to be too large to start pretraining with all data from scratch (2) catastrophic forgetting of past information leading to the memorization vs generalization issue in pre trained language models. In that regard, there are two aspects of CL to focus on:

1. Backward Transfer: which is the influence that learning on a new domain has on the performance on downstream tasks of domains on which the LLMs were previously pretrained on. Positive backward transfer exists when learning about some domain increases the performance on tasks on some previously used domain. The opposite negative transfer is catastrophic forgetting.

2. Forward transfer: which is the influence that learning on a new domain has on the performance on downstream
tasks of domains on which the LLMs will be pretrained on, in future. This is the aspect of transfer learning or generalization that has been known to be an advantage of pre-trained LLMs.

One of the utilities of CL for LLMs is in developing a more principled approach to adaptive data selection strategies for sampling the replay buffer for CL as a sequential sampling problem or as a constrained optimization problem with speedup on large datasets and model sizes as well as for being better geared towards mitigating negative backward transfer. In that sense, there are many abstractions of the problem that can be formulated. Here, I provide a few forms of abstraction.

Abstraction I

Formally, we have a language model M that has been trained on D={1,2,3} datasets, each of which has size Kd, dD and the total size of the data is K=dKd. An example would be - D comprises datasets from 5 domains or 10 languages on which M has already been trained on. We have a new dataset T of one/many domains with size R such that R<<K. We need to select examples or a subset XD such that X∣<<K that mitigates negative backward transfer while being able to achieve the domain improvement that comes along with training on T. Here R may be large (so for language models, they would be some several billion tokens), this differentiates it from intermediate low-resource pre-finetuning settings and post training model-editing with few examples, which is not the focus here. Additionally, we consider one aggregate dataset T for now and d can be fairly large.

Abstraction II

A different variant of this CL setup would be to select a subset XDsub​ where sub denotes specific indices of the dataset D (like choosing few datasets in D) along with T, as we might want to continually adapt the model to only specific downstream tasks in the process of continual learning.

Yet a different variation of the learning problem can be as follows: we have a language model M that has been trained on D datasets, we have a new dataset T of some domain. However, when training with T, for some datasets dD, we only have access to their validation data and the downstream task data related to d , which would inhibit us from including d in the replay buffer. The goal is use a mix of rehearsal on available data Dd and distillation/regularization in this case to mitigate forgetting.

Abstraction III

A tangential direction to handling the CL process is as follows: given a model with parameters Θ={θl}, l denoting the layers in a model, the goal is to decompose Θ into a shared parameter matrix σ and a domain-adaptive parameter matrix τ where the parameter matrix σ is updated using D and parameter matrix τ is updated using T. Note that this abstraction can be combined with the above three to improve the learning process.

A hard constrained continual learning would also entail constraints on the latency of the methods to enable a tradeoff between performance and time taken to adapt these models fast (LAT denoting the latency in either clock time or TFLOPS, C being the corresponding value)

minσ,X:∣X|RE(x,y)X[LCE(x,y,σ)]  s.t.LAT(X)C

Desiderata for the problem

The problem statement defined in this proposal is close to multi-task pretraining with adaptive data selection and we want to improve on the sampling technique for data from different tasks while training. Having defined the above abstractions, we lay out a few desiderata of the problem that applies to all of these abstractions.