GTLC全球技术领导力峰会·上海站,首批讲师正式上线! 了解详情
写点什么

深度学习正在被滥用

2021 年 4 月 22 日

深度学习正在被滥用

本文最初发布于 Medium 网站,经原作者授权由 InfoQ 中文站翻译并分享。


在某些情况下,神经网络之类模型的表现可能会胜过更简单的模型,但很多情况下事情并不是这样的。


打个比方:假设你需要购买某种交通工具来跑运输,如果你经常需要长距离运输大型物品,那么,购买卡车是很划算的投资;但如果你只是要去本地超市买点牛奶,那么买一辆卡车就太浪费了。一辆汽车(如果你关心气候变化的话,甚至可以买一辆自行车)也足以完成上述任务。


深度学习的使用场景也开始遇到这种问题了:我们假设它们的性能优于简单模型,然后把相关数据一股脑儿地塞给它们。此外,我们在应用这些模型时往往并没有对相关数据有适当的理解;比如说我们没有意识到,如果对数据有直观的了解,就不必进行深度学习。


任何模型被装在黑匣子里来分析数据时,总是会存在危险,深度学习家族的模型也不例外。

时间序列分析

我最常用的是时间序列分析,因此我们来考虑一个这方面的例子。


假设一家酒店希望预测其在整个客户群中收取的平均每日费用(或每天的平均费用)——ADR。每位客户的平均每日费用是每周开销的平均值。


LSTM 模型的配置如下:


model = tf.keras.Sequential()model.add(LSTM(4, input_shape=(1, lookback)))model.add(Dense(1))model.compile(loss='mean_squared_error', optimizer='adam')history=model.fit(X_train, Y_train, validation_split=0.2, epochs=100, batch_size=1, verbose=2)
复制代码


下面是预测与实际的每周 ADR:


资料来源:Jupyter Notebook 输出


获得的 RMSE 为 31,均值 160。RMSE(均方根误差)的大小是平均 ADR 大小的 20%。误差并不算高,但不得不承认,神经网络的目的是尽可能获得比其他模型更高的准确度,所以这个结果还是有些令人失望。


此外,这个 LSTM 模型是一个一步预测——意味着如果没有可用的时间 t 之前的所有数据,该模型就无法进行长期预测。


也就是说,我们是不是太急着对数据应用 LSTM 模型了呢?


我们先回到出发点,首先对数据做一个全面的分析。


下面是 ADR 波动的 7 周移动平均值:



资料来源:Jupyter Notebook 输出


当数据通过 7 周的移动平均值进行平滑处理后,我们可以清楚地看到季节性模式的证据。


我们来仔细看看数据的自相关函数。



资料来源:Jupyter Notebook 输出


我们可以看到,峰值相关性(在一系列负相关性之后)滞后 52,表明数据中存在年度季节属性。


有了这一信息后,我们可以使用 pmdarima 配置 ARIMA 模型来预测 ADR 波动的最后 15 周,并自动选择 p、d、q 坐标以最小化赤池量信息准则。


