分布式 tensorflow 源码解读 2:MonitoredTrainingSession

阅读数:19 2019 年 11 月 28 日 08:00

分布式tensorflow源码解读2:MonitoredTrainingSession

MonitoredTrainingSession 是 tensorflow 管理分布式训练中一个使用很广泛的 API,集成了一些监控训练组件,如变量的初始化、从已有 checkpoint 恢复训练、summary、log 和 checkpoint 的保存等。在早期的 tf 版本中,一般使用 tf.train.Supervisor 来管理 session,后来框架升级后,官方推荐使用 MonitoredTrainingSession。MonitoredTrainingSession 有记录日志、训练可视化、checkpoint 保存、early-stop、训练效率调优等功能。

我们直接进入主题,下面是 MonitoredTrainingSession 源码,从注释中可了解到:MonitoredTrainingSession 的作用可用一句话来概括:如果 chief 节点,负责 session 的初始化或者从已有 checkpoint 恢复 session,并且创建一些用于保存 checkpoint 和 summary 的 hooks。如果是非 chief 的 worker 节点,则需要依赖 chief 节点完成初始化或恢复 session 这些操作后才能设置属于自己的 session。

复制代码
@tf_export(v1=['train.MonitoredTrainingSession'])
def MonitoredTrainingSession(
master='', # pylint: disable=invalid-name
is_chief=True,
checkpoint_dir=None,
scaffold=None,
hooks=None,
chief_only_hooks=None,
save_checkpoint_secs=USE_DEFAULT,
save_summaries_steps=USE_DEFAULT,
save_summaries_secs=USE_DEFAULT,
config=None,
stop_grace_period_secs=120,
log_step_count_steps=100,
max_wait_secs=7200,
save_checkpoint_steps=USE_DEFAULT,
summary_dir=None):
"""
Creates a `MonitoredSession` for training.
Returns:
A `MonitoredSession` object.
"""
scaffold = scaffold or Scaffold()
worker_context = distribute_coordinator_context.get_current_worker_context()
if worker_context:
return _create_monitored_session_with_worker_context(
worker_context,
scaffold,
checkpoint_dir=checkpoint_dir,
hooks=hooks,
chief_only_hooks=chief_only_hooks,
save_checkpoint_secs=save_checkpoint_secs,
save_summaries_steps=save_summaries_steps,
save_summaries_secs=save_summaries_secs,
config=config,
stop_grace_period_secs=stop_grace_period_secs,
log_step_count_steps=log_step_count_steps,
max_wait_secs=max_wait_secs,
save_checkpoint_steps=save_checkpoint_steps,
summary_dir=summary_dir)
if not is_chief:
session_creator = WorkerSessionCreator(
scaffold=scaffold,
master=master,
config=config,
max_wait_secs=max_wait_secs)
return MonitoredSession(
session_creator=session_creator,
hooks=hooks or [],
stop_grace_period_secs=stop_grace_period_secs)
all_hooks = []
“”“
将多个 hook 都加入到 all_hooks 这个列表中
”“”
if hooks:
all_hooks.extend(hooks)
return MonitoredSession(
session_creator=session_creator,
hooks=all_hooks,
stop_grace_period_secs=stop_grace_period_secs)

我们首先解释下参数:

is_chief:用于分布式系统中,用于判断该系统是否是 chief,如果为 True,它将负责初始化并恢复底层 TensorFlow 会话。如果为 False,它将等待 chief 初始化或恢复 TensorFlow 会话。

checkpoint_dir:一个字符串。指定一个用于恢复变量的 checkpoint 文件路径。

scaffold:用于收集或建立支持性 op 的脚手架。如果未指定,则会创建默认一个默认的 scaffold。它用于完成图表的创建。

hooks:SessionRunHook 对象的可选列表。可自己定义 SessionRunHook 对象,也可用已经预定义好的 SessionRunHook 对象,如:tf.train.StopAtStepHook() 设置停止训练的条件;tf.train.NanTensorHook(loss): 如果 loss 的值为 Nan 则停止训练;

chief_only_hooks:SessionRunHook 对象列表。如果 is_chief== True,则激活这些挂钩,否则忽略。

save_checkpoint_secs:用默认的 checkpoint saver 保存 checkpoint 的频率(以秒为单位)。如果 save_checkpoint_secs 设置为 None,不保存 checkpoint。

