NVIDIA 初创加速计划,免费加速您的创业启动 了解详情
写点什么

Deep Java Library (DJL) 简介:与引擎无关的 Java 深度学习框架

  • 2020-01-13
  • 本文字数:8821 字

    阅读完需:约 29 分钟

Deep Java Library (DJL) 简介:与引擎无关的Java深度学习框架

本文要点

  • 开发人员可以使用 Java 和他们喜欢的 IDE 来构建、训练和部署机器学习(ML)和深度学习(DL)模型

  • DJL 简化了深度学习(DL)框架的使用,目前支持 Apache MXNet

  • DJL 的开源对于工具包及其用户来说都是互惠互利的

  • DJL 是引擎无关的,这意味着开发人员只需编写一次代码就可以在任何引擎上运行

  • 在尝试使用 DJL 之前,Java 开发人员应该了解 ML 生命周期和常用的 ML 术语


亚马逊(Amazon)的 DJL(Deep Java Library )是一个深度学习工具包,使用它可在 Java 中原生地进行机器学习(ML)和深度学习(DL)模型开发,从而简化深度学习框架的使用。DJL 是在 2019 年 re:Invent 大会上开源的工具包,它提供了一组高级 API 来训练、测试和运行在线推理(inference)。Java 开发人员可以开发自己的模型,也可以在他们的 Java 代码中使用数据科学家用 Python 开发的预先训练的模型。


DJL 秉承了 Java 的座右铭,“编写一次,到处运行(WORA)”,因为它是引擎和深度学习框架无关的。开发人员只需编写一次就可在任何引擎上运行。DJL 目前提供了一个 Apache MXNet 的实现,这是一个可以简化深度神经网络开发的 ML 引擎。DJL API 使用 JNA(Java Native Access)来调用相应的 Apache MXNet 操作。DJL 编排管理基础设施,基于硬件配置来提供自动的 CPU/GPU 检测,以确保良好的运行效果。


DJL API 通过抽象常用的功能来开发模型,这使 Java 开发人员能够利用现有的知识,从而可以轻松地过渡到 ML。为了了解 DJL 的实际效果,我们开发一个“鞋”的分类模型作为一个简单的示例。

机器学习生命周期

我们建立“鞋”分类模型遵循了机器学习的生命周期。ML 生命周期与传统的软件开发生命周期有所不同,它包含六个具体的步骤:


  1. 获取数据

  2. 清洗并准备数据

  3. 生成模型

  4. 评估模型

  5. 部署模型

  6. 从模型中获得预测(或推理)


生命周期的最终结果是一个可以查询并返回答案(或预测)的机器学习模型。



模型只是数据中趋势和模式的数学表示。好的数据才是所有 ML 项目的基础。


在步骤 1 中,从可靠的来源中获取数据。在步骤 2 中,数据被清洗、转换并以机器可以学习的格式存储。清洗和转换过程通常是机器学习生命周期中最耗时的部分。DJL 提供了利用翻译器(translator)来对图像进行预处理的能力,这能为开发人员简化清洗和转换过程。翻译器可以执行一些图像任务,比如,可以根据预设参数调整图像的大小或将图像从彩色图转换为灰度图。


刚刚过渡向机器学习的开发人员常常会低估清洗和转换数据所需的时间,因此翻译器是快速启动该过程的好方法。步骤 3,在训练过程中,一个机器学习算法会对数据进行多遍(或多代)处理,不断研究它们,以试图学习到不同类型的“鞋”。训练过程中发现的与“鞋”相关的趋势和模式会被存储在模型中。当需要评估模型以确定其在识别“鞋”方面的能力时,第 4 步会作为训练的一部分;如果发现了错误,则予以纠正。在步骤 5 中,将模型部署到生产环境中。模型投入生产后,步骤 6 允许其他系统使用该模型。


通常,可以在代码中动态地加载模型,或者通过基于 REST 的 HTTPS 端点访问模型。

