【AICon】开辟产业应用新天地,大模型重塑各行各业,精华内容上线58%!>>> 了解详情
写点什么

更小、更快、更便宜、更轻量:开源 DistilBERT,BERT 的精简版本

  • 2019-08-30
  • 本文字数:4284 字

    阅读完需:约 14 分钟

更小、更快、更便宜、更轻量:开源DistilBERT,BERT的精简版本

感兴趣的朋友可以点击此处获取重现 DistilBERT 训练以及 DistilBERT 预训练权重的代码。


过去 18 个月以来,大规模语言模型的迁移学习可谓遍地开花,在几乎所有自然语言处理任务当中都实现了显著的性能改进。


作为以 Vaswani 等人的 Transformer 架构为基础的解决方案,这些经过预训练的语言模型正变得日益庞大,并仍在立足更大的数据集进行训练。英伟达公司的最新模型拥有多达 83 亿个参数:24 倍于 BERT,5 倍于 GPT-2;而 Facebook AI 拿出的 RoBERTa 则利用 160 GB 文本训练而成。



社区中的不少从业者开始怀疑,到底有没有必要训练这些越来越臃肿的 Transformer,毕竟其在训练的经济与环境成本方面已经呈现出失控的状态。通过上图,我们一起来看部分最新大型模型及其参数数量(单位为百万)。


在 Hugging Face,我们亲身体验了这些模型高涨的人气,因为我们的 NLP 库(打包了其中大部分模型)在短短几个月之内就得到超过 40 万次安装。


然而,随着这些模型进入更大的社区,一个重要甚至说极具挑战性的问题开始出现——我们该如何把这些庞然大物投入生产?我们如何在低延迟约束条件下使用这类大型模型?我们是否需要昂贵的 GPU 服务器才能实现大规模服务?


对于许多研究人员及开发人员而言,这可能是个最现实的问题。


为了构建更尊重隐私的系统,我们意识到有必要在边缘位置运行机器学习系统,从而尽可能避免以调用云 API 的方式将个人数据发送至服务器端。这就意味着,我们需要能够在智能手机等小型设备上运行轻、反应灵敏且资源需求量较低的模型版本!


最后但同样重要的是,我们也越来越关注模型扩展过程当中,严苛计算资源需求所带来的环境成本。


那么,我们该如何帮助这些庞然大物成功瘦身?


不少现有技术都有望解决问题。最常见的工具包括量化(对准确率影响较小的网络权重进行近似化)以及权重修剪(删除网络中的某些连接)。对于此类技术,推荐大家参阅 Rasa 发布的BERT量化博文


但我们最终决定专注于模型蒸馏:这是一种能够将大型模型(被称为「老师」)压缩为较小模型(即「学生」)的技术。

知识蒸馏:迁移泛化能力

知识蒸馏(有时也称为师生学习)是一种压缩技术,要求对小型模型进行训练,以使其拥有类似于大型模型(或者模型集合)的行为特征。这项技术由 Bucila 等人提出,并得到了 Hinton 等人的推广。我们这里采用的,正是 Hinton 采取的方法。


在监督学习领域,我们在训练分类模型时往往会利用对数似然信号实现概率最大化(logits 的 softmax),进而预测出正确类。在大多数情况下,性能良好的模型能够利用具有高概率的正确类预测输出分布,同时其它类的发生概率则接近于零。


但是,某些“接近于零”的概率要比其它概率更大,这在一定程度上反映出模型的泛化能力。


例如,把普通椅子误认为扶手椅虽然属于错误,但这种错误远比将其误认为蘑菇来得轻微。这种不确定性,有时被称为“暗知识”。


我们也可以从另一个角度来理解蒸馏——用于防止模型对预测结果太过确定(类似于标签平滑)。


以下为具体实例。在语言建模当中,我们可以通过查看词汇表中的分布轻松观察到这种不确定性。下图为 BERT 对《卡萨布兰卡》电影当中经典台词下一句用词的猜测:



BERT 提出的 20 大高概率用词猜测结果。语言模型确定了两个可能性最高的选项(day 与 life),接下来的词汇相比之下概率要低得多。

