当前位置:网站首页 > 云计算与后端部署 > 正文

onnx模型部署修改(onnx 模型)



概述

神经网络本质上是一个计算图。计算图的节点是算子,边是参与运算的张量。而通过可视化 ONNX 模型,我们知道 ONNX 记录了所有算子节点的属性信息,并把参与运算的张量信息存储在算子节点的输入输出信息中。事实上,ONNX 模型的结构可以用类图大致表示如下:

如图所示,一个 ONNX 模型可以用 ModelProto 类表示。

  • ModelProto 包含了版本、创建者等日志信息,还包含了存储计算图结构的 graph。
  • GraphProto 类则由输入张量信息、输出张量信息、节点信息组成。
  • 张量信息 ValueInfoProto 类包括张量名、基本数据类型、形状。
  • 节点信息 NodeProto 类包含了算子名、算子输入张量名、算子输出张量名。

定义ONNX

尝试完全用 ONNX 的 Python API 构造一个描述线性函数 output=a*x+b 的 ONNX 模型。我们将根据上面的结构,自底向上地构造这个模型。

import onnx from onnx import helper from onnx import TensorProto  a = helper.make_tensor_value_info('a', TensorProto.FLOAT, [10, 10]) x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [10, 10]) b = helper.make_tensor_value_info('b', TensorProto.FLOAT, [10, 10]) output = helper.make_tensor_value_info('output', TensorProto.FLOAT, [10, 10])

之后,我们要构造算子节点信息 NodeProto,这可以通过在 helper.make_node 中传入算子类型、输入算子名、输出算子名这三个信息来实现。我们这里先构造了描述 c=a*x 的乘法节点,再构造了 output=c+b 的加法节点。如下面的代码所示:

mul = helper.make_node('Mul', ['a', 'x'], ['c']) add = helper.make_node('Add', ['c', 'b'], ['output'])

在计算机中,图一般是用一个节点集和一个边集表示的。而 ONNX 巧妙地把边的信息保存在了节点信息里,省去了保存边集的步骤。在 ONNX 中,如果某节点的输入名和之前某节点的输出名相同,就默认这两个节点是相连的。

正是因为有这种边的隐式定义规则,所以 ONNX 对节点的输入有一定的要求:一个节点的输入,要么是整个模型的输入,要么是之前某个节点的输出

接下来,我们用 helper.make_graph 来构造计算图 GraphProto。helper.make_graph 函数需要传入节点、图名称、输入张量信息、输出张量信息这 4 个参数。如下面的代码所示,我们把之前构造出来的 NodeProto 对象和 ValueInfoProto 对象按照顺序传入即可。

graph = helper.make_graph([mul, add], 'linear_func', [a, x, b], [output])

make_graph 的节点参数有一个要求:计算图的节点必须以拓扑序给出(如果按拓扑序遍历所有节点的话,能保证每个节点的输入都能在之前节点的输出里找到)。

拓扑排序:对一个有向无环图(Directed Acyclic Graph简称DAG)G进行拓扑排序,是将G中所有顶点排成一个线性序列,使得图中任意一对顶点u和v,若边∈E(G),则u在线性序列中出现在v之前。通常,这样的线性序列称为满足拓扑次序(Topological Order)的序列,简称拓扑序列。简单的说,由某个集合上的一个偏序得到该集合上的一个全序,这个操作称之为拓扑排序。

最后,我们用 helper.make_model 把计算图 GraphProto 封装进模型 ModelProto 里,一个 ONNX 模型就构造完成了。make_model 函数中还可以添加模型制作者、版本等信息,为了简单起见,我们没有添加额外的信息。

model = helper.make_model(graph)

构造完模型之后,用下面这三行代码来检查模型正确性、把模型以文本形式输出、存储到一个 “.onnx” 文件里。这里用 onnx.checker.check_model 来检查模型是否满足 ONNX 标准是必要的,因为无论模型是否满足标准,ONNX 都允许我们用 onnx.save 存储模型。

onnx.checker.check_model(model) print(model) onnx.save(model, 'linear_func.onnx')

完整代码如下:

import onnx from onnx import helper from onnx import TensorProto  # input and output a = helper.make_tensor_value_info('a', TensorProto.FLOAT, [10, 10]) x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [10, 10]) b = helper.make_tensor_value_info('b', TensorProto.FLOAT, [10, 10]) output = helper.make_tensor_value_info('output', TensorProto.FLOAT, [10, 10])  # Mul mul = helper.make_node('Mul', ['a', 'x'], ['c'])  # Add add = helper.make_node('Add', ['c', 'b'], ['output'])  # graph and model graph = helper.make_graph([mul, add], 'linear_func', [a, x, b], [output]) model = helper.make_model(graph)  # save model onnx.checker.check_model(model) print(model) onnx.save(model, 'linear_func.onnx')

可以用 ONNX Runtime 运行模型,来看看模型是否正确:

import onnxruntime import numpy as np  sess = onnxruntime.InferenceSession('linear_func.onnx') a = np.random.rand(10, 10).astype(np.float32) b = np.random.rand(10, 10).astype(np.float32) x = np.random.rand(10, 10).astype(np.float32)  output = sess.run(['output'], {'a': a, 'b': b, 'x': x})[0] 
# 比较两个array是不是每一元素都相等,默认在1e-05的误差范围内assert np.allclose(output, a * x + b)

