写点什么

详解 Google 多任务学习模型 MMoE ( KDD 2018 )

2019 年 6 月 26 日

详解 Google 多任务学习模型 MMoE ( KDD 2018 )

文章发表在 KDD 2018 Research Track 上,Modeling Task Relationships in Multi-task Learning with Multi-gate Mixture-of-Experts 。地址:


https://www.kdd.org/kdd2018/accepted-papers/view/modeling-task-relationships-in-multi-task-learning-with-multi-gate-mixture-


在工业界基于神经网络的多任务学习在推荐等场景业务应用广泛,比如在推荐系统中对用户推荐物品时,不仅要推荐用户感兴趣的物品,还要尽可能地促进转化和购买,因此要对用户评分和购买两种目标同时建模。阿里之前提出的 ESSM 模型属于同时对点击率和转换率进行建模,提出的模型是典型的 shared-bottom 结构。多任务学习中有个问题就是如果子任务差异很大,往往导致多任务模型效果不佳。今天要介绍的这篇文章是谷歌的一个内容推荐团队考虑了多任务之间的区别提出了 MMoE 模型,并取得了不错的效果。


一、Motivation

多任务模型通过学习不同任务的联系和差异,可提高每个任务的学习效率和质量。多任务学习的的框架广泛采用 shared-bottom 的结构,不同任务间共用底部的隐层。这种结构本质上可以减少过拟合的风险,但是效果上可能受到任务差异和数据分布带来的影响。也有一些其他结构,比如两个任务的参数不共用,但是通过对不同任务的参数增加 L2 范数的限制;也有一些对每个任务分别学习一套隐层然后学习所有隐层的组合。和 shared-bottom 结构相比,这些模型对增加了针对任务的特定参数,在任务差异会影响公共参数的情况下对最终效果有提升。缺点就是模型增加了参数量所以需要更大的数据量来训练模型,而且模型更复杂并不利于在真实生产环境中实际部署使用。


因此,论文中提出了一个 Multi-gate Mixture-of-Experts ( MMoE ) 的多任务学习结构。MMoE 模型刻画了任务相关性,基于共享表示来学习特定任务的函数,避免了明显增加参数的缺点。


二、模型介绍

MMoE 模型的结构 ( 下图 c ) 基于广泛使用的 Shared-Bottom 结构 ( 下图 a ) 和 MoE 结构,其中图 ( b ) 是图 ( c ) 的一种特殊情况,下面依次介绍。



  • Shared-Bottom Multi-task Model


如上图 a 所示,shared-bottom 网络 ( 表示为函数 f ) 位于底部,多个任务共用这一层。往上,K 个子任务分别对应一个 tower network ( 表示为 ) ,每个子任务的输出


  • Mixture-of-Experts


MoE 模型可以形式化表示为:



其中 是 n 个 expert network ( expert network 可认为是一个神经网络 ) 。g 是组合 experts 结果的 gating network ,具体来说 g 产生 n 个 experts 上的概率分布,最终的输出是所有 experts 的带权加和。显然,MoE 可看做基于多个独立模型的集成方法。这里注意 MoE 并不对应上图中的 b 部分。


后面有些文章将 MoE 作为一个基本的组成单元,将多个 MoE 结构堆叠在一个大网络中。比如一个 MoE 层可以接受上一层 MoE 层的输出作为输入,其输出作为下一层的输入使用。


  • Multi-gate Mixture-of-Experts


文章提出的模型 ( 简称 MMoE ) 目的就是相对于 shared-bottom 结构不明显增加模型参数的要求下捕捉任务的不同。其核心思想是将 shared-bottom 网络中的函数 f 替换成 MoE 层,如上图 c 所示,形式化表达为:



其中 ,输入就是 input feature ,输出是所有 experts 上的权重。


一方面,因为 gating networks 通常是轻量级的,而且 expert networks 是所有任务共用,所以相对于论文中提到的一些 baseline 方法在计算量和参数量上具有优势。


另一方面,相对于所有任务公共一个门控网络 ( One-gate MoE model ,如上图 b ) ,这里 MMoE ( 上图 c ) 中每个任务使用单独的 gating networks 。每个任务的 gating networks 通过最终输出权重不同实现对 experts 的选择性利用。不同任务的 gating networks 可以学习到不同的组合 experts 的模式,因此模型考虑到了捕捉到任务的相关性和区别。


三、总结

整体来看,这篇文章是对多任务学习的一个扩展,通过门控网络的机制来平衡多任务的做法在真实业务场景中具有借鉴意义。下面补充介绍文中的一个数据集设置的做法和实验结果中对不同模型的相互对比分析。


  • 人工构造数据集