我们如何复制这些“暗知识”?

在师生训练当中,我们训练学生网络,用于模拟老师网络的全部输出分布(也就是知识)。


我们通过匹配输出分布的方式训练学生网络,从而实现与老师网络相同的泛化方式。


我们并没有在硬目标上使用交叉熵训练(正确类的独热编码),而是通过软目标(老师概率)将交叉熵从老师处传递给学生。我们的训练损失因此变为:



其中 t 为来自老师的 logit,s 为学生的 logit。


这个损失函数属于更丰富的训练信号,因为单一示例要比单一硬目标拥有更高的强制约束效果。


为了进一步揭示分类结果的质量,Hinton 等人提出了 softmax 温度的概念:



T 为该温度参数。


T → 0 时,分布变为 Kronecker(相当于独热目标矢量);当 T →+∞时,则变为均匀分布。在训练过程中,将相同的温度参数应用于学生与老师网络,即可进一步为每个训练示例揭示更多信号。在推论当中,T 被设置为 1 以恢复标准 Softmax。

PyTorch 编码——压缩 BERT

我们希望利用蒸馏方法对大型语言模型加以压缩。在蒸馏方面,我们使用 Kullback-Leibler 损失函数,因为其拥有相同的优化效果:



在计算关于 q(学生网络分布)的梯度时,我们获得了相同的梯度结果。我们可以利用 PyTorch 实现加快计算速度:



PyTorch 中的知识蒸馏训练步骤。点击此处复制 gist。


利用老师信号,我们能够训练出一套较小的语言模型,我们称之为 DistilBERT,属于 BERT 的监督产物(我们使用 BERT 的英文 bert-base-uncased 版本)。


根据 Hinton 等人的发现,训练损失函数属于蒸馏损失与 masked 语言建模损失的线性组合。我们的学生网络属于 BERT 的一套小型版本,其中删除了 token-type 嵌入与 pooler(用于下一句分类任务),但其余部分架构保持不变,而层数也减少至原本的二分之一。


总体而言,我们的蒸馏模型 DistilBERT 在总体参数数量上约为 BERT 的一半,但在 GLUE 语言理解基准测试中能够保留 95%的 BERT 性能表现。


注 1 — 为什么不降低隐藏层的大小?

将 768 层减至 512 层,意味着总参数量约下降至原本的二分之一。希,在现代框架当中,大多数运算都经过高度优化,而且张量的最终维度(隐藏维度)的变化会对 Transformer 架构(线性分层与层规范化)中的大部分运算产生小幅影响。在我们的实验中,层数对于推理时间的影响要远高于隐藏层的大小。

因此,更小并不代表着一定更快……

注 2 — Tang 等人在蒸馏工作当中,直接在下游任务内使用 L2 距离作为蒸馏损失

我们的早期实验结果表明,在本案例中,交叉熵损失会明显提高性能水平。我们假定在语言建模设置当中,输出空间(词汇表)要明显大于下游任务输出空间的维度。因此,logits 可以在 L2 损失中相互补偿。


训练子网络的核心不只是建立架构,还要求我们为子网络找到正确的初始化方式以实现收敛。因此,我们以作为老师的 Bert 为基础对学生 DistilBERT 进行初始化,将层数削减一半,并采用相同的隐藏大小。


我们还用到了最近 RoBERTa 论文当中提到的一些训练技巧,这也再次证明 BERT 模型的训练方式对其最终表现有着至关重要的影响。与 RoBERTa 类似,我们对 DIstilBERT 进行大批次训练,使用梯度累积(每批最多 4000 个例子)、配合动态遮挡并删除了下一句预测目标。


我们在训练设置中对资源进行了主动限制。我们利用多伦多图书语料库与英语维基百科的串联数据集(与原始 BERT 相同),并配合八块 16 GB V100 GPU 进行了约三天半的训练。


DistilBert 的代码部分来自 Facebook XLM,也有一部分来自我们 PyTorch 版本的 Google AI Bert(可点此获取),以及针对 DistilBert 进行的精心调优。这一切,都是为了更好地重现 BERT 的预测性能。

