YaRN:结合温度与频率缩放的极长上下文方法

FreeGuideOnline 最新 2026-06-22

YaRN 位置编码:让大语言模型轻松处理超长上下文

YaRN(Yet another RoPE extensioN)是一种高效的大语言模型位置编码扩展方法。它巧妙地将温度缩放频率缩放结合在一起,在不牺牲短上下文性能的前提下,显著提升模型可处理的上下文长度。本教程将带你从核心原理到动手实践,深入理解 YaRN 的精妙设计。

1. 背景:模型为什么需要“位置感”

自然语言是有顺序的,“我打你”和“你打我”含义完全不同。Transformer 模型本身并不天生知道输入序列的顺序,因此需要注入位置编码来告诉模型每个词的位置信息。

近年来,旋转位置编码(RoPE) 成为主流,因为它能通过绝对位置实现相对位置编码,具有良好的外推潜力。但标准的 RoPE 在推理时输入长度超过训练长度时,性能会急剧下降。

2. 基础:旋转位置编码 (RoPE) 速览

RoPE 的核心思想是通过一个与位置相关的旋转矩阵来变换 query 和 key 向量。对于第 (m) 个位置、维度索引为 (i) 的向量对,旋转角度为:

[ \theta_i = m \cdot \omega_i,\quad \text{其中}\quad \omega_i = \frac{1}{10000^{2i/d}} ]

  • (m): 位置索引
  • (d): 头的维度
  • (i): 维度对的序号(取值从 0 到 d/2-1)
  • (\omega_i): 频率分量,高频对应 (i) 小,低频对应 (i) 大

不同维度对应不同旋转频率:高频维度对局部位置敏感,低频维度捕捉长距离依赖。当模型需要处理超过训练长度的序列时,低频维度的旋转角度会超出训练域,导致注意力混乱,这就是 “外推失效” 的根源。

3. YaRN 的设计哲学:温度与频率的协同作用

为扩展上下文窗口,YaRN 吸收了两种关键技术的优点:

  • 位置插值 (Position Interpolation, PI):将位置索引按比例压缩,让长序列映射回训练长度范围,但会降低模型分辨近处位置的能力。
  • NTK-aware 缩放:仅对高频维度进行缩放,保留低频维度的长距辨别力,但完全不对温度进行调整容易导致注意力分数分布异常。

YaRN 的创新在于:在 NTK-aware 频率缩放的基础上,引入 温度参数 来调节 softmax 之前的注意力 logits,以补偿因缩放导致的注意力熵变化。两者结合,使得模型既能看清远处结构,又能精确分辨近处细节,还能保持注意力分布的锐利程度。

3.1 频率缩放:让长序列“慢下来”

YaRN 对 RoPE 的频率 (\omega_i) 进行部分缩放。引入缩放因子 (s)(实际上下文长度与原始训练长度的比值)和一个波长阈值 (\lambda_{\text{ramp}})。

对于每个维度对的波长 (\lambda_i = 2\pi / \omega_i):

  • 如果 (\lambda_i < \lambda_{\text{ramp}})(高频),不进行缩放,保持对局部位置的敏感度。
  • 如果 (\lambda_i \ge \lambda_{\text{ramp}})(低频),对 (\omega_i) 进行缩放,等效为将位置索引除以某个因子,让这些维度适应更长的距离。

实际实现中,YaRN 使用一个平滑的分段函数来决定每个维度的缩放比例:

[ \gamma_i = \begin{cases} 1, & \text{if } \lambda_i < \lambda_{\text{ramp}} \ \left(1 - h(\lambda_i)\right) \frac{s_{\text{max}} - s_{\text{min}}}{s} + \frac{s_{\text{min}}}{s}, & \text{otherwise} \end{cases} ]

其中 (h(\lambda_i)) 是一个平滑斜坡函数,(s_{\text{min}}) 和 (s_{\text{max}}) 定义了缩放边界。新的旋转角度变为 (\theta_i = m \cdot \omega_i \cdot \gamma_i)。

3.2 温度缩放:让注意力分布保持“清晰”

仅仅缩放频率还不够。插值后,注意力 logits(点积值)的方差会发生变化,导致 softmax 输出过于平滑,模型难以聚焦关键信息。YaRN 引入了一个全局温度 (t) 来修正这一点:

[ \text{Attention}(Q, K) = \text{softmax}\left(\frac{QK^T}{t \sqrt{d_k}}\right) ]

通常 (t) 的设置与缩放因子相关,例如可以取 (t = \sqrt{s}) 或基于实验微调。温度升高会使 softmax 分布更加尖锐,帮助模型在长上下文中依然能做出清晰的注意力选择。

4. YaRN 的完整计算流程

