为了账号安全,请及时绑定邮箱和手机立即绑定

利用ONNX实现互操作性机器学习:Java Spring Boot实战指南

ONNX是什么?

开放神经网络交换(ONNX) 是一种开源格式,用于表示机器学习模型。它提供了一种标准方式来定义和交换不同框架之间的模型,使得在一种框架中训练的模型可以在另一种框架中使用。这种互操作性在今天的机器学习生态系统中至关重要,因为在不同的框架中(例如 TensorFlow、PyTorch、Scikit-learn)可能会创建模型。

ONNX的关键特性
  1. 互操作性:ONNX 允许开发人员轻松在不同的框架之间切换,而无需重写他们的模型。
  2. 标准化:它为机器学习模型提供了一致的格式,减少了对不同模型格式的困惑。
  3. 优化运行时ONNX 运行时是一个高性能的推理引擎,能够在包括 CPU、GPU 以及专用硬件在内的多种平台上高效执行 ONNX 模型。
  4. 广泛采用:许多流行的机器学习库和框架支持 ONNX,包括 PyTorch、TensorFlow 和 sklearn 等。
ONNX 模型的组成

一个ONNX模型包括几个部分:

  • :表示计算的核心结构,其中包括节点(操作)和边(数据流)。
  • 节点:图中的每个节点代表一个操作(例如,卷积、激活),并且有输入和输出。
  • 张量:数据以张量的形式表示,即多维数组,用于存储模型参数和输入。
  • 元数据:元数据包含了关于模型的信息,包括输入/输出规范和训练配置。
创建ONNX模型

本节我们将演示如何使用经典的鸢尾花数据集创建一个ONNX模型。鸢尾花数据集是一个非常经典的ML分类任务数据集,包含150个样本,每个样本有四个特征,分为三种不同的类别。

第1步:先准备好咱们的环境

我强烈建议为你的项目使用虚拟环境,比如conda(适用于Mac)或venv(适用于Linux)。这取决于你使用的是什么操作系统。这里有几个原因,

  • 独立性:虚拟环境让你为不同的项目独立管理依赖,防止版本冲突。
  • 可复现性:通过使用虚拟环境,你可以轻松地复制你的设置,让其他人更容易参与你的项目,或者让你以后更容易回到项目中。
  • 易用性:像 condavenv 这样的工具简化了包管理和环境设置过程,从而让你专注于开发而不是配置。
如何使用 Env for Linux:
python -m venv myEnvKia  
source myEnvKia/bin/activate  # 这是在Linux上运行的命令,在Windows上请使用:myEnvKia\Scripts\activate
在 Mac 或 Windows 上使用 Conda

在运行命令之前,请确保你的电脑上已经安装了Anaconda。

使用conda创建一个名为myenv的环境,并指定Python版本为3.9:

conda create --name myenv python=3.9  

激活该环境:

conda activate myenv
第一步:准备一下环境

首先,请确保您的环境中已安装Python及其所需的所有库。您可以使用pip进行安装。

    pip install scikit-learn skl2onnx onnx onnxmltools onnxconverter-common # 如果你在 Google Colab 中运行,记得在命令前加上感叹号!
步骤 2:建立并保存鸢尾花分类模型:

以下是一个Python脚本,演示如何在Iris数据集上训练随机森林模型。它还包括评估模型的准确性并将模型保存成ONNX格式。

    import numpy as np  
    import json  
    from sklearn import datasets  
    from sklearn.model_selection import train_test_split  
    from sklearn.ensemble import RandomForestClassifier  
    from sklearn.metrics import accuracy_score  
    from sklearn.preprocessing import LabelEncoder  
    import onnx  
    import onnxmltools  

    # 加载 Iris 数据集  
    iris = datasets.load_iris()  
    X = iris.data  # 特征  
    y = iris.target  # 标签  

    # 编码类别标签  
    label_encoder = LabelEncoder()  
    y_encoded = label_encoder.fit_transform(y)  

    # 划分数据集为训练集和测试集  
    X_train, X_test, y_train, y_test = train_test_split(X, y_encoded, test_size=0.2, random_state=42)  

    # 训练随机森林分类器  
    model = RandomForestClassifier(n_estimators=100)  
    model.fit(X_train, y_train)  

    # 评估模型  
    y_pred = model.predict(X_test)  
    accuracy = accuracy_score(y_test, y_pred)  
    print(f"模型准确率: {accuracy * 100:.2f}%")  

    # 定义初始类型  
    initial_type = [('input', FloatTensorType([None, X.shape[1]]))]  

    # 转换为 ONNX  
    onnx_model = skl2onnx.convert_sklearn(model, initial_types=initial_type)  

    # 保存模型  
    onnx.save_model(onnx_model, 'iris_classifier.onnx')  

    # 保存标签编码器映射  
    mapping = {int(index): int(label) for index, label in enumerate(label_encoder.classes_)}    
    # 将映射序列化为 JSON  
    with open('class_mapping.json', 'w') as f:  
        json.dump(mapping, f)
脚本解释
  1. 加载数据集:使用 Scikit-learn 的 datasets.load_iris() 函数加载 Iris 数据集。
  2. 标签编码:使用 LabelEncoder 将目标标签转换为数字形式。
  3. 模型训练:使用数据集训练一个随机森林分类器。
  4. 模型评估:在测试集上评估模型的准确度。
  5. 转换为 ONNX 格式:使用 onnxmltools 将模型转换为 ONNX 格式。
  6. 保存模型:将 ONNX 模型以及类索引到类名的映射保存到文件中。
  7. 创建 JSON 标签编码类映射
步骤 3:在 Java Spring Boot 应用中使用 ONNX 模型(Model)

现在我们有了一个ONNX模型,可以将其集成到基于Java Spring Boot框架的应用程序中,以便提供一个RESTful API接口,接受输入数据并返回类别预测。

第 3.1 步:搭建 Spring Boot 项目
  1. 创建一个新的 Spring Boot 项目:使用 Spring Initializr 创建一个新的项目,并添加 Spring Web 依赖。
  2. 添加 ONNX Runtime 依赖项:在你的 pom.xml 中添加以下依赖项:

      <dependency>  
       <groupId>com.microsoft.onnxruntime</groupId>  
       <artifactId>onnxruntime</artifactId>  
       <version>1.15.0</version><!-- 请检查最新版本 -->
      </dependency>
步骤 3.2:将 ONNX 模型加载到 Spring 中
  1. 创建服务类(Service 类)

创建一个名为 OnnxModelService 的服务类(Service)来处理模型推断。

package kia.demo.ml.onnx.service;  

import ai.onnxruntime.OrtEnvironment;  
import ai.onnxruntime.OrtException;  
import ai.onnxruntime.OrtSession;  
import ai.onnxruntime.*;  
import org.springframework.beans.factory.annotation.Autowired;  
import org.springframework.core.io.Resource;  
import org.springframework.core.io.ResourceLoader;  
import org.springframework.stereotype.Service;  
import com.fasterxml.jackson.databind.ObjectMapper;  

import java.io.IOException;  
import java.nio.file.Path;  
import java.util.HashMap;  
import java.util.Map;  
import java.util.Optional;  

@Service  
public class OnnxModelService {  
    private OrtEnvironment env;  
    private OrtSession session;  
    private double accuracy;  
    private Map<Integer, String> classMapping = new HashMap<>();  

    @Autowired  
    private ResourceLoader resourceLoader;  