DistilBERT 模型性能测试

我们将 DistilBERT 在 GLUE 基准测试开发集上的性能与两项基准进行了比较:其一为 BERT 基础(DistilBERT 的老师),其二为来自纽约大学的强大非 transformer 基准——ELMo 上的两个 BilSTM。我们利用纽约大学的 jiant 库获取 ELMo 基准,并使用 pytorch-transformers 获取 BERT 基准。


如下表所示,DistilBERT 的性能与基准相比更好一些,而参数数量只分别相当于二者的一半以及三分之一。在 9 项任务当中,DistilBERT 的 ELMo 基准成绩一直等同或者领先(在 QNLI 上的准确率高出 14%)。DistilBERT 的表现确实远超 BERT:我们保留了 95%以上的性能,同时将参数减少了 40%。



在 GLUE 基准测试开发集中的比较结果以及由作者上报的 ELMo 结果。BERT 与 DistilBERT 结果来自 5 次单独运行后的中位数取值。


在推理时间方面,DistilBERT 比 BERT 快 60%,体积比 BERT 小 60%,比 ELMo + BiLSTM 快 120%且模型体积更小。



为了进一步研究 DistilBERT 的加速/大小平衡点,我们在上表中比较了各个模型的参数数量,以及在 CPU 上完全处理 STS-B 开发集(批量大小为 1)所需要的推理时间。

下游任务:蒸馏与迁移学习

我们进一步研究了 DistilBERT 在有效推理约束下的下游应用效果。我们通过分类任务调优,实现对这套紧凑预训练语言模型的迁移。事实证明,这是种实现蒸馏预训练与迁移学习的好方法!



从 IMDB Review 数据集中提取到的电影评论。


我们选择了 IMDB 影评情感区中的素材,该分区共包含 5 万条英文评论,且标记为正面或负面:我们使用其中 2 万 5 千条进行训练,另外 2 万 5 千条进行测试(同时配合平衡类)。整个训练过程在单一 12 GB K80 上进行。


首先,我们在自己的数据集上训练 bert-base-uncased。我们亲爱的 BERT 老师达到了 99.98%的准确率(3 次运行取平均值)。相当完美!


接下来,我们训练 DistilBERT,使用同样的超参数。压缩模型的准确率达到 99.53%(3 次运行取平均值)。性能的绝对差为 0.5%,延迟降低 60%,大小减少 40%。


NLP 技术的另一种常见应用是问题解答。我们在 SQuAD 1.1 数据集上比较了 BERT bert-base-uncased 版本与 DistilBERT 的结果。在开发集上,BERT 的 F1 得分为 88.5,EM(完全匹配)得分为 81.2。我们利用同样的超参数进行 DistilBERT 训练,F1 分数与 EM 分数分别为 85.1 与 76.5,同 BERT 成绩的差距分别为 3 分与 5 分。


我们还研究了能否在适应阶段利用经过调优的 BERT 作为老师,配合知识蒸馏损失对 DistilBERT 实现 SQuAD 数据集上的调优。


在新案例中,我们将问题回答模型蒸馏为以往通过知识蒸馏预训练完成的语言模型,从而实现调优!这样,老师与学生将能够相互转换。


如此一来,考虑到网络规模,我们能够获得非常有趣的结果:F1 得分为 86.2,EM 得分为 78.1。与完整模型相比,差距保持在 3 分以内!

少即是多:小型模型也能带来理想性能

我们对 DistilBERT 的潜力感到非常兴奋。目前的成果只是刚刚起步,也给我们提出了很多新的问题:我们能够利用知识蒸馏技术将这些模型压缩到怎样的程度?这些技术能否用于进一步理解大型模型中存储的知识?在这类压缩当中,我们损失掉的是哪些语言/语义元素?……


在 HuggingFadce,我们一直将开源与知识共享视为自己的使命。所以,大家可以点击此处访问我们的 GitHub 库,这是我们每个人参与 NLP 深度学习项目并获取卓越成果的最简单、也最公平的方式。


因此,配合本篇博文,我们在pytorch-transformer库当中发布了实验代码(主要是重现训练与 DistilBERT 调优代码)以及一套经过训练的 DistilBERT 版本,感兴趣的朋友可以随意取用。


