ONNX 模型交换:跨框架互操作
ONNX 模型交换:跨框架互操作完全指南
什么是 ONNX
ONNX(Open Neural Network Exchange,开放神经网络交换)是一种开放的模型表示格式,它定义了一组通用的计算节点和算子,使得不同深度学习框架训练出的模型可以相互转换、部署和复用。简单来说,ONNX 就是神经网络模型的“通用语言”。
在没有 ONNX 之前,用 PyTorch 训练的模型很难直接在 TensorFlow 环境中运行,反之亦然。每次切换到新框架或部署到不同推理引擎时,往往需要从头重写模型结构并重新训练。ONNX 的出现解决了这一痛点:你可以在任意主流框架中训练模型,导出为 .onnx 文件,然后在任何支持 ONNX 的运行时中加载推理。
ONNX 能做什么
- 跨框架迁移:轻松将 PyTorch 模型转为 TensorFlow 模型,或从 Keras 导入到 ONNX Runtime。
- 统一部署:同一份 ONNX 模型可以部署在云端、边缘设备、移动端和浏览器中,无需为每个平台专门优化代码。
- 硬件加速:ONNX 运行时可以利用多种硬件加速库(如 CUDA、TensorRT、OpenVINO、DirectML),自动选择最优执行提供程序。
- 模型优化:ONNX 工具链支持图优化、量化、算子融合等,在不改变模型精度的前提下提升推理速度。
ONNX 的核心概念
在动手转换模型之前,先理解几个关键概念会让学习更顺畅。
计算图与算子
ONNX 使用有向无环图表示模型。图中的每个节点代表一个算子(操作),如卷积、矩阵乘法、激活函数等;边表示数据流动,即张量。ONNX 规范定义了超过 160 种标准算子,覆盖了绝大多数神经网络结构的需要。
模型格式与版本
一个 ONNX 模型由以下几部分组成:
- 模型元数据:生产者和版本信息。
- 图:计算图结构,包含输入/输出定义和所有算子节点。
- 算子集:声明模型使用了哪些算子集版本,确保兼容性。
ONNX 文件通常以 .onnx 为后缀,采用 Protocol Buffers 序列化存储,体积小且读取高效。
数据形状与类型
ONNX 支持动态形状和静态形状。如果在导出时指定某些维度为 dynamic(如批次大小),那么模型可以在推理时接受不同大小的输入。数据类型包括常见的 float32、int64、float16、bfloat16 等。
从主流框架导出 ONNX 模型
下面演示如何从 PyTorch、TensorFlow 和 scikit-learn 导出 ONNX 模型。
从 PyTorch 导出
PyTorch 从 1.3 版本开始内置了 torch.onnx.export 函数。
import torch
import torchvision
# 加载预训练模型
model = torchvision.models.resnet18(pretrained=True)
model.eval()
# 创建一个符合模型输入尺寸的示例张量
dummy_input = torch.randn(1, 3, 224, 224)
# 导出 ONNX 模型
torch.onnx.export(
model,
dummy_input,
"resnet18.onnx",
export_params=True, # 将权重也保存进去
opset_version=13, # 算子集版本
do_constant_folding=True, # 执行常量折叠优化
input_names=['input'], # 输入节点名称
output_names=['output'], # 输出节点名称
dynamic_axes={'input': {0: 'batch_size'}, # 动态轴设置
'output': {0: 'batch_size'}}
)
常见问题:如果模型中包含控制流(如条件判断)或仅 PyTorch 特有的操作,导出可能失败。此时需要重写相关部分为等价形式,或使用 @torch.jit.script 装饰器。
从 TensorFlow / Keras 导出
需要安装第三方包 tf2onnx。
pip install tf2onnx
命令行导出:
python -m tf2onnx.convert \
--saved-model ./my_keras_model \
--output model.onnx \
--opset 13
如果已有 Keras 模型对象,也可以使用 Python API:
import tensorflow as tf
import tf2onnx
model = tf.keras.applications.MobileNetV2()
spec = (tf.TensorSpec((None, 224, 224, 3), tf.float32, name="input"),)
output_path = "mobilenetv2.onnx"
model_proto, _ = tf2onnx.convert.from_keras(model,
input_signature=spec,
opset=13,
output_path=output_path)
从 scikit-learn 导出
传统机器学习模型也能通过 skl2onnx 库转为 ONNX。
pip install skl2onnx
from sklearn.ensemble import RandomForestClassifier
from skl2onnx import convert_sklearn
from skl2onnx.common.data_types import FloatTensorType
import onnx
# 训练一个简单模型
clf = RandomForestClassifier()
clf.fit(X_train, y_train)
# 定义输入类型
initial_type = [('float_input', FloatTensorType([None, 4]))]
onx = convert_sklearn(clf, initial_types=initial_type)
with open("rf_model.onnx", "wb") as f:
f.write(onx.SerializeToString())
ONNX 模型的验证与查看
导出后,强烈建议验证模型的有效性和结构。
验证模型格式
使用 ONNX 官方库检查模型是否完整:
import onnx
model = onnx.load("resnet18.onnx")
onnx.checker.check_model(model) # 若无异常则通过
可视化模型图
Netron 是一款极佳的可视化工具,支持直接拖拽 .onnx 文件查看网络结构。
也可以通过代码生成简化图:
from onnx.tools.net_drawer import GetPydotGraph, GetOpNodeProducer
import onnx
model = onnx.load("resnet18.onnx")
pydot_graph = GetPydotGraph(model.graph, name="model", rankdir="TB")
pydot_graph.write_dot("model.dot") # 可使用 graphviz 打开
模型形状与类型推断
对大型模型,可以利用 ONNX 的形状推断功能获得所有中间张量的形状:
import onnx
from onnx import shape_inference
model = onnx.load("resnet18.onnx")
inferred_model = shape_inference.infer_shapes(model)
for node in inferred_model.graph.value_info:
print(node.name, node.type.tensor_type.shape)
跨框架互操作实战:模型转换示例
一个典型场景:将 PyTorch 模型转为 TensorFlow 模型。
步骤 1:将 PyTorch 模型导出为 ONNX。
步骤 2:使用 onnx-tf 将 ONNX 转成 TensorFlow 模型。
pip install onnx-tf
import onnx
from onnx_tf.backend import prepare
# 加载 ONNX 模型
onnx_model = onnx.load("resnet18.onnx")
# 转换为 TensorFlow 图并保存
tf_rep = prepare(onnx_model)
tf_rep.export_graph("resnet18_tf")
执行后会生成一个 TensorFlow SavedModel 文件夹,可直接用 tf.saved_model.load() 加载。
反向转换:TensorFlow 模型转 PyTorch 同样先走 ONNX 中间格式,然后用 onnx2pytorch 工具恢复为 PyTorch 模块。
推理运行:ONNX Runtime
ONNX Runtime(ORT)是微软推出的一款高性能推理引擎,对 ONNX 模型提供一流支持。它支持 Python、C++、C#、Java 等语言,并可以跨 Windows、Linux、macOS 运行。
安装:
pip install onnxruntime # CPU 版本
# 或
pip install onnxruntime-gpu # GPU 版本
Python 推理示例:
import onnxruntime
import numpy as np
session = onnxruntime.InferenceSession("resnet18.onnx")
# 获取输入名称
input_name = session.get_inputs()[0].name
# 准备输入数据(与导出时使用的示例形状一致)
input_data = np.random.rand(1, 3, 224, 224).astype(np.float32)
# 执行推理
outputs = session.run(None, {input_name: input_data})
print(outputs[0].shape)
ORT 会选择系统上最优的执行提供程序(Execution Provider)。可以通过以下方式查看可用的:
providers = onnxruntime.get_available_providers()
print(providers) # 例如 ['CPUExecutionProvider', 'CUDAExecutionProvider']
若想强制使用 GPU,可以这样创建会话:
session = onnxruntime.InferenceSession("model.onnx",
providers=['CUDAExecutionProvider'])
模型优化
ONNX 生态提供了丰富的优化工具,能在几乎不影响精度的情况下大幅提升推理速度。
基础图优化
在导出时开启常量折叠和算子融合只是第一步。使用 onnxoptimizer 可以进行更深层次的优化:
pip install onnxoptimizer
from onnx import optimizer
import onnx
model = onnx.load("model.onnx")
passes = ["extract_constant_to_initializer",
"eliminate_unused_initializer",
"eliminate_nop_transpose",
"fuse_consecutive_transposes",
"fuse_transpose_into_gemm"]
optimized_model = optimizer.optimize(model, passes)
onnx.save(optimized_model, "model_optimized.onnx")
量化
将模型从 FP32 量化到 INT8 可以显著减小体积并提升推理速度,尤其适合移动和边缘设备。
使用 ONNX Runtime 的量化工具:
from onnxruntime.quantization import quantize_dynamic, QuantType
quantize_dynamic(
model_input="model.onnx",
model_output="model_int8.onnx",
weight_type=QuantType.QInt8
)
动态量化只对权重进行量化,激活仍然是浮点,易于实现且无需校准数据。如需更激进的静态量化,需要提供校准数据集。
转换到其他推理格式
ONNX 模型还可以进一步转换为专用加速格式,例如:
- TensorRT:使用
trtexec或 ONNX Runtime TensorRT EP 直接优化。 - OpenVINO:使用
model_optimizer转换为 IR 格式。 - CoreML:通过
onnx-coreml转换为.mlmodel。
常见问题与解决思路
算子不兼容
导出时如果提示某个算子不被 ONNX 支持,可以:
- 升级 ONNX opset 版本。
- 用支持的标准算子组合替代。
- 使用框架自带的兼容转换方法(如 PyTorch 的
torch.onnx.symbolic)。
动态形状导致推理失败
导出时设置动态轴,但在推理时未正确匹配。确保调用 run() 时传入的输入形状与动态轴定义一致,并且所有中间尺寸可推导。
模型转换后精度下降
- 检查模型转换后是否通过了
check_model。 - 将同一份输入分别喂给原模型和 ONNX 模型,对比输出差异。
- 差异过大时,可能是某些算子实现存在数值差异,尝试更换 opset 或禁用某些优化。
总结
ONNX 已经成为深度学习行业事实上的互操作标准。掌握 ONNX 模型交换,不仅可以为你的模型提供“一劳永逸”的部署方案,更能让你在不同框架和硬件之间自由穿梭。从今天开始,将你训练的模型导出为 ONNX 格式,纳入这个庞大而活跃的开放生态中吧。