KV Cache:为什么显存越用越多
推理时,模型参数大小是固定的。但跑着跑着,显存占用越来越高。
罪魁祸首是 KV Cache。
这篇讲清楚 KV Cache 是什么、为什么占这么多显存、怎么优化。
先回顾 Attention
Transformer 的核心是 Attention 机制:
Attention(Q, K, V) = softmax(Q × K^T / √d) × V
对于输入序列的每个位置:
- Q(Query):「我想找什么」
- K(Key):「我有什么」
- V(Value):「我的内容是什么」
计算时,当前位置的 Q 要和之前所有位置的 K 做点积,得到注意力权重,再和 V 加权求和。
为什么需要 KV Cache
自回归生成
大模型生成文本是自回归的:每次只生成一个 token,然后把这个 token 加到输入里,再生成下一个。
输入:「今天天气」
第 1 步:生成「很」→ 输入变成「今天天气很」
第 2 步:生成「好」→ 输入变成「今天天气很好」
第 3 步:生成「。」→ 输入变成「今天天气很好。」
...
问题:重复计算
如果每次都从头算 Attention:
第 1 步:算 token 0-3 的 K、V
第 2 步:算 token 0-4 的 K、V(token 0-3 重复算了)
第 3 步:算 token 0-5 的 K、V(token 0-4 重复算了)
...
之前的 token 反复计算,效率很低。
解决方案:缓存 K 和 V
把算过的 K、V 存起来,下次直接用:
第 1 步:算 token 0-3 的 K、V,存入 cache
第 2 步:从 cache 读 token 0-3 的 K、V,只算 token 4 的 K、V,追加到 cache
第 3 步:从 cache 读 token 0-4 的 K、V,只算 token 5 的 K、V,追加到 cache
...
这就是 KV Cache:缓存历史 token 的 Key 和 Value。
KV Cache 的大小
计算公式
单个 token 的 KV Cache = 2 × num_layers × hidden_size × 2(K 和 V)× 精度
总 KV Cache = 单个 token 大小 × 序列长度 × batch_size
具体例子
LLaMA-7B:
- num_layers = 32
- hidden_size = 4096
- 精度:FP16(2 字节)
单个 token:
2 × 32 × 4096 × 2 = 512KB
如果上下文长度 4096:
512KB × 4096 = 2GB
如果同时处理 8 个请求:
2GB × 8 = 16GB
光 KV Cache 就 16GB 了。
更大的模型
| 模型 | 单 token KV | 4K 上下文 | 32K 上下文 |
|---|---|---|---|
| 7B | 512KB | 2GB | 16GB |
| 13B | 800KB | 3.2GB | 25.6GB |
| 70B | 2.5MB | 10GB | 80GB |
70B 模型 32K 上下文,KV Cache 就要 80GB。模型参数还要 140GB。一张卡根本放不下。
KV Cache 的问题
1. 显存占用大
上面算过了。长上下文 + 高并发,KV Cache 轻松超过模型参数本身。
2. 动态增长
KV Cache 随序列长度增长:
生成第 1 个 token:KV Cache 增加 512KB
生成第 2 个 token:KV Cache 增加 512KB
...
生成第 4096 个 token:KV Cache 达到 2GB
显存占用不固定,容易 OOM。
3. 碎片化
传统做法:为每个请求预分配最大长度的 KV Cache。
问题:
- 大部分请求用不到最大长度,显存浪费
- 不同请求长度不同,显存碎片化
KV Cache 优化
1. 量化
KV Cache 也可以量化:
FP16 → INT8:大小减半
FP16 → INT4:大小减 75%
对效果影响不大,但显存省很多。
2. PagedAttention
vLLM 的核心优化。把 KV Cache 分成固定大小的「页」:
传统:为每个请求分配连续的 KV Cache
[请求1: 4096 tokens][请求2: 4096 tokens][请求3: 4096 tokens]
PagedAttention:分页管理
[页1][页2][页3][页4][页5][页6]...
请求1 用页 1, 3, 5
请求2 用页 2, 4
请求3 用页 6
好处:
- 按需分配,不浪费
- 消除碎片
- 支持共享(相同前缀的请求共享 KV Cache)
3. Sliding Window Attention
有些模型只看最近 N 个 token,不用存全部历史:
传统:存 token 0 到 token 4095 的 KV
滑动窗口:只存最近 1024 个 token 的 KV
显存占用固定,但会丢失长程依赖。
Mistral 模型用了这个技术。
4. Multi-Query Attention (MQA)
让多个 Query head 共享同一组 K、V:
传统 MHA:32 个 Q head,32 个 K head,32 个 V head
MQA:32 个 Q head,1 个 K head,1 个 V head
KV Cache 大小降到原来的 1/32。
缺点是效果略有下降。
5. Grouped-Query Attention (GQA)
MQA 和 MHA 的折中:
GQA:32 个 Q head,8 个 K head,8 个 V head
每 4 个 Q head 共享一组 K、V。
LLaMA-2 70B 用了 GQA。
对比
| 方法 | KV Cache 大小(相对 MHA) | 效果影响 |
|---|---|---|
| MHA(传统) | 1x | 基准 |
| GQA | 0.25x(8组) | 很小 |
| MQA | 0.03x(1组) | 有一些 |
| 量化 INT8 | 0.5x | 很小 |
| 滑动窗口 | 固定 | 丢长程信息 |
KV Cache 管理
预分配 vs 动态分配
预分配:启动时为每个请求槽位分配最大 KV Cache
优点:简单,不用管理
缺点:浪费显存
动态分配:按需分配,用多少分多少
优点:省显存
缺点:碎片化,分配开销
vLLM 用 PagedAttention 实现高效动态分配。
显存预算
推理服务要规划显存预算:
总显存 = 模型参数 + KV Cache + 其他开销
80GB 显存:
- 模型参数:14GB(7B FP16)
- 其他开销:2GB
- 可用于 KV Cache:64GB
64GB / 512KB = 128K tokens
如果平均每个请求 2K tokens,可以同时处理 64 个请求
监控
运行时要监控 KV Cache 使用率:
# vLLM 监控
metrics:
- gpu_cache_usage_perc # KV Cache 使用率
- num_running_requests # 运行中请求数
- num_waiting_requests # 等待中请求数
KV Cache 满了,新请求就得排队。
长上下文的挑战
现在模型上下文越来越长:4K → 8K → 32K → 128K → 1M
问题放大
上下文越长,KV Cache 越大:
| 上下文 | 7B 模型 KV Cache(单请求) |
|---|---|
| 4K | 2GB |
| 32K | 16GB |
| 128K | 64GB |
| 1M | 512GB |
128K 上下文,单个请求的 KV Cache 就 64GB,一张卡只能跑一个请求。
解决思路
- KV Cache 压缩:量化、稀疏
- KV Cache offload:放到 CPU 内存或 SSD
- 分层缓存:热数据在 GPU,冷数据在 CPU
- 稀疏 Attention:不存所有 token,只存重要的
这是当前研究热点。
代码层面
查看 KV Cache 大小
# 估算 KV Cache 大小
def estimate_kv_cache_size(
num_layers,
hidden_size,
num_heads,
seq_length,
batch_size,
dtype_bytes=2 # FP16
):
# 每层每个 token 的 KV
per_token = 2 * hidden_size * dtype_bytes # K + V
# 所有层
per_token_all_layers = per_token * num_layers
# 总大小
total = per_token_all_layers * seq_length * batch_size
return total / (1024**3) # GB
# LLaMA-7B, 4K 上下文, batch=8
size = estimate_kv_cache_size(32, 4096, 32, 4096, 8)
print(f"KV Cache: {size:.2f} GB") # ~16GB
vLLM 配置 KV Cache
from vllm import LLM
llm = LLM(
model="meta-llama/Llama-2-7b-hf",
gpu_memory_utilization=0.9, # 90% 显存用于模型+KV Cache
max_model_len=4096, # 最大上下文长度
)
小结
KV Cache 核心知识:
是什么:缓存历史 token 的 Key 和 Value,避免重复计算
为什么大:
- 每个 token 每层都要存
- 上下文越长越大
- 并发越高越大
优化方法:
- 量化(INT8/INT4)
- PagedAttention(分页管理)
- GQA/MQA(共享 KV)
- 滑动窗口
关键认知:
- 长上下文推理,KV Cache 是主要显存消耗
- 推理服务的并发能力受限于 KV Cache
下一篇讲 vLLM:PagedAttention 和 Continuous Batching 的原理。