写点什么

如何使用半监督学习为结构化数据训练出更好的深度学习模型

2020 年 10 月 22 日

如何使用半监督学习为结构化数据训练出更好的深度学习模型

本文最初发表于 Towards Data Science 博客,经原作者 Youness Mansar 授权,InfoQ 中文站翻译并分享。


众所周知,深度学习在应用于文本、音频或图像等非结构化数据时效果很好,但在应用于结构化或表格化数据时,深度学习有时会落后于其他机器学习方法,如梯度提升等。在本文中,我们将使用半监督学习来提高深度神经模型在低数据环境下应用于结构化数据时的性能。我们将展示通过使用无监督的预训练,可以使神经模型的性能优于梯度提升。


本文是基于以下两篇论文:



我们实现了一个类似于 AutoInt 论文中提出的深度神经结构,使用了多头自注意力和特征嵌入。预训练部分取自 TabNet 的论文。


方法说明

我们将处理结构化数据,这意味着可以将数据写成具有列(数字、分类、序号)和行的表。我们还假设我们有大量的未标记样本,可以用于预训练,以及少量的标记样本,可用于监督学习。在接下来的实验中,我们将模拟这个环境来绘制学习曲线,并在使用不同大小的标记集时对该方法进行评估。


数据准备

让我们用一个例子来描述在将数据提供给神经网络之前我们是如何准备数据的。



在这个例子中,我们有三个样本和三个特征 {F1,F2,F3} 和一个目标。F1 是分类特征,而 F2 F3 是数字特征。


我们将为 F1 的每个模态 X 创建一个新特征 F1_X,如果 F1==X,则为其赋值 1,否则等于 0。


转换后的样本将写入一组 (Feature_Name, Feature_Value)


例如:


第一个样本 → {(F1_A, 1), (F2, 0.3), (F3, 1.3)}


第二个样本 → {(F1_B, 1), (F2, 0.4), (F3, 0.9)}


第三个样本 → {(F1_C, 1), (F2, 0.1), (F3, 0.8)}


特征名称将被馈送到嵌入层,然后与特征值相乘。


模型:

这里使用的模型是一个多头注意力块序列和逐点前馈层。在训练时,我们也使用池化的注意力跳过连接。多头注意力模块允许我们对特征之间可能存在的交互进行建模,而池化的注意力跳过连接允许我们从一组特征嵌入中获得单个向量。



预训练

在预训练步骤中,我们使用完整的未标记数据集,输入特征的损坏版本,并训练模型来预测未损坏的特征,类似于在去噪自动编码器中所做的操作。


监督式训练

在训练的监督部分,我们在编码器部分和输出端之间添加跳过连接,并尝试预测目标。



实验

在接下来的实验中,我们将使用四个数据集,其中两个用于回归,两个用于分类。


  • Sarco:有大约 5 万个样本,21 个特征和 7 个连续目标。

  • Online News:有 4 万个左右的样本,61 个特征和 1 个连续目标。

  • Adult Census:有大约 4 万个样本、15 个特征和 1 个二元目标。

  • Forest Cover:有大约 50 万个样本,54 个特征和 1 个分类目标。


我们将比较一个预训练神经模型和一个从零开始训练的神经模型,将重点关注地数据状态下的性能,这意味着几百到几千个标记样本。我们还将于一个流行的名为lightgbm的梯度提升实现进行比较。


Forest Cover:


Adult Census:


对于这个数据集,我们可以看到,如果训练集小于 2000,那么预训练是非常有效的。


Online News:

对于 Online News 数据集,我们可以看到,预训练神经网络是非常有效的,甚至在所有样本大小为 500 或更大的情况下都超过了梯度提升。



对于 Sarco 数据集,我们可以看到,预训练神经网络是非常有效的,甚至在所有样本大小的情况下超过了梯度提升。



旁注:用于重现结果的代码

重现结果的代码可以在这里找到:


https://github.com/CVxTz/DeepTabular


使用这段代码,你可以很轻松地训练分类或回归模型:


import pandas as pdfrom sklearn.model_selection import train_test_splitfrom deeptabular.deeptabular import DeepTabularClassifierif __name__ == "__main__":data = pd.read_csv("../data/census/adult.csv")train, test = train_test_split(data, test_size=0.2, random_state=1337)target = "income"num_cols = ["age", "fnlwgt", "capital.gain", "capital.loss", "hours.per.week"]cat_cols = ["workclass","education","education.num","marital.status","occupation","relationship","race","sex","native.country",]for k in num_cols:mean = train[k].mean()std = train[k].std()train[k] = (train[k] - mean) / stdtest[k] = (test[k] - mean) / stdtrain[target] = train[target].map({"<=50K": 0, ">50K": 1})test[target] = test[target].map({"<=50K": 0, ">50K": 1})classifier = DeepTabularClassifier(num_layers=10, cat_cols=cat_cols, num_cols=num_cols, n_targets=1,)classifier.fit(train, target_col=target, epochs=128)pred = classifier.predict(test)classifier.save_config("census_config.json")classifier.save_weigts("census_weights.h5")new_classifier = DeepTabularClassifier()new_classifier.load_config("census_config.json")new_classifier.load_weights("census_weights.h5")new_pred = new_classifier.predict(test)
复制代码


