阿里、蚂蚁、晟腾、中科加禾精彩分享 AI 基础设施洞见,现购票可享受 9 折优惠 |AICon 了解详情
写点什么

TensorFlow 官方简化版!谷歌开源机器学习库 JAX

整理编译自Reddit

  • 2018-12-12
  • 本文字数:2353 字

    阅读完需:约 8 分钟

TensorFlow官方简化版!谷歌开源机器学习库JAX

AI 前线导读:什么?TensorFlow 有了替代品?什么?竟然还是谷歌自己做出来的?先别慌,从各种意义上来说,这个所谓的“替代品”其实是 TensorFlow 的一个简化库,名为 JAX,结合 Autograd 和 XLA,可以支持部分 TensorFlow 的功能,但是比 TensorFlow 更加简洁易用。虽然还不至于替代 TensorFlow,但已经有 Reddit 网友对 JAX 寄予厚望,并表示“早就期待能有一个可以直接调用 Numpy API 接口的库了!”,“希望它可以取代 TensorFlow!”。


更多干货内容请关注微信公众号“AI 前线”(ID:ai-front)


JAX 结合了 Autograd 和 XLA,是专为高性能机器学习研究打造的产品。



有了新版本的 Autograd,JAX 能够自动对 Python 和 NumPy 的自带函数求导,支持循环、分支、递归、闭包函数求导,而且可以求三阶导数。它支持自动模式反向求导(也就是反向传播)和正向求导,且二者可以任意组合成任何顺序。


JAX 的创新之处在于,它基于 XLA 在 GPU 和 TPU 上编译和运行 NumPy 程序。默认情况下,编译是在底层进行的,库调用能够及时编译和执行。但是 JAX 还允许使用单一函数 API jit将自己的 Python 函数及时编译成经过 XLA 优化的内核。编译和自动求导可以任意组合,因此可以在不脱离 Python 环境的情况下实现复杂算法并获得最优性能。


JAX 最初由 Matt Johnson、Roy Frostig、Dougal Maclaurin 和 Chris Leary 发起,他们均任职于谷歌大脑团队。在 GitHub 的说明文档中,作者明确表示:JAX 目前还只是一个研究项目,不是谷歌的官方产品,因此可能会有一些 bug。从作者的 GitHub 简介来看,这应该是谷歌大脑正在尝试的新项目,在同一个 GitHub 目录下的开源项目还包括 8 月份在业内引起热议的强化学习框架 Dopamine。


以下是 JAX 的简单使用示例。



GitHub 项目传送门:https://github.com/google/JAX


有关具体的安装和简单的入门指导大家可以在 GitHub 中自行查看,在此不做过多赘述。

JAX 库的实现原理

机器学习中的编程是关于函数的表达和转换。转换包括自动微分、加速器编译和自动批处理。像 Python 这样的高级语言非常适合表达函数,但是通常使用者只能应用它们。我们无法访问它们的内部结构,因此无法执行转换。


JAX 可以用于专门化高级 Python+NumPy 函数,并将其转换为可转换的表示形式,然后再提升为 Python 函数。



JAX 通过跟踪专门处理 Python 函数。跟踪一个函数意味着:监视应用于其输入,以产生其输出的所有基本操作,并在有向无环图(DAG)中记录这些操作及其之间的数据流。为了执行跟踪,JAX 包装了基本的操作,就像基本的数字内核一样,这样一来,当调用它们时,它们就会将自己添加到执行的操作列表以及输入和输出中。为了跟踪这些原语之间的数据流,跟踪的值被包装在 Tracer 类的实例中。


当 Python 函数被提供给 grad 或 jit 时,它被包装起来以便跟踪并返回。当调用包装的函数时,我们将提供的具体参数抽象到 AbstractValue 类的实例中,将它们框起来用于跟踪跟踪器类的实例,并对它们调用函数。


抽象参数表示一组可能的值,而不是特定的值:例如,jit 将 ndarray 参数抽象为抽象值,这些值表示具有相同形状和数据类型的所有 ndarray。相反,grad 抽象 ndarray 参数来表示底层值的无穷小邻域。通过在这些抽象值上跟踪 Python 函数,我们确保它足够专门化,以便转换是可处理的,并且它仍然足够通用,以便转换后的结果是有用的,并且可能是可重用的。然后将这些转换后的函数提升回 Python 可调用函数,这样就可以根据需要跟踪并再次转换它们。


JAX 跟踪的基本函数大多与 XLA HLO 1:1 对应,并在 lax.py 中定义。这种 1:1 的对应关系使得到 XLA 的大多数转换基本上都很简单,并且确保我们只有一小组原语来覆盖其他转换,比如自动微分。 jax.numpy 层是用纯 Python 编写的,它只是用 LAX 函数(以及我们已经编写的其他 numpy 函数)表示 numpy 函数。这使得 jax.numpy 易于延展。


当你使用 jax.numpy 时,底层 LAX 原语是在后台进行 jit 编译的,允许你在加速器上执行每个原语操作的同时编写不受限制的 Python+ numpy 代码。


但是 JAX 可以做更多的事情:你可以在越来越大的函数上使用 jit 来进行端到端编译和优化,而不仅仅是编译和调度到一组固定的单个原语。例如,可以编译整个网络,或者编译整个梯度计算和优化器更新步骤,而不仅仅是编译和调度卷积运算。


折衷之处是,jit 函数必须满足一些额外的专门化需求:因为我们希望编译专门针对形状和数据类型的跟踪,但不是专门针对具体值的跟踪,所以 jit 装饰器下的 Python 代码必须适用于抽象值。如果我们尝试在一个抽象的 x 上求 x >0 的值,结果是一个抽象的值,表示集合{True, False},所以 Python 分支就像 if x > 0 会引起报错。