save_summaries_steps:使用默认 summaries saver 将摘要写入磁盘的频率(以全局步数表示)。如果 save_summaries_steps 和 save_summaries_secs 都设置为 None,则不使用默认的 summaries saver 保存 summaries。默认为 100

save_summaries_secs:使用默认 summaries saver 将摘要写入磁盘的频率(以秒为单位)。如果 save_summaries_steps 和 save_summaries_secs 都设置为 None,则不使用默认的摘要保存。默认未启用。

config:用于配置会话的 tf.ConfigProtoproto 的实例。它是 tf.Session 的构造函数的 config 参数。

stop_grace_period_secs:调用 close()后线程停止的秒数。

log_step_count_steps:记录全局步 / 秒的全局步数的频率。

实例化后可得到一个 MonitoredSession 对象,可当作普通 session 使用。

然后我们仔细分解下代码:

复制代码
def _create_monitored_session_with_worker_context(
worker_context, # pylint: disable=missing-docstring
scaffold,
checkpoint_dir=None,
hooks=None,
chief_only_hooks=None,
save_checkpoint_secs=None,
save_summaries_steps=None,
save_summaries_secs=None,
config=None,
stop_grace_period_secs=120,
log_step_count_steps=100,
max_wait_secs=7200,
save_checkpoint_steps=None,
summary_dir=None):
all_hooks = []
“”“

将多个 hook 都加入到 all_hooks 这个列表中

复制代码
”“”
logging.info('all_hooks %r', all_hooks)
# 创建 session
session_creator = worker_context.session_creator(
scaffold,
config=config,
checkpoint_dir=checkpoint_dir,
max_wait_secs=max_wait_secs)
return MonitoredSession(
session_creator=session_creator,
hooks=all_hooks,
stop_grace_period_secs=stop_grace_period_secs)
# session_creator 函数主体
def session_creator(self,
scaffold=None,
config=None,
checkpoint_dir=None,
checkpoint_filename_with_path=None,
max_wait_secs=7200):
"""
复制代码
根据正确 master 的 target 和 session 的 config 去返回 session 的 creator 方法体。
复制代码
"""
if config:
session_config = copy.deepcopy(config)
session_config.MergeFrom(self._session_config)
else:
session_config = self._session_config
复制代码
# 根据不同的角色来创建 session
复制代码
if not self._strategy or self._strategy.extended.experimental_should_init:
logging.info("Creating chief session creator with config: %r", config)
return monitored_session.ChiefSessionCreator(
scaffold,
master=self.master_target,
config=session_config,
checkpoint_dir=checkpoint_dir,
checkpoint_filename_with_path=checkpoint_filename_with_path)
else:
logging.info("Creating worker session creator with config: %r", config)
return monitored_session.WorkerSessionCreator(
scaffold,
master=self.master_target,
config=session_config,
max_wait_secs=max_wait_secs)
# ChiefSessionCreator
@tf_export(v1=['train.ChiefSessionCreator'])
class ChiefSessionCreator(SessionCreator):
"""Creates a tf.compat.v1.Session for a chief."""
def __init__(self,
scaffold=None,
master='',
config=None,
checkpoint_dir=None,
checkpoint_filename_with_path=None):
self._checkpoint_dir = checkpoint_dir
self._checkpoint_filename_with_path = checkpoint_filename_with_path
self._scaffold = scaffold or Scaffold()
self._session_manager = None
self._master = master
self._config = config
def _get_session_manager(self):
if self._session_manager:
return self._session_manager
self._session_manager = sm.SessionManager(
local_init_op=self._scaffold.local_init_op,
ready_op=self._scaffold.ready_op,
ready_for_local_init_op=self._scaffold.ready_for_local_init_op,
graph=ops.get_default_graph())
return self._session_manager
def create_session(self):
self._scaffold.finalize()
return self._get_session_manager().prepare_session(
self._master,
saver=self._scaffold.saver,
checkpoint_dir=self._checkpoint_dir,
checkpoint_filename_with_path=self._checkpoint_filename_with_path,
config=self._config,
init_op=self._scaffold.init_op,
init_feed_dict=self._scaffold.init_feed_dict,
init_fn=self._scaffold.init_fn)
# WorkerSessionCreator
@tf_export(v1=['train.WorkerSessionCreator'])
class WorkerSessionCreator(SessionCreator):
"""Creates a tf.compat.v1.Session for a worker."""
def __init__(self,
scaffold=None,
master='',
config=None,
max_wait_secs=30 * 60):
"""Initializes a worker session creator.
{1}
Args:
max_wait_secs: Maximum time to wait for the session to become available.
"""
self._scaffold = scaffold or Scaffold()
self._session_manager = None
self._master = master
self._config = config
self._max_wait_secs = max_wait_secs
def _get_session_manager(self):
if self._session_manager:
return self._session_manager
self._session_manager = sm.SessionManager(
local_init_op=self._scaffold.local_init_op,
ready_op=self._scaffold.ready_op,
ready_for_local_init_op=self._scaffold.ready_for_local_init_op,
graph=ops.get_default_graph())
return self._session_manager
def create_session(self):
self._scaffold.finalize()
return self._get_session_manager().wait_for_session(
self._master, config=self._config, max_wait_secs=self._max_wait_secs)