结论

在计算机视觉或自然语言领域,无监督预训练可以提高神经网络的性能。在本文中,我们展示了它在应用于结构化数据时也能起作用,使其在低数据环境与其他机器学习方法(如梯度提升)具有竞争力。


作者简介:


Youness Mansar,供职于 Fortia Financial Solutions 的数据科学家。巴黎中央理工学院(Ecole Centrale Paris)应用数学硕士学位和巴黎-萨克雷高等师范学校(École normale supérieure Paris-Saclay)机器学习硕士。作为 Fortia 的数据科学家,曾参与过多个涉及自然语言处理和深度学习的项目。


原文链接:


https://towardsdatascience.com/training-better-deep-learning-models-for-structured-data-using-semi-supervised-learning-8acc3b536319


2020 年 10 月 22 日 09:00895
用户头像
刘燕 InfoQ记者

发布了 693 篇内容, 共 222.9 次阅读, 收获喜欢 1336 次。

关注

评论

发布
暂无评论
发现更多内容

金九银十跳槽必备大厂Java面试宝典,近千道面试题加详解,刷完你也进大厂

Java技术那些事

Java 程序员 面试 计算机 8月日更

MySQL优化-批量插入与1亿条数据效率COUNT

一个大红包

8月日更

面试字节跳动java岗被算法吊打,60天苦修这些笔记,侥幸收获offer

Java~~~

Java 架构 面试 算法 红黑树

crudapi增删改查接口零代码产品成功案例之金茶王投票系统

crudapi

Vue API crud crudapi 投票

宅家复习一个月,成功入职腾讯,才知道算法实在太太太重要了

Crud的程序员

面试 算法 数据结构与算法

模块五设计微博评论高性能高可用计算架构

kitten

校招失败,在小公司熬了2年后我终于进了字节跳动,竭尽全力(Java岗)

今晚早点睡

Java 编程 字节跳动 面试 计算机

太香了!阿里高工携18位架构师耗时57天整合的1658页金九银十面试押题宝典全新开源

程序员小毕

Java spring 程序员 架构 面试

终于学完国内算法第一人10年经验总结的数据结构与算法详解文档

公众号_愿天堂没有BUG

Java 编程 程序员 架构 面试

linux 工具之pstack/gstack

糖米唐爹

RunC TOCTOU逃逸CVE-2021-30465分析

腾讯安全云鼎实验室

容器安全 漏洞分析

阿里 Java 二面必问:8张图带你解决所有TCP可靠传输问题

云流

编程 面试 TCP 网络 计算机

云原生,开发者的黄金时代

阿里巴巴云原生

云计算 阿里云 云原生 中间件

Vue进阶(四十二):var、let、const 三者区别

No Silver Bullet

Vue var const let 8月日更

架构实战营模块 5 作业指导

华仔

#架构实战营

【我和达梦的故事】 有奖征文活动开始啦,万元奖品池+现金奖励等你拿!

墨天轮

数据库 征文大赛 国产数据库 达梦

0代码之缘

明道云

三面头条,靠P9级算法大牛分享的两本算法pdf书籍,轻松拿到offer

Crud的程序员

程序员 算法 编程语言 数据结构与算法

架构实战营1期模块5作业——高性能计算架构

tt

架构实战营

云原生,开发者的黄金时代

阿里巴巴中间件

云计算 阿里云 云原生 中间件

谷歌架构师分享gRPC与云原生应用开发Go和Java为例文档

公众号_愿天堂没有BUG

Java 编程 程序员 架构 面试

终于拿到了深入Java虚拟机:JVMG1GC的算法与实现文档

公众号_愿天堂没有BUG

Java 编程 程序员 架构 面试

Nebula Operator 云上实践

Nebula Graph

阿里云 云原生 k8s 图数据库 分布式图数据库

面面俱到!腾讯大牛把源码分析、基础案例、实战案例、面试、系统架构,全部总结到这份Java多线程与高并发里面了

云流

Java 编程 程序员 面试 多线程

准备两个月,面试五分钟,Java中高级岗面试为何越来越难?

程序员改bug

Java spring 程序员 架构 编程语言

深入分析JavaScript模块循环引用

普普通通程序员

在Alibaba广受喜爱的“Java突击宝典”简直太牛了

Crud的程序员

Java 编程 架构 编程语言

不会被开除吧?一顿饭换来“字节”面试题库Java岗,刷完直接入职大厂

Java架构师迁哥

我们可能是被工具耽误的一代

非著名程序员

深度思考 认知提升 成长笔记 8月日更

蚂蚁金服+拼多多+抖音+天猫(技术三面)面经合集助你拿大厂offer

程序员改bug

Java 程序员 面试 编程语言

教你快速从SQL过度到Elasticsearch的DSL查询

Java技术那些事

Java 编程 程序员 计算机 8月日更

如何使用半监督学习为结构化数据训练出更好的深度学习模型-InfoQ