NVIDIA 初创加速计划,免费加速您的创业启动 了解详情
写点什么

关于要替代 TensorFlow 的 JAX,你知道多少?

  • 2019-02-12
  • 本文字数:3550 字

    阅读完需:约 12 分钟

关于要替代TensorFlow的JAX,你知道多少?

这个简短的教程将介绍关于 JAX 的基础知识。JAX 是一个 Python 库,它通过函数转换来增强 numpy 和 Python 代码,使运行机器学习程序中常见的操作轻而易举。具体来说,它会使得编写标准 Python / numpy 代码变得简单,并且能够立即执行


  • 通过 autograd 的后继计算函数的导数

  • 及时编译函数,通过 XLA 在加速器上高效运行

  • 自动矢量化函数,并执行处理“批量”数据等


在本教程中,我们将通过演示它在 AGI 的一个核心问题:使用神经网络学习异或(XOR)函数,依次介绍这些转换。


注意:此博客文章在此处提供交互式 Jupyter notebook:https://github.com/craffel/jax-tutorial

1 JAX 只是 numpy(大多数情况下)

从本质上讲,你可以将 JAX 视为使用执行上述转换所需的机器来增强 numpy。JAX 增强的 numpy 为 jax.numpy。除了少数例外,可以认为 jax.numpy 与 numpy 可直接互换。作为一般规则,当你计划使用 JAX 的任何转换(如计算渐变或即时编译代码),或希望代码在加速器上运行时,都应该使用 jax.numpy。当 jax.numpy 不支持你的计算时,用 numpy 就行了。


import randomimport itertools
import jaximport jax.numpy as np# Current convention is to import original numpy as "onp"import numpy as onp
from __future__ import print_function
复制代码

2 背景

如前所述,我们将使用小型神经网络学习 XOR 功能。 XOR 函数将两个二进制数作为输入并输出二进制数,如下图所示:



我们将使用具有 3 个神经元和双曲正切非线性的单个隐藏层的神经网络,通过随机梯度下降训练交叉熵损失。然后实现此模型和损失函数。请注意,代码与你在标准 numpy 中编写的完全一样。


# Sigmoid nonlinearitydef sigmoid(x):    return 1 / (1 + np.exp(-x))
# Computes our network's outputdef net(params, x): w1, b1, w2, b2 = params hidden = np.tanh(np.dot(w1, x) + b1) return sigmoid(np.dot(w2, hidden) + b2)
# Cross-entropy lossdef loss(params, x, y): out = net(params, x) cross_entropy = -y * np.log(out) - (1 - y)*np.log(1 - out) return cross_entropy
# Utility function for testing whether the net produces the correct# output for all possible inputsdef test_all_inputs(inputs, params): predictions = [int(net(params, inp) > 0.5) for inp in inputs] for inp, out in zip(inputs, predictions): print(inp, '->', out) return (predictions == [onp.bitwise_xor(*inp) for inp in inputs])
复制代码


如上所述,有些地方我们想要使用标准 numpy 而不是 jax.numpy。比如参数初始化。我们想在训练网络之前随机初始化参数,这不是我们需要衍生或编译的操作。JAX 使用自己的 jax.random 库而不是 numpy.random,为不同转换的复现性(种子)提供了更好的支持。由于我们不需要以任何方式转换参数的初始化,因此最简单的方法就是在这里使用标准


的 numpy.random 而不是 jax.random。


