【AICon】探索八个行业创新案例,教你在教育、金融、医疗、法律等领域实践大模型技术! >>> 了解详情
写点什么

深度解析 MegEngine 亚线性显存优化技术

  • 2020-05-09
  • 本文字数:6360 字

    阅读完需:约 21 分钟

深度解析MegEngine亚线性显存优化技术

基于梯度检查点的亚线性显存优化方法[1]由于较高的计算/显存性价比受到关注。MegEngine 经过工程扩展和优化,发展出一套行之有效的加强版亚线性显存优化技术,既可在计算存储资源受限的条件下,轻松训练更深的模型,又可使用更大 batch size,进一步提升模型性能,稳定 batchwise 算子。

使用 MegEngine 训练 ResNet18/ResNet50,显存占用分别最高降低 23%/40%;在更大的 Bert 模型上,降幅更是高达 75%,而额外的计算开销几乎不变。

该技术已在 MegEngine 开源,欢迎大家上手使用:https://github.com/MegEngine


深度神经网络训练是一件复杂的事情,它体现为模型的时间复杂度和空间复杂度,分别对应着计算和内存;而训练时内存占用问题是漂浮在深度学习社区上空的一块乌云,如何拨云见日,最大降低神经网络训练的内存占用,是一个绕不开的课题。


GPU 显卡等硬件为深度学习提供了必需的算力,但硬件自身有限的存储,限制了可训练模型的尺寸,尤其是大型深度网络,由此诞生出一系列相关技术,比如亚线性显存优化、梯度累加、混合精度训练、分布式训练,进行 GPU 显存优化。


其中,亚线性显存优化方法[1]由于较高的计算/显存性价比备受关注;旷视基于此,经过工程扩展和优化,发展出加强版的 MegEngine 亚线性显存优化技术,轻松把大模型甚至超大模型装进显存,也可以毫无压力使用大 batch 训练模型。


这里将围绕着深度学习框架 MegEngine 亚线性显存优化技术的工程实现和实验数据,从技术背景、原理、使用、展望等多个方面进行首次深入解读。

背 景

在深度学习领域中,随着训练数据的增加,需要相应增加模型的尺寸和复杂度,进行模型「扩容」;而 ResNet [2] 等技术的出现在算法层面扫清了训练深度模型的障碍。不断增加的数据和持续创新的算法给深度学习框架带来了新挑战,能否在模型训练时有效利用有限的计算存储资源,尤其是减少 GPU 显存占用,是评估深度学习框架性能的重要指标。


在计算存储资源一定的情况下,深度学习框架有几种降低显存占用的常用方法,其示例如下:


  • 通过合适的梯度定义,让算子的梯度计算不再依赖于前向计算作为输入,从而 in-place 地完成算子的前向计算,比如 Sigmoid、Relu 等;

  • 在生命周期没有重叠的算子之间共享显存;

  • 通过额外的计算减少显存占用,比如利用梯度检查点重新计算中间结果的亚线性显存优化方法[1];

  • 通过额外的数据传输减少显存占用,比如把暂时不用的数据从 GPU 交换到 CPU,需要时再从 CPU 交换回来。


上述显存优化技术在 MegEngine 中皆有不同程度的实现,这里重点讨论基于梯度检查点的亚线性显存优化技术。

原 理

一个神经网络模型所占用的显存空间大体分为两个方面:1)模型本身的参数,2)模型训练临时占用的空间,包括参数的梯度、特征图等。其中最大占比是 2)中以特征图形式存在的中间结果,比如,从示例[1]可知,根据实现的不同,从 70%到 90%以上的显存用来存储特征图。


这里的训练过程又可分为前向计算,反向计算和优化三个方面,其中前向计算的中间结果最占显存,还有反向计算的梯度。第 1)方面模型自身的参数内存占用最小。


MegEngine 加强版亚线性显存优化技术借鉴了[1]的方法,尤其适用于计算存储资源受限的情况,比如一张英伟达 2080Ti,只有 11G 的显存;而更贵的 Tesla V100,最大显存也只有 32G。


