NVIDIA 初创加速计划,免费加速您的创业启动 了解详情
写点什么

如何使用半监督学习为结构化数据训练出更好的深度学习模型

  • 2020-10-22
  • 本文字数:2368 字

    阅读完需:约 8 分钟

如何使用半监督学习为结构化数据训练出更好的深度学习模型

本文最初发表于 Towards Data Science 博客,经原作者 Youness Mansar 授权,InfoQ 中文站翻译并分享。


众所周知,深度学习在应用于文本、音频或图像等非结构化数据时效果很好,但在应用于结构化或表格化数据时,深度学习有时会落后于其他机器学习方法,如梯度提升等。在本文中,我们将使用半监督学习来提高深度神经模型在低数据环境下应用于结构化数据时的性能。我们将展示通过使用无监督的预训练,可以使神经模型的性能优于梯度提升。


本文是基于以下两篇论文:



我们实现了一个类似于 AutoInt 论文中提出的深度神经结构,使用了多头自注意力和特征嵌入。预训练部分取自 TabNet 的论文。

方法说明

我们将处理结构化数据,这意味着可以将数据写成具有列(数字、分类、序号)和行的表。我们还假设我们有大量的未标记样本,可以用于预训练,以及少量的标记样本,可用于监督学习。在接下来的实验中,我们将模拟这个环境来绘制学习曲线,并在使用不同大小的标记集时对该方法进行评估。

数据准备

让我们用一个例子来描述在将数据提供给神经网络之前我们是如何准备数据的。



在这个例子中,我们有三个样本和三个特征 {F1,F2,F3} 和一个目标。F1 是分类特征,而 F2 F3 是数字特征。


我们将为 F1 的每个模态 X 创建一个新特征 F1_X,如果 F1==X,则为其赋值 1,否则等于 0。


转换后的样本将写入一组 (Feature_Name, Feature_Value)


例如:


第一个样本 → {(F1_A, 1), (F2, 0.3), (F3, 1.3)}


第二个样本 → {(F1_B, 1), (F2, 0.4), (F3, 0.9)}


第三个样本 → {(F1_C, 1), (F2, 0.1), (F3, 0.8)}


特征名称将被馈送到嵌入层,然后与特征值相乘。

模型:

这里使用的模型是一个多头注意力块序列和逐点前馈层。在训练时,我们也使用池化的注意力跳过连接。多头注意力模块允许我们对特征之间可能存在的交互进行建模,而池化的注意力跳过连接允许我们从一组特征嵌入中获得单个向量。


预训练

在预训练步骤中,我们使用完整的未标记数据集,输入特征的损坏版本,并训练模型来预测未损坏的特征,类似于在去噪自动编码器中所做的操作。

监督式训练

在训练的监督部分,我们在编码器部分和输出端之间添加跳过连接,并尝试预测目标。


实验

在接下来的实验中,我们将使用四个数据集,其中两个用于回归,两个用于分类。


  • Sarco:有大约 5 万个样本,21 个特征和 7 个连续目标。

  • Online News:有 4 万个左右的样本,61 个特征和 1 个连续目标。

  • Adult Census:有大约 4 万个样本、15 个特征和 1 个二元目标。

  • Forest Cover:有大约 50 万个样本,54 个特征和 1 个分类目标。


我们将比较一个预训练神经模型和一个从零开始训练的神经模型,将重点关注地数据状态下的性能,这意味着几百到几千个标记样本。我们还将于一个流行的名为lightgbm的梯度提升实现进行比较。

Forest Cover:

Adult Census:


对于这个数据集,我们可以看到,如果训练集小于 2000,那么预训练是非常有效的。

Online News:

对于 Online News 数据集,我们可以看到,预训练神经网络是非常有效的,甚至在所有样本大小为 500 或更大的情况下都超过了梯度提升。



对于 Sarco 数据集,我们可以看到,预训练神经网络是非常有效的,甚至在所有样本大小的情况下超过了梯度提升。


旁注:用于重现结果的代码

重现结果的代码可以在这里找到:


https://github.com/CVxTz/DeepTabular


使用这段代码,你可以很轻松地训练分类或回归模型:


import pandas as pdfrom sklearn.model_selection import train_test_splitfrom deeptabular.deeptabular import DeepTabularClassifierif __name__ == "__main__":data = pd.read_csv("../data/census/adult.csv")train, test = train_test_split(data, test_size=0.2, random_state=1337)target = "income"num_cols = ["age", "fnlwgt", "capital.gain", "capital.loss", "hours.per.week"]cat_cols = ["workclass","education","education.num","marital.status","occupation","relationship","race","sex","native.country",]for k in num_cols:mean = train[k].mean()std = train[k].std()train[k] = (train[k] - mean) / stdtest[k] = (test[k] - mean) / stdtrain[target] = train[target].map({"<=50K": 0, ">50K": 1})test[target] = test[target].map({"<=50K": 0, ">50K": 1})classifier = DeepTabularClassifier(num_layers=10, cat_cols=cat_cols, num_cols=num_cols, n_targets=1,)classifier.fit(train, target_col=target, epochs=128)pred = classifier.predict(test)classifier.save_config("census_config.json")classifier.save_weigts("census_weights.h5")new_classifier = DeepTabularClassifier()new_classifier.load_config("census_config.json")new_classifier.load_weights("census_weights.h5")new_pred = new_classifier.predict(test)
复制代码

