ALiBi:简单线性偏置实现长度外推的注意力
ALiBi 线性偏置:让 Transformer 轻松应对比训练时更长的序列
1. 什么是 ALiBi ?
ALiBi(Attention with Linear Biases)是一种为 Transformer 注意力机制添加位置信息的方法。它不向词向量添加位置编码,而是在计算注意力分数时,为不同距离的查询-键对加上一个固定的、预先定义的线性递减偏置。这种极其简单的设计让模型在推理时能够自然地处理比训练序列长得多的文本,即实现长度外推。
2. 为什么需要长度外推?
大语言模型通常在固定的上下文窗口(如512、2048个 token)上训练。但在实际应用中,我们可能希望模型总结长文档、进行多轮对话或处理长代码文件,这些任务的输入长度远超训练长度。
传统位置编码(如正弦编码、可学习绝对位置编码)在这种场景下会“失效”,因为模型从未见过超出训练长度的位置向量。强行外推往往导致 PPL(困惑度)爆炸,生成质量急剧下降。
ALiBi 的目标是:在不增加计算成本、无需微调的情况下,让你的模型轻松理解更长的序列。
3. 传统位置编码的困境
③.1 绝对位置编码
- 可学习绝对位置编码:每个位置有一个独立嵌入。模型根本无法外推,因为超出训练长度的位置没有对应的嵌入向量。
- 正弦位置编码:虽然具有理论上的无限长度能力,但实验表明,只靠正弦函数的外推性能十分脆弱,常需要配合其他技巧(如位置插值)。
③.2 相对位置编码
像 T5 这样的模型使用了相对位置偏置,但它是通过学习一个偏置查找表实现的。外推时,超出训练长度的相对距离会找不到对应的偏置值,需要截断或插值,这同样限制了外推能力。
ALiBi 另辟蹊径:不再“编码”位置,而是“惩罚”远程注意力。
4. ALiBi 的核心原理:减法偏置被距离定义
ALiBi 对标准注意力机制所做的修改可以用一行公式表达:
Attention Score(q, k) = q · k - m * |i - j|
其中:
q · k是查询向量和键向量的点积(缩放点积注意力)。|i - j|是查询位置i和键位置j之间的相对距离。m是一个头特定的斜率(slope),它是一个预先计算好的、不会在训练中改变的标量。
直观理解:当 token A 尝试关注 token B 时,如果两者距离很远,我们就从原始注意力分数中减去一个更大的值,让 A 更倾向于关注近距离的邻居,而抑制对遥远位置的注意力。 这种“近大远小”的偏置完全由距离决定,不涉及任何可学习参数的外推问题。
④.1 多头斜率设计
为了让不同注意力头捕获不同尺度的依赖关系,ALiBi 为每个头设置了不同的斜率 m:
- 对于
n个注意力头,斜率集合通常取几何级数,例如从2^{-2/n}到2^{-8/n}之间的等比数列。 - 具体做法是:取一个基础起始值(如
2^{-2/n})和基础结束值,按等比生成n个斜率。 - 较小斜率的头能够关注更远的距离,较大斜率的头则极度聚焦于局部上下文。这种“分工”让模型在没有位置编码的情况下依然能分辨远近。
5. 数学机制详解
我们以单头为例,从输入序列经过线性投影得到查询 Q 和键 K。假设序列长度为 L,i 为查询 token 的位置,j 为键 token 的位置。
-
计算原始缩放点积得分:
score(i, j) = (Q_i · K_j) / sqrt(d_k) -
计算相对距离偏置矩阵
B,其中B[i, j] = m * |i - j|。注意 ALiBi 使用的是减法偏置(即从得分中减去),不是加法。 -
修正后的注意力得分:
adjusted_score(i, j) = score(i, j) - m * |i - j| -
随后进行 Softmax 并乘以 Value 矩阵完成注意力聚合。
从外推的角度看: 如果你将序列长度从 512 外推到 2048,新出现的距离 |i - j| 会连续地获得相应的偏置值。没有超出范围的“索引”问题,因为偏置函数是纯距离的函数,可以无限延展。
6. 与 T5 相对偏置的对比
| 特性 | T5 相对位置偏置 | ALiBi 线性偏置 |
|---|---|---|
| 偏置来源 | 可学习的标量,按相对距离存入查找表 | 固定的、预定义的线性函数 m * distance |
| 最大外推距离 | 受查找表大小限制,需要截断或插值 | 无理论上限,函数天生可外推 |
| 额外参数量 | (2 * clip_distance + 1) * heads | 0(斜率为常数,不参与训练) |
| 计算开销 | 查表并相加 | 直接计算减法,极其廉价 |
ALiBi 用零参数、无截断的方案达到了更干净的长度外推效果。
7. 动手实现:PyTorch 风格的 ALiBi 注意力
下面是一个简化的实现示例,展示如何在自回归注意力中集成 ALiBi 偏置。
import torch
import math
def get_alibi_slopes(num_heads):
# 几何级数生成斜率,从 2^(-2/num_heads) 到 2^(-8/num_heads)
start = 2 ** (-2 / num_heads)
end = 2 ** (-8 / num_heads)
ratios = torch.linspace(0, 1, num_heads)
slopes = start * (end / start) ** ratios
return slopes # shape: (num_heads)
def create_alibi_bias(seq_len, slopes):
"""
生成上三角偏置矩阵用于因果注意力。
slopes: (num_heads) 斜率
返回: (1, num_heads, seq_len, seq_len) 偏置张量
"""
# 距离矩阵,因果掩码只允许 j <= i
positions = torch.arange(seq_len).unsqueeze(0) # (1, seq_len)
distances = positions - positions.T # (seq_len, seq_len) ,包含负值
distances = distances.abs().unsqueeze(0) # (1, seq_len, seq_len)
# 为每个头生成偏置:(num_heads, 1, 1) * (1, seq_len, seq_len) -> (num_heads, seq_len, seq_len)
slopes = slopes.view(-1, 1, 1)
bias = slopes * distances # 注意这里是减法偏置,用在 score 中时为 -bias
return -bias.unsqueeze(0) # (1, num_heads, seq_len, seq_len)
# 使用示例
class AlibiAttention(torch.nn.Module):
def __init__(self, embed_dim, num_heads):
super().__init__()
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.q_proj = torch.nn.Linear(embed_dim, embed_dim)
self.k_proj = torch.nn.Linear(embed_dim, embed_dim)
self.v_proj = torch.nn.Linear(embed_dim, embed_dim)
self.out_proj = torch.nn.Linear(embed_dim, embed_dim)
# 预先计算并注册为缓冲区(不参与梯度)
slopes = get_alibi_slopes(num_heads)
self.register_buffer('slopes', slopes)
def forward(self, x):
batch_size, seq_len, _ = x.shape
q = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1,2)
k = self.k_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1,2)
v = self.v_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1,2)
# 缩放点积
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
# 生成 ALiBi 偏置并加到 scores 上(实际上是减去偏置)
alibi_bias = create_alibi_bias(seq_len, self.slopes).to(x.device)
scores = scores + alibi_bias[:, :, :seq_len, :seq_len]
# 因果掩码(可选,根据任务)
causal_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool().to(x.device)
scores = scores.masked_fill(causal_mask, float('-inf'))
attn_weights = torch.softmax(scores, dim=-1)
out = torch.matmul(attn_weights, v)
out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
return self.out_proj(out)
几点注意:
- 斜率是固定常数,不参与优化。
- 该实现显式构建了偏置矩阵,对于极长序列可能占用内存。生产代码可以用计算距离的动态方式以节省内存。
- 如果需要双向注意力(如 BERT),不使用因果掩码,但偏置同样为
-m * |i-j|。
8. 为什么 ALiBi 能实现优秀的外推?
实验(出自论文《Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation》)表明:
- 在 WikiText-103 上,ALiBi 模型训练 1024 长度,直接外推到 2048、3072 甚至更长,PPL 几乎持平,而正弦位置编码的 PPL 急剧上升。
- 外推能力来源于偏置的“单调”与“无界”:距离增大,惩罚单调递增,且永远不会遇到未定义值。这与人类阅读长文本时自然衰减远距离关联的倾向一致。
- 模型在训练时学会了依赖距离偏置而不是绝对位置来理解次序,因此即使距离再大,只要惩罚模式延续,模型就能正常运作。
9. 应用场景与使用建议
✅ 适用场景
- 自回归语言模型(GPT 风格):天然适合,只需在因果注意力中加入 ALiBi 偏置。
- 长文本摘要、长文档问答:需要处理超长上下文的任务。
- 资源受限下的微调:你想将 2K 模型用于 8K 序列,但又不想完全重新训练位置编码。
⚠ 注意事项
- 仅对自注意力有效:交叉注意力(如编码器-解码器交互)通常不需要长度外推,但也可加入。
- 双向注意力(如 BERT 预训练)可以使用 ALiBi,此时偏置矩阵对称,
-m * |i-j|。 - 斜率初始化:使用几何级数作为起始斜率通常足够,但针对特定任务可以进行小范围网格搜索找到最优斜率分布。
- 与 Rotary Position Embedding (RoPE) 对比:两者都是当前流行的位置方案。RoPE 通过旋转矩阵编码相对位置,外推时通常需要配合位置插值;ALiBi 更纯粹,无需位置插值即可外推,但在某些基准上 RoPE + 插值可能表现更优。选择取决于你对简洁外推的偏好以及具体下游任务。
10. 总结
ALiBi 用一组头特定的恒定斜率,在注意力分数上施加距离惩罚,极其优雅地解决了 Transformer 的长度外推难题。它带来的启示是:有时候,放弃复杂的可学习编码,转为使用一个反映距离单调性的静态偏置,反而能获得更好的泛化能力。
对于希望让模型“训练短,测试长”的开发者来说,ALiBi 是一种轻量、高效且易于实现的位置感知方案。