写点什么

阿里巴巴 AAAI 2018 录用论文:通过强化学习进行图像精细描述,解决梯度消失难题

2018 年 3 月 06 日

原文下载地址:

https://102.alibaba.com/downloadFile.do?file=1518074198430/AAAI2018Stack-Captioning_Coarse-to-Fine%20Learning%20for%20Image%20Captioning_12213(1).pdf

摘要

现有的图像描述方法通常都是训练一个单级句子解码器,这难以生成丰富的细粒度的描述。另一方面,由于梯度消失问题,多级图像描述模型又难以训练。我们在本论文中提出了一种粗略到精细的多级图像描述预测框架,该框架由多个解码器构成,其中每一个都基于前一级的输出而工作,从而能得到越来越精细的图像描述。通过提供一个实施中间监督的学习目标函数,我们提出的学习方法能在训练过程中解决梯度消失的难题。尤其需要指出,我们使用了一种强化学习方法对我们的模型进行优化,该方法能够利用每个中间解码器的测试时间推理算法的输出及其前一个解码器的输出来对奖励进行归一化,这能一并解决众所周知的曝光偏差问题(exposure bias problem)和损失 - 评估不匹配问题(loss-evaluation mismatch problem)。我们在 MSCOCO 上进行了大量实验来评估我们提出的方法,结果表明我们的方法可以实现当前最佳的表现。

引言

图像描述的困难之处是让设计的模型能有效地利用图像信息和生成更接近人类的丰富的图像描述。在自然语言处理近期进展的推动下,当前的图像描述方法一般遵循编码 - 解码框架。这种框架由一个基于卷积神经网络(CNN)的图像编码器和基于循环神经网络(RNN)的句子解码器构成,有多种用于图像描述的变体。这些已有的图像描述方法的训练方式大都是根据之前的基本真值词(ground-truth words)和图像,使用反向传播,最大化每个基本真值词的可能性。

这些已有的图像描述方法存在三个主要问题。第一,它们很难生成丰富的细粒度的描述。第二,在训练和测试之间存在曝光偏差。第三,存在损失与评估的不匹配问题。

考虑到使用单级模型生成丰富的图像描述的巨大挑战性,我们在本论文中提出了一种粗略到精细的多级预测框架。我们的模型由一个图像编码器和一系列句子解码器构成,可以重复地生成细节越来越精细的图像描述。但是,直接在图像描述模型中构建这样的多级解码器面临着梯度消失问题的风险。Zhang, Lee, and Lee 2016; Fu, Zheng, and Mei 2017 等在图像识别上的研究工作表明监督非常深度的网络的中间层有助于学习,受这些研究的启发,我们也为每级解码器实施了中间监督。此外,Rennie et al. 2017 这项近期的图像描述研究使用了强化学习(RL)来解决损失 - 评估不匹配问题,并且还在训练中包含推理过程作为基准来解决曝光偏差问题;我们也设计了一种类似的基于强化学习的训练方法,但是将其从单级扩展成了我们的多级框架,其中每级都引入了作为中间监督的奖励。尤其需要指出,我们使用了一种强化学习方法对我们的模型进行优化,该方法能够利用每个中间解码器的测试时间推理算法的输出及其前一个解码器的输出来对奖励进行归一化。此外,为了应对我们的粗略到精细学习框架,我们采用了一种层叠式注意模型来为每个阶段的词预测提取更细粒度的视觉注意信息。图 1 给出了我们提出的粗略到精细框架的示意图,它由三个层叠的长短期记忆(LSTM)网络构成。第一个 LSTM 生成粗尺度的图像描述,后面的 LSTM 网络用作精细尺度的解码器。我们模型中每一级的输入都是前一级所得到的注意权重和隐藏向量,这些被用作后一级的消岐线索。由此,每一级解码器就会生成注意权重和词越来越精细的句子。

