2天时间,聊今年最热的 Agent、上下文工程、AI 产品创新等话题。2025 年最后一场~ 了解详情
写点什么

运用计算图搭建递归神经网络(RNN)

  • 2019-09-17
  • 本文字数:5544 字

    阅读完需:约 18 分钟

运用计算图搭建递归神经网络(RNN)

继续玩我们的计算图框架。这一次我们运用计算图搭建递归神经网络(RNN,Recursive Neural Network)。RNN 处理前后有承接关系的序列状数据,例如时序数据。当然,前后的承接也不一定是时间上的,但总之是有前后关系的序列。

RNN

RNN 的思想是:网络也分步,每步以输入序列的该步数据(向量)和上一步数据(第一步没有)为输入,进行变换,得到这一步的输出(向量)。这样的话,序列的每一步就会对下一步产生影响。RNN 用变换的参数把握序列每一步之间的关系。最后一步的输出可以送给全连接层,最终用于分类或回归。RNN 有很多种,有一些复杂的变体,本文搭建一种最简单的 RNN ,它的结构是这样的:



蓝色长条表示 m 维输入向量,一共 n 个。这表示数据是长度为 n 的序列,每一步是一个 m 维向量。绿色的矩形就是每一步的变换。yi 是每一步的 k 维输出向量。每一步用 k x k 的权值矩阵 Y 去乘前一步的输出向量(第一步没有),用 k x m 的权值矩阵 W 去乘这一步的输入向量,加和后再加上 k 维偏置向量 b ,施加激活函数 ϕ (我们取 ReLU),就得到这一步的输出。


最后一步的输出也是 k 维向量,把它送给全连接层,最后施加 SoftMax 后得到各个类别的概率,再接上一个交叉熵损失就可以用来训练分类问题了。用我们的计算图框架可以这样搭建这个简单的 RNN(代码):


seq_len = 96  # 序列长度dimension = 16  # 序列每一步的向量维度hidden_dim = 12  # RNN 时间单元的输出维度
# 时间序列变量,每一步一个 dimension 维向量(Variable 节点),保存在数组 input 中input_vectors = []for i in range(seq_len): input_vectors.append(Variable(dim=(dimension, 1), init=False, trainable=False)) # 对于本步输入的权值矩阵W = Variable(dim=(hidden_dim, dimension), init=True, trainable=True)
# 对于上步输入的权值矩阵Y = Variable(dim=(hidden_dim, hidden_dim), init=True, trainable=True)
# 偏置向量b = Variable(dim=(hidden_dim, 1), init=True, trainable=True)
# 构造 RNNlast_step = None # 上一步的输出,第一步没有上一步,先将其置为 Nonefor iv in input_vectors: y = Add(MatMul(W, iv), b)
if last_step is not None: y = Add(MatMul(Y, last_step), y)
y = ReLU(y)
last_step = y

fc1 = fc(y, hidden_dim, 6, "ReLU") # 第一全连接层fc2 = fc(fc1, 6, 2, "None") # 第二全连接层
# 分类概率prob = SoftMax(fc2)
# 训练标签label = Variable((2, 1), trainable=False)
# 交叉熵损失loss = CrossEntropyWithSoftMax(fc2, label)
复制代码


这就是构造 RNN 以及交叉熵损失的计算图的代码,很简单,right ?有了计算图以及自动求导,我们只管搭建网络即可,网络的训练就交给计算图去做了。否则你可以想象,按照示意图表示的计算,推导交叉熵损失对 RNN 的各个权值矩阵和偏置的梯度是多么困难。

时间序列问题

我们构造一份数据,它包含两类时间序列,一类是方波,一类是正弦波,代码如下:


def get_sequence_data(number_of_classes=2, dimension=10, length=10, number_of_examples=1000, train_set_ratio=0.7, seed=42):    """    生成两类序列数据。    """    xx = []    xx.append(np.sin(np.arange(0, 10, 10 / length)))  # 正弦波    xx.append(np.array(signal.square(np.arange(0, 10, 10 / length))))  # 方波

data = [] for i in range(number_of_classes): x = xx[i] for j in range(number_of_examples): sequence = x + np.random.normal(0, 1.0, (dimension, len(x))) # 加入高斯噪声 label = np.array([int(i == j) for j in range(number_of_classes)])
data.append(np.c_[sequence.reshape(1, -1), label.reshape(1, -1)])
# 把各个类别的样本合在一起 data = np.concatenate(data, axis=0)
# 随机打乱样本顺序 np.random.shuffle(data)
# 计算训练样本数量 train_set_size = int(number_of_examples * train_set_ratio) # 训练集样本数量
# 将训练集和测试集、特征和标签分开 return (data[:train_set_size, :-number_of_classes], data[:train_set_size, -number_of_classes:], data[train_set_size:, :-number_of_classes], data[train_set_size:, -number_of_classes:])
复制代码


