写点什么

基于 TensorFlow 2.0 的长短期记忆网络进行多类文本分类

  • 2019-12-12
  • 本文字数:7860 字

    阅读完需:约 26 分钟

基于 TensorFlow 2.0 的长短期记忆网络进行多类文本分类

文本分类是指将给定文本按照其内容判别到一个或多个预先确定的文本类别中的过程。文本分类是一种典型的有监督的学习过程,根据已经被标记的文本集合,通过学习,得到一个文本特征和文本类别之间的关系模型,然后利用这个关系模型对新文本进行类别判断。文本分类计数用于识别文档主题,并将之归类到预先定义的主题或主题集合中。

需要注意的是,多类文本分类与多标签分类并不同,其中多类分类区别于二分类问题,即在 个类别中互斥地选取一个作为输出;而多标签分类,是在 n 个标签中非互斥地选取 个标签作为输出。本文介绍了如何基于 TensorFlow 2.0 的长短期记忆网络进行多类文本分类,非常实用,希望对读者有所启迪。


对自然语言处理(Natural Language Processing,NLP)领域来说,很多创新之处都是关于如何在词向量中加入上下文。常用的方法之一就是使用递归神经网络(Recurrent Neural Networks,RNN)。下面是递归神经网络的概念:


  • 它们利用顺序信息。

  • 它们具备记忆能力,能够记住到目前为止计算过的内容,也就是说,我最后说的内容将影响我接下来要讲的内容。

  • 递归神经网络是文本和语音分析的理想选择。

  • 最常用的递归神经网络是长短期记忆网络(Long-Short Term Memory,LSTM)。



来源:https://colah.github.io/posts/2015-08-Understanding-LSTMs/


上图是递归神经网络的架构。


  • “A” 是前馈神经网络(Feedforward neural network)的一层。

  • 如果我们只看右边的话,它会递归地遍历每个序列的元素。

  • 如果我们将左边展开,它看起来将会跟右边一模一样。


译注: 前馈神经网络(Feedforward neural network),是最早发明、最简单的人工神经网络类型。在它内部,参数从输入层向输出层单向传播。和递归神经网络不通,它内部不会构成有向环。



来源:https://colah.github.io/posts/2015-08-Understanding-LSTMs


假设我们正在解决新闻文章数据集的文档分类问题。


  • 我们输入每个单词,这些单词以某种方式相互关联。

  • 当我们看到文章中所有的单词时,我们会在文章末尾做出预测。

  • 递归神经网络通过传递上一次输出的输入,能够保留信息,并能够在最后利用所有信息进行预测。



来源:https://colah.github.io/posts/2015-08-Understanding-LSTMs


  • 这对于短句很有效,但当我们处理一篇长文章时,将会有一个长期依赖问题。


因此,我们通常不是用普通的递归神经网络,而是使用长短期记忆网络。长短期记忆网络是一种递归神经网络,可以解决这种长期依赖问题。


译注: 长短期记忆网络(Long Short-Term Memory,LSTM),是一种时间递归神经网络,适合于处理和预测时间序列中间隔和延迟相对较长的重要事件。基于长短期记忆网络的系统可以实现机器翻译、视频分析、文档摘要、语音识别、图像识别、手写识别、控制聊天机器人、合成音乐等任务。



在我们的新闻文章文档分类示例中,有这种多对一的关系。输入是单词序列,而输出是单个类或标签。


现在,我们将使用 TensorFlow 2.0Keras,解决一个使用长短期记忆网络的 BBC 新闻文档分类问题。数据集可以点击此链接来获取。


  • 首先,我们导入库,并确保 TensorFlow 是正确的版本。


import csvimport tensorflow as tfimport numpy as npfrom tensorflow.keras.preprocessing.text import Tokenizerfrom tensorflow.keras.preprocessing.sequence import pad_sequencesfrom nltk.corpus import stopwordsSTOPWORDS = set(stopwords.words('english'))print(tf.__version__)
复制代码



  • 将超参数置于顶部,如下所示,便于进行更改和编辑。

  • 届时,我们将会讲解每个超参数是如何工作的。


vocab_size = 5000embedding_dim = 64max_length = 200trunc_type = 'post'padding_type = 'post'oov_tok = '<OOV>'training_portion = .8
复制代码


  • 定义两个包含文章和标签的列表。同时,我们删除了停用词。