原文链接:


https://medium.com/huggingface/distilbert-8cf3380435b5


公众号推荐:

2024 年 1 月,InfoQ 研究中心重磅发布《大语言模型综合能力测评报告 2024》,揭示了 10 个大模型在语义理解、文学创作、知识问答等领域的卓越表现。ChatGPT-4、文心一言等领先模型在编程、逻辑推理等方面展现出惊人的进步,预示着大模型将在 2024 年迎来更广泛的应用和创新。关注公众号「AI 前线」,回复「大模型报告」免费获取电子版研究报告。

AI 前线公众号
2019-08-30 08:007525

评论

发布
暂无评论
发现更多内容

开发者问第五期

HMS Core

HMS Core

在实践中学习类型定义、类型覆盖、CSS Modules

小鑫同学

CSS typescript 前端 11月月更

易观分析:2022年Q3中国跨境进口零售电商市场规模为1124.8亿元

易观分析

报告 跨境电商

Kotlin函数和扩展(extension)

子不语Any

kotlin Andrdoid 11月月更

Spring 依赖注入有哪几种方式

千锋IT教育

大数据培训应该怎么学习

小谷哥

java培训程序员就业情况如何

小谷哥

前端培训与自学有什么区别吗

小谷哥

Java程序员的BAT面试之路:数据库事物特性及隔离级别,记得看看

钟奕礼

Java 程序员 java面试 java编程

太全了!神仙级SpringCloud Alibaba笔记,从入门到实战,全方位讲解微服务技术栈!

程序知音

Java 微服务 Spring Cloud 后端技术

牛啊!长这么大还是头一次见24W字的SpringBoot从入门到实战文档

程序知音

Java spring springboot java架构 后端技术

2022年10月中国网约车领域月度观察

易观分析

报告 网约车

人人都可以给想象插上翅膀(内含AI绘画教程)

鼎道智联

openai AI绘画 鼎道智联

【web 开发基础】PHP 中的递归函数 (38)

迷彩

递归 11月月更 PHP递归 递归函数

Kotlin函数声明与默认参数(Default argument)

子不语Any

android kotlin 11月月更

记一次HBASE的故障分析和排查过程

鲸品堂

大数据 11月月更

Baklib知识分享|企业产品需求文档的特点

Baklib

PRD 产品需求文档

大数据培训应该怎么学习

小谷哥

完全解析大数据中MapReduce的运行流程

好程序员IT教育

大数据 MapReduce Service

APISIX Ingress 是如何支持上千个 Pod 副本的应用

API7.ai 技术团队

Kubernetes 容器 api 网关 APISIX

链表剖析及自己手撸"单链表"实现基本操作(初始化、增、删、改等)

C++后台开发

数据结构 链表 linux开发 Linux服务器开发 C++开发

Nacos 中的配置文件如何实现加密传输

江南一点雨

nacos SpringCloud

电容的“通交流、阻直流”,一次讲清楚

元器件秋姐

元器件采购 元器件电商 电容 电容特性 电容知识

服务全球开发者!灵雀云与Ubuntu推出一体化云原生解决方案

York

容器 云原生 操作系统 开源生态

易观千帆 | 2022年10月银行APP月活跃用户规模盘点

易观分析

报告 手机银行

超级自动化行业前景广阔——首个数字化转型国家标准发布:价值体系优化、创新和重构是数字化转型根本任务

九科Ninetech

即时通讯赛道开打信创牌,WorkPlus为何独树一帜?

WorkPlus

web前端培训学习后找工作难吗?

小谷哥

优先级反转那些事儿

字节跳动终端技术

ios QoS 移动开发 优先级反转 turnstile

Kotlin变量和属性

子不语Any

kotlin andiod 11月月更

Prometheus 监测 RocketMQ 最佳实践

阿里巴巴云原生

阿里云 RocketMQ 云原生 Prometheus

更小、更快、更便宜、更轻量:开源DistilBERT,BERT的精简版本_语言 & 开发_Victor Sanh_InfoQ精选文章