AI实践哪家强?来 AICon, 解锁技术前沿,探寻产业新机! 了解详情
写点什么

关于要替代 TensorFlow 的 JAX,你知道多少?

  • 2019-02-12
  • 本文字数:3550 字

    阅读完需:约 12 分钟

关于要替代TensorFlow的JAX,你知道多少?

这个简短的教程将介绍关于 JAX 的基础知识。JAX 是一个 Python 库,它通过函数转换来增强 numpy 和 Python 代码,使运行机器学习程序中常见的操作轻而易举。具体来说,它会使得编写标准 Python / numpy 代码变得简单,并且能够立即执行


  • 通过 autograd 的后继计算函数的导数

  • 及时编译函数,通过 XLA 在加速器上高效运行

  • 自动矢量化函数,并执行处理“批量”数据等


在本教程中,我们将通过演示它在 AGI 的一个核心问题:使用神经网络学习异或(XOR)函数,依次介绍这些转换。


注意:此博客文章在此处提供交互式 Jupyter notebook:https://github.com/craffel/jax-tutorial

1 JAX 只是 numpy(大多数情况下)

从本质上讲,你可以将 JAX 视为使用执行上述转换所需的机器来增强 numpy。JAX 增强的 numpy 为 jax.numpy。除了少数例外,可以认为 jax.numpy 与 numpy 可直接互换。作为一般规则,当你计划使用 JAX 的任何转换(如计算渐变或即时编译代码),或希望代码在加速器上运行时,都应该使用 jax.numpy。当 jax.numpy 不支持你的计算时,用 numpy 就行了。


import randomimport itertools
import jaximport jax.numpy as np# Current convention is to import original numpy as "onp"import numpy as onp
from __future__ import print_function
复制代码

2 背景

如前所述,我们将使用小型神经网络学习 XOR 功能。 XOR 函数将两个二进制数作为输入并输出二进制数,如下图所示:



我们将使用具有 3 个神经元和双曲正切非线性的单个隐藏层的神经网络,通过随机梯度下降训练交叉熵损失。然后实现此模型和损失函数。请注意,代码与你在标准 numpy 中编写的完全一样。


# Sigmoid nonlinearitydef sigmoid(x):    return 1 / (1 + np.exp(-x))
# Computes our network's outputdef net(params, x): w1, b1, w2, b2 = params hidden = np.tanh(np.dot(w1, x) + b1) return sigmoid(np.dot(w2, hidden) + b2)
# Cross-entropy lossdef loss(params, x, y): out = net(params, x) cross_entropy = -y * np.log(out) - (1 - y)*np.log(1 - out) return cross_entropy
# Utility function for testing whether the net produces the correct# output for all possible inputsdef test_all_inputs(inputs, params): predictions = [int(net(params, inp) > 0.5) for inp in inputs] for inp, out in zip(inputs, predictions): print(inp, '->', out) return (predictions == [onp.bitwise_xor(*inp) for inp in inputs])
复制代码


如上所述,有些地方我们想要使用标准 numpy 而不是 jax.numpy。比如参数初始化。我们想在训练网络之前随机初始化参数,这不是我们需要衍生或编译的操作。JAX 使用自己的 jax.random 库而不是 numpy.random,为不同转换的复现性(种子)提供了更好的支持。由于我们不需要以任何方式转换参数的初始化,因此最简单的方法就是在这里使用标准


的 numpy.random 而不是 jax.random。