我们用这一行代码获取长度为 96 ,维度为 16 的两类(各 1000 个)序列:


# 获取两类时间序列:正弦波和方波train_x, train_y, test_x, test_y = get_sequence_data(length=seq_len, dimension=dimension)
复制代码


看一看时间序列样本,先看正弦波:



正弦波序列


这是一个正弦波时间序列样本,它包含 16 条曲线,每一条都是 sin 曲线加噪声。之所以包含 16 条曲线,因为我们的时间序列的每一步是一个 16 维向量,按时间列起来就有了 16 条正弦曲线。正弦波时间序列是我们的正样本。方波时间序列是负样本:



方波序列


一个方波时间序列先维持 +1 一段时间,变为 -1 维持一段时间,再回到 +1 ,循环往复。由于我们的高斯噪声加得较大,可以看到正弦波和方波还是有可能混淆的,但也能看出它们之间的差异。

训练

现在就用我们构造的 RNN 训练一个分类模型,分类正弦波和方波,代码如下:


from sklearn.metrics import accuracy_score
from layer import *from node import *from optimizer import *
seq_len = 96 # 序列长度dimension = 16 # 序列每一步的向量维度hidden_dim = 12 # RNN 时间单元的输出维度
# 获取两类时间序列:正弦波和方波train_x, train_y, test_x, test_y = get_sequence_data(length=seq_len, dimension=dimension)
# 时间序列变量,每一步一个 dimension 维向量(Variable 节点),保存在数组 input 中input_vectors = []for i in range(seq_len): input_vectors.append(Variable(dim=(dimension, 1), init=False, trainable=False)) # 对于本步输入的权值矩阵W = Variable(dim=(hidden_dim, dimension), init=True, trainable=True)
# 对于上步输入的权值矩阵Y = Variable(dim=(hidden_dim, hidden_dim), init=True, trainable=True)
# 偏置向量b = Variable(dim=(hidden_dim, 1), init=True, trainable=True)
# 构造 RNNlast_step = None # 上一步的输出,第一步没有上一步,先将其置为 Nonefor iv in input_vectors: y = Add(MatMul(W, iv), b)
if last_step is not None: y = Add(MatMul(Y, last_step), y)
y = ReLU(y)
last_step = y

fc1 = fc(y, hidden_dim, 6, "ReLU") # 第一全连接层fc2 = fc(fc1, 6, 2, "None") # 第二全连接层
# 分类概率prob = SoftMax(fc2)
# 训练标签label = Variable((2, 1), trainable=False)
# 交叉熵损失loss = CrossEntropyWithSoftMax(fc2, label)
# Adam 优化器optimizer = Adam(default_graph, loss, 0.005, batch_size=16)
# 训练print("start training", flush=True)for e in range(10):
for i in range(len(train_x)): x = np.mat(train_x[i, :]).reshape(dimension, seq_len) for j in range(seq_len): input_vectors[j].set_value(x[:, j]) label.set_value(np.mat(train_y[i, :]).T)
# 执行一步优化 optimizer.one_step()
if i > 1 and (i + 1) % 100 == 0:
# 在测试集上评估模型正确率 probs = [] losses = [] for j in range(len(test_x)): # x = test_x[j, :].reshape(dimension, seq_len) x = np.mat(test_x[j, :]).reshape(dimension, seq_len) for k in range(seq_len): input_vectors[k].set_value(x[:, k]) label.set_value(np.mat(test_y[j, :]).T)
# 前向传播计算概率 prob.forward() probs.append(prob.value.A1)
# 计算损失值 loss.forward() losses.append(loss.value[0, 0])
# print("test instance: {:d}".format(j))
# 取概率最大的类别为预测类别 pred = np.argmax(np.array(probs), axis=1) truth = np.argmax(test_y, axis=1) accuracy = accuracy_score(truth, pred)
default_graph.draw() print("epoch: {:d}, iter: {:d}, loss: {:.3f}, accuracy: {:.2f}%".format(e + 1, i + 1, np.mean(losses), accuracy * 100), flush=True)
复制代码


训练 10 个 epoch 后,测试集上的正确率达到了 99% :


