硬核干货——《中小企业 AI 实战指南》免费下载! 了解详情
写点什么

无痛的增强学习入门:Q-Learning

  • 2017-11-19
  • 本文字数:4428 字

    阅读完需:约 15 分钟

系列导读:《无痛的增强学习入门》系列文章旨在为大家介绍增强学习相关的入门知识,为大家后续的深入学习打下基础。其中包括增强学习的基本思想,MDP 框架,几种基本的学习算法介绍,以及一些简单的实际案例。

作为机器学习中十分重要的一支,增强学习在这些年取得了十分令人惊喜的成绩,这也使得越来越多的人加入到学习增强学习的队伍当中。增强学习的知识和内容与经典监督学习、非监督学习相比并不容易,而且可解释的小例子比较少,本系列将向各位读者简单介绍其中的基本知识,并以一个小例子贯穿其中。

8 Q-Learning

8.1 Q-Learning

上一节我们介绍了 TD 的 SARSA 算法,它的核心公式为:

\(q_t(s,a)=q_{t-1}(s,a)+\frac{1}{N}[R(s’)+q(s’,a’)-q_{t-1}(s,a)]\)

接下来我们要看的另一种 TD 的算法叫做 Q-Learning,它的基本公式为:

\(q_t(s,a)=q_{t-1}(s,a)+\frac{1}{N}[R(s’)+max_{a’} q(s’,a’)-q_{t-1}(s,a)]\)

两个算法的差别只在其中的一个项目,一个使用了当前 episode 中的状态 - 行动序列,另一个并没有,而是选择了数值最大的那个。这就涉及到两种不同的思路了。我们先暂时不管这个思路,来看看这个算法的效果。

首先还是实现代码:

复制代码
def q_learning(self):
iteration = 0
while True:
iteration += 1
self.q_learn_eval()
ret = self.policy_improve()
if not ret:
break

对应的策略评估代码为:

复制代码
def q_learn_eval(self):
episode_num = 1000
env = self.snake
for i in range(episode_num):
env.start()
state = env.pos
prev_act = -1
while True:
act = self.policy_act(state)
reward, state = env.action(act)
if prev_act != -1:
return_val = reward + (0 if state == -1 else np.max(self.value_q[state,:]))
self.value_n[prev_state][prev_act] += 1
self.value_q[prev_state][prev_act] += (return_val - \
self.value_q[prev_state][prev_act]) / \
self.value_n[prev_state][prev_act]
prev_act = act
prev_state = state
if state == -1:
break

实际的运行代码省略,最终的结果为:

复制代码
Timer Temporal Difference Iter COST:4.24033594131
return_pi=81
[0 0 0 0 1 1 1 1 1 1 0 1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
0 1 1 1 1 1 1 1 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 1 1
1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0]
policy evaluation proceed 94 iters.
policy evaluation proceed 62 iters.
policy evaluation proceed 46 iters.
Iter 3 rounds converge
Timer PolicyIter COST:0.318824052811
return_pi=84
[0 1 1 1 1 1 1 1 1 1 1 0 0 0 1 1 1 1 1 1 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 0 0
0 0 0 0 1 1 1 1 0 0 0 0 1 1 1 1 1 1 1 1 0 0 0 1 1 1 1 1 1 1 1 1 0 0 0 1 1
1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0]

可以看出,Q-Learning 的方法和策略迭代比还是有点差距,当然通过一些方法是可以把数字提高的,但是这不是本文的重点了。不过从结果来看,Q-Learning 比 SARSA 要好一些。

8.2 On-Policy 与 Off-Policy

实际上这两种算法代表了两种思考问题的方式,分别是 On-Policy 和 Off-Policy。On-Policy 的代表是 SARSA,它的价值函数更新是完全根据 episode 的进展进行的,相当于在之前的基础上不断计算做改进;Off-Policy 的代表是 Q-Learning,它的价值函数中包含了之前的精华,所以可以算是战得更高,看得更远。当然也不能说 Off-Policy 就绝对好,Off-Policy 还会带来一些自己的问题,总之我们需要具体问题具体分析。

这里最后再做一点对两大类算法的总结。其实如果游戏有终点,模型不复杂(像我们的蛇棋),蒙特卡罗法还是有绝对的优势,但是蒙特卡罗的软肋也比较明显,它需要完整的 episode,产生的结果方差大,这对于一些大型游戏并不适合,所以在真正的产品级规模的增强学习上,TD 的身影还是更多一些。

8.3 展望

由于篇幅的限制,到这里我们实际上就完成了对增强学习的一些基础内容的介绍,但是了解了这些内容,还不足以完成更加复杂的任务,下面就让我们简单了解一些更为高级的内容。

8.3.1 Function Approximation