图 1 _:我们提出的粗略到精细框架示意图。我们的模型由一个图像编码器(CNN__)和一系列句子解码器(基于注意的 LSTM_ 网络)构成。该模型以图像为输入,能够从粗略到精细不断细化图像描述。这里我们展示了两级式的图像描述渐进提升(灰色和深灰色)。

本工作的主要贡献包括:(a)一种用于图像描述的粗略到精细框架,可以使用越来越细化的注意权重逐渐增大模型复杂度;(b)一种使用归一化后的中间奖励直接优化模型的强化学习方法。实验表明我们的方法在 MSCOCO 上表现出色。

方法

在本论文中,我们考虑了学习生成图像描述的问题。我们的算法构建了一个粗略到精细模型,它具有与单级模型一样的目标,但在输出层和输入层之间具有额外的中间层。我们首先根据输入图像和目标词的黄金历史,通过最大化每个连续目标词的对数似然而对该模型进行训练,然后再使用句子层面的评估指标对模型进行优化。结果,每个中间句子解码器会预测得到越来越细化的图像描述,最后一个解码器的预测结果用作最终的图像描述。

图像编码

我们首先将给定图像编码成空间图像特征。具体来说,我们从 CNN 的最后卷积层提取图像特征,然后使用空间自适应平均池化将这些特征的尺寸调整成固定尺寸的空间表示。

粗略到精细解码

整体的粗略到精细句子解码器由一个粗略解码器和一系列基于注意的精细解码器构成,这些解码器可以根据来自前一个解码器的线索来得到每个词预测的细化后的注意图(attention map)。我们模型的第一级是一个粗略解码器,能根据全局图像特征预测得到粗略的描述。在后续阶段,每一级都是一个精细解码器,可以基于图像特征和前一级的输出而预测得到更好的图像描述。尤其需要指出,我们使用了前一级的注意权重来提供后一级词预测的区域信念。也就是说,我们以多级方式解码图像特征,其中每级的预测结果都是对前一级预测结果的精细化。

图 2 给出了我们提出的粗略到精细解码架构,其中每一级之后都使用了中间监督(奖励)。上面一行(灰色)包含一个粗略解码器(左)和两个层叠的基于注意的精细解码器(处于训练模式下);下面一行给出了处于推理模式(贪婪解码)下的精细解码器,用于计算奖励以将中间监督整合进来。

图 2

粗略解码器。我们首先在第一级的粗略搜索空间中解码,我们在这里使用一个 LSTM 网络学习一个粗略解码器,称为 LSTMcoarse。LSTMcoarse 在每个时间步骤的输入都由前一个目标词(连接着全局图像特征)和之前的隐藏状态构成。

精细解码器。在后续的多级中,每个精细解码器都会再次基于图像特征以及来自前一个 LSTM 的注意权重和隐藏状态来预测词。每个精细解码器都由一个 LSTMfine 网络和一个注意模型构成。LSTMfine 在每个时间步骤的输入都包含已出现的图像特征、前一个词嵌入及其隐藏状态、来自前一个 LSTM 的更新后的隐藏状态。

层叠式注意模型。如前所述,我们的粗略解码器基于全局图像特征生成词。但是在很多情况下,每个词都只与图像中的很小一部分有关。由于每次预测时图像中的无关区域会引入噪声,所以为词预测使用全局图像特征会得到次优的结果。因此,我们引入了注意机制,这能显著提升图像描述的表现。注意机制通常会得到一个空间图(spatial map),其中突出显示了与每个预测词相关的图像区域。为了为词预测提取更细粒度的视觉信息,我们在本研究中采用了一种层叠式注意模型来逐渐滤除噪声和定位与词预测高度相关的区域。在每个精细处理级中,我们的注意模型都会对图像特征和来自前一级的注意权重进行操作。

学习

上面描述的粗略到精细方法能得到一种深度架构。训练这样一种深度网络可能很容易出现梯度消失问题,即梯度的幅度会在反向传播通过多个中间层时强度减小。解决这种问题的一种自然方法是将监督训练目标整合到中间层中。每一级粗略到精细句子解码器的训练目标都是重复地预测词。我们首先通过为每一级定义一个最小化交叉熵损失的损失函数来训练网络。

