线性注意力机制:突破二次复杂度瓶颈的高效方法
引言:注意力机制的算力困局
在深度学习中,注意力机制(Attention)已成为序列建模与Transformer架构的核心。然而,标准缩放点积注意力(Scaled Dot-Product Attention)的计算复杂度与序列长度 N 呈二次方关系 O(N²)。当处理长文档、高分辨率图像或长视频时,显存与计算时间会急剧膨胀,成为性能瓶颈。
线性注意力机制正是为了解决这一痛点而生。它通过巧妙的数学变换,将复杂度降至 O(N),在几乎不牺牲模型表达能力的前提下,让长序列建模真正可行。本文将从基础原理出发,为你逐步拆解这一高效方案。
一、重温标准注意力的二次方根源
1.1 注意力计算的三个矩阵
标准注意力公式通常写作:
$$ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V $$
Q(Query)、K(Key)、V(Value) 的形状均为(N, d),N 为序列长度,d 为特征维度。QK^T产生一个(N, N)的注意力分数矩阵。- 对该矩阵进行 softmax 后,再与 V 相乘,得到输出。
1.2 复杂度分析:O(N²) 的由来
最耗时的步骤在于:
- 矩阵乘法
QK^T:计算量为N × d × N = N²d,即 O(N²d)。 - 注意力矩阵乘以 V:计算量为
N × N × d = N²d。
因此,总复杂度随 N 的平方线性增长。对于 10000 个 token 的序列,注意力矩阵将包含 1 亿个元素,显存占用巨大。
二、线性注意力的核心思想:核函数近似
2.1 从 Softmax 转向核函数
线性注意力的关键突破在于:将 softmax 分解为核函数的点积形式。
原始 softmax 可以表示为:
$$ \text{softmax}(QK^T){ij} = \frac{\exp(Q_i K_j^T)}{\sum{k=1}^N \exp(Q_i K_k^T)} $$
如果我们将指数函数替换为一组特征映射(核函数)的内积: $$ \exp(Q_i K_j^T) \approx \phi(Q_i) \phi(K_j)^T $$ 那么注意力就变为: $$ \text{Attention}(Q, K, V) = \frac{\phi(Q)(\phi(K)^T V)}{\phi(Q)(\phi(K)^T \mathbf{1}_N)} $$ 其中 1_N 是全1向量,用于归一化。
2.2 改变计算顺序:先算 K 和 V
观察变换后的分子:φ(K)^T V,这是一个 (d_kernel, d) 的矩阵乘法,其中 d_kernel 是特征映射后的维度。它的计算量与 N 无关,仅为 O(d_kernel × d)。
然后,再乘以 φ(Q),得到输出。通过优先计算 K 和 V 的乘积,我们完全避免了 N×N 的大型矩阵。这便是线性复杂度的精妙之处。
三、常见线性注意力实现方法
3.1 基于核函数的方法:Performer
Google 提出的 Performer 使用 正交随机特征(FAVOR+) 来近似 softmax。
- 选择特征映射
φ(x) = 1/√m [f₁(x), ..., f_m(x)],其中f_k是随机傅里叶特征。 - 实现方式:
Q' = φ(Q),K' = φ(K),然后计算KV = K'^T V,最后输出Q' KV,并做必要的归一化。 - 优点:无偏近似,可证明逼近标准注意力;缺点:需保证映射维度 m 足够大以保持精度。
3.2 基于线性代数技巧的方法:Linear Attention (Efficient Attention)
另一条路径是直接使用简单的非线性函数作为核,例如 ELU + 1 函数:
φ(x) = elu(x) + 1
计算流程:
K_tilde = φ(K),V_tilde = φ(K) V(元素对应相乘后再与 V 结合?实际是先将 K 映射后存储)- 计算归一化项
Z = φ(K)^T 1,即每行和。 - 输出 =
φ(Q) (K_tilde^T V) / φ(Q) Z
这种方法的映射维度与原始 d 相同,避免了额外计算开销,在视觉任务和语言任务中都表现良好。
3.3 基于低秩分解的方法:Linformer
Linformer 假设自注意力矩阵是低秩的,因此可以用一个小矩阵去近似。
- 额外引入两个投影矩阵
E, F ∈ R^{N×k},其中 k 远小于 N。 - 实际上通过将 K 和 V 的序列长度从 N 压缩到 k(可学习投影),将复杂度降至 O(Nk)。
- 这是一种参数化的线性复杂度,但需要固定 k,适合于最大序列长度已知的场景。
四、手把手实现一个简单的线性注意力层
以下代码基于 PyTorch,演示使用 elu+1 的线性注意力层。请基于 PyTorch 1.10+ 运行。
import torch
import torch.nn as nn
import torch.nn.functional as F
class LinearAttention(nn.Module):
def __init__(self, dim, heads=8, dim_head=64):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head ** -0.5
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
self.to_out = nn.Linear(inner_dim, dim)
def forward(self, x):
# x: (batch, n, dim)
qkv = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t: t.reshape(*t.shape[:-1], self.heads, -1).transpose(1, 2), qkv)
# 非线性映射(线性注意力的核心)
q = F.elu(q) + 1
k = F.elu(k) + 1
# 计算 K^T V (先计算,复杂度 O(N·d·d_head·heads))
# 实际实现时用 einsum 保持清晰
# k: (batch, heads, n, d_head), v: (batch, heads, n, d_head)
# 我们希望计算 sum over n: k_trans = (batch, heads, d_head, n)
k_trans = k.transpose(-2, -1)
kv = torch.einsum('b h d n, b h n e -> b h d e', k_trans, v) # (b, heads, d_head, d_head)
# 计算归一化分母 Z = sum(φ(K)),形状 (b, heads, d_head)
z = k.sum(dim=2, keepdim=True) # (b, heads, 1, d_head)
# 分子:φ(Q) 与 kv 相乘
out = torch.einsum('b h n d, b h d e -> b h n e', q, kv) # (b, heads, n, d_head)
# 除以归一化项(避免除0加一个小量)
out = out / (q * z).sum(dim=-1, keepdim=True).clamp(min=1e-6)
# 合并多头
out = out.transpose(1, 2).reshape(x.shape[0], -1, self.heads * q.shape[-1])
return self.to_out(out)
关键解读:
- 经过
elu+1后,所有值均为正,保证注意力权重的非负性。 - 计算
kv = K^T V的复杂度为O(N · d_head^2 · heads),与 N 线性相关。 - 最后的归一化类似于 softmax 中的分母,但计算方式完全基于核函数的和。
五、线性注意力的优势与局限
5.1 优势
- 线性复杂度:显存与时间随序列长度线性增长,可轻松处理 4096 甚至超过万级别的 token。
- 易于实现:改动小,兼容现有 Transformer 结构。
- 支持因果掩码:通过递推形式,线性注意力天然支持自回归解码,无需存储完整注意力矩阵。
5.2 局限与挑战
- 模型容量:简单核函数(如 Elu)的表达能力弱于 softmax,可能丢失部分长距离依赖的精细模式。
- 精度取舍:随机特征映射需要足够大的映射维度才能逼近标准注意力,这会在速度与精度之间权衡。
- 训练稳定性:某些核函数可能引发梯度消失或数值不稳定,需小心设计。
5.3 何时选择线性注意力?
- 序列长度 > 1024 且资源受限。
- 对推理速度、吞吐量有极致要求的场景(如实时长语音识别、长文本摘要)。
- 需处理高分辨率特征图(如图像、视频 Transformer)。
如果任务中序列长度始终较短(N<512),标准注意力可能仍是更稳健的选择。
六、进阶方向与最新发展
- 门控线性注意力 (GLA):引入类似门控机制,动态调整信息流,提升表达力。
- Attention Free Transformer (AFT):完全抛弃点积,用逐元素操作代替。
- RetNet:微软提出的保留网络,在训练时保持并行,推理时变为循环形式,达到 O(1) 显存复杂度。
- 混合策略:将标准注意力与线性注意力混合,在局部高精度建模与全局线性概括之间取得平衡。
七、总结
线性注意力机制通过核函数近似与计算重排,优雅地将自注意力的二次方复杂度降低到线性。它并非万能替代,但在许多长序列应用中展现了显著的价值。理解其背后的思想,能帮助你根据实际任务选择合适的注意力方案,构建更高效的模型。
下一步:建议在 Colab 中运行上面的代码示例,尝试不同序列长度下的显存占用,亲身体验线性注意力的优势。同时,也可查阅 Performer、Linear Transformer 等原论文,深入了解细节。