开放神经网络交换(ONNX) 是一种开源格式,用于表示机器学习模型。它提供了一种标准方式来定义和交换不同框架之间的模型,使得在一种框架中训练的模型可以在另一种框架中使用。这种互操作性在今天的机器学习生态系统中至关重要,因为在不同的框架中(例如 TensorFlow、PyTorch、Scikit-learn)可能会创建模型。
ONNX的关键特性- 互操作性:ONNX 允许开发人员轻松在不同的框架之间切换,而无需重写他们的模型。
- 标准化:它为机器学习模型提供了一致的格式,减少了对不同模型格式的困惑。
- 优化运行时:ONNX 运行时是一个高性能的推理引擎,能够在包括 CPU、GPU 以及专用硬件在内的多种平台上高效执行 ONNX 模型。
- 广泛采用:许多流行的机器学习库和框架支持 ONNX,包括 PyTorch、TensorFlow 和 sklearn 等。
一个ONNX模型包括几个部分:
- 图:表示计算的核心结构,其中包括节点(操作)和边(数据流)。
- 节点:图中的每个节点代表一个操作(例如,卷积、激活),并且有输入和输出。
- 张量:数据以张量的形式表示,即多维数组,用于存储模型参数和输入。
- 元数据:元数据包含了关于模型的信息,包括输入/输出规范和训练配置。
本节我们将演示如何使用经典的鸢尾花数据集创建一个ONNX模型。鸢尾花数据集是一个非常经典的ML分类任务数据集,包含150个样本,每个样本有四个特征,分为三种不同的类别。
第1步:先准备好咱们的环境我强烈建议为你的项目使用虚拟环境,比如conda
(适用于Mac)或venv
(适用于Linux)。这取决于你使用的是什么操作系统。这里有几个原因,
- 独立性:虚拟环境让你为不同的项目独立管理依赖,防止版本冲突。
- 可复现性:通过使用虚拟环境,你可以轻松地复制你的设置,让其他人更容易参与你的项目,或者让你以后更容易回到项目中。
- 易用性:像
conda
和venv
这样的工具简化了包管理和环境设置过程,从而让你专注于开发而不是配置。
python -m venv myEnvKia
source myEnvKia/bin/activate # 这是在Linux上运行的命令,在Windows上请使用:myEnvKia\Scripts\activate
在 Mac 或 Windows 上使用 Conda
在运行命令之前,请确保你的电脑上已经安装了Anaconda。
使用conda创建一个名为myenv的环境,并指定Python版本为3.9:
conda create --name myenv python=3.9
激活该环境:
conda activate myenv
第一步:准备一下环境
首先,请确保您的环境中已安装Python及其所需的所有库。您可以使用pip进行安装。
pip install scikit-learn skl2onnx onnx onnxmltools onnxconverter-common # 如果你在 Google Colab 中运行,记得在命令前加上感叹号!
步骤 2:建立并保存鸢尾花分类模型:
以下是一个Python脚本,演示如何在Iris数据集上训练随机森林模型。它还包括评估模型的准确性并将模型保存成ONNX格式。
import numpy as np
import json
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score
from sklearn.preprocessing import LabelEncoder
import onnx
import onnxmltools
# 加载 Iris 数据集
iris = datasets.load_iris()
X = iris.data # 特征
y = iris.target # 标签
# 编码类别标签
label_encoder = LabelEncoder()
y_encoded = label_encoder.fit_transform(y)
# 划分数据集为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y_encoded, test_size=0.2, random_state=42)
# 训练随机森林分类器
model = RandomForestClassifier(n_estimators=100)
model.fit(X_train, y_train)
# 评估模型
y_pred = model.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
print(f"模型准确率: {accuracy * 100:.2f}%")
# 定义初始类型
initial_type = [('input', FloatTensorType([None, X.shape[1]]))]
# 转换为 ONNX
onnx_model = skl2onnx.convert_sklearn(model, initial_types=initial_type)
# 保存模型
onnx.save_model(onnx_model, 'iris_classifier.onnx')
# 保存标签编码器映射
mapping = {int(index): int(label) for index, label in enumerate(label_encoder.classes_)}
# 将映射序列化为 JSON
with open('class_mapping.json', 'w') as f:
json.dump(mapping, f)
脚本解释
- 加载数据集:使用 Scikit-learn 的
datasets.load_iris()
函数加载 Iris 数据集。 - 标签编码:使用
LabelEncoder
将目标标签转换为数字形式。 - 模型训练:使用数据集训练一个随机森林分类器。
- 模型评估:在测试集上评估模型的准确度。
- 转换为 ONNX 格式:使用
onnxmltools
将模型转换为 ONNX 格式。 - 保存模型:将 ONNX 模型以及类索引到类名的映射保存到文件中。
- 创建 JSON 标签编码类映射
现在我们有了一个ONNX模型,可以将其集成到基于Java Spring Boot框架的应用程序中,以便提供一个RESTful API接口,接受输入数据并返回类别预测。
第 3.1 步:搭建 Spring Boot 项目- 创建一个新的 Spring Boot 项目:使用 Spring Initializr 创建一个新的项目,并添加 Spring Web 依赖。
- 添加 ONNX Runtime 依赖项:在你的
pom.xml
中添加以下依赖项:
<dependency>
<groupId>com.microsoft.onnxruntime</groupId>
<artifactId>onnxruntime</artifactId>
<version>1.15.0</version><!-- 请检查最新版本 -->
</dependency>
步骤 3.2:将 ONNX 模型加载到 Spring 中
- 创建服务类(Service 类):
创建一个名为 OnnxModelService
的服务类(Service)来处理模型推断。
package kia.demo.ml.onnx.service;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import ai.onnxruntime.*;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.core.io.Resource;
import org.springframework.core.io.ResourceLoader;
import org.springframework.stereotype.Service;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.io.IOException;
import java.nio.file.Path;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
@Service
public class OnnxModelService {
private OrtEnvironment env;
private OrtSession session;
private double accuracy;
private Map<Integer, String> classMapping = new HashMap<>();
@Autowired
private ResourceLoader resourceLoader;
public OnnxModelService() throws OrtException, IOException {
// 初始化ONNX Runtime环境
env = OrtEnvironment.getEnvironment();
// 从类路径加载模型
Resource resource = resourceLoader.getResource("classpath:iris_classifier.onnx");
Path modelPath = resource.getFile().toPath();
session = env.createSession(modelPath.toString());
// 从JSON文件中加载类别映射
Resource mappingResource = resourceLoader.getResource("classpath:class_mapping.json");
ObjectMapper objectMapper = new ObjectMapper();
classMapping = objectMapper.readValue(mappingResource.getFile(), HashMap.class);
// 为了演示,这里硬编码了准确率
this.accuracy = 1.0; // 如果需要,可以用实际计算的准确率替换
}
public float[] predict(float[][] inputData) throws OrtException {
OnnxTensor inputTensor = OnnxTensor.createTensor(env, inputData);
Map<String, OnnxTensor> inputs = new HashMap<>();
inputs.put("input", inputTensor);
// 模型的主要处理部分
OrtSession.Result outputs = session.run(inputs);
Optional<OnnxValue> optionalValue = outputs.get("output");
OnnxTensor predictionsTensor = (OnnxTensor) optionalValue.get();
float[] predictions = (float[]) predictionsTensor.getValue();
// 清理资源
inputTensor.close();
outputs.close();
return predictions;
}
}
第 3.3 步:创建预测控制模块
创建一个名为 IrisPredictionController
的 REST 控制器,用于处理 API 请求的控制器。
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.*;
@RestController
@RequestMapping("/api/predict")
public class IrisPredictionController {
@Autowired
private OnnxModelService onnxModelService;
@PostMapping
public float[] predict(@RequestBody float[][] inputData) {
try {
return onnxModelService.predict(inputData);
} catch (OrtException e) {
throw new RuntimeException("推理过程失败", e);
}
}
}
第 3.4 步:启动您的 Spring Boot 应用
把 iris_classifier.onnx
和 class_mapping.json
文件放到你的 Spring Boot 项目的 src/main/resources
文件夹里。然后启动你的 Spring Boot 应用程序。
你可以用像 Postman 或 curl 这样的工具来测试预测接口。
示例请求:
[
[5.1, 3.5, 1.4, 0.2], // Iris 属 setosa 的示例数据
[6.7, 3.1, 4.4, 1.4] // Iris 属 versicolor 的示例数据
]
下面是一个发送POST请求的curl命令:
curl -X POST http://localhost:8080/api/predict \
-H "Content-Type: application/json" \
-d '[
[5.1, 3.5, 1.4, 0.2],
[6.7, 3.1, 4.4, 1.4]
]'
发送的数据是两组数字,每组代表一个样本的特征值:
关于命令的解释
-X POST
: 这表示我们正在做一个POST请求。http://localhost:8080/api/predict
: 这是你要测试的接口的URL。-H "Content-Type: application/json"
: 这设置了头部,表明我们发送的是JSON格式的数据。-d '[...]'
: 这里包含了要发送到服务器的JSON格式数据,其中特征以数组形式表示。
示例回答.
[1.0, 2.0] // 这些类别的标签包括 'Iris Setosa' 和 Iris Versicolor
现在就来展示一下 ONNX 可以做到的超乎你想象的事情,让你意想不到。
要展示ONNX的性能表现,你可以使用Apache Benchmark(ab)工具对你的Spring Boot应用进行压力测试,你可以按照以下步骤。Apache Benchmark是一个简单的工具,可以用来对你的web服务器进行基准测试,可以帮助你评估预测接口的性能。
步骤 1:安装 ab
如果你还没安装 Apache Benchmark,你可以用包管理器来搞定。下面是一些不同系统的安装方法。
Ubuntu/Debian 系统:运行以下命令来安装Apache2工具:
sudo apt-get install apache2-utils
对于 MacOS 用户(使用 Homebrew)
运行以下命令安装httpd:brew install httpd
我们需要准备好测试。
你可以使用以下命令对 /api/predict
端点进行压力测试。这个例子会发送一系列请求,每个请求都包含用于预测的示例数据。首先,准备一个包含你想要测试的输入数据的 JSON 文件,比如 input.json
。例如,可以创建一个名为 input.json
的文件,内容如下:
[
[5.1, 3.5, 1.4, 0.2],
[6.7, 3.1, 4.4, 1.4]
]
看这些数据,每个小括号代表一个数据点,包含四个数值。这是可能用于数据分析或机器学习模型输入的数值列表。
第三步:运行 Apache 压力测试工具命令
使用以下命令来运行压力测试。
ab -n 1000 -c 100 -p input.json -T application/json http://localhost:8080/api/predict
来解释一下这个命令
-n 1000
: 这指定了总共要执行的请求数(本例中为1000)。-c 100
: 这指定了每次同时执行的请求数量(100个并发请求)。-p input.json
: 这个选项指定了包含要通过POST请求发送的数据的文件。-T application/json
: 这将POST请求的内容类型设置为JSON格式。http://localhost:8080/api/predict
: 这是被测试的端点的URL。
当你运行那个命令时,你会看到类似以下的内容。
这是 ApacheBench,版本 2.3 ($Revision: 1879491$)
...
并发级别: 100
测试所用时间: 5.123 秒
完成请求数: 1000
失败请求次数: 0
总传输数据量: 123456 字节
HTML 传输数据量: 123456 字节
每秒平均请求数: 195.33 [#/sec]
每请求平均耗时: 512.3 [ms]
所有并发请求平均耗时: 5.123 [ms]
平均接收速率: 24.76 [KBytes/sec]
...
几个关键指标
- 每秒请求数:表示服务器每秒处理的请求数。
- 测试总耗时:完成所有请求所需的总时间。
- 失败请求数:失败的请求数量;理想状态下,这个数字最好是零。
- 每个请求平均处理时间:处理每个请求的平均时间。
运行完 Apache 压力测试(如 Apache Benchmark)后,我们得到了 /api/predict
端点的以下结果:
- 总请求数 : 1000
- 并发请求 : 100
- 用时 : 5.123 秒 // 这也太惊人了!
- 每秒请求数 : 195.33
- 每个请求平均耗时 : 512.3 毫秒 (平均)
在这篇文章中,我们探讨了ONNX格式及其在机器学习生态系统中的重要性。我们展示了如何使用Iris数据集来创建ONNX模型,并如何将该模型部署到Java Spring Boot应用程序中进行推理。ONNX的集成提供了灵活性和互操作性,这使得在不同环境中使用机器学习模型变得更简单。这种设置为未来构建更复杂的机器学习应用程序奠定了坚实的基础。
Git 仓库:https://github.com/KiaShamaei/onnxSpring
参考文献:https://onnxruntime.ai/docs/tutorials/mnist_java.html
[ONNX Runtime | 首页]跨平台加速机器学习。内置优化可加速训练和推理。https://onnxruntime.ai/?source=post_page-----67a2ffbe7bf1--------------------------------共同学习,写下你的评论
评论加载中...
作者其他优质文章