一、并行的意义
当你尝试训练一个70B参数的模型时,会遇到两个硬性瓶颈:
| 瓶颈类型 | 具体表现 | 估算数据 |
显存瓶颈 | 单卡放不下模型 | 70B参数FP16需要140GB+,加上优化器状态、梯度、激活值,总需求超过300GB |
计算瓶颈 | 单卡训练太慢 | 全量训练可能需要数月甚至数年 |
并行策略的核心思想:把"大"问题拆解成多个"小"问题,分布到多GPU上同时解决。
二、 三种并行策略详解
2.1 数据并行(Data Parallelism, DP)
核心思想:数据分片,模型复制。每张卡都存完整的模型,但处理不同的数据批次。
关键特征:
显存占用:每卡都存完整模型,显存需求无减少。
加速效果:理想情况下线性加速(4卡≈4倍速),实际因通信损耗约70-90%效率。
适用场景:模型能塞进单卡,但想加速训练。
通信开销:每轮迭代需要一次All-Reduce同步梯度,通信量与模型参数量成正比。对于70B模型,每次迭代需传输140GB梯度数据(FP16)。
2.2 张量并行(Tensor Parallelism, TP)
核心思想:模型纵向切分,层内分割。把单层网络的参数矩阵切分到多卡,每张卡只存部分权重。
以Transformer的MLP层为例(简化表示):
原始计算: Y = GeLU(X @ A) @ B
A: [hidden_size, 4*hidden_size]
B: [4*hidden_size, hidden_size]
TP=2时的切分方式(按列切分A,按行切分B):
┌──────────────────────────────────────────────────────────┐
│ GPU 0 持有 A[:, :2h] 和 B[:2h, :] │
│ GPU 1 持有 A[:, 2h:] 和 B[2h:, :] │
│ │
│ 前向传播: │
│ GPU 0: Z0 = X @ A[:, :2h] → GeLU(Z0) @ B[:2h, :] = Y0 │
│ GPU 1: Z1 = X @ A[:, 2h:] → GeLU(Z1) @ B[2h:, :] = Y1 │
│ │
│ 需要All-Gather通信合并 Y = Y0 + Y1 │
└──────────────────────────────────────────────────────────┘2.3 流水线并行(Pipeline Parallelism, PP)
核心思想:模型横向切分,层间分割。把不同层放到不同卡,数据像流水线一样流动。
气泡问题与GPipe改进:
朴素流水线(如GPipe)把batch切成更多微批次(micro-batches),让GPU尽可能填满:
微批次数量 = 4, PP度数 = 4 时的填充效果:
GPU 0: [F0][F1][F2][F3][B3][B2][B1][B0]
GPU 1: [F0][F1][F2][F3][B3][B2][B1][B0]
GPU 2: [F0][F1][F2][F3][B3][B2][B1][B0]
GPU 3: [F0][F1][F2][F3][B3][B2][B1][B0]
气泡率 ≈ (PP-1)/(PP+M-1) M为微批次数量,M越大气泡越小1F1B(One Forward One Backward)调度:当前主流方案,显存更优,但气泡稍大。
关键特征:
显存占用:每卡只存部分层,显存与层数成正比减少
加速效果:接近线性加速(通信极少),但受气泡影响,通常80-95%效率
硬性限制:PP度数 ≤ 总层数,且最好整除;层间通信量小(激活值),可跨节点
三、三维并行的组合与配置
实际大模型训练(如GPT-3、LLaMA)需要TP + PP + DP三者组合:
配置约束清单:
| 维度 | 硬性约束 | 软性建议 | 通信方式 |
TP | ≤ 单节点GPU数(通常8) | 2/4/8,避免3/5/6 | NVLink/NCCL,带宽>400GB/s |
PP | ≤ 总层数,建议整除 | 4-8,太大则气泡大 | IB/RoCE,带宽>50GB/s |
DP | 无上限,但需整除总卡数 | 根据global batch size定 | 跨节点All-Reduce |
配置优先级:
先定TP:能塞进单节点就不跨节点,TP通信最密集
再定PP:层数允许范围内最大化,减少单卡显存
最后DP:剩余卡数做数据并行,提升吞吐
四、显存与速度的深度分析
4.1 显存占用公式
单卡显存需求 ≈
模型参数: Params × 2 (FP16/BF16) / TP
优化器状态: Params × 4 * 2 (Adam) / PP
梯度: Params × 2 / TP
激活值: Batch × Seq × Hidden × Layers × 4 / (TP × PP) [重计算可减]TP降低参数/梯度显存,但不降优化器(除非配合ZeRO)
PP降低激活值显存(层数少了),也降参数
DP不降单卡显存,只通过多卡扩大全局batch
4.2 速度提升的边界
理想吞吐 = 单卡吞吐 × 总卡数 × 并行效率
效率损失来源:
├─ TP: All-Reduce通信,单节点内<5%损失,跨节点>30%损失
├─ PP: 气泡时间,GPipe约10-20%,1F1B约15-25%
└─ DP: 梯度同步,大batch时<5%,小batch时>20%实践建议:
优先用PP扩展:通信最少,扩展性最好
TP只用单节点:跨节点TP是性能杀手
DP需要大batch:global batch ≥ 1M tokens时效率最佳
五、配置实例:70B模型训练
硬件:4节点×8A100-80GB(共32卡)
模型配置:70B参数,80层, hidden_size=8192, seq_len=4096
计算配置:
TP=8(单节点内8卡)
PP=4(4节点构成完整模型)
DP=1
| 项目 | 每卡参数量 | 每卡显存 |
bf16 参数 | 2.2 B | 2.2 B × 2 B = 4.4 GB |
fp32 master weight | 2.2 B | 2.2 B × 4 B = 8.8 GB |
Adam(momentum + variance) | 2.2 B | 2 × 2.2 B × 4 B = 17.6 GB |
bf16 梯度 | 2.2 B | 2.2 B × 2 B = 4.4 GB |
激活值 |
| 8192 (seq) × 1 (micro_batch) × 8192 (hidden) × 2 B / TP8 ≈ 16 MB 峰值激活 ≈ 4 (in-flight) × 20 (layers) × 16 MB ≈ 1.3 GB |
小计 |
| ~37 GB |