序列并行:长上下文训练时的维度切分
什么是序列并行
序列并行(Sequence Parallelism,SP)是一种专门为长上下文训练设计的分布式训练技术。在Transformer模型训练中,当输入序列长度变得极大(例如数万甚至数十万个token)时,单张GPU的显存已无法容纳整个激活张量。序列并行的核心思想是沿着序列维度(token 维度)对激活、计算和存储进行切分,将长序列分割成多个子序列,分配到不同的设备上并行处理,从而支持无限长的上下文训练。
与数据并行、张量并行、流水线并行不同,序列并行直接解决的是激活显存随序列长度平方增长的自注意力瓶颈,而非单纯切分模型参数。
为什么需要序列并行
自注意力的显存陷阱
标准的多头自注意力(Multi-Head Self-Attention)所需显存可概括为:
- 查询(Q)、键(K)、值(V)矩阵:形状为
[batch, seq_len, hidden_dim]→ 显存为 O(seq_len) - 注意力分数矩阵:
Q × K^T→ 形状[batch, num_heads, seq_len, seq_len]→ 显存为 O(seq_len²) - 注意力输出:与V的乘积,同样涉及
seq_len²的中间张量
当 seq_len 从 2k 增至 128k 时,注意力矩阵的显存占用会增长数千倍,成为训练长上下文模型的绝对瓶颈。单纯增加 GPU 数量无法解决这个问题,因为每张卡仍需要完整的 seq_len 维度才能计算局部注意力。
传统并行策略的局限
常见并行策略对长序列的支持能力有限:
| 并行策略 | 切分维度 | 能否解决长序列激活瓶颈 | 主要问题 |
|---|---|---|---|
| 数据并行 | 批次(batch) | 否 | 每张卡仍需完整序列长度 |
| 张量并行 | 隐藏层/注意力头 | 部分 | 只切分参数和计算,序列维度保留 |
| 流水线并行 | 层 | 否 | 微批次内仍为完整序列 |
| 序列并行 | 序列(seq) | 是 | 专为长序列设计,通信模式需精心设计 |
序列并行可与上述所有并行策略正交叠加,构建出支持超长上下文训练的复合并行方案。
序列并行的核心设计理念
序列并行需要解决两个核心问题:
- 如何切分长序列:将输入
[B, S, D]沿序列维度拆分为[B, S/N, D]并分配到 N 个设备上。 - 如何在不持有完整序列的情况下计算自注意力:因为注意力需要看到所有 token,必须设计跨设备的通信机制来完成全局注意力的近似或精确计算。
根据注意力计算方式的不同,序列并行大致分为两类:
- 稀疏化方法:迫使每个 token 只关注局部或特定模式的 token,自然适配序列切分(如 Local Attention、Sparse Attention)。
- 全注意力方法:通过环形通信(Ring Attention)、分块重计算或异步 All-to-All 等方式,在严格保证数学等价的同时实现序列拆分。
本教程重点介绍工业界广泛采用的全注意力序列并行。
经典实现一:Ring Attention(环形自注意力)
Ring Attention 是一种优雅的序列并行方案,它利用环形通信在设备间传递分块的 Key 和 Value,从而在无全局同步开销的情况下完成近似的全局注意力计算(数学上是精确的,但分块次序不同)。
工作流程
假设有 N 个设备,序列被均匀切成 N 块。每个设备持有完整的 Q 和局部的 K、V 块。
- 初始化:各设备用本地 Q 块与本地 K、V 块计算部分注意力,得到部分输出。同时将本地的 K、V 块发送给下一个设备。
- 循环通信 - 计算:重复 N-1 次以下步骤:
- 接收上一个设备传来的 K、V 块;
- 用本地 Q 块与收到的 K、V 块计算额外的部分注意力,累加到输出中;
- 将收到的 K、V 块传递给下一个设备。
- 经过 N 步后,每个设备的 Q 块都“看遍”了所有 K、V 块,完成了全序列注意力计算。输出结果在本地即可获得,无需最终汇聚。
这一机制实现了计算与通信的完美重叠:在等待下一个 K、V 块传输的同时,可以开始当前步的计算。
显存优势
Ring Attention 将每个设备上的注意力矩阵显存从 S² 降至 S×(S/N) = S²/N,序列维度上的显存随设备数线性下降,从而使训练百万级别 token 的上下文成为可能。
经典实现二:Megatron-SP(序列并行 + 张量并行融合)
NVIDIA Megatron-LM 框架中的序列并行提供了一种更轻量级的方案,专为 Transformer 结构中的 Dropout 和 LayerNorm 等非注意力部分设计,并常与张量并行结合使用,以减少冗余激活。
机制
不直接切分自注意力的序列维度,而是将 Transformer 块中与序列维度无关但消耗大量显存的操作进行并行化:
- LayerNorm:原本在每个设备上完整计算
[B, S, D],序列并行下可将序列维度切分,各设备计算[B, S/N, D]的统计量,然后通过 All-Gather 得到完整结果(或采用分布式 LayerNorm 直接输出分块)。 - Dropout:各设备独立对本地序列块执行 dropout,无需通信。
- 残差连接中的加法:输入被切分,加法天然本地执行。
在 Megatron-SP 中,真正的注意力计算仍依赖于张量并行(切分注意力头),而序列并行主要用于减轻 LayerNorm 和 Dropout 的激活显存压力。该方案的优势在于通信量小,实现简单,适合与张量并行深度集成。
通信与计算的权衡
序列并行的不同实现对应不同的通信模式:
| 方案 | 通信操作 | 通信量(相对) | 同步模式 |
|---|---|---|---|
| Ring Attention | P2P发送/接收 K、V 块 | O(B × S × D × N) 总传输量 | 管线式环形传递 |
| Megatron-SP (LayerNorm前向) | All-Gather 用于恢复序列 | O(B × S × D) | 全局同步 |
| 序列并行的 all-to-all 方案 | All-to-All 重新分发 QKV 块 | O(B × S × D) | 全局同步 |
Ring Attention 通过环形 P2P 通信将全局同步的瓶颈打散,更适合极端序列长度下的弹性扩展;而基于 All-Gather 的方案则代码侵入性小,更容易在现有框架上实现。
与其他并行策略的协同
序列并行是训练超长上下文大模型的基石,但它通常不会单独使用。典型的组合方式如下:
例子:训练一个 128K 上下文长度的 Transformer 模型
- 第一层并行:序列并行(切分序列),将 128K 序列分到 8 张 GPU 上,每卡只有 16K token。
- 第二层并行:张量并行(切分注意力头和 FFN 矩阵),进一步降低单卡模型参数和计算量。
- 第三层并行:流水线并行(切分层),将 Transformer 层分组分配到不同节点。
- 外层加速:数据并行(切分 batch),提升整体吞吐。
这样的层次化并行方案实现了计算、显存、通信的三维平衡,使得原本需要数千 GB 显存的长上下文训练能够在数百张 GPU 上高效运行。
动手实践:使用 Ring Attention 的最小示例(伪代码)
以下伪代码展示了 Ring Attention 的核心循环,假设使用 torch.distributed 通信原语:
def ring_attention(q, k, v, rank, world_size, ring_group):
# q, k, v 均为本地块:[batch, local_seq_len, num_heads, head_dim]
# ring_group 包含按环形顺序排列的设备列表
# 初始计算本地块
attn_out = scaled_dot_product_attention(q, k, v)
# 获取当前 k, v 的发送副本
send_k, send_v = k.clone(), v.clone()
# 确定环形中的下一个和前一个设备
next_rank = ring_group[(ring_group.index(rank) + 1) % world_size]
prev_rank = ring_group[(ring_group.index(rank) - 1) % world_size]
for step in range(world_size - 1):
# 异步发送本地k,v,异步接收远端k,v
recv_k = torch.empty_like(k)
recv_v = torch.empty_like(v)
send_op_k = dist.isend(send_k, dst=next_rank)
send_op_v = dist.isend(send_v, dst=next_rank)
recv_op_k = dist.irecv(recv_k, src=prev_rank)
recv_op_v = dist.irecv(recv_v, src=prev_rank)
# 等待通信完成(实际实现中可重叠计算与通信)
recv_op_k.wait()
recv_op_v.wait()
send_op_k.wait()
send_op_v.wait()
# 使用接收到的k,v继续计算注意力并累加
attn_out += scaled_dot_product_attention(q, recv_k, recv_v)
# 将收到的块继续向下传递
send_k, send_v = recv_k, recv_v
return attn_out
实际训练代码需集成到 Transformer 层中,并处理好反向传播的对应通信。社区开源库(如 FlashAttention、xFormers)已提供序列并行的生产级封装,可直接调用。
常见问题
序列并行会增加计算量吗?
数学上不增加,计算总量不变,但分块计算可能引入额外的重计算(如 Megatron-SP 中的 LayerNorm 重算)。Ring Attention 的总计算量严格等于标准注意力,但因为数据局部性更优,实际效率可能更高。
序列长度可以无限扩展吗?
理论上可以随设备数量线性扩展。实际受限于网络带宽和延迟:Ring Attention 的通信步数与设备数成正比,当设备过多时,环形流水线效率会下降。针对此问题,学术界提出了基于 All-to-All 的 Striped Attention、Tree Attention 等方案来降低通信步数。
如何与 FlashAttention 结合?
序列并行与 FlashAttention 天然互补。FlashAttention 优化单个块内的注意力计算,而序列并行处理块之间的全局通信。将两者结合(例如在 Ring Attention 的每个本地计算步使用 FlashAttention)可以同时获得块级IO优化和跨设备的序列切分收益,这也是目前训练超长上下文的最前沿方案。