但是,只使用这里的损失函数进行训练是不够的。

为了优化每一级的评估指标,我们将图像描述生成过程看作是一个强化学习问题,即给定一个环境(之前的状态),我们想要智能体(比如 RNN、LSTM 或 GRU)查看环境(图像特征、隐藏状态和之前的词)并做出动作(预测下一个词)。在生成了一个完整句子之后,该智能体将观察句子层面的奖励并更新自己的内部状态。

实验

数据集和设置

我们在 MSCOCO 数据集上评估了我们提出的方法。

用于比较的基准方法

为了了解我们提出的方法的有效性,我们对以下模型进行了相互比较:

LSTM 和 LSTM3 layers我们根据 Vinyals et al. 2015 提出的框架实现了一个基于单层 LSTM 的图像描述模型。我们还在该单层 LSTM 模型之后增加了另外两个 LSTM 网络,我们将其称为 LSTM3 layers。

**LSTM+ATTSoft和 LSTM+ATTTop-Down。** 我们实现了两种基于视觉注意的图像描述模型:Xu et al. 2015 提出的软注意模型(LSTM+ATTSoft)和 Anderson et al. 2017 提出的自上而下注意模型(LSTM+ATTTop-Down)

Stack-Cap 和 Stack-Cap*Stack-Cap 是我们提出的方法,Stack-Cap* 是一种简化版本。Stack-Cap 和 Stack-Cap* 的架构相似,只是 Stack-Cap 应用了我们提出的层叠式注意模型而不是独立的注意模型。

定量分析

在实验中,我们首先使用标准的交叉熵损失对模型进行了优化。我们报告了我们的模型和基准模型在 Karpathy test split 上的表现,如表 1 所示。注意这里报告的所有结果都没有使用 ResNet-101 的微调。

表 1 _:在 MSCOCO_ 上的表现比较,其中 B@n 是指 BLEU-n_,M_ 指 METEOR_,C_ 指 CIDEr_。这里所有的值都是百分数(加粗数字是最佳结果)。_

在使用交叉熵损失优化了模型之后,我们又使用基于强化学习的算法针对 CIDEr 指标对它们进行了优化。表 2 给出了使用 SCST(Rennie et al. 2017)为 CIDEr 指标优化的 4 种模型的表现以及使用我们提出的粗略到精细(C2F)学习方法优化的 2 种模型的表现。可以看到我们的 Stack-Cap 模型在所有指标上都有显著的优势。

表 2

表 3 比较了我们的 Stack-Cap(C2F)模型与其它已有方法在 MSCOCO Karpathy test split 上的结果。Stack-Cap 在所有指标上都表现最佳。

表 3

在线评估。表 4 报告了我们提出的使用粗略到精细学习训练的 Stack-Cap 模型在官方 MSCOCO 评估服务器上的表现。可以看到,与当前最佳的方法相比,我们的方法非常有竞争力。注意,SCST:Att2in (Ens. 4) 的结果是使用 4 个模型联合实现的,而我们的结果是使用单个模型生成的。

表 4

定性分析

为了表明我们提出的粗略到精细方法可以逐级生成越来越好的图像描述,并且这些图像描述与自适应关注的区域有良好的关联,我们对生成的描述中词的空间注意权重进行了可视化。我们以 16 的采样系数对注意权重进行了上采样,并使用了一个高斯过滤器使之与输入图像一样大小,并将所有上采样后的空间注意图叠加到了原始输入图像上。

图 3 给出了一些生成的描述。通过多个注意层逐步进行推理,Stack-Cap 模型可以逐渐滤除噪声和定位与当前词预测高度相关的区域。可以发现,我们的 Stack-Cap 模型可以学习到与人类直觉高度对应的对齐方式。以第一张图像为例,对比粗略级生成的描述,由第一个精细解码器生成的首次细化后的描述中包含“dog”,第二个精细解码器不仅得到了“dog”,还识别出了“umbrella”。

