张量并行:切分 Transformer 层的分布式训练

FreeGuideOnline 最新 2026-06-14

什么是张量并行

张量并行(Tensor Parallelism,TP)是一种模型并行策略,专门用于将单个 Transformer 层 内部的参数张量切分到多个计算设备(如 GPU)上。与数据并行复制整个模型不同,张量并行直接对权重矩阵进行分片,让多个设备协同计算同一层的前向与反向传播,从而突破显存限制,训练更大的模型。

在分布式训练中,当单卡无法容纳一个完整的层时,张量并行是首选方案。它属于模型并行的一种细粒度形式,通常与流水线并行、数据并行结合使用,形成 3D 并行策略。


为什么需要张量并行

显存墙问题

随着模型参数规模增长,单张 GPU 显存远远不够。以一个简单的线性层 Y = XA 为例,如果矩阵 A 的大小为 4096×4096(FP16 约占 32MB),这看起来不大。但在 Transformer 中,一个层包含多头注意力的 Q、K、V、O 矩阵,以及前馈网络的 up、gate、down 矩阵,加上激活值和优化器状态,总显存会急剧膨胀。

数据并行的局限

纯数据并行要求每个设备都持有完整模型副本。当模型本身大小超过单卡显存时,数据并行无法直接使用。虽然可以通过梯度累积或 ZeRO 优化器部分缓解,但对于超大层(如嵌入层或单个注意力头的权重),仍需模型并行层面的切分。

计算效率

相比于流水线并行,张量并行在同一层内引入通信,但可以将计算和通信重叠,并利用高速的 NVLink 或 InfiniBand 降低延迟。对于 Transformer 结构,某些矩阵乘法天然适合切分,能实现计算负载均衡。


张量并行的核心思想

张量并行的本质是 将大型矩阵乘法分布到多个处理器上。一个全连接层的前向计算 Z = XW,输入 X 大小为 [B, S, H](批次×序列长度×隐藏维度),权重 W[H, 4H]。切分方式有两种基本形式:行切分列切分

列切分(Column-wise Sharding)

将权重矩阵 W 按列分成多块,例如分成两份 W1W2,每块大小 [H, 2H]。输入 X 是完整副本,分别计算 Z1 = X W1Z2 = X W2,然后将结果在列维度拼接:

Z = [Z1, Z2]

这种方式在前向时无需通信,但需要 X 在每个设备上都有副本。在反向传播时,梯度需要在列方向拼接后传递给前一层。

行切分(Row-wise Sharding)

如果后续再接一个线性层 Y = ZV,其中 V 的大小是 [4H, H]。此时可以采用行切分:将 V 按行分成 V1V2,每块大小 [2H, H]。前一层列切分的输出 Z1Z2 直接与本地的 V1V2 相乘,得到 Y1 = Z1 V1Y2 = Z2 V2,然后使用 All-Reduce 求和得到最终输出:

Y = Y1 + Y2

这种前后配合的方式,使得在 Transformer MLP 或注意力模块中,只需在行切分层后进行一次 All-Reduce,列切分层无需通信。只需要在前向传递中,对输入 X 保持相同(通常是复制),就可以形成高效的并行模式。


Transformer 层的张量并行切分

MLP 块切分

典型的 Transformer MLP 包含两个全连接层(以 GPT 风格为例,gate 和 up 投影后逐元素相乘,再 down 投影)。我们以简单线性 fc1: H -> 4Hfc2: 4H -> H 为例:

  1. fc1 采用列切分:权重 W1 按列分成 2 份,每份 [H, 2H],输入 X 复制给两个设备,分别计算 [B, S, 2H] 的输出。
  2. 激活函数:对本地输出应用 GELU 等激活,无需通信。
  3. fc2 采用行切分:权重 W2 按行分成 2 份,每份 [2H, H],与本地输出相乘得到 Y_i,形状 [B, S, H]
  4. All-Reduce 求和:所有设备的 Y_i 相加得到最终 MLP 输出。

前向过程中只在 fc2 后进行一次 All-Reduce;反向传播时,对应梯度流会反向执行 Reduce-Scatter。

自注意力块切分

注意力机制的 Q、K、V 投影矩阵非常适合列切分,输出投影矩阵适合行切分。

  • QKV 投影:将权重 W_QW_KW_V 按列切分到每个设备。例如头的总数为 num_heads,可以按头数均匀切分,每个设备负责一部分头(即每个设备持有若干完整头)。或者按隐藏维度列切分。
  • 注意力计算:每个设备使用本地的 QKV 执行对应的缩放点积注意力,得到本地的头输出。
  • 输出投影:将 W_O 按行切分,各设备计算本地输出 Y_i,最后 All-Reduce 得到注意力模块最终输出。

这种方式自然地将多头注意力分布在不同设备上,每个设备负责一部分头,减少了通信次数。

输入与输出层

  • 嵌入层:通常采用行切分(词汇表维度切分),每个设备存储一部分嵌入向量,然后通过 All-Gather 获取完整嵌入,或者使用 All-Reduce 分发。对于共享嵌入权重(输入与输出相同),需注意保持一致性。
  • 输出层:若词汇表很大,可将输出投影矩阵按列切分,计算 logits 后再通信;或使用交叉熵的分布式计算,如通过 All-Gather 获取完整 logits,但效率较低,实践中常设计专门的词汇表并行。