从上面的源码中分析得到,MonitoredTrainingSession 可根据不同的角色去创建不同种类的 Session,其中 chief 节点是由 ChiefSessionCreator 类去创建 session,而非 chief 的 worker 节点是由 WorkerSessionCreator 类创建,特殊之处就是创建时调用的是 wait_for_session(),大致意识是需要等待 chief 节点的 session 创建完成之后才去创建属于自己节点的 session。其中创建 session 都是属于 SessionManager 类的一个方法,下面我们具体分析下 SessionManager 类:

官方针对 SessionManager 类有一个简单的例子,感觉很清楚:

复制代码
# prepare_session 函数可以初始化或者 restore 一个模型,需要传入`init_op`和 `saver`
with tf.Graph().as_default():
# add operations to the graph...
# Create a SessionManager that will checkpoint the model in '/tmp/mydir'.
sm = SessionManager()
sess = sm.prepare_session(master, init_op, saver, checkpoint_dir)
# Use the session to train the graph.
while True:
sess.run(<my_train_op>)

第二个进程可以用以下方法启动 op,wait_for_session() 的意思是需要等上面一个 session 创建好之后

复制代码
# 再创建自己的 session
with tf.Graph().as_default():
# ...add operations to the graph...
# Create a SessionManager that will wait for the model to become ready.
sm = SessionManager()
sess = sm.wait_for_session(master)
# Use the session to train the graph.
while True:
sess.run(<my_train_op>)

然后我们可以重点关注下 prepare_session 和 wait_for_session 这两个函数:

复制代码
@tf_export(v1=["train.SessionManager"])
class SessionManager(object):
def __init__(self,
local_init_op=None,
ready_op=None,
ready_for_local_init_op=None,
graph=None,
recovery_wait_secs=30,
local_init_run_options=None):
"""
复制代码
local_init_op 是每当有一个新的 session 被创建时,就会运行下 local_init_op 这个操作。
ready_op 用于 check 模型是否准备好的一个 op。
ready_for_local_init_op 是 checkp 模型是否已经可以运行 local_init_op 的一个 op。
复制代码
"""
# Sets default values of arguments.
if graph is None:
graph = ops.get_default_graph()
self._local_init_op = local_init_op
self._ready_op = ready_op
self._ready_for_local_init_op = ready_for_local_init_op
self._graph = graph
self._recovery_wait_secs = recovery_wait_secs
self._target = None
self._local_init_run_options = local_init_run_options
if ready_for_local_init_op is not None and local_init_op is None:
raise ValueError("If you pass a ready_for_local_init_op "
"you must also pass a local_init_op "
", ready_for_local_init_op [%s]" %
ready_for_local_init_op)
def prepare_session(self,
master,
init_op=None,
saver=None,
checkpoint_dir=None,
checkpoint_filename_with_path=None,
wait_for_checkpoint=False,
max_wait_secs=7200,
config=None,
init_feed_dict=None,
init_fn=None):
"""
复制代码
其实 prepare_session 函数的作用就是如果有 checkpoint 存在,就从 checkpoint 恢复 session,如果
不存在 checkpoint 就从传入的`init_op`和 调用`init_fn`函数去创建 session
复制代码
"""
sess, is_loaded_from_checkpoint = self._restore_checkpoint(
master,
saver,
checkpoint_dir=checkpoint_dir,
checkpoint_filename_with_path=checkpoint_filename_with_path,
wait_for_checkpoint=wait_for_checkpoint,
max_wait_secs=max_wait_secs,
config=config)
if not is_loaded_from_checkpoint:
if init_op is None and not init_fn and self._local_init_op is None:
raise RuntimeError("Model is not initialized and no init_op or "
"init_fn or local_init_op was given")
if init_op is not None:
sess.run(init_op, feed_dict=init_feed_dict)
if init_fn:
init_fn(sess)
”“”
.....
“”“
return sess
def wait_for_session(self, master, config=None, max_wait_secs=float("Inf")):
"""
Creates a new `Session` and waits for model to be ready.
"""
self._target = master
if max_wait_secs is None:
max_wait_secs = float("Inf")
timer = _CountDownTimer(max_wait_secs)
while True:
sess = session.Session(self._target, graph=self._graph, config=config)
not_ready_msg = None
not_ready_local_msg = None
local_init_success, not_ready_local_msg = self._try_run_local_init_op(
sess)
if local_init_success:
# Successful if local_init_op is None, or ready_for_local_init_op passes
is_ready, not_ready_msg = self._model_ready(sess)
if is_ready:
return sess
self._safe_close(sess)
# Do we have enough time left to try again?
remaining_ms_after_wait = (
timer.secs_remaining() - self._recovery_wait_secs)
if remaining_ms_after_wait < 0:
raise errors.DeadlineExceededError(
None, None,
"Session was not ready after waiting %d secs." % (max_wait_secs,))
logging.info("Waiting for model to be ready. "
"Ready_for_local_init_op: %s, ready: %s",
not_ready_local_msg, not_ready_msg)
time.sleep(self._recovery_wait_secs)

