Hi Elman,
In this newsletter, we'll delve into the latest advancements in deep learning architectures specifically designed to tackle the challenges of long-context language modeling. As the demand for LLMs capable of processing and generating extended text sequences grows, innovative approaches to memory management and computational efficiency become crucial. We'll explore two cutting-edge papers that offer promising solutions to these challenges, each employing unique strategies to optimize performance and scalability in long-context scenarios.
TreeKV: Smooth Key-Value Cache Compression with Tree Structures by Ziwei He, Jian Yuan, Haoli Bai, Jingwen Leng, Bo Jiang https://arxiv.org/abs/2501.04987
Caption: This diagram illustrates the tree structure of TreeKV, a novel KV cache compression method for LLMs. As the cache reaches its maximum size (4 in this example), TreeKV begins compressing by merging less important adjacent tokens, starting from the oldest entries, as indicated by the dashed box and the arrival of the "sky" token. This tree-based approach allows for smooth transitions in context granularity, prioritizing recent tokens while efficiently managing memory.
Scaling transformer-based LLMs for long sequences and resource-constrained environments requires efficient key-value (KV) cache management. Existing methods, which typically evict tokens based on position or global importance scores, often lead to information loss and regional biases. TreeKV offers a novel, training-free solution to this problem by employing a tree structure for smooth KV cache compression.
The core idea behind TreeKV is organizing keys and values in a tree-like hierarchy, leveraging temporal locality. This allows for smooth transitions in context granularity between short-range and long-range contexts. The motivation for this approach stems from a wavelet analysis, revealing that the contributions of tokens to generation gradually increase and diverge from neighboring tokens as they approach the end of a sequence. This suggests increasing complexity and variability from distant to nearby context. TreeKV's tree structure mirrors this observation, being sparse on the left (distant past) and dense on the right (recent context).
Unlike many other compression methods, TreeKV is applicable to both generation and prefilling stages. During decoding, it maintains a fixed cache size c. When the cache reaches capacity, a tree-based approach strategically evicts less important token KV pairs within a specific eviction scope (idx, idx + 1), cycling through the cache and prioritizing the removal of older tokens. Token importance is determined by averaged attention weights: s = a<sup>(t)<sup>T</sup>V<sup>(t)</sup>. During prefilling, TreeKV employs a similar strategy but operates on blocks of tokens for enhanced efficiency.
Experimental results on language modeling tasks using Llama-2-7B on PG19 and OpenWebText2 demonstrate TreeKV's effectiveness. With a 16x cache reduction, it achieved the lowest perplexity among all baselines, enabling the LLM to generalize to sequences of at least 16k tokens. On PG19 with a 16k context window, TreeKV outperformed the second-best method by 3.6%, and on OpenWebText2 by 1.1%. Furthermore, on the Longbench benchmark using Llama-3.2-1B-Instruct, TreeKV consistently outperformed other methods, achieving the best performance with only 6% of the budget at optimal efficiency, confirming the crucial role of the tree structure.
AdaSkip: Adaptive Sublayer Skipping for Accelerating Long-Context LLM Inference by Zhuomin He, Yizhen Yao, Pengfei Zuo, Bin Gao, Qinya Li, Zhenzhe Zheng, Fan Wu https://arxiv.org/abs/2501.02336
Caption: This figure illustrates different layer skipping strategies in Transformer models. a) Early Skipping removes layers at the beginning, b) Periodic Skipping removes layers at regular intervals, c) Early Exit stops computation after a certain layer, and d) AdaSkip (Adaptive skip layer), the proposed method, dynamically skips layers based on sublayer importance calculated using input-output similarity. The color intensity represents the importance of attention (A) and feed-forward network (F) sublayers, with red indicating high importance and blue indicating low importance.
AdaSkip presents a novel adaptive sublayer skipping method designed to accelerate long-context LLM inference. Existing layer-wise skipping strategies often fall short due to their inability to adapt to model and context variability, disregard for sublayer significance, and inapplicability to the prefilling phase. AdaSkip addresses these issues by leveraging on-the-fly similarity information (cosine similarity between input and output vectors: Similarity(a, b) = ab/||a|| ||b||) to pinpoint less important layers and sublayers (attention and FFN modules) for both prefilling and decoding phases.
AdaSkip's methodology comprises offline and online learning phases. Offline learning uses historical data to pre-compute sublayer importance during prefilling by averaging IO similarity across multiple tasks. This informs skipping predictions in new tasks. Online learning, during decoding, refines these decisions by dynamically assessing sublayer importance based on the current context using the first few decoded tokens. A scaling factor (Scale<sub>j</sub>) compensates for potential deviations in input and output vectors due to residual connections.
Experiments across various long-context benchmarks and models (LLaMA3.1-8B-128k, InternLM-7B-8k, Vicuna-v1.5-7B-16k) demonstrate AdaSkip's superior performance over existing strategies. For example, with 8 skipped sublayers, AdaSkip achieved near-full model performance on tasks like TREC and TriviaQA with the LLaMA model. Speedups varied across models and tasks, reaching up to 17% improvement over baselines. The results highlight the importance of adaptive, sublayer-wise skipping for long-context inference, allowing AdaSkip to effectively balance speed and accuracy, unlike fixed strategies. Furthermore, AdaSkip maintains strong end-to-end performance even with skipping in both prefilling and decoding phases.
This newsletter has showcased two innovative approaches to enhancing the efficiency and scalability of long-context LLMs. TreeKV introduces a novel tree-based KV cache compression method that leverages temporal locality and smooth transitions in context granularity, resulting in significant performance gains and reduced memory footprint. AdaSkip, on the other hand, focuses on adaptive sublayer skipping, dynamically identifying and bypassing less important computations during both prefilling and decoding phases. Both methods demonstrate promising results on various benchmarks and models, offering valuable insights into optimizing long-context LLM architectures. The distinct strategies employed by each method highlight the diverse avenues being explored to address the challenges of long-range dependencies in language modeling, paving the way for more powerful and efficient LLMs capable of handling increasingly complex and extensive textual data.