10 月 23 - 25 日,QCon 上海站即将召开,现在购票,享9折优惠 了解详情
写点什么

Java 机器学习工具箱:Amazon Deep Java Library

作者:Xinyu Liu, Frank Liu 等

  • 2020-06-17
  • 本文字数:4023 字

    阅读完需:约 13 分钟

Java机器学习工具箱:Amazon Deep Java Library

本文要点


  • 目前还没有用 Java 开发机器学习应用程序的标准

  • JSR 381 的提出就是为了填补这项空白

  • Amazon 的 Deep Java Library(DJL)是这个新标准的其中一种实现

  • VisRec 是 JSR 381 的一部分,用于图像的视觉识别

  • DJL 包含许多预训练的模型


近年来,人们对机器学习的兴趣稳步增长。具体来说,现在,企业在各种各样的场景中使用机器学习进行图像识别。它在汽车工业医疗保健安全零售仓库、农场和农业的自动化产品跟踪食品识别,甚至通过手机摄像头进行实时翻译等方面都有应用。借助机器学习和视觉识别,机器可以从 MRI 和 CT 扫描结果中发现癌症COVID-19


如今,这些解决方案主要是用 Python 开发的,使用了开源和专有的 ML 工具包,每个工具包都有自己的 API。尽管Java在企业中很流行,但是 Java 中没有任何标准是针对机器学习应用程序开发的。JSR-381的提出就是为了填补这项空白,它为 Java 应用程序开发人员提供了一套标准的、灵活的、Java 友好的、面向视觉识别(VisRec)应用程序(如图像分类和对象检测)的 API。JSR-381 有几个依赖于 TensorFlow、MXNet 和 DeepNetts 等机器学习平台的实现。其中一个实现是基于Deep Java Library(DJL)的,这是一个由 Amazon 开发的开源库,用于使用 Java 构建机器学习应用。DJL 通过绑定必要的图像处理例程,提供了流行机器学习框架(如TensorFlowMXNetPyTorch)的钩子,对于 JSR-381 的用户来说,这是一个灵活而简单的选项。


在本文中,我们将演示 Java 开发人员如何使用 JSR-381 VisRec API 在不到 10 行代码内利用 DJL 的预训练模型实现图像分类或对象检测。我们还通过两个例子演示了用户如何在 10 分钟内使用预先训练好的机器学习模型。让我们开始吧!

使用预训练的模型识别手写数字

识别手写数字是一个有用的应用,也是视觉识别的一个“hello world”示例。对人类来说,识别手写数字似乎很容易。得益于我们大脑中视觉和模式匹配子系统的处理能力和协作,我们通常可以从潦草的手写文件中正确地识别出数字。然而,由于可能存在许多变化,这个看似简单的任务对于机器来说是难以置信的复杂。这是机器学习,特别是视觉识别的一个很好的用例。JSR 381 库中有一个很好的示例,使用 JSR-381 VisRec API 正确地识别出了手写数字。这个示例将手写数字与MNIST手写数字数据集进行比较,后者是一个包含超过 6 万幅图像的公开数据库。预测图像所代表的内容称为图像分类。我们的示例查看一副新图像,并确定它具体是哪个数字的概率。


对于这项任务,VisRec API 提供了一个ImageClassifier接口,可以使用泛型参数具体化为输入图像的特定 Java 类。它还提供了一个classie()方法,该方法执行图像分类并返回所有可能的图像类别与概率的Map。根据 VisRec API 的约定,每个模型都提供一个静态的builder()方法,它返回一个对应的builder对象,并允许开发者配置所有相关的设置,例如imageHeightimageWidth


在我们的手写数字示例中,要定义一个图像分类器,就需要使用inputClass(BufferedImage.class) 配置输入句柄。你可以通过它指定使用哪个类来表示图像。你可以使用imageHeight(28)imageWidth(28) 将图像尺寸调整到 28x28,模型最初训练时就用的这个大小。


分类器对象构建完成后,将输入图像输入到分类器以识别图像。


File input = new File("../jsr381/src/test/resources/0.png");// 使用mlp文件夹里的预训练模型Path modelPath = Paths.get("../jsr381/src/test/resources/mlp");ImageClassifier<BufferedImage> classifier =       NeuralNetImageClassifier.builder()           // 输入时一个图像文件,应该作为BufferImage进行处理           .inputClass(BufferedImage.class)           // 图像尺寸应该调整到28 x 28           .imageHeight(28)           .imageWidth(28)           .importModel(modelPath)           .build();// 执行推断并获取分类结果Map<String, Float> result = classifier.classify(input);// 打印结果for (Map.Entry<String, Float> entry : result.entrySet()) {   System.out.println(entry.getKey() + ": " + entry.getValue());}
复制代码


执行上述代码会产生以下输出:


