写点什么

如何使用半监督学习为结构化数据训练出更好的深度学习模型

  • 2020-10-22
  • 本文字数:2368 字

    阅读完需:约 8 分钟

如何使用半监督学习为结构化数据训练出更好的深度学习模型

本文最初发表于 Towards Data Science 博客,经原作者 Youness Mansar 授权,InfoQ 中文站翻译并分享。


众所周知,深度学习在应用于文本、音频或图像等非结构化数据时效果很好,但在应用于结构化或表格化数据时,深度学习有时会落后于其他机器学习方法,如梯度提升等。在本文中,我们将使用半监督学习来提高深度神经模型在低数据环境下应用于结构化数据时的性能。我们将展示通过使用无监督的预训练,可以使神经模型的性能优于梯度提升。


本文是基于以下两篇论文:



我们实现了一个类似于 AutoInt 论文中提出的深度神经结构,使用了多头自注意力和特征嵌入。预训练部分取自 TabNet 的论文。

方法说明

我们将处理结构化数据,这意味着可以将数据写成具有列(数字、分类、序号)和行的表。我们还假设我们有大量的未标记样本,可以用于预训练,以及少量的标记样本,可用于监督学习。在接下来的实验中,我们将模拟这个环境来绘制学习曲线,并在使用不同大小的标记集时对该方法进行评估。

数据准备

让我们用一个例子来描述在将数据提供给神经网络之前我们是如何准备数据的。



在这个例子中,我们有三个样本和三个特征 {F1,F2,F3} 和一个目标。F1 是分类特征,而 F2 F3 是数字特征。


我们将为 F1 的每个模态 X 创建一个新特征 F1_X,如果 F1==X,则为其赋值 1,否则等于 0。


转换后的样本将写入一组 (Feature_Name, Feature_Value)


例如:


第一个样本 → {(F1_A, 1), (F2, 0.3), (F3, 1.3)}


第二个样本 → {(F1_B, 1), (F2, 0.4), (F3, 0.9)}


第三个样本 → {(F1_C, 1), (F2, 0.1), (F3, 0.8)}


特征名称将被馈送到嵌入层,然后与特征值相乘。

模型:

这里使用的模型是一个多头注意力块序列和逐点前馈层。在训练时,我们也使用池化的注意力跳过连接。多头注意力模块允许我们对特征之间可能存在的交互进行建模,而池化的注意力跳过连接允许我们从一组特征嵌入中获得单个向量。


预训练

在预训练步骤中,我们使用完整的未标记数据集,输入特征的损坏版本,并训练模型来预测未损坏的特征,类似于在去噪自动编码器中所做的操作。

监督式训练

在训练的监督部分,我们在编码器部分和输出端之间添加跳过连接,并尝试预测目标。


实验

在接下来的实验中,我们将使用四个数据集,其中两个用于回归,两个用于分类。


  • Sarco:有大约 5 万个样本,21 个特征和 7 个连续目标。

  • Online News:有 4 万个左右的样本,61 个特征和 1 个连续目标。

  • Adult Census:有大约 4 万个样本、15 个特征和 1 个二元目标。

  • Forest Cover:有大约 50 万个样本,54 个特征和 1 个分类目标。


我们将比较一个预训练神经模型和一个从零开始训练的神经模型,将重点关注地数据状态下的性能,这意味着几百到几千个标记样本。我们还将于一个流行的名为lightgbm的梯度提升实现进行比较。

Forest Cover:

Adult Census:


对于这个数据集,我们可以看到,如果训练集小于 2000,那么预训练是非常有效的。

Online News:

对于 Online News 数据集,我们可以看到,预训练神经网络是非常有效的,甚至在所有样本大小为 500 或更大的情况下都超过了梯度提升。



对于 Sarco 数据集,我们可以看到,预训练神经网络是非常有效的,甚至在所有样本大小的情况下超过了梯度提升。


旁注:用于重现结果的代码

重现结果的代码可以在这里找到:


https://github.com/CVxTz/DeepTabular


使用这段代码,你可以很轻松地训练分类或回归模型:


import pandas as pdfrom sklearn.model_selection import train_test_splitfrom deeptabular.deeptabular import DeepTabularClassifierif __name__ == "__main__":data = pd.read_csv("../data/census/adult.csv")train, test = train_test_split(data, test_size=0.2, random_state=1337)target = "income"num_cols = ["age", "fnlwgt", "capital.gain", "capital.loss", "hours.per.week"]cat_cols = ["workclass","education","education.num","marital.status","occupation","relationship","race","sex","native.country",]for k in num_cols:mean = train[k].mean()std = train[k].std()train[k] = (train[k] - mean) / stdtest[k] = (test[k] - mean) / stdtrain[target] = train[target].map({"<=50K": 0, ">50K": 1})test[target] = test[target].map({"<=50K": 0, ">50K": 1})classifier = DeepTabularClassifier(num_layers=10, cat_cols=cat_cols, num_cols=num_cols, n_targets=1,)classifier.fit(train, target_col=target, epochs=128)pred = classifier.predict(test)classifier.save_config("census_config.json")classifier.save_weigts("census_weights.h5")new_classifier = DeepTabularClassifier()new_classifier.load_config("census_config.json")new_classifier.load_weights("census_weights.h5")new_pred = new_classifier.predict(test)
复制代码

