写点什么

招商证券 BERT 压缩实践:如何大幅提高模型推断速度?

  • 2020-11-02
  • 本文字数:6277 字

    阅读完需:约 21 分钟

招商证券BERT压缩实践:如何大幅提高模型推断速度?

BERT,全称 Bidirectional Encoder Representation from Transformers,是一款于 2018 年发布,在包括问答和语言理解等多个任务中达到顶尖性能的语言模型。它不仅击败了之前最先进的计算模型,而且在答题方面也有超过人类的表现。

招商证券希望借助 BERT 提升自研 NLP 平台的能力,为旗下智能产品家族赋能。但是,BERT 在工程方面的表现还多少存在着一些问题,推断速度慢正是其中之一。针对这一问题,招商证券信息技术中心 NLP 开发组对 BERT 模型进行了压缩,大幅提高推断速度,从而满足上线要求。

本系列中,作者会从研发思路开始,讲述如何对原始 BERT 进行改造以适应特定的领域方向,同时还会展示具体的模型压缩解决方案及效果对比,和最终的线上效果。本文是系列第一篇,如果你对 NLP 实践感兴趣,这个系列你不容错过!


以 BERT[1]为代表的基于 Transformer 架构的预训练语言模型,将 NLP 各项任务的处理能力提高到了一个新的高度。与此同时,NLP 领域也开始进入了大模型时代,动辄上亿乃至成百上千亿1的参数量,大大提高了训练及部署所需的硬件成本和时间成本,尤其对于线上场景,过大的模型导致了过长的推断时间,会直接导致服务失效。


1 BERT 模型拥有 3.4 亿参数量,3GPT3 模型拥有 1750 亿参数量


为了解决这个问题,我们探索和实践了如何通过模型结构压缩方法减小模型体积,从而大幅提高模型推断速度,降低计算延迟和硬件需求,使线上系统真正享受到 BERT 带来的性能提升红利。


现有神经网络压缩方法主要包括模型裁减、知识蒸馏、模型量化、动态计算、参数共享等多种方向。


在综合考虑了模型压缩方法的实现复杂度和最终推断速度提升能力等因素,结合工程实践经验,我们选用了结合知识蒸馏(knowledge distillation)和模型量化(model quantization)的技术路线。由于篇幅限制,本篇主要分享知识蒸馏方面的内容。


知识蒸馏的概念是在 2015 年由 Hinton 等人提出[2]。在进行知识蒸馏的过程中需要定义一个教师模型和一个学生模型,教师模型是一个训练完成且规模较大的网络,而学生模型是一个未经训练且规模较小的网络,知识蒸馏希望可以将教师模型所学习到的知识传递给学生模型。很多实验都表明,这种知识传递过程比单独训练2学生模型更加有效。


2 这里说的“单独训练”是指在没有教师网络的参与下从零开始训练。


知识蒸馏在提出后便受到广泛关注,特别是在计算机视觉(computer vision)领域出现了非常多的相关工作,这主要是因为计算机视觉领域中的许多模型的参数规模都是非常庞大的。自从 BERT 提出后,也涌现了许多针对 BERT 的知识蒸馏的工作,如 DistillBERT[3]、PKD-BERT[4]、TinyBERT[5]、MobileBERT[6]、FastBERT3[7]、LTD-BERT 等。


3 FastBERT 融合了动态计算和知识蒸馏的思想。


知识蒸馏工作的研究重心大致在以下几个方面:如何设计学生模型、如何设计知识蒸馏的目标函数、在什么阶段实施蒸馏。


1. 如何设计学生模型