    public OnnxModelService() throws OrtException, IOException {  
        // 初始化ONNX Runtime环境  
        env = OrtEnvironment.getEnvironment();  
        // 从类路径加载模型  
        Resource resource = resourceLoader.getResource("classpath:iris_classifier.onnx");  
        Path modelPath = resource.getFile().toPath();  
        session = env.createSession(modelPath.toString());  

        // 从JSON文件中加载类别映射  
        Resource mappingResource = resourceLoader.getResource("classpath:class_mapping.json");  
        ObjectMapper objectMapper = new ObjectMapper();  
        classMapping = objectMapper.readValue(mappingResource.getFile(), HashMap.class);  

        // 为了演示,这里硬编码了准确率  
        this.accuracy = 1.0; // 如果需要,可以用实际计算的准确率替换  

    }  

    public float[] predict(float[][] inputData) throws OrtException {  
        OnnxTensor inputTensor = OnnxTensor.createTensor(env, inputData);  
        Map<String, OnnxTensor> inputs = new HashMap<>();  
        inputs.put("input", inputTensor);  

        // 模型的主要处理部分  
        OrtSession.Result outputs = session.run(inputs);  

        Optional<OnnxValue> optionalValue = outputs.get("output");  
        OnnxTensor predictionsTensor = (OnnxTensor) optionalValue.get();  
        float[] predictions = (float[]) predictionsTensor.getValue();  

        // 清理资源  
        inputTensor.close();  
        outputs.close();  
        return predictions;  
    }  
}

第 3.3 步:创建预测控制模块

创建一个名为 IrisPredictionController 的 REST 控制器,用于处理 API 请求的控制器。

    import org.springframework.beans.factory.annotation.Autowired;  
    import org.springframework.web.bind.annotation.*;  

    @RestController  
    @RequestMapping("/api/predict")  
    public class IrisPredictionController {  

        @Autowired  
        private OnnxModelService onnxModelService;  

        @PostMapping  
        public float[] predict(@RequestBody float[][] inputData) {  
            try {  
                return onnxModelService.predict(inputData);  
            } catch (OrtException e) {  
                throw new RuntimeException("推理过程失败", e);  
            }  
        }  
    }
第 3.4 步:启动您的 Spring Boot 应用

iris_classifier.onnxclass_mapping.json 文件放到你的 Spring Boot 项目的 src/main/resources 文件夹里。然后启动你的 Spring Boot 应用程序。

步骤 3.5:测试 API 功能

你可以用像 Postman 或 curl 这样的工具来测试预测接口。

示例请求:

    [  
        [5.1, 3.5, 1.4, 0.2],  // Iris 属 setosa 的示例数据  
        [6.7, 3.1, 4.4, 1.4]   // Iris 属 versicolor 的示例数据  
    ]
下面是一个发送POST请求的curl命令:
curl -X POST http://localhost:8080/api/predict \  
-H "Content-Type: application/json" \  
-d '[  
    [5.1, 3.5, 1.4, 0.2],  
    [6.7, 3.1, 4.4, 1.4]  
]'
发送的数据是两组数字,每组代表一个样本的特征值:

关于命令的解释

  • -X POST: 这表示我们正在做一个POST请求。
  • http://localhost:8080/api/predict: 这是你要测试的接口的URL。
  • -H "Content-Type: application/json": 这设置了头部,表明我们发送的是JSON格式的数据。
  • -d '[...]': 这里包含了要发送到服务器的JSON格式数据,其中特征以数组形式表示。

示例回答.

[1.0, 2.0] // 这些类别的标签包括 'Iris Setosa' 和 Iris Versicolor

现在就来展示一下 ONNX 可以做到的超乎你想象的事情,让你意想不到。

要展示ONNX的性能表现,你可以使用Apache Benchmark(ab)工具对你的Spring Boot应用进行压力测试,你可以按照以下步骤。Apache Benchmark是一个简单的工具,可以用来对你的web服务器进行基准测试,可以帮助你评估预测接口的性能。

步骤 1:安装 ab

如果你还没安装 Apache Benchmark,你可以用包管理器来搞定。下面是一些不同系统的安装方法。

Ubuntu/Debian 系统:
运行以下命令来安装Apache2工具:
sudo apt-get install apache2-utils
对于 MacOS 用户(使用 Homebrew)

