写点什么

腾讯提出基于协同通道剪枝的深度神经网络压缩新方法,降低模型精度损失

2020 年 5 月 07 日

腾讯提出基于协同通道剪枝的深度神经网络压缩新方法,降低模型精度损失

随着智能设备的普及,深度神经网络在智能手机、可穿戴设备等嵌入式设备上的应用越来越常见,这些嵌入式设备对模型体积、计算性能、功耗等方面都有比较严格的要求。但与此同时,深度学习网络规模却在不断增大,计算复杂度随之增高,严重限制了其在手机等智能设备上的应用。深度学习模型压缩与加速技术就是为了解决这一问题而生。InfoQ 将通过选题的方式逐一介绍各大公司在模型压缩上的技术创新和落地实践。


为了降低深度学习模型在部署阶段的计算开销,往往需要对模型进行压缩。其中,基于通道剪枝(channel pruning)的模型压缩方法对模型部署时的计算方式没有额外要求,因此是目前比较常用的模型压缩方案。目前大多数通道剪枝算法的工作方式是:根据卷积层中各个通道对重构或判别损失函数的各自影响来决定是否保留该通道,但这种方式较少考虑通道之间的相互关系。


腾讯 AI Lab 对这种相互关系在通道剪枝方面的价值做了进一步的研究。在一篇与中国科学院深圳先进技术研究院合作完成的论文《Collaborative Channel Pruning for Deep Networks》中,腾讯 AI Lab 的研究人员为深度网络压缩提出了一种新的协同通道剪枝方法,可进一步降低通道剪枝所带来的模型精度损失。该论文已被 ICML 2019 会议接收。



论文地址:


http://proceedings.mlr.press/v97/peng19c.html


具体来说,研究人员首先分析了卷积层中各个通道对于最终损失函数的联合影响。基于分析结果,将限定计算复杂度下的损失函数最小化问题建模为了一个含约束的 0/1 二次优化问题,并通过对 Hessian 矩阵进行近似,实现了一种快速高效的求解算法。


在 CIFAR-10 和 ILSVRC-12 数据集上的实验评估表明,新提出的协同剪枝算法能在保证分类准确度较小的同时显著降低模型的计算开销;比如在 ILSVRC-12 数据集上,新方法可将 ResNet-50 模型的浮点运算量降低 54.1%,同时 Top-1/5 分类准确度的损失仅为 0.83%和 0.33%。


方法解读

基于通道剪枝的模型压缩方法的基本目标是在限定计算复杂度的情况下(例如限定所保留的通道数量)最小化模型的损失函数,即:



其中,通道剪枝后的模型参数包括两部分(为简便起见,暂不考虑卷积层以外的模型参数):1)卷积层中卷积核的权重,2)用于标记通道保留与否的 0/1 取值的掩码向量(0/1 分别对应删除/保留对应的通道)。通过对掩码向量的 L-0 范数进行限制,可以保证通道剪枝后的各个卷积层中所保留的通道数量符合要求。


以单个卷积层为例。当固定其他模型参数不变时,最终的损失函数可以写成该卷积层中卷积核权重以及对应掩码向量的联合函数:



其中,g 对应于损失函数在原始未剪枝的卷积核处的一阶梯度,而 H 对应于损失函数在此处的二阶 Hessian 矩阵,v 对应于剪枝前后卷积核的变化差值(即是否在卷积核上施加掩码向量)。通过对梯度和 Hessian 矩阵按照通道进行切分,可以将掩码向量的优化问题建模为如下带约束的 0/1 二次优化问题:



其中,u_i 和 s_ij 分别对应于通道 i 和通道组合 (i,j) 对于模型最终损失函数的影响,与掩码向量无关,仅依赖于卷积核的具体权重(以及模型中该卷积层以外的其他模型参数)。



用图(graph)可视化带约束的 0/1 二次优化问题;虚线节点(1 和 5)表示被剪枝的节点,同时相应的边(虚线)也被移除。


该优化问题是一个 NP-hard 问题,难以直接求解,但可以先使用序列二次型优化(SQP,sequential quadratic programming)方法求解该问题的松弛形式,然后对最优解进行二值化处理,以得到最优的掩码向量。此外,Hessian 矩阵的计算和存储复杂度与卷积核中参数量是平方关系,因此难以直接进行计算。


针对不同的损失函数形式(包括用于回归任务的均方误差和用于分类任务的交叉熵误差),腾讯 AI Lab 分别给出了对 Hessian 矩阵进行近似计算的方法,以保证可以快速地构建和求解上述优化问题。具体的 Hessian 矩阵近似以及掩码向量的求解方法参见论文中的对应章节。


以下算法 1 简要总结了新提出的协同剪枝算法的主要工作过程:



实验和结果

为验证新提出的基于协同通道剪枝的模型压缩方法的有效性,腾讯 AI Lab 在 CIFAR-10 和 ILSVRC-12 数据集分别使用 ResNet-56 和 ResNet-50 模型进行了实验评估,结果如下表 1 和表 2 所示。



表 1:在 CIFAR-10 数据集上不同通道剪枝方法的 FLOPs 降低情况与分类准确度损失情况对比