创建完 session 之后,再包装一下返回最终的 MonitoredSession 类,

一个完整的 monitored session 在创建时间内可做的事情(按顺序):

  • 为每个 hook 调用 hook.begin()
    MonitoredTrainingSession 是 tensorflow 管理分布式训练中一个使用很广泛的 API,集成了一些监控训练组件,如变量的初始化、从已有 checkpoint 恢复训练、summary、log 和 checkpoint 的保存等。在早期的 tf 版本中,一般使用 tf.train.Supervisor 来管理 session,后来框架升级后,官方推荐使用 MonitoredTrainingSession。MonitoredTrainingSession 有记录日志、训练可视化、checkpoint 保存、early-stop、训练效率调优等功能。

我们直接进入主题,下面是 MonitoredTrainingSession 源码,从注释中可了解到:MonitoredTrainingSession 的作用可用一句话来概括:如果 chief 节点,负责 session 的初始化或者从已有 checkpoint 恢复 session,并且创建一些用于保存 checkpoint 和 summary 的 hooks。如果是非 chief 的 worker 节点,则需要依赖 chief 节点完成初始化或恢复 session 这些操作后才能设置属于自己的 session。

@tf_export(v1=[‘train.MonitoredTrainingSession’])
def MonitoredTrainingSession(
master=’’, # pylint: disable=invalid-name
is_chief=True,
checkpoint_dir=None,
scaffold=None,
hooks=None,
chief_only_hooks=None,
save_checkpoint_secs=USE_DEFAULT,
save_summaries_steps=USE_DEFAULT,
save_summaries_secs=USE_DEFAULT,
config=None,
stop_grace_period_secs=120,
log_step_count_steps=100,
max_wait_secs=7200,
save_checkpoint_steps=USE_DEFAULT,
summary_dir=None):

“”"
Creates a MonitoredSession for training.
Returns:
A MonitoredSession object.
“”"

scaffold = scaffold or Scaffold()
worker_context = distribute_coordinator_context.get_current_worker_context()

if worker_context:
return _create_monitored_session_with_worker_context(
worker_context,
scaffold,
checkpoint_dir=checkpoint_dir,
hooks=hooks,
chief_only_hooks=chief_only_hooks,
save_checkpoint_secs=save_checkpoint_secs,
save_summaries_steps=save_summaries_steps,
save_summaries_secs=save_summaries_secs,
config=config,
stop_grace_period_secs=stop_grace_period_secs,
log_step_count_steps=log_step_count_steps,
max_wait_secs=max_wait_secs,
save_checkpoint_steps=save_checkpoint_steps,
summary_dir=summary_dir)