>>> Arima_model=pm.auto_arima(train_df, start_p=0, start_q=0, max_p=10, max_q=10, start_P=0, start_Q=0, max_P=10, max_Q=10, m=52, stepwise=True, seasonal=True, information_criterion='aic', trace=True, d=1, D=1, error_action='warn', suppress_warnings=True, random_state = 20, n_fits=30)Performing stepwise search to minimize aicARIMA(0,1,0)(0,1,0)[52]             : AIC=422.399, Time=0.27 secARIMA(1,1,0)(1,1,0)[52]             : AIC=inf, Time=16.12 secARIMA(0,1,1)(0,1,1)[52]             : AIC=inf, Time=19.08 secARIMA(0,1,0)(1,1,0)[52]             : AIC=inf, Time=14.55 secARIMA(0,1,0)(0,1,1)[52]             : AIC=inf, Time=11.94 secARIMA(0,1,0)(1,1,1)[52]             : AIC=inf, Time=16.47 secARIMA(1,1,0)(0,1,0)[52]             : AIC=414.708, Time=0.56 secARIMA(1,1,0)(0,1,1)[52]             : AIC=inf, Time=15.98 secARIMA(1,1,0)(1,1,1)[52]             : AIC=inf, Time=20.41 secARIMA(2,1,0)(0,1,0)[52]             : AIC=413.878, Time=1.01 secARIMA(2,1,0)(1,1,0)[52]             : AIC=inf, Time=22.19 secARIMA(2,1,0)(0,1,1)[52]             : AIC=inf, Time=25.80 secARIMA(2,1,0)(1,1,1)[52]             : AIC=inf, Time=28.23 secARIMA(3,1,0)(0,1,0)[52]             : AIC=414.514, Time=1.13 secARIMA(2,1,1)(0,1,0)[52]             : AIC=415.165, Time=2.18 secARIMA(1,1,1)(0,1,0)[52]             : AIC=413.365, Time=1.11 secARIMA(1,1,1)(1,1,0)[52]             : AIC=415.351, Time=24.93 secARIMA(1,1,1)(0,1,1)[52]             : AIC=inf, Time=21.92 secARIMA(1,1,1)(1,1,1)[52]             : AIC=inf, Time=30.36 secARIMA(0,1,1)(0,1,0)[52]             : AIC=411.433, Time=0.59 secARIMA(0,1,1)(1,1,0)[52]             : AIC=413.422, Time=11.57 secARIMA(0,1,1)(1,1,1)[52]             : AIC=inf, Time=23.39 secARIMA(0,1,2)(0,1,0)[52]             : AIC=413.343, Time=0.82 secARIMA(1,1,2)(0,1,0)[52]             : AIC=415.196, Time=1.63 secARIMA(0,1,1)(0,1,0)[52] intercept   : AIC=413.377, Time=1.04 secBest model:  ARIMA(0,1,1)(0,1,0)[52]Total fit time: 313.326 seconds
复制代码


根据上面的输出,ARIMA(0,1,1)(0,1,0)[52]是 AIC 的最佳拟合模型。使用这个模型,对于 160 的平均 ADR,可获得 10 的 RMSE。


这比 LSTM 实现的 RMSE 要低得多(这是一件好事),仅占均值大小的 6%多。


对数据进行适当的分析后,人们会认识到,数据中存在的年度季节属性可以让时间序列更具可预测性,而使用深度学习模型来尝试预测这种属性在很大程度上是多余的。

回归分析:预测客户 ADR 值


我们换个角度来讨论上述问题。


现在我们不再尝试预测平均每周 ADR,而是尝试预测每个客户的 ADR 值。


为此我们使用两个基于回归的模型:


  • 线性 SVM(支持向量机)

  • 基于回归的神经网络


两种模型均使用以下特征来预测每个客户的 ADR 值:


  • IsCanceled:客户是否取消预订

  • country:客户的原籍国

  • marketsegment:客户的细分市场

  • deposittype:客户是否已支付订金

  • customertype:客户类型

  • rcps:所需的停车位

  • arrivaldateweekno:到达的星期数


我们使用平均绝对误差作为效果指标,来对比两个模型相对于平均值获得的 MAE。

线性支持向量机


这里定义了 epsilon 为 0.5 的 LinearSVR,并使用训练数据进行了训练:


svm_reg_05 = LinearSVR(epsilon=0.5)svm_reg_05.fit(X_train, y_train)
复制代码


现在使用测试集中的特征值进行预测:


>>> svm_reg_05.predict(atest)array([ 81.7431138 , 107.46098525, 107.46098525, ...,  94.50144931,94.202052  ,  94.50144931])
复制代码


这是相对于均值的均值绝对误差:


>>> mean_absolute_error(btest, bpred)30.332614341027753>>> np.mean(btest)105.30446539770578
复制代码


MAE 是均值大小的 28%。让我们看看基于回归的神经网络是否可以做得更好。

基于回归的神经网络


神经网络的定义如下:


model = Sequential()model.add(Dense(8, input_dim=8, kernel_initializer='normal', activation='elu'))model.add(Dense(2670, activation='elu'))model.add(Dense(1, activation='linear'))model.summary()
复制代码


使用的批大小是 150,用 30 个 epoch 训练模型:


model.compile(loss='mse', optimizer='adam', metrics=['mse','mae'])history=model.fit(xtrain_scale, ytrain_scale, epochs=30, batch_size=150, verbose=1, validation_split=0.2)predictions = model.predict(xval_scale)
复制代码


现在将测试集的特征输入到模型中,以下是 MAE 和平均值:


>>> mean_absolute_error(btest, bpred)28.908454264679218>>> np.mean(btest)105.30446539770578
复制代码


我们看到,MAE 仅仅比使用 SVM 所获得的 MAE 低一点。因此,当线性 SVM 模型显示出几乎相同的准确度时,很难证明使用神经网络来预测客户 ADR 是合适的选项。


