「如何实现流动式软件发布」线上课堂开课啦,快来报名参与课堂抽奖吧~ 了解详情
写点什么

Facebook 发布深度学习工具包 PyTorch Hub,让论文复现变得更容易

2019 年 6 月 11 日

Facebook发布深度学习工具包PyTorch Hub,让论文复现变得更容易

近日,PyTorch 社区发布了一个深度学习工具包 PyTorchHub, 帮助机器学习工作者更快实现重要论文的复现工作。PyTorchHub 由一个预训练模型仓库组成,专门用于提高研究工作的复现性以及新的研究。同时它还内置了对Google Colab的支持,并与Papers With Code集成。目前 PyTorchHub 包括了一系列与图像分类、分割、生成以及转换相关的模型。


可复现性是许多研究领域的基本要求,这其中当然包括基于机器学习技术的研究领域。然而, 许多机器学习相关论文要么无法复现,要么难以重现。随着论文数量的持续增长,包括目前在 arXiv 上预印刷的数万份论文以及提交给会议的论文,研究工作的可复现性变得越来越重要。虽然其中许多论文都附有代码以及训练好的模型,但这种帮助显然非常有限,复现过程中仍有大量需要读者自己摸索的步骤。下面让我们来看一下如何通过 PyTorch Hub 这一利器完成快速的模型发布与工作复现。



如何快速发布模型

这部分主要介绍了对于模型发布者来说如何快速高效的将自己的模型加入 PyTorch Hub 库。PyTorch Hub 支持通过添加简单的 hubconf.py 文件将预先训练的模型(模型定义和预先训练重)发布到 GitHub 存储库。这提供了模型列表以及其依赖库列表。一些示例可以在torchvisionhuggingface-bertgan-model-zoo存储库中找到。


Pytoch 社区给出了 torchvision 的 hubconf.py 文件的示例:


# Optional list of dependencies required by the packagedependencies = ['torch']
from torchvision.models.alexnet import alexnetfrom torchvision.models.densenet import densenet121, densenet169, densenet201, densenet161from torchvision.models.inception import inception_v3from torchvision.models.resnet import resnet18, resnet34, resnet50, resnet101, resnet152, resnext50_32x4d, resnext101_32x8dfrom torchvision.models.squeezenet import squeezenet1_0, squeezenet1_1from torchvision.models.vgg import vgg11, vgg13, vgg16, vgg19, vgg11_bn, vgg13_bn, vgg16_bn, vgg19_bnfrom torchvision.models.segmentation import fcn_resnet101, deeplabv3_resnet101from torchvision.models.googlenet import googlenetfrom torchvision.models.shufflenetv2 import shufflenet_v2_x0_5, shufflenet_v2_x1_0from torchvision.models.mobilenet import mobilenet_v2
复制代码


在 torchvision 中,模型有以下特性:


  • 每个模型文件可以被独立执行或实现某个功能

  • 不需要除了 PyTorch 之外的任何软件包(在 hubconf.py 中编码为 dependencies[‘torch’])

  • 他们不需要单独的入口点,因为模型在创建时可以无缝地开箱即用。


PyTroch 社区认为最小化包依赖性可减少用户加载模型时遇到的困难。这里他们给出了一个更为复杂的例子——HuggingFace’s BERT 模型,它的 hubconf.py 如下:


dependencies = ['torch', 'tqdm', 'boto3', 'requests', 'regex']
from hubconfs.bert_hubconf import ( bertTokenizer, bertModel, bertForNextSentencePrediction, bertForPreTraining, bertForMaskedLM, bertForSequenceClassification, bertForMultipleChoice, bertForQuestionAnswering, bertForTokenClassification)
复制代码


此外,对于每个模型,PyTorch 官方提到都需要为其创建一个入口点。下面是一个用于指定 bertForMaskedLM 模型的入口点的代码片段,这部分代码完成的功能是返回加载了预训练参数的模型。


def bertForMaskedLM(*args, **kwargs):    """    BertForMaskedLM includes the BertModel Transformer followed by the    pre-trained masked language modeling head.    Example:      ...    """    model = BertForMaskedLM.from_pretrained(*args, **kwargs)    return model
复制代码


这些入口点可以看成是复杂的模型结构的一种封装形式。它们可以在提供简洁高效的帮助文档的同时完成下载预训练权重的功能(例如,通过 pretrained = True),也可以集成其他特定功能,例如可视化。


通过hubconf.py,模型发布者可以在 Github 上基于template提交他们的合并请求。PyTorch 社区希望通过 PyTorch Hub 创建一系列高质量、易复现且效果好的模型以提高研究工作的复现性。因此,PyTorch 会通过与模型发布者合作的方式以完善请求,并有可能会在某些情况下拒绝发布一些低质量的模型。一旦 PyTorch 社区接受了模型发布者的请求,这些新的模型将会很快出现在 PyTorch Hub 的网页上以供用户浏览。


用户工作流

对于想使用 PyTorch Hub 对别人的工作进行复现的用户,PyTorch Hub 提供了以下几个步骤:1)浏览可用的模型;2)加载模型;3)探索已加载的模型。下面让我们来浏览几个例子。


浏览可用的入口点

用户可以使用 torch.hub.list() API 列出仓库中的所有可用入口点。


>>> torch.hub.list('pytorch/vision')>>>['alexnet','deeplabv3_resnet101','densenet121',...'vgg16','vgg16_bn','vgg19', 'vgg19_bn']
复制代码


注意,PyTorch Hub 还允许辅助入口点(除了预训练模型),例如,用于 BERT 模型预处理的 bertTokenizer,它可以使用户工作流程更加顺畅。


加载模型

