写点什么

Tensorflow2.0 实现 Deep-Q-Network

  • 2019-12-02
  • 本文字数:4042 字

    阅读完需:约 13 分钟

Tensorflow2.0实现Deep-Q-Network

深度 Q 网络(Deep - Q - Network) 介绍


在 Q-learning 算法中,当状态和动作空间是离散且维数不高时,可使用 Q-table 储存每个状态动作对的 Q 值,然后通过贝尔曼方差迭代求得每个状态动作对收敛的 Q 值,然后选择最优的动作当做策略。但是而当状态和动作空间是高维连续时,比如(游戏的状态动作对数目就很大)使用 Q-table 存储每个状态动作对就显得很不现实。


所以可以将 Q-Table 的更新问题变成一个函数拟合问题,相近的状态得到相近的输出动作。DQN 就是要设计一个神经网络结构,通过函数来拟合 Q 值。


下面引用一下自己写的一篇综述里面的 DQN 训练流程图,贴自己的图,不算侵犯版权吧,哈哈。知网上可以下载到这篇文章:http://kns.cnki.net/kcms/detail/detail.aspx?dbcode=CJFD&&filename=JSJX201801001。当时3个月大概看了百余篇DRL方向的论文,才写出来的,哈哈。



DQN 的亮点:


  • 通过 experience replay(经验池)的方法来解决相关性及非静态分布问题,在训练深度网络时,通常要求样本之间是相互独立的,所以通过这种随机采样的方式,大大降低了样本之间的关联性,从而提升了算法的稳定性。

  • 使用一个神经网络产生当前 Q 值,使用另外一个神经网络产生 Target Q 值。

  • DQN 损失函数和参数更新:


损失函数:



其中 yi 表示值函数的优化目标即目标网络的 Q 值:



参数更新的梯度为:



Tensorflow 2.0 实现 DQN


整体的代码是借鉴的莫烦大神,只不过现在用的接口都是 Tensorflow 2.0,所以代码显得很简单,风格很像 keras。


# -*- coding:utf-8 -*-# Author : zhaijianwei# Date : 2019/6/19 19:48
import tensorflow as tfimport numpy as npfrom tensorflow.python.keras import layersfrom tensorflow.python.keras.optimizers import RMSprop
from DQN.maze_env import Maze

class Eval_Model(tf.keras.Model): def __init__(self, num_actions): super().__init__('mlp_q_network') self.layer1 = layers.Dense(10, activation='relu') self.logits = layers.Dense(num_actions, activation=None)
def call(self, inputs): x = tf.convert_to_tensor(inputs) layer1 = self.layer1(x) logits = self.logits(layer1) return logits

class Target_Model(tf.keras.Model): def __init__(self, num_actions): super().__init__('mlp_q_network_1') self.layer1 = layers.Dense(10, trainable=False, activation='relu') self.logits = layers.Dense(num_actions, trainable=False, activation=None)
def call(self, inputs): x = tf.convert_to_tensor(inputs) layer1 = self.layer1(x) logits = self.logits(layer1) return logits

