GSPMD:统一描述设备切分与通信的编译器抽象
GSPMD:统一描述设备切分与通信的编译器抽象
1. 为什么需要 GSPMD?
在训练大型机器学习模型(如 GPT‑4、PaLM)时,单个加速器(GPU/TPU)无法容纳完整的模型参数或训练数据,必须使用分布式并行训练。常见的分布式策略包括:
- 数据并行:每个设备拥有模型完全副本,处理不同数据批次,然后同步梯度。
- 模型并行:将模型不同层放在不同设备上,流水线式执行。
- 张量并行:将单个算子(例如矩阵乘法)的计算切分到多个设备上,每个设备保留部分张量。
这些策略通常由用户手动指定,并且每种策略的通信模式(AllReduce、AllGather、CollectivePermute)和切分方式各不相同。随着模型规模增长,混合使用多种并行策略变得不可避免,编写高效且正确的分布式代码变得极其复杂。
GSPMD(General and Scalable Parallelization for ML Computation Graphs)是 Google 提出的一种编译器抽象,它提供统一的方式描述任意计算图中张量的设备切分方式,然后由编译器自动推导出所需的通信算子和设备间数据搬运,用户只需要对关键张量标注“希望如何切分”,无需手写通信代码。
2. GSPMD 的核心思想
GSPMD 的核心是基于分片注解的编译转换。它将分布式问题分解为两个独立的阶段:
- 用户标注分片方式:告诉编译器程序中的某些张量应该按照什么方式切分到不同的设备上(例如沿某个维度切分或者复制)。
- 编译器自动推导通信:根据算子的语义和输入输出的分片方式,自动插入集合通信操作(如 AllReduce、AllGather)或设备间数据搬运,使程序变成合法的多设备执行图。
这种分离让用户只需关心“数据怎么分片”,而不必手动安排“如何通信”。GSPMD 同样是一种中间表示(IR),它能够作为高层并行策略和底层编译器优化的桥梁。
3. GSPMD 的分片标注
GSPMD 使用 sharding 注解来描述张量在设备网格上的分布。一个设备网格是一个多维逻辑设备阵列,例如 4×2 的 TPU 网格。
3.1 基本分片类型
对于张量的每个维度,用户可以指定以下三种分片方式之一:
- 复制(replicated):该维度不切分,所有设备持有相同的数据。
- 分片(sharded):沿该维度将张量切分成若干块,每块放置到一个设备上。需要指定该维度映射到设备网格的哪个维度。
- 未指定:交由编译器自动决定。
例如,对于一个形状为 [B, L, D] 的张量,若标注为:
B维度:复制L维度:映射到设备网格的第 0 维D维度:映射到设备网格的第 1 维
则表示该张量将按 L 和 D 两个维度被切分成一个二维设备网格上的块,而批次维度则被复制。
3.2 整张量的复制与全分片
- 全复制:所有维度均复制,即传统数据并行。
- 全分片:如 ZeRO‑3 或者张量并行的极端情况。
4. GSPMD 编译工作原理
编译过程分为几个关键步骤:
- 标注传播:用户只需在少数关键张量(如模型参数、输入数据)上标记分片。编译器通过前向传播和后向传播,尝试为图中其余每个张量推导出合法的分片标注。
- 冲突解决:当同一个张量从不同路径获得不一致的分片标注时,编译器会插入必要的**重分片(resharding)**操作,这些重分片就是集合通信(AllGather、AllReduce、All‑to‑All 等)或者简单的本地拼接/切分。
- 算子转换:对于每个算子,GSPMD 根据其输入和输出的分片投影,确定该算子在每个设备上应该执行的部分计算。部分算子(如 ReLU)天然支持分片执行,部分算子(如 Softmax)可能需要全局规约。
- 生成多设备代码:最终输出一个在设备网格上可执行的 HLO/MLIR 图,每个设备拥有自己的局部计算图,通信操作显式插入。
以矩阵乘法 C = A × B 为例:
- 若
A按行分片(切分第一个维度),B被复制,则C也自然按行分片,不需要通信。 - 若
A按列分片(切分第二个维度),B也按对应维度分片,那么每个设备只能计算部分乘积累,需要后续 AllReduce 得到完整C。 - GSPMD 根据标注自动决定采用哪种方案,并在需要的地方插入 AllReduce。
5. 一个完整的示例
假设我们想训练一个 Transformer,使用 2×2 设备网格。我们只在输入嵌入矩阵 W_emb 上标注:沿着词表维度(第 1 维)映射到设备网格第 0 维,沿着嵌入维度(第 0 维)映射到设备网格第 1 维。
# 伪代码标注
sharding(W_emb) = (('mesh_dim_0', 'mesh_dim_1'),)
编译器将自动推导出:
- 嵌入查找操作后,输出的分词表示将沿着序列长度维度被复制(或者分片,取决于后续)。
- 注意力矩阵乘法
Q = X × W_Q中,根据W_Q的分片(可能由W_emb传播得出)自动插入 AllGather 或 AllReduce。 - 前馈网络层同理。
最终用户无需关心每一层如何划分,只要顶层规范,即可获得混合了数据并行、张量并行甚至流水并行的多设备程序。
6. GSPMD 的优势
- 简化编程:用户用少量注解替代大量手动通信代码。
- 灵活组合:同一个程序中可以无缝混合多种并行策略,GSPMD 自动处理边界。
- 可复用优化:编译器可以进行全局重分片优化,减少冗余通信。
- 硬件无关:分片标注与设备网格抽象不依赖具体硬件拓扑,可迁移。
7. GSPMD 在现有框架中的应用
GSPMD 已经成为多个机器学习编译器的核心组件:
- JAX:通过
pmap或自动并行化的pjit/shard_map使用 GSPMD 思想。 - TensorFlow:
DTensor是 GSPMD 的 TensorFlow 实现,提供tf.experimental.dtensor接口。 - PyTorch:类似于
DTensor的torch.distributed.tensor也借鉴了 GSPMD 的统一分片抽象。
8. 总结
GSPMD 提供了一种统一描述设备切分与通信的编译器抽象,将复杂的分布式并行问题简化为张量分片注解和编译推导。通过学习 GSPMD,初学者可以理解现代分布式训练框架为什么能够自动处理数据并行、模型并行和张量并行的混合策略——内核正是这种将“如何切分数据”与“如何通信”解耦的设计。
如果你想进一步动手实践,推荐从 JAX 的 pjit 或 TensorFlow 的 DTensor 教程开始,体验只用几行注解就能在多个 TPU/GPU 上运行大型模型。