自动并行化:由框架决定模型如何分布到多卡
自动并行化:让框架决定模型如何分布到多卡
什么是自动并行化?
在深度学习中,当模型大到单张GPU显存放不下,或训练速度需要进一步提升时,我们会使用多张GPU进行分布式训练。传统的做法是手动指定每个算子或每个层放在哪张卡上,数据如何切分——这通常被称为模型并行、数据并行、流水线并行等,但它们的配置非常复杂且容易出错。
自动并行化(Automatic Parallelism)是一种由训练框架或编译器自动分析模型的计算图,然后智能地决定:
- 哪些算子放到哪张GPU上
- 张量如何切分(按行、按列、按批次)
- 通信操作插入何处(All-Reduce、All-Gather等)
目标是在不改变或少改变用户代码的前提下,实现高效的多卡分布式训练。用户只需定义模型和前向逻辑,框架会处理剩余的一切。
为什么需要自动并行化?
手动分配并行策略的痛点:
- 学习成本高:前置知识包括数据并行、模型并行、流水线并行、张量并行等,组合使用更难掌握。
- 代码侵入性强:需要大量改写模型代码,例如插入
.to(device)、切分张量、手动同步等。 - 策略依赖模型结构:不同模型的最优切分方案差异巨大,很难通过通用经验一次性配置好。
- 动态硬件适配:当GPU数量或拓扑变化时,手动策略往往需要重新设计和调整。
自动并行化将工程师从琐碎的切分细节中解放出来,让框架根据某种代价模型搜索一个较优的分布式执行计划,显著提升开发效率。
自动并行化如何工作?
大多数自动并行系统遵循以下流程:
-
计算图构建
框架捕获用户定义的前向和反向计算图,通常以算子(Operation)为节点,张量(Tensor)为边。 -
设备网格定义
用户声明可用的设备拓扑,例如[2, 4]表示 2 个节点、每个节点 4 张卡,总计 8 张卡。 -
切分布局探索
框架为每个张量搜索合适的 Sharding Spec(分片规格),例如:R:张量在该维度上复制(不切分)S(0):在第 0 维切分S(1):在第 1 维切分 同时为每个算子选择在哪些设备上执行。
-
代价模型与优化
基于通信量、计算量、内存占用等构建代价函数,利用动态规划、整数线性规划或启发式算法搜索一个总代价最小的并行策略。 -
可执行图生成
将找到的最优策略翻译成具体执行图,自动插入通信算子(如 All-Reduce、All-Gather、Reduce-Scatter)和重分片操作(Resharding),并生成可在多卡上运行的代码。
哪些框架支持自动并行化?
目前主流框架中,自动并行化的支持程度有所不同:
-
PyTorch 2.x +
torch.distributed
原生主要支持数据并行 (DistributedDataParallel) 和完全分片数据并行 (FullyShardedDataParallel),这些有一定自动性,但仍需用户明确调用。真正的自动模型并行需要借助外部编译器或库。 -
PyTorch +
torch.distributed.tensor(DTensor) + 自动并行 API
torch.distributed.tensor提供了张量级别的分片表示,配合DeviceMesh和自动并行探索工具,可以实现一定程度上的自动并行。但完整的自动并行尚在发展中。 -
Google JAX 与
pjit/shard_map
JAX 在函数式编程范式下,通过pjit或自动分片编译器(GSPMD)提供高度自动化的并行能力。用户只需声明输入/输出的分片约束,编译器会自动推导计算图内的分片。这是自动并行理念的典型代表。 -
OneFlow
早期提出 Consistent Tensor(一致性视角)和 SBP(Split, Broadcast, Partial)签名,实现自动推导并行策略。 -
MindSpore
提供自动并行功能,被称为“全自动并行”,用户只需设置set_auto_parallel_context,框架自动搜索策略。 -
阿里PAI的 TorchAcc
基于 Torch 的编译加速库,具备自动并行搜索功能。
示例:JAX的自动并行化思想
以JAX为例,展示自动并行化的典型使用。我们定义一个简单的矩阵乘法函数,并让编译器自动分发到8张卡。
import jax
import jax.numpy as jnp
from jax.sharding import Mesh, PartitionSpec, NamedSharding
from jax.experimental import mesh_utils
# 创建 2x4 的设备网格
devices = mesh_utils.create_device_mesh((2, 4))
mesh = Mesh(devices, ('data', 'model'))
# 定义计算逻辑
def f(x, w):
return jnp.dot(x, w)
# 声明输入张量的分片方式:
# x 在 data 维切分,w 在 model 维切分(矩阵乘法的列切分)
x_sharding = NamedSharding(mesh, PartitionSpec('data', None))
w_sharding = NamedSharding(mesh, PartitionSpec(None, 'model'))
# 创建分片张量
x = jax.device_put(jnp.ones((1024, 512)), x_sharding)
w = jax.device_put(jnp.ones((512, 256)), w_sharding)
# 使用 pjit 自动推导并行执行计划
result = jax.jit(f, in_shardings=(x_sharding, w_sharding),
out_shardings=NamedSharding(mesh, PartitionSpec('data', 'model')))(x, w)
在这个过程中,用户仅指定了输入输出的分片期望,JAX的编译器(GSPMD/XLA)自动推导出 dot 操作如何切分、何时插入 All-Reduce 等。这就是典型的自动并行化。
PyTorch中的自动并行化探索
PyTorch 也在逐步向自动并行靠拢。可使用 torch.distributed.tensor 和 DeviceMesh 进行实验:
import torch
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor import distribute_tensor, Shard, Replicate
mesh = init_device_mesh("cuda", (2, 4), mesh_dim_names=("dp", "mp"))
# 创建一个张量并按策略分片
tensor = torch.randn(8, 128)
dtensor = distribute_tensor(tensor, mesh, [Shard(0)]) # 自行指定切分
# 自动并行API (实验性)
from torch.distributed.tensor.parallel import loss_parallel
from torch.utils._pytree import tree_map
# 需要结合 torch.compile 和自动并行策略探索,目前还处于原型阶段
自动并行在 PyTorch 生态中更常见于 torch.compile 结合 DTensor 以及 torch.distributed._composable.fsdp 混用。不过完全无需用户指明切分策略的“全自动”版本,还需要等待生态成熟。
自动并行的优势与局限性
优势:
- 极低的使用门槛:甚至不超过十行代码就能实现复杂混合并行。
- 策略通用性:一次搜索,适配多种模型结构。
- 自动负载均衡:代价模型考虑通信与计算,避免人为偏执。
- 可随硬件弹性伸缩:换不同卡数时重新搜索即可。
局限性:
- 搜索耗时:复杂大模型的策略空间巨大,离线搜索可能花费数小时甚至更长。
- 不一定是最优解:自动算法找到的是近似最优,某些情况下需要手调以获得极致性能。
- 调试困难:出错时看到的执行图可能与用户原始代码差距很大,不易排查。
- 依赖框架成熟度:部分框架实现尚不完善,可能遇到算子不支持等瓶颈。
如何开始使用自动并行化?
- 选择支持较好的框架:若可接受函数式编程,JAX 是自动并行的先行者;若必须使用 PyTorch,可以利用 FSDP +
torch.compile逐步尝试。 - 理解设备网格概念:自动并行最基本的输入是设备拓扑,学会用
Mesh描述你的计算资源。 - 从简单的线性模型开始:先让框架自动探索一个小型 MLP 或 ResNet 的策略,感受通信与计算的切分。
- 学会查看和分析搜索到的策略:多数框架会输出策略的文本描述或可视化,通过它学习并行规律。
- 逐步增加模型复杂度:当信任度建立后,再迁移到更大规模的模型,并对比手动优化的性能。
总结
自动并行化代表分布式训练工具链的一次进化,它将复杂的切分决策交给编译器,让算法研究员和普通工程师可以将注意力集中在模型本身。虽然它还处在快速发展阶段,但已经成为下一代训练框架的必备能力。了解其原理和典型用法,将使你在面对大规模模型训练时多一个强有力的选择。