「ArchSummit·深圳」人工智能如何促进工业和制造领域的智能化转型? >>> 了解详情
写点什么

分布式 tensorflow 源码解读 2:MonitoredTrainingSession

  • 2019-11-28
  • 本文字数:15917 字

    阅读完需:约 52 分钟

分布式tensorflow源码解读2:MonitoredTrainingSession

MonitoredTrainingSession 是 tensorflow 管理分布式训练中一个使用很广泛的 API,集成了一些监控训练组件,如变量的初始化、从已有 checkpoint 恢复训练、summary、log 和 checkpoint 的保存等。在早期的 tf 版本中,一般使用 tf.train.Supervisor 来管理 session,后来框架升级后,官方推荐使用 MonitoredTrainingSession。MonitoredTrainingSession 有记录日志、训练可视化、checkpoint 保存、early-stop、训练效率调优等功能。


我们直接进入主题,下面是 MonitoredTrainingSession 源码,从注释中可了解到:MonitoredTrainingSession 的作用可用一句话来概括:如果 chief 节点,负责 session 的初始化或者从已有 checkpoint 恢复 session,并且创建一些用于保存 checkpoint 和 summary 的 hooks。如果是非 chief 的 worker 节点,则需要依赖 chief 节点完成初始化或恢复 session 这些操作后才能设置属于自己的 session。


