写点什么

分布式 tensorflow 源码解读 2:MonitoredTrainingSession

  • 2019-11-28
  • 本文字数:15917 字

    阅读完需:约 52 分钟

分布式tensorflow源码解读2:MonitoredTrainingSession

MonitoredTrainingSession 是 tensorflow 管理分布式训练中一个使用很广泛的 API,集成了一些监控训练组件,如变量的初始化、从已有 checkpoint 恢复训练、summary、log 和 checkpoint 的保存等。在早期的 tf 版本中,一般使用 tf.train.Supervisor 来管理 session,后来框架升级后,官方推荐使用 MonitoredTrainingSession。MonitoredTrainingSession 有记录日志、训练可视化、checkpoint 保存、early-stop、训练效率调优等功能。


我们直接进入主题,下面是 MonitoredTrainingSession 源码,从注释中可了解到:MonitoredTrainingSession 的作用可用一句话来概括:如果 chief 节点,负责 session 的初始化或者从已有 checkpoint 恢复 session,并且创建一些用于保存 checkpoint 和 summary 的 hooks。如果是非 chief 的 worker 节点,则需要依赖 chief 节点完成初始化或恢复 session 这些操作后才能设置属于自己的 session。


@tf_export(v1=['train.MonitoredTrainingSession'])def MonitoredTrainingSession(    master='',  # pylint: disable=invalid-name    is_chief=True,    checkpoint_dir=None,    scaffold=None,    hooks=None,    chief_only_hooks=None,    save_checkpoint_secs=USE_DEFAULT,    save_summaries_steps=USE_DEFAULT,    save_summaries_secs=USE_DEFAULT,    config=None,    stop_grace_period_secs=120,    log_step_count_steps=100,    max_wait_secs=7200,    save_checkpoint_steps=USE_DEFAULT,    summary_dir=None):
""" Creates a `MonitoredSession` for training. Returns: A `MonitoredSession` object. """ scaffold = scaffold or Scaffold() worker_context = distribute_coordinator_context.get_current_worker_context()
if worker_context: return _create_monitored_session_with_worker_context( worker_context, scaffold, checkpoint_dir=checkpoint_dir, hooks=hooks, chief_only_hooks=chief_only_hooks, save_checkpoint_secs=save_checkpoint_secs, save_summaries_steps=save_summaries_steps, save_summaries_secs=save_summaries_secs, config=config, stop_grace_period_secs=stop_grace_period_secs, log_step_count_steps=log_step_count_steps, max_wait_secs=max_wait_secs, save_checkpoint_steps=save_checkpoint_steps, summary_dir=summary_dir)
if not is_chief: session_creator = WorkerSessionCreator( scaffold=scaffold, master=master, config=config, max_wait_secs=max_wait_secs) return MonitoredSession( session_creator=session_creator, hooks=hooks or [], stop_grace_period_secs=stop_grace_period_secs)
all_hooks = [] “”“ 将多个hook都加入到all_hooks这个列表中 ”“” if hooks: all_hooks.extend(hooks)
return MonitoredSession( session_creator=session_creator, hooks=all_hooks, stop_grace_period_secs=stop_grace_period_secs)
复制代码


我们首先解释下参数:


is_chief:用于分布式系统中,用于判断该系统是否是 chief,如果为 True,它将负责初始化并恢复底层 TensorFlow 会话。如果为 False,它将等待 chief 初始化或恢复 TensorFlow 会话。


checkpoint_dir:一个字符串。指定一个用于恢复变量的 checkpoint 文件路径。


scaffold:用于收集或建立支持性 op 的脚手架。如果未指定,则会创建默认一个默认的 scaffold。它用于完成图表的创建。


hooks:SessionRunHook 对象的可选列表。可自己定义 SessionRunHook 对象,也可用已经预定义好的 SessionRunHook 对象,如:tf.train.StopAtStepHook()设置停止训练的条件;tf.train.NanTensorHook(loss):如果 loss 的值为 Nan 则停止训练;


chief_only_hooks:SessionRunHook 对象列表。如果 is_chief== True,则激活这些挂钩,否则忽略。


save_checkpoint_secs:用默认的 checkpoint saver 保存 checkpoint 的频率(以秒为单位)。如果 save_checkpoint_secs 设置为 None,不保存 checkpoint。


save_summaries_steps:使用默认 summaries saver 将摘要写入磁盘的频率(以全局步数表示)。如果 save_summaries_steps 和 save_summaries_secs 都设置为 None,则不使用默认的 summaries saver 保存 summaries。默认为 100


save_summaries_secs:使用默认 summaries saver 将摘要写入磁盘的频率(以秒为单位)。如果 save_summaries_steps 和 save_summaries_secs 都设置为 None,则不使用默认的摘要保存。默认未启用。


config:用于配置会话的 tf.ConfigProtoproto 的实例。它是 tf.Session 的构造函数的 config 参数。


stop_grace_period_secs:调用 close()后线程停止的秒数。


log_step_count_steps:记录全局步/秒的全局步数的频率。


实例化后可得到一个 MonitoredSession 对象,可当作普通 session 使用。


然后我们仔细分解下代码:


def _create_monitored_session_with_worker_context(    worker_context,  # pylint: disable=missing-docstring    scaffold,    checkpoint_dir=None,    hooks=None,    chief_only_hooks=None,    save_checkpoint_secs=None,    save_summaries_steps=None,    save_summaries_secs=None,    config=None,    stop_grace_period_secs=120,    log_step_count_steps=100,    max_wait_secs=7200,    save_checkpoint_steps=None,    summary_dir=None):  all_hooks = []
“”“
复制代码


将多个 hook 都加入到 all_hooks 这个列表中


  ”“”
logging.info('all_hooks %r', all_hooks) # 创建session session_creator = worker_context.session_creator( scaffold, config=config, checkpoint_dir=checkpoint_dir, max_wait_secs=max_wait_secs)
return MonitoredSession( session_creator=session_creator, hooks=all_hooks, stop_grace_period_secs=stop_grace_period_secs)
# session_creator 函数主体 def session_creator(self, scaffold=None, config=None, checkpoint_dir=None, checkpoint_filename_with_path=None, max_wait_secs=7200): """
复制代码


根据正确master的target和session的config去返回session的creator方法体。
复制代码


    """    if config:      session_config = copy.deepcopy(config)      session_config.MergeFrom(self._session_config)    else:      session_config = self._session_config
复制代码


# 根据不同的角色来创建session
复制代码


    if not self._strategy or self._strategy.extended.experimental_should_init:      logging.info("Creating chief session creator with config: %r", config)      return monitored_session.ChiefSessionCreator(          scaffold,          master=self.master_target,          config=session_config,          checkpoint_dir=checkpoint_dir,          checkpoint_filename_with_path=checkpoint_filename_with_path)    else:      logging.info("Creating worker session creator with config: %r", config)      return monitored_session.WorkerSessionCreator(          scaffold,          master=self.master_target,          config=session_config,          max_wait_secs=max_wait_secs)
# ChiefSessionCreator@tf_export(v1=['train.ChiefSessionCreator'])class ChiefSessionCreator(SessionCreator): """Creates a tf.compat.v1.Session for a chief.""" def __init__(self, scaffold=None, master='', config=None, checkpoint_dir=None, checkpoint_filename_with_path=None): self._checkpoint_dir = checkpoint_dir self._checkpoint_filename_with_path = checkpoint_filename_with_path self._scaffold = scaffold or Scaffold() self._session_manager = None self._master = master self._config = config
def _get_session_manager(self): if self._session_manager: return self._session_manager
self._session_manager = sm.SessionManager( local_init_op=self._scaffold.local_init_op, ready_op=self._scaffold.ready_op, ready_for_local_init_op=self._scaffold.ready_for_local_init_op, graph=ops.get_default_graph()) return self._session_manager
def create_session(self): self._scaffold.finalize() return self._get_session_manager().prepare_session( self._master, saver=self._scaffold.saver, checkpoint_dir=self._checkpoint_dir, checkpoint_filename_with_path=self._checkpoint_filename_with_path, config=self._config, init_op=self._scaffold.init_op, init_feed_dict=self._scaffold.init_feed_dict, init_fn=self._scaffold.init_fn)
# WorkerSessionCreator@tf_export(v1=['train.WorkerSessionCreator'])class WorkerSessionCreator(SessionCreator): """Creates a tf.compat.v1.Session for a worker.""" def __init__(self, scaffold=None, master='', config=None, max_wait_secs=30 * 60): """Initializes a worker session creator.
Args: max_wait_secs: Maximum time to wait for the session to become available. """ self._scaffold = scaffold or Scaffold() self._session_manager = None self._master = master self._config = config self._max_wait_secs = max_wait_secs
def _get_session_manager(self): if self._session_manager: return self._session_manager
self._session_manager = sm.SessionManager( local_init_op=self._scaffold.local_init_op, ready_op=self._scaffold.ready_op, ready_for_local_init_op=self._scaffold.ready_for_local_init_op, graph=ops.get_default_graph()) return self._session_manager
def create_session(self): self._scaffold.finalize() return self._get_session_manager().wait_for_session( self._master, config=self._config, max_wait_secs=self._max_wait_secs)
复制代码


从上面的源码中分析得到,MonitoredTrainingSession 可根据不同的角色去创建不同种类的 Session,其中 chief 节点是由 ChiefSessionCreator 类去创建 session,而非 chief 的 worker 节点是由 WorkerSessionCreator 类创建,特殊之处就是创建时调用的是 wait_for_session(),大致意识是需要等待 chief 节点的 session 创建完成之后才去创建属于自己节点的 session。其中创建 session 都是属于 SessionManager 类的一个方法,下面我们具体分析下 SessionManager 类:


官方针对 SessionManager 类有一个简单的例子,感觉很清楚:


  # prepare_session函数可以初始化或者restore一个模型,需要传入`init_op`和 `saver`   with tf.Graph().as_default():    # add operations to the graph...    # Create a SessionManager that will checkpoint the model in '/tmp/mydir'.    sm = SessionManager()    sess = sm.prepare_session(master, init_op, saver, checkpoint_dir)    # Use the session to train the graph.    while True:      sess.run(<my_train_op>)
复制代码

第二个进程可以用以下方法启动 op,wait_for_session()的意思是需要等上面一个 session 创建好之后

  # 再创建自己的session  with tf.Graph().as_default():    # ...add operations to the graph...    # Create a SessionManager that will wait for the model to become ready.    sm = SessionManager()    sess = sm.wait_for_session(master)    # Use the session to train the graph.    while True:      sess.run(<my_train_op>)
复制代码


然后我们可以重点关注下 prepare_session 和 wait_for_session 这两个函数:


@tf_export(v1=["train.SessionManager"])class SessionManager(object):  def __init__(self,               local_init_op=None,               ready_op=None,               ready_for_local_init_op=None,               graph=None,               recovery_wait_secs=30,               local_init_run_options=None):    """
复制代码


local_init_op 是每当有一个新的session被创建时,就会运行下local_init_op这个操作。ready_op 用于check模型是否准备好的一个op。ready_for_local_init_op是checkp模型是否已经可以运行local_init_op的一个op。
复制代码


    """    # Sets default values of arguments.    if graph is None:      graph = ops.get_default_graph()    self._local_init_op = local_init_op    self._ready_op = ready_op    self._ready_for_local_init_op = ready_for_local_init_op    self._graph = graph    self._recovery_wait_secs = recovery_wait_secs    self._target = None    self._local_init_run_options = local_init_run_options    if ready_for_local_init_op is not None and local_init_op is None:      raise ValueError("If you pass a ready_for_local_init_op "                       "you must also pass a local_init_op "                       ", ready_for_local_init_op [%s]" %                       ready_for_local_init_op)
def prepare_session(self, master, init_op=None, saver=None, checkpoint_dir=None, checkpoint_filename_with_path=None, wait_for_checkpoint=False, max_wait_secs=7200, config=None, init_feed_dict=None, init_fn=None): """
复制代码


其实prepare_session函数的作用就是如果有checkpoint存在,就从checkpoint恢复session,如果不存在checkpoint就从传入的`init_op`和 调用`init_fn`函数去创建session。
复制代码


    """    sess, is_loaded_from_checkpoint = self._restore_checkpoint(        master,        saver,        checkpoint_dir=checkpoint_dir,        checkpoint_filename_with_path=checkpoint_filename_with_path,        wait_for_checkpoint=wait_for_checkpoint,        max_wait_secs=max_wait_secs,        config=config)    if not is_loaded_from_checkpoint:      if init_op is None and not init_fn and self._local_init_op is None:        raise RuntimeError("Model is not initialized and no init_op or "                           "init_fn or local_init_op was given")      if init_op is not None:        sess.run(init_op, feed_dict=init_feed_dict)      if init_fn:        init_fn(sess)    ”“”    .....    “”“    return sess

def wait_for_session(self, master, config=None, max_wait_secs=float("Inf")): """ Creates a new `Session` and waits for model to be ready. """ self._target = master if max_wait_secs is None: max_wait_secs = float("Inf") timer = _CountDownTimer(max_wait_secs)
while True: sess = session.Session(self._target, graph=self._graph, config=config) not_ready_msg = None not_ready_local_msg = None local_init_success, not_ready_local_msg = self._try_run_local_init_op( sess) if local_init_success: # Successful if local_init_op is None, or ready_for_local_init_op passes is_ready, not_ready_msg = self._model_ready(sess) if is_ready: return sess self._safe_close(sess) # Do we have enough time left to try again? remaining_ms_after_wait = ( timer.secs_remaining() - self._recovery_wait_secs) if remaining_ms_after_wait < 0: raise errors.DeadlineExceededError( None, None, "Session was not ready after waiting %d secs." % (max_wait_secs,)) logging.info("Waiting for model to be ready. " "Ready_for_local_init_op: %s, ready: %s", not_ready_local_msg, not_ready_msg) time.sleep(self._recovery_wait_secs)
复制代码


创建完 session 之后,再包装一下返回最终的 MonitoredSession 类,


一个完整的 monitored session 在创建时间内可做的事情(按顺序):


  • 为每个 hook 调用 hook.begin()

  • MonitoredTrainingSession 是 tensorflow 管理分布式训练中一个使用很广泛的 API,集成了一些监控训练组件,如变量的初始化、从已有 checkpoint 恢复训练、summary、log 和 checkpoint 的保存等。在早期的 tf 版本中,一般使用 tf.train.Supervisor 来管理 session,后来框架升级后,官方推荐使用 MonitoredTrainingSession。MonitoredTrainingSession 有记录日志、训练可视化、checkpoint 保存、early-stop、训练效率调优等功能。


我们直接进入主题,下面是 MonitoredTrainingSession 源码,从注释中可了解到:MonitoredTrainingSession 的作用可用一句话来概括:如果 chief 节点,负责 session 的初始化或者从已有 checkpoint 恢复 session,并且创建一些用于保存 checkpoint 和 summary 的 hooks。如果是非 chief 的 worker 节点,则需要依赖 chief 节点完成初始化或恢复 session 这些操作后才能设置属于自己的 session。


@tf_export(v1=[‘train.MonitoredTrainingSession’])


def MonitoredTrainingSession(


master=’’, # pylint: disable=invalid-name


is_chief=True,


checkpoint_dir=None,


scaffold=None,


hooks=None,


chief_only_hooks=None,


save_checkpoint_secs=USE_DEFAULT,


save_summaries_steps=USE_DEFAULT,


save_summaries_secs=USE_DEFAULT,


config=None,


stop_grace_period_secs=120,


log_step_count_steps=100,


max_wait_secs=7200,


save_checkpoint_steps=USE_DEFAULT,


summary_dir=None):


“”"


Creates a MonitoredSession for training.


Returns:


A MonitoredSession object.


“”"


scaffold = scaffold or Scaffold()


worker_context = distribute_coordinator_context.get_current_worker_context()


if worker_context:


return _create_monitored_session_with_worker_context(


worker_context,


scaffold,


checkpoint_dir=checkpoint_dir,


hooks=hooks,


chief_only_hooks=chief_only_hooks,


save_checkpoint_secs=save_checkpoint_secs,


save_summaries_steps=save_summaries_steps,


save_summaries_secs=save_summaries_secs,


config=config,


stop_grace_period_secs=stop_grace_period_secs,


log_step_count_steps=log_step_count_steps,


max_wait_secs=max_wait_secs,


save_checkpoint_steps=save_checkpoint_steps,


summary_dir=summary_dir)


if not is_chief:


session_creator = WorkerSessionCreator(


scaffold=scaffold,


master=master,


config=config,


max_wait_secs=max_wait_secs)


return MonitoredSession(


session_creator=session_creator,


hooks=hooks or [],


stop_grace_period_secs=stop_grace_period_secs)


all_hooks = []


“”“


将多个 hook 都加入到 all_hooks 这个列表中


”“”


if hooks:


all_hooks.extend(hooks)


return MonitoredSession(


session_creator=session_creator,


hooks=all_hooks,


stop_grace_period_secs=stop_grace_period_secs)


我们首先解释下参数:


is_chief:用于分布式系统中,用于判断该系统是否是 chief,如果为 True,它将负责初始化并恢复底层 TensorFlow 会话。如果为 False,它将等待 chief 初始化或恢复 TensorFlow 会话。


checkpoint_dir:一个字符串。指定一个用于恢复变量的 checkpoint 文件路径。


scaffold:用于收集或建立支持性 op 的脚手架。如果未指定,则会创建默认一个默认的 scaffold。它用于完成图表的创建。


hooks:SessionRunHook 对象的可选列表。可自己定义 SessionRunHook 对象,也可用已经预定义好的 SessionRunHook 对象,如:tf.train.StopAtStepHook()设置停止训练的条件;tf.train.NanTensorHook(loss):如果 loss 的值为 Nan 则停止训练;


chief_only_hooks:SessionRunHook 对象列表。如果 is_chief== True,则激活这些挂钩,否则忽略。


save_checkpoint_secs:用默认的 checkpoint saver 保存 checkpoint 的频率(以秒为单位)。如果 save_checkpoint_secs 设置为 None,不保存 checkpoint。


save_summaries_steps:使用默认 summaries saver 将摘要写入磁盘的频率(以全局步数表示)。如果 save_summaries_steps 和 save_summaries_secs 都设置为 None,则不使用默认的 summaries saver 保存 summaries。默认为 100


save_summaries_secs:使用默认 summaries saver 将摘要写入磁盘的频率(以秒为单位)。如果 save_summaries_steps 和 save_summaries_secs 都设置为 None,则不使用默认的摘要保存。默认未启用。


config:用于配置会话的 tf.ConfigProtoproto 的实例。它是 tf.Session 的构造函数的 config 参数。


stop_grace_period_secs:调用 close()后线程停止的秒数。


log_step_count_steps:记录全局步/秒的全局步数的频率。


实例化后可得到一个 MonitoredSession 对象,可当作普通 session 使用。


然后我们仔细分解下代码:


def _create_monitored_session_with_worker_context(


worker_context, # pylint: disable=missing-docstring


scaffold,


checkpoint_dir=None,


hooks=None,


chief_only_hooks=None,


save_checkpoint_secs=None,


save_summaries_steps=None,


save_summaries_secs=None,


config=None,


stop_grace_period_secs=120,


log_step_count_steps=100,


max_wait_secs=7200,


save_checkpoint_steps=None,


summary_dir=None):


all_hooks = []


“”“


将多个 hook 都加入到 all_hooks 这个列表中


”“”


logging.info(‘all_hooks %r’, all_hooks)

创建 session

session_creator = worker_context.session_creator(


scaffold,


config=config,


checkpoint_dir=checkpoint_dir,


max_wait_secs=max_wait_secs)


return MonitoredSession(


session_creator=session_creator,


hooks=all_hooks,


stop_grace_period_secs=stop_grace_period_secs)

session_creator 函数主体

def session_creator(self,


scaffold=None,


config=None,


checkpoint_dir=None,


checkpoint_filename_with_path=None,


max_wait_secs=7200):


“”"


根据正确 master 的 target 和 session 的 config 去返回 session 的 creator 方法体。


“”"


if config:


session_config = copy.deepcopy(config)


session_config.MergeFrom(self._session_config)


else:


session_config = self._session_config


# 根据不同的角色来创建sessionif not self._strategy or self._strategy.extended.experimental_should_init:  logging.info("Creating chief session creator with config: %r", config)  return monitored_session.ChiefSessionCreator(      scaffold,      master=self.master_target,      config=session_config,      checkpoint_dir=checkpoint_dir,      checkpoint_filename_with_path=checkpoint_filename_with_path)else:  logging.info("Creating worker session creator with config: %r", config)  return monitored_session.WorkerSessionCreator(      scaffold,      master=self.master_target,      config=session_config,      max_wait_secs=max_wait_secs)
复制代码

ChiefSessionCreator

@tf_export(v1=[‘train.ChiefSessionCreator’])


class ChiefSessionCreator(SessionCreator):


“”“Creates a tf.compat.v1.Session for a chief.”""


def init(self,


scaffold=None,


master=’’,


config=None,


checkpoint_dir=None,


checkpoint_filename_with_path=None):


self._checkpoint_dir = checkpoint_dir


self._checkpoint_filename_with_path = checkpoint_filename_with_path


self._scaffold = scaffold or Scaffold()


self._session_manager = None


self._master = master


self._config = config


def _get_session_manager(self):


if self._session_manager:


return self._session_manager


self._session_manager = sm.SessionManager(    local_init_op=self._scaffold.local_init_op,    ready_op=self._scaffold.ready_op,    ready_for_local_init_op=self._scaffold.ready_for_local_init_op,    graph=ops.get_default_graph())return self._session_manager
复制代码


def create_session(self):


self._scaffold.finalize()


return self._get_session_manager().prepare_session(


self._master,


saver=self._scaffold.saver,


checkpoint_dir=self._checkpoint_dir,


checkpoint_filename_with_path=self._checkpoint_filename_with_path,


config=self._config,


init_op=self._scaffold.init_op,


init_feed_dict=self._scaffold.init_feed_dict,


init_fn=self._scaffold.init_fn)

WorkerSessionCreator

@tf_export(v1=[‘train.WorkerSessionCreator’])


class WorkerSessionCreator(SessionCreator):


“”“Creates a tf.compat.v1.Session for a worker.”""


def init(self,


scaffold=None,


master=’’,


config=None,


max_wait_secs=30 * 60):


“”"Initializes a worker session creator.


Args:  max_wait_secs: Maximum time to wait for the session to become available."""self._scaffold = scaffold or Scaffold()self._session_manager = Noneself._master = masterself._config = configself._max_wait_secs = max_wait_secs
复制代码


def _get_session_manager(self):


if self._session_manager:


return self._session_manager


self._session_manager = sm.SessionManager(    local_init_op=self._scaffold.local_init_op,    ready_op=self._scaffold.ready_op,    ready_for_local_init_op=self._scaffold.ready_for_local_init_op,    graph=ops.get_default_graph())return self._session_manager
复制代码


def create_session(self):


self._scaffold.finalize()


return self._get_session_manager().wait_for_session(


self._master, config=self._config, max_wait_secs=self._max_wait_secs)


从上面的源码中分析得到,MonitoredTrainingSession 可根据不同的角色去创建不同种类的 Session,其中 chief 节点是由 ChiefSessionCreator 类去创建 session,而非 chief 的 worker 节点是由 WorkerSessionCreator 类创建,特殊之处就是创建时调用的是 wait_for_session(),大致意识是需要等待 chief 节点的 session 创建完成之后才去创建属于自己节点的 session。其中创建 session 都是属于 SessionManager 类的一个方法,下面我们具体分析下 SessionManager 类:


官方针对 SessionManager 类有一个简单的例子,感觉很清楚:

prepare_session 函数可以初始化或者 restore 一个模型,需要传入init_opsaver

with tf.Graph().as_default():


# add operations to the graph…


# Create a SessionManager that will checkpoint the model in ‘/tmp/mydir’.


sm = SessionManager()


sess = sm.prepare_session(master, init_op, saver, checkpoint_dir)


# Use the session to train the graph.


while True:


sess.run(<my_train_op>)

第二个进程可以用以下方法启动 op,wait_for_session()的意思是需要等上面一个 session 创建好之后

再创建自己的 session

with tf.Graph().as_default():


# …add operations to the graph…


# Create a SessionManager that will wait for the model to become ready.


sm = SessionManager()


sess = sm.wait_for_session(master)


# Use the session to train the graph.


while True:


sess.run(<my_train_op>)


然后我们可以重点关注下 prepare_session 和 wait_for_session 这两个函数:


@tf_export(v1=[“train.SessionManager”])


class SessionManager(object):


def init(self,


local_init_op=None,


ready_op=None,


ready_for_local_init_op=None,


graph=None,


recovery_wait_secs=30,


local_init_run_options=None):


“”"


local_init_op 是每当有一个新的 session 被创建时,就会运行下 local_init_op 这个操作。


ready_op 用于 check 模型是否准备好的一个 op。


ready_for_local_init_op 是 checkp 模型是否已经可以运行 local_init_op 的一个 op。


“”"


# Sets default values of arguments.


if graph is None:


graph = ops.get_default_graph()


self._local_init_op = local_init_op


self._ready_op = ready_op


self._ready_for_local_init_op = ready_for_local_init_op


self._graph = graph


self._recovery_wait_secs = recovery_wait_secs


self._target = None


self._local_init_run_options = local_init_run_options


if ready_for_local_init_op is not None and local_init_op is None:


raise ValueError("If you pass a ready_for_local_init_op "


"you must also pass a local_init_op "


“, ready_for_local_init_op [%s]” %


ready_for_local_init_op)


def prepare_session(self,


master,


init_op=None,


saver=None,


checkpoint_dir=None,


checkpoint_filename_with_path=None,


wait_for_checkpoint=False,


max_wait_secs=7200,


config=None,


init_feed_dict=None,


init_fn=None):


“”"


其实 prepare_session 函数的作用就是如果有 checkpoint 存在,就从 checkpoint 恢复 session,如果


不存在 checkpoint 就从传入的init_op和 调用init_fn函数去创建 session。


“”"


sess, is_loaded_from_checkpoint = self._restore_checkpoint(


master,


saver,


checkpoint_dir=checkpoint_dir,


checkpoint_filename_with_path=checkpoint_filename_with_path,


wait_for_checkpoint=wait_for_checkpoint,


max_wait_secs=max_wait_secs,


config=config)


if not is_loaded_from_checkpoint:


if init_op is None and not init_fn and self._local_init_op is None:


raise RuntimeError("Model is not initialized and no init_op or "


“init_fn or local_init_op was given”)


if init_op is not None:


sess.run(init_op, feed_dict=init_feed_dict)


if init_fn:


init_fn(sess)


”“”



“”“


return sess


def wait_for_session(self, master, config=None, max_wait_secs=float(“Inf”)):


“”"


Creates a new Session and waits for model to be ready.


“”"


self._target = master


if max_wait_secs is None:


max_wait_secs = float(“Inf”)


timer = _CountDownTimer(max_wait_secs)


while True:  sess = session.Session(self._target, graph=self._graph, config=config)  not_ready_msg = None  not_ready_local_msg = None  local_init_success, not_ready_local_msg = self._try_run_local_init_op(      sess)  if local_init_success:    # Successful if local_init_op is None, or ready_for_local_init_op passes    is_ready, not_ready_msg = self._model_ready(sess)    if is_ready:      return sess  self._safe_close(sess)  # Do we have enough time left to try again?  remaining_ms_after_wait = (      timer.secs_remaining() - self._recovery_wait_secs)  if remaining_ms_after_wait < 0:    raise errors.DeadlineExceededError(        None, None,        "Session was not ready after waiting %d secs." % (max_wait_secs,))  logging.info("Waiting for model to be ready.  "               "Ready_for_local_init_op:  %s, ready: %s",               not_ready_local_msg, not_ready_msg)  time.sleep(self._recovery_wait_secs)
复制代码


创建完 session 之后,再包装一下返回最终的 MonitoredSession 类,


一个完整的 monitored session 在创建时间内可做的事情(按顺序):


为每个 hook 调用 hook.begin()


  • 调用 scaffold.finalize()完成 graph

  • 创建 session

  • 为模型参数做初始化 ,通过 Scaffold

  • 如果存在 checkpoint 则根据 checkpoint restore 参数

  • 发布 runners 队列

  • 调用 hook.after_create_session()函数

  • 当 run 函数调用时,monitored session 做的事情:

  • 调用 hook.before_run()

  • 调用 TensorFlow 中的 session.run() with merged fetches and feed_dict

  • 调用 hook.after_run()

  • 返回 session.run()的结果

  • 如果发生 AbortedError 或者 UnavailableError,则在再次执行 run()之前恢复或者重新初始化会话

  • 当 close()函数调用时,monitored session 做的事情:

  • 调用 hook.end()

  • 关闭 queue runners 和 session

  • 如果所有的输入数据被消耗完,抛出 OutOfRange 异常。

  • 最后,给大家贴一个使用 MonitoredSession 类进行分布式训练的 example:


from __future__ import print_function, absolute_import, division
import tensorflow as tf
tf.app.flags.DEFINE_string("ps_hosts", "localhost:2222", "ps hosts")tf.app.flags.DEFINE_string("worker_hosts", "localhost:2223,localhost:2224", "worker hosts")tf.app.flags.DEFINE_string("job_name", "worker", "'ps' or'worker'")tf.app.flags.DEFINE_integer("task_index", 0, "Index of task within the job")tf.app.flags.DEFINE_integer("num_workers", 2, "Number of workers")tf.app.flags.DEFINE_boolean("is_sync", False, "using synchronous training or not")
FLAGS = tf.app.flags.FLAGS

def model(images): """Define a simple mnist classifier""" net = tf.layers.dense(images, 500, activation=tf.nn.relu) net = tf.layers.dense(net, 500, activation=tf.nn.relu) net = tf.layers.dense(net, 10, activation=None) return net

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()x_train = x_train.reshape(-1, 784).astype('float32')x_test = x_test.reshape(-1, 784).astype('float32')x_train /= 255x_test /= 255

def get_batch(image, label, batch_size=32, training=True): df = tf.data.Dataset.from_tensor_slices((image, label)) if training: df = df.repeat(10).shuffle(buffer_size=1000) df = df.batch(batch_size).prefetch(batch_size) iterator = df.make_one_shot_iterator() batch_x, batch_y = iterator.get_next() return batch_x, batch_y

def main(_): ps_hosts = FLAGS.ps_hosts.split(",") worker_hosts = FLAGS.worker_hosts.split(",")
# create the cluster configured by `ps_hosts' and 'worker_hosts' cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts})
# create a server for local task server = tf.train.Server(cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index)
train_batch_x, train_batch_y = get_batch(x_train, y_train) test_batch_x, test_batch_y = get_batch(x_test, y_test, training=False)
if FLAGS.job_name == "ps": server.join() # ps hosts only join elif FLAGS.job_name == "worker": # workers perform the operation # ps_strategy = tf.contrib.training.GreedyLoadBalancingStrategy(FLAGS.num_ps)
# Note: tf.train.replica_device_setter automatically place the paramters (Variables) # on the ps hosts (default placement strategy: round-robin over all ps hosts, and also # place multi copies of operations to each worker host with tf.device(tf.train.replica_device_setter(worker_device="/job:worker/task:%d" % FLAGS.task_index, cluster=cluster)):
logits = model(train_batch_x) loss = tf.reduce_mean( tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=tf.one_hot(train_batch_y, 10)))
# The StopAtStepHook handles stopping after running given steps. hooks = [tf.train.StopAtStepHook(last_step=10000)]
global_step = tf.train.get_or_create_global_step() optimizer = tf.train.AdamOptimizer(learning_rate=1e-04) if FLAGS.is_sync: # asynchronous training # use tf.train.SyncReplicasOptimizer wrap optimizer # ref: https://www.tensorflow.org/api_docs/python/tf/train/SyncReplicasOptimizer optimizer = tf.train.SyncReplicasOptimizer(optimizer, replicas_to_aggregate=FLAGS.num_workers, total_num_replicas=FLAGS.num_workers) # create the hook which handles initialization and queues hooks.append(optimizer.make_session_run_hook((FLAGS.task_index == 0)))
train_op = optimizer.minimize(loss, global_step=global_step)
# The MonitoredTrainingSession takes care of session initialization, # restoring from a checkpoint, saving to a checkpoint, and closing when done # or an error occurs. with tf.train.MonitoredTrainingSession(master=server.target, is_chief=(FLAGS.task_index == 0), checkpoint_dir="./checkpoint_dir", hooks=hooks) as mon_sess: while not mon_sess.should_stop(): # mon_sess.run handles AbortedError in case of preempted PS. _, ls, step = mon_sess.run([train_op, loss, global_step]) if step % 100 == 0: print("Train step %d, loss: %f" % (step, ls))
if __name__ == "__main__": tf.app.run()
复制代码


参考文献:


https://www.cnblogs.com/estragon/p/10034511.html


https://zhuanlan.zhihu.com/p/88876923


本文转载自 Alex-zhai 知乎账号。


原文链接:https://zhuanlan.zhihu.com/p/91608555


2019-11-28 08:001683

评论

发布
暂无评论
发现更多内容

分享一份大佬的MySQL数据库设计规范,值得收藏

小Q

学习 架构 面试 JVM 多线程

区块链跨境溯源平台搭建,助力跨境电商防伪溯源

13530558032

跨国区块链投资 花式“割韭菜”骗光你的钱

CECBC

区块链

清华大佬马士兵告诉你从阿里P5级一直学到P8架构师的成长路线+视频教程!

比伯

Java 编程 架构 面试 计算机

美团架构师总结整理的这份GitHub标星150K+的神仙笔记,我花了两个月肝完成功面进了阿里定级P7,现在分享出来希望大家也能有所提升!

Java架构之路

Java 程序员 架构 面试 编程语言

一周信创舆情观察(11.16~11.22)

统小信uos

食堂就餐卡系统设计

我们新四军不拿群众一针一线

涨薪神作!华为内部操作系统与网络协议笔记爆火,这也太香了吧

Java架构之路

Java 程序员 架构 面试 编程语言

去“测试”下的测试重生,不为焦虑买单

陈磊@Criss

软件测试 自动化测试 测试开发 智能化测试

如何保护视频资源?这几个防盗链使用技巧你一定要知道!

腾讯云音视频

音视频 防盗链 内容安全 视频防盗 视频资源防盗

腾讯云数据库登陆印尼,金融行业显神威

数据君

tdsql

【得物技术】供应链库存幂等实战分享

得物技术

幂等 供应链 得物技术部 实战 得物

数字时代,如何跟上互联网医院的建设潮?

CECBC

数字化医疗

一文带你了解两种Transformer文字识别方法

华为云开发者联盟

人工智能 AI 文字识别

深入内核,拒绝蒙圈!阿里巴巴一位P7级架构师总结整理的这份《Java架构成长笔记》彻底火了。

Java架构之路

Java 程序员 架构 面试 编程语言

架构师训练营第一周架构方法学习总结

Geek_xq

“通证经济”实质是生产关系的变革

CECBC

通证经济

架构方法

raox

极客大学架构师训练营

有了它,数据库也能空中加油,一边迁移一边跑起来

数据君

架构师训练营第一周UML图

Geek_xq

5G矿山,工业真金,以及智能体矿井

脑极体

通过python基于netconf协议获取网络中网元的配置数据,助力企业网络控制自动化轻松实现!

华为云开发者联盟

通信 企业 网络自动化

区块链数字版权应用落地开发,区块链版权溯源解决方案

13530558032

EMAS 移动 DevOps 解决方案 —— Mobile DevOps

移动研发平台EMAS

阿里云 DevOps 运维 开发 emas

架构师训练营第 10 周学习总结

netspecial

极客大学架构师训练营

ThinkPad X1 Carbon与MacBook有何不同?不止是专业

E科讯

架构师训练营 -week10-总结

大刘

极客大学架构师训练营

从数据库巨人身上撕开一道口子

数据君

Week 1 学习总结

J

极客大学架构师训练营

《迅雷链精品课》第八课:迅雷链多链结构

迅雷链

区块链

如何保障企业数据资产的全生命周期安全?看这篇就够了

华为云开发者联盟

数据 数据资产 数据安全

分布式tensorflow源码解读2:MonitoredTrainingSession_语言 & 开发_Alex-zhai_InfoQ精选文章