模型并行策略搜索:为模型自动寻找最优切分方案

FreeGuideOnline 最新 2026-06-28

引言:为什么需要自动寻找最优切分方案

随着大语言模型(LLM)参数量向千亿甚至万亿规模迈进,模型并行已经成为分布式训练的基本功。传统的模型并行策略往往依赖专家经验,手工雕琢层切分、设备映射和通信模式,但模型的异构结构、硬件拓扑的多样性让“最佳”策略极难确定。

手工调优不仅耗时,而且容易陷入局部最优——某个切分在 A100 集群上表现优异,搬到 H100 或跨地域集群就可能退化。模型并行策略搜索 应运而生:将切分问题建模为一个优化问题,通过算法自动探索庞大的策略空间,找到“训练吞吐最高、显存占用最低、通信开销最小”的帕累托前沿方案。

本教程将从零开始,带你理解模型并行策略搜索的核心概念、搜索空间构建、关键算法以及主流实用工具,让大模型分布式训练不再依赖“拍脑袋”。


第一章:模型并行基础回顾

在进入搜索之前,必须厘清几种基本并行模式,因为它们是搜索空间的生产原语。

1.1 数据并行(Data Parallelism, DP)

每个设备持有完整模型副本,输入数据被切分成 micro-batches,独立执行前向和反向,梯度通过 AllReduce 同步。优势是实现简单,但要求单卡能放下整个模型。

1.2 张量并行(Tensor Parallelism, TP)

将层内的权重矩阵沿行或列切分到多个设备上,计算和通信交织进行。Megatron-LM 的列并行与行并行就是典型实现。TP 通信量大,主要用于 Transformer 内的 Attention 和 MLP 块。

1.3 流水线并行(Pipeline Parallelism, PP)

模型按层纵向切分,每个设备负责连续的一组层,微批次数据像流水线一样流过各阶段。存在 F-then-B1F1B 调度来减少气泡。通信量相对较小,但要求平衡各阶段的计算负载。

1.4 序列并行(Sequence Parallelism, SP)

与张量并行配合,将长序列维度切分到多个设备,降低 LayerNorm 和 Dropout 处的激活内存。通常与 TP 共同使用。

1.5 混合并行

现代大模型训练几乎必用混合并行,例如 TP+PP+DP 的三维并行网格。搜索的目标就是为网格中每一维选择度,并为算子分配具体策略。


第二章:策略搜索问题的形式化

2.1 三个核心定义

  1. 策略 (Strategy):为计算图中每个算子指定并行化方案,包括该算子是否切分、沿哪个维度切分、切分成几份、映射到哪些设备上。
  2. 代价模型 (Cost Model):量化策略的执行代价,通常是训练一步的时间或吞吐量,以及峰值显存占用。
  3. 搜索空间 (Search Space):所有合法策略的集合。受硬件设备数量、拓扑带宽、内存容量等约束。

目标可以形式化为:

找到策略 ( s^{*} ),使得最大显存占用 ( M(s) \leq M_{\text{limit}} ),并最小化单步耗时 ( T(s) )。

实际场景中常退化为吞吐量最大化或内存满足约束下的最优。


第三章:构建搜索空间:你的可选“配方”

搜索空间的粒度和表达力直接影响搜索质量。

3.1 算子级切分空间

将模型表示为一个计算图(如通过 JAX 的 XLA 或 PyTorch 的 FX),每个计算节点(矩阵乘法、LayerNorm、激活函数等)可以配置:

  • 输入/输出的张量切分方式(沿 batch 维度、序列维度、隐藏维度等)。
  • 权重矩阵的切分(行切、列切、无切分)。
  • 设备网格拓扑(例如 (2, 4, 8) 表示 DP=2, TP=4, PP=8)。

搜索空间大小会随算子数指数增长,必须设计合理的剪枝和规则约束。

3.2 细粒度 vs. 粗粒度策略

  • 粗粒度搜索:只决定层级别的 TP、PP、DP 度,层内采用统一策略。例如 Alpa 的 inter-op 并行搜索。
  • 细粒度搜索:允许每个算子有独立策略,能够挖掘算子间并行化的重叠机会(如自动融合通信与计算)。

初学者建议从粗粒度开始,逐步过渡到细粒度,以平衡搜索代价与收益。

3.3 约束条件编码

必须将硬件限制编码进空间:

  • 设备数量:TP 维度乘积不能超过总设备数。
  • 通信拓扑:高带宽域(NVLink)内的设备适合 TP,低带宽域(跨节点网络)适合 PP。
  • 内存限制:策略必须确保单卡激活+参数+优化器状态不超过显存。

第四章:代价模型:不实际操作怎么预知性能

搜索算法需要一个代价模型来评估未实际执行的策略。

4.1 基于分析 (Analytical) 的成本模型

  • 计算代价:根据算子 FLOPs 和设备峰值算力估算,考虑矩阵乘的 arithmetic intensity。
  • 通信代价:根据切分方式推导所需的通信原语(AllGather、ReduceScatter、AllReduce),计算通信数据量,再除以链路带宽。对于流水线并行模型,还需计算 bubble 时间。
  • 内存代价:权重、梯度、优化器状态、激活值(通过 activation checkpoint 调整)求和。

这种模型计算快,但容易因为不精确(忽略 kernel launch 开销、竞争效应)导致排序错误。

4.2 基于模拟器的代价模型