articles = []labels = []with open("bbc-text.csv", 'r') as csvfile:    reader = csv.reader(csvfile, delimiter=',')    next(reader)    for row in reader:        labels.append(row[0])        article = row[1]        for word in STOPWORDS:            token = ' ' + word + ' '            article = article.replace(token, ' ')            article = article.replace(' ', ' ')        articles.append(article)print(len(labels))print(len(articles))
复制代码



数据中有 2225 篇新闻文章,我们将它们分为训练集和验证集,根据我们之前设置的参数,80% 用于训练,20% 用于验证。


train_size = int(len(articles) * training_portion)train_articles = articles[0: train_size]train_labels = labels[0: train_size]validation_articles = articles[train_size:]validation_labels = labels[train_size:]print(train_size)print(len(train_articles))print(len(train_labels))print(len(validation_articles))print(len(validation_labels))
复制代码



词法分析器(Tokenizer)为我们承担了所有繁重的工作。在我们的文章中,它将进行标记化,需要 5000 个最常见的单词。oov_token 是在遇到不可见的单词时放入一个特殊的值。这意味着我们希望 <OOV> 用于不在 word_index 中的单词。fit_on_text 将遍历所有文本,并创建如下词典:


tokenizer = Tokenizer(num_words = vocab_size, oov_token=oov_tok)tokenizer.fit_on_texts(train_articles)word_index = tokenizer.word_indexdict(list(word_index.items())[0:10])
复制代码


译注: 词法分析器(Tokenizer),是计算机科学中将字符串行转换为标记(token)串行的过程。进行词法分析的进程或者函数叫作词法分析器(lexical analyzer,简称 lexer),也叫扫描器(scanner)。词法分析器一般以函数的形式存在,供语法分析器调用。



我们可以看到,“”是我们语料库中最常见的令牌,其次是“said”、“mr”等等。


完成标记化之后,下一步就是将这些标记转换为序列列表。下面是已经转换成序列的训练数据中的第 11 篇文章。


train_sequences = tokenizer.texts_to_sequences(train_articles)print(train_sequences[10])



" 图 1"


当我们为自然语言处理训练神经网络时,我们需要相同大小的序列,这就是我们为什么使用填充的原因。如果你查看一下的话,就会发现,我们的 max_length 是 200,所以我们使用 pad_sequences ,将所有文章的长度都设置为 200。结果,你会看到第一篇文章长度为 426,变成了 200;第二篇是 192,也变成了 200。以此类推。


train_padded = pad_sequences(train_sequences, maxlen=max_length, padding=padding_type, truncating=trunc_type)print(len(train_sequences[0]))print(len(train_padded[0]))print(len(train_sequences[1]))print(len(train_padded[1]))print(len(train_sequences[10]))print(len(train_padded[10]))
复制代码



此外,还有 padding_typetruncating_type, 还有所有的 post,例如,第 11 篇文章的长度是 186,我们需要填充到 200,我们就在结尾处开始填充,也就是说,填充了 14 个 0。


print(train_padded[10])



" 图 2"


对于第一篇文章,它的长度为 426,我们需要将其截断到 200,我们就在结尾处截断。


然后,我们对验证序列执行同样的操作。


validation_sequences = tokenizer.texts_to_sequences(validation_articles)validation_padded = pad_sequences(validation_sequences, maxlen=max_length, padding=padding_type, truncating=trunc_type)print(len(validation_sequences))print(validation_padded.shape)
复制代码



现在,我们来看一下标签。因为我们的标签是文本,因此,我们将它们进行标记。在训练时,标签应该是 numpy 数组。所以,我们要将标签列表转换为 numpy 数组,如下所示:


label_tokenizer = Tokenizer()label_tokenizer.fit_on_texts(labels)training_label_seq = np.array(label_tokenizer.texts_to_sequences(train_labels))validation_label_seq = np.array(label_tokenizer.texts_to_sequences(validation_labels))print(training_label_seq[0])print(training_label_seq[1])print(training_label_seq[2])print(training_label_seq.shape)print(validation_label_seq[0])print(validation_label_seq[1])print(validation_label_seq[2])print(validation_label_seq.shape)
复制代码