以 BERT-base 为例,它由 12 层的 transformer 组成,每一层的宽度 (即 hidden size)为 768,总的参数量为 109M。为了减少其网络规模,可以通过减少模型的层数(高度),或者通过减少模型的宽度来实现。DistillBERT 和 PKD-BERT 只减少了模型的层数,而 MobileBERT 只减少了模型的宽度,而 TinyBERT 既减少了模型的层数也减少了模型的宽度。早期的工作只关注层数的减少,这样做的好处是可以直接使用教师模型中的权重来初始化学生模型,但是减少层数对模型的压缩毕竟是有限的。而减少模型的宽度则意味着无法直接使用教师模型的参数对学生模型进行初始化。此外,MobileBERT 指出减少了宽度,attention head 数目也应该减少。


2. 知识蒸馏中的目标函数


为了实现将教师模型的知识传递到学生模型中,需要设计一些特定的目标函数。首先,Hinton 等人的工作[1]提出应该让学生网络和教师网络在输出的 logit 上尽可能相似,并使用了 T-softmax 和 KL 散度来实现。而后,研究者们指出 transformer 在每一层的输出(包括 embedding)也应该相似4,同时 transformer 中的 MHA(multi-head attention)模块中输出的注意力权重也应该相似。目前面向 BERT 的知识蒸馏的损失函数主要包含以上三种。


4 如果教师网络和学生网络的宽度不一样那么就期望它们投影到一个相同的空间后相似。


此外,也有研究者提出了面向特定任务的目标函数。如对于相似度匹配任务,腾讯提出的 LTD-BERT 会为两个句子分别计算一个句子表示,然后期望教师模型和学生模型在输出的句子表示尽可能相似。对于序列标注任务,期望 CRF(conditional random field)计算过程中的后验概率矩阵相似[8]


3. 预训练阶段和微调阶段的蒸馏


BERT 模型的训练常分为在通用预料上的预训练阶段和在特定任务上的微调阶段,这两个阶段上均可以实施知识蒸馏。预训练阶段的蒸馏显然需要花费非常大的算力,也需要规模庞大的语料,但是从教师模型到学生模型的知识传递过程也更加充分。而微调阶段所需的算力不大,而且只需要特定任务上的数据即可,而且现有的研究表明特定任务上的无标记数据也可以用于蒸馏过程。


是否实施预训练阶段的蒸馏取决于两点:(1) 微调阶段的知识蒸馏所得到的学生模型是否可以满足准确率(或者 f1 值)上的要求。(2) 削减层数是否可以达到所要求延时降低的效果;如果不能达到,那么需要减小模型的宽度,这使得从教师模型到学生模型的初始化无法进行。


蒸馏方法


与学术研究不同,在工程实践中,大部分公司和开发者很难有足够的算力、语料和时间支撑进行预训练阶段的蒸馏,因此我们选择了与 PKD-BERT 类似的蒸馏方法,针对不同任务在微调阶段进行蒸馏:以预训练好的 12 层 transformer encoder 的 RoBERTa 为教师,目标是蒸馏出 6 层及 3 层的学生模型。


首先,对预训练的中文 RoBERTa 进行下游任务微调训练,通过微调后的 RoBERTa(教师)参数5隔层初始化 student 模型:


5 使用 pytorch transformers 包加载模型参数


def load_state_dict_from_teacher(resolved_archive_file, n_teacher_layer, n_student_layer):    state_dict = torch.load(resolved_archive_file, map_location='cpu')   keys = []    layer_div = n_teacher_layer // n_student_layer    for key in state_dict.keys():        new_key = None        layer_i = None        for idx in range(layer_div-1, 12, layer_div):            if f'.{idx}.' in key:                new_idx = (idx + 1) // layer_div - 1                new_key = key.replace(str(idx), str(new_idx))                layer_i = idx                break        if new_key:            keys.append((layer_i, key, new_key))
keys.sort() for _, old_key, new_key in keys: state_dict[new_key] = state_dict.pop(old_key) return state_dict
复制代码


其次,定义目标函数:



无论使用何种模型压缩方法,模型输出和真实标签之间计算的交叉熵损失都是最基本的损失函数。



