写点什么

如何利用 TensorFlow Hub 让 BERT 开发更简单?

  • 2021-01-14
  • 本文字数:3818 字

    阅读完需:约 13 分钟

如何利用TensorFlow Hub 让BERT开发更简单?

在自然语言处理领域,BERT 和其他 Transformer 编码器架构都非常成功,无论是推进学术基准的技术水平,还是在 Google Search 这样的大规模应用中,均是如此。BERT 自 TensorFlow 创建以来一直可用,但它最初依赖于非 TensorFlow 的 Python 代码,以将原始文本转换为模型输入。

 

如今,在 TensorFlow 中构建 BERT 会更加简单。开发者可在 TensorFlow Hub 上使用预训练编码器和匹配的文本预处理模型。在 TensorFlow 中运行 BERT 对文本输入的操作只需要几行代码:

 

# Load BERT and the preprocessing model from TF Hub.preprocess = hub.load('https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/1')encoder = hub.load('https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/3')

# Use BERT on a batch of raw text inputs.input = preprocess(['Batch of inputs', 'TF Hub makes BERT easy!', 'More text.'])pooled_output = encoder(input)["pooled_output"]print(pooled_output)

tf.Tensor([[-0.8384154 -0.26902363 -0.3839138 ... -0.3949695 -0.58442086 0.8058556 ] [-0.8223734 -0.2883956 -0.09359277 ... -0.13833837 -0.6251748 0.88950026] [-0.9045408 -0.37877116 -0.7714909 ... -0.5112085 -0.70791864 0.92950743]],shape=(3, 768), dtype=float32
复制代码

 

这些编码器和预处理模型已经用 TensorFlow Model Garden 的 NLP 库构建,并以 SavedModel 格式导出到 TensorFlow Hub。实际上,预处理使用 TF.text 库中的 TensorFlow ops 输入文本进行标记化:允许开发者建立自己的 TensorFlow 模型,将原始文本输入到预测输出,而无需使用 Python 的循环。这样可以提高计算速度,去除样板代码,减少出错的可能性,并且可以将整个文本序列化为输出模型,使得 BERT 在生产环境中更容易使用。

 

为了详细说明这些模型的具体作用,我们发布了两个新的教程:

 

  • 初级教程:解决一项情感分析任务,不需要任何特殊定制,就能得到很好的模型质量。这是最简单的使用 BERT 和预处理模型的方法。

  • 高级教程:解决了在 TPU 上运行 GLUE 基准中的自然语言处理分类任务。它还说明了如何在需要多段输入的情况下使用预处理模型。

 

选择 BERT 模型

 

BERT 模型是在大型文本语料库(例如,Wikipedia 文章的归档)上使用自我监督任务进行预训练的,比如根据上下文预测句子中的单词。这种类型的训练使模型能够在没有标记数据的情况下学习文本语义的强大表示。但是训练它需要大量的计算:在 16 个 TPU 上花费 4 天的时间(如 2018 年 BERT 论文所报道的)。所幸的是,在这种昂贵的预训练完成一次后,我们就可以为许多不同的任务高效地重用这种丰富的表示了。

 

  • 八个 BERT 模型是与 BERT 原始作者发布的训练权重一起提供的。

  • 24 个 Small BERT 具有相同的通用架构,但 Transformer 会更少或更小,这让你可以探索速度、尺寸和质量之间的权衡。

  • ALBERT:这是四种不同大小的“A Lite Bert”,通过在层之间共享参数来减少模型大小(但不是计算时间)。

  • 8 个 BERT Experts 都具有相同的 BERT 架构和大小,但是为预训练域和中间微调任务提供了不同的选择,以便更好地配合目标任务。

  • Electra 具有与 BERT 相同的架构(有三种不同的大小),但在预训练时作为判别器,类似于生成对抗网络(Generative Adversarial Network,GAN)。

  • BERT 与 Talking-Heads Attention 和 Gated GELU [base, large] 对 Transformer 架构的核心进行了两个改进。

  • Lambert 已经接受了一些由 LAMB 优化器和 Roberta 提供的技术训练。

  • .......

 

这些模型是 BERT 编码器。上述链接将可以访问 TF Hub 上的文档,其中提到了各自所使用的正确的预处理模型。我们建议开发者访问这些模型页面,以便了解更多关于每个模型所针对的不同应用场景。基于其通用界面,通过更改编码器模型及其预处理的 URL,可以方便地对不同编码器进行特定任务的性能实验和比较。

预处理模型

 

对于每个 BERT 编码器,都有一个匹配的预处理模型。它使用 TF.text 库提供的 TensorFlow ops,它可以将原始文本转换为编码器所期望的数字输入时序。不像纯 Python 的预处理那样,这些操作可以作为 TensorFlow 模型的一部分,用于直接从文本输入中提供服务。每个 TF Hub 的预处理模型都已经配置了词汇表及其相关的文本归一化逻辑,无需进行进一步的设置。

 

前面我们已经介绍了最简单的预处理模型的使用方法,接下来让我们仔细看看。


preprocess = hub.load('https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/1')input = preprocess(["This is an amazing movie!"]) {'input_word_ids': <tf.Tensor: shape=(1, 128), dtype=int32, numpy=  array([[ 101, 2023, 2003, 2019, 6429, 3185,  999,  102,    0,  ...]])>, 'input_mask': <tf.Tensor: shape=(1, 128), dtype=int32, numpy=  array([[   1,    1,    1,    1,    1,    1,    1,    1,    0,  ...,]])>, 'input_type_ids': <tf.Tensor: shape=(1, 128), dtype=int32, numpy=  array([[   0,    0,    0,    0,    0,    0,    0,    0,    0,  ...,]])>}
复制代码

 

像这样调用 preprocess() 可以将原始文本输入转换为固定长度的 BERT 编码器输入序列。你可以看到,它由一个张量 input_word_ids 组成,其中包含了每个标记化输入的数字 id,包括开始、结束和填充标记,再加上两个辅助张量:一个 input_mask(用于区分非填充和填充标记)和每个标记的 input_type_ids(可以区分每个输入的多个文本段,我们将在下面讨论)。

 

相同的预处理 SavedModel 还提供了更细粒度的 API,支持在编码器的一个输入序列中使用一个或两个不同的文本段。下面我们来看一个句子蕴含任务:

 

text_premises = ["The fox jumped over the lazy dog.",                 "Good day."]tokenized_premises = preprocess.tokenize(text_premises) <tf.RaggedTensor  [[[1996], [4419], [5598], [2058], [1996], [13971], [3899], [1012]],  [[2204], [2154], [1012]]]> text_hypotheses = ["The dog was lazy.",  # Entailed.                   "Axe handle!"]        # Not entailed.tokenized_hypotheses = preprocess.tokenize(text_hypotheses) <tf.RaggedTensor  [[[1996], [3899], [2001], [13971], [1012]],  [[12946], [5047], [999]]]>
复制代码

 

每个标记化的结果是一个数字 token idRaggedTensor,完整地表示每一个文本输入。如果某些前提和假设对太长,无法在下一步用于 BERT 输入的 seq_length 内适应,则可以在这里进行额外的预处理,比如修剪文本段或将其分割成多个编码器输入。

 

然后,将标记化的输入打包为用于 BERT 编码器的固定长度的输入序列:

 

encoder_inputs = preprocess.bert_pack_inputs(   [tokenized_premises, tokenized_hypotheses],   seq_length=18)  # Optional argument, defaults to 128. {'input_word_ids': <tf.Tensor: shape=(2, 18), dtype=int32, numpy=  array([[  101,  1996,  4419,  5598,  2058,  1996, 13971,  3899,  1012,            102,  1996,  3899,  2001, 13971,  1012,   102,     0,     0],         [  101,  2204,  2154,  1012,   102, 12946,  5047,   999,   102,              0,     0,     0,     0,     0,     0,     0,     0,     0]])>, 'input_mask': <tf.Tensor: shape=(2, 18), dtype=int32, numpy=  array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0],         [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]])>, 'input_type_ids': <tf.Tensor: shape=(2, 18), dtype=int32, numpy=  array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0],         [0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]])>}
复制代码

 

打包的结果是已经熟悉的 input_word_idsinput_maskinput_type_ids(第一个和第二个输入分别为 0 和 1)。所有输出都有一个公共的 seq_length(默认为 128)。在打包过程中,超过 seq_length 的输入被截断为大致相等的大小。 

加速模型训练

 

TensorFlow Hub 将 BERT 编码器和预处理模型作为独立的部分,用于加速训练,特别是在 TPU 上。

 

张量处理单元(Tensor Processing Units,TPU)是 Google 定制开发的加速器硬件,它擅长于大规模机器学习计算,比如对 BERT 所需的计算进行微调。TPU 工作在密集的张量上,并期望像字符串这样的可变长度数据,已由主机 CPU 转换为固定大小的张量。

 

由于 BERT 编码器模型与其相关的预处理模型之间的解耦,可以将编码器微调计算作为模型训练的一部分分配给 TPU,而预处理模型则在主机 CPU 上执行。通过使用 tf.data.Dataset.map(),可以在数据集中异步运行预处理计算,并且 TPU 上的编码器模型可以消耗密集的输出。这种异步预处理还可以改善其他加速器的性能。

 

我们的高级 BERT 教程可以在使用 TPU 工作器的 Colab 运行时中运行,并演示了这种端到端的方式。

总结

 

在 TensorFlow 中使用 BERT 和类似的模型已经变得更加简单了。TensorFlow Hub 提供了大量预训练 BERT 编码器文本预处理模型,只需几行代码就能很容易地使用。

 

作者介绍:

 

Arno Eigenwillig,软件工程师。 Luiz GUStavo Martins,开发技术推广工程师。

 

原文链接:

 

https://blog.tensorflow.org/2020/12/making-bert-easier-with-preprocessing-models-from-tensorflow-hub.html

公众号推荐:

2024 年 1 月,InfoQ 研究中心重磅发布《大语言模型综合能力测评报告 2024》,揭示了 10 个大模型在语义理解、文学创作、知识问答等领域的卓越表现。ChatGPT-4、文心一言等领先模型在编程、逻辑推理等方面展现出惊人的进步,预示着大模型将在 2024 年迎来更广泛的应用和创新。关注公众号「AI 前线」,回复「大模型报告」免费获取电子版研究报告。

AI 前线公众号
2021-01-14 09:522129
用户头像
赵钰莹 InfoQ 主编

发布了 874 篇内容, 共 603.9 次阅读, 收获喜欢 2671 次。

关注

评论

发布
暂无评论
发现更多内容

Elasticsearch filter vs. query 对比

escray

elastic 七日更 28天写作 死磕Elasticsearch 60天通过Elastic认证考试 2月春节不断更

APM(应用性能监控) 行业认知系列 - 一

东风微鸣

APM Trace 可观察性

APM 行业认知系列 - 二

东风微鸣

APM Trace 可观察性

APM 行业认知系列 - 四

东风微鸣

APM Trace 可观察性

滚雪球学 Python 番外系列,自动化测试是个啥?

梦想橡皮擦

Python 28天写作 2月春节不断更

可能是Java 8 Optional最佳实践

ES_her0

28天写作

你的面试专属!JVM G1GC的算法+实现,90张图+33段代码

Java架构追梦

Java 架构 JVM 调优 G1GC

MySQL事务浅析|由浅入深

MySQL 编程 架构

Angular性能优化实践——巧用第三方组件和懒加载技术

葡萄城技术团队

angular SpreadJS

LoadRunner测试中遇见的不可思议的问题及其解决方法

陈磊@Criss

2020回顾,2021学习目标

叫练

学习 2021年展望 2020年度总结

SpringBoot之自定义启动异常堆栈信息打印

false℃

架构设计篇之微服务实战笔记(一)

小诚信驿站

架构师 刘晓成 小诚信驿站 28天写作 架构师成长笔记

Elasticsearch踩坑记之深度分页

topsion

大数据 elasticsearch 深度分页

如何 0 改造,让单体/微服务应用成为Serverless Application

阿里巴巴云原生

Docker Serverless 容器 微服务 云原生

面试官:Java性能调优你会多少?一个问题就把我问的哑口无言,哭了!

996小迁

架构 面试 Java性能调优

APM 行业认知系列 - 三

东风微鸣

APM Trace 可观察性

重大更新!一文了解京东通用目标重识别开源库FastReID V1.0

京东科技开发者

AI 监控

【STM32】CubeMX+HAL 点亮 LED

AXYZdong

硬件 stm32 2月春节不断更

【LeetCode】数组的度Java题解

Albert

算法 LeetCode 28天写作 2月春节不断更

IDEA 敏捷开发技巧——后缀完成

程序员小航

Java 后端 IDEA

电子产品中EMC隔离设计的方法

不脱发的程序猿

二月春节不断更 电路设计 EMC 电子产品

Java实体映射利器---MapStruct

是小毛吖

Java MapStruct

著名的Java并发编程大师都这么说了,你还不知道伪共享么!

看点代码再上班

Java 后端

Golang代码测试:一点到面用测试驱动开发

华为云开发者联盟

测试 TDD 代码 Go 语言

大小厂必问Java后端面试题(含答案)

yes

Java 面试 后端

全网最新、最全面蚂蚁金服面经分享:简历模板/面试题库/Java核心技术笔记

比伯

Java 编程 程序员 面试 技术宅

读书总结2020

IT民工大叔

#读书

《经济学人》2021年2月20日刊精彩文章导读及资源下载

wbliu85

产品训练营--第四期作业

曦语

产品训练营

EMC设计中电缆屏蔽使用方法

不脱发的程序猿

二月春节不断更 电路设计 EMC 电子产品 电缆屏蔽

如何利用TensorFlow Hub 让BERT开发更简单?_AI&大模型_Arno Eigenwillig_InfoQ精选文章