写点什么

基于 TensorFlow 2.0 的长短期记忆网络进行多类文本分类

  • 2019-12-12
  • 本文字数:7860 字

    阅读完需:约 26 分钟

基于 TensorFlow 2.0 的长短期记忆网络进行多类文本分类

文本分类是指将给定文本按照其内容判别到一个或多个预先确定的文本类别中的过程。文本分类是一种典型的有监督的学习过程,根据已经被标记的文本集合,通过学习,得到一个文本特征和文本类别之间的关系模型,然后利用这个关系模型对新文本进行类别判断。文本分类计数用于识别文档主题,并将之归类到预先定义的主题或主题集合中。

需要注意的是,多类文本分类与多标签分类并不同,其中多类分类区别于二分类问题,即在 个类别中互斥地选取一个作为输出;而多标签分类,是在 n 个标签中非互斥地选取 个标签作为输出。本文介绍了如何基于 TensorFlow 2.0 的长短期记忆网络进行多类文本分类,非常实用,希望对读者有所启迪。


对自然语言处理(Natural Language Processing,NLP)领域来说,很多创新之处都是关于如何在词向量中加入上下文。常用的方法之一就是使用递归神经网络(Recurrent Neural Networks,RNN)。下面是递归神经网络的概念:


  • 它们利用顺序信息。

  • 它们具备记忆能力,能够记住到目前为止计算过的内容,也就是说,我最后说的内容将影响我接下来要讲的内容。

  • 递归神经网络是文本和语音分析的理想选择。

  • 最常用的递归神经网络是长短期记忆网络(Long-Short Term Memory,LSTM)。



来源:https://colah.github.io/posts/2015-08-Understanding-LSTMs/


上图是递归神经网络的架构。


  • “A” 是前馈神经网络(Feedforward neural network)的一层。

  • 如果我们只看右边的话,它会递归地遍历每个序列的元素。

  • 如果我们将左边展开,它看起来将会跟右边一模一样。


译注: 前馈神经网络(Feedforward neural network),是最早发明、最简单的人工神经网络类型。在它内部,参数从输入层向输出层单向传播。和递归神经网络不通,它内部不会构成有向环。



来源:https://colah.github.io/posts/2015-08-Understanding-LSTMs


假设我们正在解决新闻文章数据集的文档分类问题。


  • 我们输入每个单词,这些单词以某种方式相互关联。

  • 当我们看到文章中所有的单词时,我们会在文章末尾做出预测。

  • 递归神经网络通过传递上一次输出的输入,能够保留信息,并能够在最后利用所有信息进行预测。



来源:https://colah.github.io/posts/2015-08-Understanding-LSTMs


  • 这对于短句很有效,但当我们处理一篇长文章时,将会有一个长期依赖问题。


因此,我们通常不是用普通的递归神经网络,而是使用长短期记忆网络。长短期记忆网络是一种递归神经网络,可以解决这种长期依赖问题。


译注: 长短期记忆网络(Long Short-Term Memory,LSTM),是一种时间递归神经网络,适合于处理和预测时间序列中间隔和延迟相对较长的重要事件。基于长短期记忆网络的系统可以实现机器翻译、视频分析、文档摘要、语音识别、图像识别、手写识别、控制聊天机器人、合成音乐等任务。



在我们的新闻文章文档分类示例中,有这种多对一的关系。输入是单词序列,而输出是单个类或标签。


现在,我们将使用 TensorFlow 2.0Keras,解决一个使用长短期记忆网络的 BBC 新闻文档分类问题。数据集可以点击此链接来获取。


  • 首先,我们导入库,并确保 TensorFlow 是正确的版本。


import csvimport tensorflow as tfimport numpy as npfrom tensorflow.keras.preprocessing.text import Tokenizerfrom tensorflow.keras.preprocessing.sequence import pad_sequencesfrom nltk.corpus import stopwordsSTOPWORDS = set(stopwords.words('english'))print(tf.__version__)
复制代码



  • 将超参数置于顶部,如下所示,便于进行更改和编辑。

  • 届时,我们将会讲解每个超参数是如何工作的。