损失是学生模型和教师模型输出的 logit 上计算的蒸馏损失(KL 散度)。由于原始 logit 分布往往非常稀疏,因此在进行 softmax 时,引入一个超参数温度 T,来控制 logit 的平滑程度,即 T-softmax。如下式所示,T 设置得越大,得到的分布越平滑。



蒸馏损失也从另外的角度阐述了学生模型为什么能更好地从教师模型蒸馏得到了知识:


  • 当类的标记有误时,教师模型可以一定程度上消除这些错误。

  • 教师模型输出的 logit 比原始给定的类标记更加平滑,容易学习。

  • 相比于原始的值为 0/1 的类标记,logit 包含更多的信息量,特别是当类别数目很多的时候。

  • 部分样本的类标记可能根本无法通过输入的特征学习到,这些样本会对模型的训练产生干扰,而以教师模型输出的 logit 为目标进行训练则会消除这些干扰。


的 pytorch 实现如下:


from torch import nnimport torch.nn.functional as F
def distillation_loss(student_scores, labels, teacher_scores, T, alpha): """ student_scores, teacher_scores: [batch_size, n_labels] labels: [batch_size, ] """ if teacher_scores is not None: d_loss_func = nn.KLDivLoss(reduction='mean') d_loss = d_loss_func( F.log_softmax(student_scores/T, dim=1), F.softmax(teacher_scores / T, dim=1) ) * T * T else: assert alpha == 0 d_loss = 0.0 nll_loss = F.cross_entropy(student_scores, labels, reduction='mean') tol_loss = (1.0 - alpha) * nll_loss + alpha * d_loss return tol_loss, nll_loss, d_loss
复制代码


除此之外,我们增加了与 PKD-BERT 类似的学生模型和教师模型在特征表示上的损失,该方法更进一步,在学生模型每一层及对应的教师模型上逐层计算表示损失,计算过程如下式表示:



实现如下,当前我们将各个隐藏层的 MSE 损失进行了简单的算术平均,未来可以改为加权平均值,作为超参手动设置或由模型学习得到。


def patience_loss(student_hiddens, teacher_hiddens, normalize=False):    """    student_hiddens: list([batch_size, seq_length, hidden_size]) __len__ = 4 / 7    teacher_hiddens: list([batch_size, seq_length, hidden_size]) __len__ = 13    """    if normalize:        student_hiddens = [F.normalize(hidden, p=2, dim=-1) for hidden in student_hiddens]        teacher_hiddens = [F.normalize(hidden, p=2, dim=-1) for hidden in teacher_hiddens]        # embedding_loss    pt_loss = F.mse_loss(student_hiddens[0], teacher_hiddens[0])        # hidden_loss    n_student_layer = len(student_hiddens) - 1    n_teacher_layer = len(teacher_hiddens) - 1    assert n_student_layer in (3, 6)    assert n_student_layer in (12, )        N = n_teacher_layer // n_student_layer         for i in range(n_student_layer):        pt_loss += F.mse_loss(student_hiddens[i+1], teacher_hiddens[i*N+1])        return pt_loss / (n_student_layer+1)
复制代码


的计算方法与原始 PKD-BERT 定义的损失有所差别,主要在于包括了 Embedding 上的计算损失,并且在整个序列特征上表示计算,而不仅仅是对于[CLS]标签的单独特征表示。



最后,通过两个超参数调节各个函数比重,获得蒸馏方法的整体目标函数。



图 1:模型结构及目标函数计算图示


测试数据及结果


在测试过程中,我们使用 Chinese-GLUE6 的数据集中的6个下游任务进行微调训练及模型蒸馏:


6 https://github.com/CLUEbenchmark/CLUE


表 1:Chinese-GLUE 数据集中不同任务


任务名称描述
AFQMC金融语义相似度
TNEWS新闻短文本分类
IFLYTEK长文本分类
CMNLI语言推理任务
WSC代词消歧
CSL论文关键词识别