0: 0.99976332: 6.915607E-55: 2.7744078E-56: 6.1097984E-59: 3.8322916E-5
复制代码


对于图像中的数字,该模型识别出五种可能的选项,以及每个选项的概率。分类器以 99.98%的压倒性概率正确地预测了数字 0。


推而广之,如果需要检测出同一副图像中的多个不同的对象该怎么办?

使用预训练的单帧检测器(SSD)模型识别物体

单帧检测器(SSD)是一种利用一个深度神经网络从图像中检测物体的机制。本例使用预先训练好的 SSD 模型识别图像中的对象。对象检测是一项比较具有挑战性的视觉识别任务。除了对图像中的对象进行分类外,对象检测还可以识别图像中对象的位置。它还可以在关注对象周围绘制一个边框并添加一个类别(文本)标签。


SSD 机制是机器学习领域的一项最新进展,它检测对象的速度非常快,与此同时,还能保持与需要更大计算量的模型相媲美的准确性。要了解关于 SSD 模型的更多信息,可以阅读博文“理解SSD MultiBox——深度学习中的实时对象检测”以及《深入机器学习》这本书里的这个练习


使用 DJL 的 JSR-381 实现,用户可以访问预先训练好的、开箱即用的 SSD 模型实现。DJL 使用ModelZoo来简化模型部署。下面的代码块使用ModelZoo.loadModel()加载一个预先训练好的模型,实例化一个对象检测器类,并将这个模型应用到一副示例图像上。


// 定义一个满足用户需求的模型查找标准Criteria<BufferedImage, DetectedObjects> criteria =        Criteria.builder()                .setTypes(BufferedImage.class, DetectedObjects.class)                // 查找一个对象检测模型                .optApplication(Application.CV.OBJECT_DETECTION)                .build();// 加载模型,创建一个SimpleObjectDectector对象try (ZooModel<BufferedImage, DetectedObjects> model = ModelZoo.loadModel(criteria)) {   // SimpleObjectDetector是一个负责检测对象的高级JSR-381 API   SimpleObjectDetector objectDetector = new SimpleObjectDetector(model);   // 加载图像   BufferedImage input =       BufferedImageUtils.fromUrl(           "https://djl-ai.s3.amazonaws.com/resources/images/dog_bike_car.jpg");   // 检测对象   Map<String, List<BoundingBox>> result = objectDetector.detectObject(input);   for (List<BoundingBox> boundingBoxes : result.values()) {       for (BoundingBox boundingBox : boundingBoxes) {           System.out.println(boundingBox.toString());       }   }}
复制代码


下面是一副可供我们使用的新图像。



在这幅图像上运行代码将产生如下结果:


BoundingBox{id=0, x=124.0, y=119.0, width=456.45093, height=338.8393, label=bicycle, score=0.9538524}BoundingBox{id=0, x=469.0, y=78.0, width=225.19464, height=92.147675, label=car, score=0.99991035}BoundingBox{id=0, x=128.0, y=201.0, width=210.51933, height=341.7647, label=dog, score=0.9375212}
复制代码


如果你希望给从图像上检测到的每个对象添加边框,只需几行代码即可。要了解更多信息,请参见完整的GitHub示例。该模型对三个关注对象(自行车、汽车和狗)进行分类,在每个对象周围画一个边框,并提供一个由概率反映的置信度。



值得注意的是,预训练模型的检测精度取决于用于训练模型的图像。模型的精度可以通过再训练来提高,也可以使用一组更能代表最终应用程序的图像开发一个自定义的模型。然而,这种方法非常耗时,并且需要使用大量的训练数据。对于许多 ML 应用程序,使用预先训练好的模型建立基线通常是值得的。这可以节省大量收集、准备数据和从头训练模型的时间。

未来展望

在这篇文章中,我们仅仅了解了使用 JSR-381 API 的 DJL 实现可以做些什么。你可以使用 ModelZoo 中预先训练好的模型库探索和实现更多的模型,或者引入自己的模型。


感兴趣的读者可以检出DJL,这是一个由 Amazon 的 Java 开发人员为 Java 社区构建的开源库。我们试图简化 Java 中机器学习的开发和部署。欢迎加入我们!


DJL 有很多用例,你可以开发一个客服问答应用程序,实现瑜伽姿势的姿态估计,或者训练你自己的模型来检测后院的入侵者。我们的Spring Boot入门套件还简化了 ML 与 Spring Boot 应用程序的集成。读者可以通过我们的介绍性博客网站示例库了解更多关于 DJL 的信息。请访问我们的Github库,在我们的 Slack频道与我们合作。



参考资料




作者简介:


Frank Liu AWS AI软件工程师。他专注于为软件工程师和科学家打造创新型深度学习工具。在业余时间,他喜欢与朋友和家人一起徒步旅行。


