推测解码:草稿模型辅助的投机性加速推理
什么是推测解码
推测解码是一种用于自回归语言模型推理的加速技术。它利用一个轻量级的“草稿模型”并行生成多个候选词,再由目标大模型进行一次前向传播验证,从而将原本必须逐词串行的解码过程转化为批处理验证,大幅提升生成速度,同时不改变最终的输出分布。
为什么需要推测解码
自回归解码的瓶颈
自回归模型每次只能生成一个词,且必须等待前一个词生成后才能计算下一个词。这种严格的顺序依赖导致:
- 显存带宽利用率极低,计算单元大量空闲。
- 生成延迟随序列长度线性增长,长文本生成尤其缓慢。
- 无法直接利用批量并行加速推理。
投机执行的启示
推测解码借鉴了处理器中的“投机执行”思想:提前猜测分支方向,若正确则获得了加速。在大模型推理中,用一个小模型快速推测多个未来词,再让大模型一次性审批,若审批通过则直接接受多个词,从而跳过多次串行计算。
核心工作原理
整体流程
推测解码的一次完整迭代包含三个阶段:
-
草稿生成
使用草稿模型对给定的前缀自回归地连续生成 K 个词(例如 5 个),得到一条候选序列。 -
目标模型验证
将这 K 个词连同前缀一起送入目标大模型,执行一次前向传播。大模型会输出每个位置的词表概率分布。 -
投机采样接受/拒绝
逐位置比较草稿词的概率:如果草稿词的概率不低于目标模型概率,则以一定概率接受该词;否则拒绝并重新采样。一旦出现拒绝,后续所有草稿词全部丢弃,从该位置开始重新生成。
严格保持目标分布
通过精心设计的接受概率,推测解码保证最终产生的词序列与直接使用目标模型逐词生成的分布完全一致,没有精度损失。这是它区别于其他有损加速方法的核心优势。
数学原理详解
设前缀为 $x_{<t}$,草稿模型 $p_D$ 和目标模型 $p_T$。草稿模型生成了词 $y_t, y_{t+1}, \dots, y_{t+K-1}$。
对于位置 $i$(从 $t$ 开始),定义草稿词 $a = y_i$。目标模型给出概率 $q(a|x_{<i})$,草稿模型给出 $p(a|x_{<i})$。
接受准则:
以概率 $\min\left(1, \frac{q(a)}{p(a)}\right)$ 接受该草稿词。
若接受,则直接将 $a$ 作为位置 $i$ 的输出;若拒绝,则从修正后的分布中重新采样一个词,采样概率为: $$ P(\text{reject and pick } v) = \frac{\max(0, q(v) - p(v))}{\sum_{w}\max(0, q(w) - p(w))} $$ 该公式保证最终输出 $v$ 的边际概率恰好等于 $q(v)$。
直觉解释
- 当草稿模型与目标模型意见一致($q(a) \ge p(a)$)时,总会接受。
- 当 $q(a) < p(a)$ 时,草稿模型过于自信,需要以一定概率拒绝,避免生成草稿模型的“幻觉”词。
- 拒绝后的重采样只从目标模型更偏好的词中选取,以确保分布对齐。
效果量化
若平均每次迭代接受长度为 $\alpha K$($\alpha$ 为接受率),则加速比理论上可接近 $\frac{K}{\text{草稿生成开销} + 1}$ 倍。实践中,在内存带宽受限场景下,可达到 2-3 倍的生成加速。
草稿模型的选择
设计要求
- 结构同源性:草稿模型应与目标模型共享相同的词表和分词器,避免额外编解码开销。
- 轻量高效:参数量通常为目标模型的 1/10 甚至更小,保证草稿生成延迟远低于一次目标模型前向传播。
- 分布近似:草稿模型输出的分布越接近目标模型,接受率越高,加速越明显。
常见草稿模型来源
- 独立小模型:同架构的窄层或浅层版本,如 Llama-7B 用 Llama-160M 作草稿。
- 蒸馏模型:通过知识蒸馏专门训练以模仿目标模型的输出。
- 目标模型子结构:直接复用目标模型的底层或早期层输出,通过简单的头投影预测下个词。
- 缓存与 n-gram 模型:在特定场景下,用简单统计模型作为草稿。
实现中的关键细节
KV 缓存管理
草稿模型和目标模型需要各自维护 KV 缓存。验证阶段只做一次目标模型的前向,因此需要对 K 个位置一次性传入。通常需处理以下问题:
- 草稿模型生成时使用的缓存片段,在验证结束后需根据接受长度进行裁剪或回滚。
- 框架需要支持动态批处理和解耦缓存。
批量化与并行
现代实现通常将草稿生成视为一个“预填充”阶段:对草稿模型使用连续的预填充模式生成 K 个词,而非逐个 decode,进一步减少 kernel 启动开销。
接受步的轻量实现
接受/拒绝采样通常只在 CPU 上处理概率向量,开销极小。关键优化是避免 GPU-CPU 频繁同步,可通过异步处理或提前计算接受随机数。
自适应的 K 值
动态调整每次推测的步数 K 可以根据接受率优化实际吞吐。若接受率高,逐步增大 K;若频繁拒绝,减小 K 以减少浪费的正向计算。
推测解码的优势与局限
优势
- 无损加速:输出分布精确保持,无质量妥协。
- 无需重新训练目标模型:草稿模型可独立选择,目标模型保持不变。
- 可与其他加速技术叠加:可与量化、FlashAttention、Tensor Parallelism 等组合使用。
- 内存带宽友好:将多次小规模计算合并为一次大计算,提高显存利用率。
局限与挑战
- 草稿模型工程复杂度:需要额外部署和管理一个小模型,增加显存占用。
- 接受率退化:若草稿模型与目标模型分布差异大,接受率低,反而拖慢速度。
- 批量推理场景受限:在同时服务多个请求时,推测解码的加速效果可能受调度与缓存开销影响。
- 长序列稳定性:K 过大时,草稿质量下降,收益边际递减。
变体与相关方法
- 推测解码 + 模型并行:在张量并行或流水线并行中,草稿模型可复制到每个设备,维护本地缓存。
- Medusa 头:在目标模型最后隐藏层增加多个并行的词预测头,直接预测未来多个词,省去独立草稿模型。
- Lookahead 解码:使用目标模型自身的 Jacobi 迭代特性并行生成草稿词,无需额外模型。
- 基于检索的草稿:通过搜索类似上下文快速构建候选序列。
实际应用与性能数据
在公开基准中,使用 Llama-2-7B 作为目标模型、Llama-68M 作为草稿模型,推测解码在单条生成任务中可获得 2.0-2.5 倍加速,且生成文本与原始模型完全对齐。在代码补全等高频重复场景中,加速比可超过 3 倍。已集成至 Hugging Face Transformers、vLLM、TensorRT-LLM 等主流框架,可通过简单配置启用。
总结
推测解码通过引入投机执行的思想,巧妙地将语言模型的自回归瓶颈转化为可并行的批量验证问题。它在严格保持原始分布的前提下,显著降低了推理延时,尤其适合内存带宽受限的生成场景。随着轻量草稿模型技术的成熟和框架支持的完善,该技术已成为大模型高效部署的重要工具。