写点什么

关于要替代 TensorFlow 的 JAX,你知道多少?

  • 2019-02-12
  • 本文字数:3550 字

    阅读完需:约 12 分钟

关于要替代TensorFlow的JAX,你知道多少?

这个简短的教程将介绍关于 JAX 的基础知识。JAX 是一个 Python 库,它通过函数转换来增强 numpy 和 Python 代码,使运行机器学习程序中常见的操作轻而易举。具体来说,它会使得编写标准 Python / numpy 代码变得简单,并且能够立即执行


  • 通过 autograd 的后继计算函数的导数

  • 及时编译函数,通过 XLA 在加速器上高效运行

  • 自动矢量化函数,并执行处理“批量”数据等


在本教程中,我们将通过演示它在 AGI 的一个核心问题:使用神经网络学习异或(XOR)函数,依次介绍这些转换。


注意:此博客文章在此处提供交互式 Jupyter notebook:https://github.com/craffel/jax-tutorial

1 JAX 只是 numpy(大多数情况下)

从本质上讲,你可以将 JAX 视为使用执行上述转换所需的机器来增强 numpy。JAX 增强的 numpy 为 jax.numpy。除了少数例外,可以认为 jax.numpy 与 numpy 可直接互换。作为一般规则,当你计划使用 JAX 的任何转换(如计算渐变或即时编译代码),或希望代码在加速器上运行时,都应该使用 jax.numpy。当 jax.numpy 不支持你的计算时,用 numpy 就行了。


import randomimport itertools
import jaximport jax.numpy as np# Current convention is to import original numpy as "onp"import numpy as onp
from __future__ import print_function
复制代码

2 背景

如前所述,我们将使用小型神经网络学习 XOR 功能。 XOR 函数将两个二进制数作为输入并输出二进制数,如下图所示:



我们将使用具有 3 个神经元和双曲正切非线性的单个隐藏层的神经网络,通过随机梯度下降训练交叉熵损失。然后实现此模型和损失函数。请注意,代码与你在标准 numpy 中编写的完全一样。


# Sigmoid nonlinearitydef sigmoid(x):    return 1 / (1 + np.exp(-x))
# Computes our network's outputdef net(params, x): w1, b1, w2, b2 = params hidden = np.tanh(np.dot(w1, x) + b1) return sigmoid(np.dot(w2, hidden) + b2)
# Cross-entropy lossdef loss(params, x, y): out = net(params, x) cross_entropy = -y * np.log(out) - (1 - y)*np.log(1 - out) return cross_entropy
# Utility function for testing whether the net produces the correct# output for all possible inputsdef test_all_inputs(inputs, params): predictions = [int(net(params, inp) > 0.5) for inp in inputs] for inp, out in zip(inputs, predictions): print(inp, '->', out) return (predictions == [onp.bitwise_xor(*inp) for inp in inputs])
复制代码


如上所述,有些地方我们想要使用标准 numpy 而不是 jax.numpy。比如参数初始化。我们想在训练网络之前随机初始化参数,这不是我们需要衍生或编译的操作。JAX 使用自己的 jax.random 库而不是 numpy.random,为不同转换的复现性(种子)提供了更好的支持。由于我们不需要以任何方式转换参数的初始化,因此最简单的方法就是在这里使用标准


的 numpy.random 而不是 jax.random。