前面我们提到的所有的增强学习算法都是基于表格的,表格的好处在于我们可以独立地考虑每一个状态、每一个行动的价值,不需要把很多状态汇集在一起考虑,然而在很多实际问题中,状态的数目非常多,甚至状态本身就不是连续的,那么采用表格来进行计算就显得不太现实,于是研究人员就开始研究一些更为合适的表达方式,这时候,我们的机器学习就登场了,我们可以建立一些模型,并用一个模型来表示状态、行动到价值函数的关系。

我们令状态为\(s \in S\),行动为\(a \in A\),最终的价值函数为\(v \in R\),那么我们要建立这样一个映射:

\(S \times A \rightarrow R\)

这样我们就把增强学习中的一个子问题转换成了监督学习,而监督学习是大家熟悉的套路,所以做起来就更加的得心应手了。实际上这就是一个回归问题,于是所有可以用于回归问题的模型都可以被我们用上。比方说线性回归,支持向量回归,决策树,神经网络。因为现在深度学习十分火,于是我们可以用一个深层的神经网络完成这个映射。

模型函数和监督学习使我们又了从另一个角度观察增强学习的可能。此时的模型要考虑 bias 和 variance 的权衡,要考虑模型的泛化性,这些问题最终都会映射到增强学习上。模型的表示形式也有很多种,前面给出的\(S \times A \rightarrow R\) 只是模型表示的一种形式,我们还可以表示成\(S \rightarrow A,R\) 这样的形式。对于第一种表示形式,由于不同的行动将作用在同一个状态下(或者不同的行动被同一种行动操纵),模型中的参数表示必然会存在重复,那么为了更好地共享参数,我们可以使用后面一种形式来表示。

说完了模型的形式,那么接下来就来看看模型的目标函数。最简单的目标函数自然是平方损失函数:

\(obj=\frac{1}{2}\sum_i^N(v’_i(s,a;w) - v_i)^2\)

其中的\(v_i’\) 表示模型估计的价值,而\(v_i\) 表示当前的真实价值,定义了目标函数,下面我们就可以利用机器学习经典的梯度下降法求解了:

\(\frac{\partial obj}{\partial w}=\sum_i^N (v_i’ -v_i) \frac{\partial v_i’}{\partial w}\)

这个公式是不是看上去很眼熟?实际上如果

\(\frac{\partial v_i’}{\partial w}=1\)

,那么模型的最优解就等于

\(v_i’=\frac{1}{N}\sum_i^N v_i\)

也就是所有训练数据的平均值,这个结果和表格版的计算公式是一致的,也就是说表格版的算法实际上也是一种模型,只不过它的梯度处处为 1。

关于模型更多的讨论,我们也可以取阅读各种论文,论文中对这里面的各个问题都有深入的讨论。

8.3.2 Policy Gradient

另外一个方向则是跳出已有的思维框架,朝着另一种运算方式前进。这种方法被称为 Policy Gradient。在这种方法中,我们不再采用先求价值函数,后更新策略的方式,而是直接针对策略进行建模,也就是\(\pi(a|s)\) 建模。

我们回到增强学习问题的源头,最初我们希望找到一种策略使得 Agent 的长期回报最大化,也就是:

\(max E_{\pi}[v_{\pi}(s_0)]\)

这个公式求解梯度可以展开为:

\(\nabla v_{\pi}(s_0)=E_{\pi}[\gamma^t G_t \nabla log \pi(a_t|s_t;w)]\)

更新公式为:

\(\theta_{t+1}=\theta_t + \alpha G_t \nabla log\pi(a_t|s_t;w)\)

这其中的推导过程就省略不谈了。得到了目标函数的梯度,那么我们就可以直接根据梯度去求极值了。我们真正关心的部分实际上是里面的那个求梯度的部分,所以当整体目标达到最大时,策略也就达到了最大值,因此目标函数的梯度可以回传给模型以供使用。由于 log 函数不改变函数的单调性,对最终的最优策略步影响,于是我们可以直接对\(log \pi(a|s;w)\) 进行建模。这个方法被称为 REINFORCE。

当然,这个算法里面还存在着一些问题。比方说里面的\(G_t\) 同样是用蒙特卡罗的方法得到的。它和蒙特卡罗方法一样,存在着高方差的问题,为了解决高方差的问题,有人提出构建一个 BaseLine 模型,让每一个\(G_t\) 减掉 BaseLine 数字,从而降低了模型的方差。更新公式为:

\(\theta_{t+1}=\theta_t + \alpha (G_t-b(S_t)) \nabla log\pi(a_t|s_t;w)\)

8.3.3 Actor-Critic

既然有基于 Policy-Gradient 的蒙特卡罗方法,那么就应该有 TD 的方法。这种方法被称为 Actor-Critic 方法。这个方法由两个模型组成,其中 Actor 负责 Policy 的建模,Critic 负责价值函数的建模,于是上面基于 BaseLine 的 REINFORCE 算法的公式就变成了:

\(\theta_{t+1}=\theta_t + \alpha (R_{t+1}+\gamma \hat{v}(s_{t+1};w)-\hat{v}(S_t;w)) \nabla log\pi(a_t|s_t;w)\)

这样每一轮的优化也就变成了先根据模拟的 episode 信息优化价值函数,然后再优化策略的过程。

一般来说,上面公式中的\((R_{t+1}+\gamma \hat{v}(s_{t+1};w)-\hat{v}(S_t;w))\) 表示向前看一步的价值和就在当前分析的价值的差。下棋的人们都知道“下棋看三步”的道理,也就是说有时面对一些游戏的状态,我们在有条件的情况下可以多向前看看,再做出决定,所以一般认为多看一步的价值函数的值会更高更好一些,所以上面的那一项通常被称为优势项(Advantage)。

在前面的实验中我们也发现,对于无模型的问题,我们要进行大量的采样才能获得比较精确的结果,对于更大的问题来说,需要的采样模拟也就更多。因此,对于大量的计算量,如何更快地完成计算也成为了一个问题。有一个比较知名且效果不错的算法,被称为 A3C(Asynchronous Advantage Actor-Critic) 的方法,主要是采用了异步的方式进行并行采样,更新参数,这个方法也得到了比较好的效果。

以上就是《无痛的增强学习》入门篇,我们以蛇棋为例介绍了增强学习基础框架 MDP,介绍了模型已知的几种方法——策略迭代、价值迭代、泛化迭代法,介绍了模型未知的几种方法——蒙特卡罗、SARSA、Q-Learning,还简单介绍了一些更高级的计算方法,希望大家能从中有所收获。由于作者才疏学浅,行文中难免有疏漏之处,还请各位谅解。

作者介绍

冯超,毕业于中国科学院大学,猿辅导研究团队视觉研究负责人,小猿搜题拍照搜题负责人之一。2017 年独立撰写《深度学习轻松学:核心算法与视觉实践》一书,以轻松幽默的语言深入详细地介绍了深度学习的基本结构,模型优化和参数设置细节,视觉领域应用等内容。自 2016 年起在知乎开设了自己的专栏:《无痛的机器学习》,发表机器学习与深度学习相关文章,收到了不错的反响,并被多家媒体转载。曾多次参与社区技术分享活动。

2017-11-19 17:084131

评论

发布
暂无评论
发现更多内容

【YashanDB知识库】解压安装包时报错"tar:Error is not recoverable"

YashanDB

数据库 yashandb

阿里巴巴1688 API接口深度解析:高效获取商品详情与关键词搜索商品实战指南

代码忍者

1688API接口

分析代码变更与新增代码覆盖率的最佳实践

测试人

软件测试

2024 京东零售技术年度总结

京东零售技术

【YashanDB知识库】调用外部UDF未能识别Java环境配置

YashanDB

数据库 yashandb

【YashanDB知识库】过期统计信息导致SQL执行计划变差

YashanDB

数据库 yashandb

HarmonyOS NEXT——独立开发者们的机遇之门

最新动态

数字孪生丨如何利用现有数据提升产品耐久性?

Altair RapidMiner

大数据 数字孪生 智能制造 altair 仿真设计

【YashanDB知识库】个别数据库用户无法登录数据库,报错 io fail:IO.EOF

YashanDB

数据库 yashandb

从设计到伴飞:数字孪生赋能航空航天新时代

DevOps和数字孪生

航天航空

代码质量保证的利器:Git 预提交钩子

俞凡

最佳实践

【YashanDB知识库】ycm托管主机报错libnsl.so.1 no such file or directory

YashanDB

数据库 yashandb

【YashanDB知识库】安装共享集群时报错:YAS-05721 invalid input parameter, reason: node name invalid

YashanDB

数据库 yashandb

共探数据可信流通时代的密态新算力|走进隐语年度嘉年华精彩现场

隐语SecretFlow

【YashanDB知识库】隐藏参数怎么查看初始值

YashanDB

数据库 yashandb

【YashanDB知识库】原生mysql驱动配置连接崖山数据库

YashanDB

数据库 yashandb

Sonarqube 代码分析技术体系

测试人

软件测试

ITIL 4的4个维度

ServiceDesk_Plus

ITIL

【YashanDB知识库】如何使用MySQL客户端链接YashanDB

YashanDB

数据库 yashandb

复盘2024,大模型的商业化主线是什么?

脑极体

AI

代码复杂度定义与分析方法

测试人

软件测试

【YashanDB知识库】yashandb升级后,yasboot restart出现版本回退、报错control file version incompatible

YashanDB

数据库 yashandb

无痛的增强学习入门:Q-Learning_语言 & 开发_冯超_InfoQ精选文章