张量并行的通信模式

张量并行的核心通信算子包括 All-ReduceAll-Gather,在 Transformer 中常见两种模式:

模式一:列切分 + 行切分 + All-Reduce

即前文所述,前一层列切分输出无需通信,后一层行切分之后执行 All-Reduce 求和。这种模式将通信融合在层的末尾,有利于流水优化。

模式二:仅行切分并 All-Reduce

若某层无法前接列切分(如网络入口),可以对输入 X 进行行切分,即每个设备持有 X 的一部分,权重 W 完全复制。计算 Y_i = X_i W 后执行 All-Gather 得到完整结果。但这种方式的通信量更大,不常用。

前向与反向的通信对称

  • 前向 All-Reduce 对应反向中的 Identity(梯度复制)。
  • 实际训练中,反向传播会自动推导通信。对于行切分+All-Reduce,反向会变成 Reduce-Scatter。

实现细节与代码示例

以 PyTorch 为例,使用 torch.distributed 和简单的通信原语模拟 2 卡张量并行的 MLP:

import torch
import torch.nn as nn
import torch.distributed as dist

class ColumnParallelLinear(nn.Module):
    def __init__(self, in_features, out_features, world_size, gather_output=False):
        super().__init__()
        self.world_size = world_size
        self.gather_output = gather_output
        self.weight = nn.Parameter(torch.randn(in_features, out_features // world_size))

    def forward(self, x):
        y = x @ self.weight
        if self.gather_output:
            # 如果需要完整输出,All-Gather 拼接
            y_list = [torch.empty_like(y) for _ in range(self.world_size)]
            dist.all_gather(y_list, y)
            y = torch.cat(y_list, dim=-1)
        return y

class RowParallelLinear(nn.Module):
    def __init__(self, in_features, out_features, world_size):
        super().__init__()
        self.world_size = world_size
        self.weight = nn.Parameter(torch.randn(in_features // world_size, out_features))

    def forward(self, x):
        # x 已经是分片后的本地结果
        y = x @ self.weight
        dist.all_reduce(y)  # 求和
        return y

# 使用示例(2 卡)
world_size = 2
# 假设输入 x 完整复制到每个设备,形状 [B, S, H]
x = torch.randn(2, 512, 1024)  # 示例
fc1 = ColumnParallelLinear(1024, 4096, world_size)
activation = nn.GELU()
fc2 = RowParallelLinear(4096, 1024, world_size)

h = fc1(x)          # 本地 [B, S, 2048]
h = activation(h)
out = fc2(h)        # 本地 [B, S, 1024] -> all_reduce 得到完整输出

在真实框架(如 Megatron-LM、DeepSpeed、PyTorch 的 FSDP 配合 TP)中,这些通信和分片细节已被高度封装,但原理不变。


张量并行的优势与挑战

优势

  • 显存分担:每卡只需存储一份权重分片,激活值也可根据分片策略减少(如配合序列并行)。
  • 支持超大层:即使单个层超过单卡显存也能训练。
  • 计算效率高:在高速互联下,通信可隐藏在计算之后,整体吞吐接近线性扩展。

挑战

  • 通信开销:每层都需要 All-Reduce,对网络带宽和延迟敏感。通常在节点内利用 NVLink(高速),跨节点则尽量用流水线并行减少张量并行跨节点。
  • 负载均衡:需要谨慎分割维度,保证各设备计算量均等,避免闲置。
  • 实现复杂:需要修改模型代码,处理层归一化、残差连接等细节的放置。
  • 批次大小限制:张量并行通常与大的全局批次结合,可能不适合小批量场景。

与其他并行策略的结合

与数据并行结合

数据并行复制整个张量并行组。每个 TP 组内进行张量切分,组间进行数据并行。梯度在 TP 组内局部 All-Reduce 后,再跨 DP 组进行全局 All-Reduce。Megatron-LM 的 PTD-P(Pipeline-Tensor-Data Parallel)即此模式。

与流水线并行结合

将模型按层切分到多个 TP 组,每个组负责连续的几层,组间通过流水线传递中间激活。这样可减少跨节点通信,将张量并行限制在节点内。

与 ZeRO 的关系

ZeRO 优化器的 Stage-3 同样对参数进行分片,但其分片维度是参数,而非计算中的矩阵乘法维度。ZeRO 分片在优化器状态和梯度层面,计算时依然使用完整参数(通过集合通信重建)。张量并行在计算时就保持分片,两者可以正交组合。


总结

张量并行通过切分单个 Transformer 层内的权重矩阵,实现大模型分布式训练的内存和计算均衡。核心是列切分与行切分的组合,配合 All-Reduce 实现层内通信。理解了 MLP 和自注意力的切分模式,就能很好地掌握张量并行的设计。在实践中,张量并行常与数据并行、流水线并行协同,形成完整的 3D 并行方案,是当今千亿、万亿参数模型训练的基础技术之一。