epoch: 1, iter: 100, loss: 0.693, accuracy: 51.08%epoch: 1, iter: 200, loss: 0.692, accuracy: 51.08%epoch: 1, iter: 300, loss: 0.677, accuracy: 78.31%epoch: 1, iter: 400, loss: 0.573, accuracy: 49.31%epoch: 1, iter: 500, loss: 0.520, accuracy: 53.92%epoch: 1, iter: 600, loss: 0.599, accuracy: 97.08%epoch: 1, iter: 700, loss: 0.617, accuracy: 99.00%epoch: 2, iter: 100, loss: 0.601, accuracy: 94.46%epoch: 2, iter: 200, loss: 0.579, accuracy: 82.08%epoch: 2, iter: 300, loss: 0.558, accuracy: 76.15%epoch: 2, iter: 400, loss: 0.531, accuracy: 67.85%epoch: 2, iter: 500, loss: 0.507, accuracy: 63.77%epoch: 2, iter: 600, loss: 0.493, accuracy: 61.15%epoch: 2, iter: 700, loss: 0.479, accuracy: 62.23%epoch: 3, iter: 100, loss: 0.443, accuracy: 69.92%epoch: 3, iter: 200, loss: 0.393, accuracy: 85.85%epoch: 3, iter: 300, loss: 0.365, accuracy: 97.69%epoch: 3, iter: 400, loss: 0.284, accuracy: 95.08%epoch: 3, iter: 500, loss: 0.199, accuracy: 95.69%epoch: 3, iter: 600, loss: 0.490, accuracy: 80.62%epoch: 3, iter: 700, loss: 0.264, accuracy: 94.31%epoch: 4, iter: 100, loss: 0.320, accuracy: 83.46%epoch: 4, iter: 200, loss: 0.333, accuracy: 80.92%epoch: 4, iter: 300, loss: 0.276, accuracy: 90.15%epoch: 4, iter: 400, loss: 0.242, accuracy: 95.00%epoch: 4, iter: 500, loss: 0.217, accuracy: 96.38%epoch: 4, iter: 600, loss: 0.191, accuracy: 95.31%epoch: 4, iter: 700, loss: 0.167, accuracy: 94.00%epoch: 5, iter: 100, loss: 0.142, accuracy: 94.62%epoch: 5, iter: 200, loss: 0.111, accuracy: 96.85%epoch: 5, iter: 300, loss: 0.116, accuracy: 96.85%epoch: 5, iter: 400, loss: 0.080, accuracy: 96.77%epoch: 5, iter: 500, loss: 0.059, accuracy: 98.54%epoch: 5, iter: 600, loss: 0.054, accuracy: 98.54%epoch: 5, iter: 700, loss: 0.042, accuracy: 99.00%epoch: 6, iter: 100, loss: 0.047, accuracy: 98.46%epoch: 6, iter: 200, loss: 0.049, accuracy: 98.08%epoch: 6, iter: 300, loss: 0.030, accuracy: 99.15%epoch: 6, iter: 400, loss: 0.029, accuracy: 99.23%epoch: 6, iter: 500, loss: 0.028, accuracy: 99.08%epoch: 6, iter: 600, loss: 0.029, accuracy: 99.08%epoch: 6, iter: 700, loss: 0.024, accuracy: 99.15%epoch: 7, iter: 100, loss: 0.023, accuracy: 99.15%epoch: 7, iter: 200, loss: 0.031, accuracy: 98.85%epoch: 7, iter: 300, loss: 0.023, accuracy: 99.46%epoch: 7, iter: 400, loss: 0.022, accuracy: 99.54%epoch: 7, iter: 500, loss: 0.022, accuracy: 99.38%epoch: 7, iter: 600, loss: 0.027, accuracy: 98.77%epoch: 7, iter: 700, loss: 0.019, accuracy: 99.46%epoch: 8, iter: 100, loss: 0.018, accuracy: 99.54%epoch: 8, iter: 200, loss: 0.018, accuracy: 99.46%epoch: 8, iter: 300, loss: 0.018, accuracy: 99.54%epoch: 8, iter: 400, loss: 0.018, accuracy: 99.62%epoch: 8, iter: 500, loss: 0.017, accuracy: 99.54%epoch: 8, iter: 600, loss: 0.026, accuracy: 99.00%epoch: 8, iter: 700, loss: 0.021, accuracy: 99.23%epoch: 9, iter: 100, loss: 0.017, accuracy: 99.62%epoch: 9, iter: 200, loss: 0.016, accuracy: 99.54%epoch: 9, iter: 300, loss: 0.015, accuracy: 99.54%epoch: 9, iter: 400, loss: 0.014, accuracy: 99.69%epoch: 9, iter: 500, loss: 0.014, accuracy: 99.62%epoch: 9, iter: 600, loss: 0.014, accuracy: 99.69%epoch: 9, iter: 700, loss: 0.014, accuracy: 99.62%epoch: 10, iter: 100, loss: 0.014, accuracy: 99.54%epoch: 10, iter: 200, loss: 0.014, accuracy: 99.54%epoch: 10, iter: 300, loss: 0.015, accuracy: 99.69%epoch: 10, iter: 400, loss: 0.014, accuracy: 99.69%epoch: 10, iter: 500, loss: 0.013, accuracy: 99.62%epoch: 10, iter: 600, loss: 0.016, accuracy: 99.38%epoch: 10, iter: 700, loss: 0.017, accuracy: 99.38%
复制代码