数据

“鞋”分类模型是一个多级分类计算机视觉(CV)模型,它使用有监督学习进行训练,可以将“鞋”分为四类:靴子(boots)、凉鞋(sandals)、鞋子(shoes)或拖鞋(slippers)。有监督学习必须包含已经标记了我们想要预测的目标(或答案)的数据;这就是机器学习的方式。


“鞋”分类模型的数据源是德克萨斯大学奥斯汀分校(The University of Texas at Austin)提供的 UTZappos50k 数据集(dataset),它可免费用于学术和非商业用途。下面这个“鞋子”数据集包含了从 Zappos.com 收集的 50025 张带标签的目录图像。



“鞋”数据保存在本地,并使用 DJL 的 ImageFolder 数据集对其进行加载,该数据集可以从本地文件夹中检索图像。


// 识别训练数据的位置String trainingDatasetRoot = "src/test/resources/imagefolder/train";
// 识别验证数据的位置String validateDatasetRoot = "src/test/resources/imagefolder/validate";
// 创建训练数据 ImageFolder 数据集ImageFolder trainingDataset = initDataset(trainingDatasetRoot);
//创建验证数据 ImageFolder 数据集ImageFolder validateDataset = initDataset(validateDatasetRoot);
复制代码


在本地构造数据时,我并没有深入到 UTZappos50k 数据集所标识的最细粒度的分类等级,比如到脚踝的、膝盖等高的、到达小腿中部的、过膝的等靴子的最细粒度等级的分类标签。我的本地数据使用的是最高等级的分类,仅包括靴子、凉鞋、鞋子和拖鞋等四类。



在 DJL 术语中,数据集只用于保存训练数据。有些数据集的实现可用于下载数据(基于我们提供的 URL)、提取数据、以及自动地将数据分为训练集和验证集。


自动分离是一个特别有用的特性,因为不使用相同的数据来训练和验证模型这一点是至关重要的。该模型所使用的训练数据集用于查找“鞋”数据中的趋势和模式。验证数据集通过提供对“鞋”分类模型精度无偏差的估计来检验模型的效果。


如果用训练的数据验证模型,则会降低我们对模型分类鞋子能力的信心,因为模型是用它已经看到的数据进行测试的。在现实世界中,老师也不会使用和学习指南上完全相同的题目来测试学生,因为这不能衡量一个学生的真实知识或对资料的理解;当然,同样的概念也适用于机器学习模型。

训练

现在我们已经将“鞋”数据分为训练集和验证集,下面我们将使用神经网络来训练(或生成)模型。


