HiHuo
首页
博客
手册
工具
关于
首页
博客
手册
工具
关于

KV Cache:为什么显存越用越多

推理时,模型参数大小是固定的。但跑着跑着,显存占用越来越高。

罪魁祸首是 KV Cache。

这篇讲清楚 KV Cache 是什么、为什么占这么多显存、怎么优化。


先回顾 Attention

Transformer 的核心是 Attention 机制:

Attention(Q, K, V) = softmax(Q × K^T / √d) × V

对于输入序列的每个位置:

  • Q(Query):「我想找什么」
  • K(Key):「我有什么」
  • V(Value):「我的内容是什么」

计算时,当前位置的 Q 要和之前所有位置的 K 做点积,得到注意力权重,再和 V 加权求和。


为什么需要 KV Cache

自回归生成

大模型生成文本是自回归的:每次只生成一个 token,然后把这个 token 加到输入里,再生成下一个。

输入:「今天天气」
第 1 步:生成「很」→ 输入变成「今天天气很」
第 2 步:生成「好」→ 输入变成「今天天气很好」
第 3 步:生成「。」→ 输入变成「今天天气很好。」
...

问题:重复计算

如果每次都从头算 Attention:

第 1 步:算 token 0-3 的 K、V
第 2 步:算 token 0-4 的 K、V(token 0-3 重复算了)
第 3 步:算 token 0-5 的 K、V(token 0-4 重复算了)
...

之前的 token 反复计算,效率很低。

解决方案:缓存 K 和 V

把算过的 K、V 存起来,下次直接用:

第 1 步:算 token 0-3 的 K、V,存入 cache
第 2 步:从 cache 读 token 0-3 的 K、V,只算 token 4 的 K、V,追加到 cache
第 3 步:从 cache 读 token 0-4 的 K、V,只算 token 5 的 K、V,追加到 cache
...

这就是 KV Cache:缓存历史 token 的 Key 和 Value。


KV Cache 的大小

计算公式

单个 token 的 KV Cache = 2 × num_layers × hidden_size × 2(K 和 V)× 精度

总 KV Cache = 单个 token 大小 × 序列长度 × batch_size

具体例子

LLaMA-7B:

  • num_layers = 32
  • hidden_size = 4096
  • 精度:FP16(2 字节)

单个 token:

2 × 32 × 4096 × 2 = 512KB

如果上下文长度 4096:

512KB × 4096 = 2GB

如果同时处理 8 个请求:

2GB × 8 = 16GB

光 KV Cache 就 16GB 了。

更大的模型

模型单 token KV4K 上下文32K 上下文
7B512KB2GB16GB
13B800KB3.2GB25.6GB
70B2.5MB10GB80GB

70B 模型 32K 上下文,KV Cache 就要 80GB。模型参数还要 140GB。一张卡根本放不下。


KV Cache 的问题

1. 显存占用大

上面算过了。长上下文 + 高并发,KV Cache 轻松超过模型参数本身。

2. 动态增长

KV Cache 随序列长度增长:

生成第 1 个 token:KV Cache 增加 512KB
生成第 2 个 token:KV Cache 增加 512KB
...
生成第 4096 个 token:KV Cache 达到 2GB

显存占用不固定,容易 OOM。

3. 碎片化

传统做法:为每个请求预分配最大长度的 KV Cache。

问题:

  • 大部分请求用不到最大长度,显存浪费
  • 不同请求长度不同,显存碎片化

KV Cache 优化

1. 量化

KV Cache 也可以量化:

FP16 → INT8:大小减半
FP16 → INT4:大小减 75%

对效果影响不大,但显存省很多。

2. PagedAttention

vLLM 的核心优化。把 KV Cache 分成固定大小的「页」:

传统:为每个请求分配连续的 KV Cache
     [请求1: 4096 tokens][请求2: 4096 tokens][请求3: 4096 tokens]

PagedAttention:分页管理
     [页1][页2][页3][页4][页5][页6]...
     请求1 用页 1, 3, 5
     请求2 用页 2, 4
     请求3 用页 6

好处:

  • 按需分配,不浪费
  • 消除碎片
  • 支持共享(相同前缀的请求共享 KV Cache)

3. Sliding Window Attention

有些模型只看最近 N 个 token,不用存全部历史:

传统:存 token 0 到 token 4095 的 KV
滑动窗口:只存最近 1024 个 token 的 KV

显存占用固定,但会丢失长程依赖。

Mistral 模型用了这个技术。

4. Multi-Query Attention (MQA)

让多个 Query head 共享同一组 K、V:

传统 MHA:32 个 Q head,32 个 K head,32 个 V head
MQA:32 个 Q head,1 个 K head,1 个 V head

KV Cache 大小降到原来的 1/32。

缺点是效果略有下降。

5. Grouped-Query Attention (GQA)

MQA 和 MHA 的折中:

GQA:32 个 Q head,8 个 K head,8 个 V head

每 4 个 Q head 共享一组 K、V。

LLaMA-2 70B 用了 GQA。

对比

方法KV Cache 大小(相对 MHA)效果影响
MHA(传统)1x基准
GQA0.25x(8组)很小
MQA0.03x(1组)有一些
量化 INT80.5x很小
滑动窗口固定丢长程信息

KV Cache 管理

预分配 vs 动态分配

预分配:启动时为每个请求槽位分配最大 KV Cache

优点:简单,不用管理
缺点:浪费显存

动态分配:按需分配,用多少分多少

优点:省显存
缺点:碎片化,分配开销

vLLM 用 PagedAttention 实现高效动态分配。

显存预算

推理服务要规划显存预算:

总显存 = 模型参数 + KV Cache + 其他开销

80GB 显存:
- 模型参数:14GB(7B FP16)
- 其他开销:2GB
- 可用于 KV Cache:64GB

64GB / 512KB = 128K tokens

如果平均每个请求 2K tokens,可以同时处理 64 个请求

监控

运行时要监控 KV Cache 使用率:

# vLLM 监控
metrics:
  - gpu_cache_usage_perc    # KV Cache 使用率
  - num_running_requests    # 运行中请求数
  - num_waiting_requests    # 等待中请求数

KV Cache 满了,新请求就得排队。


长上下文的挑战

现在模型上下文越来越长:4K → 8K → 32K → 128K → 1M

问题放大

上下文越长,KV Cache 越大:

上下文7B 模型 KV Cache(单请求)
4K2GB
32K16GB
128K64GB
1M512GB

128K 上下文,单个请求的 KV Cache 就 64GB,一张卡只能跑一个请求。

解决思路

  1. KV Cache 压缩:量化、稀疏
  2. KV Cache offload:放到 CPU 内存或 SSD
  3. 分层缓存:热数据在 GPU,冷数据在 CPU
  4. 稀疏 Attention:不存所有 token,只存重要的

这是当前研究热点。


代码层面

查看 KV Cache 大小

# 估算 KV Cache 大小
def estimate_kv_cache_size(
    num_layers,
    hidden_size,
    num_heads,
    seq_length,
    batch_size,
    dtype_bytes=2  # FP16
):
    # 每层每个 token 的 KV
    per_token = 2 * hidden_size * dtype_bytes  # K + V
    # 所有层
    per_token_all_layers = per_token * num_layers
    # 总大小
    total = per_token_all_layers * seq_length * batch_size
    return total / (1024**3)  # GB

# LLaMA-7B, 4K 上下文, batch=8
size = estimate_kv_cache_size(32, 4096, 32, 4096, 8)
print(f"KV Cache: {size:.2f} GB")  # ~16GB

vLLM 配置 KV Cache

from vllm import LLM

llm = LLM(
    model="meta-llama/Llama-2-7b-hf",
    gpu_memory_utilization=0.9,  # 90% 显存用于模型+KV Cache
    max_model_len=4096,  # 最大上下文长度
)

小结

KV Cache 核心知识:

是什么:缓存历史 token 的 Key 和 Value,避免重复计算

为什么大:

  • 每个 token 每层都要存
  • 上下文越长越大
  • 并发越高越大

优化方法:

  • 量化(INT8/INT4)
  • PagedAttention(分页管理)
  • GQA/MQA(共享 KV)
  • 滑动窗口

关键认知:

  • 长上下文推理,KV Cache 是主要显存消耗
  • 推理服务的并发能力受限于 KV Cache

下一篇讲 vLLM:PagedAttention 和 Continuous Batching 的原理。