NTK 感知缩放:非线性插值保护高频信息
FreeGuideOnline
最新
2026-06-22
θ_i = base^{-2i / d}
其中 `base` 通常取 10000。
3. 对于位置 `pos`,该对向量会旋转一个角度 `pos * θ_i`。
基础旋转角频率 `θ_i` 会从高频(i 小,θ 大)指数级衰减到低频(i 大,θ 小)。这种设计可以类比为傅里叶级数:不同频率的基函数编码不同尺度的位置依赖。
## 扩展上下文时的直接插值为何失败
当我们要将模型应用到比训练上下文更长的序列时,一个直觉的做法是**位置线性插值(Position Interpolation, PI)**。例如,模型训练时最大位置 `L` 为 2048,现在要使用 4096 的上下文,我们将所有 token 的位置除以 2:
new_pos = pos * (L_old / L_new)
这种做法等效于对所有旋转频率施加一个统一的缩放因子 `s = L_new / L_old`,将旋转角度变为 `pos / s * θ_i`。
**问题在于:线性插值“一视同仁”地压缩了所有频率,导致高频分量被过度展宽,失去了对局部位置的敏感度。** 模型依赖高频旋转来区分相邻 token 的精细位置关系,一旦这些高频被强行降频,模型在短距离上的注意力模式就会受损,出现“困惑度升高”甚至完全无法收敛的现象。
## NTK 感知缩放的工作原理
NTK 感知缩放受到神经正切核(NTK)理论的启发,提出一种**非线性插值**方案:**低频分量可以大幅压缩以适应更长上下文,但高频分量应当几乎保持不变,从而保护局部注意力**。
具体做法不是直接改变位置索引,而是**修改 RoPE 的频率基 `base`**。
### 频率基缩放与非线性插值的等价性
我们将原始的 `base`(例如 10000)乘以一个缩放因子 `α`,得到新的基:
base_new = base * α^(d / (d-2))
这等价于对所有旋转频率进行如下更改:
θ'_i = (base * α^(d/(d-2))) ^ (-2i/d) = base^{-2i/d} * α^{-2i/(d-2)}
注意到指数 `-2i/(d-2)` 几乎为 0 当 i 很小(高频),而接近 1 当 i 很大(低频)。因此:
- 对于高频分量(i → 0),`θ' ≈ θ`,基本不变。
- 对于低频分量(i → d/2),`θ' ≈ θ / α`,等效于按 `α` 线性缩放。
这样我们就实现了对频率的**非线性拉伸**:高频不受影响,低频被压缩以承载更长的位置周期。这正是 NTK 感知缩放名称的来源——它像 NTK 理论那样,意识到网络的不同频率信道对外推的敏感度不同,并给予不同的处理。
### 缩放因子 α 的选择
若想将上下文从 `L` 扩展到 `L_target`,缩放因子通常取经验值,使得最低频的周期能够覆盖新长度。最低频率对应的波长为 `2π / θ_{max}`,我们希望它在 `L_target` 内至少完成一个周期,因此大致有:
α ≈ (L_target / L)^(d / (d-2))
简化实践中,经常直接按期望的上下文倍数设定 `α`,并结合少量微调。许多开源项目(如 llama.cpp, vLLM)提供了 `alpha` 参数,让用户可以直接测试。
## 代码实现示例
以下伪代码展示了 NTK 感知缩放的核心修改(以 PyTorch 为例):
```python
import torch
import math
def ntk_aware_scaled_rope(x, base=10000.0, scale_factor=2.0, dim=128):
# scale_factor: 上下文扩展倍数
# dim: 单个头的维度
# 计算新的 base
alpha = scale_factor ** (dim / (dim - 2))
base_new = base * alpha
# 生成频率 θ_i
inv_freq = 1.0 / (base_new ** (torch.arange(0, dim, 2).float() / dim))
# 应用于 RoPE 的后续计算...
# 注意:实际实现中还需考虑位置编码生成的具体逻辑
return inv_freq