ONNX 模型交换:跨框架互操作

FreeGuideOnline 最新 2026-06-17

ONNX 模型交换:跨框架互操作完全指南

什么是 ONNX

ONNX(Open Neural Network Exchange,开放神经网络交换)是一种开放的模型表示格式,它定义了一组通用的计算节点和算子,使得不同深度学习框架训练出的模型可以相互转换、部署和复用。简单来说,ONNX 就是神经网络模型的“通用语言”。

在没有 ONNX 之前,用 PyTorch 训练的模型很难直接在 TensorFlow 环境中运行,反之亦然。每次切换到新框架或部署到不同推理引擎时,往往需要从头重写模型结构并重新训练。ONNX 的出现解决了这一痛点:你可以在任意主流框架中训练模型,导出为 .onnx 文件,然后在任何支持 ONNX 的运行时中加载推理。

ONNX 能做什么

  • 跨框架迁移:轻松将 PyTorch 模型转为 TensorFlow 模型,或从 Keras 导入到 ONNX Runtime。
  • 统一部署:同一份 ONNX 模型可以部署在云端、边缘设备、移动端和浏览器中,无需为每个平台专门优化代码。
  • 硬件加速:ONNX 运行时可以利用多种硬件加速库(如 CUDA、TensorRT、OpenVINO、DirectML),自动选择最优执行提供程序。
  • 模型优化:ONNX 工具链支持图优化、量化、算子融合等,在不改变模型精度的前提下提升推理速度。

ONNX 的核心概念

在动手转换模型之前,先理解几个关键概念会让学习更顺畅。

计算图与算子

ONNX 使用有向无环图表示模型。图中的每个节点代表一个算子(操作),如卷积、矩阵乘法、激活函数等;边表示数据流动,即张量。ONNX 规范定义了超过 160 种标准算子,覆盖了绝大多数神经网络结构的需要。

模型格式与版本

一个 ONNX 模型由以下几部分组成:

  • 模型元数据:生产者和版本信息。
  • :计算图结构,包含输入/输出定义和所有算子节点。
  • 算子集:声明模型使用了哪些算子集版本,确保兼容性。

ONNX 文件通常以 .onnx 为后缀,采用 Protocol Buffers 序列化存储,体积小且读取高效。

数据形状与类型

ONNX 支持动态形状和静态形状。如果在导出时指定某些维度为 dynamic(如批次大小),那么模型可以在推理时接受不同大小的输入。数据类型包括常见的 float32int64float16bfloat16 等。

从主流框架导出 ONNX 模型

下面演示如何从 PyTorch、TensorFlow 和 scikit-learn 导出 ONNX 模型。

从 PyTorch 导出

PyTorch 从 1.3 版本开始内置了 torch.onnx.export 函数。

import torch
import torchvision

# 加载预训练模型
model = torchvision.models.resnet18(pretrained=True)
model.eval()

# 创建一个符合模型输入尺寸的示例张量
dummy_input = torch.randn(1, 3, 224, 224)

# 导出 ONNX 模型
torch.onnx.export(
    model,
    dummy_input,
    "resnet18.onnx",
    export_params=True,          # 将权重也保存进去
    opset_version=13,            # 算子集版本
    do_constant_folding=True,    # 执行常量折叠优化
    input_names=['input'],       # 输入节点名称
    output_names=['output'],     # 输出节点名称
    dynamic_axes={'input': {0: 'batch_size'},  # 动态轴设置
                  'output': {0: 'batch_size'}}
)

常见问题:如果模型中包含控制流(如条件判断)或仅 PyTorch 特有的操作,导出可能失败。此时需要重写相关部分为等价形式,或使用 @torch.jit.script 装饰器。

从 TensorFlow / Keras 导出

需要安装第三方包 tf2onnx

pip install tf2onnx

命令行导出:

python -m tf2onnx.convert \
  --saved-model ./my_keras_model \
  --output model.onnx \
  --opset 13

如果已有 Keras 模型对象,也可以使用 Python API:

import tensorflow as tf
import tf2onnx

model = tf.keras.applications.MobileNetV2()
spec = (tf.TensorSpec((None, 224, 224, 3), tf.float32, name="input"),)
output_path = "mobilenetv2.onnx"

model_proto, _ = tf2onnx.convert.from_keras(model,
                                            input_signature=spec,
                                            opset=13,
                                            output_path=output_path)

从 scikit-learn 导出

传统机器学习模型也能通过 skl2onnx 库转为 ONNX。

pip install skl2onnx
from sklearn.ensemble import RandomForestClassifier
from skl2onnx import convert_sklearn
from skl2onnx.common.data_types import FloatTensorType
import onnx

# 训练一个简单模型
clf = RandomForestClassifier()
clf.fit(X_train, y_train)

# 定义输入类型
initial_type = [('float_input', FloatTensorType([None, 4]))]
onx = convert_sklearn(clf, initial_types=initial_type)

with open("rf_model.onnx", "wb") as f:
    f.write(onx.SerializeToString())

ONNX 模型的验证与查看

导出后,强烈建议验证模型的有效性和结构。

验证模型格式

使用 ONNX 官方库检查模型是否完整:

import onnx

model = onnx.load("resnet18.onnx")
onnx.checker.check_model(model)  # 若无异常则通过

可视化模型图

Netron 是一款极佳的可视化工具,支持直接拖拽 .onnx 文件查看网络结构。

