NVIDIA 初创加速计划,免费加速您的创业启动 了解详情
写点什么

带你跨过神经网络训练常见的 37 个坑

  • 2017-09-17
  • 本文字数:4091 字

    阅读完需:约 13 分钟

神经网络已经持续训练了 12 个小时。它看起来很好:梯度在变化,损失也在下降。但是预测结果出来了:全部都是零值,全部都是背景,什么也检测不到。我质问我的计算机:“我做错了什么?”,它却无法回答。

如果你的模型正在输出垃圾(比如预测所有输出的平均值,或者它的精确度真的很低),那么你从哪里开始检查呢?

无法训练神经网络的原因有很多。在经历了许多次调试之后,我发现有一些检查是经常做的。这张列表汇总了我的经验以及最好的想法,希望对读者也有所帮助。

〇. 使用指南

许多事情都可能出错。但其中有些事情相比于其他方面更容易出问题。在出现问题时,我通常会做以下几件事情。

  1. 从已知的用于该数据类型的简单模型入手(例如 VGG 用于图像处理)。尽可能使用标准误差。
  2. 去掉所有的花哨的预处理程序,例如正则化和数据增强。
  3. 微调模型,仔细检查预处理,应该和原始模型的训练设置保持一致。
  4. 验证输入数据的正确性。
  5. 从较小的数据集开始(2-20 个样本)。在小数据集上过拟合之后再增加数据量。
  6. 慢慢加入之前忽略的项:增强或正则化、自定义损失函数,以及尝试更多的复杂模型。

如果上面的步骤还不能解决,可以开始一项一项的按以下列表进行检查。

Ⅰ. 数据集问题

1. 检查你的输入数据

检查馈送到网络的输入数据是否正确。例如,我不止一次混淆了图像的宽度和高度。有时,我错误地让输入数据全部为零,或者一遍遍地使用同一批数据。所以要打印或显示一些批次的输入和目标输出,并确保它们是正确的。

2. 尝试随机输入

尝试向网络传入随机数而不是真实数据,看看错误的产生方式是否相同。如果是,说明在某些时候你的网络把数据转化为了垃圾。试着逐层调试,并查看出错的地方。

3. 检查数据加载器

你的数据也许很好,但是把输入数据读取到网络的代码可能有问题,所以我们应该在进行其他操作之前打印出第一层的输入并进行检查。

4. 确保输入与输出相关联

检查少许输入样本是否有正确的标签,确保打乱输入样本同样也要打乱输出标签。

5. 输入与输出之间的关系是否太随机

相较于随机的部分(可以认为股票价格也是这种情况),输入与输出之间的非随机部分也许占得比重太小。也就是说输入与输出的关联度太低。没有统一的方法来检测它,因为这取决于数据的性质。

6. 数据集中是否有太多的噪声

我曾经遇到过这种情况,当我从一个食品网站抓取一个图像数据集时,错误标签太多以至于网络无法学习。手动检查一些输入样本并查看标签是否大致正确。例如这篇文章,由于在MNIST 数据集中使用了50% 损坏的标签,只得到了50% 的准确率。

7. 打乱数据集

如果你的数据集没有被随机打乱,并且有特定的序列(按标签排序),这可能给学习带来不利影响。打乱数据集可以避免这一问题。要确保输入和标签都被重新排列。

8. 减少类别失衡

是不是对于一张类别 B 的图像,有 1000 张类别 A 图像?如果是这种情况,那么你也许需要平衡损失函数或者尝试其他解决类别失衡的方法

9. 你有足够的训练实例吗?

如果你从头开始训练一个网络(不是调试),你很可能需要大量数据。对于图像分类,每个类别需要 1000 张图像甚至更多。

10. 确保一批数据不是单一标签

这可能发生在排好顺序的数据集中(即前 10000 个样本属于同一个分类)。可通过打乱数据集轻松修复这个问题。

11. 缩减训练批次大小

这篇文章指出巨大的批次会降低模型的泛化能力。

补充. 使用标准数据集(例如MNIST,cifar10)

测试新的网络结构,或者写了一段新代码时,首先要使用标准数据集,而不是你自己的数据。这是因为在这些数据集上已经有了许多参考结果,他们被证明是“可解的”。不会出现标签噪音、训练/ 测试分布差距、数据集太难等问题。

Ⅱ. 数据归一化/ 增强

12. 归一化特征

你的输入已经归一化到零均值和单位方差了吗?

13. 你是否应用了过量的数据增强?

数据增强有正则化效果。过量的数据增强,加上其它形式的正则化(权重 L2,dropout 操作,等等)可能会导致网络欠拟合。

14. 检查预训练模型的预处理过程

如果你正在使用一个已经预训练过的模型,确保你现在正在使用的归一化和预处理与之前训练模型的设置相同。例如,一个图像的像素是在 [0, 1],[-1, 1] 或 [0, 255] 的范围内吗?

15. 检查训练、验证、测试集的预处理

