动态量化:在推理时按需量化的灵活方案
FreeGuideOnline
最新
2026-06-29
python import torch import torch.nn as nn import time
### 定义一个简单的文本分类模型(LSTM)
```python
class TextClassifier(nn.Module):
def __init__(self, vocab_size, embed_dim, hidden_dim, num_classes):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim)
self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True)
self.fc = nn.Linear(hidden_dim, num_classes)
def forward(self, x):
x = self.embedding(x)
_, (h_n, _) = self.lstm(x)
out = self.fc(h_n[-1])
return out
实例化模型并下载预训练权重
我们用随机权重模拟:
vocab_size = 10000
embed_dim = 256
hidden_dim = 512
num_classes = 5
model = TextClassifier(vocab_size, embed_dim, hidden_dim, num_classes)
model.eval()
基准测试:浮点模型推理时间
dummy_input = torch.randint(0, vocab_size, (1, 50)) # 批次大小1,序列长度50
# 预热
with torch.no_grad():
for _ in range(10):
_ = model(dummy_input)
# 计时
start = time.time()
with torch.no_grad():
for _ in range(100):
_ = model(dummy_input)
fp32_time = time.time() - start
print(f"FP32 推理时间: {fp32_time:.4f} 秒")
应用动态量化
一条命令即完成量化,指定需要量化的层类型和精度。
quantized_model = torch.quantization.quantize_dynamic(
model, # 原始浮点模型
{nn.Linear, nn.LSTM}, # 要量化的层类型集合
dtype=torch.qint8 # 权重量化精度
)
print(quantized_model)
你会看到被量化的层名称现在包含了 DynamicQuantized 前缀。
测试量化后模型的推理时间与精度
# 预热量化模型
with torch.no_grad():
for _ in range(10):
_ = quantized_model(dummy_input)
# 计时
start = time.time()
with torch.no_grad():
for _ in range(100):
_ = quantized_model(dummy_input)
int8_time = time.time() - start
print(f"INT8 动态量化推理时间: {int8_time:.4f} 秒")
print(f"加速比: {fp32_time / int8_time:.2f}x")
模型大小对比
import os
def get_model_size(model):
torch.save(model.state_dict(), "temp.p")
size = os.path.getsize("temp.p") / 1e6 # MB
os.remove("temp.p")
return size
fp32_size = get_model_size(model)
int8_size = get_model_size(quantized_model)
print(f"FP32 模型大小: {fp32_size:.2f} MB")
print(f"INT8 模型大小: {int8_size:.2f} MB")
print(f"压缩比: {fp32_size / int8_size:.2f}x")
在典型的 LSTM 模型上,你可能会看到 2~3 倍的推理加速和接近 4 倍的模型压缩,而准确率几乎无损。
动态量化的适用场景与局限性
适用场景
- NLP 模型:基于 LSTM/GRU 的文本分类、序列标注、机器翻译等。
- Transformer 模型:BERT 等预训练模型在仅需要减小模型体积和省内存时,可以快速使用动态量化,不需要额外校准数据(但静态量化或 QAT 可能获得更好性能)。
- 时间序列预测:递归结构广泛存在。
- 嵌入式/IoT 设备:模型存储空间和内存带宽受限,动态量化改造工作量极小。
- 快速原型验证:不想投入量化训练或校准的成本,又想立刻看到压缩和加速效果。
局限性
- 激活量化开销:对于大型全连接层极多的模型,动态计算激活范围的开销可能会抵消部分加速。
- 硬件支持:部分低功耗芯片可能没有高效的 INT8 向量化指令,加速效果受限。
- 精度敏感层:某些特殊情况(如生成模型中的温度采样)可能对激活值的精度极其敏感,动态量化的浮点<->整型频繁转换会带来轻微误差。
- 不支持 GPU 加速(PyTorch 目前动态量化主要针对 CPU);PyTorch 移动端支持动态量化,但 GPU 上通常采用不同的量化策略。
动态量化的进阶技巧
自定义量化配置
可以仅量化部分层,保留部分层为浮点,平衡精度和速度:
quantized_model = torch.quantization.quantize_dynamic(
model,
{nn.LSTM}, # 只量化 LSTM 层
dtype=torch.qint8
)
结合算子融合
动态量化自动融合支持的算子(如 Linear + ReLU),如果模型中有自定义激活函数,可能需要手动进行融合。
检查量化模型结构
print(quantized_model)
# 观察哪些层被替换为 DynamicQuantizedLSTM 或 DynamicQuantizedLinear
保存与加载量化模型
# 保存
torch.save(quantized_model.state_dict(), "quantized_model.pth")
# 加载时需重新构建量化模型对象
model = TextClassifier(vocab_size, embed_dim, hidden_dim, num_classes)
model.load_state_dict(torch.load("quantized_model.pth")) # 错误!state_dict 的 key 不匹配
# 正确做法:先量化再加载
model_q = torch.quantization.quantize_dynamic(model, {nn.Linear, nn.LSTM}, dtype=torch.qint8)
model_q.load_state_dict(torch.load("quantized_model.pth"))
推荐直接使用 torch.jit.save 保存 TorchScript 形式的量化模型,实现无需模型定义即可加载。
scripted = torch.jit.script(quantized_model)
torch.jit.save(scripted, "quantized_model.pt")
# 加载
loaded = torch.jit.load("quantized_model.pt")