假设有一个预训练好的模型,原始训练长度为 (L_{\text{train}}),现在需要扩展到目标长度 (L_{\text{target}}),缩放因子 (s = L_{\text{target}} / L_{\text{train}})。

  1. 确定频率缩放映射
    根据 RoPE 各维度的波长和设置的阈值,计算每个维度对的缩放系数 (\gamma_i),得到新的 (\omega'_i = \omega_i \cdot \gamma_i)。

  2. 替换旋转频率
    模型前向传播时,使用 (\omega'_i) 计算旋转位置编码,不再使用原始的 (\omega_i)。

  3. 设置注意力温度
    在每个注意力层的 softmax 前,将点积结果除以温度 (t)(通常 (t \approx \sqrt{s}) 或根据小规模搜索确定)。

  4. (可选)微调
    在目标长文本数据上对模型进行少量步数的微调,YaRN 方法通常只需要极少的训练(甚至零样本)即可获得优秀的长序列性能。

5. 直观理解:为什么 YaRN 如此有效

  • 高频不变:短距离词对(如“猫”和“吃”)的位置关系几乎不受长度扩展影响,模型对这些关系的感知能力完好保留。
  • 低频压缩:长距离依赖(如文档首尾的呼应)通过缩放被映射回训练区间,原本未曾见过的遥远位置被“拉近”到模型熟悉的范围。
  • 温度调节:相当于给注意力机制增加了一个“放大镜”,防止 long context 中信息被淹没在平滑的概率分布里。

三者协同,使模型在无需大量重新训练的前提下,平滑地适应 2 倍、4 倍乃至更长的上下文。

6. 实际代码示例(PyTorch 微缩版)

以下是 YaRN 应用于 Hugging Face Llama 模型的简化关键代码片段,展示如何修改旋转嵌入并加入温度。

import torch
import math

def apply_yarn_rotary_emb(q, k, cos, sin, position_ids, yarn_params):
    # cos, sin 是基于原始频率生成的旋转矩阵组件
    # yarn_params 包含缩放后的频率

    # 根据 yarn 缩放后的频率重新生成 cos, sin
    dim = q.shape[-1]
    freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
    # 应用 YaRN 缩放因子 gamma
    gamma = compute_yarn_gamma(freq, yarn_params)
    freq = freq * gamma
    # ... 按位置生成新的旋转嵌入
    # 然后应用到 q, k
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed

def attention_with_temp(query, key, value, temperature):
    attn_weights = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(query.size(-1))
    attn_weights = attn_weights / temperature  # YaRN 温度缩放
    attn_weights = torch.softmax(attn_weights, dim=-1)
    output = torch.matmul(attn_weights, value)
    return output

在实际应用时,你可以使用 transformers 库中已集成的 YaRN 实现,只需在模型配置中设置 rope_scaling 参数:

from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained(
    "model_name",
    rope_scaling={"type": "yarn", "factor": 4.0},  # 扩展4倍上下文
    trust_remote_code=True
)

7. 性能表现与适用范围

  • 扩展能力:在无需微调的情况下,YaRN 可将 LLaMA 等模型从 2k 扩展至 32k 甚至 128k,困惑度仍保持在合理范围。
  • 短上下文保持:由于高频部分未缩放,模型在原始短任务上的性能几乎不下降。
  • 微调高效:若辅以少量长文本微调,可进一步大幅提升长序列理解能力。
  • 适用模型:所有基于 RoPE 的模型,如 LLaMA 系列、Mistral、Qwen 等,均可直接替换为 YaRN。

8. 常见疑问解答

问:可以直接用 YaRN 处理任意长度吗?
答:理论上只要显存足够,长度可以非常大,但注意 factor 过大时模型性能仍会逐渐衰减,推荐 factor 不超过 8 或 16。

问:温度参数如何选择?
答:YaRN 论文中给出了基于缩放因子 (s) 的经验公式 (t = \sqrt{s}),也可以在实际任务上微调温度值。有些实现直接在激活中学习一个可调节的温度。

问:必须微调吗?
答:不必须,YaRN 的 zero-shot 扩展能力非常出色。但若对长任务精度要求极高,极少量微调(几百步)就能带来显著提升。

9. 总结

YaRN 通过部分频率缩放温度调节的巧妙组合,扫除了 RoPE 位置编码的外推障碍。它既保留了模型对局部信息的辨别力,又将长距离依赖‘拉入’训练域,同时用温度保证注意力的聚焦性。作为目前最主流的上下文扩展方案之一,YaRN 已被广泛应用于各种开源大语言模型的超长窗口版本中。

当你下次需要让一个 4k 模型阅读整本小说时,不妨试试 YaRN。