CS231n 指出了一个常见的陷阱:“任何预处理数据(例如数据均值)必须只在训练数据上进行计算,然后再应用到验证、测试数据中。例如,计算均值,然后在整个数据集的每个图像中都减去它,再把数据分发进训练、验证、测试集中,这是一个典型的错误。”

此外,要在每一个样本或批次(batch)中检查是否存在不同的预处理。

Ⅲ. 实现问题

16. 试着解决某一问题的更简单的版本

这将会有助于找到问题的根源究竟在哪里。例如,如果目标输出是一个物体类别和坐标,那就试着把预测结果仅限制在物体类别当中。

17. “碰巧”寻找正确的损失

还是来源于 CS231n 的技巧:用小参数进行初始化,不使用正则化。例如,如果我们有 10 个类别,“碰巧”就意味着我们将会在 10% 的时间里得到正确类别,Softmax 损失是正确类别的负 log 概率: -ln(0.1) = 2.302。然后,试着增加正则化的强度,这样应该会增加损失。

18. 检查你的损失函数

如果你实现的是你自己的损失函数,那么就要检查错误,并且添加单元测试。通常情况下,损失可能会有些不正确,并且略微损害网络的性能表现。

19. 核实损失输入

如果你正在使用的是框架提供的损失函数,那么要确保你传递给它的东西是它所期望的。例如,在 PyTorch 中,我会混淆 NLLLoss 和 CrossEntropyLoss,因为一个需要 softmax 输入,而另一个不需要。

20. 调整损失权重

如果你的损失由几个更小的损失函数组成,那么确保它们每一个的相应幅值都是正确的。这可能会涉及到测试损失权重的不同组合。

21. 监控其它指标

有时损失并不是衡量你的网络是否被正确训练的最佳预测器。如果可以的话,使用其它指标来帮助你,例如精度。

22. 测试任意的自定义层

你自己在网络中实现过任意层吗?检查并且复核以确保它们的运行符合你的预期。

23. 检查“冷冻”层或变量

检查你是否无意中阻止了一些层或变量的梯度更新,这些层或变量本来应该是可以学习的。

24. 扩大网络规模

可能你网络的表现力不足以捕捉目标函数。试着加入更多的层,或在全连层中增加更多的隐藏单元。

25. 检查隐维度误差

如果你的输入看上去像(k,H,W)= (64, 64, 64),那么很容易错过与错误维度相关的误差。给输入维度使用一些“奇怪”的数值(例如,每一个维度使用不同的质数),并且检查它们是如何通过网络传播的。

26. 探索梯度检查

如果你手动实现了梯度下降,梯度检查会确保你的反向传播能像预期一样工作。

更多信息: 1 2 3

Ⅳ. 训练问题

27. 一个真正小的数据集

过拟合数据的一个小子集,并确保它能正常工作。例如,仅使用 1 个 或 2 个实例训练,并查看你的网络是否能够区分它们。然后再训练每个分类的更多实例。

28. 检查权重初始化

如果不确定,请使用 Xavier He 初始化。同样,初始化也许会给你带来坏的局部最小值,因此尝试不同的初始化,看看是否有效。

29. 改变你的超参数

或许你正在使用一个很糟糕的超参数集。如果可行,尝试一下网格搜索

30. 减少正则化

太多的正则化会导致网络严重地欠拟合。减少正则化,比如 dropout、批归一、权重/偏差 L2 正则化等。在课程《编程人员的深度学习实战》中, Jeremy Howard 建议首先解决欠拟合问题。这意味着你充分地过拟合训练数据,并且只在那时处理过拟合。

31. 给它一些时间

也许你的网络需要更多的时间来训练,在它能做出有意义的预测之前。如果你的损失在稳步下降,那就再多训练一会儿。

32. 从训练模式转换为测试模式

一些框架有批归一化层、Dropout 层,而其他的层在训练和测试时表现并不同。转换到适当的模式有助于网络更好地预测。

33. 可视化训练

  • 监督每层的激活值、权重和更新。确保它们的大小匹配。例如,参数更新的大小幅度(权重和偏差)应该是 1-e3
  • 考虑可视化库,例如 Tensorboard Crayon 。紧要时你也可以打印权重、偏差或激活值。
  • 寻找平均值远大于 0 的层激活。尝试批归一化层或者 ELU 单元。
  • Deeplearning4j 指出了权重和偏差柱状图的期望值应该是什么样的: 对于权重,一段时间之后这些柱状图应该有一个近似高斯的(正态)分布。对于偏差,这些柱状图通常会从 0 开始,并经常以近似高斯(LSTM 是例外情况)结束。留意那些向正无穷或负无穷发散的参数。留意那些变得很大的偏差。这有可能发生在分类网络的输出层,如果类别的分布不均匀。
  • 检查层更新,它们应该呈高斯分布。

34. 尝试不同的优化器