def initial_params():    return [        onp.random.randn(3, 2),  # w1        onp.random.randn(3),  # b1        onp.random.randn(3),  # w2        onp.random.randn(),  #b2    ]
复制代码

3 jax.grad

我们将使用的第一个转换是 jax.grad。jax.grad 接受一个函数并返回一个新函数,该函数计算原始函数的渐变。默认情况下,相对于第一个参数进行渐变;这可以通过 jgn.grad 的 argnums 参数来控制。要使用梯度下降,我们希望能够根据神经网络的参数计算损失函数的梯度。为此,使用 jax.grad(loss)就可以,它将提供一个可以调用以获得这些渐变的函数。


loss_grad = jax.grad(loss)
# Stochastic gradient descent learning ratelearning_rate = 1.# All possible inputsinputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])
# Initialize parameters randomlyparams = initial_params()
for n in itertools.count(): # Grab a single random input x = inputs[onp.random.choice(inputs.shape[0])] # Compute the target output y = onp.bitwise_xor(*x) # Get the gradient of the loss for this input/output pair grads = loss_grad(params, x, y) # Update parameters via gradient descent params = [param - learning_rate * grad for param, grad in zip(params, grads)] # Every 100 iterations, check whether we've solved XOR if not n % 100: print('Iteration {}'.format(n)) if test_all_inputs(inputs, params): break
复制代码


4 jax.jit

虽然我们精心编写的 numpy 代码运行起来效果还行,但对于现代机器学习来说,我们希望这些代码运行得尽可能快。这一般通过在 GPU 或 TPU 等不同的“加速器”上运行代码来实现。JAX 提供了一个 JIT(即时)编译器,它采用标准的 Python / numpy 函数,经编译可以在加速器上高效运行。编译函数还可以避免 Python 解释器的开销,这决定了你是否使用加速器。总的来说,jax.jit 可以显著加速代码运行,且基本上没有编码开销,你需要做的就是让 JAX 为你编译函数。使用 jax.jit 时,即使是微小的神经网络也可以实现相当惊人的加速度:


# Time the original gradient function%timeit loss_grad(params, x, y)loss_grad = jax.jit(jax.grad(loss))# Run once to trigger JIT compilationloss_grad(params, x, y)%timeit loss_grad(params, x, y)
复制代码


10 loops, best of 3: 13.1 ms per loop


1000 loops, best of 3: 862 µs per loop


请注意,JAX 允许我们将变换链接在一起。首先,我们使用 jax.grad 获取了丢失的梯度,然后使用 jax.jit 立即进行编译。这是使 JAX 更强大的一个因素——除了链接 jax.jit 和 jax.grad 之外,我们还可以多次应用 jax.grad 以获得更高阶的导数等。为了确保训练神经网络经过编译后仍然有效,我们再次对它进行训练。请注意,训练代码没有任何变化。


params = initial_params()
for n in itertools.count(): x = inputs[onp.random.choice(inputs.shape[0])] y = onp.bitwise_xor(*x) grads = loss_grad(params, x, y) params = [param - learning_rate * grad for param, grad in zip(params, grads)] if not n % 100: print('Iteration {}'.format(n)) if test_all_inputs(inputs, params): break
复制代码


5 jax.vmap

精明的读者可能已经注意到,我们一直在一个例子上训练我们的神经网络。这是“真正的”随机梯度下降;在实践中,当训练现代机器学习模型时,我们执行“小批量”梯度下降,在梯度下降的每个步骤中,我们对一小批示例中的损失梯度求平均值。JAX 提供了 jax.vmap,这是一个自动“矢量化”函数的转换。这意味着它允许你在输入的某个轴上并行计算函数的输出。对我们来说,这意味着我们可以应用 jax.vmap 函数转换并立即获得损失函数渐变的版本,该版本适用于小批量示例。


jax.vmap 还可接受其他参数:


  • in_axes 是一个元组或整数,它告诉 JAX 函数参数应该对哪些轴并行化。元组应该与 vmap’d 函数的参数数量相同,或者只有一个参数时为整数。示例中,我们将使用(None,0,0),指“不在第一个参数(params)上并行化,并在第二个和第三个参数(x 和 y)的第一个(第零个)维度上并行化”。

  • out_axes 类似于 in_axes,除了它指定了函数输出的哪些轴并行化。我们在例子中使用 0,表示在函数唯一输出的第一个(第零个)维度上进行并行化(损失梯度)。


请注意,我们必须稍微修改一下训练代码——我们需要一次抓取一批数据而不是单个示例,并在应用它们来更新参数之前对批处理中的渐变求平均。


loss_grad = jax.jit(jax.vmap(jax.grad(loss), in_axes=(None, 0, 0), out_axes=0))
params = initial_params()
batch_size = 100
for n in itertools.count(): # Generate a batch of inputs x = inputs[onp.random.choice(inputs.shape[0], size=batch_size)] y = onp.bitwise_xor(x[:, 0], x[:, 1]) # The call to loss_grad remains the same! grads = loss_grad(params, x, y) # Note that we now need to average gradients over the batch params = [param - learning_rate * np.mean(grad, axis=0) for param, grad in zip(params, grads)] if not n % 100: print('Iteration {}'.format(n)) if test_all_inputs(inputs, params): break
复制代码


6 指南

这就是我们将在这个简短的教程中介绍的内容,但这实际上涵盖了大量的 JAX 知识。由于 JAX 主要是 numpy 和 Python,因此你可以利用现有知识,而不必学习基本的新框架或范例。


有关其他资源,请查看 JAX GitHub:


https://github.com/google/jax 上的 notebook 和示例目录。


2019-02-12 08:056411
用户头像

发布了 98 篇内容, 共 64.7 次阅读, 收获喜欢 285 次。

关注

评论

发布
暂无评论
发现更多内容

别再说你不懂Java内存模型了!!!

做梦都在改BUG

Java 内存模型 JMM 并发

火山引擎DataTester 3大功能升级:聚焦敏捷、智能与易用,帮助企业降本增效

字节跳动数据平台

大数据 AB testing实战 A/B 测试 对比实验

Neuron 2.4.0发布:体验下一代工业物联网连接和管理

EMQ映云科技

UI 物联网 IoT neuron 企业号 4 月 PK 榜

2023年免费堡垒机软件推荐-行云管家堡垒机免费版

行云管家

网络安全 堡垒机

优化配置Little Snitch for Mac的规则和设置

理理

macOS防火墙软件 Little Snitch 下载 小飞贼防火墙 Little Snitch 规则

Mysql 连接查询

AutoCAD 2024 Mac版如何设置自动保存文件为低版本?

理理

CAD制图软件 AutoCAD 2024 Mac版 AutoCAD使用技巧 保存文件为低版本

迁移prometheus数据

TiDB 社区干货传送门

迁移 实践案例 集群管理

TiDB迁移、升级与案例分享(TiDB v4.0.11 → v6.5.1)

TiDB 社区干货传送门

迁移 版本升级 安装 & 部署 扩/缩容 6.x 实践

Mac人工智能图像降噪 Topaz Photo AI 安装激活教程

理理

Topaz Photo AI下载 图像降噪 苹果mac软件下载 Topaz Photo AI mac Topaz Photo AI破解

4 月 25 日直播预告 | 深入解读 Flink 1.17

Apache Flink

大数据 flink 实时计算

膜拜!华为内部都在强推的783页大数据处理系统:Hadoop源代码

做梦都在改BUG

Java 大数据 hadoop

maya软件在建模上有什么优势?

Finovy Cloud

maya 3D软件

零样本文本分类应用:基于UTC的医疗意图多分类,打通数据标注-模型训练-模型调优-预测部署全流程。

汀丶人工智能

人工智能 自然语言处理 深度学习 文本分类 小样本学习

全平台数据(数据库)管理工具 DataCap 管理 Rainbond 上的所有数据库

北京好雨科技有限公司

数据库 Kubernetes 云原生 rainbond 企业号 4 月 PK 榜

Karmada 多云容器编排引擎支持多调度组,助力成本优化

华为云开发者联盟

云原生 后端 华为云 华为云开发者联盟 企业号 4 月 PK 榜

苹果电脑臻选4K壁纸app:Dynamic Wallpaper

理理

动态壁纸 Dynamic Wallpaper下载 苹果壁纸下载 mac壁纸一软件 高清壁纸下载

tidb-loadbalance 客户端方式软负载均衡配置实践

TiDB 社区干货传送门

数据库架构设计 数据库连接

WordPress 使用 TiDB Cloud 替换 MySQL

TiDB 社区干货传送门

迁移 实践案例 版本测评 应用适配

ChatGPT 真能带货吗?晒一下 SQL Chat 上线 3 周以来的真实运营数据📊

Bytebase

MySQL sql postgres ChatGPT SQL Server

谷歌 Chrome 正式发布 WebGPU!Orillusion开源倒计时!

Orillusion

开源 WebGL 元宇宙 web3d #WebGPU

用数据分析的方法去做dba,维护好tidb数据库。

TiDB 社区干货传送门

6.x 实践

python正则 | python小知识

AIWeker

Python python小知识 三周年连更

国内外主流的8款 IT 项目管理平台

爱吃小舅的鱼

项目管理工具 PingCode 项目研发管理软件

HummerRisk V1.0 开发手册(微服务版)

HummerCloud

开源 微服务 云原生安全

Parallels Desktop PD 18虚拟机关闭、停止、中止和暂停操作的区别

理理

pd虚拟机 Parallels Desktop PD18虚拟机操作

阿里P8推荐学习的44个微服务架构设计模式,真的太香了!

做梦都在改BUG

Java 架构 微服务 设计模式

迁移PD坑-cdc任务全部stop

TiDB 社区干货传送门

实践案例 集群管理 故障排查/诊断

新特性解析丨TiDB 资源管控的设计思路与场景解析

TiDB 社区干货传送门

版本测评

【Linux】之创建普通用户并禁止root用户远程登陆

A-刘晨阳

Linux 三周年连更 用户名

chatGPT衣食住行10种场景系列教程(01)chatGPT热点事件汇总+开发利器

非喵鱼

java openai AIGC ChatGPT 三周年连更

关于要替代TensorFlow的JAX,你知道多少?_AI&大模型_Colin Raffel_InfoQ精选文章