现代 AI 的“推理速度”:KV Cache 的工程真相

在 LLM 的宣传册中,我们经常看到“每秒生成 100 个 token”这样的速度指标。但对于开发者来说,决定这个速度的核心不在于 GPU 的算力(TFLOPS),而在于一个极其关键的内存优化机制:KV Cache(Key-Value Cache)。

专属插画
现代 AI 的“推理速度”:KV Cache 的工程真相

现代 AI 的“推理速度”:KV Cache 的工程真相

在 LLM 的宣传册中,我们经常看到“每秒生成 100 个 token”这样的速度指标。但对于开发者来说,决定这个速度的核心不在于 GPU 的算力(TFLOPS),而在于一个极其关键的内存优化机制:KV Cache(Key-Value Cache)

如果你在调用 API 时感觉到首字响应快但后续生成慢,或者在部署模型时发现显存被迅速吃光,那么你实际上正在与 KV Cache 打交道。

为什么需要 KV Cache?

要理解 KV Cache,首先要理解 Transformer 的自回归(Autoregressive)特性。

LLM 生成文本是一个一个 token 产生的。当你输入“今天天气”,模型预测出“很”;然后模型将“今天天气很”作为输入,预测出“好”。

在计算第 $N$ 个 token 时,模型需要计算当前 token 与之前所有 $N-1$ 个 token 的注意力(Attention)。这意味着:
1. 重复计算:如果没有缓存,每次生成新 token,都要重新计算前面所有 token 的 $Q$ (Query), $K$ (Key), $V$ (Value) 向量。
2. 复杂度爆炸:计算量随序列长度呈平方级增长 $\mathcal{O}(N^2)$。

KV Cache 的核心逻辑很简单:既然之前的 token 不会改变,那么它们的 $K$ 和 $V$ 向量在每一轮迭代中都是相同的。我们只需要把它们存起来,下次直接用。

KV Cache 是如何工作的?

在推理过程中,模型分为两个阶段:

1. Prefill 阶段(预填充)

当你发送 Prompt 时,模型一次性处理所有输入 token。此时它会计算所有输入 token 的 $K$ 和 $V$,并将它们写入显存中的 KV Cache 区域。这个阶段是计算密集型的(Compute-bound),因为它可以利用 GPU 的并行能力。

2. Decoding 阶段(解码)

生成每个新 token 时,模型只需要为这一个新 token 计算其 $Q, K, V$。然后,它将这个新的 $K, V$ 追加到缓存中,并利用缓存中的历史 $K, V$ 来计算注意力权重。这个阶段是内存带宽密集型的(Memory-bound),因为 GPU 大部分时间在等待从显存中读取巨大的 KV Cache 矩阵。

工程上的残酷代价:显存压力

KV Cache 虽然解决了计算冗余,但它引入了巨大的内存开销。

KV Cache 的大小取决于:$\text{层数} \times \text{头数} \times \text{维度} \times \text{序列长度} \times \text{精度}$。
以 Llama-3-8B 为例(FP16 精度):
- 每增加一个 token,每个请求大约需要消耗数百 KB 到数 MB 的显存。
- 当并发用户增加或上下文窗口扩大到 128K 时,KV Cache 会迅速撑爆 A100/H100 的显存,导致 OOM (Out of Memory)。

如何优化 KV Cache?(工业界方案)

为了在不牺牲性能的前提下支持更长上下文和更高并发,工业界采用了三种主流方案:

1. MQA / GQA (Multi-Query / Grouped-Query Attention)

这是从架构层面减少缓存量。传统的 MHA 每个 Query 头都有对应的 Key/Value 头;而 GQA 让多个 Query 头共享一组 KV 头。这直接将 KV Cache 的体积压缩了数倍(例如 Llama-3 就使用了 GQA)。

2. PagedAttention (vLLM)

这是目前最主流的系统级优化方案。传统的 KV Cache 要求连续的内存空间,导致严重的碎片化(类似早期的操作系统内存管理)。PagedAttention 将 KV Cache 分页存储在不连续的物理块中,实现了类似虚拟内存的管理方式,极大地提升了吞吐量并降低了浪费。

3. 量化 (Quantization)

将 KV Cache 从 FP16 量化到 INT8 或 FP8。这可以将显存占用直接减半,且对模型精度的影响微乎其微。

总结

KV Cache 是 LLM 推理从“实验室玩具”变成“工业产品”的关键工程基石。它将 $\mathcal{O}(N^2)$ 的重复计算转化为 $\mathcal{O}(N)$ 的空间换时间策略。当我们讨论 AI 推理成本和延迟时,本质上是在讨论如何更高效地管理这块昂贵的显存缓存。

留言区

欢迎分享你的想法!

发表留言

0/500

加载留言中…