量化感知训练 QAT:在训练时模拟量化效应

FreeGuideOnline 最新 2026-06-29

什么是量化感知训练

量化感知训练(Quantization-Aware Training,QAT)是一种在模型训练过程中模拟低精度计算效应的技术。与简单训练后量化不同,QAT 让模型在训练时就“知道”自己将来会被量化为 INT8 或其他低比特格式,从而主动学习对量化误差更鲁棒的权重分布。

为什么需要 QAT

深度学习模型部署到端侧设备时常需要将 FP32 权重和激活压缩为 INT8,以降低功耗、提升推理速度。直接进行训练后量化往往会带来明显的精度损失,尤其在小型模型或敏感任务上。QAT 通过在训练图中插入伪量化节点,使损失函数的梯度能反映量化误差,最终得到一个更容易被量化的模型。

QAT 与传统训练的核心区别

特性 标准训练 量化感知训练
权重与激活精度 FP32 全精度 模拟 INT8 行为
反向传播 直接计算梯度 使用直通估计器绕过量化节点
部署流程 训练后再量化 训练后仅需转换即可获得量化模型
精度保持 可能出现明显下降 通常可将损失控制在 1% 以内

QAT 的工作原理

QAT 的核心思想是在前向传播时使用量化后的值进行计算,而在反向传播时保留全精度梯度更新。这样既能让模型适应量化噪声,又不会破坏优化过程的稳定性。

伪量化节点的插入

在 QAT 中,我们不会真正修改权重的存储格式,而是在每个需要进行量化的操作前插入伪量化节点。这些节点执行三个步骤:

  1. 钳位:将输入张量限制在可表示范围内,例如对称量化的 [-max_value, max_value]。
  2. 缩放与取整:将钳位后的值除以缩放因子,并进行就近取整。
  3. 反量化:将取整后的整数乘以缩放因子,还原回浮点空间。

经过反量化后的值与原始值存在差异,这个差异就是量化误差。后续层接收到的是带有误差的值,模型前向路径完全模拟了低精度推理时的计算效果。

直通估计器的使用

反向传播时,四舍五入取整函数的导数几乎处处为 0,这会导致梯度消失。QAT 普遍使用直通估计器来解决这一问题:将取整操作视为恒等映射,直接传递上游梯度。

这样,网络更新权重时就仿佛取整不存在,但前向损失中已经包含了取整带来的误差信号,从而引导权重向更有利于量化的方向移动。

典型 QAT 流程

一个完整的量化感知训练通常包含以下阶段:

  1. 使用标准 FP32 训练获得一个基线模型。
  2. 在模型图中插入伪量化节点,并配置量化策略(如每层对称/非对称量化、per-tensor 或 per-channel 量化)。
  3. 载入基线权重,以较低学习率进行微调。微调期间,伪量化节点保持激活。
  4. 微调完成后,将伪量化节点转换为真正的定点和整数版本,生成最终部署模型。

量化策略的选择

  • 权重量化:通常采用 per-channel 的对称量化,因为权重分布往往不同通道尺度差异大,per-channel 可大幅减少误差。
  • 激活量化:常采用 per-tensor 的非对称量化,利用零点偏移来适配激活的分布(如 ReLU 输出非负)。
  • 量化范围:可通过最大值校准或指数移动平均来动态更新激活的量化范围。

初学者可以从 per-tensor 激活量化 + per-channel 权重量化的组合开始,这是大多数框架的默认推荐。

动手实践:一个 PyTorch QAT 示例

以下示例基于 PyTorch 的 torch.quantization 模块,展示如何对一个简单卷积网络进行量化感知训练。

准备模型与数据

定义一个两层卷积的小型网络,并在训练前插入量化与反量化桩(stubs):

import torch
import torch.nn as nn
import torch.quantization as quant

class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 16, 3)
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv2d(16, 32, 3)
        self.relu2 = nn.ReLU()
        self.fc = nn.Linear(32 * 6 * 6, 10)

    def forward(self, x):
        x = self.relu1(self.conv1(x))
        x = self.relu2(self.conv2(x))
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

# 转为 QAT 模式
model = SimpleCNN()
model.qconfig = quant.get_default_qat_qconfig('fbgemm')
model = quant.prepare_qat(model, inplace=True)

这里使用的 'fbgemm' 后端支持 per-channel 权重量化,适用于 x86 平台。如果目标是 ARM 设备,可替换为 'qnnpack'

微调与转换

将模型置于训练模式,以较小的学习率进行微调:

optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

model.train()
for epoch in range(5):
    for inputs, labels in data_loader:
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

# 转换为真正的量化模型
model.eval()
quantized_model = quant.convert(model, inplace=False)

convert 调用后,伪量化节点被替换为真正的量化算子,权重被转换为 INT8,激活在推理时动态量化。

调优技巧与常见问题

如何选择微调轮数

QAT 微调通常只需原始训练周期的 10%~20%。过多的微调可能导致模型偏离已学到的高精度分布,反而损害精度。建议在 2~5 个 epoch 内观察验证集指标,一旦稳定即可停止。

批量归一化的特殊处理

批量归一化层在 QAT 中需要特别关注。通常将 BN 层与前面的卷积融合,然后将融合后的权重进行量化。PyTorch 的 fuse_modules 工具可自动完成此步骤:

model = quant.fuse_modules(model, [['conv1', 'relu1'], ['conv2', 'relu2']])

融合后的模型再执行 prepare_qat

激活范围校准

激活的量化范围对精度影响显著。使用指数移动平均更新最小值与最大值,可以更稳定地估计推理时的激活分布。在 PyTorch 中,通过 qconfig 的 observer 类型控制,例如 MovingAverageMinMaxObserver

诊断量化误差

若微调后模型精度不佳,可通过以下方式排查:

  • 比较原始模型与 QAT 模型在同一批数据上的激活分布差异。
  • 检查每层量化后的信噪比(原始权重与反量化权重的余弦相似度)。
  • 尝试将权重量化从 per-tensor 改为 per-channel。

框架支持一览

主流深度学习框架均提供了对 QAT 的原生支持:

  • PyTorchtorch.quantization 模块,支持静态/动态 QAT。
  • TensorFlowtf.quantization API,配合 TensorFlow Model Optimization 工具包。
  • ONNX Runtime:支持导入 QAT 模型并利用硬件加速指令。
  • OpenVINO / TensorRT:提供 QAT 微调后的 INT8 推理优化。

初学者建议从 PyTorch 或 TensorFlow 官方 QAT 教程入手,配合小型图像分类任务快速验证效果。

总结

量化感知训练通过在训练时显式建模量化噪声,让模型学会在低精度下保持表达能力。它的核心在于伪量化节点直通估计器,流程简单却十分有效。掌握 QAT 将使你能够在不显著损失精度的前提下,将模型部署到资源受限的设备上。

下一步可以尝试在自己训练的模型上应用 QAT,比较训练后量化与 QAT 的精度差异,并尝试调节量化参数与微调策略,找到最优部署方案。