张量并行:切分 Transformer 层的分布式训练
什么是张量并行
张量并行(Tensor Parallelism,TP)是一种模型并行策略,专门用于将单个 Transformer 层 内部的参数张量切分到多个计算设备(如 GPU)上。与数据并行复制整个模型不同,张量并行直接对权重矩阵进行分片,让多个设备协同计算同一层的前向与反向传播,从而突破显存限制,训练更大的模型。
在分布式训练中,当单卡无法容纳一个完整的层时,张量并行是首选方案。它属于模型并行的一种细粒度形式,通常与流水线并行、数据并行结合使用,形成 3D 并行策略。
为什么需要张量并行
显存墙问题
随着模型参数规模增长,单张 GPU 显存远远不够。以一个简单的线性层 Y = XA 为例,如果矩阵 A 的大小为 4096×4096(FP16 约占 32MB),这看起来不大。但在 Transformer 中,一个层包含多头注意力的 Q、K、V、O 矩阵,以及前馈网络的 up、gate、down 矩阵,加上激活值和优化器状态,总显存会急剧膨胀。
数据并行的局限
纯数据并行要求每个设备都持有完整模型副本。当模型本身大小超过单卡显存时,数据并行无法直接使用。虽然可以通过梯度累积或 ZeRO 优化器部分缓解,但对于超大层(如嵌入层或单个注意力头的权重),仍需模型并行层面的切分。
计算效率
相比于流水线并行,张量并行在同一层内引入通信,但可以将计算和通信重叠,并利用高速的 NVLink 或 InfiniBand 降低延迟。对于 Transformer 结构,某些矩阵乘法天然适合切分,能实现计算负载均衡。
张量并行的核心思想
张量并行的本质是 将大型矩阵乘法分布到多个处理器上。一个全连接层的前向计算 Z = XW,输入 X 大小为 [B, S, H](批次×序列长度×隐藏维度),权重 W 为 [H, 4H]。切分方式有两种基本形式:行切分 和 列切分。
列切分(Column-wise Sharding)
将权重矩阵 W 按列分成多块,例如分成两份 W1 和 W2,每块大小 [H, 2H]。输入 X 是完整副本,分别计算 Z1 = X W1 和 Z2 = X W2,然后将结果在列维度拼接:
Z = [Z1, Z2]
这种方式在前向时无需通信,但需要 X 在每个设备上都有副本。在反向传播时,梯度需要在列方向拼接后传递给前一层。
行切分(Row-wise Sharding)
如果后续再接一个线性层 Y = ZV,其中 V 的大小是 [4H, H]。此时可以采用行切分:将 V 按行分成 V1、V2,每块大小 [2H, H]。前一层列切分的输出 Z1、Z2 直接与本地的 V1、V2 相乘,得到 Y1 = Z1 V1 和 Y2 = Z2 V2,然后使用 All-Reduce 求和得到最终输出:
Y = Y1 + Y2
这种前后配合的方式,使得在 Transformer MLP 或注意力模块中,只需在行切分层后进行一次 All-Reduce,列切分层无需通信。只需要在前向传递中,对输入 X 保持相同(通常是复制),就可以形成高效的并行模式。
Transformer 层的张量并行切分
MLP 块切分
典型的 Transformer MLP 包含两个全连接层(以 GPT 风格为例,gate 和 up 投影后逐元素相乘,再 down 投影)。我们以简单线性 fc1: H -> 4H,fc2: 4H -> H 为例:
- fc1 采用列切分:权重
W1按列分成 2 份,每份[H, 2H],输入X复制给两个设备,分别计算[B, S, 2H]的输出。 - 激活函数:对本地输出应用 GELU 等激活,无需通信。
- fc2 采用行切分:权重
W2按行分成 2 份,每份[2H, H],与本地输出相乘得到Y_i,形状[B, S, H]。 - All-Reduce 求和:所有设备的
Y_i相加得到最终 MLP 输出。
前向过程中只在 fc2 后进行一次 All-Reduce;反向传播时,对应梯度流会反向执行 Reduce-Scatter。
自注意力块切分
注意力机制的 Q、K、V 投影矩阵非常适合列切分,输出投影矩阵适合行切分。
- QKV 投影:将权重
W_Q、W_K、W_V按列切分到每个设备。例如头的总数为num_heads,可以按头数均匀切分,每个设备负责一部分头(即每个设备持有若干完整头)。或者按隐藏维度列切分。 - 注意力计算:每个设备使用本地的
Q、K、V执行对应的缩放点积注意力,得到本地的头输出。 - 输出投影:将
W_O按行切分,各设备计算本地输出Y_i,最后 All-Reduce 得到注意力模块最终输出。
这种方式自然地将多头注意力分布在不同设备上,每个设备负责一部分头,减少了通信次数。
输入与输出层
- 嵌入层:通常采用行切分(词汇表维度切分),每个设备存储一部分嵌入向量,然后通过 All-Gather 获取完整嵌入,或者使用 All-Reduce 分发。对于共享嵌入权重(输入与输出相同),需注意保持一致性。
- 输出层:若词汇表很大,可将输出投影矩阵按列切分,计算 logits 后再通信;或使用交叉熵的分布式计算,如通过 All-Gather 获取完整 logits,但效率较低,实践中常设计专门的词汇表并行。
张量并行的通信模式
张量并行的核心通信算子包括 All-Reduce 和 All-Gather,在 Transformer 中常见两种模式:
模式一:列切分 + 行切分 + All-Reduce
即前文所述,前一层列切分输出无需通信,后一层行切分之后执行 All-Reduce 求和。这种模式将通信融合在层的末尾,有利于流水优化。
模式二:仅行切分并 All-Reduce
若某层无法前接列切分(如网络入口),可以对输入 X 进行行切分,即每个设备持有 X 的一部分,权重 W 完全复制。计算 Y_i = X_i W 后执行 All-Gather 得到完整结果。但这种方式的通信量更大,不常用。
前向与反向的通信对称
- 前向 All-Reduce 对应反向中的 Identity(梯度复制)。
- 实际训练中,反向传播会自动推导通信。对于行切分+All-Reduce,反向会变成 Reduce-Scatter。
实现细节与代码示例
以 PyTorch 为例,使用 torch.distributed 和简单的通信原语模拟 2 卡张量并行的 MLP:
import torch
import torch.nn as nn
import torch.distributed as dist
class ColumnParallelLinear(nn.Module):
def __init__(self, in_features, out_features, world_size, gather_output=False):
super().__init__()
self.world_size = world_size
self.gather_output = gather_output
self.weight = nn.Parameter(torch.randn(in_features, out_features // world_size))
def forward(self, x):
y = x @ self.weight
if self.gather_output:
# 如果需要完整输出,All-Gather 拼接
y_list = [torch.empty_like(y) for _ in range(self.world_size)]
dist.all_gather(y_list, y)
y = torch.cat(y_list, dim=-1)
return y
class RowParallelLinear(nn.Module):
def __init__(self, in_features, out_features, world_size):
super().__init__()
self.world_size = world_size
self.weight = nn.Parameter(torch.randn(in_features // world_size, out_features))
def forward(self, x):
# x 已经是分片后的本地结果
y = x @ self.weight
dist.all_reduce(y) # 求和
return y
# 使用示例(2 卡)
world_size = 2
# 假设输入 x 完整复制到每个设备,形状 [B, S, H]
x = torch.randn(2, 512, 1024) # 示例
fc1 = ColumnParallelLinear(1024, 4096, world_size)
activation = nn.GELU()
fc2 = RowParallelLinear(4096, 1024, world_size)
h = fc1(x) # 本地 [B, S, 2048]
h = activation(h)
out = fc2(h) # 本地 [B, S, 1024] -> all_reduce 得到完整输出
在真实框架(如 Megatron-LM、DeepSpeed、PyTorch 的 FSDP 配合 TP)中,这些通信和分片细节已被高度封装,但原理不变。
张量并行的优势与挑战
优势
- 显存分担:每卡只需存储一份权重分片,激活值也可根据分片策略减少(如配合序列并行)。
- 支持超大层:即使单个层超过单卡显存也能训练。
- 计算效率高:在高速互联下,通信可隐藏在计算之后,整体吞吐接近线性扩展。
挑战
- 通信开销:每层都需要 All-Reduce,对网络带宽和延迟敏感。通常在节点内利用 NVLink(高速),跨节点则尽量用流水线并行减少张量并行跨节点。
- 负载均衡:需要谨慎分割维度,保证各设备计算量均等,避免闲置。
- 实现复杂:需要修改模型代码,处理层归一化、残差连接等细节的放置。
- 批次大小限制:张量并行通常与大的全局批次结合,可能不适合小批量场景。
与其他并行策略的结合
与数据并行结合
数据并行复制整个张量并行组。每个 TP 组内进行张量切分,组间进行数据并行。梯度在 TP 组内局部 All-Reduce 后,再跨 DP 组进行全局 All-Reduce。Megatron-LM 的 PTD-P(Pipeline-Tensor-Data Parallel)即此模式。
与流水线并行结合
将模型按层切分到多个 TP 组,每个组负责连续的几层,组间通过流水线传递中间激活。这样可减少跨节点通信,将张量并行限制在节点内。
与 ZeRO 的关系
ZeRO 优化器的 Stage-3 同样对参数进行分片,但其分片维度是参数,而非计算中的矩阵乘法维度。ZeRO 分片在优化器状态和梯度层面,计算时依然使用完整参数(通过集合通信重建)。张量并行在计算时就保持分片,两者可以正交组合。
总结
张量并行通过切分单个 Transformer 层内的权重矩阵,实现大模型分布式训练的内存和计算均衡。核心是列切分与行切分的组合,配合 All-Reduce 实现层内通信。理解了 MLP 和自注意力的切分模式,就能很好地掌握张量并行的设计。在实践中,张量并行常与数据并行、流水线并行协同,形成完整的 3D 并行方案,是当今千亿、万亿参数模型训练的基础技术之一。