无论如何,用于“解释”ADR 的特征选择之类的因素比模型本身有着更大的相关性。俗话说,“进垃圾,出垃圾”。如果特征选取很烂,模型输出也会很差。


在上面这个例子里,尽管两个回归模型都显示出一定程度的预测能力,但很可能要么 1)选择数据集中的其他特征可以进一步提高准确性,要么 2)ADR 的变量太多,对数据集中特征的影响太大。例如,数据集没有告诉我们关于每个客户收入水平的任何信息,这些因素将极大地影响他们每天的平均支出。

结论


在上面的两个示例中我们已经看到,使用“更轻”的模型已经能够匹配(或超过)深度学习模型所实现的准确性。


在某些情况下,数据可能非常复杂,需要“从头开始”在数据中使用算法学习模式,但这往往是例外,而不是规则。


对于任何数据科学问题,关键是首先要了解我们正在使用的数据,模型的选择往往是次要的。


可以在此处(https://github.com/MGCodesandStats/hotel-modelling)找到上述示例的数据集和 Jupyter 笔记本。


原文链接:


https://towardsdatascience.com/deep-learning-is-becoming-overused-1e6b08bc709f

2021 年 4 月 22 日 10:051
用户头像
刘燕 InfoQ记者

发布了 555 篇内容, 共 174.6 次阅读, 收获喜欢 1055 次。

关注

评论

发布
暂无评论
发现更多内容

HDFS SHELL详解(6)

罗小龙

hadoop 28天写作 hdfs shell

与前端训练营的日子 --Week11

SamGo

学习

大型企业引进低代码开发技术是大趋势

Sam678678

用Rust写点啥:数据结构篇——单向链表

Kurtis Moxley

数据结构 rust

Java 异常处理

学个球

Java java异常处理

夜莺二次开发指南-用户资源中心

qinyening

滴滴夜莺 夜莺监控

Socket粘包问题终极解决方案—Netty版(2W字)!

王磊

Java socket Netty

实时媒体AI,打破内容创作天花板,加速视频创新

华为云开发者社区

人工智能 云原生 媒体 视频

智慧平安社区平台搭建方案,智慧社区综合管理系统开发

WX13823153201

智慧平安社区平台搭建

进来抄作业:分布式系统中保证高可用性的常用经验

华为云开发者社区

高可用 运维 设计 分布式系统 系统

智汇华云 | 安超OS为企业数字化转型构建坚实的云基座

华云数据

每个人都拥有这项神技能

熊斌

职场成长 28天写作

低代码开发技术

Sam678678

Redis 学习笔记 02:SDS 简单动态字符串

架构精进之路

redis 七日更 28天写作

项目管理系列(1)-如何开好一个周会

Ian哥

项目管理 28天写作

港股配资系统搭建

软件开发大鱼V15724971504

金融科技 港股交易系统开发 在线开户系统 CFD交易系统 港股多账户系统

知乎问答:“既然生命无意义,为什么要活着?”

三只猫

28天写作

【PS】给黑白照片上色

学习委员

PhotoShop ps 28天写作

十八般武艺玩转GaussDB(DWS)性能调优:路径干预

华为云开发者社区

数据库 sql 性能调优 GaussDB 算子

JFR定位线上问题实例 - JFR导致的雪崩问题定位与解决

AI乔治

Java 架构 线程

调查bug的手段有哪些?(没有调查,就没有发言权,二)Jan 13, 2021

王泰

28天写作

okhttp3 第一次使用

我就感觉到快

夜莺二次开发指南-任务执行中心

qinyening

滴滴夜莺 夜莺监控

一文学会Java死锁和CPU 100% 问题的排查技巧

AI乔治

Java 架构 死锁 cpu 100%

夜莺二次开发指南-资产设备管理

qinyening

滴滴夜莺 夜莺监控

僵尸进程的成因以及僵尸可以被“杀死”吗?

AI乔治

Java 架构 进程

微信视频号常见问题 | 视频号 28 天 (06)

赵新龙

28天写作

智能合约上链系统开发|智能合约上链APP软件开发

开發I852946OIIO

系统开发

开发复杂业务系统,有哪些设计思路

邴越

Android开发时的多点触控是如何实现的?

博文视点Broadview

一次慢查询暴露的隐蔽的问题

AI乔治

Java sql 架构 SQL优化

DNSPod与开源应用专场

DNSPod与开源应用专场

深度学习正在被滥用-InfoQ