进行蒸馏前后的结果对比如下表:


表 2:知识蒸馏前后模型准确率及推断加速比


ModelAFQMCTNEWSIFLYTEKCMNLI7WSCCSLAVGspeed8
teacher(12)74.1457.6060.4573.8687.9580.572.421.00x
student(6)73.1056.2561.4970.6275.6677.7069.141.95x
student(3)72.1755.2959.9565.7766.4574.4065.673.58x
rbt(3)70.5555.3257.6464.7065.2175.1164.763.58x
majority69.0010.8914.9334.9463.4950.0040.54/


7 使用了前 2 万条数据进行 CMNLI 的微调训练和测试

8 使用 cpu 进行小 batch 推断,通过比较平均耗时获得加速比


可以看到相较于教师模型,学生模型在层数减少一半(6 层)之后,性能下降了 3 个百分点,再减少一半(3 层)后下降 7 个百分点。由于当前蒸馏方法保持了 Embedding 层规模不变,因此推断时的加速比并没有完全与蒸馏程度成正比(在 4 倍压缩比下,加速比为 3.58 倍)。同时,我们还与直接预训练的 3 层 RoBERTa 在各任务下进行微调后的结果进行对比,发现蒸馏模型总体上要优于非蒸馏模型。


此外,我们在 6 层的学生模型基础上进行了消融实验分析蒸馏过程中各个过程的重要程度,结果如下表所示:


表 3:知识蒸馏前后模型准确率及推断加速比


ModelAFQMCTNEWSIFLYTEKCMNLIWSCCSLAVG
teacher(12)74.1457.6060.4573.8687.9580.572.42
student-rand-init(6) $L_{CE}$69.0051.7038.8246.9263.4950.0040.54
student (6) $L_{CE}$72.1356.0660.0668.9773.3674.1067.45
student(6) $L_{CE}+L_{KD}$72.8955.9060.4169.9972.0477.2368.08
student(6) $L_{CE}+L_{KD}+L_{PT}$73.1056.2561.4970.6275.6677.7069.14


可以看出,使用教师模型初始化学生模型,对蒸馏效果的影响最明显,随机初始化学生模型很容易导致模型无法收敛。损失及损失对蒸馏效果都有正面影响,其中损失显得更加有用一些。


至此我们完成了蒸馏的基本过程,将 BERT encoder 层高度缩减为原始的 1/4,蒸馏后模型的推断速度得到了显著提升。但是直接使用该模型在未配备 GPU 的服务器上提供在线服务仍然十分吃力,下一步我们通过应用参数量化等方法进一步压缩模型规模并最终上线,请关注我们下次分享。[fh4]


作者简介


招商证券信息技术中心 NLP 开发组,专注于自然语言处理和人工智能技术在金融科技领域的研究、设计、开发与应用落地。目前已孵化出智能搜索、智能推荐、智能助手、智能选股等多项产品,并采用平台化策略服务公司内外各项智能化需求。


参考文献


[1] J. Devlin, M.-W. Chang, K. Lee, et al. BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding[J]. arXiv preprint arXiv:1810.04805,2018


[2] G. E. Hinton, O. Vinyals and J. Dean. Distilling the Knowledge in a Neural Network[J]. arXiv preprint arXiv:1503.02531,2015


[3] V. Sanh, L. Debut, J. Chaumond, et al. DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter[J]. arXiv preprint arXiv:1910.01108,2019


[4] S. Sun, Y. Cheng, Z. Gan, et al. Patient Knowledge Distillation for BERT Model Compression[J]. arXiv preprint arXiv:1908.09355,2019


[5] X. Jiao, Y. Yin, L. Shang, et al. TinyBERT: Distilling BERT for Natural Language Understanding[J]. arXiv preprint arXiv:1909.10351,2019


[6] Z. Sun, H. Yu, X. Song, et al. MobileBERT: a Compact Task-Agnostic BERT for Resource-Limited Devices[J]. arXiv preprint arXiv:2004.02984,2020


