多查询注意力 MQA:共享键值头的推理加速技术

FreeGuideOnline 最新 2026-06-22

多查询注意力 MQA:共享键值头的推理加速技术

Transformer 模型在自然语言处理领域取得了巨大成功,但其解码阶段的自回归特性带来了严重的推理延迟。多查询注意力(Multi-Query Attention,MQA)正是为解决这一瓶颈而提出的一种高效注意力变体。本教程将带你从基础概念出发,逐步理解 MQA 的原理、实现与优势。


1. 背景:多头注意力及其推理瓶颈

在深入 MQA 之前,我们需要回顾标准的多头注意力(Multi-Head Attention,MHA)机制。

1.1 多头注意力 MHA 是如何工作的?

给定输入序列,MHA 通过 h 个并行的注意力头捕捉不同子空间的信息。对于每个头 i,存在独立的线性投影矩阵:

  • 查询投影:W_i^Q
  • 键投影:W_i^K
  • 值投影:W_i^V

每个头都执行缩放点积注意力,最后将所有头的输出拼接起来再进行一次线性变换。

  • 参数量h × (d_model × d_k + d_model × d_v + d_model × d_k),其中 d_k = d_v = d_model // h

1.2 自回归解码时的瓶颈

在生成文本时,模型每生成一个新 token,都需要:

  1. 计算该 token 的查询向量 (Q)。
  2. 将所有历史 token 的键 (K) 和值 (V) 重新计算并拼接(或在推理时通过缓存技术保存)。
  3. 执行注意力计算。

关键瓶颈:每次解码步骤都需要从显存中读取全部头的完整 K 和 V 矩阵。当序列变长、批量增大时,显存带宽成为主要限制,导致推理速度大幅下降。


2. 多查询注意力 MQA:核心思想

MQA 在 2019 年的论文《Fast Transformer Decoding: One Write-Head is All You Need》中被提出,其核心改造极其直接:

保留多个查询头,但在所有头之间共享同一套键(K)和值(V)的投影。

也就是说,MQA 中:

  • 查询 (Q):仍然使用 h 个独立的投影头,保持模型对不同信息的提取能力。
  • 键 (K) 和值 (V):仅保留一套投影矩阵,所有查询头共享同一个 K 和同一个 V。

2.1 直观理解

可以把 MHA 想象成 h 个独立的“读者”,每人手里有一份自己写的“问题”(Q),同时每人拿着一份自己整理的“参考资料”(K、V)。而 MQA 则是这 h 个读者共用同一套“参考资料”,但问题仍然可以各不相同。每位读者都用相同的参考资料回答自己的问题,这大幅减少了需要携带和查阅的资料量。

2.2 架构对比图

标准 MHA:                    
Q: [Head1 Q, Head2 Q, ...]  → 各自匹配自己的 K,V  
K: [Head1 K, Head2 K, ...]  
V: [Head1 V, Head2 V, ...]  

MQA:                        
Q: [Head1 Q, Head2 Q, ...]  → 全部匹配共用的 K, V  
K: [Shared K]               
V: [Shared V]               

3. 推理加速原理

MQA 对推理的优化直接体现在 KV 缓存(KV Cache)的尺寸上。

在自回归解码时,为免重复计算,模型会将之前步骤的 K 和 V 保存在显存中。KV 缓存大小由 批次大小 × 序列长度 × 头数 × 每头维度 决定。

  • MHA 的 KV 缓存尺寸B × S × h × d_k
  • MQA 的 KV 缓存尺寸B × S × 1 × d_k (因为 头数 变为 1)

这意味着 MQA 将 KV 缓存减少了 h 倍。在长序列推理中,巨大的 KV 缓存读取是主要速度瓶颈,MQA 通过降低内存带宽需求,直接提升了推理吞吐量。实验表明,在批量解码时,MQA 的加速比可以接近 h 倍。


4. 分组查询注意力 GQA:实用的折中

虽然 MQA 大幅提升速度,但共享 K、V 会轻微损失模型质量。为此,2023 年的 LLaMA 2 等模型采用了分组查询注意力(Grouped-Query Attention,GQA)

GQA 将查询头分成 g 个组,每个组共享一套 K、V 投影。

  • g = 1 时,GQA 等价于 MQA。
  • g = h 时,GQA 等价于 MHA。

GQA 在内存和效果之间提供了更好的平衡,是目前大模型推理部署的主流选择。理解 MQA 是理解 GQA 的基础。


5. 代码级实现示例

以下是一个简化的 MQA 实现思路(使用 PyTorch 风格伪代码)。假设 d_model = 512, h = 8, d_k = 64

class MultiQueryAttention(nn.Module):
    def __init__(self, d_model, h, d_k):
        super().__init__()
        self.h = h
        self.d_k = d_k
        # 查询投影:仍然为 h 个头
        self.W_Q = nn.Linear(d_model, h * d_k)
        # 键和值投影:仅一套(1 个头)
        self.W_K = nn.Linear(d_model, d_k)
        self.W_V = nn.Linear(d_model, d_k)
        # 输出投影
        self.out = nn.Linear(h * d_k, d_model)

    def forward(self, x):
        # x: (batch, seq_len, d_model)
        # 计算 Q,并拆分为 (batch, seq_len, h, d_k)
        Q = self.W_Q(x).view(batch, seq_len, self.h, self.d_k)
        # 计算共享的 K 和 V,形状为 (batch, seq_len, d_k)
        K = self.W_K(x)
        V = self.W_V(x)
        # 将 K, V 扩展维度以便广播:(batch, seq_len, 1, d_k)
        K = K.unsqueeze(2)
        V = V.unsqueeze(2)

        # 缩放点积注意力
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.d_k ** 0.5)
        attn_weights = F.softmax(attn_scores, dim=-1)
        # 输出 (batch, seq_len, h, d_k)
        heads_out = torch.matmul(attn_weights, V)
        # 拼接所有头并投影
        concat = heads_out.reshape(batch, seq_len, self.h * self.d_k)
        return self.out(concat)

实际推理框架会通过优化的 CUDA Kernel 批量执行注意力操作,并直接管理 KV 缓存。以上代码清晰地展示了参数数量的减少和计算逻辑。


6. 优缺权衡与适用场景

6.1 优点

  • 极速推理:KV 缓存尺寸锐减,显存带宽压力大幅降低,解码速度显著提升。
  • 训练时显存节省:虽然训练时仍保留全部中间状态,但参数减少本身也略微降低了训练显存占用。
  • 实现简单:只需修改 K、V 投影层的维度,无需更改模型主体架构。

6.2 缺点

  • 质量微降:共享 K、V 可能限制了多头注意力从不同表示子空间提取信息的能力,在极大规模模型上损失更小。
  • 灵活性降低:某些需要多样化键值匹配的任务(如信息检索增强生成)可能轻微受损。

6.3 何时使用

  • 推理吞吐要求极高的在线服务(聊天机器人、实时翻译)。
  • 需要在边缘设备或消费级 GPU 上部署大型语言模型。
  • 后续方案中,如果模型规模足够大,GQA 往往是更稳健的选择。

7. 总结

多查询注意力 MQA 通过一个简单的观察——“解码时瓶颈在于 KV 缓存的读取而非计算”——巧妙地用参数共享换取了显存带宽。它是现代高效 Transformer 解码技术的基石,启发了 GQA 等更通用的变体。理解 MQA,你将更深刻地掌握大模型推理优化的核心思路。

若你想进一步实验,可以尝试将标准 Transformer 解码器中的 MHA 替换为 MQA,并比较相同序列长度下的推理速度,差异会十分直观。