在真实数据集中我们无法改变任务之间的相关性,所以不太方便进行研究任务相关性对多任务模型的影响。轮文中人工构建了两个回归任务的数据集,然后通过两个任务的标签的 Pearson 相关系数来作为任务相关性的度量。在工业界中通过人工构造的数据集来验证自己的假设是个有意思的做法。


  • 模型的可训练性


模型的可训练性,就是模型对于超参数和初始化是否足够鲁棒。作者在人工合成数据集上进行了实验,观察不同随机种子和模型初始化方法对 loss 的影响。这里简单介绍下两个现象:第一,Shared-Bottom models 的效果方差要明显大于基于 MoE 的方法,说明 Shared-Bottom 模型有很多偏差的局部最小点;第二,如果任务相关度非常高,则 OMoE 和 MMoE 的效果近似,但是如果任务相关度很低,则 OMoE 的效果相对于 MMoE 明显下降,说明 MMoE 中的 multi-gate 的结构对于任务差异带来的冲突有一定的缓解作用。


作者介绍


杨镒铭,阿里妈妈算法专家。硕士毕业于中国科学技术大学,记录广告、推荐等方面的模型积累 @知乎专栏作者。


本文来自 DataFun 社区


原文链接


https://mp.weixin.qq.com/s/2Rc6W82Iy6rTyWa14Yf9Gg


2019 年 6 月 26 日 08:009223

评论

发布
暂无评论
发现更多内容

八、Kubernetes 入门实践

悟尘

Docker Kubernetes 容器 k8s Compose

使用Typora + PicGo 图床 + jsDelivr CDN实现高效 Markdown 创作

悟尘

Typora PicGo iPic jsDelivr CDN

长假将至,推荐两个好东西

池建强

算法 视觉笔记

告诉你一个学习编程的诀窍(建议收藏)

ithuangqing

学习 编程 自学编程

程序员到底应该学习什么语言好?

页面仔小杨

H5功能足够强大,为什么还要微信小程序?

顾强

微信小程序 移动应用

附录1、Docker 常用命令及示例

悟尘

Docker 容器

附录4、Docker-compose 配置文件编写指南

悟尘

Docker Docker-compose

意想不到的收获哦

南辞

写在开头

杨友峰

Java 期现

从少儿编程讲讲开发行业的大趋势

kimmking

在线教育 少儿编程

四、Docker 网络原理、分类及容器互联配置

悟尘

Docker Kubernetes 容器 k8s Compose

七、Docker Compose 入门实践

悟尘

Docker Kubernetes 容器 k8s Compose

为什么说此前的WiFi安全方案都是小弟?

石君

wifi 无线网络 无线网络安全 Wi-Fi安全

Hexo-admonition 插件安装使用指南

悟尘

Hexo Hexo-admonition Admonition

五、Docker 数据持久化存储与性能调优

悟尘

Docker 容器 k8s Compose kubernet

六、基于多阶段构建减小镜像体积降低复杂度

悟尘

Docker Kubernetes 容器 k8s Compose

附录3、Docker-compose 命令使用指南

悟尘

Docker Docker-compose

Hexo-deployer-cos-cdn 插件安装使用指南

悟尘

Hexo COS CDN Hexo-deployer-cos-cdn

Netty 源码解析(三): Netty 的 Future 和 Promise

猿灯塔

spring-cloud-stream 集成 rocketmq

再见孙悟空

RocketMQ Spring Cloud

附录2、Dockerfile 参考及最佳实践

悟尘

Docker Dockerfile

Node.js 必知必会(安装配置、应用实例及同步控制)

悟尘

node.js

游戏夜读 | 设计师的数据模型

game1night

曾国藩家书嘉言钞(六)

熊小北同学

曾国藩 曾国藩家书 嘉言钞

高性能交易系统设计原理

廖雪峰

架构

读 Guide to Java String Pool

shengjk1

Java string pool

源码分析 Vector 和 ArrayList

张sir

Java 源码 collection

VSCode-aliyun-oss-paste-image 插件安装使用指南

悟尘

vscode Paste-image

Netty 源码解析(二):Netty 的 Channel

猿灯塔

Netty

Redis高可用-哨兵模式配置

for

redis 高可用 主从配置 redis高可用 redis哨兵模式

打造 VUCA 时代的 10 倍速 IT 团队

打造 VUCA 时代的 10 倍速 IT 团队

详解 Google 多任务学习模型 MMoE ( KDD 2018 )-InfoQ