This newsletter explores the latest advancements in deep learning architectures designed for long-context language modeling. We'll examine two recent papers that tackle the challenges of extending context windows, one focusing on the limitations of recurrent neural networks (RNNs) and the other proposing a novel evaluation metric for long-range memory in various architectures.
Stuffed Mamba: State Collapse and State Capacity of RNN-Based Long-Context Modeling by Yingfa Chen, Xinrong Zhang, Shengding Hu, Xu Han, Zhiyuan Liu, Maosong Sun https://arxiv.org/abs/2410.07145
Caption: This graph illustrates the relationship between RNN performance, training length, and contextual memory capacity. Shorter training lengths lead to over-parameterization and state collapse when the context exceeds training length, while longer training allows the model to approach the upper bound of its memory capacity, mitigating state collapse. The dashed red line indicates the optimal training length where performance peaks and under-parameterization begins.
RNNs, with their linear computational complexity relative to sequence length, present an attractive alternative to Transformers for long-context language modeling. However, their performance in contexts significantly longer than their training length has been disappointing. This paper investigates the underlying reasons for this limitation, identifying a phenomenon termed state collapse (SC) and proposing mitigation strategies.
SC manifests as a substantial performance drop when the input sequence length exceeds the training length. The authors attribute SC to state overparameterization, where the RNN's memory capacity surpasses the demands of the training data. This allows the model to simply memorize the training data rather than learning to selectively forget information as the context grows. Evidence for this is found in the observation that during SC, certain channels in the recurrent state exhibit exploding values, while others vanish after normalization.
The hidden state h<sub>t</sub> can be represented as a weighted sum of past inputs: h<sub>t</sub> = Σ<sup>t</sup><sub>i=1</sub> a<sub>i:t</sub>B<sub>i</sub>x<sub>i</sub>, where a<sub>i:t</sub> represents the strength of the memory trace for the i-th token at time step t. Analysis reveals that collapsing states tend to retain information from all tokens within the training length, failing to adequately decay older information.
To counter SC, the paper proposes three training-free mitigation methods: Forget More and Remember Less, which modifies the update rule to decrease the retention of new information and increase the decay of old information; State Normalization, which normalizes the state after each update to prevent exploding values; and Sliding Window by State Difference, which approximates a sliding window by calculating the difference between two states. A fourth method, continual training on longer sequences, encourages the model to learn appropriate forgetting behavior. Experiments on the Mamba-2 architecture demonstrate that these methods enable processing of sequences exceeding 1 million tokens without experiencing SC. Furthermore, training on longer sequences confirms that SC can be avoided when the training length surpasses the state capacity.
The paper also explores the connection between state capacity and state size. A linear relationship is observed between the training length at which SC disappears and the state size. Experiments on passkey retrieval reveal an exponential relationship between state capacity and state size for this specific task. Impressively, a Mamba-2 370M model achieves near-perfect accuracy on a 256K context length, outperforming comparably sized Transformer models. These findings highlight the potential of RNNs for efficient and effective long-context modeling.
Forgetting Curve: A Reliable Method for Evaluating Memorization Capability for Long-context Models by Xinyu Liu, Runsong Zhao, Pengcheng Huang, Chunyang Xiao, Bei Li, Jingang Wang, Tong Xiao, Jingbo Zhu https://arxiv.org/abs/2410.04727
This paper addresses the lack of a robust metric for evaluating the true memory capacity of long-context language models. Existing metrics like perplexity and "Needle in a Haystack" have limitations, such as conflating memory with other language modeling skills and being sensitive to prompt engineering. The authors propose the forgetting curve as a novel, standardized method for assessing long-context memorization.
The forgetting curve leverages the emergent copy ability of LLMs and consists of two curves: a copy accuracy curve and a language modeling (LM) accuracy curve. The copy accuracy curve measures the model's ability to reconstruct the second half of a string when given its first half. The LM accuracy curve measures the model's prediction accuracy for the same second half, but preceded by an unrelated prefix of equal length. Both curves employ teacher forcing and track token prediction accuracy. The difference between these curves reveals the model's memory behavior, transitioning from fine-grained memory (perfect replication) to coarse-grained memory (above-chance replication) and finally to amnesia (no difference from LM accuracy).
The forgetting curve was applied to 14 open-source long-context LLMs, including transformer and RNN/SSM architectures, with claimed context lengths ranging from 4k to 1M tokens. The results effectively visualized and quantified memory capabilities. For instance, Llama-3 models demonstrated a substantial improvement in fine-grained memory length (4k tokens) compared to Llama-2 (0 tokens), while both families largely confirmed their claimed coarse-grained memory lengths. Furthermore, transformer-based context extension techniques proved their effectiveness. However, RNN/SSM models, despite theoretically supporting infinite context, exhibited limited coarse-grained memory and no fine-grained memory.
Intriguingly, the forgetting curve revealed no direct correlation between memory capacity and perplexity. A Transformer-XL style model, known for improved perplexity at long range, showed decreasing perplexity with increasing context length, while its copy accuracy converged with LM accuracy after 1k tokens. This empirically supports the argument that perplexity primarily reflects short-context modeling and is not a reliable indicator of long-context memorization. The forgetting curve, by isolating memory from other language understanding aspects, offers a more robust and insightful evaluation of LLM memory, potentially guiding future research. The authors acknowledge current limitations, including the maximum measurable dependency length being half the total sequence length and the current focus on the Llama family, leaving broader applicability to other models as future work.
This newsletter has highlighted two key aspects of the ongoing quest for effective long-context language modeling. The "Stuffed Mamba" paper sheds light on the limitations of RNNs, specifically the phenomenon of state collapse and its connection to state overparameterization. The proposed mitigation techniques offer promising avenues for unlocking the full potential of RNNs for long sequences. Conversely, the "Forgetting Curve" paper introduces a novel evaluation metric that provides a more accurate and nuanced understanding of memory capacity across various architectures. By disentangling memory from other aspects of language modeling, the forgetting curve offers valuable insights into the strengths and weaknesses of different approaches, paving the way for more targeted research and development in the exciting field of long-context language modeling.