if not is_chief:
session_creator = WorkerSessionCreator(
scaffold=scaffold,
master=master,
config=config,
max_wait_secs=max_wait_secs)
return MonitoredSession(
session_creator=session_creator,
hooks=hooks or [],
stop_grace_period_secs=stop_grace_period_secs)

all_hooks = []
“”“
将多个 hook 都加入到 all_hooks 这个列表中
”“”
if hooks:
all_hooks.extend(hooks)

return MonitoredSession(
session_creator=session_creator,
hooks=all_hooks,
stop_grace_period_secs=stop_grace_period_secs)
我们首先解释下参数:

is_chief:用于分布式系统中,用于判断该系统是否是 chief,如果为 True,它将负责初始化并恢复底层 TensorFlow 会话。如果为 False,它将等待 chief 初始化或恢复 TensorFlow 会话。

checkpoint_dir:一个字符串。指定一个用于恢复变量的 checkpoint 文件路径。

scaffold:用于收集或建立支持性 op 的脚手架。如果未指定,则会创建默认一个默认的 scaffold。它用于完成图表的创建。

hooks:SessionRunHook 对象的可选列表。可自己定义 SessionRunHook 对象,也可用已经预定义好的 SessionRunHook 对象,如:tf.train.StopAtStepHook() 设置停止训练的条件;tf.train.NanTensorHook(loss): 如果 loss 的值为 Nan 则停止训练;

chief_only_hooks:SessionRunHook 对象列表。如果 is_chief== True,则激活这些挂钩,否则忽略。

save_checkpoint_secs:用默认的 checkpoint saver 保存 checkpoint 的频率(以秒为单位)。如果 save_checkpoint_secs 设置为 None,不保存 checkpoint。

save_summaries_steps:使用默认 summaries saver 将摘要写入磁盘的频率(以全局步数表示)。如果 save_summaries_steps 和 save_summaries_secs 都设置为 None,则不使用默认的 summaries saver 保存 summaries。默认为 100

save_summaries_secs:使用默认 summaries saver 将摘要写入磁盘的频率(以秒为单位)。如果 save_summaries_steps 和 save_summaries_secs 都设置为 None,则不使用默认的摘要保存。默认未启用。

config:用于配置会话的 tf.ConfigProtoproto 的实例。它是 tf.Session 的构造函数的 config 参数。

stop_grace_period_secs:调用 close()后线程停止的秒数。

log_step_count_steps:记录全局步 / 秒的全局步数的频率。

实例化后可得到一个 MonitoredSession 对象,可当作普通 session 使用。

然后我们仔细分解下代码:

def _create_monitored_session_with_worker_context(
worker_context, # pylint: disable=missing-docstring
scaffold,
checkpoint_dir=None,
hooks=None,
chief_only_hooks=None,
save_checkpoint_secs=None,
save_summaries_steps=None,
save_summaries_secs=None,
config=None,
stop_grace_period_secs=120,
log_step_count_steps=100,
max_wait_secs=7200,
save_checkpoint_steps=None,
summary_dir=None):
all_hooks = []

“”“
将多个 hook 都加入到 all_hooks 这个列表中
”“”

logging.info(‘all_hooks %r’, all_hooks)

创建 session

session_creator = worker_context.session_creator(
scaffold,
config=config,
checkpoint_dir=checkpoint_dir,
max_wait_secs=max_wait_secs)

return MonitoredSession(
session_creator=session_creator,
hooks=all_hooks,
stop_grace_period_secs=stop_grace_period_secs)

session_creator 函数主体

def session_creator(self,
scaffold=None,
config=None,
checkpoint_dir=None,
checkpoint_filename_with_path=None,
max_wait_secs=7200):
“”"
根据正确 master 的 target 和 session 的 config 去返回 session 的 creator 方法体。
“”"
if config:
session_config = copy.deepcopy(config)
session_config.MergeFrom(self._session_config)
else:
session_config = self._session_config

复制代码
# 根据不同的角色来创建 session
if not self._strategy or self._strategy.extended.experimental_should_init:
logging.info("Creating chief session creator with config: %r", config)
return monitored_session.ChiefSessionCreator(
scaffold,
master=self.master_target,
config=session_config,
checkpoint_dir=checkpoint_dir,
checkpoint_filename_with_path=checkpoint_filename_with_path)
else:
logging.info("Creating worker session creator with config: %r", config)
return monitored_session.WorkerSessionCreator(
scaffold,
master=self.master_target,
config=session_config,
max_wait_secs=max_wait_secs)