图1:亚线性显存优化原理,其中 (b) 保存了Relu结果,实际中Relu结果可用in-place计算


图 1(a) 给出了卷积神经网络的基本单元,它由 Conv-BN-Relu 组成。可以看到,反向计算梯度的过程依赖于前向计算获取的中间结果,一个网络需要保存的中间结果与其大小成正比,即显存复杂度为 O(n)。


本质上,亚线性显存优化方法是以时间换空间,以计算换显存,如图 1(b) 所示,它的算法原理如下:


  • 选取神经网络中 k 个检查点,从而把网络分成 k 个 block,需要注意的是,初始输入也作为一个检查点;前向计算过程中只保存检查点处的中间结果;

  • 反向计算梯度的过程中,首先从相应检查点出发,重新计算单个 block 需要的中间结果,然后计算 block 内部各个 block 的梯度;不同 block 的中间结果计算共享显存。这种方法有着明显的优点,即大幅降低了模型的空间复杂度,同时缺点是增加了额外的计算:

  • 显存占用从 O(n)变成 O(n/k)+ O(k),O(n/k)代表计算单个节点需要的显存,O(k)代表 k 个检查点需要的显存, 取 k=sqrt(n),O(n/k)+ O(k)~O(sqrt(n)),可以看到显存占用从线性变成了亚线性;

  • 因为在反向梯度的计算过程中需要从检查点恢复中间结果,整体需要额外执行一次前向计算。

工 程

在[1]的基础上,MegEngine 结合自身实践,做了工程扩展和优化,把亚线性显存优化方法扩展至任意的计算图,并结合其它常见的显存优化方法,发展出一套行之有效的加强版亚线性显存优化技术。


亚线性优化方法采用简单的网格搜索(grid search)选择检查点,MegEngine 在此基础上增加遗传算法,采用边界移动、块合并、块分裂等策略,实现更细粒度的优化,进一步降低了显存占用。


如图 2 所示,采用型号为 2080Ti 的 GPU 训练 ResNet50,分别借助基准、亚线性、亚线性+遗传算法三种显存优化策略,对比了可使用的最大 batch size。仅使用亚线性优化,batch size 从 133 增至 211,是基准的 1.6x;而使用亚线性+遗传算法联合优化,batch size 进一步增至 262,较基准提升 2x。



图 2:三种显存优化方法优化 batch size 的对比:ResNet50


通过选定同一模型、给定 batch size,可以更好地观察遗传算法优化显存占用的情况。如图 3 所示,随着迭代次数的增加,遗传算法逐渐收敛显存占用,并在第 5 次迭代之后达到一个较稳定的状态。



图 3:遗传算法收敛示意图


此外,MegEngine 亚线性优化技术通过工程改良,不再局限于简单的链状结构和同质计算节点, 可用于任意的计算图,计算节点也可异质,从而拓展了技术的适用场景;并可配合上述显存优化方法,进一步降低模型的显存占用。

实 验

MegEngine 基于亚线性显存技术开展了相关实验,这里固定 batch size=64,在 ResNet18 和 ResNet50 两个模型上,考察模型训练时的显存占用和计算时间。


如图 4 所示,相较于基准实现,使用 MegEngine 亚线性显存技术训练 ResNet18 时,显存占用降低 32%, 计算时间增加 24%;在较大的 ReNet50 上,显存占用降低 40%,计算时间增加 25%。同时经过理论分析可知,模型越大,亚线性显存优化的效果越明显,额外的计算时间则几乎不变。



图 4:MegEngine 亚线性优化技术实验显存/时间对比:ReNet18/ReNet50


在更大模型 Bert 上实验数据表明,借助 MegEngine 亚线性显存技术,显存占用最高降低 75%,而计算时间仅增加 23%,这与理论分析相一致。有兴趣的同学可前往 MegEngine ModeHub 试手更多模型实验:https://megengine.org.cn/model-hub/

使 用