def initial_params():    return [        onp.random.randn(3, 2),  # w1        onp.random.randn(3),  # b1        onp.random.randn(3),  # w2        onp.random.randn(),  #b2    ]
复制代码

3 jax.grad

我们将使用的第一个转换是 jax.grad。jax.grad 接受一个函数并返回一个新函数,该函数计算原始函数的渐变。默认情况下,相对于第一个参数进行渐变;这可以通过 jgn.grad 的 argnums 参数来控制。要使用梯度下降,我们希望能够根据神经网络的参数计算损失函数的梯度。为此,使用 jax.grad(loss)就可以,它将提供一个可以调用以获得这些渐变的函数。


loss_grad = jax.grad(loss)
# Stochastic gradient descent learning ratelearning_rate = 1.# All possible inputsinputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])
# Initialize parameters randomlyparams = initial_params()
for n in itertools.count(): # Grab a single random input x = inputs[onp.random.choice(inputs.shape[0])] # Compute the target output y = onp.bitwise_xor(*x) # Get the gradient of the loss for this input/output pair grads = loss_grad(params, x, y) # Update parameters via gradient descent params = [param - learning_rate * grad for param, grad in zip(params, grads)] # Every 100 iterations, check whether we've solved XOR if not n % 100: print('Iteration {}'.format(n)) if test_all_inputs(inputs, params): break
复制代码


4 jax.jit

虽然我们精心编写的 numpy 代码运行起来效果还行,但对于现代机器学习来说,我们希望这些代码运行得尽可能快。这一般通过在 GPU 或 TPU 等不同的“加速器”上运行代码来实现。JAX 提供了一个 JIT(即时)编译器,它采用标准的 Python / numpy 函数,经编译可以在加速器上高效运行。编译函数还可以避免 Python 解释器的开销,这决定了你是否使用加速器。总的来说,jax.jit 可以显著加速代码运行,且基本上没有编码开销,你需要做的就是让 JAX 为你编译函数。使用 jax.jit 时,即使是微小的神经网络也可以实现相当惊人的加速度:


# Time the original gradient function%timeit loss_grad(params, x, y)loss_grad = jax.jit(jax.grad(loss))# Run once to trigger JIT compilationloss_grad(params, x, y)%timeit loss_grad(params, x, y)
复制代码


10 loops, best of 3: 13.1 ms per loop


1000 loops, best of 3: 862 µs per loop


请注意,JAX 允许我们将变换链接在一起。首先,我们使用 jax.grad 获取了丢失的梯度,然后使用 jax.jit 立即进行编译。这是使 JAX 更强大的一个因素——除了链接 jax.jit 和 jax.grad 之外,我们还可以多次应用 jax.grad 以获得更高阶的导数等。为了确保训练神经网络经过编译后仍然有效,我们再次对它进行训练。请注意,训练代码没有任何变化。


params = initial_params()
for n in itertools.count(): x = inputs[onp.random.choice(inputs.shape[0])] y = onp.bitwise_xor(*x) grads = loss_grad(params, x, y) params = [param - learning_rate * grad for param, grad in zip(params, grads)] if not n % 100: print('Iteration {}'.format(n)) if test_all_inputs(inputs, params): break
复制代码


5 jax.vmap

精明的读者可能已经注意到,我们一直在一个例子上训练我们的神经网络。这是“真正的”随机梯度下降;在实践中,当训练现代机器学习模型时,我们执行“小批量”梯度下降,在梯度下降的每个步骤中,我们对一小批示例中的损失梯度求平均值。JAX 提供了 jax.vmap,这是一个自动“矢量化”函数的转换。这意味着它允许你在输入的某个轴上并行计算函数的输出。对我们来说,这意味着我们可以应用 jax.vmap 函数转换并立即获得损失函数渐变的版本,该版本适用于小批量示例。


jax.vmap 还可接受其他参数:


  • in_axes 是一个元组或整数,它告诉 JAX 函数参数应该对哪些轴并行化。元组应该与 vmap’d 函数的参数数量相同,或者只有一个参数时为整数。示例中,我们将使用(None,0,0),指“不在第一个参数(params)上并行化,并在第二个和第三个参数(x 和 y)的第一个(第零个)维度上并行化”。

  • out_axes 类似于 in_axes,除了它指定了函数输出的哪些轴并行化。我们在例子中使用 0,表示在函数唯一输出的第一个(第零个)维度上进行并行化(损失梯度)。


请注意,我们必须稍微修改一下训练代码——我们需要一次抓取一批数据而不是单个示例,并在应用它们来更新参数之前对批处理中的渐变求平均。


loss_grad = jax.jit(jax.vmap(jax.grad(loss), in_axes=(None, 0, 0), out_axes=0))
params = initial_params()
batch_size = 100
for n in itertools.count(): # Generate a batch of inputs x = inputs[onp.random.choice(inputs.shape[0], size=batch_size)] y = onp.bitwise_xor(x[:, 0], x[:, 1]) # The call to loss_grad remains the same! grads = loss_grad(params, x, y) # Note that we now need to average gradients over the batch params = [param - learning_rate * np.mean(grad, axis=0) for param, grad in zip(params, grads)] if not n % 100: print('Iteration {}'.format(n)) if test_all_inputs(inputs, params): break
复制代码


6 指南

这就是我们将在这个简短的教程中介绍的内容,但这实际上涵盖了大量的 JAX 知识。由于 JAX 主要是 numpy 和 Python,因此你可以利用现有知识,而不必学习基本的新框架或范例。


有关其他资源,请查看 JAX GitHub:


https://github.com/google/jax 上的 notebook 和示例目录。


2019-02-12 08:056557
用户头像

发布了 98 篇内容, 共 66.0 次阅读, 收获喜欢 285 次。

关注

评论

发布
暂无评论
发现更多内容

Oracle 常用SQL语句大全(精),java框架学习顺序

Java 程序员 后端

Redis 变慢了?那你这样试试,不行就捶我,java面试问职业规划

Java 程序员 后端

mysql的 int 类型,刨析返回类型为BigDicemal 类型的奇怪现象

Java 程序员 后端

Offer拿来吧你!秒杀系统?这不是必考的嘛,kafka与rabbitmq面试题

Java 程序员 后端

Redis与MySQL数据双写一致性工程落地案例,java最新技术百度云

Java 程序员 后端

MySQL各种锁详情,实战分析

Java 程序员 后端

微信朋友圈复杂度分析

abingagl

Nginx面试三连问:如何工作?负载均衡策略有哪些

Java 程序员 后端

P8大牛带你细谈架构中的限流与计数器的实现方式

Java 程序员 后端

Redis 中 RDB 和 AOF 持久化有啥区别?看这儿,你就懂了

Java 程序员 后端

Redis入门HelloWorld,java入门视频教程

Java 程序员 后端

MySQL是如何恢复到某一天的某一秒的状态?,现在做Java开发有前途吗

Java 程序员 后端

MySQL面试题:谈谈MySQL 索引,B,java程序员面试算法宝典pdf下载

Java 程序员 后端

mysql的timestamp会存在时区问题?,java技术专家方向

Java 程序员 后端

Nginx架构浅析:为什么不用多线程模型管理连接与处理逻辑业务?

Java 程序员 后端

OpenKruise :SidecarSet 助力 Mesh 容器热升级

Java 程序员 后端

quartz-2,linux视频教程百度云

Java 程序员 后端

MySQL事务:ACID特性的实现原理知多少,java教学视频百度云

Java 程序员 后端

Mysql优化提高笔记整理,来自于一位鹅厂大佬的笔记

Java 程序员 后端

MySQL索引原理B+树,java学习视频百度云盘

Java 程序员 后端

Redis事务详述,java多并发面试题

Java 程序员 后端

Netty编解码开发+多协议开发和应用+源码,Java开发经验谈

Java 程序员 后端

Redis 最全性能监控指标:汇总实战,实战java虚拟机葛一鸣第二版pdf

Java 程序员 后端

Netty相关面试题汇总,java并发编程电子书

Java 程序员 后端

RabbitMQ实现即时通讯居然如此简单!后端代码都省得写了

Java 程序员 后端

Mysql进阶三板斧(一)带你彻底搞懂View视图的原理及应用

Java 程序员 后端

MySQL进阶三板斧(三)看清,java高级框架思维导图

Java 程序员 后端

Redis从入门到精通,至少要看看这篇,java医疗管理系统技术描述

Java 程序员 后端

Redis实战(五)-字符串,kafka基本原理

Java 程序员 后端

MySQL慢查询,一口从天而降的锅!,java程序开发基础彭政答案

Java 程序员 后端

Offer拿来吧你!秒杀系统?这不是必考的嘛(1)

Java 程序员 后端

关于要替代TensorFlow的JAX,你知道多少?_AI&大模型_Colin Raffel_InfoQ精选文章