[7] W. Liu, P. Zhou, Z. Zhao, et al. FastBERT: a Self-distilling BERT with Adaptive Inference Time[J]. arXiv preprint arXiv:2004.02178,2020


[8] X. Wang, Y. Jiang, N. Bach, et al. Structure-Level Knowledge Distillation For Multilingual Sequence Labeling[J]. arXiv preprint arXiv:2004.03846,2020


2020-11-02 15:432912
用户头像
陈思 InfoQ编辑

发布了 576 篇内容, 共 247.0 次阅读, 收获喜欢 1281 次。

关注

评论

发布
暂无评论
发现更多内容

Newbe.Claptrap 框架如何实现多级生命周期控制?

newbe36524

架构 微服务 .net core ASP.NET Core

计算机网络基础(十一)---网络层-OSPF协议

书旅

计算机网络 网络 协议栈 OSPF

授人以渔:stm32资料查询技巧

华为云开发者联盟

架构 armv8 芯片 华为云 二进制

华为云GaussDB(DWS)内存知识点,你知道吗?

华为云开发者联盟

数据库 大数据 数据 内存 华为云

架构师训练营第八章-作业1

A Matt

<<前端进阶篇>> PDF 出炉了 — 「阿宝哥」,精心准备的 6 万多字 170 页的前端进阶资料

阿宝哥

大前端

基于 Golang的侵入式 Opentracing实现全链路追踪 ----实践篇

是老郭啊

财务分析与主要的模型

松子(李博源)

第九周

hdhdh

初识分布式:MIT 6.284系列(一)

Kerwin

分布式 MIT 28天写作

英特尔®AI计算盒参考设计发布 加速智能边缘崛起

最新动态

实用!一键生成数据库文档,堪称数据库界的Swagger

程序员小富

Java MySQL

云小课 | IPv4枯了,IPv6来了

华为云开发者联盟

IP 公有云 虚拟私有云 华为云 虚拟化

智能膜切机,解决手机贴膜行业难题

Geek_116789

JVM系列之:JIT中的Virtual Call

程序那些事

Java JVM JIT

ARTS打卡 第10周

引花眠

ARTS 打卡计划

你问我答:微服务治理应该如何去做?

BoCloud博云

容器 微服务 PaaS API 博云

将信将疑,将中台进行到底

郭华

老哥,您看我这篇Java集合,还有机会评优吗?

cxuan

Java 后端

技术管理者带团队的几个实用技巧

Phoenix

团队管理 企业文化 团队 价值观

什么?不写代码也能做功能开发! -RUOYI 教程二

Java_若依框架教程

Java 无代码开发 若依

数据人必须知道的SQL概念(A—Z)

大唐小生

sql 数据 职场成长

微软苏州集体抵制来自阿里、华为的跳槽者:请停止你的“奋斗逼”行为!网友:看到 955 不加班的公司名单,我酸了

程序员生活志

程序员 加班 996

Vue中使用装饰器,我是认真的

前端有的玩

Java Vue 装饰器

秒杀系统问题与方案设计

superman

秒杀 架构总结

谈一谈webpack打包

林浩

Java 大前端 webpack

手写一个重入锁

诸葛小猿

synchronized CAS 重入锁 compareAndSwap ReentrantLock

系统设计系列之如何设计一个短链服务

看山

架构 面试 分布式 架构设计 短链服务

飞天茅台超卖事故:Redis分布式锁请慎用!

程序员生活志

redis 分布式

《深度工作》学习笔记(3)

石云升

学习 深度工作 工作哲学

在人工智能时代追逐的“后浪”

华为云开发者联盟

程序员 AI 开发者 技术社区 华为云

招商证券BERT压缩实践:如何大幅提高模型推断速度?_AI_招商证券信息技术中心NLP开发组_InfoQ精选文章