class DeepQNetwork: def __init__(self, n_actions, n_features, eval_model, target_model):
self.params = { 'n_actions': n_actions, 'n_features': n_features, 'learning_rate': 0.01, 'reward_decay': 0.9, 'e_greedy': 0.9, 'replace_target_iter': 300, 'memory_size': 500, 'batch_size': 32, 'e_greedy_increment': None }
# total learning step
self.learn_step_counter = 0
# initialize zero memory [s, a, r, s_] self.epsilon = 0 if self.params['e_greedy_increment'] is not None else self.params['e_greedy'] self.memory = np.zeros((self.params['memory_size'], self.params['n_features'] * 2 + 2))
self.eval_model = eval_model self.target_model = target_model
self.eval_model.compile( optimizer=RMSprop(lr=self.params['learning_rate']), loss='mse' ) self.cost_his = []
def store_transition(self, s, a, r, s_): if not hasattr(self, 'memory_counter'): self.memory_counter = 0
transition = np.hstack((s, [a, r], s_))
# replace the old memory with new memory index = self.memory_counter % self.params['memory_size'] self.memory[index, :] = transition
self.memory_counter += 1
def choose_action(self, observation): # to have batch dimension when feed into tf placeholder observation = observation[np.newaxis, :]
if np.random.uniform() < self.epsilon: # forward feed the observation and get q value for every actions actions_value = self.eval_model.predict(observation) print(actions_value) action = np.argmax(actions_value) else: action = np.random.randint(0, self.params['n_actions']) return action
def learn(self): # sample batch memory from all memory if self.memory_counter > self.params['memory_size']: sample_index = np.random.choice(self.params['memory_size'], size=self.params['batch_size']) else: sample_index = np.random.choice(self.memory_counter, size=self.params['batch_size'])
batch_memory = self.memory[sample_index, :]
q_next = self.target_model.predict(batch_memory[:, -self.params['n_features']:]) q_eval = self.eval_model.predict(batch_memory[:, :self.params['n_features']])
# change q_target w.r.t q_eval's action q_target = q_eval.copy()
batch_index = np.arange(self.params['batch_size'], dtype=np.int32) eval_act_index = batch_memory[:, self.params['n_features']].astype(int) reward = batch_memory[:, self.params['n_features'] + 1]
q_target[batch_index, eval_act_index] = reward + self.params['reward_decay'] * np.max(q_next, axis=1)
# check to replace target parameters if self.learn_step_counter % self.params['replace_target_iter'] == 0: for eval_layer, target_layer in zip(self.eval_model.layers, self.target_model.layers): target_layer.set_weights(eval_layer.get_weights()) print('\ntarget_params_replaced\n')
""" For example in this batch I have 2 samples and 3 actions: q_eval = [[1, 2, 3], [4, 5, 6]] q_target = q_eval = [[1, 2, 3], [4, 5, 6]] Then change q_target with the real q_target value w.r.t the q_eval's action. For example in: sample 0, I took action 0, and the max q_target value is -1; sample 1, I took action 2, and the max q_target value is -2: q_target = [[-1, 2, 3], [4, 5, -2]] So the (q_target - q_eval) becomes: [[(-1)-(1), 0, 0], [0, 0, (-2)-(6)]] We then backpropagate this error w.r.t the corresponding action to network, leave other action as error=0 cause we didn't choose it. """
# train eval network
self.cost = self.eval_model.train_on_batch(batch_memory[:, :self.params['n_features']], q_target)
self.cost_his.append(self.cost)
# increasing epsilon self.epsilon = self.epsilon + self.params['e_greedy_increment'] if self.epsilon < self.params['e_greedy'] \ else self.params['e_greedy'] self.learn_step_counter += 1
def plot_cost(self): import matplotlib.pyplot as plt plt.plot(np.arange(len(self.cost_his)), self.cost_his) plt.ylabel('Cost') plt.xlabel('training steps') plt.show()

def run_maze(): step = 0 for episode in range(300): # initial observation observation = env.reset()
while True: # fresh env env.render() # RL choose action based on observation action = RL.choose_action(observation) # RL take action and get next observation and reward observation_, reward, done = env.step(action) RL.store_transition(observation, action, reward, observation_) if (step > 200) and (step % 5 == 0): RL.learn() # swap observation observation = observation_ # break while loop when end of this episode if done: break step += 1 # end of game print('game over') env.destroy()

if __name__ == "__main__": # maze game env = Maze() eval_model = Eval_Model(num_actions=env.n_actions) target_model = Target_Model(num_actions=env.n_actions) RL = DeepQNetwork(env.n_actions, env.n_features, eval_model, target_model) env.after(100, run_maze) env.mainloop() RL.plot_cost()
复制代码


参考文献:


https://www.jianshu.com/p/10930c371cac


https://github.com/MorvanZhou/Reinforcement-learning-with-tensorflow