Xinyu Liu AWS AI软件开发经理。他热衷于机器学习和大规模分布式系统。


Frank Greco 是 Crossroads Technologies 公司的创始人和首席执行官。他是高级技术顾问和企业架构师,致力于为开发人员提供云计算和 AI/ML 工具。他是一名 Java 冠军程序员、NYJavaSIG的主席,并在欧洲举办了企业机器学习国际会议。他业余时间喜欢弹吉他。


Zoran Sevarac Deep Netts的 CEO。他致力于为 Java 开发人员构建用户友好的深度学习工具,并创建 AI Java 标准。他是贝尔格莱德大学的教授和 Java 冠军程序员。他业余时间喜欢弹吉他。


Balaji Kamakoti AWS AI高级产品经理。他致力于让开发者更容易使用深度学习产品。在业余时间,他喜欢打网球,弹萨罗德琴(一种无琴格的弦乐器)。


特别感谢JCPJSR-381团队的宝贵贡献:Kevin Berendsen、Sandhya Kapoor、Werner Keil、Constantin Drabo、Ankara Parida、Melissa Mckay、Buddha Jyoti Prasad、Shreya Gupta、Amit Nagesh、Heather VanCura 和 Harold Ogle。


原文链接:


Machine Learning in Java With Amazon Deep Java Library


2020-06-17 10:392625

评论

发布
暂无评论
发现更多内容

一文了解 MySQL 中的锁

Ayue、

MySQL 数据库 1月月更

Awesome DAO 文章和资源推荐(8/100)

hackstoic

DAO

用 K3s 来运行安装和极狐GitLab Runner

极狐GitLab

一个cpp协程库的前世今生(二十四)对象池与栈内存池

SkyFire

c++ cocpp

低代码实现探索(三十)低代码设计器设计方式

零道云-混合式低代码平台

微信业务架构图及学生管理系统架构设计实践

IT屠狗辈

系统架构 架构实战营 微信业务架构图

微信业务架构+学生管理系统毕设方案

李大虾

#架构实战营 「架构实战营」

一条SQL查询语句是如何执行的?

蝉沐风

MySQL sql 面试

SAST 为什么会成为网络安全领域的下一件大事?

旋极智能

静态分析 静态测试工具 代码静态分析

TortoiseSVN 执行清理( cleanUp )失败的解决方案

编程三昧

svn 开发工具 1月月更

微信业务架构图 & 学生管理系统架构设计

smile

架构实战营

征集用户| 填写 2022 Apache Pulsar 用户调查问卷,抽取丰厚礼品

Apache Pulsar

开源 云原生 中间件 Apache Pulsar 社区

面试官太难伺候?一个try-catch问出这么多花样

阿Q说代码

效率 字节码指令 1月月更 try-catch finally-return

行业先锋畅聊 Flink 未来 —— FFA 2021 圆桌会议(北京)

Apache Flink

大数据 flink 编程 后端 实时计算

Hive 数据倾斜问题定位排查及解决

五分钟学大数据

hive 1月月更

作业帮基于 Flink 的实时计算平台实践

Apache Flink

大数据 flink 编程 实时计算 IT

Hive企业级性能优化

五分钟学大数据

hive 1月月更

深入理解Python内存管理与垃圾回收

宇宙之一粟

Python 内存管理 1月月更

ReactNative进阶(三十五):应用脚手架 Yo 构建 RN 页面

No Silver Bullet

React Native 1月月更 BloC yo

为什么需要闭包?闭包是什么概念?

蜜糖的代码注释

Java 后端 开发

Mybatis中的VFS是个啥

尹昶胜

mybatis

如何使用JavaScript开发AR(增强现实)移动应用

汪子熙

JavaScript AR 1月月更 增强现实

微信业务架构、学生管理系统(草稿)

Geek_16d2b8

架构训练营

Web or Native 哪个才是元宇宙的未来(下)?

Orillusion

WebGL 元宇宙 Metaverse webgpu

混沌工程之ChaosMesh使用之模拟CPU使用率

zuozewei

混沌工程 Chaos Mesh 1月月更

模块一

Geek_f3e842

架构实战营

征文投稿丨在轻量应用服务器上部署SpringBoot项目

阿里云弹性计算

阿里云 用户投稿 轻量应用

代码之外的生存指南,先掌握这五步。

叶小鍵

华为云FusionInsight连续三次获得第一 加速释放数据要素价值

数据湖洞见

大数据 FusionInsight 华为云

面向复杂度架构设计之学生管理系统

晨亮

「架构实战营」

Fabric.js 将本地图像上传到画布背景

德育处主任

前端 数据可视化 前端可视化 FabricJS Fabric.js

Java机器学习工具箱:Amazon Deep Java Library_AI&大模型_InfoQ精选文章