HiHuo
首页
博客
手册
工具
关于
首页
博客
手册
工具
关于

为什么要分布式训练

单卡训练大模型,要么放不下,要么太慢。所以必须分布式。

这篇讲清楚:为什么单卡不够、分布式能解决什么问题、有什么代价。


单卡的极限

显存不够

一个 7B 参数的模型,用 FP16 精度:

  • 模型参数:7B × 2 字节 = 14GB

看起来 24GB 的 4090 也能放下?

但训练时还要存这些东西:

  • 梯度:和参数一样大,14GB
  • 优化器状态(Adam):参数的 2 倍,28GB
  • 激活值:取决于 batch size,可能几十 GB

加起来:

总显存 ≈ 参数 + 梯度 + 优化器状态 + 激活值
       ≈ 14 + 14 + 28 + 激活值
       ≈ 56GB + 激活值

80GB 的 A100 都不一定够。

参数量更大呢?

模型参数量FP16 显存(仅参数)训练所需显存(估算)
7B70 亿14GB80-120GB
13B130 亿26GB150-200GB
70B700 亿140GB800-1000GB
175B1750 亿350GB2-3TB

70B 以上的模型,单卡根本放不下。必须切分到多张卡上。

算力不够

就算显存够,单卡算力也可能不够。

训练 7B 模型,1 万亿 token 的数据:

  • 计算量:约 6 × 7B × 1T = 4.2 × 10^22 FLOPS
  • A100 FP16 算力:312 TFLOPS
  • 理论训练时间:4.2 × 10^22 / (312 × 10^12) / 3600 / 24 ≈ 1560 天

4 年多。等训完黄花菜都凉了。

用 100 张 A100:15.6 天。 用 1000 张 A100:1.56 天。

算力可以堆。


分布式能解决什么

1. 突破显存限制

把模型切分到多张卡上,每张卡只放一部分。

8 张 80GB 的卡 = 640GB 显存。

原来放不下的模型,现在能放了。

2. 加速训练

多张卡并行计算,训练时间近似线性下降。

1 张卡 100 天,100 张卡可能 1-2 天(考虑通信开销)。

3. 更大的 batch size

单卡 batch size 受显存限制。多卡可以每张卡放一部分数据,等效于更大的 batch size。

大 batch size 通常能提升训练效率和效果。


分布式的代价

没有免费午餐。分布式带来新的问题:

1. 通信开销

多卡之间要同步数据(梯度、激活值等),需要时间。

如果通信太慢,GPU 大部分时间在等,利用率上不去。

这就是为什么前面花那么多篇幅讲 NVLink、InfiniBand。

2. 复杂度上升

代码要改:

  • 数据怎么分
  • 模型怎么切
  • 梯度怎么同步
  • 多机怎么启动

调试也变难了:100 张卡出 bug,定位起来比单卡难多了。

3. 效率损失

理论上 100 张卡应该快 100 倍,实际可能只快 70-80 倍。

损失来自:

  • 通信时间
  • 负载不均衡
  • 同步等待

这个比例叫扩展效率(Scaling Efficiency)。80% 的扩展效率算不错了。

4. 成本上升

100 张卡的成本不只是 1 张卡的 100 倍:

  • 要买高速网络(InfiniBand 很贵)
  • 要买专用服务器(DGX 比普通服务器贵)
  • 电费、机房、运维都要增加

什么时候需要分布式

必须分布式

  • 模型参数 > 单卡显存
  • 训练时间要求紧(几天内出结果)
  • Batch size 要求大

可以考虑分布式

  • 有现成的多卡环境
  • 想加速迭代
  • 数据量很大

不需要分布式

  • 小模型(1B 以下)
  • 微调(显存占用小)
  • 个人学习研究

分布式的几种方式

后面会详细讲,这里先预告一下:

数据并行(Data Parallelism)

每张卡放完整的模型,各自处理不同的数据,最后同步梯度。

卡0: 完整模型,处理 batch 0
卡1: 完整模型,处理 batch 1
卡2: 完整模型,处理 batch 2
...
同步梯度,更新参数

适合:模型能放进单卡,想加速训练

模型并行(Model Parallelism)

把模型切分到多张卡上。

  • 张量并行:把一层切分到多张卡
  • 流水线并行:不同的层放在不同的卡
卡0: 层 0-10
卡1: 层 11-20
卡2: 层 21-30
...
数据像流水线一样经过各张卡

适合:模型太大,单卡放不下

混合并行

实际训练大模型,通常同时用多种并行:

- 张量并行:同一个节点内的 8 张卡
- 流水线并行:跨节点
- 数据并行:多组流水线处理不同数据

复杂但有效。


框架选择

分布式训练不用从零写。有现成的框架:

框架特点适合场景
PyTorch DDP官方支持,简单数据并行
DeepSpeed微软开源,功能全大模型训练
Megatron-LMNVIDIA 开源,性能强超大模型训练
FSDPPyTorch 官方大模型训练
ColossalAI国产,易用入门级大模型

后面会专门对比这些框架。


实际案例

GPT-3 训练

  • 模型:175B 参数
  • 硬件:约 1000 张 V100
  • 时间:约 1 个月
  • 并行方式:数据并行 + 模型并行

LLaMA 训练

  • 模型:65B 参数
  • 硬件:2048 张 A100
  • 时间:约 21 天
  • 数据:1.4 万亿 token

国内大模型

大部分国内大模型训练规模:

  • 几百到几千张 A100/H100
  • 几周到几个月
  • 成本:几千万到几亿人民币

小结

为什么需要分布式训练:

  1. 显存不够:大模型参数 + 梯度 + 优化器状态,单卡放不下
  2. 算力不够:单卡训练太慢,等不起
  3. Batch size 要求:更大的 batch 需要更多显存

分布式的代价:

  1. 通信开销
  2. 复杂度上升
  3. 效率损失
  4. 成本上升

核心认知:分布式不是可选的,是大模型训练的必须。 7B 以上的模型训练,基本都要分布式。

下一篇详细讲三种并行方式:数据并行、模型并行、流水线并行。