图 3

此外,我们的方法还能生成更为描述性的句子。比如,喷气机图像的注意可视化表明 Stack-Cap 模型可以查询这些喷气机的关系以及它们身后长长的烟雾尾迹,因为这些突出区域有很高的关注权重。这个例子以及其它案例说明层叠式注意可以为序列预测更有效地探索视觉信息。也就是说,我们使用层叠式注意的方法可以从粗略到精细地考虑图像中的视觉信息,这与通常通过粗略到精细流程理解图像的人类视觉系统很相似。

如果您也有论文被国际顶会录用或者对论文编译整理工作感兴趣,欢迎关注 AI前线(ai-front),在后台留下联系方式,我们将与您联系,并进行更多交流!

2018 年 3 月 06 日 17:262953

评论

发布
暂无评论
发现更多内容

如果张东升是个程序员

程序员生活志

程序员 张东升

Cordova项目使用Android Studio真机调试

麦叔

android Android Studio 真机调试

IDEA 不为人知的 5 个骚技巧!真香!

王磊

Java 工具 IDEA

小白也有大厂梦,如何从零开始掌握高薪Java工程师必备技能?

无予且行

Java 架构 面试 后端 大厂

Java程序员的必修课之Spring理解透彻了吗?不会还咋去面试?

犬来八荒

Java spring 面试 后端 框架

大厂经验(3):Android端埋点自动采集技术原理剖析

DeeperMan

前端 数据采集 采集 埋点

面试官80%会问的分布式事务中的“最大努力通知”事务

无予且行

Java MySQL 面试 事务 java面试

有了多线程,为什么还要有协程?

八两

线程 进程 协程 GMP 进程线程区别

JVM中的双亲委派机制你还没懂吗?

阿文

Java JVMTI JVM 深入理解JVM JVM原理

Hexo blog 创建指导手册

想飞的鱼

GitHub Hexo GitHub Pages Blog

Hash一致性算法的Java实现

wei

熟悉JVM吗?为什么新生代内存需要有两个Survivor区?

南南

Java java面试 深入理解JVM JVM原理

现在面试这么难,背下题就能过的时代一去不复返了

小谈

Java JVM Java 面试 springboot SpringCloud

对mysql事务的认识,再不懂我就捶死我自己!

你是人间四月天

MySQL 面试 mysql事务 Java 面试 大厂面试

游戏夜读 | 跟风说一说爬虫

game1night

如何通过调试学习 nginx ?

张小方

c++ nginx 高性能 后端开发 服务器端开发

架构师训练营第 5 周——学习总结

在野

极客大学架构师训练营

【Python】 any() 和 or 区别你真的知道吗?

Leetao

Python 数据结构 Python基础知识

碎片化学习行不行

封不羁

架构师训练营第五周 - 总结

Eric

极客大学架构师训练营

老是自以为JVM懂了,那你知道 i = i++和 i = ++i 的区别吗?

小谈

Java 面试 编程语言 JVM 程序

解决死锁的4种基本方法(建议收藏)

小吴选手

Java 死锁

「架构师训练营」第 5 周作业 - 一致性哈希算法

guoguo 👻

极客大学架构师训练营

道路千万条,安全只三条

石君

安全评估 安全设计

k6简单入门

IT民工仁兄

性能测试

Java线程池最细的解释,看完后彻底征服面试官

小新

Java 架构 面试 线程 线程池

Git 的远端操作及解析(含思维导图)

多选参数

git GitHub gitlab

一致性哈希 -- java 实现

lei Shi

公司一直用Mybatis的原因原来在这!不得不竖起我的大拇指

小闫

Java mybatis mybatis-config.xml mybatis缓存

工业4.0|振动分析能做到预防性维护吗?

清水河路人甲

计算机操作系统基础(十三)---线程同步之读写锁

书旅

php laravel 线程 操作系统 进程

阿里巴巴AAAI 2018录用论文:通过强化学习进行图像精细描述,解决梯度消失难题-InfoQ