【ArchSummit架构师峰会】探讨数据与人工智能相互驱动的关系>>> 了解详情
写点什么

如何使用半监督学习为结构化数据训练出更好的深度学习模型

  • 2020-10-22
  • 本文字数:2368 字

    阅读完需:约 8 分钟

如何使用半监督学习为结构化数据训练出更好的深度学习模型

本文最初发表于 Towards Data Science 博客,经原作者 Youness Mansar 授权,InfoQ 中文站翻译并分享。


众所周知,深度学习在应用于文本、音频或图像等非结构化数据时效果很好,但在应用于结构化或表格化数据时,深度学习有时会落后于其他机器学习方法,如梯度提升等。在本文中,我们将使用半监督学习来提高深度神经模型在低数据环境下应用于结构化数据时的性能。我们将展示通过使用无监督的预训练,可以使神经模型的性能优于梯度提升。


本文是基于以下两篇论文:



我们实现了一个类似于 AutoInt 论文中提出的深度神经结构,使用了多头自注意力和特征嵌入。预训练部分取自 TabNet 的论文。

方法说明

我们将处理结构化数据,这意味着可以将数据写成具有列(数字、分类、序号)和行的表。我们还假设我们有大量的未标记样本,可以用于预训练,以及少量的标记样本,可用于监督学习。在接下来的实验中,我们将模拟这个环境来绘制学习曲线,并在使用不同大小的标记集时对该方法进行评估。

数据准备

让我们用一个例子来描述在将数据提供给神经网络之前我们是如何准备数据的。



在这个例子中,我们有三个样本和三个特征 {F1,F2,F3} 和一个目标。F1 是分类特征,而 F2 F3 是数字特征。


我们将为 F1 的每个模态 X 创建一个新特征 F1_X,如果 F1==X,则为其赋值 1,否则等于 0。


转换后的样本将写入一组 (Feature_Name, Feature_Value)


例如:


第一个样本 → {(F1_A, 1), (F2, 0.3), (F3, 1.3)}


第二个样本 → {(F1_B, 1), (F2, 0.4), (F3, 0.9)}


第三个样本 → {(F1_C, 1), (F2, 0.1), (F3, 0.8)}


特征名称将被馈送到嵌入层,然后与特征值相乘。

模型:

这里使用的模型是一个多头注意力块序列和逐点前馈层。在训练时,我们也使用池化的注意力跳过连接。多头注意力模块允许我们对特征之间可能存在的交互进行建模,而池化的注意力跳过连接允许我们从一组特征嵌入中获得单个向量。


预训练

在预训练步骤中,我们使用完整的未标记数据集,输入特征的损坏版本,并训练模型来预测未损坏的特征,类似于在去噪自动编码器中所做的操作。

监督式训练

在训练的监督部分,我们在编码器部分和输出端之间添加跳过连接,并尝试预测目标。


实验

在接下来的实验中,我们将使用四个数据集,其中两个用于回归,两个用于分类。


  • Sarco:有大约 5 万个样本,21 个特征和 7 个连续目标。

  • Online News:有 4 万个左右的样本,61 个特征和 1 个连续目标。

  • Adult Census:有大约 4 万个样本、15 个特征和 1 个二元目标。

  • Forest Cover:有大约 50 万个样本,54 个特征和 1 个分类目标。


我们将比较一个预训练神经模型和一个从零开始训练的神经模型,将重点关注地数据状态下的性能,这意味着几百到几千个标记样本。我们还将于一个流行的名为lightgbm的梯度提升实现进行比较。

Forest Cover:

Adult Census:


对于这个数据集,我们可以看到,如果训练集小于 2000,那么预训练是非常有效的。

Online News:

对于 Online News 数据集,我们可以看到,预训练神经网络是非常有效的,甚至在所有样本大小为 500 或更大的情况下都超过了梯度提升。



对于 Sarco 数据集,我们可以看到,预训练神经网络是非常有效的,甚至在所有样本大小的情况下超过了梯度提升。


旁注:用于重现结果的代码

重现结果的代码可以在这里找到:


https://github.com/CVxTz/DeepTabular


使用这段代码,你可以很轻松地训练分类或回归模型:


import pandas as pdfrom sklearn.model_selection import train_test_splitfrom deeptabular.deeptabular import DeepTabularClassifierif __name__ == "__main__":data = pd.read_csv("../data/census/adult.csv")train, test = train_test_split(data, test_size=0.2, random_state=1337)target = "income"num_cols = ["age", "fnlwgt", "capital.gain", "capital.loss", "hours.per.week"]cat_cols = ["workclass","education","education.num","marital.status","occupation","relationship","race","sex","native.country",]for k in num_cols:mean = train[k].mean()std = train[k].std()train[k] = (train[k] - mean) / stdtest[k] = (test[k] - mean) / stdtrain[target] = train[target].map({"<=50K": 0, ">50K": 1})test[target] = test[target].map({"<=50K": 0, ">50K": 1})classifier = DeepTabularClassifier(num_layers=10, cat_cols=cat_cols, num_cols=num_cols, n_targets=1,)classifier.fit(train, target_col=target, epochs=128)pred = classifier.predict(test)classifier.save_config("census_config.json")classifier.save_weigts("census_weights.h5")new_classifier = DeepTabularClassifier()new_classifier.load_config("census_config.json")new_classifier.load_weights("census_weights.h5")new_pred = new_classifier.predict(test)
复制代码

