模型并行策略搜索:为模型自动寻找最优切分方案
引言:为什么需要自动寻找最优切分方案
随着大语言模型(LLM)参数量向千亿甚至万亿规模迈进,模型并行已经成为分布式训练的基本功。传统的模型并行策略往往依赖专家经验,手工雕琢层切分、设备映射和通信模式,但模型的异构结构、硬件拓扑的多样性让“最佳”策略极难确定。
手工调优不仅耗时,而且容易陷入局部最优——某个切分在 A100 集群上表现优异,搬到 H100 或跨地域集群就可能退化。模型并行策略搜索 应运而生:将切分问题建模为一个优化问题,通过算法自动探索庞大的策略空间,找到“训练吞吐最高、显存占用最低、通信开销最小”的帕累托前沿方案。
本教程将从零开始,带你理解模型并行策略搜索的核心概念、搜索空间构建、关键算法以及主流实用工具,让大模型分布式训练不再依赖“拍脑袋”。
第一章:模型并行基础回顾
在进入搜索之前,必须厘清几种基本并行模式,因为它们是搜索空间的生产原语。
1.1 数据并行(Data Parallelism, DP)
每个设备持有完整模型副本,输入数据被切分成 micro-batches,独立执行前向和反向,梯度通过 AllReduce 同步。优势是实现简单,但要求单卡能放下整个模型。
1.2 张量并行(Tensor Parallelism, TP)
将层内的权重矩阵沿行或列切分到多个设备上,计算和通信交织进行。Megatron-LM 的列并行与行并行就是典型实现。TP 通信量大,主要用于 Transformer 内的 Attention 和 MLP 块。
1.3 流水线并行(Pipeline Parallelism, PP)
模型按层纵向切分,每个设备负责连续的一组层,微批次数据像流水线一样流过各阶段。存在 F-then-B 或 1F1B 调度来减少气泡。通信量相对较小,但要求平衡各阶段的计算负载。
1.4 序列并行(Sequence Parallelism, SP)
与张量并行配合,将长序列维度切分到多个设备,降低 LayerNorm 和 Dropout 处的激活内存。通常与 TP 共同使用。
1.5 混合并行
现代大模型训练几乎必用混合并行,例如 TP+PP+DP 的三维并行网格。搜索的目标就是为网格中每一维选择度,并为算子分配具体策略。
第二章:策略搜索问题的形式化
2.1 三个核心定义
- 策略 (Strategy):为计算图中每个算子指定并行化方案,包括该算子是否切分、沿哪个维度切分、切分成几份、映射到哪些设备上。
- 代价模型 (Cost Model):量化策略的执行代价,通常是训练一步的时间或吞吐量,以及峰值显存占用。
- 搜索空间 (Search Space):所有合法策略的集合。受硬件设备数量、拓扑带宽、内存容量等约束。
目标可以形式化为:
找到策略 ( s^{*} ),使得最大显存占用 ( M(s) \leq M_{\text{limit}} ),并最小化单步耗时 ( T(s) )。
实际场景中常退化为吞吐量最大化或内存满足约束下的最优。
第三章:构建搜索空间:你的可选“配方”
搜索空间的粒度和表达力直接影响搜索质量。
3.1 算子级切分空间
将模型表示为一个计算图(如通过 JAX 的 XLA 或 PyTorch 的 FX),每个计算节点(矩阵乘法、LayerNorm、激活函数等)可以配置:
- 输入/输出的张量切分方式(沿 batch 维度、序列维度、隐藏维度等)。
- 权重矩阵的切分(行切、列切、无切分)。
- 设备网格拓扑(例如
(2, 4, 8)表示 DP=2, TP=4, PP=8)。
搜索空间大小会随算子数指数增长,必须设计合理的剪枝和规则约束。
3.2 细粒度 vs. 粗粒度策略
- 粗粒度搜索:只决定层级别的 TP、PP、DP 度,层内采用统一策略。例如 Alpa 的 inter-op 并行搜索。
- 细粒度搜索:允许每个算子有独立策略,能够挖掘算子间并行化的重叠机会(如自动融合通信与计算)。
初学者建议从粗粒度开始,逐步过渡到细粒度,以平衡搜索代价与收益。
3.3 约束条件编码
必须将硬件限制编码进空间:
- 设备数量:TP 维度乘积不能超过总设备数。
- 通信拓扑:高带宽域(NVLink)内的设备适合 TP,低带宽域(跨节点网络)适合 PP。
- 内存限制:策略必须确保单卡激活+参数+优化器状态不超过显存。
第四章:代价模型:不实际操作怎么预知性能
搜索算法需要一个代价模型来评估未实际执行的策略。
4.1 基于分析 (Analytical) 的成本模型
- 计算代价:根据算子 FLOPs 和设备峰值算力估算,考虑矩阵乘的 arithmetic intensity。
- 通信代价:根据切分方式推导所需的通信原语(AllGather、ReduceScatter、AllReduce),计算通信数据量,再除以链路带宽。对于流水线并行模型,还需计算 bubble 时间。
- 内存代价:权重、梯度、优化器状态、激活值(通过 activation checkpoint 调整)求和。
这种模型计算快,但容易因为不精确(忽略 kernel launch 开销、竞争效应)导致排序错误。
4.2 基于模拟器的代价模型
使用模拟器(如 ASTRA-sim、Habitat 等)运行策略的简化 trace,可以捕获网络拥塞、计算通信重叠等动态行为。精度更高但速度慢,常用于搜索过程的 refinement 阶段。
4.3 数据驱动代价模型
利用少量真实执行样本训练一个神经网络预测器,输入为策略表示,输出为吞吐/显存。需要解决策略表征学习,可以通过图神经网络编码计算图与策略。
实践中常用分析模型作快速筛选,再用少量实测验证。
第五章:主流搜索算法一览
5.1 动态规划 (DP) —— Alpa 的 inter-op 搜索
将流水线阶段的切分问题视为动态规划,对于算子序列,递推计算不同设备数和切分点的最小时延。该方法保证在特定假设下得到最优解,适用于流水线并行度的自动选择。
5.2 枚举与剪枝
对较小的搜索空间,可以暴力枚举并利用代价模型剪枝。例如,先固定 TP、PP 大小组合,通过成本模型过滤掉显存溢出或明显低效的配置。
5.3 随机搜索与贝叶斯优化
随机采样策略并评估,简单有效,尤其适合多超参数的混合并行配置。贝叶斯优化可以用高斯过程拟合吞吐表面,引导搜索向高吞吐区域,适合资源受限的搜索预算。
5.4 基于 MCTS (蒙特卡洛树搜索) 的搜索
将策略构建视为序列决策:逐算子或逐层分配并行度。MCTS 可以在庞大的组合空间中平衡探索与利用。MindSpore 的 “金箍棒” 就使用了类似思路。
5.5 强化学习 (RL)
将设备放置看作 action,通过策略网络输出配置,用代价模型作为 reward。适合极大规模空间,但训练 RL 本身成本高。
5.6 整数线性规划 (ILP)
将通信与计算约束写成线性不等式,用求解器找出精确最优解。对于小规模图可求出真最优,但图稍大即不可解。
第六章:典型工具实战:从 Alpa 到 Megatron-LM 的自动并行
6.1 Alpa:自动模型并行编译器
Alpa 是 Google 推出的基于 JAX 的自动并行库,核心思想是两层搜索:
- Inter-op 并行:决定流水线调度和设备网格的流水线维。
- Intra-op 并行:在每个流水线阶段内,使用整数规划为每个算子分配张量并行策略。
初学体验:只需给 JAX 代码加上装饰器,Alpa 便可自动搜索并在集群上执行。非常适合研究者和想快速验证想法的人。
import alpa
@alpa.parallelize
def train_step(state, batch):
...
6.2 FlexFlow Serve 与 Unity
FlexFlow 支持细粒度算子级并行,允许输入/输出张量不同维度切分(SOAP 搜索空间),采用基于代价模型的 MCMC 采样搜索最优并行代码。
6.3 Megatron-LM 的半自动网格搜索
虽然 Megatron-LM 本身不提供自动搜索,但社区实践中常写脚本遍历 TP、PP、DP 组合,依据经验公式估算显存,并实测最佳配置。也有如 Amazon SageMaker 的模型并行自动调优功能与此类似。
6.4 PyTorch 的 torch.distributed.tensor 与 DeviceMesh
PyTorch 2.0 引入分布式张量抽象,支持 SPMD 风格编程。
from torch.distributed.tensor import distribute_tensor, DeviceMesh
mesh = DeviceMesh("cuda", [[0, 1], [2, 3]])
# 然后手动配置张量切分,未来或将与自动搜索工具(如 torch.compile 配合)集成
第七章:进阶话题与最佳实践
7.1 动态重配置与重计算权衡
搜索出的策略并非一成不变。训练的不同阶段(比如 warm-up、stable)可动态调整并行拓扑,但切换成本需要纳入代价模型。另外,activation checkpointing 策略也应作为搜索变量,与并行策略共同优化。
7.2 异构硬件与非对称拓扑
当集群由不同型号 GPU 组成或网络拓扑不规则时,搜索需引入带权重的设备图。策略能映射切分到更快通信的链路,并限制弱卡的计算量。
7.3 搜索时间 vs. 训练时间的平衡
执行搜索本身也要花费时间,收益必须大于成本。对于同一个模型需要长期训练的预训练任务,搜索的投入产出比极高;对于微调或小规模实验,手工规则可能更经济。
7.4 诊断与可解释性
自动搜索的结果应该输出可视化策略:展示每一层的切分方式、设备映射以及预估的通信热图。这能帮助工程师验证合理性并累积直觉。
第八章:总结与学习路径
模型并行策略搜索正在从手工调优迈向自动化的关键阶段。要掌握这一技能,建议遵循:
- 手工实践传统并行:先用 Megatron、DeepSpeed 配置 TP、PP、DP,理解通信原语和显存模型。
- 学习分析成本模型:能够写出一个简化的吞吐和显存估算脚本,预测单个配置。
- 使用 Alpa 等自动工具:上手感受自动并行的效率和局限。
- 阅读论文:Alpa、FlexFlow、GSPMD、Unity 等经典文献能帮你理解搜索算法的演进。
- 关注产业落地:跟踪 PyTorch 分布式 API 与 compiler 的整合,以及各云平台的自动调优服务。
模型越大,并行越复杂,搜索的价值就越不可替代。早日构建自动优化的思维,能让你在未来千卡、万卡集群上从容驾驭巨型模型。
本教程由「免费在线教程」出品,持续追踪系统与机器学习前沿,致力提供零门槛高密度的技术学习内容。