ChiefSessionCreator

@tf_export(v1=[‘train.ChiefSessionCreator’])
class ChiefSessionCreator(SessionCreator):
“”“Creates a tf.compat.v1.Session for a chief.”""
def init(self,
scaffold=None,
master=’’,
config=None,
checkpoint_dir=None,
checkpoint_filename_with_path=None):
self._checkpoint_dir = checkpoint_dir
self._checkpoint_filename_with_path = checkpoint_filename_with_path
self._scaffold = scaffold or Scaffold()
self._session_manager = None
self._master = master
self._config = config

def _get_session_manager(self):
if self._session_manager:
return self._session_manager

复制代码
self._session_manager = sm.SessionManager(
local_init_op=self._scaffold.local_init_op,
ready_op=self._scaffold.ready_op,
ready_for_local_init_op=self._scaffold.ready_for_local_init_op,
graph=ops.get_default_graph())
return self._session_manager

def create_session(self):
self._scaffold.finalize()
return self._get_session_manager().prepare_session(
self._master,
saver=self._scaffold.saver,
checkpoint_dir=self._checkpoint_dir,
checkpoint_filename_with_path=self._checkpoint_filename_with_path,
config=self._config,
init_op=self._scaffold.init_op,
init_feed_dict=self._scaffold.init_feed_dict,
init_fn=self._scaffold.init_fn)

WorkerSessionCreator

@tf_export(v1=[‘train.WorkerSessionCreator’])
class WorkerSessionCreator(SessionCreator):
“”“Creates a tf.compat.v1.Session for a worker.”""
def init(self,
scaffold=None,
master=’’,
config=None,
max_wait_secs=30 * 60):
“”"Initializes a worker session creator.