在训练深度神经网络之前,我们应该探索一下我们的原始文章和填充后的文章是什么样子的。运行下面的代码,我们浏览第 11 篇文章,可以看到,一些单词变成了“”,因为它们没有进入前 5000。


reverse_word_index = dict([(value, key) for (key, value) in word_index.items()])def decode_article(text):    return ' '.join([reverse_word_index.get(i, '?') for i in text])print(decode_article(train_padded[10]))print('---')print(train_articles[10])
复制代码



“图 3”


现在,是实施长短期记忆网络的时候了。


  • 我们构建了一个 tf.keras.Sequential 模型,从嵌入层开始。嵌入层为每个单词存储一个向量。调用时,它将单词索引序列转换为向量序列。经过训练后,具有相似意义的单词,通常会具有相似的向量。

  • 双向包装器(Bidirectional wrapper)与 LSTM 层一起使用,它通过 LSTM 层向前和向后传播输入,然后连接输出。这有助于长短期记忆网络学习长期依赖关系。然后我们将其拟合到密集神经网络(Dense Neural Network)中进行分类。

  • 我们使用 relu 代替 than 函数,因为这两个函数能够彼此很好地相互替代。

  • 我们添加了 6 个单位和 softmax 激活的密集层(Dense Layer)。当我们有多个输出时,softmax 将输出层转换为概率分布。