有关使用 jit 的更多要求,请参见:https://github.com/google/jax#whats-supported


好消息是,jit 是可选的:JAX 库在后台对单个操作和函数使用 jit,允许编写不受限制的 Python+Numpy,同时仍然使用硬件加速器。但是,当你希望最大化性能时,通常可以在自己的代码中使用 jit 编译和端到端优化更大的函数。

后续计划

目前项目小组还将对以下几项做更多尝试和更新:


  1. 完善说明文档

  2. 支持 Cloud TPU

  3. 支持多 GPU 和多 TPU

  4. 支持完整的 NumPy 功能和部分 SciPy 功能

  5. 全面支持 vmap

  6. 加速

  7. 降低 XLA 函数调度开销

  8. 线性代数例程(CPU 上的 MKL 和 GPU 上的 MAGMA)

  9. 高效自动微分原语condwhile


有关 JAX 库的介绍大致如此,如果你在尝试了 JAX 之后有一些较好的使用心得,欢迎随时向我们投稿,AI 前线十分愿意将你的经验传播给更多开发者。


再次附上 GitHub 链接:https://github.com/google/jax


相关资源:


JAX 论文链接:https://www.sysml.cc/doc/146.pdf

会议推荐

AICon


公众号推荐:

跳进 AI 的奇妙世界,一起探索未来工作的新风貌!想要深入了解 AI 如何成为产业创新的新引擎?好奇哪些城市正成为 AI 人才的新磁场?《中国生成式 AI 开发者洞察 2024》由 InfoQ 研究中心精心打造,为你深度解锁生成式 AI 领域的最新开发者动态。无论你是资深研发者,还是对生成式 AI 充满好奇的新手,这份报告都是你不可错过的知识宝典。欢迎大家扫码关注「AI前线」公众号,回复「开发者洞察」领取。

2018-12-12 07:002268
用户头像
陈思 InfoQ编辑

发布了 576 篇内容, 共 262.4 次阅读, 收获喜欢 1293 次。

关注

评论 1 条评论

发布
暂无评论
发现更多内容

阿里大牛耗时三年整理出来的4588页Java面试诛仙手册,已全面开源

java小李

Linux 面试

声网 Agora 音频互动 MoS 分方法:为音频互动体验进行实时打分

声网

算法 网络

技术分析| 即时通讯和实时通讯的区别

anyRTC开发者

音视频 WebRTC 即时通讯 实时通讯 实时消息

纷多多拼团系统开发案例详解,纷多多拼团现成源码

系统开发咨询1357O98O718

短视频平台获客软件系统开发

短视频营销系统开发内容

秀出新天际的SpringBoot笔记,让开发像搭积木一样简单

java小李

Spring Boot java架构

java并发编程

十二万伏特皮卡丘

独家!精挑细选三个月的臻品Java面试题,无糟粕!高质量

白亦杨

Java 编程 程序员 架构师 计算机

系统性能优化-数据结构

从简历被拒到收割8个大厂offer,我用了3个月成功破茧成蝶

java小李

面试

Fil还有希望吗?目前Fil发展如何了?

区块链 IPFS Filecoin fil filecoin生态

阿里这份15w字Java核心面试笔记!GitHub凭借百万下载量位居榜首

java小李

面试 Java核心笔记

阿里资深架构师倾情力荐:Java全线成长宝典,P5到P8一应俱全

愚者

Java 面试

柏益美康系统开发案例详解,柏益美康开发源码

系统开发咨询1357O98O718

阿里云飞天论文获国际架构顶会 ATC 2021最佳论文:全球仅三篇

阿里云大数据AI技术

渣本展示Spring Cloud 架构绝活!最后成功入职阿里

java小李

Spring Cloud

LeetCode题解:61. 旋转链表,闭合为环,JavaScript,详细注释

Lee Chen

算法 大前端 LeetCode

北鲲云超算在生命科学领域的使用场景中有什么作用?

北鲲云

泪目!跳槽太不容易,蚂蚁金服三轮面试,四个小时灵魂拷问

java小李

面试 Leader

阿里内网流传的9w字图解网络(全彩版)GitHub现已下载量过百万

java小李

HTTP

HarmonyOS学习路之开发篇——线程管理

爱吃土豆丝的打工人

多线程 HarmonyOS 线程管理

我们向华为公司学什么?

石云升

学习 华为 7月日更

对话交互:封闭域任务型与开放域闲聊算法技术

OPPO小布助手

人工智能 深度学习 对话 智能助手 语义理解

模块四作业

燕燕 yen yen

架构实战营

拿来吧你!从阿里P8手里抢来的的JDK源码解析手册,Alibaba真的强

java小李

jdk

阿里内网疯传的P8“顶级”分布式架构手册,GitHub上线直接霸榜了

java小李

微信业务架构 P8

大专的我狂刷29天“阿里内部面试笔记”最终直接斩获十七个Offer

java小李

大数据 面试

贝丽美牙系统开发(开发案例),贝丽美牙源码设计

系统开发咨询1357O98O718

香到爆!SpringBoot/SpringCloud全套学习脑图+面试笔记免费分享

java小李

SpringCloud Alibaba

我看 JAVA 之 并发编程【二】java.util.concurrent.locks

awen

Java AQS lock Condition LockSupport

TensorFlow官方简化版!谷歌开源机器学习库JAX_AI&大模型_InfoQ精选文章