http://inoryy.com/post/tensorf


本文转载自 Alex-zhai 知乎账号。


原文链接:https://zhuanlan.zhihu.com/p/70009692


2019-12-02 16:221292

评论

发布
暂无评论
发现更多内容

HummerRisk V0.6.0:列表高级搜索,对象存储、操作审计扩充支持

HummerCloud

云安全 云原生安全

太强了!GitHub大佬白嫖的SpringCloud微服务进阶宝典,啃完感觉能吊锤面试官!

程序知音

Java 微服务 SpringCloud java架构 后端技术

民办二本程序员阿里、百度、平安等五厂面经,5份offer(含真题)

钟奕礼

Java 程序员 java面试 java编程

StoneDB 首席架构师李浩:如何选择一款 HTAP 产品?

StoneDB

MySQL HTAP 数据库· StoneDB 12 月 PK 榜

聚焦稳定性,Dubbo 发版规划公布

Apache Dubbo

Java 开源 微服务 云原生 dubbo

架构训练营 模块一作业

提姆

阿里程序员给我一份Java笔、面试宝典,看目录的那一刻,我傻了!

钟奕礼

Java 程序员 java面试 java编程

北京等保备案预约平台是哪个?多久能办好?

行云管家

等保 等保测评 等保备案 北京

老板答应了我,只要回答对几道简单的Spring问题,就给我涨3K

钟奕礼

Java 程序员 java面试 java编程

TDH 社区版上新宽表数据库 Hyperbase,轻松实现海量数据的毫秒级精确检索

星环科技

数据库

腾讯云与流媒体服务商BeLive达成合作,助力提升东南亚与周边地区直播水平

科技热闻

带你了解基于Ploto构建自动驾驶平台

华为云开发者联盟

人工智能 自动驾驶 华为云 12 月 PK 榜

Spring 事务失效的六种情况

江南一点雨

spring 事务

数字先锋| 农业农村部大数据公共平台基座上线,天翼云擎起乡村振兴新希望!

天翼云开发者社区

Redis事务、pub/sub、PipeLine-管道、benchmark性能测试详解

C++后台开发

redis 中间件 性能测试 后端开发 C++开发

当线下门店遇上AI:华为云ModelBox携手佳华科技客流分析实践

华为云开发者联盟

人工智能 数字化转型 华为云 12 月 PK 榜

不让Bug陪你过年,StarRocks年终抓虫派对重金相邀!

StarRocks

#数据库

iptables 命令和 iptables.service 服务的区别

山河已无恙

12月月更

天翼云电脑为医共体建设加buff!

天翼云开发者社区

京东内部流传的MyBatis笔记,短小而精悍,处处是源码细节

小小怪下士

Java 源码 程序员 mybatis

云服务器好用吗,有哪些特点?

Finovy Cloud

云服务器 云渲染

Java程序员:为了跳槽刷完1000道真题,没想到老板直接给我升职了

钟奕礼

Java 程序员 java面试 java编程

拼多多电商部java岗三面落选,记下的面试题,不睡觉都要背下来!

钟奕礼

Java 程序员 java面试 java编程

【设计指南】避免PCB板翘,合格的工程师选择这样设计!

华秋PCB

生产 PCB PCB设计

糟糕,数据库异常不可用怎么办?

华为云开发者联盟

数据库 后端 华为云 12 月 PK 榜

这个算法不一般,控制拥塞有一手!

天翼云开发者社区

更快更稳更安全!天翼云CDN了解一下

天翼云开发者社区

当ChatGPT火爆全球,中国交互AI平台「聆心智能」获得千万元融资

硬科技星球

公司CTO:高性能开发,你不会Netty,怎么好意思拿20K?

钟奕礼

Java 程序员 java面试 java编程

低代码平台,企业业务创新的最佳路径

元年技术洞察

低代码 数字化转型 #方舟平台

Tensorflow2.0实现Deep-Q-Network_语言 & 开发_Alex-zhai_InfoQ精选文章