复制代码
Args:
max_wait_secs: Maximum time to wait for the session to become available.
"""
self._scaffold = scaffold or Scaffold()
self._session_manager = None
self._master = master
self._config = config
self._max_wait_secs = max_wait_secs

def _get_session_manager(self):
if self._session_manager:
return self._session_manager

复制代码
self._session_manager = sm.SessionManager(
local_init_op=self._scaffold.local_init_op,
ready_op=self._scaffold.ready_op,
ready_for_local_init_op=self._scaffold.ready_for_local_init_op,
graph=ops.get_default_graph())
return self._session_manager

def create_session(self):
self._scaffold.finalize()
return self._get_session_manager().wait_for_session(
self._master, config=self._config, max_wait_secs=self._max_wait_secs)
从上面的源码中分析得到,MonitoredTrainingSession 可根据不同的角色去创建不同种类的 Session,其中 chief 节点是由 ChiefSessionCreator 类去创建 session,而非 chief 的 worker 节点是由 WorkerSessionCreator 类创建,特殊之处就是创建时调用的是 wait_for_session(),大致意识是需要等待 chief 节点的 session 创建完成之后才去创建属于自己节点的 session。其中创建 session 都是属于 SessionManager 类的一个方法,下面我们具体分析下 SessionManager 类:

官方针对 SessionManager 类有一个简单的例子,感觉很清楚:

prepare_session 函数可以初始化或者 restore 一个模型,需要传入init_opsaver

with tf.Graph().as_default():
# add operations to the graph…
# Create a SessionManager that will checkpoint the model in ‘/tmp/mydir’.
sm = SessionManager()
sess = sm.prepare_session(master, init_op, saver, checkpoint_dir)
# Use the session to train the graph.
while True:
sess.run(<my_train_op>)

第二个进程可以用以下方法启动 op,wait_for_session() 的意思是需要等上面一个 session 创建好之后

再创建自己的 session

with tf.Graph().as_default():
# …add operations to the graph…
# Create a SessionManager that will wait for the model to become ready.
sm = SessionManager()
sess = sm.wait_for_session(master)
# Use the session to train the graph.
while True:
sess.run(<my_train_op>)
然后我们可以重点关注下 prepare_session 和 wait_for_session 这两个函数:

@tf_export(v1=[“train.SessionManager”])
class SessionManager(object):
def init(self,
local_init_op=None,
ready_op=None,
ready_for_local_init_op=None,
graph=None,
recovery_wait_secs=30,
local_init_run_options=None):
“”"
local_init_op 是每当有一个新的 session 被创建时,就会运行下 local_init_op 这个操作。
ready_op 用于 check 模型是否准备好的一个 op。
ready_for_local_init_op 是 checkp 模型是否已经可以运行 local_init_op 的一个 op。
“”"
# Sets default values of arguments.
if graph is None:
graph = ops.get_default_graph()
self._local_init_op = local_init_op
self._ready_op = ready_op
self._ready_for_local_init_op = ready_for_local_init_op
self._graph = graph
self._recovery_wait_secs = recovery_wait_secs
self._target = None
self._local_init_run_options = local_init_run_options
if ready_for_local_init_op is not None and local_init_op is None:
raise ValueError("If you pass a ready_for_local_init_op "
"you must also pass a local_init_op "
“, ready_for_local_init_op [%s]” %
ready_for_local_init_op)

def prepare_session(self,
master,
init_op=None,
saver=None,
checkpoint_dir=None,
checkpoint_filename_with_path=None,
wait_for_checkpoint=False,
max_wait_secs=7200,
config=None,
init_feed_dict=None,
init_fn=None):
“”"
其实 prepare_session 函数的作用就是如果有 checkpoint 存在,就从 checkpoint 恢复 session,如果
不存在 checkpoint 就从传入的init_op和 调用init_fn函数去创建 session。
“”"
sess, is_loaded_from_checkpoint = self._restore_checkpoint(
master,
saver,
checkpoint_dir=checkpoint_dir,
checkpoint_filename_with_path=checkpoint_filename_with_path,
wait_for_checkpoint=wait_for_checkpoint,
max_wait_secs=max_wait_secs,
config=config)
if not is_loaded_from_checkpoint:
if init_op is None and not init_fn and self._local_init_op is None:
raise RuntimeError("Model is not initialized and no init_op or "
“init_fn or local_init_op was given”)
if init_op is not None:
sess.run(init_op, feed_dict=init_feed_dict)
if init_fn:
init_fn(sess)
”“”

“”“
return sess

def wait_for_session(self, master, config=None, max_wait_secs=float(“Inf”)):
“”"
Creates a new Session and waits for model to be ready.
“”"
self._target = master
if max_wait_secs is None:
max_wait_secs = float(“Inf”)
timer = _CountDownTimer(max_wait_secs)

复制代码
while True:
sess = session.Session(self._target, graph=self._graph, config=config)
not_ready_msg = None
not_ready_local_msg = None
local_init_success, not_ready_local_msg = self._try_run_local_init_op(
sess)
if local_init_success:
# Successful if local_init_op is None, or ready_for_local_init_op passes
is_ready, not_ready_msg = self._model_ready(sess)
if is_ready:
return sess
self._safe_close(sess)
# Do we have enough time left to try again?
remaining_ms_after_wait = (
timer.secs_remaining() - self._recovery_wait_secs)
if remaining_ms_after_wait < 0:
raise errors.DeadlineExceededError(
None, None,
"Session was not ready after waiting %d secs." % (max_wait_secs,))
logging.info("Waiting for model to be ready. "
"Ready_for_local_init_op: %s, ready: %s",
not_ready_local_msg, not_ready_msg)
time.sleep(self._recovery_wait_secs)

创建完 session 之后,再包装一下返回最终的 MonitoredSession 类,

一个完整的 monitored session 在创建时间内可做的事情(按顺序):

为每个 hook 调用 hook.begin()

  • 调用 scaffold.finalize() 完成 graph

  • 创建 session

  • 为模型参数做初始化 ,通过 Scaffold

  • 如果存在 checkpoint 则根据 checkpoint restore 参数

  • 发布 runners 队列

  • 调用 hook.after_create_session() 函数
    当 run 函数调用时,monitored session 做的事情:

  • 调用 hook.before_run()

  • 调用 TensorFlow 中的 session.run() with merged fetches and feed_dict

  • 调用 hook.after_run()

  • 返回 session.run() 的结果

  • 如果发生 AbortedError 或者 UnavailableError,则在再次执行 run() 之前恢复或者重新初始化会话
    当 close() 函数调用时,monitored session 做的事情:

  • 调用 hook.end()

  • 关闭 queue runners 和 session

  • 如果所有的输入数据被消耗完,抛出 OutOfRange 异常。
    最后,给大家贴一个使用 MonitoredSession 类进行分布式训练的 example:

复制代码
from __future__ import print_function, absolute_import, division
import tensorflow as tf
tf.app.flags.DEFINE_string("ps_hosts", "localhost:2222", "ps hosts")
tf.app.flags.DEFINE_string("worker_hosts", "localhost:2223,localhost:2224", "worker hosts")
tf.app.flags.DEFINE_string("job_name", "worker", "'ps' or'worker'")
tf.app.flags.DEFINE_integer("task_index", 0, "Index of task within the job")
tf.app.flags.DEFINE_integer("num_workers", 2, "Number of workers")
tf.app.flags.DEFINE_boolean("is_sync", False, "using synchronous training or not")
FLAGS = tf.app.flags.FLAGS
def model(images):
"""Define a simple mnist classifier"""
net = tf.layers.dense(images, 500, activation=tf.nn.relu)
net = tf.layers.dense(net, 500, activation=tf.nn.relu)
net = tf.layers.dense(net, 10, activation=None)
return net
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train = x_train.reshape(-1, 784).astype('float32')
x_test = x_test.reshape(-1, 784).astype('float32')
x_train /= 255
x_test /= 255
def get_batch(image, label, batch_size=32, training=True):
df = tf.data.Dataset.from_tensor_slices((image, label))
if training:
df = df.repeat(10).shuffle(buffer_size=1000)
df = df.batch(batch_size).prefetch(batch_size)
iterator = df.make_one_shot_iterator()
batch_x, batch_y = iterator.get_next()
return batch_x, batch_y
def main(_):
ps_hosts = FLAGS.ps_hosts.split(",")
worker_hosts = FLAGS.worker_hosts.split(",")
# create the cluster configured by `ps_hosts' and 'worker_hosts'
cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts})
# create a server for local task
server = tf.train.Server(cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index)
train_batch_x, train_batch_y = get_batch(x_train, y_train)
test_batch_x, test_batch_y = get_batch(x_test, y_test, training=False)
if FLAGS.job_name == "ps":
server.join() # ps hosts only join
elif FLAGS.job_name == "worker":
# workers perform the operation
# ps_strategy = tf.contrib.training.GreedyLoadBalancingStrategy(FLAGS.num_ps)
# Note: tf.train.replica_device_setter automatically place the paramters (Variables)
# on the ps hosts (default placement strategy: round-robin over all ps hosts, and also
# place multi copies of operations to each worker host
with tf.device(tf.train.replica_device_setter(worker_device="/job:worker/task:%d" % FLAGS.task_index,
cluster=cluster)):
logits = model(train_batch_x)
loss = tf.reduce_mean(
tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=tf.one_hot(train_batch_y, 10)))
# The StopAtStepHook handles stopping after running given steps.
hooks = [tf.train.StopAtStepHook(last_step=10000)]
global_step = tf.train.get_or_create_global_step()
optimizer = tf.train.AdamOptimizer(learning_rate=1e-04)
if FLAGS.is_sync:
# asynchronous training
# use tf.train.SyncReplicasOptimizer wrap optimizer
# ref: https://www.tensorflow.org/api_docs/python/tf/train/SyncReplicasOptimizer
optimizer = tf.train.SyncReplicasOptimizer(optimizer, replicas_to_aggregate=FLAGS.num_workers,
total_num_replicas=FLAGS.num_workers)
# create the hook which handles initialization and queues
hooks.append(optimizer.make_session_run_hook((FLAGS.task_index == 0)))
train_op = optimizer.minimize(loss, global_step=global_step)
# The MonitoredTrainingSession takes care of session initialization,
# restoring from a checkpoint, saving to a checkpoint, and closing when done
# or an error occurs.
with tf.train.MonitoredTrainingSession(master=server.target,
is_chief=(FLAGS.task_index == 0),
checkpoint_dir="./checkpoint_dir",
hooks=hooks) as mon_sess:
while not mon_sess.should_stop():
# mon_sess.run handles AbortedError in case of preempted PS.
_, ls, step = mon_sess.run([train_op, loss, global_step])
if step % 100 == 0:
print("Train step %d, loss: %f" % (step, ls))
if __name__ == "__main__":
tf.app.run()

参考文献:

https://www.cnblogs.com/estragon/p/10034511.html
https://zhuanlan.zhihu.com/p/88876923

本文转载自 Alex-zhai 知乎账号。

原文链接: https://zhuanlan.zhihu.com/p/91608555

评论

发布