也可以通过代码生成简化图:

from onnx.tools.net_drawer import GetPydotGraph, GetOpNodeProducer
import onnx

model = onnx.load("resnet18.onnx")
pydot_graph = GetPydotGraph(model.graph, name="model", rankdir="TB")
pydot_graph.write_dot("model.dot")   # 可使用 graphviz 打开

模型形状与类型推断

对大型模型,可以利用 ONNX 的形状推断功能获得所有中间张量的形状:

import onnx
from onnx import shape_inference

model = onnx.load("resnet18.onnx")
inferred_model = shape_inference.infer_shapes(model)
for node in inferred_model.graph.value_info:
    print(node.name, node.type.tensor_type.shape)

跨框架互操作实战:模型转换示例

一个典型场景:将 PyTorch 模型转为 TensorFlow 模型。

步骤 1:将 PyTorch 模型导出为 ONNX。

步骤 2:使用 onnx-tf 将 ONNX 转成 TensorFlow 模型。

pip install onnx-tf
import onnx
from onnx_tf.backend import prepare

# 加载 ONNX 模型
onnx_model = onnx.load("resnet18.onnx")

# 转换为 TensorFlow 图并保存
tf_rep = prepare(onnx_model)
tf_rep.export_graph("resnet18_tf")

执行后会生成一个 TensorFlow SavedModel 文件夹,可直接用 tf.saved_model.load() 加载。

反向转换:TensorFlow 模型转 PyTorch 同样先走 ONNX 中间格式,然后用 onnx2pytorch 工具恢复为 PyTorch 模块。

推理运行:ONNX Runtime

ONNX Runtime(ORT)是微软推出的一款高性能推理引擎,对 ONNX 模型提供一流支持。它支持 Python、C++、C#、Java 等语言,并可以跨 Windows、Linux、macOS 运行。

安装:

pip install onnxruntime   # CPU 版本
# 或
pip install onnxruntime-gpu   # GPU 版本

Python 推理示例:

import onnxruntime
import numpy as np

session = onnxruntime.InferenceSession("resnet18.onnx")

# 获取输入名称
input_name = session.get_inputs()[0].name

# 准备输入数据(与导出时使用的示例形状一致)
input_data = np.random.rand(1, 3, 224, 224).astype(np.float32)

# 执行推理
outputs = session.run(None, {input_name: input_data})
print(outputs[0].shape)

ORT 会选择系统上最优的执行提供程序(Execution Provider)。可以通过以下方式查看可用的:

providers = onnxruntime.get_available_providers()
print(providers)  # 例如 ['CPUExecutionProvider', 'CUDAExecutionProvider']

若想强制使用 GPU,可以这样创建会话:

session = onnxruntime.InferenceSession("model.onnx",
                     providers=['CUDAExecutionProvider'])

模型优化

ONNX 生态提供了丰富的优化工具,能在几乎不影响精度的情况下大幅提升推理速度。

基础图优化

在导出时开启常量折叠和算子融合只是第一步。使用 onnxoptimizer 可以进行更深层次的优化:

pip install onnxoptimizer
from onnx import optimizer
import onnx

model = onnx.load("model.onnx")
passes = ["extract_constant_to_initializer",
          "eliminate_unused_initializer",
          "eliminate_nop_transpose",
          "fuse_consecutive_transposes",
          "fuse_transpose_into_gemm"]
optimized_model = optimizer.optimize(model, passes)
onnx.save(optimized_model, "model_optimized.onnx")

量化

将模型从 FP32 量化到 INT8 可以显著减小体积并提升推理速度,尤其适合移动和边缘设备。

使用 ONNX Runtime 的量化工具:

from onnxruntime.quantization import quantize_dynamic, QuantType

quantize_dynamic(
    model_input="model.onnx",
    model_output="model_int8.onnx",
    weight_type=QuantType.QInt8
)

动态量化只对权重进行量化,激活仍然是浮点,易于实现且无需校准数据。如需更激进的静态量化,需要提供校准数据集。

转换到其他推理格式

ONNX 模型还可以进一步转换为专用加速格式,例如:

  • TensorRT:使用 trtexec 或 ONNX Runtime TensorRT EP 直接优化。
  • OpenVINO:使用 model_optimizer 转换为 IR 格式。
  • CoreML:通过 onnx-coreml 转换为 .mlmodel

常见问题与解决思路

算子不兼容

导出时如果提示某个算子不被 ONNX 支持,可以:

  1. 升级 ONNX opset 版本。
  2. 用支持的标准算子组合替代。
  3. 使用框架自带的兼容转换方法(如 PyTorch 的 torch.onnx.symbolic)。

动态形状导致推理失败

导出时设置动态轴,但在推理时未正确匹配。确保调用 run() 时传入的输入形状与动态轴定义一致,并且所有中间尺寸可推导。

模型转换后精度下降

  • 检查模型转换后是否通过了 check_model
  • 将同一份输入分别喂给原模型和 ONNX 模型,对比输出差异。
  • 差异过大时,可能是某些算子实现存在数值差异,尝试更换 opset 或禁用某些优化。

总结

ONNX 已经成为深度学习行业事实上的互操作标准。掌握 ONNX 模型交换,不仅可以为你的模型提供“一劳永逸”的部署方案,更能让你在不同框架和硬件之间自由穿梭。从今天开始,将你训练的模型导出为 ONNX 格式,纳入这个庞大而活跃的开放生态中吧。