一切顺利的话,这段代码不会有任何报错信息。这说明我们的模型等价于执行 a * x + b 这个计算。

netron 可视化查看:

读写 ONNX

import onnx model = onnx.load('linear_func.onnx') 
# 访问节点graph = model.graph node = graph.node input = graph.input output = graph.output

可以用jupyter-notebook调试,很方便

当我们想知道 ONNX 模型某数据对象有哪些属性时,只需要先把数据对象输出一下,然后在输出结果找出属性名即可。

import onnx model = onnx.load('linear_func.onnx')  node = model.graph.node node[1].op_type = 'Sub'  onnx.checker.check_model(model) onnx.save(model, 'linear_func_2.onnx')

调试 ONNX

在实际部署中,如果用深度学习框架导出的 ONNX 模型出了问题,一般要通过修改框架的代码来解决,而不会从 ONNX 入手,我们把 ONNX 模型当成一个不可修改的黑盒看待。

子模型提取

ONNX 官方为开发者提供了子模型提取(extract)的功能。子模型提取,顾名思义,就是从一个给定的 ONNX 模型中,拿出一个子模型。这个子模型的节点集、边集都是原模型中对应集合的子集。让我们来用 PyTorch 导出一个复杂一点的 ONNX 模型,并在它的基础上执行提取操作:

import torch import onnx
class Model(torch.nn.Module): def __init__(self): super().__init__() self.convs1 = torch.nn.Sequential(torch.nn.Conv2d(3, 3, 3), torch.nn.Conv2d(3, 3, 3), torch.nn.Conv2d(3, 3, 3)) self.convs2 = torch.nn.Sequential(torch.nn.Conv2d(3, 3, 3), torch.nn.Conv2d(3, 3, 3)) self.convs3 = torch.nn.Sequential(torch.nn.Conv2d(3, 3, 3), torch.nn.Conv2d(3, 3, 3)) self.convs4 = torch.nn.Sequential(torch.nn.Conv2d(3, 3, 3), torch.nn.Conv2d(3, 3, 3), torch.nn.Conv2d(3, 3, 3)) def forward(self, x): x = self.convs1(x) x1 = self.convs2(x) x2 = self.convs3(x) x = x1 + x2 x = self.convs4(x) return x model = Model() input = torch.randn(1, 3, 20, 20) # PyTorch 自动生成输入和输出的张量序号torch.onnx.export(model, input, 'whole_model.onnx')
onnx.utils.extract_model('whole_model.onnx', 'partial_model.onnx', ['22'], ['28'])

模型的可视化结果如下图所示:

子模型的可视化结果如下图所示:

onnx.utils.extract_model 就是完成子模型提取的函数,它的参数分别是原模型路径、输出模型路径、子模型的输入边(输入张量)、子模型的输出边(输出张量)。

添加额外输出

我们在提取时新设定了一个输出张量,如下面的代码所示:

onnx.utils.extract_model('whole_model.onnx', 'submodel_1.onnx', ['22'], ['27', '31'])

我们可以看到子模型会添加一条把张量输出的新边,如下图所示:

输入信息不足

尝试提取的子模型输入是边 24,输出是边 28。

# Error onnx.utils.extract_model('whole_model.onnx', 'submodel_3.onnx', ['24'], ['28'])

想通过边 24 计算边 28 的结果,至少还需要输入边 26,或者更上面的边。仅凭借边 24 是无法计算出边 28 的结果的,因此这样提取子模型会报错。

在使用 ONNX 模型时,可以在提取子模型时,添加了一条原来模型中不存在的输出边,用推理引擎输出中间节点的值。所以在框架模型和 ONNX 模型的精度对齐中,只要能够输出中间节点的值,就能定位到精度出现偏差的算子。

但子模型提取固然是一个便利的 ONNX 调试工具。但是,在实际的情况中,我们一般是用 PyTorch 等框架导出 ONNX 模型。这里有两个问题:

  • 一旦 PyTorch 模型改变,ONNX 模型的边序号也会改变。这样每次提取同样的子模块时都要重新去 ONNX 模型里查序号,如此繁琐的调试方法是不会在实践中采用的。
  • 即使我们能保证 ONNX 的边序号不发生改变,也难以把 PyTorch 代码和 ONNX 节点对应起来——当模型结构变得十分复杂时,要识别 ONNX 中每个节点的含义是不可能的。

到此这篇onnx模型部署修改(onnx 模型)的文章就介绍到这了,更多相关内容请继续浏览下面的相关推荐文章,希望大家都能在编程的领域有一番成就!

版权声明


相关文章:

  • ueditor官网版本(ueditor部署方法)2026-01-17 10:09:09
  • 操作系统课后(操作系统课后题答案汤小丹)2026-01-17 10:09:09
  • 操作系统教程第一版(操作系统教程课后答案)2026-01-17 10:09:09
  • redis-cli 端口(redis client 端口)2026-01-17 10:09:09
  • redis的端口号是多少(redis的默认端口号)2026-01-17 10:09:09
  • ceph存储部署(cephadm部署ceph集群)2026-01-17 10:09:09
  • 苹果软件后缀apk(iphone软件后缀)2026-01-17 10:09:09
  • 安装软件后缀(安装软件后缀-ce)2026-01-17 10:09:09
  • onnx模型部署到手机(onnx模型是什么)2026-01-17 10:09:09
  • orecal(orecal默认端口)2026-01-17 10:09:09
  • 全屏图片