优化器的选择不应当妨碍网络的训练,除非你选择了特别糟糕的超参数。但是,选择一个合适的优化器非常有助于在最短的时间内获得最多的训练结果。描述算法的论文应该指定了优化器,如果没有,我倾向于选择 Adam 或者带有动量的朴素 SGD。

关于梯度下降的优化器可以参考 Sebastian Ruder 的博文

35. 梯度爆炸、梯度消失

  • 检查隐藏层的更新情况,过大的值说明可能出现了梯度爆炸。这时,梯度截断(Gradient clipping)可能会有所帮助。
  • 检查隐藏层的激活值。 Deeplearning4j 中有一个很好的指导方针:“一个好的激活值标准差大约在 0.5 到 2.0 之间。明显超过这一范围可能就代表着激活值消失或爆炸。”

36. 增加、减少学习速率

低学习速率将会导致你的模型收敛很慢。高学习速率将会在开始阶段减少你的损失,但是可能会导致你很难找到一个好的解决方案。

试着把你当前的学习速率乘以 0.1 或 10 然后进行循环。

37. 克服 NaN

据我所知,在训练 RNNs 时得到 NaN(Non-a-Number,非数)是一个很大的问题。一些解决它的方法:

  • 减小学习速率,尤其是如果你在前 100 次迭代中就得到了 NaN。
  • NaNs 的出现可能是由于用零作了除数,或用零或负数作了自然对数。
  • Russell Stewart 在《如何处理 NaN》中分享了很多心得。
  • 尝试逐层评估你的网络,这样就会看见 NaN 到底出现在了哪里。

关于作者:Slav Ivanov 是保加利亚索菲亚的企业家和 ML 实践者。博客主页

查看英文原文: 37 Reasons why your Neural Network is not working


感谢薛命灯对本文的审校。

给InfoQ 中文站投稿或者参与内容翻译工作,请邮件至 editors@cn.infoq.com 。也欢迎大家通过新浪微博( @InfoQ @丁晓昀),微信(微信号: InfoQChina )关注我们。

2017-09-17 17:4017302
用户头像

发布了 52 篇内容, 共 28.2 次阅读, 收获喜欢 72 次。

关注

评论

发布
暂无评论
发现更多内容

网络安全小白别拜师了,求人不如求己

网络安全学海

黑客 网络安全 信息安全 渗透测试 安全漏洞

学习心得-架构训练营-第一课

Fm

【DPDK工程师手册】 —— 官方文档,最新视频,开源项目,论文,大厂内部ppt,知名工程师一览表

奔着腾讯去

Linux DPDK VPP

架构师实战营作业[模块一]

看,有只猪

模块一作业

紫云

架构实战营

图像分类-cifar100 实验研究

毛显新

人工智能 神经网络 tensorflow 图像识别 keras

学生管理系统(作业)

Geek_a772a7

AI巨头们建造的“新世界”,进展如何?

脑极体

在线JSON转XML工具

入门小站

工具

Flutter 安卓 Platform 与 Dart 端消息通信方式 Channel 源码解析

工匠若水

flutter android 8月日更

作业

Li. Mr

架构实战营模块六作业

老猎人

架构实战营

架构训练营 模块一作业

初一

☕【Java技术指南】「OpenJDK专题」想不想编译属于你自己的JDK呢?(Windows10环境)

洛神灬殇

Java jdk Openjdk 8月日更

架构训练营模块一作业

guangbao

极客时间【架构实战营】第二期 模块一作业

Geek_91606e

架构实战营

Linux之nohup命令

入门小站

Linux

百度地图开发-实现离线地图功能 05

Andy阿辉

android 百度地图 Android 小菜鸟 Android端 8月日更

模块一作业

potti

架构实战营

正经人一辈子都用不到的 JavaScript 方法总结 (二)

编程三昧

JavaScript 大前端 8月日更

graphql中的'子查询'

杜艮魁

开源 后端 graphql

搜索引擎渐行渐远,未来路在何方

石头IT视角

初识html,一文搞懂HTMl骨架标签都有哪些含义及浏览器内核

你好bk

html html5 大前端 浏览器 html/css

微信业务架构图-作业

Geek_a772a7

🚀【Guava技术指南】「RateLimiter类」服务请求流控实现方案

洛神灬殇

Java ratelimiter Guava 8月日更

微信的业务架构图

Rabbit

架构实战营

HTTP协议之:HTTP/1.1和HTTP/2

程序那些事

HTTP 程序那些事 HTTP协议 http2

OceanBase 源码解读(三)分区的一生

OceanBase 数据库

数据库 分布式数据库 oceanbase OceanBase 开源 OceanBase 社区版

01. 你身边的AI

数据与智能

人工智能

[架构实战营]模块一

Amy

架构实战营 业务架构图

python通过PyQt5实现登录界面

Python研究者

8月日更

带你跨过神经网络训练常见的37个坑_语言 & 开发_Slav Ivanov_InfoQ精选文章