结论

在计算机视觉或自然语言领域,无监督预训练可以提高神经网络的性能。在本文中,我们展示了它在应用于结构化数据时也能起作用,使其在低数据环境与其他机器学习方法(如梯度提升)具有竞争力。


作者简介:


Youness Mansar,供职于 Fortia Financial Solutions 的数据科学家。巴黎中央理工学院(Ecole Centrale Paris)应用数学硕士学位和巴黎-萨克雷高等师范学校(École normale supérieure Paris-Saclay)机器学习硕士。作为 Fortia 的数据科学家,曾参与过多个涉及自然语言处理和深度学习的项目。


原文链接:


https://towardsdatascience.com/training-better-deep-learning-models-for-structured-data-using-semi-supervised-learning-8acc3b536319


公众号推荐:

跳进 AI 的奇妙世界,一起探索未来工作的新风貌!想要深入了解 AI 如何成为产业创新的新引擎?好奇哪些城市正成为 AI 人才的新磁场?《中国生成式 AI 开发者洞察 2024》由 InfoQ 研究中心精心打造,为你深度解锁生成式 AI 领域的最新开发者动态。无论你是资深研发者,还是对生成式 AI 充满好奇的新手,这份报告都是你不可错过的知识宝典。欢迎大家扫码关注「AI前线」公众号,回复「开发者洞察」领取。

2020-10-22 09:001519
用户头像
刘燕 InfoQ高级技术编辑

发布了 1112 篇内容, 共 493.1 次阅读, 收获喜欢 1966 次。

关注

评论

发布
暂无评论
发现更多内容

Spring 框架使用了哪些设计模式?

Java快了!

spring框架

前端常见react面试题合集

beifeng1996

前端 React

JWT本无状态,为何却要存储在Redis破坏其无状态特性?

知识浅谈

JWT 9月月更

你知道数据资产管理的目标是什么?

雨果

数据中台 数据资产管理

大数据ELK(二):Elasticsearch简单介绍

Lansonli

elasticsearch 9月月更

MFC模拟消息发送,自定义以及系统消息

中国好公民st

c++ 消息分发 9月月更

TCPIP协议栈的心跳、丢包重传、连接超时机制实例详解

Java快了!

网易易盾 GameSentry 正式开源,做游戏安全保障的尖兵利刃

网易智企

安全 测试

心血来潮,手绘一张Spring学习思维,内容详细全面,秋招面试必看!

收到请回复

Java 云计算 开源 架构 编程语言

什么是访问控制列表ACL?

wljslmz

acl 访问控制列表 9月月更

谁能说清楚数据资产管理与数据治理是什么关系?

雨果

数据治理

为超级品牌打造「上瘾算法」|Whale 帷幄发布全新 DAM & VAP 内容数字化产品

科技热闻

PANews与NFTScan联合推出Top50 NFT Collection全球影响力榜单

NFT Research

Ethereum NFT

SAP ABAP 平台新的编程模型

Jerry Wang

SAP abap Netweaver 思爱普 9月月更

怎样才能开一场高效的迭代评审会?

LigaAI

Scrum 迭代 LigaAI 敏捷实践 企业号九月金秋榜

前端经典面试题(有答案)

loveX001

JavaScript 前端

20道高频react面试题(附答案)

beifeng1996

前端 React

华为云宣布全面建设全球初创生态,3年内赋能10000家高潜初创企业

华为云开发者联盟

云计算 创业 创新创业 企业号九月金秋榜

[极致用户体验] 让你的网页,适配微信大字号模式!体验超好,快来收藏

HullQin

CSS JavaScript html 前端 9月月更

现代数据栈如何降低数据平台的复杂度?

Kyligence

数据分析 云原生 指标中台 指标自动化

Java进阶(二十一)java 空字符串与null区别

No Silver Bullet

Java null 9月月更 空字符串

【HTML-CSS】小游戏--渣灰哥的愿望之砍砍渣灰

Sam9029

JavaScript HTML5, CSS3 9月月更

“基础-中级-高级”Java程序员面试合集,看完献出我的膝盖!

收到请回复

Java 云计算 开源 架构 编程语言

2022前端经典vue面试题(持续更新中)

bb_xiaxia1998

Vue 前端

数据、管理、分析和运营:大数据专家面临的四大挑战!

雨果

大数据

一线架构师开发总结:剖析并发编程+JVM性能,深入Tomcat与MySQL!

收到请回复

Java 云计算 开源 架构 编程语言

Github星标90K!京东架构师一篇讲明白百亿级并发系统架构设计

了不起的程序猿

Java 程序员 高并发 java程序员 高并发系统设计

3D打印机打印模型的10大技巧

Dylan

3D模型

Scrum 实施过程的主要内容及5大常用工具

PingCode

谁来说说数据质量评估的标准是什么?

雨果

数据质量

2022前端二面必会vue面试题汇总

bb_xiaxia1998

Vue 前端

如何使用半监督学习为结构化数据训练出更好的深度学习模型_AI&大模型_Youness Mansar_InfoQ精选文章