@tf_export(v1=['train.MonitoredTrainingSession'])def MonitoredTrainingSession(    master='',  # pylint: disable=invalid-name    is_chief=True,    checkpoint_dir=None,    scaffold=None,    hooks=None,    chief_only_hooks=None,    save_checkpoint_secs=USE_DEFAULT,    save_summaries_steps=USE_DEFAULT,    save_summaries_secs=USE_DEFAULT,    config=None,    stop_grace_period_secs=120,    log_step_count_steps=100,    max_wait_secs=7200,    save_checkpoint_steps=USE_DEFAULT,    summary_dir=None):
""" Creates a `MonitoredSession` for training. Returns: A `MonitoredSession` object. """ scaffold = scaffold or Scaffold() worker_context = distribute_coordinator_context.get_current_worker_context()
if worker_context: return _create_monitored_session_with_worker_context( worker_context, scaffold, checkpoint_dir=checkpoint_dir, hooks=hooks, chief_only_hooks=chief_only_hooks, save_checkpoint_secs=save_checkpoint_secs, save_summaries_steps=save_summaries_steps, save_summaries_secs=save_summaries_secs, config=config, stop_grace_period_secs=stop_grace_period_secs, log_step_count_steps=log_step_count_steps, max_wait_secs=max_wait_secs, save_checkpoint_steps=save_checkpoint_steps, summary_dir=summary_dir)
if not is_chief: session_creator = WorkerSessionCreator( scaffold=scaffold, master=master, config=config, max_wait_secs=max_wait_secs) return MonitoredSession( session_creator=session_creator, hooks=hooks or [], stop_grace_period_secs=stop_grace_period_secs)
all_hooks = [] “”“ 将多个hook都加入到all_hooks这个列表中 ”“” if hooks: all_hooks.extend(hooks)
return MonitoredSession( session_creator=session_creator, hooks=all_hooks, stop_grace_period_secs=stop_grace_period_secs)
复制代码


我们首先解释下参数:


is_chief:用于分布式系统中,用于判断该系统是否是 chief,如果为 True,它将负责初始化并恢复底层 TensorFlow 会话。如果为 False,它将等待 chief 初始化或恢复 TensorFlow 会话。


checkpoint_dir:一个字符串。指定一个用于恢复变量的 checkpoint 文件路径。


scaffold:用于收集或建立支持性 op 的脚手架。如果未指定,则会创建默认一个默认的 scaffold。它用于完成图表的创建。


hooks:SessionRunHook 对象的可选列表。可自己定义 SessionRunHook 对象,也可用已经预定义好的 SessionRunHook 对象,如:tf.train.StopAtStepHook()设置停止训练的条件;tf.train.NanTensorHook(loss):如果 loss 的值为 Nan 则停止训练;


chief_only_hooks:SessionRunHook 对象列表。如果 is_chief== True,则激活这些挂钩,否则忽略。


save_checkpoint_secs:用默认的 checkpoint saver 保存 checkpoint 的频率(以秒为单位)。如果 save_checkpoint_secs 设置为 None,不保存 checkpoint。


save_summaries_steps:使用默认 summaries saver 将摘要写入磁盘的频率(以全局步数表示)。如果 save_summaries_steps 和 save_summaries_secs 都设置为 None,则不使用默认的 summaries saver 保存 summaries。默认为 100


save_summaries_secs:使用默认 summaries saver 将摘要写入磁盘的频率(以秒为单位)。如果 save_summaries_steps 和 save_summaries_secs 都设置为 None,则不使用默认的摘要保存。默认未启用。


config:用于配置会话的 tf.ConfigProtoproto 的实例。它是 tf.Session 的构造函数的 config 参数。


stop_grace_period_secs:调用 close()后线程停止的秒数。


log_step_count_steps:记录全局步/秒的全局步数的频率。


实例化后可得到一个 MonitoredSession 对象,可当作普通 session 使用。


然后我们仔细分解下代码:


def _create_monitored_session_with_worker_context(    worker_context,  # pylint: disable=missing-docstring    scaffold,    checkpoint_dir=None,    hooks=None,    chief_only_hooks=None,    save_checkpoint_secs=None,    save_summaries_steps=None,    save_summaries_secs=None,    config=None,    stop_grace_period_secs=120,    log_step_count_steps=100,    max_wait_secs=7200,    save_checkpoint_steps=None,    summary_dir=None):  all_hooks = []
“”“
复制代码


将多个 hook 都加入到 all_hooks 这个列表中


  ”“”
logging.info('all_hooks %r', all_hooks) # 创建session session_creator = worker_context.session_creator( scaffold, config=config, checkpoint_dir=checkpoint_dir, max_wait_secs=max_wait_secs)
return MonitoredSession( session_creator=session_creator, hooks=all_hooks, stop_grace_period_secs=stop_grace_period_secs)
# session_creator 函数主体 def session_creator(self, scaffold=None, config=None, checkpoint_dir=None, checkpoint_filename_with_path=None, max_wait_secs=7200): """
复制代码


根据正确master的target和session的config去返回session的creator方法体。
复制代码


    """    if config:      session_config = copy.deepcopy(config)      session_config.MergeFrom(self._session_config)    else:      session_config = self._session_config
复制代码


# 根据不同的角色来创建session
复制代码


    if not self._strategy or self._strategy.extended.experimental_should_init:      logging.info("Creating chief session creator with config: %r", config)      return monitored_session.ChiefSessionCreator(          scaffold,          master=self.master_target,          config=session_config,          checkpoint_dir=checkpoint_dir,          checkpoint_filename_with_path=checkpoint_filename_with_path)    else:      logging.info("Creating worker session creator with config: %r", config)      return monitored_session.WorkerSessionCreator(          scaffold,          master=self.master_target,          config=session_config,          max_wait_secs=max_wait_secs)
# ChiefSessionCreator@tf_export(v1=['train.ChiefSessionCreator'])class ChiefSessionCreator(SessionCreator): """Creates a tf.compat.v1.Session for a chief.""" def __init__(self, scaffold=None, master='', config=None, checkpoint_dir=None, checkpoint_filename_with_path=None): self._checkpoint_dir = checkpoint_dir self._checkpoint_filename_with_path = checkpoint_filename_with_path self._scaffold = scaffold or Scaffold() self._session_manager = None self._master = master self._config = config
def _get_session_manager(self): if self._session_manager: return self._session_manager
self._session_manager = sm.SessionManager( local_init_op=self._scaffold.local_init_op, ready_op=self._scaffold.ready_op, ready_for_local_init_op=self._scaffold.ready_for_local_init_op, graph=ops.get_default_graph()) return self._session_manager
def create_session(self): self._scaffold.finalize() return self._get_session_manager().prepare_session( self._master, saver=self._scaffold.saver, checkpoint_dir=self._checkpoint_dir, checkpoint_filename_with_path=self._checkpoint_filename_with_path, config=self._config, init_op=self._scaffold.init_op, init_feed_dict=self._scaffold.init_feed_dict, init_fn=self._scaffold.init_fn)
# WorkerSessionCreator@tf_export(v1=['train.WorkerSessionCreator'])class WorkerSessionCreator(SessionCreator): """Creates a tf.compat.v1.Session for a worker.""" def __init__(self, scaffold=None, master='', config=None, max_wait_secs=30 * 60): """Initializes a worker session creator.
Args: max_wait_secs: Maximum time to wait for the session to become available. """ self._scaffold = scaffold or Scaffold() self._session_manager = None self._master = master self._config = config self._max_wait_secs = max_wait_secs
def _get_session_manager(self): if self._session_manager: return self._session_manager
self._session_manager = sm.SessionManager( local_init_op=self._scaffold.local_init_op, ready_op=self._scaffold.ready_op, ready_for_local_init_op=self._scaffold.ready_for_local_init_op, graph=ops.get_default_graph()) return self._session_manager
def create_session(self): self._scaffold.finalize() return self._get_session_manager().wait_for_session( self._master, config=self._config, max_wait_secs=self._max_wait_secs)
复制代码


从上面的源码中分析得到,MonitoredTrainingSession 可根据不同的角色去创建不同种类的 Session,其中 chief 节点是由 ChiefSessionCreator 类去创建 session,而非 chief 的 worker 节点是由 WorkerSessionCreator 类创建,特殊之处就是创建时调用的是 wait_for_session(),大致意识是需要等待 chief 节点的 session 创建完成之后才去创建属于自己节点的 session。其中创建 session 都是属于 SessionManager 类的一个方法,下面我们具体分析下 SessionManager 类:


官方针对 SessionManager 类有一个简单的例子,感觉很清楚:


  # prepare_session函数可以初始化或者restore一个模型,需要传入`init_op`和 `saver`   with tf.Graph().as_default():    # add operations to the graph...    # Create a SessionManager that will checkpoint the model in '/tmp/mydir'.    sm = SessionManager()    sess = sm.prepare_session(master, init_op, saver, checkpoint_dir)    # Use the session to train the graph.    while True:      sess.run(<my_train_op>)
复制代码

第二个进程可以用以下方法启动 op,wait_for_session()的意思是需要等上面一个 session 创建好之后

  # 再创建自己的session  with tf.Graph().as_default():    # ...add operations to the graph...    # Create a SessionManager that will wait for the model to become ready.    sm = SessionManager()    sess = sm.wait_for_session(master)    # Use the session to train the graph.    while True:      sess.run(<my_train_op>)
复制代码


然后我们可以重点关注下 prepare_session 和 wait_for_session 这两个函数:


@tf_export(v1=["train.SessionManager"])class SessionManager(object):  def __init__(self,               local_init_op=None,               ready_op=None,               ready_for_local_init_op=None,               graph=None,               recovery_wait_secs=30,               local_init_run_options=None):    """
复制代码


local_init_op 是每当有一个新的session被创建时,就会运行下local_init_op这个操作。ready_op 用于check模型是否准备好的一个op。ready_for_local_init_op是checkp模型是否已经可以运行local_init_op的一个op。
复制代码


    """    # Sets default values of arguments.    if graph is None:      graph = ops.get_default_graph()    self._local_init_op = local_init_op    self._ready_op = ready_op    self._ready_for_local_init_op = ready_for_local_init_op    self._graph = graph    self._recovery_wait_secs = recovery_wait_secs    self._target = None    self._local_init_run_options = local_init_run_options    if ready_for_local_init_op is not None and local_init_op is None:      raise ValueError("If you pass a ready_for_local_init_op "                       "you must also pass a local_init_op "                       ", ready_for_local_init_op [%s]" %                       ready_for_local_init_op)
def prepare_session(self, master, init_op=None, saver=None, checkpoint_dir=None, checkpoint_filename_with_path=None, wait_for_checkpoint=False, max_wait_secs=7200, config=None, init_feed_dict=None, init_fn=None): """
复制代码


其实prepare_session函数的作用就是如果有checkpoint存在,就从checkpoint恢复session,如果不存在checkpoint就从传入的`init_op`和 调用`init_fn`函数去创建session。
复制代码


    """    sess, is_loaded_from_checkpoint = self._restore_checkpoint(        master,        saver,        checkpoint_dir=checkpoint_dir,        checkpoint_filename_with_path=checkpoint_filename_with_path,        wait_for_checkpoint=wait_for_checkpoint,        max_wait_secs=max_wait_secs,        config=config)    if not is_loaded_from_checkpoint:      if init_op is None and not init_fn and self._local_init_op is None:        raise RuntimeError("Model is not initialized and no init_op or "                           "init_fn or local_init_op was given")      if init_op is not None:        sess.run(init_op, feed_dict=init_feed_dict)      if init_fn:        init_fn(sess)    ”“”    .....    “”“    return sess

def wait_for_session(self, master, config=None, max_wait_secs=float("Inf")): """ Creates a new `Session` and waits for model to be ready. """ self._target = master if max_wait_secs is None: max_wait_secs = float("Inf") timer = _CountDownTimer(max_wait_secs)
while True: sess = session.Session(self._target, graph=self._graph, config=config) not_ready_msg = None not_ready_local_msg = None local_init_success, not_ready_local_msg = self._try_run_local_init_op( sess) if local_init_success: # Successful if local_init_op is None, or ready_for_local_init_op passes is_ready, not_ready_msg = self._model_ready(sess) if is_ready: return sess self._safe_close(sess) # Do we have enough time left to try again? remaining_ms_after_wait = ( timer.secs_remaining() - self._recovery_wait_secs) if remaining_ms_after_wait < 0: raise errors.DeadlineExceededError( None, None, "Session was not ready after waiting %d secs." % (max_wait_secs,)) logging.info("Waiting for model to be ready. " "Ready_for_local_init_op: %s, ready: %s", not_ready_local_msg, not_ready_msg) time.sleep(self._recovery_wait_secs)
复制代码


创建完 session 之后,再包装一下返回最终的 MonitoredSession 类,


一个完整的 monitored session 在创建时间内可做的事情(按顺序):


  • 为每个 hook 调用 hook.begin()

  • MonitoredTrainingSession 是 tensorflow 管理分布式训练中一个使用很广泛的 API,集成了一些监控训练组件,如变量的初始化、从已有 checkpoint 恢复训练、summary、log 和 checkpoint 的保存等。在早期的 tf 版本中,一般使用 tf.train.Supervisor 来管理 session,后来框架升级后,官方推荐使用 MonitoredTrainingSession。MonitoredTrainingSession 有记录日志、训练可视化、checkpoint 保存、early-stop、训练效率调优等功能。


我们直接进入主题,下面是 MonitoredTrainingSession 源码,从注释中可了解到:MonitoredTrainingSession 的作用可用一句话来概括:如果 chief 节点,负责 session 的初始化或者从已有 checkpoint 恢复 session,并且创建一些用于保存 checkpoint 和 summary 的 hooks。如果是非 chief 的 worker 节点,则需要依赖 chief 节点完成初始化或恢复 session 这些操作后才能设置属于自己的 session。


@tf_export(v1=[‘train.MonitoredTrainingSession’])


def MonitoredTrainingSession(


master=’’, # pylint: disable=invalid-name


is_chief=True,


checkpoint_dir=None,


scaffold=None,


hooks=None,


chief_only_hooks=None,


save_checkpoint_secs=USE_DEFAULT,


save_summaries_steps=USE_DEFAULT,


save_summaries_secs=USE_DEFAULT,


config=None,


stop_grace_period_secs=120,


log_step_count_steps=100,


max_wait_secs=7200,


save_checkpoint_steps=USE_DEFAULT,


summary_dir=None):


“”"


Creates a MonitoredSession for training.


Returns:


A MonitoredSession object.


“”"


scaffold = scaffold or Scaffold()


worker_context = distribute_coordinator_context.get_current_worker_context()


if worker_context:


return _create_monitored_session_with_worker_context(


worker_context,


scaffold,


checkpoint_dir=checkpoint_dir,


hooks=hooks,


chief_only_hooks=chief_only_hooks,


save_checkpoint_secs=save_checkpoint_secs,


save_summaries_steps=save_summaries_steps,


save_summaries_secs=save_summaries_secs,


config=config,


stop_grace_period_secs=stop_grace_period_secs,


log_step_count_steps=log_step_count_steps,


max_wait_secs=max_wait_secs,


save_checkpoint_steps=save_checkpoint_steps,


summary_dir=summary_dir)


if not is_chief:


session_creator = WorkerSessionCreator(


scaffold=scaffold,


master=master,


config=config,


max_wait_secs=max_wait_secs)


return MonitoredSession(


session_creator=session_creator,


hooks=hooks or [],


stop_grace_period_secs=stop_grace_period_secs)


all_hooks = []


“”“


将多个 hook 都加入到 all_hooks 这个列表中


”“”


if hooks:


all_hooks.extend(hooks)


return MonitoredSession(


session_creator=session_creator,


hooks=all_hooks,


stop_grace_period_secs=stop_grace_period_secs)


我们首先解释下参数:


is_chief:用于分布式系统中,用于判断该系统是否是 chief,如果为 True,它将负责初始化并恢复底层 TensorFlow 会话。如果为 False,它将等待 chief 初始化或恢复 TensorFlow 会话。


checkpoint_dir:一个字符串。指定一个用于恢复变量的 checkpoint 文件路径。


scaffold:用于收集或建立支持性 op 的脚手架。如果未指定,则会创建默认一个默认的 scaffold。它用于完成图表的创建。


hooks:SessionRunHook 对象的可选列表。可自己定义 SessionRunHook 对象,也可用已经预定义好的 SessionRunHook 对象,如:tf.train.StopAtStepHook()设置停止训练的条件;tf.train.NanTensorHook(loss):如果 loss 的值为 Nan 则停止训练;


chief_only_hooks:SessionRunHook 对象列表。如果 is_chief== True,则激活这些挂钩,否则忽略。


save_checkpoint_secs:用默认的 checkpoint saver 保存 checkpoint 的频率(以秒为单位)。如果 save_checkpoint_secs 设置为 None,不保存 checkpoint。


save_summaries_steps:使用默认 summaries saver 将摘要写入磁盘的频率(以全局步数表示)。如果 save_summaries_steps 和 save_summaries_secs 都设置为 None,则不使用默认的 summaries saver 保存 summaries。默认为 100


save_summaries_secs:使用默认 summaries saver 将摘要写入磁盘的频率(以秒为单位)。如果 save_summaries_steps 和 save_summaries_secs 都设置为 None,则不使用默认的摘要保存。默认未启用。


config:用于配置会话的 tf.ConfigProtoproto 的实例。它是 tf.Session 的构造函数的 config 参数。


stop_grace_period_secs:调用 close()后线程停止的秒数。


log_step_count_steps:记录全局步/秒的全局步数的频率。


实例化后可得到一个 MonitoredSession 对象,可当作普通 session 使用。


然后我们仔细分解下代码:


def _create_monitored_session_with_worker_context(


worker_context, # pylint: disable=missing-docstring


scaffold,


checkpoint_dir=None,


hooks=None,


chief_only_hooks=None,


save_checkpoint_secs=None,


save_summaries_steps=None,


save_summaries_secs=None,


config=None,


stop_grace_period_secs=120,


log_step_count_steps=100,


max_wait_secs=7200,


save_checkpoint_steps=None,


summary_dir=None):


all_hooks = []


“”“


将多个 hook 都加入到 all_hooks 这个列表中


”“”


logging.info(‘all_hooks %r’, all_hooks)

创建 session

session_creator = worker_context.session_creator(


scaffold,


config=config,


checkpoint_dir=checkpoint_dir,


max_wait_secs=max_wait_secs)


return MonitoredSession(


session_creator=session_creator,


hooks=all_hooks,


stop_grace_period_secs=stop_grace_period_secs)

session_creator 函数主体

def session_creator(self,


scaffold=None,


config=None,


checkpoint_dir=None,


checkpoint_filename_with_path=None,


max_wait_secs=7200):


“”"


根据正确 master 的 target 和 session 的 config 去返回 session 的 creator 方法体。


“”"


if config:


session_config = copy.deepcopy(config)


session_config.MergeFrom(self._session_config)


else:


session_config = self._session_config


# 根据不同的角色来创建sessionif not self._strategy or self._strategy.extended.experimental_should_init:  logging.info("Creating chief session creator with config: %r", config)  return monitored_session.ChiefSessionCreator(      scaffold,      master=self.master_target,      config=session_config,      checkpoint_dir=checkpoint_dir,      checkpoint_filename_with_path=checkpoint_filename_with_path)else:  logging.info("Creating worker session creator with config: %r", config)  return monitored_session.WorkerSessionCreator(      scaffold,      master=self.master_target,      config=session_config,      max_wait_secs=max_wait_secs)
复制代码

ChiefSessionCreator

@tf_export(v1=[‘train.ChiefSessionCreator’])


class ChiefSessionCreator(SessionCreator):


“”“Creates a tf.compat.v1.Session for a chief.”""


def init(self,


scaffold=None,


master=’’,


config=None,


checkpoint_dir=None,


checkpoint_filename_with_path=None):


self._checkpoint_dir = checkpoint_dir


self._checkpoint_filename_with_path = checkpoint_filename_with_path


self._scaffold = scaffold or Scaffold()


self._session_manager = None


self._master = master


self._config = config


def _get_session_manager(self):


if self._session_manager:


return self._session_manager


self._session_manager = sm.SessionManager(    local_init_op=self._scaffold.local_init_op,    ready_op=self._scaffold.ready_op,    ready_for_local_init_op=self._scaffold.ready_for_local_init_op,    graph=ops.get_default_graph())return self._session_manager
复制代码


def create_session(self):


self._scaffold.finalize()


return self._get_session_manager().prepare_session(


self._master,


saver=self._scaffold.saver,


checkpoint_dir=self._checkpoint_dir,


checkpoint_filename_with_path=self._checkpoint_filename_with_path,


config=self._config,


init_op=self._scaffold.init_op,


init_feed_dict=self._scaffold.init_feed_dict,


init_fn=self._scaffold.init_fn)

WorkerSessionCreator

@tf_export(v1=[‘train.WorkerSessionCreator’])


class WorkerSessionCreator(SessionCreator):


“”“Creates a tf.compat.v1.Session for a worker.”""


def init(self,


scaffold=None,


master=’’,


config=None,


max_wait_secs=30 * 60):


“”"Initializes a worker session creator.


Args:  max_wait_secs: Maximum time to wait for the session to become available."""self._scaffold = scaffold or Scaffold()self._session_manager = Noneself._master = masterself._config = configself._max_wait_secs = max_wait_secs
复制代码


def _get_session_manager(self):


if self._session_manager:


return self._session_manager


self._session_manager = sm.SessionManager(    local_init_op=self._scaffold.local_init_op,    ready_op=self._scaffold.ready_op,    ready_for_local_init_op=self._scaffold.ready_for_local_init_op,    graph=ops.get_default_graph())return self._session_manager
复制代码


def create_session(self):


self._scaffold.finalize()


return self._get_session_manager().wait_for_session(


self._master, config=self._config, max_wait_secs=self._max_wait_secs)


从上面的源码中分析得到,MonitoredTrainingSession 可根据不同的角色去创建不同种类的 Session,其中 chief 节点是由 ChiefSessionCreator 类去创建 session,而非 chief 的 worker 节点是由 WorkerSessionCreator 类创建,特殊之处就是创建时调用的是 wait_for_session(),大致意识是需要等待 chief 节点的 session 创建完成之后才去创建属于自己节点的 session。其中创建 session 都是属于 SessionManager 类的一个方法,下面我们具体分析下 SessionManager 类:


官方针对 SessionManager 类有一个简单的例子,感觉很清楚:

prepare_session 函数可以初始化或者 restore 一个模型,需要传入init_opsaver

with tf.Graph().as_default():


# add operations to the graph…


# Create a SessionManager that will checkpoint the model in ‘/tmp/mydir’.


sm = SessionManager()


sess = sm.prepare_session(master, init_op, saver, checkpoint_dir)


# Use the session to train the graph.


while True:


sess.run(<my_train_op>)

第二个进程可以用以下方法启动 op,wait_for_session()的意思是需要等上面一个 session 创建好之后

再创建自己的 session

with tf.Graph().as_default():


# …add operations to the graph…


# Create a SessionManager that will wait for the model to become ready.


sm = SessionManager()


sess = sm.wait_for_session(master)


# Use the session to train the graph.


while True:


sess.run(<my_train_op>)


然后我们可以重点关注下 prepare_session 和 wait_for_session 这两个函数:


@tf_export(v1=[“train.SessionManager”])


class SessionManager(object):


def init(self,


local_init_op=None,


ready_op=None,


ready_for_local_init_op=None,


graph=None,


recovery_wait_secs=30,


local_init_run_options=None):


“”"


local_init_op 是每当有一个新的 session 被创建时,就会运行下 local_init_op 这个操作。


ready_op 用于 check 模型是否准备好的一个 op。


ready_for_local_init_op 是 checkp 模型是否已经可以运行 local_init_op 的一个 op。


“”"


# Sets default values of arguments.


if graph is None:


graph = ops.get_default_graph()


self._local_init_op = local_init_op


self._ready_op = ready_op


self._ready_for_local_init_op = ready_for_local_init_op


self._graph = graph


self._recovery_wait_secs = recovery_wait_secs


self._target = None


self._local_init_run_options = local_init_run_options


if ready_for_local_init_op is not None and local_init_op is None:


raise ValueError("If you pass a ready_for_local_init_op "


"you must also pass a local_init_op "


“, ready_for_local_init_op [%s]” %


ready_for_local_init_op)


def prepare_session(self,


master,


init_op=None,


saver=None,


checkpoint_dir=None,


checkpoint_filename_with_path=None,


wait_for_checkpoint=False,


max_wait_secs=7200,


config=None,


init_feed_dict=None,


init_fn=None):


“”"


其实 prepare_session 函数的作用就是如果有 checkpoint 存在,就从 checkpoint 恢复 session,如果


不存在 checkpoint 就从传入的init_op和 调用init_fn函数去创建 session。


“”"


sess, is_loaded_from_checkpoint = self._restore_checkpoint(


master,


saver,


checkpoint_dir=checkpoint_dir,


checkpoint_filename_with_path=checkpoint_filename_with_path,


wait_for_checkpoint=wait_for_checkpoint,


max_wait_secs=max_wait_secs,


config=config)


if not is_loaded_from_checkpoint:


if init_op is None and not init_fn and self._local_init_op is None:


raise RuntimeError("Model is not initialized and no init_op or "


“init_fn or local_init_op was given”)


if init_op is not None:


sess.run(init_op, feed_dict=init_feed_dict)


if init_fn:


init_fn(sess)


”“”



“”“


return sess


def wait_for_session(self, master, config=None, max_wait_secs=float(“Inf”)):


“”"


Creates a new Session and waits for model to be ready.


“”"


self._target = master


if max_wait_secs is None:


max_wait_secs = float(“Inf”)


timer = _CountDownTimer(max_wait_secs)


while True:  sess = session.Session(self._target, graph=self._graph, config=config)  not_ready_msg = None  not_ready_local_msg = None  local_init_success, not_ready_local_msg = self._try_run_local_init_op(      sess)  if local_init_success:    # Successful if local_init_op is None, or ready_for_local_init_op passes    is_ready, not_ready_msg = self._model_ready(sess)    if is_ready:      return sess  self._safe_close(sess)  # Do we have enough time left to try again?  remaining_ms_after_wait = (      timer.secs_remaining() - self._recovery_wait_secs)  if remaining_ms_after_wait < 0:    raise errors.DeadlineExceededError(        None, None,        "Session was not ready after waiting %d secs." % (max_wait_secs,))  logging.info("Waiting for model to be ready.  "               "Ready_for_local_init_op:  %s, ready: %s",               not_ready_local_msg, not_ready_msg)  time.sleep(self._recovery_wait_secs)
复制代码


创建完 session 之后,再包装一下返回最终的 MonitoredSession 类,


一个完整的 monitored session 在创建时间内可做的事情(按顺序):


为每个 hook 调用 hook.begin()


  • 调用 scaffold.finalize()完成 graph

  • 创建 session

  • 为模型参数做初始化 ,通过 Scaffold

  • 如果存在 checkpoint 则根据 checkpoint restore 参数

  • 发布 runners 队列

  • 调用 hook.after_create_session()函数

  • 当 run 函数调用时,monitored session 做的事情:

  • 调用 hook.before_run()

  • 调用 TensorFlow 中的 session.run() with merged fetches and feed_dict

  • 调用 hook.after_run()

  • 返回 session.run()的结果

  • 如果发生 AbortedError 或者 UnavailableError,则在再次执行 run()之前恢复或者重新初始化会话

  • 当 close()函数调用时,monitored session 做的事情:

  • 调用 hook.end()

  • 关闭 queue runners 和 session

  • 如果所有的输入数据被消耗完,抛出 OutOfRange 异常。

  • 最后,给大家贴一个使用 MonitoredSession 类进行分布式训练的 example:


from __future__ import print_function, absolute_import, division
import tensorflow as tf
tf.app.flags.DEFINE_string("ps_hosts", "localhost:2222", "ps hosts")tf.app.flags.DEFINE_string("worker_hosts", "localhost:2223,localhost:2224", "worker hosts")tf.app.flags.DEFINE_string("job_name", "worker", "'ps' or'worker'")tf.app.flags.DEFINE_integer("task_index", 0, "Index of task within the job")tf.app.flags.DEFINE_integer("num_workers", 2, "Number of workers")tf.app.flags.DEFINE_boolean("is_sync", False, "using synchronous training or not")
FLAGS = tf.app.flags.FLAGS

def model(images): """Define a simple mnist classifier""" net = tf.layers.dense(images, 500, activation=tf.nn.relu) net = tf.layers.dense(net, 500, activation=tf.nn.relu) net = tf.layers.dense(net, 10, activation=None) return net

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()x_train = x_train.reshape(-1, 784).astype('float32')x_test = x_test.reshape(-1, 784).astype('float32')x_train /= 255x_test /= 255

def get_batch(image, label, batch_size=32, training=True): df = tf.data.Dataset.from_tensor_slices((image, label)) if training: df = df.repeat(10).shuffle(buffer_size=1000) df = df.batch(batch_size).prefetch(batch_size) iterator = df.make_one_shot_iterator() batch_x, batch_y = iterator.get_next() return batch_x, batch_y

def main(_): ps_hosts = FLAGS.ps_hosts.split(",") worker_hosts = FLAGS.worker_hosts.split(",")
# create the cluster configured by `ps_hosts' and 'worker_hosts' cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts})
# create a server for local task server = tf.train.Server(cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index)
train_batch_x, train_batch_y = get_batch(x_train, y_train) test_batch_x, test_batch_y = get_batch(x_test, y_test, training=False)
if FLAGS.job_name == "ps": server.join() # ps hosts only join elif FLAGS.job_name == "worker": # workers perform the operation # ps_strategy = tf.contrib.training.GreedyLoadBalancingStrategy(FLAGS.num_ps)
# Note: tf.train.replica_device_setter automatically place the paramters (Variables) # on the ps hosts (default placement strategy: round-robin over all ps hosts, and also # place multi copies of operations to each worker host with tf.device(tf.train.replica_device_setter(worker_device="/job:worker/task:%d" % FLAGS.task_index, cluster=cluster)):
logits = model(train_batch_x) loss = tf.reduce_mean( tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=tf.one_hot(train_batch_y, 10)))
# The StopAtStepHook handles stopping after running given steps. hooks = [tf.train.StopAtStepHook(last_step=10000)]
global_step = tf.train.get_or_create_global_step() optimizer = tf.train.AdamOptimizer(learning_rate=1e-04) if FLAGS.is_sync: # asynchronous training # use tf.train.SyncReplicasOptimizer wrap optimizer # ref: https://www.tensorflow.org/api_docs/python/tf/train/SyncReplicasOptimizer optimizer = tf.train.SyncReplicasOptimizer(optimizer, replicas_to_aggregate=FLAGS.num_workers, total_num_replicas=FLAGS.num_workers) # create the hook which handles initialization and queues hooks.append(optimizer.make_session_run_hook((FLAGS.task_index == 0)))
train_op = optimizer.minimize(loss, global_step=global_step)
# The MonitoredTrainingSession takes care of session initialization, # restoring from a checkpoint, saving to a checkpoint, and closing when done # or an error occurs. with tf.train.MonitoredTrainingSession(master=server.target, is_chief=(FLAGS.task_index == 0), checkpoint_dir="./checkpoint_dir", hooks=hooks) as mon_sess: while not mon_sess.should_stop(): # mon_sess.run handles AbortedError in case of preempted PS. _, ls, step = mon_sess.run([train_op, loss, global_step]) if step % 100 == 0: print("Train step %d, loss: %f" % (step, ls))
if __name__ == "__main__": tf.app.run()
复制代码


参考文献:


https://www.cnblogs.com/estragon/p/10034511.html


https://zhuanlan.zhihu.com/p/88876923


本文转载自 Alex-zhai 知乎账号。


原文链接:https://zhuanlan.zhihu.com/p/91608555


2019-11-28 08:001281

评论

发布
暂无评论
发现更多内容
分布式tensorflow源码解读2:MonitoredTrainingSession_语言 & 开发_Alex-zhai_InfoQ精选文章