【ArchSummit】如何通过AIOps推动可量化的业务价值增长和效率提升?>>> 了解详情
写点什么

TensorFlow 官方简化版!谷歌开源机器学习库 JAX

整理编译自Reddit

  • 2018-12-12
  • 本文字数:2353 字

    阅读完需:约 8 分钟

TensorFlow官方简化版!谷歌开源机器学习库JAX

AI 前线导读:什么?TensorFlow 有了替代品?什么?竟然还是谷歌自己做出来的?先别慌,从各种意义上来说,这个所谓的“替代品”其实是 TensorFlow 的一个简化库,名为 JAX,结合 Autograd 和 XLA,可以支持部分 TensorFlow 的功能,但是比 TensorFlow 更加简洁易用。虽然还不至于替代 TensorFlow,但已经有 Reddit 网友对 JAX 寄予厚望,并表示“早就期待能有一个可以直接调用 Numpy API 接口的库了!”,“希望它可以取代 TensorFlow!”。


更多干货内容请关注微信公众号“AI 前线”(ID:ai-front)


JAX 结合了 Autograd 和 XLA,是专为高性能机器学习研究打造的产品。



有了新版本的 Autograd,JAX 能够自动对 Python 和 NumPy 的自带函数求导,支持循环、分支、递归、闭包函数求导,而且可以求三阶导数。它支持自动模式反向求导(也就是反向传播)和正向求导,且二者可以任意组合成任何顺序。


JAX 的创新之处在于,它基于 XLA 在 GPU 和 TPU 上编译和运行 NumPy 程序。默认情况下,编译是在底层进行的,库调用能够及时编译和执行。但是 JAX 还允许使用单一函数 API jit将自己的 Python 函数及时编译成经过 XLA 优化的内核。编译和自动求导可以任意组合,因此可以在不脱离 Python 环境的情况下实现复杂算法并获得最优性能。


JAX 最初由 Matt Johnson、Roy Frostig、Dougal Maclaurin 和 Chris Leary 发起,他们均任职于谷歌大脑团队。在 GitHub 的说明文档中,作者明确表示:JAX 目前还只是一个研究项目,不是谷歌的官方产品,因此可能会有一些 bug。从作者的 GitHub 简介来看,这应该是谷歌大脑正在尝试的新项目,在同一个 GitHub 目录下的开源项目还包括 8 月份在业内引起热议的强化学习框架 Dopamine。


以下是 JAX 的简单使用示例。



GitHub 项目传送门:https://github.com/google/JAX


有关具体的安装和简单的入门指导大家可以在 GitHub 中自行查看,在此不做过多赘述。

JAX 库的实现原理

机器学习中的编程是关于函数的表达和转换。转换包括自动微分、加速器编译和自动批处理。像 Python 这样的高级语言非常适合表达函数,但是通常使用者只能应用它们。我们无法访问它们的内部结构,因此无法执行转换。


JAX 可以用于专门化高级 Python+NumPy 函数,并将其转换为可转换的表示形式,然后再提升为 Python 函数。



JAX 通过跟踪专门处理 Python 函数。跟踪一个函数意味着:监视应用于其输入,以产生其输出的所有基本操作,并在有向无环图(DAG)中记录这些操作及其之间的数据流。为了执行跟踪,JAX 包装了基本的操作,就像基本的数字内核一样,这样一来,当调用它们时,它们就会将自己添加到执行的操作列表以及输入和输出中。为了跟踪这些原语之间的数据流,跟踪的值被包装在 Tracer 类的实例中。


当 Python 函数被提供给 grad 或 jit 时,它被包装起来以便跟踪并返回。当调用包装的函数时,我们将提供的具体参数抽象到 AbstractValue 类的实例中,将它们框起来用于跟踪跟踪器类的实例,并对它们调用函数。


抽象参数表示一组可能的值,而不是特定的值:例如,jit 将 ndarray 参数抽象为抽象值,这些值表示具有相同形状和数据类型的所有 ndarray。相反,grad 抽象 ndarray 参数来表示底层值的无穷小邻域。通过在这些抽象值上跟踪 Python 函数,我们确保它足够专门化,以便转换是可处理的,并且它仍然足够通用,以便转换后的结果是有用的,并且可能是可重用的。然后将这些转换后的函数提升回 Python 可调用函数,这样就可以根据需要跟踪并再次转换它们。


JAX 跟踪的基本函数大多与 XLA HLO 1:1 对应,并在 lax.py 中定义。这种 1:1 的对应关系使得到 XLA 的大多数转换基本上都很简单,并且确保我们只有一小组原语来覆盖其他转换,比如自动微分。 jax.numpy 层是用纯 Python 编写的,它只是用 LAX 函数(以及我们已经编写的其他 numpy 函数)表示 numpy 函数。这使得 jax.numpy 易于延展。


当你使用 jax.numpy 时,底层 LAX 原语是在后台进行 jit 编译的,允许你在加速器上执行每个原语操作的同时编写不受限制的 Python+ numpy 代码。


但是 JAX 可以做更多的事情:你可以在越来越大的函数上使用 jit 来进行端到端编译和优化,而不仅仅是编译和调度到一组固定的单个原语。例如,可以编译整个网络,或者编译整个梯度计算和优化器更新步骤,而不仅仅是编译和调度卷积运算。


折衷之处是,jit 函数必须满足一些额外的专门化需求:因为我们希望编译专门针对形状和数据类型的跟踪,但不是专门针对具体值的跟踪,所以 jit 装饰器下的 Python 代码必须适用于抽象值。如果我们尝试在一个抽象的 x 上求 x >0 的值,结果是一个抽象的值,表示集合{True, False},所以 Python 分支就像 if x > 0 会引起报错。


