Efficient continual pre-training of LLMs
I have been working on this topic on and off for some time now. While continual learning has been a subject of research for many decades, I have been frustrated to see that most of these techniques that seem to rely on some form of replay buffers cannot be adapted easily when pretraining large language models on billions of tokens due to costly optimization procedures for selecting replay buffers. So far, I haven't been able to materialize my research into any fruitful publications. Therefore, I decided to give a shot at trying to formulate the continual learning problem and propose few directions of research. This blog post will continue evolving over the next few weeks.
The objective of continual learning (CL) is to accumulate knowledge on non-stationary data or from a stream of non i.i.d. samples. Here I would discuss some methods that can be systematically applied for continual pretraining of LLMs. That being said, CL for Large Language Models (LLM) is known to be limiting in its implementation in two scenarios: (1) the data for previous domains/tasks are assumed to be too large to start pretraining with all data from scratch (2) catastrophic forgetting of past information leading to the memorization vs generalization issue in pre trained language models. In that regard, there are two aspects of CL to focus on:
1. Backward Transfer: which is the influence that learning on a new domain has on the performance on downstream tasks of domains on which the LLMs were previously pretrained on. Positive backward transfer exists when learning about some domain increases the performance on tasks on some previously used domain. The opposite negative transfer is catastrophic forgetting.
2. Forward transfer: which is the influence that learning on a new domain has on the performance on downstream
tasks of domains on which the LLMs will be pretrained on, in future. This is the aspect of transfer learning or generalization that has been known to be an advantage of pre-trained LLMs.
One of the utilities of CL for LLMs is in developing a more principled approach to adaptive data selection strategies for sampling the replay buffer for CL as a sequential sampling problem or as a constrained optimization problem with speedup on large datasets and model sizes as well as for being better geared towards mitigating negative backward transfer. In that sense, there are many abstractions of the problem that can be formulated. Here, I provide a few forms of abstraction.
Abstraction I
Formally, we have a language model
Abstraction II
A different variant of this CL setup would be to select a subset
Yet a different variation of the learning problem can be as follows: we have a language model
Abstraction III
A tangential direction to handling the CL process is as follows: given a model with parameters
A hard constrained continual learning would also entail constraints on the latency of the methods to enable a tradeoff between performance and time taken to adapt these models fast (
Desiderata for the problem
The problem statement defined in this proposal is close to multi-task pretraining with adaptive data selection and we want to improve on the sampling technique for data from different tasks while training. Having defined the above abstractions, we lay out a few desiderata of the problem that applies to all of these abstractions.
We want to minimize negative backward transfer on downstream tasks related to
but for each new task, we do not want to iterate over all of (which we would be infeasible) and neither do we want to store or compute gradients offline/online over all of on the current or initial starting model. As an example, we might want to continually train multilingual models on topic specific corpuses like Covid corpus or a Twitter corpus in English. So for multilingual models, this means ensuring that language specific task performances do not degrade after the training. Additionally, we can also continually pretrain the language model with more data of one domain like tables or language like Egyptian Arabic so as to improve its performance, in that sense, the language or the tabular data becomes the new domain.We want to avoid techniques like online batch selection during the CL process that solves a subset selection optimization problem for each mini-batch in an SGD optimization. This would not be feasible in terms of computation for large models and large data.
We want to avoid techniques like network expansion where possible, since these are often hard to adapt with model size restrictions and specific model types to use. Also, with post-training techniques like adapters for parameter-efficient finetuning so, adding our own adapters or prefixes like soft-prompts in the pre-training/pre-finetuning stage increases the complexity.