MonitoredTrainingSession 是 tensorflow 管理分布式训练中一个使用很广泛的 API,集成了一些监控训练组件,如变量的初始化、从已有 checkpoint 恢复训练、summary、log 和 checkpoint 的保存等。在早期的 tf 版本中,一般使用 tf.train.Supervisor 来管理 session,后来框架升级后,官方推荐使用 MonitoredTrainingSession。MonitoredTrainingSession 有记录日志、训练可视化、checkpoint 保存、early-stop、训练效率调优等功能。
我们直接进入主题,下面是 MonitoredTrainingSession 源码,从注释中可了解到:MonitoredTrainingSession 的作用可用一句话来概括:如果 chief 节点,负责 session 的初始化或者从已有 checkpoint 恢复 session,并且创建一些用于保存 checkpoint 和 summary 的 hooks。如果是非 chief 的 worker 节点,则需要依赖 chief 节点完成初始化或恢复 session 这些操作后才能设置属于自己的 session。
@tf_export(v1=['train.MonitoredTrainingSession'])def MonitoredTrainingSession( master='', # pylint: disable=invalid-name is_chief=True, checkpoint_dir=None, scaffold=None, hooks=None, chief_only_hooks=None, save_checkpoint_secs=USE_DEFAULT, save_summaries_steps=USE_DEFAULT, save_summaries_secs=USE_DEFAULT, config=None, stop_grace_period_secs=120, log_step_count_steps=100, max_wait_secs=7200, save_checkpoint_steps=USE_DEFAULT, summary_dir=None):
""" Creates a `MonitoredSession` for training. Returns: A `MonitoredSession` object. """ scaffold = scaffold or Scaffold() worker_context = distribute_coordinator_context.get_current_worker_context()
if worker_context: return _create_monitored_session_with_worker_context( worker_context, scaffold, checkpoint_dir=checkpoint_dir, hooks=hooks, chief_only_hooks=chief_only_hooks, save_checkpoint_secs=save_checkpoint_secs, save_summaries_steps=save_summaries_steps, save_summaries_secs=save_summaries_secs, config=config, stop_grace_period_secs=stop_grace_period_secs, log_step_count_steps=log_step_count_steps, max_wait_secs=max_wait_secs, save_checkpoint_steps=save_checkpoint_steps, summary_dir=summary_dir)
if not is_chief: session_creator = WorkerSessionCreator( scaffold=scaffold, master=master, config=config, max_wait_secs=max_wait_secs) return MonitoredSession( session_creator=session_creator, hooks=hooks or [], stop_grace_period_secs=stop_grace_period_secs)
all_hooks = [] “”“ 将多个hook都加入到all_hooks这个列表中 ”“” if hooks: all_hooks.extend(hooks)
return MonitoredSession( session_creator=session_creator, hooks=all_hooks, stop_grace_period_secs=stop_grace_period_secs)
复制代码
我们首先解释下参数:
is_chief:用于分布式系统中,用于判断该系统是否是 chief,如果为 True,它将负责初始化并恢复底层 TensorFlow 会话。如果为 False,它将等待 chief 初始化或恢复 TensorFlow 会话。
checkpoint_dir:一个字符串。指定一个用于恢复变量的 checkpoint 文件路径。
scaffold:用于收集或建立支持性 op 的脚手架。如果未指定,则会创建默认一个默认的 scaffold。它用于完成图表的创建。
hooks:SessionRunHook 对象的可选列表。可自己定义 SessionRunHook 对象,也可用已经预定义好的 SessionRunHook 对象,如:tf.train.StopAtStepHook()设置停止训练的条件;tf.train.NanTensorHook(loss):如果 loss 的值为 Nan 则停止训练;
chief_only_hooks:SessionRunHook 对象列表。如果 is_chief== True,则激活这些挂钩,否则忽略。
save_checkpoint_secs:用默认的 checkpoint saver 保存 checkpoint 的频率(以秒为单位)。如果 save_checkpoint_secs 设置为 None,不保存 checkpoint。
save_summaries_steps:使用默认 summaries saver 将摘要写入磁盘的频率(以全局步数表示)。如果 save_summaries_steps 和 save_summaries_secs 都设置为 None,则不使用默认的 summaries saver 保存 summaries。默认为 100
save_summaries_secs:使用默认 summaries saver 将摘要写入磁盘的频率(以秒为单位)。如果 save_summaries_steps 和 save_summaries_secs 都设置为 None,则不使用默认的摘要保存。默认未启用。
config:用于配置会话的 tf.ConfigProtoproto 的实例。它是 tf.Session 的构造函数的 config 参数。
stop_grace_period_secs:调用 close()后线程停止的秒数。
log_step_count_steps:记录全局步/秒的全局步数的频率。
实例化后可得到一个 MonitoredSession 对象,可当作普通 session 使用。
然后我们仔细分解下代码:
def _create_monitored_session_with_worker_context( worker_context, # pylint: disable=missing-docstring scaffold, checkpoint_dir=None, hooks=None, chief_only_hooks=None, save_checkpoint_secs=None, save_summaries_steps=None, save_summaries_secs=None, config=None, stop_grace_period_secs=120, log_step_count_steps=100, max_wait_secs=7200, save_checkpoint_steps=None, summary_dir=None): all_hooks = []
“”“
复制代码
将多个 hook 都加入到 all_hooks 这个列表中
”“”
logging.info('all_hooks %r', all_hooks) # 创建session session_creator = worker_context.session_creator( scaffold, config=config, checkpoint_dir=checkpoint_dir, max_wait_secs=max_wait_secs)
return MonitoredSession( session_creator=session_creator, hooks=all_hooks, stop_grace_period_secs=stop_grace_period_secs)
# session_creator 函数主体 def session_creator(self, scaffold=None, config=None, checkpoint_dir=None, checkpoint_filename_with_path=None, max_wait_secs=7200): """
复制代码
根据正确master的target和session的config去返回session的creator方法体。
复制代码
""" if config: session_config = copy.deepcopy(config) session_config.MergeFrom(self._session_config) else: session_config = self._session_config
复制代码
if not self._strategy or self._strategy.extended.experimental_should_init: logging.info("Creating chief session creator with config: %r", config) return monitored_session.ChiefSessionCreator( scaffold, master=self.master_target, config=session_config, checkpoint_dir=checkpoint_dir, checkpoint_filename_with_path=checkpoint_filename_with_path) else: logging.info("Creating worker session creator with config: %r", config) return monitored_session.WorkerSessionCreator( scaffold, master=self.master_target, config=session_config, max_wait_secs=max_wait_secs)
# ChiefSessionCreator@tf_export(v1=['train.ChiefSessionCreator'])class ChiefSessionCreator(SessionCreator): """Creates a tf.compat.v1.Session for a chief.""" def __init__(self, scaffold=None, master='', config=None, checkpoint_dir=None, checkpoint_filename_with_path=None): self._checkpoint_dir = checkpoint_dir self._checkpoint_filename_with_path = checkpoint_filename_with_path self._scaffold = scaffold or Scaffold() self._session_manager = None self._master = master self._config = config
def _get_session_manager(self): if self._session_manager: return self._session_manager
self._session_manager = sm.SessionManager( local_init_op=self._scaffold.local_init_op, ready_op=self._scaffold.ready_op, ready_for_local_init_op=self._scaffold.ready_for_local_init_op, graph=ops.get_default_graph()) return self._session_manager
def create_session(self): self._scaffold.finalize() return self._get_session_manager().prepare_session( self._master, saver=self._scaffold.saver, checkpoint_dir=self._checkpoint_dir, checkpoint_filename_with_path=self._checkpoint_filename_with_path, config=self._config, init_op=self._scaffold.init_op, init_feed_dict=self._scaffold.init_feed_dict, init_fn=self._scaffold.init_fn)
# WorkerSessionCreator@tf_export(v1=['train.WorkerSessionCreator'])class WorkerSessionCreator(SessionCreator): """Creates a tf.compat.v1.Session for a worker.""" def __init__(self, scaffold=None, master='', config=None, max_wait_secs=30 * 60): """Initializes a worker session creator.
Args: max_wait_secs: Maximum time to wait for the session to become available. """ self._scaffold = scaffold or Scaffold() self._session_manager = None self._master = master self._config = config self._max_wait_secs = max_wait_secs
def _get_session_manager(self): if self._session_manager: return self._session_manager
self._session_manager = sm.SessionManager( local_init_op=self._scaffold.local_init_op, ready_op=self._scaffold.ready_op, ready_for_local_init_op=self._scaffold.ready_for_local_init_op, graph=ops.get_default_graph()) return self._session_manager
def create_session(self): self._scaffold.finalize() return self._get_session_manager().wait_for_session( self._master, config=self._config, max_wait_secs=self._max_wait_secs)
复制代码
从上面的源码中分析得到,MonitoredTrainingSession 可根据不同的角色去创建不同种类的 Session,其中 chief 节点是由 ChiefSessionCreator 类去创建 session,而非 chief 的 worker 节点是由 WorkerSessionCreator 类创建,特殊之处就是创建时调用的是 wait_for_session(),大致意识是需要等待 chief 节点的 session 创建完成之后才去创建属于自己节点的 session。其中创建 session 都是属于 SessionManager 类的一个方法,下面我们具体分析下 SessionManager 类:
官方针对 SessionManager 类有一个简单的例子,感觉很清楚:
# prepare_session函数可以初始化或者restore一个模型,需要传入`init_op`和 `saver` with tf.Graph().as_default(): # add operations to the graph... # Create a SessionManager that will checkpoint the model in '/tmp/mydir'. sm = SessionManager() sess = sm.prepare_session(master, init_op, saver, checkpoint_dir) # Use the session to train the graph. while True: sess.run(<my_train_op>)
复制代码
第二个进程可以用以下方法启动 op,wait_for_session()的意思是需要等上面一个 session 创建好之后
# 再创建自己的session with tf.Graph().as_default(): # ...add operations to the graph... # Create a SessionManager that will wait for the model to become ready. sm = SessionManager() sess = sm.wait_for_session(master) # Use the session to train the graph. while True: sess.run(<my_train_op>)
复制代码
然后我们可以重点关注下 prepare_session 和 wait_for_session 这两个函数:
@tf_export(v1=["train.SessionManager"])class SessionManager(object): def __init__(self, local_init_op=None, ready_op=None, ready_for_local_init_op=None, graph=None, recovery_wait_secs=30, local_init_run_options=None): """
复制代码
local_init_op 是每当有一个新的session被创建时,就会运行下local_init_op这个操作。ready_op 用于check模型是否准备好的一个op。ready_for_local_init_op是checkp模型是否已经可以运行local_init_op的一个op。
复制代码
""" # Sets default values of arguments. if graph is None: graph = ops.get_default_graph() self._local_init_op = local_init_op self._ready_op = ready_op self._ready_for_local_init_op = ready_for_local_init_op self._graph = graph self._recovery_wait_secs = recovery_wait_secs self._target = None self._local_init_run_options = local_init_run_options if ready_for_local_init_op is not None and local_init_op is None: raise ValueError("If you pass a ready_for_local_init_op " "you must also pass a local_init_op " ", ready_for_local_init_op [%s]" % ready_for_local_init_op)
def prepare_session(self, master, init_op=None, saver=None, checkpoint_dir=None, checkpoint_filename_with_path=None, wait_for_checkpoint=False, max_wait_secs=7200, config=None, init_feed_dict=None, init_fn=None): """
复制代码
其实prepare_session函数的作用就是如果有checkpoint存在,就从checkpoint恢复session,如果不存在checkpoint就从传入的`init_op`和 调用`init_fn`函数去创建session。
复制代码
""" sess, is_loaded_from_checkpoint = self._restore_checkpoint( master, saver, checkpoint_dir=checkpoint_dir, checkpoint_filename_with_path=checkpoint_filename_with_path, wait_for_checkpoint=wait_for_checkpoint, max_wait_secs=max_wait_secs, config=config) if not is_loaded_from_checkpoint: if init_op is None and not init_fn and self._local_init_op is None: raise RuntimeError("Model is not initialized and no init_op or " "init_fn or local_init_op was given") if init_op is not None: sess.run(init_op, feed_dict=init_feed_dict) if init_fn: init_fn(sess) ”“” ..... “”“ return sess
def wait_for_session(self, master, config=None, max_wait_secs=float("Inf")): """ Creates a new `Session` and waits for model to be ready. """ self._target = master if max_wait_secs is None: max_wait_secs = float("Inf") timer = _CountDownTimer(max_wait_secs)
while True: sess = session.Session(self._target, graph=self._graph, config=config) not_ready_msg = None not_ready_local_msg = None local_init_success, not_ready_local_msg = self._try_run_local_init_op( sess) if local_init_success: # Successful if local_init_op is None, or ready_for_local_init_op passes is_ready, not_ready_msg = self._model_ready(sess) if is_ready: return sess self._safe_close(sess) # Do we have enough time left to try again? remaining_ms_after_wait = ( timer.secs_remaining() - self._recovery_wait_secs) if remaining_ms_after_wait < 0: raise errors.DeadlineExceededError( None, None, "Session was not ready after waiting %d secs." % (max_wait_secs,)) logging.info("Waiting for model to be ready. " "Ready_for_local_init_op: %s, ready: %s", not_ready_local_msg, not_ready_msg) time.sleep(self._recovery_wait_secs)
复制代码
创建完 session 之后,再包装一下返回最终的 MonitoredSession 类,
一个完整的 monitored session 在创建时间内可做的事情(按顺序):
我们直接进入主题,下面是 MonitoredTrainingSession 源码,从注释中可了解到:MonitoredTrainingSession 的作用可用一句话来概括:如果 chief 节点,负责 session 的初始化或者从已有 checkpoint 恢复 session,并且创建一些用于保存 checkpoint 和 summary 的 hooks。如果是非 chief 的 worker 节点,则需要依赖 chief 节点完成初始化或恢复 session 这些操作后才能设置属于自己的 session。
@tf_export(v1=[‘train.MonitoredTrainingSession’])
def MonitoredTrainingSession(
master=’’, # pylint: disable=invalid-name
is_chief=True,
checkpoint_dir=None,
scaffold=None,
hooks=None,
chief_only_hooks=None,
save_checkpoint_secs=USE_DEFAULT,
save_summaries_steps=USE_DEFAULT,
save_summaries_secs=USE_DEFAULT,
config=None,
stop_grace_period_secs=120,
log_step_count_steps=100,
max_wait_secs=7200,
save_checkpoint_steps=USE_DEFAULT,
summary_dir=None):
“”"
Creates a MonitoredSession for training.
Returns:
A MonitoredSession object.
“”"
scaffold = scaffold or Scaffold()
worker_context = distribute_coordinator_context.get_current_worker_context()
if worker_context:
return _create_monitored_session_with_worker_context(
worker_context,
scaffold,
checkpoint_dir=checkpoint_dir,
hooks=hooks,
chief_only_hooks=chief_only_hooks,
save_checkpoint_secs=save_checkpoint_secs,
save_summaries_steps=save_summaries_steps,
save_summaries_secs=save_summaries_secs,
config=config,
stop_grace_period_secs=stop_grace_period_secs,
log_step_count_steps=log_step_count_steps,
max_wait_secs=max_wait_secs,
save_checkpoint_steps=save_checkpoint_steps,
summary_dir=summary_dir)
if not is_chief:
session_creator = WorkerSessionCreator(
scaffold=scaffold,
master=master,
config=config,
max_wait_secs=max_wait_secs)
return MonitoredSession(
session_creator=session_creator,
hooks=hooks or [],
stop_grace_period_secs=stop_grace_period_secs)
all_hooks = []
“”“
将多个 hook 都加入到 all_hooks 这个列表中
”“”
if hooks:
all_hooks.extend(hooks)
return MonitoredSession(
session_creator=session_creator,
hooks=all_hooks,
stop_grace_period_secs=stop_grace_period_secs)
我们首先解释下参数:
is_chief:用于分布式系统中,用于判断该系统是否是 chief,如果为 True,它将负责初始化并恢复底层 TensorFlow 会话。如果为 False,它将等待 chief 初始化或恢复 TensorFlow 会话。
checkpoint_dir:一个字符串。指定一个用于恢复变量的 checkpoint 文件路径。
scaffold:用于收集或建立支持性 op 的脚手架。如果未指定,则会创建默认一个默认的 scaffold。它用于完成图表的创建。
hooks:SessionRunHook 对象的可选列表。可自己定义 SessionRunHook 对象,也可用已经预定义好的 SessionRunHook 对象,如:tf.train.StopAtStepHook()设置停止训练的条件;tf.train.NanTensorHook(loss):如果 loss 的值为 Nan 则停止训练;
chief_only_hooks:SessionRunHook 对象列表。如果 is_chief== True,则激活这些挂钩,否则忽略。
save_checkpoint_secs:用默认的 checkpoint saver 保存 checkpoint 的频率(以秒为单位)。如果 save_checkpoint_secs 设置为 None,不保存 checkpoint。
save_summaries_steps:使用默认 summaries saver 将摘要写入磁盘的频率(以全局步数表示)。如果 save_summaries_steps 和 save_summaries_secs 都设置为 None,则不使用默认的 summaries saver 保存 summaries。默认为 100
save_summaries_secs:使用默认 summaries saver 将摘要写入磁盘的频率(以秒为单位)。如果 save_summaries_steps 和 save_summaries_secs 都设置为 None,则不使用默认的摘要保存。默认未启用。
config:用于配置会话的 tf.ConfigProtoproto 的实例。它是 tf.Session 的构造函数的 config 参数。
stop_grace_period_secs:调用 close()后线程停止的秒数。
log_step_count_steps:记录全局步/秒的全局步数的频率。
实例化后可得到一个 MonitoredSession 对象,可当作普通 session 使用。
然后我们仔细分解下代码:
def _create_monitored_session_with_worker_context(
worker_context, # pylint: disable=missing-docstring
scaffold,
checkpoint_dir=None,
hooks=None,
chief_only_hooks=None,
save_checkpoint_secs=None,
save_summaries_steps=None,
save_summaries_secs=None,
config=None,
stop_grace_period_secs=120,
log_step_count_steps=100,
max_wait_secs=7200,
save_checkpoint_steps=None,
summary_dir=None):
all_hooks = []
“”“
将多个 hook 都加入到 all_hooks 这个列表中
”“”
logging.info(‘all_hooks %r’, all_hooks)
创建 session
session_creator = worker_context.session_creator(
scaffold,
config=config,
checkpoint_dir=checkpoint_dir,
max_wait_secs=max_wait_secs)
return MonitoredSession(
session_creator=session_creator,
hooks=all_hooks,
stop_grace_period_secs=stop_grace_period_secs)
session_creator 函数主体
def session_creator(self,
scaffold=None,
config=None,
checkpoint_dir=None,
checkpoint_filename_with_path=None,
max_wait_secs=7200):
“”"
根据正确 master 的 target 和 session 的 config 去返回 session 的 creator 方法体。
“”"
if config:
session_config = copy.deepcopy(config)
session_config.MergeFrom(self._session_config)
else:
session_config = self._session_config
# 根据不同的角色来创建sessionif not self._strategy or self._strategy.extended.experimental_should_init: logging.info("Creating chief session creator with config: %r", config) return monitored_session.ChiefSessionCreator( scaffold, master=self.master_target, config=session_config, checkpoint_dir=checkpoint_dir, checkpoint_filename_with_path=checkpoint_filename_with_path)else: logging.info("Creating worker session creator with config: %r", config) return monitored_session.WorkerSessionCreator( scaffold, master=self.master_target, config=session_config, max_wait_secs=max_wait_secs)
复制代码
ChiefSessionCreator
@tf_export(v1=[‘train.ChiefSessionCreator’])
class ChiefSessionCreator(SessionCreator):
“”“Creates a tf.compat.v1.Session for a chief.”""
def init(self,
scaffold=None,
master=’’,
config=None,
checkpoint_dir=None,
checkpoint_filename_with_path=None):
self._checkpoint_dir = checkpoint_dir
self._checkpoint_filename_with_path = checkpoint_filename_with_path
self._scaffold = scaffold or Scaffold()
self._session_manager = None
self._master = master
self._config = config
def _get_session_manager(self):
if self._session_manager:
return self._session_manager
self._session_manager = sm.SessionManager( local_init_op=self._scaffold.local_init_op, ready_op=self._scaffold.ready_op, ready_for_local_init_op=self._scaffold.ready_for_local_init_op, graph=ops.get_default_graph())return self._session_manager
复制代码
def create_session(self):
self._scaffold.finalize()
return self._get_session_manager().prepare_session(
self._master,
saver=self._scaffold.saver,
checkpoint_dir=self._checkpoint_dir,
checkpoint_filename_with_path=self._checkpoint_filename_with_path,
config=self._config,
init_op=self._scaffold.init_op,
init_feed_dict=self._scaffold.init_feed_dict,
init_fn=self._scaffold.init_fn)
WorkerSessionCreator
@tf_export(v1=[‘train.WorkerSessionCreator’])
class WorkerSessionCreator(SessionCreator):
“”“Creates a tf.compat.v1.Session for a worker.”""
def init(self,
scaffold=None,
master=’’,
config=None,
max_wait_secs=30 * 60):
“”"Initializes a worker session creator.
Args: max_wait_secs: Maximum time to wait for the session to become available."""self._scaffold = scaffold or Scaffold()self._session_manager = Noneself._master = masterself._config = configself._max_wait_secs = max_wait_secs
复制代码
def _get_session_manager(self):
if self._session_manager:
return self._session_manager
self._session_manager = sm.SessionManager( local_init_op=self._scaffold.local_init_op, ready_op=self._scaffold.ready_op, ready_for_local_init_op=self._scaffold.ready_for_local_init_op, graph=ops.get_default_graph())return self._session_manager
复制代码
def create_session(self):
self._scaffold.finalize()
return self._get_session_manager().wait_for_session(
self._master, config=self._config, max_wait_secs=self._max_wait_secs)
从上面的源码中分析得到,MonitoredTrainingSession 可根据不同的角色去创建不同种类的 Session,其中 chief 节点是由 ChiefSessionCreator 类去创建 session,而非 chief 的 worker 节点是由 WorkerSessionCreator 类创建,特殊之处就是创建时调用的是 wait_for_session(),大致意识是需要等待 chief 节点的 session 创建完成之后才去创建属于自己节点的 session。其中创建 session 都是属于 SessionManager 类的一个方法,下面我们具体分析下 SessionManager 类:
官方针对 SessionManager 类有一个简单的例子,感觉很清楚:
prepare_session 函数可以初始化或者 restore 一个模型,需要传入init_op和 saver
with tf.Graph().as_default():
# add operations to the graph…
# Create a SessionManager that will checkpoint the model in ‘/tmp/mydir’.
sm = SessionManager()
sess = sm.prepare_session(master, init_op, saver, checkpoint_dir)
# Use the session to train the graph.
while True:
sess.run(<my_train_op>)
第二个进程可以用以下方法启动 op,wait_for_session()的意思是需要等上面一个 session 创建好之后
再创建自己的 session
with tf.Graph().as_default():
# …add operations to the graph…
# Create a SessionManager that will wait for the model to become ready.
sm = SessionManager()
sess = sm.wait_for_session(master)
# Use the session to train the graph.
while True:
sess.run(<my_train_op>)
然后我们可以重点关注下 prepare_session 和 wait_for_session 这两个函数:
@tf_export(v1=[“train.SessionManager”])
class SessionManager(object):
def init(self,
local_init_op=None,
ready_op=None,
ready_for_local_init_op=None,
graph=None,
recovery_wait_secs=30,
local_init_run_options=None):
“”"
local_init_op 是每当有一个新的 session 被创建时,就会运行下 local_init_op 这个操作。
ready_op 用于 check 模型是否准备好的一个 op。
ready_for_local_init_op 是 checkp 模型是否已经可以运行 local_init_op 的一个 op。
“”"
# Sets default values of arguments.
if graph is None:
graph = ops.get_default_graph()
self._local_init_op = local_init_op
self._ready_op = ready_op
self._ready_for_local_init_op = ready_for_local_init_op
self._graph = graph
self._recovery_wait_secs = recovery_wait_secs
self._target = None
self._local_init_run_options = local_init_run_options
if ready_for_local_init_op is not None and local_init_op is None:
raise ValueError("If you pass a ready_for_local_init_op "
"you must also pass a local_init_op "
“, ready_for_local_init_op [%s]” %
ready_for_local_init_op)
def prepare_session(self,
master,
init_op=None,
saver=None,
checkpoint_dir=None,
checkpoint_filename_with_path=None,
wait_for_checkpoint=False,
max_wait_secs=7200,
config=None,
init_feed_dict=None,
init_fn=None):
“”"
其实 prepare_session 函数的作用就是如果有 checkpoint 存在,就从 checkpoint 恢复 session,如果
不存在 checkpoint 就从传入的init_op和 调用init_fn函数去创建 session。
“”"
sess, is_loaded_from_checkpoint = self._restore_checkpoint(
master,
saver,
checkpoint_dir=checkpoint_dir,
checkpoint_filename_with_path=checkpoint_filename_with_path,
wait_for_checkpoint=wait_for_checkpoint,
max_wait_secs=max_wait_secs,
config=config)
if not is_loaded_from_checkpoint:
if init_op is None and not init_fn and self._local_init_op is None:
raise RuntimeError("Model is not initialized and no init_op or "
“init_fn or local_init_op was given”)
if init_op is not None:
sess.run(init_op, feed_dict=init_feed_dict)
if init_fn:
init_fn(sess)
”“”
…
“”“
return sess
def wait_for_session(self, master, config=None, max_wait_secs=float(“Inf”)):
“”"
Creates a new Session and waits for model to be ready.
“”"
self._target = master
if max_wait_secs is None:
max_wait_secs = float(“Inf”)
timer = _CountDownTimer(max_wait_secs)
while True: sess = session.Session(self._target, graph=self._graph, config=config) not_ready_msg = None not_ready_local_msg = None local_init_success, not_ready_local_msg = self._try_run_local_init_op( sess) if local_init_success: # Successful if local_init_op is None, or ready_for_local_init_op passes is_ready, not_ready_msg = self._model_ready(sess) if is_ready: return sess self._safe_close(sess) # Do we have enough time left to try again? remaining_ms_after_wait = ( timer.secs_remaining() - self._recovery_wait_secs) if remaining_ms_after_wait < 0: raise errors.DeadlineExceededError( None, None, "Session was not ready after waiting %d secs." % (max_wait_secs,)) logging.info("Waiting for model to be ready. " "Ready_for_local_init_op: %s, ready: %s", not_ready_local_msg, not_ready_msg) time.sleep(self._recovery_wait_secs)
复制代码
创建完 session 之后,再包装一下返回最终的 MonitoredSession 类,
一个完整的 monitored session 在创建时间内可做的事情(按顺序):
为每个 hook 调用 hook.begin()
调用 scaffold.finalize()完成 graph
创建 session
为模型参数做初始化 ,通过 Scaffold
如果存在 checkpoint 则根据 checkpoint restore 参数
发布 runners 队列
调用 hook.after_create_session()函数
当 run 函数调用时,monitored session 做的事情:
调用 hook.before_run()
调用 TensorFlow 中的 session.run() with merged fetches and feed_dict
调用 hook.after_run()
返回 session.run()的结果
如果发生 AbortedError 或者 UnavailableError,则在再次执行 run()之前恢复或者重新初始化会话
当 close()函数调用时,monitored session 做的事情:
调用 hook.end()
关闭 queue runners 和 session
如果所有的输入数据被消耗完,抛出 OutOfRange 异常。
最后,给大家贴一个使用 MonitoredSession 类进行分布式训练的 example:
from __future__ import print_function, absolute_import, division
import tensorflow as tf
tf.app.flags.DEFINE_string("ps_hosts", "localhost:2222", "ps hosts")tf.app.flags.DEFINE_string("worker_hosts", "localhost:2223,localhost:2224", "worker hosts")tf.app.flags.DEFINE_string("job_name", "worker", "'ps' or'worker'")tf.app.flags.DEFINE_integer("task_index", 0, "Index of task within the job")tf.app.flags.DEFINE_integer("num_workers", 2, "Number of workers")tf.app.flags.DEFINE_boolean("is_sync", False, "using synchronous training or not")
FLAGS = tf.app.flags.FLAGS
def model(images): """Define a simple mnist classifier""" net = tf.layers.dense(images, 500, activation=tf.nn.relu) net = tf.layers.dense(net, 500, activation=tf.nn.relu) net = tf.layers.dense(net, 10, activation=None) return net
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()x_train = x_train.reshape(-1, 784).astype('float32')x_test = x_test.reshape(-1, 784).astype('float32')x_train /= 255x_test /= 255
def get_batch(image, label, batch_size=32, training=True): df = tf.data.Dataset.from_tensor_slices((image, label)) if training: df = df.repeat(10).shuffle(buffer_size=1000) df = df.batch(batch_size).prefetch(batch_size) iterator = df.make_one_shot_iterator() batch_x, batch_y = iterator.get_next() return batch_x, batch_y
def main(_): ps_hosts = FLAGS.ps_hosts.split(",") worker_hosts = FLAGS.worker_hosts.split(",")
# create the cluster configured by `ps_hosts' and 'worker_hosts' cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts})
# create a server for local task server = tf.train.Server(cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index)
train_batch_x, train_batch_y = get_batch(x_train, y_train) test_batch_x, test_batch_y = get_batch(x_test, y_test, training=False)
if FLAGS.job_name == "ps": server.join() # ps hosts only join elif FLAGS.job_name == "worker": # workers perform the operation # ps_strategy = tf.contrib.training.GreedyLoadBalancingStrategy(FLAGS.num_ps)
# Note: tf.train.replica_device_setter automatically place the paramters (Variables) # on the ps hosts (default placement strategy: round-robin over all ps hosts, and also # place multi copies of operations to each worker host with tf.device(tf.train.replica_device_setter(worker_device="/job:worker/task:%d" % FLAGS.task_index, cluster=cluster)):
logits = model(train_batch_x) loss = tf.reduce_mean( tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=tf.one_hot(train_batch_y, 10)))
# The StopAtStepHook handles stopping after running given steps. hooks = [tf.train.StopAtStepHook(last_step=10000)]
global_step = tf.train.get_or_create_global_step() optimizer = tf.train.AdamOptimizer(learning_rate=1e-04) if FLAGS.is_sync: # asynchronous training # use tf.train.SyncReplicasOptimizer wrap optimizer # ref: https://www.tensorflow.org/api_docs/python/tf/train/SyncReplicasOptimizer optimizer = tf.train.SyncReplicasOptimizer(optimizer, replicas_to_aggregate=FLAGS.num_workers, total_num_replicas=FLAGS.num_workers) # create the hook which handles initialization and queues hooks.append(optimizer.make_session_run_hook((FLAGS.task_index == 0)))
train_op = optimizer.minimize(loss, global_step=global_step)
# The MonitoredTrainingSession takes care of session initialization, # restoring from a checkpoint, saving to a checkpoint, and closing when done # or an error occurs. with tf.train.MonitoredTrainingSession(master=server.target, is_chief=(FLAGS.task_index == 0), checkpoint_dir="./checkpoint_dir", hooks=hooks) as mon_sess: while not mon_sess.should_stop(): # mon_sess.run handles AbortedError in case of preempted PS. _, ls, step = mon_sess.run([train_op, loss, global_step]) if step % 100 == 0: print("Train step %d, loss: %f" % (step, ls))
if __name__ == "__main__": tf.app.run()
复制代码
参考文献:
https://www.cnblogs.com/estragon/p/10034511.html
https://zhuanlan.zhihu.com/p/88876923
本文转载自 Alex-zhai 知乎账号。
原文链接:https://zhuanlan.zhihu.com/p/91608555
评论