Elman, in this newsletter, we delve into the cutting-edge of long-context language modeling, exploring two novel approaches that tackle the computational challenges of handling millions of tokens. As LLMs evolve to process increasingly longer sequences, efficient inference becomes paramount. These papers offer innovative solutions, from parallelization strategies to intelligent input reduction, pushing the boundaries of what's possible in the realm of long-context understanding.
Mnemosyne: Parallelization Strategies for Efficiently Serving Multi-Million Context Length LLM Inference Requests Without Approximations by Amey Agrawal, Junda Chen, Íñigo Goiri, Ramachandran Ramjee, Chaojie Zhang, Alexey Tumanov, Esha Choukse https://arxiv.org/abs/2409.17264
Caption: This diagram illustrates Mnemosyne's KV Cache Parallelism (KVP), where the key-value cache is distributed across multiple workers (KVP0 and KVP1), each containing two Sequence Pipeline Parallelism (SPP) units. Within each SPP unit, Tensor Parallelism (TP) is employed across 8 GPUs (TP 0-7), enabling parallel token generation and reducing Time Between Tokens (TBT). This 3D parallelization strategy (TP, SPP, KVP) allows Mnemosyne to efficiently handle long sequences and achieve near-interactive speeds for 10M token decodes.
Serving LLMs with extremely long contexts, reaching millions of tokens, presents significant challenges for inference. Existing techniques, while effective for training, fall short in addressing inference-specific issues such as varying prefill and decode phases and their distinct latency constraints: Time to First Token (TTFT) and Time Between Tokens (TBT). Current solutions also lack efficient batching support, hindering hardware utilization. Mnemosyne introduces a novel 3D parallelization strategy to tackle these challenges, enabling efficient and scalable inference for contexts up to 10 million tokens without approximations.
At its core, Mnemosyne introduces three key innovations: adaptive chunking, Sequence Pipeline Parallelism (SPP), and KV Cache Parallelism (KVP). Adaptive chunking dynamically adjusts the chunk size based on workload characteristics, optimizing the trade-off between chunking overhead and latency targets, particularly crucial for mixed batching scenarios. SPP combines chunked prefills with pipeline parallelism, scheduling chunks across pipeline stages immediately after the first stage completes. This leads to a near-linear speedup with an increasing number of GPUs, captured by the formula: T<sub>SPP</sub>(n,c) ≈ T<sub>p</sub>(n,c)/P<sub>SPP</sub>, where n is the sequence length, c is the chunk size, and P<sub>SPP</sub> is the number of pipeline stages. KVP distributes the key-value cache across multiple workers during the decode phase, parallelizing token generation and reducing TBT, as modeled by: T<sub>KVP</sub>(n) ≈ T<sub>attn</sub>(n)/P<sub>KVP</sub> + (T<sub>d</sub>(n) - T<sub>attn</sub>(n)) + T<sub>comm</sub>.
Mnemosyne integrates these techniques into a cohesive 3D parallelism strategy, combining Tensor Parallelism (TP) with SPP and KVP. This hierarchical design, where each KVP unit contains a full model replica with multiple pipeline stages utilizing TP, allows for efficient processing of long sequences by accelerating both prefill and decode computations. This architecture enables flexible scheduling for mixed workloads, crucial for real-world applications. Furthermore, the system incorporates platform-level optimizations in inter-process communication, model execution (using FlashInfer kernels and CUDA graphs), and page-table management (GPU-side page tables) to maximize performance. Evaluation on Llama-3 (8B and 70B) models demonstrates significant improvements in TTFT, TBT, and resource utilization, paving the way for truly interactive experiences with multi-million token contexts.
Discovering the Gems in Early Layers: Accelerating Long-Context LLMs with 1000x Input Token Reduction by Zhenmei Shi, Yifei Ming, Xuan-Phi Nguyen, Yingyu Liang, Shafiq Joty https://arxiv.org/abs/2409.17422
This paper introduces GemFilter, a novel and efficient approach to address the computational bottleneck of long-context LLM inference. The central idea revolves around leveraging the LLM's early layers to identify and select the most relevant tokens before full processing, drastically reducing the context length for subsequent computations. This insight stems from the observation that LLMs often pinpoint essential information in their initial layers, even before generating a response. The attention matrices within these "filter layers" effectively highlight crucial tokens, allowing for significant compression of the input sequence.
GemFilter operates in two stages. First, a forward pass through only the early filter layers of the LLM selects the top k tokens based on the attention scores from the last query token. This drastically reduces the input length, for instance, from 128,000 tokens to a mere 100. In the second stage, the selected "gem" tokens are fed to the full LLM for standard generation. This two-stage process significantly reduces the computational burden of the initial prompt computation phase, which is the bottleneck for long contexts. The time complexity for prompt computation is reduced from Θ(mhn²d) for standard attention to Θ(rhn²d) for GemFilter, where m is the total number of layers, r is the number of filter layers (r < m), h is the number of attention heads, and d is the hidden dimension.
Evaluations on various LLMs, including LLaMA 3.1 8B Instruct, Mistral Nemo 12B Instruct, and Phi 3.5 Mini 3.8B Instruct, demonstrate GemFilter's impressive performance gains. Using benchmarks like Needle in a Haystack and LongBench, it achieves a remarkable 2.4x speedup and a 30% reduction in GPU memory usage compared to state-of-the-art methods like SnapKV. GemFilter's simplicity, training-free nature, and broad applicability make it a compelling solution for optimizing long-context LLM inference. Moreover, its inherent interpretability, allowing direct inspection of the selected token sequence, offers valuable insights into the LLM's inner workings.
This newsletter highlights two distinct but complementary approaches to enhance long-context LLM inference. Mnemosyne tackles the problem through sophisticated parallelization, pushing the boundaries of scale to 10 million tokens. GemFilter, on the other hand, offers an elegant solution by intelligently reducing the input size, achieving significant speedups and memory savings. Together, these advancements represent significant progress toward making long-context LLMs more practical and efficient, opening up new possibilities for applications requiring extensive contextual understanding. The combination of hardware-level parallelization and intelligent input filtering represents a promising direction for future research in this rapidly evolving field.