The "Memory Fragments" of Modern AI: A Deep Dive into KV Cache
When discussing LLM inference optimization, PagedAttention and Speculative Decoding are frequently mentioned. However, they share a single core object: KV Cache
The "Memory Fragments" of Modern AI: A Deep Dive into KV Cache
When discussing LLM inference optimization, PagedAttention and Speculative Decoding are frequently mentioned. However, they share a single core object: KV Cache (Key-Value Cache). If you want to understand why large model inference consumes so much VRAM and why inference speed slows down as the context length increases, the KV Cache is the only entry point.
What is KV Cache?
The core of an LLM is the Transformer architecture. During text generation, the model operates in an autoregressive mode: for every new token generated, all previously generated tokens must be fed back into the model as input.
In the Transformer's attention mechanism, each token produces three vectors: Query (Q), Key (K), and Value (V).
- Query: What the current token is "looking for."
- Key: What historical tokens "can provide."
- Value: The "actual content" contained in historical tokens.
The calculation process is: $\text{Attention}(Q, K, V) = \text{softmax}(\frac{QK^T}{\sqrt{d_k}})V$.
The key point is that for historical tokens that have already been generated, their $K$ and $V$ vectors remain completely unchanged in subsequent generation steps. If we were to recalculate the $K$ and $V$ vectors for all historical tokens every time a new token is generated, the computational cost would grow quadratically with sequence length, i.e., $\mathcal{O}(n^2)$.
To avoid this redundant computation, we store the $K$ and $V$ vectors produced at each step in VRAM. This is the KV Cache. Consequently, at each step, we only need to compute the $Q, K, V$ for the current new token and then directly read the previous $K, V$ from the cache for matrix multiplication. This reduces the computational complexity to $\mathcal{O}(n)$.
The Cost of KV Cache: The VRAM Black Hole
Although KV Cache significantly boosts speed, it imposes immense pressure on VRAM.
1. Calculation Formula
The size of a model's KV Cache depends on:
- $\text{batch_size}$ (concurrency)
- $\text{seq_len}$ (sequence length)
- $\text{num_layers}$ (number of layers)
- $\text{num_heads}$ (number of attention heads)
- $\text{head_dim}$ (dimension per head)
- $\text{precision}$ (data precision, e.g., 2 bytes for FP16)
The formula is: $\text{Memory} = 2 \times \text{batch_size} \times \text{seq_len} \times \text{num_layers} \times \text{num_heads} \times \text{head_dim} \times \text{precision}$
(Multiplied by 2 because both Key and Value must be stored).
2. Practical Quantification
Take Llama-3-8B as an example (assuming FP16):
- Layers: 32, Heads: 32, Dimension per head: 128.
- KV size generated by a single token in a single layer = $2 \times 32 \times 128 \times 2\text{ bytes} = 16\text{ KB}$.
- KV size for a single token across the entire model = $32\text{ layers} \times 16\text{ KB} = 512\text{ KB}$.
Doesn't seem like much? But if batch_size=32 and seq_len=4096:
$32 \times 4096 \times 512\text{ KB} \approx 67\text{ GB}$.
This means that even if the model weights themselves occupy only about 15 GB, supporting high-concurrency long-text inference might require an A100 (80 GB) or more GPUs, just to store these "memory fragments."
How to Optimize KV Cache?
Facing VRAM pressure, the industry has evolved three mainstream solutions:
MQA and GQA (Structural Optimization)
Traditional Multi-Head Attention (MHA) assigns one Key/Value head to each Query head.
- MQA (Multi-Query Attention): All Query heads share a single pair of KV heads. VRAM usage is directly reduced to $1/\text{num_heads}$ of the original, but with some loss in precision.
- GQA (Grouped-Query Attention): A compromise solution. Query heads are grouped, and each group shares a pair of KV heads (as used in Llama-3). This significantly reduces VRAM usage while maintaining performance.
PagedAttention (Memory Management Optimization)
Traditional KV Cache requires contiguous memory space, leading to severe internal fragmentation. PagedAttention, introduced by vLLM, stores KV Cache in pages (similar to virtual memory in operating systems). This allows for non-contiguous storage and dynamic on-demand allocation, boosting VRAM utilization to nearly 100%.
Quantization (Precision Optimization)
Quantizing the KV Cache from FP16 to INT8 or FP8, or even INT4. This can directly halve (or more) the VRAM usage, with a negligible impact on model generation quality.
Conclusion
KV Cache is a prime example of "trading space for time" in LLM inference. It solves the problem of redundant computation in autoregressive generation but has also become the biggest bottleneck limiting throughput and context length. From GQA to PagedAttention and quantization, the evolution of AI systems engineering is essentially a struggle against this massive "memory fragment."
Comments
Share your thoughts!
Loading comments…