使用模拟器(如 ASTRA-sim、Habitat 等)运行策略的简化 trace,可以捕获网络拥塞、计算通信重叠等动态行为。精度更高但速度慢,常用于搜索过程的 refinement 阶段。

4.3 数据驱动代价模型

利用少量真实执行样本训练一个神经网络预测器,输入为策略表示,输出为吞吐/显存。需要解决策略表征学习,可以通过图神经网络编码计算图与策略。

实践中常用分析模型作快速筛选,再用少量实测验证。


第五章:主流搜索算法一览

5.1 动态规划 (DP) —— Alpa 的 inter-op 搜索

将流水线阶段的切分问题视为动态规划,对于算子序列,递推计算不同设备数和切分点的最小时延。该方法保证在特定假设下得到最优解,适用于流水线并行度的自动选择。

5.2 枚举与剪枝

对较小的搜索空间,可以暴力枚举并利用代价模型剪枝。例如,先固定 TP、PP 大小组合,通过成本模型过滤掉显存溢出或明显低效的配置。

5.3 随机搜索与贝叶斯优化

随机采样策略并评估,简单有效,尤其适合多超参数的混合并行配置。贝叶斯优化可以用高斯过程拟合吞吐表面,引导搜索向高吞吐区域,适合资源受限的搜索预算。

5.4 基于 MCTS (蒙特卡洛树搜索) 的搜索

将策略构建视为序列决策:逐算子或逐层分配并行度。MCTS 可以在庞大的组合空间中平衡探索与利用。MindSpore 的 “金箍棒” 就使用了类似思路。

5.5 强化学习 (RL)

将设备放置看作 action,通过策略网络输出配置,用代价模型作为 reward。适合极大规模空间,但训练 RL 本身成本高。

5.6 整数线性规划 (ILP)

将通信与计算约束写成线性不等式,用求解器找出精确最优解。对于小规模图可求出真最优,但图稍大即不可解。


第六章:典型工具实战:从 Alpa 到 Megatron-LM 的自动并行

6.1 Alpa:自动模型并行编译器

Alpa 是 Google 推出的基于 JAX 的自动并行库,核心思想是两层搜索:

  • Inter-op 并行:决定流水线调度和设备网格的流水线维。
  • Intra-op 并行:在每个流水线阶段内,使用整数规划为每个算子分配张量并行策略。

初学体验:只需给 JAX 代码加上装饰器,Alpa 便可自动搜索并在集群上执行。非常适合研究者和想快速验证想法的人。

import alpa
@alpa.parallelize
def train_step(state, batch):
    ...

6.2 FlexFlow Serve 与 Unity

FlexFlow 支持细粒度算子级并行,允许输入/输出张量不同维度切分(SOAP 搜索空间),采用基于代价模型的 MCMC 采样搜索最优并行代码。

6.3 Megatron-LM 的半自动网格搜索

虽然 Megatron-LM 本身不提供自动搜索,但社区实践中常写脚本遍历 TP、PP、DP 组合,依据经验公式估算显存,并实测最佳配置。也有如 Amazon SageMaker 的模型并行自动调优功能与此类似。

6.4 PyTorch 的 torch.distributed.tensor 与 DeviceMesh

PyTorch 2.0 引入分布式张量抽象,支持 SPMD 风格编程。

from torch.distributed.tensor import distribute_tensor, DeviceMesh
mesh = DeviceMesh("cuda", [[0, 1], [2, 3]])
# 然后手动配置张量切分,未来或将与自动搜索工具(如 torch.compile 配合)集成

第七章:进阶话题与最佳实践

7.1 动态重配置与重计算权衡

搜索出的策略并非一成不变。训练的不同阶段(比如 warm-up、stable)可动态调整并行拓扑,但切换成本需要纳入代价模型。另外,activation checkpointing 策略也应作为搜索变量,与并行策略共同优化。

7.2 异构硬件与非对称拓扑

当集群由不同型号 GPU 组成或网络拓扑不规则时,搜索需引入带权重的设备图。策略能映射切分到更快通信的链路,并限制弱卡的计算量。

7.3 搜索时间 vs. 训练时间的平衡

执行搜索本身也要花费时间,收益必须大于成本。对于同一个模型需要长期训练的预训练任务,搜索的投入产出比极高;对于微调或小规模实验,手工规则可能更经济。

7.4 诊断与可解释性

自动搜索的结果应该输出可视化策略:展示每一层的切分方式、设备映射以及预估的通信热图。这能帮助工程师验证合理性并累积直觉。


第八章:总结与学习路径

模型并行策略搜索正在从手工调优迈向自动化的关键阶段。要掌握这一技能,建议遵循:

  1. 手工实践传统并行:先用 Megatron、DeepSpeed 配置 TP、PP、DP,理解通信原语和显存模型。
  2. 学习分析成本模型:能够写出一个简化的吞吐和显存估算脚本,预测单个配置。
  3. 使用 Alpa 等自动工具:上手感受自动并行的效率和局限。
  4. 阅读论文:Alpa、FlexFlow、GSPMD、Unity 等经典文献能帮你理解搜索算法的演进。
  5. 关注产业落地:跟踪 PyTorch 分布式 API 与 compiler 的整合,以及各云平台的自动调优服务。

模型越大,并行越复杂,搜索的价值就越不可替代。早日构建自动优化的思维,能让你在未来千卡、万卡集群上从容驾驭巨型模型。


本教程由「免费在线教程」出品,持续追踪系统与机器学习前沿,致力提供零门槛高密度的技术学习内容。