结论

在计算机视觉或自然语言领域,无监督预训练可以提高神经网络的性能。在本文中,我们展示了它在应用于结构化数据时也能起作用,使其在低数据环境与其他机器学习方法(如梯度提升)具有竞争力。


作者简介:


Youness Mansar,供职于 Fortia Financial Solutions 的数据科学家。巴黎中央理工学院(Ecole Centrale Paris)应用数学硕士学位和巴黎-萨克雷高等师范学校(École normale supérieure Paris-Saclay)机器学习硕士。作为 Fortia 的数据科学家,曾参与过多个涉及自然语言处理和深度学习的项目。


原文链接:


https://towardsdatascience.com/training-better-deep-learning-models-for-structured-data-using-semi-supervised-learning-8acc3b536319


2020-10-22 09:002466
用户头像
刘燕 InfoQ高级技术编辑

发布了 1112 篇内容, 共 575.9 次阅读, 收获喜欢 1981 次。

关注

评论

发布
暂无评论
发现更多内容

算法大赛报名 | OMG!这些名企的真实数据竟用来battle

工赋开发者社区

算法 工业互联网

再谈BOM和DOM(4):DOM0/DOM2事件处理分析

zhoulujun

DOM DOM事件 DOM0 DOM2

三年开发经验,字节跳动抖音组离职后,一口气拿到15家公司Offer

Java架构师迁哥

再谈BOM和DOM(2):DOM节点层次/属性/选择器/节点关系/操作详解

zhoulujun

JavaScript DOM BOM 对象模型 文档模型

Vue进阶(幺叁捌):vue路由传参的几种基本方式

No Silver Bullet

Vue 路由 7月日更

7.24 杭州站 | 阿里云 Serverless Developer Meetup 开放报名!

Serverless Devs

云计算 阿里云 Serverless 云原生

字节取消“大小周”,管理者与员工的“灵魂争夺战"从未停歇

架构实战营模块八作业

竹林七贤

5分钟速读之Rust权威指南(四十一)高级类型

wzx

rust

医美行业哪个环节最赚钱?

石云升

行业分析 7月日更

【redis前传】自己手写一个LRU策略

zxhtom

Java redis 原理 造轮子 jdk运用

JVM锁bug导致G1 GC挂起问题分析和解决

毕昇JDK社区

没想到我也可以入职阿里!二本毕业、两年crud经验,侥幸通过面试定级P6

Java架构师迁哥

再谈BOM和DOM(3):DOM节点操作-元素样式修改及DOM内容增删改查

zhoulujun

DOM BOM 文档对象 DOM结点操作 DOM增删改查

再谈BOM和DOM(5):各个大流浪器DOM和BOM里面的那些坑—兼容性

zhoulujun

DOM事件兼容性

我乃平常客,本持平常心| 2021 年中总结

编程三昧

程序人生 大前端 代码人生

hadoop 1.0 和 hadoop 2.0 的区别

五分钟学大数据

hadoop 7月日更

架构实战营1期第二模块作业

五只羊

架构实战营

Ta想做一粒智慧的种子

脑极体

神来之笔,2021CTF内核漏洞精选解析

网络安全学海

网络安全 信息安全 CTF 安全漏洞 渗透测试·

重温历史 致敬百年 “复兴大道100号”线上VR展馆正式开馆

百度大脑

百度 虚拟现实

性能测试软启动初探

FunTester

性能测试 接口测试 测试框架 压力测试 测试开发

攒塑料袋,究竟是如何刻进中国人DNA的?

脑极体

火爆 GitHub!这个图像分割神器开源了

百度大脑

百度 算法

物联网安全难题还需行业标杆来解

熵核科技

物联网安全

拥抱云原生,腾讯发布TCSS容器安全服务!

腾讯安全云鼎实验室

容器 云原生

保洁阿姨分享:腾讯架构师JDK源码笔记,13万字,带你飙向实战

Java架构师迁哥

再谈BOM和DOM(1):BOM与DOM概述

zhoulujun

JavaScript DOM BOM 对象模型 文档模型

数据仓库的基本概念

大数据技术指南

7月日更

fil矿机怎么选择?用什么fil矿机比较好?

FIL矿机怎么买 fil挖矿

熵核科技,自主研发虚拟机赋能安全操作系统

熵核科技

支付安全 安全操作系统 物联网安全 eSIM安全

如何使用半监督学习为结构化数据训练出更好的深度学习模型_AI&大模型_Youness Mansar_InfoQ精选文章