表 2:在 ILSVRC-12 数据集上不同通道剪枝方法的 FLOPs 降低情况与 top-1 和 top-5 分类准确度损失情况对比


从结果上看,新提出的协同通道剪枝方法在相似或者更低的计算复杂度下,可以实现更低的精度损失,而通过在模型的中间层引入辅助分类器(CCP-AC),可以进一步地降低通道剪枝后模型的精度损失。


此外,研究人员还分析了基于序列二次型优化(SQP)方法求解带约束的 0/1 二次型优化问题的收敛情况。下图给出了对于 ResNet-50 模型中部分卷积层,0/1 二次型优化问题中的损失函数与 SQP 方法的迭代次数的关系曲线:



从上图可以看到,对于各个卷积层,SQP 仅需不到 10 次迭代过程就可以快速收敛,求解过程是较为高效的。


总结与展望

这篇论文提出了一种基于协同通道剪枝的深度神经网络压缩方法:通过对卷积层中各个通道对最终损失函数的联合影响进行分析来降低压缩后模型的精度损失。本文将决定通道保留与否的过程建模为一个带约束的 0/1 二次优化问题,并基于近似的 Hessian 矩阵和 SQP 算法给出了一种快速高效的求解算法。


近期一个关注度较高的探索方向结合基于通道剪枝的模型压缩与神经架构搜索(NAS),以降低或去除对预先训练一个预测精度高但计算开销大的复杂模型的需求(例如本文需要在 ILSVRC-12 数据集上预先训练 ResNet-50 模型,再进行模型压缩)。如何在基于模型参数共享的网络结构搜索过程中,借鉴目前模型压缩领域中的某些设计思想,加入适当的约束或者参数共享策略,从而提高搜索得到的网络结构的预测精度与计算效率,也是未来比较值得探讨的研究方向。


延伸阅读:


《深度学习模型压缩技术的落地实践与创新》专题


2020 年 5 月 07 日 11:125034
用户头像
蔡芳芳 InfoQ高级编辑

发布了 548 篇内容, 共 259.9 次阅读, 收获喜欢 1681 次。

关注

评论

发布
暂无评论
发现更多内容

MySQL InnoDB 存储引擎 - 锁

Arthur

ARTS week3

姜海天

二叉树深度优先遍历

封不羁

Java 算法 二叉树

架构师训练营第四周

Melo

架构训练营第四周 - 作业

无心水

极客大学架构师训练营

WPF中的Data Binding调试指南

大白技术控

.net 微软 WPF

测试阶段发现缺陷多怎么办?

洪永潮

Oracle SQL调优系列之看懂执行计划explain

Nicky.Ma

sql

架构师第4周

上山砍柴

极客大学架构师训练营

架构师训练营第四周-总结

无心水

极客大学架构师训练营

​外包公司干了不到3个月,我离职了...(防坑指南)

程序员生活志

程序员 外包 程序员人生 工作经历

面试官:我们来聊下锁吧

java金融

Java 乐观锁 悲观锁

架构师训练营第三周学习总结

lwy

创业一定要学投资

Neco.W

创业 投资

2020年6月26日 查询性能优化

瑞克与莫迪

区块链的应用为什么这么难?出路在哪?

CECBC区块链专委会

比特币 区块链技术 Token 联盟共识

[译]都0202年了,你还觉得go-scheduler很难理解吗?

卓丁

golang golang scheduler go调度 GPM goroutines

极客大学架构师训练营 系统架构 第7课 听课总结

John(易筋)

极客时间 系统架构 高并发 极客大学 极客大学架构师训练营

抖音、腾讯、阿里、美团春招服务端开发岗位硬核面试(完结)

aoho

面试 后端 阿里

Docker基础修炼2--Docker镜像原理及常用命令

黑马腾云

Docker Linux 容器 运维 镜像

新手村:Redis基础补充知识

多选参数

数据库 redis 数据库设计 redis6.0.0

Why Spring ???

猴哥一一 cium

Java spring 源码 Spring Boot 框架设计

区块链系列教程之:比特币中的挖矿

程序那些事

比特币 区块链 挖矿

【总结】企业级案例驱动 打造高可用、高并发、多IDC部署业务中台微服务架构

魔曦

极客大学架构师训练营

ARTS WEEK4

紫枫

ARTS 打卡计划

从0开始设计Flutter独立APP | 第一篇: 数据库与状态管理

渔子长

flutter 前端 跨平台

过早优化是万恶之源

非著名程序员

程序员 程序人生 提升认知 程序员成长

极客大学架构师训练营 框架开发 模式与重构 JUnit、Spring、Hive核心源码解析 第6课

John(易筋)

spring 极客时间 极客大学 极客大学架构师训练营 JUnit

架构师训练营第三周命题作业

lwy

极客大学架构师训练营

基于阿里云服务网格(ASM)的GRPC服务部署实践

韩陆

Kubernetes gRPC Service Mesh

近两年流行面试题:Spring循环依赖问题

Java小咖秀

spring 面试题 ioc

演讲经验交流会|ArchSummit 上海站

演讲经验交流会|ArchSummit 上海站

腾讯提出基于协同通道剪枝的深度神经网络压缩新方法,降低模型精度损失-InfoQ