MegEngine 官网提供了亚线性显存优化技术的使用文档。当你的 GPU 显存有限,苦于无法训练较深、较大的神经网络模型,或者无法使用大 batch 进一步提升深度神经网络的性能,抑或想要使 batchwise 算子更加稳定,那么,MegEngine 亚线性显存优化技术正是你需要的解决方案。


上手 MegEngine 亚线性优化技术非常便捷,无需手动设定梯度检查点,通过几个简单的参数,轻松控制遗传算法的搜索策略。具体使用时,在 MegEngine 静态图接口中调用 SublinearMemoryConfig 设置 trace 的参数 sublinear_memory_config,即可打开亚线性显存优化:


from megengine.jit import trace, SublinearMemoryConfig config = SublinearMemoryConfig() @trace(symbolic=True, sublinear_memory_config=config)def train_func(data, label, *, net, optimizer):    ...
复制代码


MegEngine 在编译计算图和训练模型时,虽有少量的额外时间开销,但会显著缓解显存不足问题。下面以 ResNet50 为例,说明 MegEngine 可有效突破显存瓶颈,训练 batch size 从 100 最高增至 200:


import osfrom multiprocessing import Process  def train_resnet_demo(batch_size, enable_sublinear, genetic_nr_iter=0):    import megengine as mge    import megengine.functional as F    import megengine.hub as hub    import megengine.optimizer as optim    from megengine.jit import trace, SublinearMemoryConfig    import numpy as np     print(        "Run with batch_size={}, enable_sublinear={}, genetic_nr_iter={}".format(            batch_size, enable_sublinear, genetic_nr_iter        )    )    # 使用GPU运行这个例子    assert mge.is_cuda_available(), "Please run with GPU"    try:        # 我们从 megengine hub 中加载一个 resnet50 模型。        resnet = hub.load("megengine/models", "resnet50")         optimizer = optim.SGD(resnet.parameters(), lr=0.1,)         config = None        if enable_sublinear:            config = SublinearMemoryConfig(genetic_nr_iter=genetic_nr_iter)         @trace(symbolic=True, sublinear_memory_config=config)        def train_func(data, label, *, net, optimizer):            pred = net(data)            loss = F.cross_entropy_with_softmax(pred, label)            optimizer.backward(loss)         resnet.train()        for i in range(10):            batch_data = np.random.randn(batch_size, 3, 224, 224).astype(np.float32)            batch_label = np.random.randint(1000, size=(batch_size,)).astype(np.int32)            optimizer.zero_grad()            train_func(batch_data, batch_label, net=resnet, optimizer=optimizer)            optimizer.step()    except:        print("Failed")        return     print("Sucess")  # 以下示例结果在2080Ti GPU运行得到,显存容量为 11 GB # 不使用亚线性内存优化,允许的batch_size最大为 100 左右p = Process(target=train_resnet_demo, args=(100, False))p.start()p.join()# 报错显存不足p = Process(target=train_resnet_demo, args=(200, False))p.start()p.join() # 使用亚线性内存优化,允许的batch_size最大为 200 左右p = Process(target=train_resnet_demo, args=(200, True, 20))p.start()p.join()
复制代码

展 望

如上所述,MegEngine 的亚线性显存优化技术通过额外做一次前向计算,即可达到 O(sqrt(n))的空间复杂度。如果允许做更多次的前向计算,对整个网络递归地调用亚线性显存算法,有望在时间复杂度为 O(n log n)的情况下,达到 O(log n)的空间复杂度。


更进一步,MegEngine 还将探索亚线性显存优化技术与数据并行/模型并行、混合精度训练的组合使用问题,以期获得更佳的集成效果。最后,在 RNN 以及 GNN、Transformer 等其他类型网络上的使用问题,也是 MegEngine 未来的一个探索方向。


了解更多信息请查询:



参考文献


1. Chen, T., Xu, B., Zhang, C., & Guestrin, C. (2016). Training deep nets with sublinear memory cost. arXiv preprint arXiv:1604.06174.


2. He, K., Zhang, X., Ren, S., & Sun, J. (2016). Deep residual learning for image recognition. In Proceedings of the IEEE conference on computer vision and pattern recognition (pp. 770-778).


公众号推荐:

跳进 AI 的奇妙世界,一起探索未来工作的新风貌!想要深入了解 AI 如何成为产业创新的新引擎?好奇哪些城市正成为 AI 人才的新磁场?《中国生成式 AI 开发者洞察 2024》由 InfoQ 研究中心精心打造,为你深度解锁生成式 AI 领域的最新开发者动态。无论你是资深研发者,还是对生成式 AI 充满好奇的新手,这份报告都是你不可错过的知识宝典。欢迎大家扫码关注「AI前线」公众号,回复「开发者洞察」领取。

2020-05-09 12:586675
用户头像
刘燕 InfoQ高级技术编辑

发布了 1112 篇内容, 共 493.6 次阅读, 收获喜欢 1967 次。

关注

评论

发布
暂无评论
发现更多内容

百亿级分布式文件系统之元数据设计

焱融科技

云计算 技术 分布式 高性能 文件存储

Compose 中的主题

Changing Lin

8月日更

高并发中,那些不得不说的线程池与ThreadPoolExecutor类

华为云开发者联盟

Java 线程 高并发 线程池 ThreadPoolExecutor类

用Java仿一个低配版的Everything软件

Regan Yue

Java 8月日更 Everything

能源区块链研究 | 加密行业碳抵消有助于大众接纳比特币吗?

CECBC

出现吧,Python Web 菜谱系统的首页,不会前端技术,也能做

梦想橡皮擦

8月日更

架构实战营-模块二作业

俞立夫

架构实战营

Fastdata for TSDB: SQL使时序数据可扩展

数据库 大数据 时序数据库 tsdb 数据智能

智能时代的信任口诀:让计算远离算计

白洞计划

Excelize 发布 2.4.1 版本,新增并发安全支持

xuri

Excel Go 语言 Excelize #Github

FastApi-15-文件上传-3

Python研究所

FastApi 8月日更

LeetCode题解:220. 存在重复元素 III,暴力法,JavaScript,详细注释

Lee Chen

算法 大前端 LeetCode

OpenYurt 联手 eKuiper,解决 IoT 场景下边缘流数据处理难题

阿里巴巴云原生

云计算 阿里云 开源 云原生 中间件

Python入门:ChainMap 有效管理多个上下文

华为云开发者联盟

Python 字典 上下文 映射 ChainMap

区块链技术:为什么说波卡能加速区块链行业的发展?

CECBC

数据加密和BCrypt哈希算法应用 | StartDT Tech Lab 15

奇点云

导播上云,把 “虚拟演播厅” 搬到奥运村

阿里云视频云

阿里云 视频处理 视频直播 视频云 云导播

Seata TCC模式原理与实战

码农参上

分布式事务 seata SpringCloud Alibaba 8月日更

【Vue2.x 源码学习】第三十七篇 - 组件部分 - 组件的合并

Brave

源码 vue2 8月日更

Spark RDD模型

布兰特

spark

docker的使用

Rubble

8月日更

netty系列之:自定义编码解码器

程序那些事

Java Netty 程序那些事

趣说开源|学生如何参与开源社区?

SphereEx

数据库 开源

【LeetCode】有效的字母异位词Java题解

Albert

算法 LeetCode 8月日更

手撸二叉树之递增顺序搜索树

HelloWorld杰少

数据结构与算法 8月日更

零代码以“王者荣耀”为例解析设计七原则

华为云开发者联盟

软件 设计原则 王者荣耀 单一职责

如何将知识引入机器学习模型提升泛化能力?

华为云开发者联盟

机器学习 算法 数据 模型 物理学

“遇见”未来“编程”语言,面向组件编程,送给在校学生

清风

Java 小程序 毕业设计

Go语言:如何通过Go来更好的开发并发程序 ?

微客鸟窝

Go 语言

为什么区块链是互联网的100倍?

CECBC

基于java springboot体育馆预约微信小程序源码(毕设)设计开发

清风

Java 小程序 源码 毕业设计

深度解析MegEngine亚线性显存优化技术_AI&大模型_旷视研究院_InfoQ精选文章