有关使用 jit 的更多要求,请参见:https://github.com/google/jax#whats-supported


好消息是,jit 是可选的:JAX 库在后台对单个操作和函数使用 jit,允许编写不受限制的 Python+Numpy,同时仍然使用硬件加速器。但是,当你希望最大化性能时,通常可以在自己的代码中使用 jit 编译和端到端优化更大的函数。

后续计划

目前项目小组还将对以下几项做更多尝试和更新:


  1. 完善说明文档

  2. 支持 Cloud TPU

  3. 支持多 GPU 和多 TPU

  4. 支持完整的 NumPy 功能和部分 SciPy 功能

  5. 全面支持 vmap

  6. 加速

  7. 降低 XLA 函数调度开销

  8. 线性代数例程(CPU 上的 MKL 和 GPU 上的 MAGMA)

  9. 高效自动微分原语condwhile


有关 JAX 库的介绍大致如此,如果你在尝试了 JAX 之后有一些较好的使用心得,欢迎随时向我们投稿,AI 前线十分愿意将你的经验传播给更多开发者。


再次附上 GitHub 链接:https://github.com/google/jax


相关资源:


JAX 论文链接:https://www.sysml.cc/doc/146.pdf

会议推荐

AICon


公众号推荐:

跳进 AI 的奇妙世界,一起探索未来工作的新风貌!想要深入了解 AI 如何成为产业创新的新引擎?好奇哪些城市正成为 AI 人才的新磁场?《中国生成式 AI 开发者洞察 2024》由 InfoQ 研究中心精心打造,为你深度解锁生成式 AI 领域的最新开发者动态。无论你是资深研发者,还是对生成式 AI 充满好奇的新手,这份报告都是你不可错过的知识宝典。欢迎大家扫码关注「AI前线」公众号,回复「开发者洞察」领取。

2018-12-12 07:002274
用户头像
陈思 InfoQ编辑

发布了 576 篇内容, 共 263.3 次阅读, 收获喜欢 1293 次。

关注

评论 1 条评论

发布
暂无评论
发现更多内容

RAG 修炼手册|一文讲透 RAG 背后的技术

Zilliz

nlp 向量数据库 LLM rag enbedding

企业架构设计原则之品质均衡性(二)

凌晞

企业架构 架构设计 架构设计原则

企业架构设计原则之品质均衡性(三)

凌晞

企业架构 架构设计

大模型工程化落地,足够细分的优质数据是关键

澳鹏Appen

AI工程化 数据标注 训练数据 大模型 LLM

探索314协议代币合约开发:解析AVE热搜上币与项目推广

区块链软件开发推广运营

dapp开发 区块链开发 链游开发 NFT开发 公链开发

NL2SQL基础系列(2):主流大模型与微调方法精选集,Text2SQL经典算法技术回顾七年发展脉络梳理

汀丶人工智能

大模型 NL2SQL

华为天气“赏春计划”来袭,浪漫解锁影音会员、出行礼包多重福利

最新动态

智能制造与AI大模型

百度开发者中心

人工智能 深度学习 大模型 智能制造

【精选】发布应用到应用商店的基本介绍

获取体育视频源数据与搭建自主体育直播平台源码的作用

软件开发-梦幻运营部

AMD 以全新第二代 Versal 系列器件扩展领先自适应 SoC 产品组合

财见

解析为什么企业出海需要SD-WAN专线

Ogcloud

SD-WAN 企业组网 SD-WAN组网 SD-WAN服务商 SDWAN

文件处理的神器,一键上传签署,安全又高效!

聚道云软件连接器

案例分享

大模型分布式训练并行技术

百度开发者中心

人工智能 深度学习 大模型

Qt Group与高通公司合作,简化工业物联网的用户界面开发

财见

2024上海国际智慧物业展览会

AIOTE智博会

智慧物业展 智慧物业展会 智慧物业展览会 智慧物业博览会

RUM 最佳实践-交互延迟的探索与发现

观测云

性能优化

软件测试学习笔记丨Python的自动解包 自动组包

测试人

Python 软件测试 测试开发

牛蛙!GoFrame2.7正式版的监控组件真是及时雨

王中阳Go

Go golang 面试题 面经 大厂面经

小红书笔记详情API接口解析:轻松抓取内容数据,提升业务效率

技术冰糖葫芦

API Explorer api 货币化 API】 pinduoduo API

基于istio实现单集群地域故障转移

华为云开发者联盟

微服务 istio 华为云 华为云开发者联盟 企业号2024年4月PK榜

选择国外云主机的五大理由以及优劣势分析

一只扑棱蛾子

国外主机

大型连锁企业异地组网稳定性提升指南

Ogcloud

SD-WAN SD-WAN组网 SD-WAN服务商 异地组网 SDWAN

ETLCloud结合kafka的数据集成

RestCloud

kafka ETL 数据集成

国内低代码哪家强?深入探讨低代码选型关键指标和评估模型

牛刀专业低代码

低代码开发平台 国内低代码 低代码选择 低代码平台比较 低代码排名

【干货】零售商的商品规划策略

第七在线

软件测试学习笔记丨测试框架体系 TDD DDT BDD ATDD 介绍

测试人

软件测试 测试开发

云手机解决海外社媒运营的诸多挑战

Ogcloud

云手机 海外云手机 云手机海外版 国外云手机 跨境云手机

英特尔和Altera发布边缘和FPGA产品,提供FPGA AI套件加速开发者创新

E科讯

人大金仓:国产数据库的领航者,高速公路信息化的创新力量

科技热闻

零基础到精通,Postman安装使用教程(一)

霍格沃兹测试开发学社

TensorFlow官方简化版!谷歌开源机器学习库JAX_AI&大模型_InfoQ精选文章