Hi Elman,
In this newsletter, we'll explore cutting-edge research tackling the challenge of training large language models (LLMs) with very long contexts. As you know, extending context windows beyond a few thousand tokens presents significant hurdles due to exploding memory requirements and computational costs. Traditional approaches often rely on short-context training followed by inference-time tricks, but this fails to address the growing need for training and fine-tuning LLMs directly on long sequences. This newsletter will delve into a novel technique called "adjoint sharding," which promises to revolutionize long-context training by dramatically reducing memory overhead. We'll examine its underlying principles, practical implementation, and potential impact on the future of LLMs.
Adjoint sharding for very long context training of state space models by Xingzi Xu, Amir Tavanaei, Kavosh Asadi, Karim Bouyarmane https://arxiv.org/abs/2501.00692
Caption: This diagram visualizes the adjoint sharding process, showcasing the sharded gradient computation across time steps (t) and layers (k). Each block represents a vector-Jacobian product (VJP) calculation for parameters A, B, and C, with inputs derived from hidden states (h), input tokens (x), and gradients of the loss with respect to intermediate outputs (dl/dy). This sharding enables parallel computation and significantly reduces memory requirements for training large language models with long contexts.
Training large language models (LLMs) with very long contexts remains a significant challenge due to limitations in GPU memory and computationally expensive training times. Existing methods often rely on training with shorter contexts and then using specialized techniques during inference to handle longer sequences. However, a rising number of real-world applications, such as fact extraction, summarization, and reconciliation, demand training or fine-tuning directly on long contexts. This paper introduces adjoint sharding, a novel technique designed to address this need by drastically reducing memory requirements during training, making long-context training computationally feasible.
At the heart of adjoint sharding lies the adjoint method, a memory-efficient optimization technique commonly used in dynamical systems. Unlike traditional backpropagation, which computes the gradient as a whole, adjoint sharding decomposes the gradient calculation into independent vector-Jacobian product (VJP) computations. The advantage here is that modern VJP computations are highly efficient, often comparable in speed to a single forward pass of the model, leading to significant speed improvements. The paper illustrates this method using state-space models (SSMs) and residual networks (ResNets), demonstrating its applicability to recurrent architectures, which are fundamental to many LLMs. Critically, the gradient computation is sharded across both the sequence dimension (t) and the layer dimension (k), as shown in the following formula:
$\frac{dL}{d\theta} = \bigoplus_{t=1}^{T}\bigoplus_{k=1}^{K}\left(\sum_{i=1}^{t} VJP_{A}(\frac{d{l_t}}{d{y_k^t}} \otimes h_{i-1}^{t,k}), \sum_{i=1}^{t} VJP_{B}(\frac{d{l_t}}{d{y_k^t}} \otimes x_i^{t,k}), VJP_{C}(\frac{d{l_t}}{d{y_k^t}} \otimes h_t^{t,k})\right)$
This sharded computation allows for distributing the memory load, enabling training with much longer contexts. To further optimize efficiency, the authors introduce truncated adjoint sharding. This variant prioritizes the most influential gradients and discards less significant ones by limiting the number of states considered to a fixed number T. This truncation results in a linear increase in the number of VJPs with respect to context length, a significant improvement over the quadratic increase observed in standard adjoint sharding. Furthermore, the paper presents distributed and parallel versions of the algorithm to further enhance scalability and training speed. Empirical results demonstrate a substantial reduction in memory usage—up to 3X for a 1.27B parameter LLM trained on sequences of 1M tokens in length. This memory saving translates to a significant increase in the maximum trainable context length, from 35K tokens to over 100K tokens for a 1.27B parameter model on a cluster of five AWS P4 instances.
This newsletter highlighted the significant advancements made by adjoint sharding in addressing the challenges of long context training for LLMs. By leveraging the adjoint method and strategically sharding the gradient computation, this technique drastically reduces memory requirements, paving the way for training on significantly longer sequences. The introduction of truncated adjoint sharding further enhances efficiency by prioritizing important gradients. These innovations hold significant promise for pushing the boundaries of LLM capabilities, enabling more effective training and fine-tuning on extensive contextual information, ultimately leading to more powerful and nuanced language models. While adjoint sharding presents a compelling solution, future research directions include further optimizing parallel implementations and exploring the theoretical convergence properties of the truncated variant. This continued exploration will be crucial for realizing the full potential of long-context training and unlocking new applications for LLMs.