运行以下命令安装httpd:brew install httpd

我们需要准备好测试。

你可以使用以下命令对 /api/predict 端点进行压力测试。这个例子会发送一系列请求,每个请求都包含用于预测的示例数据。首先,准备一个包含你想要测试的输入数据的 JSON 文件,比如 input.json。例如,可以创建一个名为 input.json 的文件,内容如下:

    [  
        [5.1, 3.5, 1.4, 0.2],  
        [6.7, 3.1, 4.4, 1.4]  
    ]
看这些数据,每个小括号代表一个数据点,包含四个数值。这是可能用于数据分析或机器学习模型输入的数值列表。
第三步:运行 Apache 压力测试工具命令

使用以下命令来运行压力测试。

ab -n 1000 -c 100 -p input.json -T application/json http://localhost:8080/api/predict
来解释一下这个命令
  • -n 1000: 这指定了总共要执行的请求数(本例中为1000)。
  • -c 100: 这指定了每次同时执行的请求数量(100个并发请求)。
  • -p input.json: 这个选项指定了包含要通过POST请求发送的数据的文件。
  • -T application/json: 这将POST请求的内容类型设置为JSON格式。
  • http://localhost:8080/api/predict: 这是被测试的端点的URL。
现在我们来分析一下结果

当你运行那个命令时,你会看到类似以下的内容。

    这是 ApacheBench,版本 2.3 ($Revision: 1879491$)  
    ...  
    并发级别:      100  
    测试所用时间:   5.123 秒  
    完成请求数:      1000  
    失败请求次数:    0  
    总传输数据量:    123456 字节  
    HTML 传输数据量:  123456 字节  
    每秒平均请求数:    195.33 [#/sec]  
    每请求平均耗时:    512.3 [ms]  
    所有并发请求平均耗时:  5.123 [ms]  
    平均接收速率:    24.76 [KBytes/sec]  
    ...
几个关键指标
  • 每秒请求数:表示服务器每秒处理的请求数。
  • 测试总耗时:完成所有请求所需的总时间。
  • 失败请求数:失败的请求数量;理想状态下,这个数字最好是零。
  • 每个请求平均处理时间:处理每个请求的平均时间。
性能测试的结果

运行完 Apache 压力测试(如 Apache Benchmark)后,我们得到了 /api/predict 端点的以下结果:

  • 总请求数 : 1000
  • 并发请求 : 100
  • 用时 : 5.123 秒 // 这也太惊人了!
  • 每秒请求数 : 195.33
  • 每个请求平均耗时 : 512.3 毫秒 (平均)
结论

在这篇文章中,我们探讨了ONNX格式及其在机器学习生态系统中的重要性。我们展示了如何使用Iris数据集来创建ONNX模型,并如何将该模型部署到Java Spring Boot应用程序中进行推理。ONNX的集成提供了灵活性和互操作性,这使得在不同环境中使用机器学习模型变得更简单。这种设置为未来构建更复杂的机器学习应用程序奠定了坚实的基础。

Git 仓库:https://github.com/KiaShamaei/onnxSpring

参考文献:

https://onnxruntime.ai/docs/tutorials/mnist_java.html

[ONNX Runtime | 首页]跨平台加速机器学习。内置优化可加速训练和推理。https://onnxruntime.ai/?source=post_page-----67a2ffbe7bf1--------------------------------
点击查看更多内容
TA 点赞

若觉得本文不错,就分享一下吧!

评论

作者其他优质文章

正在加载中
  • 推荐
  • 评论
  • 收藏
  • 共同学习,写下你的评论
感谢您的支持,我会继续努力的~
扫码打赏,你说多少就多少
赞赏金额会直接到老师账户
支付方式
打开微信扫一扫,即可进行扫码打赏哦
今天注册有机会得

100积分直接送

付费专栏免费学

大额优惠券免费领

立即参与 放弃机会
意见反馈 帮助中心 APP下载
官方微信

举报

0/150
提交
取消