YaRN:结合温度与频率缩放的极长上下文方法
YaRN 位置编码:让大语言模型轻松处理超长上下文
YaRN(Yet another RoPE extensioN)是一种高效的大语言模型位置编码扩展方法。它巧妙地将温度缩放与频率缩放结合在一起,在不牺牲短上下文性能的前提下,显著提升模型可处理的上下文长度。本教程将带你从核心原理到动手实践,深入理解 YaRN 的精妙设计。
1. 背景:模型为什么需要“位置感”
自然语言是有顺序的,“我打你”和“你打我”含义完全不同。Transformer 模型本身并不天生知道输入序列的顺序,因此需要注入位置编码来告诉模型每个词的位置信息。
近年来,旋转位置编码(RoPE) 成为主流,因为它能通过绝对位置实现相对位置编码,具有良好的外推潜力。但标准的 RoPE 在推理时输入长度超过训练长度时,性能会急剧下降。
2. 基础:旋转位置编码 (RoPE) 速览
RoPE 的核心思想是通过一个与位置相关的旋转矩阵来变换 query 和 key 向量。对于第 (m) 个位置、维度索引为 (i) 的向量对,旋转角度为:
[ \theta_i = m \cdot \omega_i,\quad \text{其中}\quad \omega_i = \frac{1}{10000^{2i/d}} ]
- (m): 位置索引
- (d): 头的维度
- (i): 维度对的序号(取值从 0 到 d/2-1)
- (\omega_i): 频率分量,高频对应 (i) 小,低频对应 (i) 大
不同维度对应不同旋转频率:高频维度对局部位置敏感,低频维度捕捉长距离依赖。当模型需要处理超过训练长度的序列时,低频维度的旋转角度会超出训练域,导致注意力混乱,这就是 “外推失效” 的根源。
3. YaRN 的设计哲学:温度与频率的协同作用
为扩展上下文窗口,YaRN 吸收了两种关键技术的优点:
- 位置插值 (Position Interpolation, PI):将位置索引按比例压缩,让长序列映射回训练长度范围,但会降低模型分辨近处位置的能力。
- NTK-aware 缩放:仅对高频维度进行缩放,保留低频维度的长距辨别力,但完全不对温度进行调整容易导致注意力分数分布异常。
YaRN 的创新在于:在 NTK-aware 频率缩放的基础上,引入 温度参数 来调节 softmax 之前的注意力 logits,以补偿因缩放导致的注意力熵变化。两者结合,使得模型既能看清远处结构,又能精确分辨近处细节,还能保持注意力分布的锐利程度。
3.1 频率缩放:让长序列“慢下来”
YaRN 对 RoPE 的频率 (\omega_i) 进行部分缩放。引入缩放因子 (s)(实际上下文长度与原始训练长度的比值)和一个波长阈值 (\lambda_{\text{ramp}})。
对于每个维度对的波长 (\lambda_i = 2\pi / \omega_i):
- 如果 (\lambda_i < \lambda_{\text{ramp}})(高频),不进行缩放,保持对局部位置的敏感度。
- 如果 (\lambda_i \ge \lambda_{\text{ramp}})(低频),对 (\omega_i) 进行缩放,等效为将位置索引除以某个因子,让这些维度适应更长的距离。
实际实现中,YaRN 使用一个平滑的分段函数来决定每个维度的缩放比例:
[ \gamma_i = \begin{cases} 1, & \text{if } \lambda_i < \lambda_{\text{ramp}} \ \left(1 - h(\lambda_i)\right) \frac{s_{\text{max}} - s_{\text{min}}}{s} + \frac{s_{\text{min}}}{s}, & \text{otherwise} \end{cases} ]
其中 (h(\lambda_i)) 是一个平滑斜坡函数,(s_{\text{min}}) 和 (s_{\text{max}}) 定义了缩放边界。新的旋转角度变为 (\theta_i = m \cdot \omega_i \cdot \gamma_i)。
3.2 温度缩放:让注意力分布保持“清晰”
仅仅缩放频率还不够。插值后,注意力 logits(点积值)的方差会发生变化,导致 softmax 输出过于平滑,模型难以聚焦关键信息。YaRN 引入了一个全局温度 (t) 来修正这一点:
[ \text{Attention}(Q, K) = \text{softmax}\left(\frac{QK^T}{t \sqrt{d_k}}\right) ]
通常 (t) 的设置与缩放因子相关,例如可以取 (t = \sqrt{s}) 或基于实验微调。温度升高会使 softmax 分布更加尖锐,帮助模型在长上下文中依然能做出清晰的注意力选择。
4. YaRN 的完整计算流程
假设有一个预训练好的模型,原始训练长度为 (L_{\text{train}}),现在需要扩展到目标长度 (L_{\text{target}}),缩放因子 (s = L_{\text{target}} / L_{\text{train}})。
-
确定频率缩放映射
根据 RoPE 各维度的波长和设置的阈值,计算每个维度对的缩放系数 (\gamma_i),得到新的 (\omega'_i = \omega_i \cdot \gamma_i)。 -
替换旋转频率
模型前向传播时,使用 (\omega'_i) 计算旋转位置编码,不再使用原始的 (\omega_i)。 -
设置注意力温度
在每个注意力层的 softmax 前,将点积结果除以温度 (t)(通常 (t \approx \sqrt{s}) 或根据小规模搜索确定)。 -
(可选)微调
在目标长文本数据上对模型进行少量步数的微调,YaRN 方法通常只需要极少的训练(甚至零样本)即可获得优秀的长序列性能。
5. 直观理解:为什么 YaRN 如此有效
- 高频不变:短距离词对(如“猫”和“吃”)的位置关系几乎不受长度扩展影响,模型对这些关系的感知能力完好保留。
- 低频压缩:长距离依赖(如文档首尾的呼应)通过缩放被映射回训练区间,原本未曾见过的遥远位置被“拉近”到模型熟悉的范围。
- 温度调节:相当于给注意力机制增加了一个“放大镜”,防止 long context 中信息被淹没在平滑的概率分布里。
三者协同,使模型在无需大量重新训练的前提下,平滑地适应 2 倍、4 倍乃至更长的上下文。
6. 实际代码示例(PyTorch 微缩版)
以下是 YaRN 应用于 Hugging Face Llama 模型的简化关键代码片段,展示如何修改旋转嵌入并加入温度。
import torch
import math
def apply_yarn_rotary_emb(q, k, cos, sin, position_ids, yarn_params):
# cos, sin 是基于原始频率生成的旋转矩阵组件
# yarn_params 包含缩放后的频率
# 根据 yarn 缩放后的频率重新生成 cos, sin
dim = q.shape[-1]
freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
# 应用 YaRN 缩放因子 gamma
gamma = compute_yarn_gamma(freq, yarn_params)
freq = freq * gamma
# ... 按位置生成新的旋转嵌入
# 然后应用到 q, k
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
def attention_with_temp(query, key, value, temperature):
attn_weights = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(query.size(-1))
attn_weights = attn_weights / temperature # YaRN 温度缩放
attn_weights = torch.softmax(attn_weights, dim=-1)
output = torch.matmul(attn_weights, value)
return output
在实际应用时,你可以使用 transformers 库中已集成的 YaRN 实现,只需在模型配置中设置 rope_scaling 参数:
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained(
"model_name",
rope_scaling={"type": "yarn", "factor": 4.0}, # 扩展4倍上下文
trust_remote_code=True
)
7. 性能表现与适用范围
- 扩展能力:在无需微调的情况下,YaRN 可将 LLaMA 等模型从 2k 扩展至 32k 甚至 128k,困惑度仍保持在合理范围。
- 短上下文保持:由于高频部分未缩放,模型在原始短任务上的性能几乎不下降。
- 微调高效:若辅以少量长文本微调,可进一步大幅提升长序列理解能力。
- 适用模型:所有基于 RoPE 的模型,如 LLaMA 系列、Mistral、Qwen 等,均可直接替换为 YaRN。
8. 常见疑问解答
问:可以直接用 YaRN 处理任意长度吗?
答:理论上只要显存足够,长度可以非常大,但注意 factor 过大时模型性能仍会逐渐衰减,推荐 factor 不超过 8 或 16。
问:温度参数如何选择?
答:YaRN 论文中给出了基于缩放因子 (s) 的经验公式 (t = \sqrt{s}),也可以在实际任务上微调温度值。有些实现直接在激活中学习一个可调节的温度。
问:必须微调吗?
答:不必须,YaRN 的 zero-shot 扩展能力非常出色。但若对长任务精度要求极高,极少量微调(几百步)就能带来显著提升。
9. 总结
YaRN 通过部分频率缩放与温度调节的巧妙组合,扫除了 RoPE 位置编码的外推障碍。它既保留了模型对局部信息的辨别力,又将长距离依赖‘拉入’训练域,同时用温度保证注意力的聚焦性。作为目前最主流的上下文扩展方案之一,YaRN 已被广泛应用于各种开源大语言模型的超长窗口版本中。
当你下次需要让一个 4k 模型阅读整本小说时,不妨试试 YaRN。