这就是我们的简单 RNN ,以后有机会我们再尝试搭建类似 LSTM 这种更复杂的 RNN 。


作者介绍


张觉非,本科毕业于复旦大学,硕士毕业于中国科学院大学,先后任职于新浪微博、阿里,目前就职于奇虎 360,任机器学习技术专家。


本文来自 DataFun 社区


原文链接


https://mp.weixin.qq.com/s?__biz=MzU1NTMyOTI4Mw==&mid=2247493606&idx=1&sn=bf89adb739302688e6b837084bff911a&chksm=fbd7558acca0dc9c6a9754975ee796239b5fa2c26a38f604c56d19a5189f8a0febd75698ddd7&scene=27#wechat_redirect


2019-09-17 08:001543

评论

发布
暂无评论
发现更多内容

架构师训练营第五周作业

一剑

一口气说出 OAuth2.0 的四种授权方式

程序员小富

Java oauth2.0

架构师第5周总结

老姜

啃碎并发(二):Java线程的生命周期

猿灯塔

「深度解析」AI训练之数据缓存

焱融科技

人工智能 AI 存储 焱融科技 数据缓存

女同事问哪吒什么是 Spring 循环依赖?我...

通天哪吒

小师妹学JVM之:cache line对代码性能的影响

程序那些事

JVM 小师妹 性能调优 cache line 签约计划第二季

Ceph数据恢复初探

焱融科技

焱融科技 文件存储 分布式存储 数据恢复 Ceph

干货 | 如何评估Kubernetes持久化存储方案

焱融科技

Kubernetes 容器 云原生 k8s

Java 线程池中的线程复用是如何实现的?

武培轩

Java 程序员 后端 线程池 源码解析

超详细!一文带你了解 LVS 负载均衡集群!

JackTian

Linux 负载均衡 运维 LVS 服务器集群

一致性 hash 算法

Z冰红茶

一致性Hash算法

SpringBoot 中使用 Filter 的正确姿势

Java课代表

数据分析师成长体系漫谈--数据埋点

analysis-lion

数据分析 数据采集 埋点

一次非常有意思的 SQL 优化经历: 从 30248.271s 到 0.001s

Java小咖秀

MySQL 面试 后端 经验分享 优化逻辑

第五周作业

Linuxer

极客大学架构师训练营

很多人毕业多年以后,还是改不掉学生思维

小智

职场 思维方式 高考

计算机操作系统基础(十四)---线程同步之条件变量

书旅

php laravel 操作系统 进程 线程’

这份高考卷,只有程序员能得满分...

程序员生活志

程序员 高考

一致性哈希算法实现

老姜

架构师训练营 - 第 5 课总结 -20200704- 技术选型

👑👑merlan

负载均衡 缓存 分布式数据库 架构设计 消息队列

week05 学习总结 分布式缓存&消息队列&负载

Z冰红茶

一致性hash的理解与实现

dongge

架构师训练营第五周课后总结

Cloud.

游戏夜读 | 关卡设计的难点

game1night

架构师训练营第5周

大丁💸💵💴💶🚀🐟

第 5 周作业:一致性 Hash 算法

姜 某某

啃碎并发(三):Java线程上下文切换

猿灯塔

联想来酷广谱化生存:后疫情时代的"硬核品牌"启示录

Geek_116789

用进废退,增加能力熟练度与经验值,让你的技能再次精进。

叶小鍵

【Python】__name__ 是什么?

Leetao

Python Python基础

运用计算图搭建递归神经网络(RNN)_文化 & 方法_DataFunTalk_InfoQ精选文章