对于 PyTroch Hub 中可用的模型,用户可以使用 torch.hub.load() API 加载模型入口点。此外,torch.hub.help() API 可以提供有关如何实例化模型的有用信息。


print(torch.hub.help('pytorch/vision', 'deeplabv3_resnet101'))model = torch.hub.load('pytorch/vision', 'deeplabv3_resnet101', pretrained=True)
复制代码


由于仓库的持有者会不断添加错误修复以及性能改进,PyTorch Hub 允许用户通过调用以下内容简单地获取最新更新:


model = torch.hub.load(..., force_reload=True)
复制代码


这一举措可以有效地减轻仓库持有者重复发布模型的负担,从而使他们能够更专注于自己的研究工作。同时,也确保了用户可以获得最新版本的模型。


此外,对于用户来说,稳定性也是一个重要问题。因此,某些模型所有者会从特征的分支或标签为他们提供服务,以确保代码的稳定性。例如,pytorch_GAN_zoo 会从 hub 分支为他们提供服务:


model = torch.hub.load('facebookresearch/pytorch_GAN_zoo:hub', 'DCGAN', pretrained=True, useGPU=False)
复制代码


这里,传递给 hub.load() 的 * args,** kwargs 用于实例化模型。在上面的示例中,pretrained = True 和 useGPU = False 被传递给模型的入口点。


探索已加载的模型

从 PyTorch Hub 加载模型后,用户可以使用以下工作流查看已加载模型的可用方法,并更好地了解运行它所需的参数。


其中,dir(model)可以查看模型中可用的方法。下面是 bertForMaskedLM 的一些方法:


>>> dir(model)>>>['forward'...'to''state_dict',]
复制代码


help(model.forward)则会提供使已加载的模型运行时所需参数的视图:


>>> help(model.forward)>>>Help on method forward in module pytorch_pretrained_bert.modeling:forward(input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None)...
复制代码


更多细节可以查看BERTDeepLabV3页面:



其他探索方式与相关资源

PyTorch Hub 中提供的模型也支持 Colab,并且会直接链接在 Papers With Code 上,用户只需单击链接即可开始使用:



PyTorch 提供了一些相关资源帮助用户快速上手 PyTorch Hub:



FAQ

问:如果我们想贡献一个 Hub 中已经有了的模型,但也许我的模型具有更高的准确性,我还应该贡献吗?


答:是的,请提交您的模型,Hub 的下一步是开发投票系统以展示最佳模型。


问:谁负责保管 PyTorch Hub 的模型权重?


答:作为贡献者,您负责保管模型权重。您可以在您喜欢的云存储中托管您的模型,或者如果它符合限制,则可以在 GitHub 上托管您的模型。 如果您无法保管权重,请通过 Hub 仓库中提交问题的方式与我们联系。


问:如果我的模型使用了私有化数据进行训练怎么办?我还应该贡献这个模型吗?


答:请不要提交您的模型!PyTorch Hub 以开源研究为中心,并扩展到使用公开数据集来训练这些模型。如果提交了私有模型的合并请求,我们将恳请您重新提交使用公开数据进行训练后的模型。


问:我下载的模型保存在哪里?


答:我们遵循 XDG 基本目录规范,并遵循缓存文件和目录的通用标准。这些位置按以下顺序使用:


  • 调用 hub.set_dir(<PATH_TO_HUB_DIR>)

  • 如果环境变量了 TORCH_HOME,则为 $TORCH_HOME/hub。

  • 如果设置了环境变量 XDG_CACHE_HOME,则为 $ XDG_CACHE_HOME / torch / hub。

  • ~/.cache/torch/hub


相关推荐:



2019 年 6 月 11 日 14:5616312

评论

发布
暂无评论
发现更多内容

第二周-学习总结

ray-arch

极客大学架构师训练营

week2-作业1

Mr_No爱学习

2周 总结

水浴清风

Week_06 总结+作业

golangboy

极客大学架构师训练营

技术选型二第六周作业「架构师训练营第 1 期」

天天向善

架构师 01 期,第六周课后作业

子文

架构师训练营第六周学习总结

Gosling

极客大学架构师训练营

架构师训练营 Week6 - 课后作业

极客大学架构师训练营

第六周作业

Geek_ce484f

极客大学架构师训练营

极客时间架构 1 期:第 6 周 技术选型(二) - 命题作业

Null

打工人必会算法—快速幂算法讲解

bigsai

2020.10.26-2020.11.01 学习总结

icydolphin

极客大学架构师训练营

周练习 6

何毅曦

极客时间架构 1 期:第6周 技术选型(二) - 学习总结

Null

极客时间 - 架构训练营 第一周总结 - 设计原则

架构师课程第二周作业

文江

学习笔记:架构师训练营-第六周

四夕晖

Architecture Phase1 Week6:HomeWork

phylony-lu

极客大学架构师训练营

【第六周】课后作业

云龙

架构师训练营第二期 Week 2 作业

bigxiang

极客大学架构师训练营

第2周作业

Rocky·Chen

LeetCode题解:90. 子集 II,回溯+哈希表去重,JavaScript,详细注释

Lee Chen

算法 LeetCode 前端进阶训练营

第六周作业总结

Geek_ce484f

极客大学架构师训练营

CAP原理简述及应用

博古通今小虾米

CAP

第六周作业1

Yangjing

极客大学架构师训练营

架构师训练营第六周作业

四夕晖

第六周作业2

Yangjing

极客大学架构师训练营

week06作业

追风

架构师一期

学习总结 -week2

Mr_No爱学习

week2-作业

Mr_No爱学习

思考 - 从传统雪崩到K8S

东风微鸣

k8s

Facebook发布深度学习工具包PyTorch Hub,让论文复现变得更容易-InfoQ