vocab_size = 5000embedding_dim = 64max_length = 200trunc_type = 'post'padding_type = 'post'oov_tok = '<OOV>'training_portion = .8
复制代码


  • 定义两个包含文章和标签的列表。同时,我们删除了停用词。


articles = []labels = []with open("bbc-text.csv", 'r') as csvfile:    reader = csv.reader(csvfile, delimiter=',')    next(reader)    for row in reader:        labels.append(row[0])        article = row[1]        for word in STOPWORDS:            token = ' ' + word + ' '            article = article.replace(token, ' ')            article = article.replace(' ', ' ')        articles.append(article)print(len(labels))print(len(articles))
复制代码



数据中有 2225 篇新闻文章,我们将它们分为训练集和验证集,根据我们之前设置的参数,80% 用于训练,20% 用于验证。


train_size = int(len(articles) * training_portion)train_articles = articles[0: train_size]train_labels = labels[0: train_size]validation_articles = articles[train_size:]validation_labels = labels[train_size:]print(train_size)print(len(train_articles))print(len(train_labels))print(len(validation_articles))print(len(validation_labels))
复制代码



词法分析器(Tokenizer)为我们承担了所有繁重的工作。在我们的文章中,它将进行标记化,需要 5000 个最常见的单词。oov_token 是在遇到不可见的单词时放入一个特殊的值。这意味着我们希望 <OOV> 用于不在 word_index 中的单词。fit_on_text 将遍历所有文本,并创建如下词典:


tokenizer = Tokenizer(num_words = vocab_size, oov_token=oov_tok)tokenizer.fit_on_texts(train_articles)word_index = tokenizer.word_indexdict(list(word_index.items())[0:10])
复制代码


译注: 词法分析器(Tokenizer),是计算机科学中将字符串行转换为标记(token)串行的过程。进行词法分析的进程或者函数叫作词法分析器(lexical analyzer,简称 lexer),也叫扫描器(scanner)。词法分析器一般以函数的形式存在,供语法分析器调用。



我们可以看到,“”是我们语料库中最常见的令牌,其次是“said”、“mr”等等。


完成标记化之后,下一步就是将这些标记转换为序列列表。下面是已经转换成序列的训练数据中的第 11 篇文章。


train_sequences = tokenizer.texts_to_sequences(train_articles)print(train_sequences[10])



" 图 1"


当我们为自然语言处理训练神经网络时,我们需要相同大小的序列,这就是我们为什么使用填充的原因。如果你查看一下的话,就会发现,我们的 max_length 是 200,所以我们使用 pad_sequences ,将所有文章的长度都设置为 200。结果,你会看到第一篇文章长度为 426,变成了 200;第二篇是 192,也变成了 200。以此类推。


train_padded = pad_sequences(train_sequences, maxlen=max_length, padding=padding_type, truncating=trunc_type)print(len(train_sequences[0]))print(len(train_padded[0]))print(len(train_sequences[1]))print(len(train_padded[1]))print(len(train_sequences[10]))print(len(train_padded[10]))
复制代码



此外,还有 padding_typetruncating_type, 还有所有的 post,例如,第 11 篇文章的长度是 186,我们需要填充到 200,我们就在结尾处开始填充,也就是说,填充了 14 个 0。


print(train_padded[10])



" 图 2"


对于第一篇文章,它的长度为 426,我们需要将其截断到 200,我们就在结尾处截断。


然后,我们对验证序列执行同样的操作。


validation_sequences = tokenizer.texts_to_sequences(validation_articles)validation_padded = pad_sequences(validation_sequences, maxlen=max_length, padding=padding_type, truncating=trunc_type)print(len(validation_sequences))print(validation_padded.shape)
复制代码



现在,我们来看一下标签。因为我们的标签是文本,因此,我们将它们进行标记。在训练时,标签应该是 numpy 数组。所以,我们要将标签列表转换为 numpy 数组,如下所示:


label_tokenizer = Tokenizer()label_tokenizer.fit_on_texts(labels)training_label_seq = np.array(label_tokenizer.texts_to_sequences(train_labels))validation_label_seq = np.array(label_tokenizer.texts_to_sequences(validation_labels))print(training_label_seq[0])print(training_label_seq[1])print(training_label_seq[2])print(training_label_seq.shape)print(validation_label_seq[0])print(validation_label_seq[1])print(validation_label_seq[2])print(validation_label_seq.shape)
复制代码



在训练深度神经网络之前,我们应该探索一下我们的原始文章和填充后的文章是什么样子的。运行下面的代码,我们浏览第 11 篇文章,可以看到,一些单词变成了“”,因为它们没有进入前 5000。


reverse_word_index = dict([(value, key) for (key, value) in word_index.items()])def decode_article(text):    return ' '.join([reverse_word_index.get(i, '?') for i in text])print(decode_article(train_padded[10]))print('---')print(train_articles[10])
复制代码



“图 3”


现在,是实施长短期记忆网络的时候了。


  • 我们构建了一个 tf.keras.Sequential 模型,从嵌入层开始。嵌入层为每个单词存储一个向量。调用时,它将单词索引序列转换为向量序列。经过训练后,具有相似意义的单词,通常会具有相似的向量。

  • 双向包装器(Bidirectional wrapper)与 LSTM 层一起使用,它通过 LSTM 层向前和向后传播输入,然后连接输出。这有助于长短期记忆网络学习长期依赖关系。然后我们将其拟合到密集神经网络(Dense Neural Network)中进行分类。

  • 我们使用 relu 代替 than 函数,因为这两个函数能够彼此很好地相互替代。

  • 我们添加了 6 个单位和 softmax 激活的密集层(Dense Layer)。当我们有多个输出时,softmax 将输出层转换为概率分布。