结论

在计算机视觉或自然语言领域,无监督预训练可以提高神经网络的性能。在本文中,我们展示了它在应用于结构化数据时也能起作用,使其在低数据环境与其他机器学习方法(如梯度提升)具有竞争力。


作者简介:


Youness Mansar,供职于 Fortia Financial Solutions 的数据科学家。巴黎中央理工学院(Ecole Centrale Paris)应用数学硕士学位和巴黎-萨克雷高等师范学校(École normale supérieure Paris-Saclay)机器学习硕士。作为 Fortia 的数据科学家,曾参与过多个涉及自然语言处理和深度学习的项目。


原文链接:


https://towardsdatascience.com/training-better-deep-learning-models-for-structured-data-using-semi-supervised-learning-8acc3b536319


公众号推荐:

跳进 AI 的奇妙世界,一起探索未来工作的新风貌!想要深入了解 AI 如何成为产业创新的新引擎?好奇哪些城市正成为 AI 人才的新磁场?《中国生成式 AI 开发者洞察 2024》由 InfoQ 研究中心精心打造,为你深度解锁生成式 AI 领域的最新开发者动态。无论你是资深研发者,还是对生成式 AI 充满好奇的新手,这份报告都是你不可错过的知识宝典。欢迎大家扫码关注「AI前线」公众号,回复「开发者洞察」领取。

2020-10-22 09:001529
用户头像
刘燕 InfoQ高级技术编辑

发布了 1112 篇内容, 共 494.0 次阅读, 收获喜欢 1967 次。

关注

评论

发布
暂无评论
发现更多内容

创新监管首批8个试点应用公示 其中7个涉及区块链

CECBC

Netty-物联网设备Channel管理

凸出

Java Netty ConcurrentHashMap 物联网 channel

第7周总结:性能

慵秋

百度大脑OCR技术助力钢铁物流实现智能管理

百度大脑

人工智能 百度大脑 文字识别

数字货币并不能完美诠释区块链金融

CECBC

区块链技术 社会价值 打通数据孤岛 重建产业信用

Self-Compassion,对自己好一点

霍太稳@极客邦科技

创业 个人成长 自我管理 创业心态

架构感悟 7- 性能优化何为

旭东(Frank)

关于数据库索引的知识点,你所需要了解的都在这儿了

鄙人薛某

MySQL 索引结构 索引 MySQL优化

Java如何调用Python(二)

wjchenge

程序设计理念-CentOs7实践Nginx-带来安装服务的通用法则

图南日晟

nginx 架构设计 环境安装

web压力性能测试

周冬辉

压力测试

为啥Underlay才是容器网络的最佳落地选择

BoCloud博云

云计算 容器

进击的 Flink:网易云音乐实时数仓建设实践

Apache Flink

flink

超详细讲解网络中的数据链路层~

程序员的时光

技术​选型的艺术

YourBatman

技术选型 湖北

Kubernetes的拐点助推器:左手开源,右手边缘计算

华为云开发者联盟

Kubernetes 容器 边缘计算 华为云

创业使人成长系列 (4)- 常用账号申请

石云升

支付宝 微信商户 商标

Django Models随机获取指定数量数据方法

BigYoung

django 数据 random 随机 Models

漫画:如何证明sleep不释放锁,而wait释放锁?

王磊

Java Wait Sleep

API网关——Kong实践分享

BoCloud博云

云计算 容器 PaaS API

PV与UV你的网站也可以

北漂码农有话说

挑战10的1,143,913次方种算法组合:这都不是事儿!

华为云开发者联盟

华为 算法 进化 华为云

为什么我们要自主开发一个稳定可靠的容器网络

BoCloud博云

云计算 容器 PaaS fabric

BIGO海量小文件存储实践

InfoQ_3597a20b53cc

【数据结构】Java 常用集合类 ConcurrentHashMap(JDK 1.8)

Alex🐒

Java 源码 数据结构 并发编程

【数据结构】Java 常用集合类 HashMap(JDK 1.8)

Alex🐒

Java 源码 数据结构

计算机网络基础(八)---网络层-路由概述

书旅

计算机网络 网络协议 计算机基础 AS

web 性能压测工具

Z冰红茶

原生Ingress灰度发布能力不够?我们是这么干的

BoCloud博云

云计算 容器 云原生 PaaS

随着并发压力的增加,系统响应时间和吞吐量如何变化,为什么?

chenzt

前浪出新招,996已过时,互联网员工都开始住公司了!(爆公司信息)

程序员生活志

加班 996 007 互联网公司

如何使用半监督学习为结构化数据训练出更好的深度学习模型_AI&大模型_Youness Mansar_InfoQ精选文章