public final class Training extends AbstractTraining {
. . .
@Override protected void train(Arguments arguments) throws IOException {
// 识别训练数据的位置 String trainingDatasetRoot = "src/test/resources/imagefolder/train";
// 识别验证数据的位置 String validateDatasetRoot = "src/test/resources/imagefolder/validate";
//创建训练数据 ImageFolder 数据集 ImageFolder trainingDataset = initDataset(trainingDatasetRoot);
//创建验证数据 ImageFolder 数据集 ImageFolder validateDataset = initDataset(validateDatasetRoot);
. . . try (Model model = Models.getModel(NUM_OF_OUTPUT, NEW_HEIGHT, NEW_WIDTH)) { TrainingConfig config = setupTrainingConfig(loss);
try (Trainer trainer = model.newTrainer(config)) { trainer.setMetrics(metrics);
trainer.setTrainingListener(this);
Shape inputShape = new Shape(1, 3, NEW_HEIGHT, NEW_WIDTH);
// 根据相应输入的形状初始化训练器 trainer.initialize(inputShape);
//在数据中查找模式 fit(trainer, trainingDataset, validateDataset, "build/logs/training");
//设置模型属性 model.setProperty("Epoch", String.valueOf(EPOCHS)); model.setProperty("Accuracy", String.format("%.2f", getValidationAccuracy()));
// 训练完成后保存模型,为后面的推理做准备 //模型保存为 shoeclassifier-0000.params model.save(Paths.get(modelParamsPath), modelParamsName); } } }
}
复制代码


第一步是通过调用 Models.getModel(NUM_OF_OUTPUT, NEW_HEIGHT, NEW_WIDTH) 来获取模型实例。深度学习是机器学习的一种形式,它使用神经网络来训练模型。神经网络是以人脑中的神经元来进行建模的;神经元是可以将信息(或数据)传递给其他细胞的细胞。


ResNet-50 是一种常用于图像分类的神经网络,50 表示从初始输入数据和最终预测之间有 50 个学习层(或神经元)。getModel() 方法用于创建一个空模型,构造一个 ResNet-50 神经网络,并将神经网络设置到该模型中。


public class Models {   public static ai.djl.Model getModel(int numOfOutput, int height, int width) {       //创建一个空模型的新实例       ai.djl.Model model = ai.djl.Model.newInstance();
//是构建神经网络所需的可组合单元;可以像像乐高积木一样将它们连结在一起, //形成一个复杂的网络 Block resNet50 = //构建网络 new ResNetV1.Builder() .setImageShape(new Shape(3, height, width)) .setNumLayers(50) .setOutSize(numOfOutput) .build();
//将神经网络设置到模型中 model.setBlock(resNet50); return model; }}
复制代码


下一步是通过调用 model.newTrainer(config) 方法来设置和配置训练器。通过调用 setupTrainingConfig(loss) 方法来初始化配置对象,该方法通过设置训练的配置(或超参)来决定如何训练网络。


接下来的步骤使我们可以通过设置以下内容来向 Trainer 中添加功能:


  • 使用 trainer.setMetrics(metrics) 来设置 Metrics

  • 使用 trainer.setTrainingListener(this) 来设置训练监听器

  • 使用 trainer.initialize(inputShape) 来设置合适的输入形状


Metrics 在训练期间收集并报告关键绩效指标(KPI),该 KPI 可用于分析和监控训练的效果和稳定性。下一步是通过调用 fit(trainer, trainingDataset, validateDataset, “build/logs/training”) 方法来启动训练过程,该方法将迭代训练数据并存储在模型中找到的模式。训练结束时,使用 model.save(Paths.get(modelParamsPath) 方法将一个表现良好的、经过验证的模型工件及属性保存在本地。


训练过程中报告的度量指标如下所示。注意,随着每代(epoch)(或每遍(pass))的递增,模型的精度都会提高;第 9 代(epoch)的最终训练精度为 90%。


推理

现在我们已经生成了模型,它可以用于对我们不知道类型(或目标)的新数据执行推理(或预测)。


private Classifications predict() throws IOException, ModelException, TranslateException  {   //在训练期间保存到模型的位置   String modelParamsPath = "build/logs";
//训练时设置的模型名称 String modelParamsName = "shoeclassifier";
//需要分类的图像路径 String imageFilePath = "src/test/resources/slippers.jpg";
//从路径加载图像文件 BufferedImage img = BufferedImageUtils.fromFile(Paths.get(imageFilePath));
//持有每个标签的概率分数 Classifications predictResult;
try (Model model = Models.getModel(NUM_OF_OUTPUT, NEW_HEIGHT, NEW_WIDTH)) { //加载模型 model.load(Paths.get(modelParamsPath), modelParamsName);
//定义用于预处理和后置处理的翻译器 Translator<BufferedImage, Classifications> translator = new MyTranslator();
//使用预测器运行推理 try (Predictor<BufferedImage, Classifications> predictor = model.newPredictor(translator)) { predictResult = predictor.predict(img); } }
return predictResult;}
复制代码


在设置了模型和要分类的图像的必要路径之后,使用 Models.getModel(NUM_OF_OUTPUT, NEW_HEIGHT, NEW_WIDTH) 方法获取一个空模型实例,并使用 model.load(Paths.get(modelParamsPath), modelParamsName) 方法对其进行初始化。它将会加载上一步训练的模型。


接下来,使用 model.newPredictor(translator) 方法初始化一个带有指定的 Translator 的 Predictor。在 DJL 术语中,Translator 提供了模型预处理和置后处理的能力。例如,对于 CV 模型,需要将图像重塑为灰度图;Translator 是可以做到的。Predictor 使我们可以利用 predictor.predict(img) 方法来对加载的 Model 进行推理,并传入图像进行分类。


这个示例展示的是单个的预测,但是 DJL 也支持批量预测。推理存储在 predictResult 中,predictResult 包含了每个标签的概率估计。


推理(每张图片)及其对应的概率得分如下所示。






(表格对应的图片如上所示)


图像概率得分
如图1[信息] - [                 分类: “0”, 概率: 0.98985                 分类: “1”, 概率: 0.00225                 分类: “2”, 概率: 0.00224                 分类: “3”, 概率: 0.00564             ] 分类0 代表靴子,概率得分为 98.98%
图2[信息] - [                分类: “0”, 概率: 0.02111                分类: “1”, 概率: 0.76524                分类: “2”, 概率: 0.01159                分类: “3”, 概率: 0.20204           ] 分类1 代表凉鞋,概率得分为 o76.52%
图3[信息] - [                分类: “0”, 概率: 0.05523                分类: “1”, 概率: 0.01417                分类: “2”, 概率: 0.87900                分类: “3”, 概率: 0.05158               ] 分类2 代表鞋子,概率得分为 87.90%
图4[信息] - [                 分类: “0”, 概率: 0.00003                 分类: “1”, 概率: 0.01133                分类: “2”, 概率: 0.00179                 分类: “3”, 概率: 0.98682               ] 分类3 代表拖鞋,概率得分为of 98.68%.


DJL 提供了与其他 Java 库一样的原生 Java 开发体验和功能。设计这些 API 是为了指导开发人员能够用最佳实践来完成深度学习任务。在开始使用 DJL 之前,需要对 ML 生命周期有一个很好的理解。如果您是 ML 初学者,请先阅读这篇概述或 InfoQ 的系列文章《软件开发人员机器学习入门》。在理解了生命周期和常见的 ML 术语之后,开发人员就可以快速地掌握 DJL 的 API 了。


亚马逊已经开源了 DJL,有关该工具包的更多详细信息可以在 DJL 网站Java 库 API 规范(Java Library API Specification) 页面上找到。您也可以回顾下“鞋”分类模型的代码,以进一步探索该示例。

作者介绍

Kesha Williams 是一位屡获殊荣的软件工程师、机器学习实践者和 A Cloud Guru 的技术讲师,拥有 24 年的经验。在大学任教期间,她曾培训并指导了数千名来自美国、欧洲和亚洲的 Java 软件工程师。她经常带领创新团队验证新兴技术,并在全球各地的会议上分享她的经验教训。作为 TED 的 Spotlight Presentation Academy 的获得者,她在 TED 舞台上做过机器学习的演讲。此外,她在人工智能领域的开创性工作为她赢得了亚马逊的 Alexa Champion 和 AWS Machine Learning Hero 的殊荣。在业余时间,她通过在线社交专业网络平台 Colors of STEM 指导女性科技从业者。


原文链接:


Getting to Know Deep Java Library (DJL)


公众号推荐:

跳进 AI 的奇妙世界,一起探索未来工作的新风貌!想要深入了解 AI 如何成为产业创新的新引擎?好奇哪些城市正成为 AI 人才的新磁场?《中国生成式 AI 开发者洞察 2024》由 InfoQ 研究中心精心打造,为你深度解锁生成式 AI 领域的最新开发者动态。无论你是资深研发者,还是对生成式 AI 充满好奇的新手,这份报告都是你不可错过的知识宝典。欢迎大家扫码关注「AI前线」公众号,回复「开发者洞察」领取。

2020-01-13 09:155859

评论

发布
暂无评论
发现更多内容

数据库原理及MySQL应用 | 数据表操作

TiAmo

数据库 MySQL数据库 数据表 三周年连更

支持宽屏格式的音乐播放器:Swinsian mac版

真大的脸盆

Mac Mac 软件 播放器 音乐播放器

一文了解 Go 标准库 math 和 rand 的常用函数

陈明勇

Go golang math 三周年连更 rand

ER 图是什么?这一篇让你搞懂 ER 图!

产品海豚湾

数据分析 产品设计 ER图 产品架构 业务梳理

css水平垂直居中各种方法实现方式

肥晨

三周年连更

Unity 之 安卓堆栈跟踪和日志工具 (Android Logcat | 符号表解析Bugly捕获)

陈言必行

Unity 三周年连更 日志工具

如何基于 zap 封装一个更好用的日志库

江湖十年

后端 日志 log Go 语言

连续信源的熵与RD

timerring

信息论 三周年连更

NoClassDefFoundError 和 ClassNotFoundException 有什么区别 | 社区征文

共饮一杯无

NoClassDefFoundError 三周年连更 ClassNotFoundException

CSS架构之BEM设计模式

肥晨

CSS bem 三周年连更 css架构

Ai工具推荐 - Claude(手机端也可使用的媲美ChatGPT的产品)

炜娓道来程序人生

AI 工具 ChatGPT

C生万物 | 函数的讲解与剖析【内附众多案例详解】

Fire_Shield

C语言 三周年连更

Fragment——底部导航栏的实现

智趣匠

Fragment QRadioButton 三周年连更

如何管理你的python包 | python小知识

AIWeker

Python python小知识 三周年连更

Java中「Future」接口详解

Java 架构

跨平台应用开发进阶(四十七)APP字体库文件处理方案

No Silver Bullet

App 跨平台应用开发 三周年连更 字体库

一文读懂Spring中的AOP机制

老周聊架构

三周年连更

$ZZZ 以 Launchpad 形式多平台首发,GoSleep 成 Sleep to Earn 叙事成 X2E 新宠

股市老人

Spring Data开发手册|Java持久化API(JPA)需要了解到什么程度呢?

浅羽技术

Java 框架 jpa ORM 三周年连更

从零开始学习MySQL调试跟踪(2)

GreatSQL

使用chatGPT自动回复抖音评论

南城FE

人工智能 AI 前端 后端

Mysql常用数据类型及其默认值

will

MySQL varchar 数据类型 tinyint

参与开源之夏 x OpenTiny 跨端跨框架 UI 组件库贡献,可以赢取奖金🏆!这份《OpenTiny 开源贡献指南》请收好🎁!

Kagol

开源 Vue 前端 UI组件库

IPv6地址分类

穿过生命散发芬芳

ipv6 三周年连更

抖音起诉某刷量软件侵权获胜,如何严厉打击刷量、数据造假现象

石头IT视角

进程与线程、并行和并发有啥区别 | 社区征文

共饮一杯无

Java 多线程 三周年征文

Mac上实用的工具分享

IT蜗壳-Tango

三周年连更

如何检查 Linux 内存使用量是否耗尽?这5个命令堪称绝了!

wljslmz

Linux 三周年连更

FastAPI 快速开发 Web API 项目: 响应模型与错误处理

宇宙之一粟

Python FastApi 三周年连更

markdown格式基础用法

乌龟哥哥

三周年连更

HTTP报文的组成

阿泽🧸

HTTP 三周年连更

Deep Java Library (DJL) 简介:与引擎无关的Java深度学习框架_AI&大模型_Kesha Williams_InfoQ精选文章