model = tf.keras.Sequential([    # Add an Embedding layer expecting input vocab of size 5000, and output embedding dimension of size 64 we set at the top    tf.keras.layers.Embedding(vocab_size, embedding_dim),    tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(embedding_dim)),#    tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(32)),    # use ReLU in place of tanh function since they are very good alternatives of each other.    tf.keras.layers.Dense(embedding_dim, activation='relu'),    # Add a Dense layer with 6 units and softmax activation.    # When we have multiple outputs, softmax convert outputs layers into a probability distribution.    tf.keras.layers.Dense(6, activation='softmax')])model.summary()
复制代码



“图 4”


在我们的模型摘要中,我们有嵌入,双向包含长短期记忆网络,然后就是两个密集层(Dense layer)。双向的输出为 128,因为它是我们在长短期记忆网络中输入的两倍。我们也可以堆叠 LSTM 层,但我们发现,结果反而更糟。


print(set(labels))



我们总共有 5 个标签,但因为我们没有对标签进行独热编码(One-hot encode),因此,我们不得不使用


sparse_categorical_crossentropy 作为损失函数,它似乎认为 0 也是一个可能的标签,而词法分析器对象是从整数 1 开始标记化,而不是整数 0。结果,尽管从未使用过 0,但最后一个密集层需要标签 0、1、2、3、4、5 的输出。


如果你希望最后一个密集层为 5,那么你就需要从训练和验证标签中减去 1。我决定保持现状。


我决定训练 10 个轮数,正如你将看到的,这是很多轮数。


model.compile(loss='sparse_categorical_crossentropy', optimizer='adam', metrics=['accuracy'])num_epochs = 10history = model.fit(train_padded, training_label_seq, epochs=num_epochs, validation_data=(validation_padded, validation_label_seq), verbose=2)
复制代码



“图 5”


def plot_graphs(history, string):  plt.plot(history.history[string])  plt.plot(history.history['val_'+string])  plt.xlabel("Epochs")  plt.ylabel(string)  plt.legend([string, 'val_'+string])  plt.show()plot_graphs(history, "accuracy")plot_graphs(history, "loss")
复制代码


图6


我们可能只需 3 到 4 个轮数。在训练结束时,我们可以发现有点过拟合。


在后续文章中,我们将致力于改进这一模型。


你可以在 Github 找到本文的 Jupyter notebook


参考文献:


作者介绍:

Susan Li,是加拿大多伦多的高级数据科学家。她的理想是,每次发表文章,就改变世界。


原文链接:


https://towardsdatascience.com/multi-class-text-classification-with-lstm-using-tensorflow-2-0-d88627c10a35


2019-12-12 08:002758
用户头像
刘燕 InfoQ高级技术编辑

发布了 1112 篇内容, 共 546.1 次阅读, 收获喜欢 1978 次。

关注

评论

发布
暂无评论
发现更多内容

音视频开发进阶|第六讲:色彩和色彩空间·下篇

ZEGO即构

音视频开发 色彩

元宇宙场景技术实践|实现“虚拟人”自由

ZEGO即构

一种基于Prompt的通用信息抽取(UIE)框架

阿里技术

深度学习 信息抽取

阿里云E-HPC+i4p大内存实例,加速寻因生物单细胞数据分析效率

阿里云弹性计算

HPC

阿里云Imagine Computing创新技术大赛正式开启!

阿里云CloudImagine

阿里云 技术大赛

浅谈:数字资产永续合约交易所开发有什么好处?

W13902449729

合约交易所开发 区块链交易所开发

React核心工作原理

xiaofeng

React

react源码中的生命周期和事件系统

flyzz177

React

聊聊前端开发中的 Ghost Design 设计思路

汪子熙

前端开发 angular web开发 SAP 11月月更

深入浅出分布式,阿里大牛手写《分布式核心原理》Github一夜爆火

Java永远的神

分布式 程序人生 分布式计算 分布式系统 分布式存储

Oracle、MySQL等数据库故障处理优质文章分享 | 10月文章汇总

墨天轮

MySQL 数据库 oracle 性能优化 故障恢复

OpenHarmony移植案例: build lite源码分析之hb命令__entry__.py

华为云开发者联盟

鸿蒙 芯片 华为云 源代码 企业号十月 PK 榜

一本书,带你走出Spring新手村

博文视点Broadview

美团前端常考手写面试题(边面边更)

helloworld1024fd

JavaScript

想会用synchronized锁,先掌握底层核心原理

华为云开发者联盟

开发 华为云 企业号十月 PK 榜

动手实践丨使用华为云IoT边缘体验“边云协同”

华为云开发者联盟

云计算 华为云 企业号十月 PK 榜

React组件设计模式-纯组件,函数组件,高阶组件

xiaofeng

React

React性能优化的8种方式

xiaofeng

React

最近面试经常被问到的js手写题

helloworld1024fd

JavaScript

假如面试官要你手写一个promise

helloworld1024fd

JavaScript

react源码中的协调与调度

flyzz177

React

CIO们开始将软件供应链升级为安全优先级top

SEAL安全

DevOps 开源软件 软件供应链 SBOM 软件供应链安全

共筑使能千行百业的数字底座 | HDC 2022松湖对话顺利召开

OpenHarmony开发者

OpenHarmony

重磅!涛思数据发布TDengine PI连接器

TDengine

数据库 tdengine 时序数据库

React的5种高级模式

夏天的味道123

React

react源码中的hooks

flyzz177

React

写个JS深拷贝,面试备用

helloworld1024fd

JavaScript

React组件复用的技巧

夏天的味道123

React

React组件复用的发展史

夏天的味道123

React

从华泰证券年报看数字化转型的平台化趋势

王和全

数字化转型 数字化 华泰证券 平台化

走进 Orca 架构及技术世界

KaiwuDB

数据库·

基于 TensorFlow 2.0 的长短期记忆网络进行多类文本分类_AI&大模型_Susan Li_InfoQ精选文章