model = tf.keras.Sequential([    # Add an Embedding layer expecting input vocab of size 5000, and output embedding dimension of size 64 we set at the top    tf.keras.layers.Embedding(vocab_size, embedding_dim),    tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(embedding_dim)),#    tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(32)),    # use ReLU in place of tanh function since they are very good alternatives of each other.    tf.keras.layers.Dense(embedding_dim, activation='relu'),    # Add a Dense layer with 6 units and softmax activation.    # When we have multiple outputs, softmax convert outputs layers into a probability distribution.    tf.keras.layers.Dense(6, activation='softmax')])model.summary()
复制代码



“图 4”


在我们的模型摘要中,我们有嵌入,双向包含长短期记忆网络,然后就是两个密集层(Dense layer)。双向的输出为 128,因为它是我们在长短期记忆网络中输入的两倍。我们也可以堆叠 LSTM 层,但我们发现,结果反而更糟。


print(set(labels))



我们总共有 5 个标签,但因为我们没有对标签进行独热编码(One-hot encode),因此,我们不得不使用


sparse_categorical_crossentropy 作为损失函数,它似乎认为 0 也是一个可能的标签,而词法分析器对象是从整数 1 开始标记化,而不是整数 0。结果,尽管从未使用过 0,但最后一个密集层需要标签 0、1、2、3、4、5 的输出。


如果你希望最后一个密集层为 5,那么你就需要从训练和验证标签中减去 1。我决定保持现状。


我决定训练 10 个轮数,正如你将看到的,这是很多轮数。


model.compile(loss='sparse_categorical_crossentropy', optimizer='adam', metrics=['accuracy'])num_epochs = 10history = model.fit(train_padded, training_label_seq, epochs=num_epochs, validation_data=(validation_padded, validation_label_seq), verbose=2)
复制代码



“图 5”


def plot_graphs(history, string):  plt.plot(history.history[string])  plt.plot(history.history['val_'+string])  plt.xlabel("Epochs")  plt.ylabel(string)  plt.legend([string, 'val_'+string])  plt.show()plot_graphs(history, "accuracy")plot_graphs(history, "loss")
复制代码


图6


我们可能只需 3 到 4 个轮数。在训练结束时,我们可以发现有点过拟合。


在后续文章中,我们将致力于改进这一模型。


你可以在 Github 找到本文的 Jupyter notebook


参考文献:


作者介绍:

Susan Li,是加拿大多伦多的高级数据科学家。她的理想是,每次发表文章,就改变世界。


原文链接:


https://towardsdatascience.com/multi-class-text-classification-with-lstm-using-tensorflow-2-0-d88627c10a35


2019-12-12 08:002939
用户头像
刘燕 InfoQ高级技术编辑

发布了 1112 篇内容, 共 584.5 次阅读, 收获喜欢 1981 次。

关注

评论

发布
暂无评论
发现更多内容

【HarmonyOS Next】鸿蒙循环渲染ForEach,LazyForEach,Repeat使用心得体会

GeorgeGcs

foreach LazyForEach Repeat

【HarmonyOS NEXT】systemDateTime 时间戳转换为时间格式 Date,DateTimeFormat

GeorgeGcs

Date systemDateTime DateTimeFormat

AI之山,鸿蒙之水,画一幅未来之家

脑极体

AI

YashanDB巡检

YashanDB

数据库 yashandb

如何在云效中使用 DeepSeek 等大模型实现 AI 智能评审

阿里云云效

阿里云 云原生 云效

YashanDB dump

YashanDB

数据库 yashandb

【HarmonyOS NEXT】鸿蒙三方应用跳转到系统浏览器

GeorgeGcs

鸿蒙 三方应用 系统浏览器

【HarmonyOS NEXT】鸿蒙应用点9图的处理(draw9patch)

GeorgeGcs

鸿蒙 draw9patch 应用点9图

YashanDB故障诊断架构

YashanDB

数据库 yashandb

如何在云效中使用 DeepSeek 等大模型实现 AI 智能评审

阿里巴巴云原生

阿里云 AI 云原生

【HarmonyOS NEXT】鸿蒙应用实现屏幕录制详解和源码

GeorgeGcs

鸿蒙 源码 应用 屏幕录制 详解

YashanDB健康检查

YashanDB

数据库 yashandb

HarmonyOS NEXT 实现拖动卡片背景模糊效果

威哥爱编程

HarmonyOS HarmonyOS框架 HarmonyOS NEXT

【HarmonyOS Next】鸿蒙TaskPool和Worker详解 (一)

GeorgeGcs

Worker askPool

【HarmonyOS Next】鸿蒙应用公钥和证书MD5指纹的获取

GeorgeGcs

应用公钥 证书MD5指纹 获取

【HarmonyOS Next】拒绝权限二次申请授权处理

GeorgeGcs

拒绝权限 二次申请 授权处理

【HarmonyOS NEXT】设备显示白屏 syswarning happended in XXX

GeorgeGcs

设备显示白屏 syswarning happended in XXX

具身智能:人工智能的革命——从算法智能到物理智能的范式转移

测试人

人工智能

【HarmonyOS NEXT】鸿蒙应用如何进行页面横竖屏切换以及注意事项,自动切换横竖屏,监听横竖屏

GeorgeGcs

鸿蒙应用 横竖屏切换 自动切换横竖屏 监听横竖屏

【HarmonyOS NEXT】解决自定义弹框遮挡气泡提示的问题

GeorgeGcs

自定义弹框 间隙

联合民生证券,探讨AI技术驱动下的财富管理新范式

非凸科技

YashanDB故障诊断概念

YashanDB

数据库 yashandb

【HarmonyOS Next】鸿蒙状态管理V2装饰器详解

GeorgeGcs

鸿蒙状态管理 V2装饰器

【HarmonyOS Next】鸿蒙监听手机按键

GeorgeGcs

鸿蒙 监听 手机按键

Netty源码—客户端接入流程

不在线第一只蜗牛

Java php 服务器

【HarmonyOS Next】 共享HSP和应用内HSP,useNormalizedOHMUrl详解

GeorgeGcs

共享HSP 应用内HSP useNormalizedOHMUrl

【HarmonyOS Next】鸿蒙应用进程和线程详解

GeorgeGcs

鸿蒙 线程 应用进程 详解

【HarmonyOS NEXT】鸿蒙跳转华为应用市场目标APP下载页

GeorgeGcs

鸿蒙 华为应用市场 目标APP下载页 跳转

类似智联招聘/前程无忧,BOSS直聘网站小程序项目源码定制开发搭建

网站,小程序,APP开发定制

【HarmonyOS NEXT】鸿蒙应用使用后台任务之长时任务,解决屏幕录制音乐播放等操作不被挂起

GeorgeGcs

后台任务 长时任务 屏幕录制音乐播放

Redis 高可用方案

天翼云开发者社区

redis

基于 TensorFlow 2.0 的长短期记忆网络进行多类文本分类_AI&大模型_Susan Li_InfoQ精选文章