def initial_params():    return [        onp.random.randn(3, 2),  # w1        onp.random.randn(3),  # b1        onp.random.randn(3),  # w2        onp.random.randn(),  #b2    ]
复制代码

3 jax.grad

我们将使用的第一个转换是 jax.grad。jax.grad 接受一个函数并返回一个新函数,该函数计算原始函数的渐变。默认情况下,相对于第一个参数进行渐变;这可以通过 jgn.grad 的 argnums 参数来控制。要使用梯度下降,我们希望能够根据神经网络的参数计算损失函数的梯度。为此,使用 jax.grad(loss)就可以,它将提供一个可以调用以获得这些渐变的函数。


loss_grad = jax.grad(loss)
# Stochastic gradient descent learning ratelearning_rate = 1.# All possible inputsinputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])
# Initialize parameters randomlyparams = initial_params()
for n in itertools.count(): # Grab a single random input x = inputs[onp.random.choice(inputs.shape[0])] # Compute the target output y = onp.bitwise_xor(*x) # Get the gradient of the loss for this input/output pair grads = loss_grad(params, x, y) # Update parameters via gradient descent params = [param - learning_rate * grad for param, grad in zip(params, grads)] # Every 100 iterations, check whether we've solved XOR if not n % 100: print('Iteration {}'.format(n)) if test_all_inputs(inputs, params): break
复制代码


4 jax.jit

虽然我们精心编写的 numpy 代码运行起来效果还行,但对于现代机器学习来说,我们希望这些代码运行得尽可能快。这一般通过在 GPU 或 TPU 等不同的“加速器”上运行代码来实现。JAX 提供了一个 JIT(即时)编译器,它采用标准的 Python / numpy 函数,经编译可以在加速器上高效运行。编译函数还可以避免 Python 解释器的开销,这决定了你是否使用加速器。总的来说,jax.jit 可以显著加速代码运行,且基本上没有编码开销,你需要做的就是让 JAX 为你编译函数。使用 jax.jit 时,即使是微小的神经网络也可以实现相当惊人的加速度:


# Time the original gradient function%timeit loss_grad(params, x, y)loss_grad = jax.jit(jax.grad(loss))# Run once to trigger JIT compilationloss_grad(params, x, y)%timeit loss_grad(params, x, y)
复制代码


10 loops, best of 3: 13.1 ms per loop


1000 loops, best of 3: 862 µs per loop


请注意,JAX 允许我们将变换链接在一起。首先,我们使用 jax.grad 获取了丢失的梯度,然后使用 jax.jit 立即进行编译。这是使 JAX 更强大的一个因素——除了链接 jax.jit 和 jax.grad 之外,我们还可以多次应用 jax.grad 以获得更高阶的导数等。为了确保训练神经网络经过编译后仍然有效,我们再次对它进行训练。请注意,训练代码没有任何变化。


params = initial_params()
for n in itertools.count(): x = inputs[onp.random.choice(inputs.shape[0])] y = onp.bitwise_xor(*x) grads = loss_grad(params, x, y) params = [param - learning_rate * grad for param, grad in zip(params, grads)] if not n % 100: print('Iteration {}'.format(n)) if test_all_inputs(inputs, params): break
复制代码


5 jax.vmap

精明的读者可能已经注意到,我们一直在一个例子上训练我们的神经网络。这是“真正的”随机梯度下降;在实践中,当训练现代机器学习模型时,我们执行“小批量”梯度下降,在梯度下降的每个步骤中,我们对一小批示例中的损失梯度求平均值。JAX 提供了 jax.vmap,这是一个自动“矢量化”函数的转换。这意味着它允许你在输入的某个轴上并行计算函数的输出。对我们来说,这意味着我们可以应用 jax.vmap 函数转换并立即获得损失函数渐变的版本,该版本适用于小批量示例。


jax.vmap 还可接受其他参数:


  • in_axes 是一个元组或整数,它告诉 JAX 函数参数应该对哪些轴并行化。元组应该与 vmap’d 函数的参数数量相同,或者只有一个参数时为整数。示例中,我们将使用(None,0,0),指“不在第一个参数(params)上并行化,并在第二个和第三个参数(x 和 y)的第一个(第零个)维度上并行化”。

  • out_axes 类似于 in_axes,除了它指定了函数输出的哪些轴并行化。我们在例子中使用 0,表示在函数唯一输出的第一个(第零个)维度上进行并行化(损失梯度)。


请注意,我们必须稍微修改一下训练代码——我们需要一次抓取一批数据而不是单个示例,并在应用它们来更新参数之前对批处理中的渐变求平均。


loss_grad = jax.jit(jax.vmap(jax.grad(loss), in_axes=(None, 0, 0), out_axes=0))
params = initial_params()
batch_size = 100
for n in itertools.count(): # Generate a batch of inputs x = inputs[onp.random.choice(inputs.shape[0], size=batch_size)] y = onp.bitwise_xor(x[:, 0], x[:, 1]) # The call to loss_grad remains the same! grads = loss_grad(params, x, y) # Note that we now need to average gradients over the batch params = [param - learning_rate * np.mean(grad, axis=0) for param, grad in zip(params, grads)] if not n % 100: print('Iteration {}'.format(n)) if test_all_inputs(inputs, params): break
复制代码


6 指南

这就是我们将在这个简短的教程中介绍的内容,但这实际上涵盖了大量的 JAX 知识。由于 JAX 主要是 numpy 和 Python,因此你可以利用现有知识,而不必学习基本的新框架或范例。


有关其他资源,请查看 JAX GitHub:


https://github.com/google/jax 上的 notebook 和示例目录。


公众号推荐:

跳进 AI 的奇妙世界,一起探索未来工作的新风貌!想要深入了解 AI 如何成为产业创新的新引擎?好奇哪些城市正成为 AI 人才的新磁场?《中国生成式 AI 开发者洞察 2024》由 InfoQ 研究中心精心打造,为你深度解锁生成式 AI 领域的最新开发者动态。无论你是资深研发者,还是对生成式 AI 充满好奇的新手,这份报告都是你不可错过的知识宝典。欢迎大家扫码关注「AI前线」公众号,回复「开发者洞察」领取。

2019-02-12 08:056194
用户头像

发布了 98 篇内容, 共 62.6 次阅读, 收获喜欢 285 次。

关注

评论

发布
暂无评论
发现更多内容

云原生技术及其未来发展趋势展望 | 趋势解读

浪潮云

云原生

IAP:物联网终端软件升级技术

华为云开发者联盟

IoT LiteOS iap 物联网终端 OTA

【LeetCode】132模式Java题解

Albert

算法 LeetCode 3月日更

区块链BaaS应用服务平台的搭建

13828808769

区块链+ #区块链#

神策大数据技术直播系列课第二季,开讲啦

神策技术社区

大数据 性能优化 大前端 工程师 事件分析

情指勤指挥调度平台搭建,公安局情报指挥系统

PostgreSQL 集群宕机后恢复

桜喵ノねこ

“刷脸”日益泛滥,“掌经脉”开辟生物识别新路

E科讯

区块链技术或加速企业“碳中和”战略落地

CECBC

区块

用 WebRTC 打造一个音乐教育 App,要解决哪些音质难题?

阿里云视频云

音视频 WebRTC 在线教育 RTC

终于知道为啥网页不让我复制粘贴了!

华为云开发者联盟

js 代码 button事件 复制粘贴 输入框

带你了解数据库的“吸尘器”:VACUUM

华为云开发者联盟

数据库 数据 GaussDB(DWS) VACUUM

商品溯源之痛,区块链对商品假冒的解决方案

13828808769

区块链+ 区块链应用 区块链发展 #区块链#

自媒体平台数据统计分析爬虫之【趣头条】模拟登陆分析详解及数据统计接口详解

ucsheep

接口 爬虫 趣头条 模拟登录

直播预告 | 数据操作加速器,CloudQuery v1.3.5 发布

BinTools图尔兹

sql 编辑器 数据治理 数据安全 数据库管理工具

“英特尔‘IDM2.0’的疯狂”

E科讯

“数字云南”建设成效逐渐显现 区块链财政电子票据带来民生与环保效益

CECBC

区块链

云原生数据库风起云涌,华为云GaussDB破浪前行

华为云开发者联盟

数据库 架构 云原生 华为云 GaussDB

《Redis 核心技术与实战》学习笔记 08:GEO数据类型和时间序列数据

escray

redis 学习 极客时间 3月日更 Redis 核心技术与实战

PHP程序员如何简单的开展服务治理架构(一)

CrazyCodes

php 服务治理

智慧公安重点人员管控系统大数据分析平台的搭建

13828808769

智慧城市 智慧交通

超详细!手把手带你快速入门 GitHub!

JackTian

git GitHub 开源

技术杂谈 | Flutter 的性能分析、工程架构与细节处理

有道技术团队

flutter

量化策略软件搭建,马丁策略交易软件开发

k8s(Kubernetes)中Pod,Deployment,ReplicaSet,Service之间关系分析

ucsheep

Kubernetes k8s pod Deployment ReplicaSet

翻译:《实用的Python编程》07_04_Function_decorators

codists

Python PEP

Python SMTP 发送邮件方法

HoneyMoose

Spring-Retry重试实现原理,有点东西哈

Java小咖秀

Java spring 源码 原理 开发

分而治之——D&C

Kylin

3月日更 21天挑战 分而治之

网络连接总超时?从四层模型上解析网络是怎么连接的

京东科技开发者

计算机网络 服务器 域名

力扣(LeetCode)刷题,简单题(第14期)

不脱发的程序猿

面试 LeetCode 28天写作 算法攻关 3月日更

关于要替代TensorFlow的JAX,你知道多少?_AI&大模型_Colin Raffel_InfoQ精选文章