时隔16年Jeff Barr重返10.23-25 QCon上海站,带你看透AI如何重塑软件开发! 了解详情
写点什么

分布式 tensorflow 源码解读 2:MonitoredTrainingSession

  • 2019-11-28
  • 本文字数:15917 字

    阅读完需:约 52 分钟

分布式tensorflow源码解读2:MonitoredTrainingSession

MonitoredTrainingSession 是 tensorflow 管理分布式训练中一个使用很广泛的 API,集成了一些监控训练组件,如变量的初始化、从已有 checkpoint 恢复训练、summary、log 和 checkpoint 的保存等。在早期的 tf 版本中,一般使用 tf.train.Supervisor 来管理 session,后来框架升级后,官方推荐使用 MonitoredTrainingSession。MonitoredTrainingSession 有记录日志、训练可视化、checkpoint 保存、early-stop、训练效率调优等功能。


我们直接进入主题,下面是 MonitoredTrainingSession 源码,从注释中可了解到:MonitoredTrainingSession 的作用可用一句话来概括:如果 chief 节点,负责 session 的初始化或者从已有 checkpoint 恢复 session,并且创建一些用于保存 checkpoint 和 summary 的 hooks。如果是非 chief 的 worker 节点,则需要依赖 chief 节点完成初始化或恢复 session 这些操作后才能设置属于自己的 session。


@tf_export(v1=['train.MonitoredTrainingSession'])def MonitoredTrainingSession(    master='',  # pylint: disable=invalid-name    is_chief=True,    checkpoint_dir=None,    scaffold=None,    hooks=None,    chief_only_hooks=None,    save_checkpoint_secs=USE_DEFAULT,    save_summaries_steps=USE_DEFAULT,    save_summaries_secs=USE_DEFAULT,    config=None,    stop_grace_period_secs=120,    log_step_count_steps=100,    max_wait_secs=7200,    save_checkpoint_steps=USE_DEFAULT,    summary_dir=None):
""" Creates a `MonitoredSession` for training. Returns: A `MonitoredSession` object. """ scaffold = scaffold or Scaffold() worker_context = distribute_coordinator_context.get_current_worker_context()
if worker_context: return _create_monitored_session_with_worker_context( worker_context, scaffold, checkpoint_dir=checkpoint_dir, hooks=hooks, chief_only_hooks=chief_only_hooks, save_checkpoint_secs=save_checkpoint_secs, save_summaries_steps=save_summaries_steps, save_summaries_secs=save_summaries_secs, config=config, stop_grace_period_secs=stop_grace_period_secs, log_step_count_steps=log_step_count_steps, max_wait_secs=max_wait_secs, save_checkpoint_steps=save_checkpoint_steps, summary_dir=summary_dir)
if not is_chief: session_creator = WorkerSessionCreator( scaffold=scaffold, master=master, config=config, max_wait_secs=max_wait_secs) return MonitoredSession( session_creator=session_creator, hooks=hooks or [], stop_grace_period_secs=stop_grace_period_secs)
all_hooks = [] “”“ 将多个hook都加入到all_hooks这个列表中 ”“” if hooks: all_hooks.extend(hooks)
return MonitoredSession( session_creator=session_creator, hooks=all_hooks, stop_grace_period_secs=stop_grace_period_secs)
复制代码


我们首先解释下参数:


is_chief:用于分布式系统中,用于判断该系统是否是 chief,如果为 True,它将负责初始化并恢复底层 TensorFlow 会话。如果为 False,它将等待 chief 初始化或恢复 TensorFlow 会话。


checkpoint_dir:一个字符串。指定一个用于恢复变量的 checkpoint 文件路径。


scaffold:用于收集或建立支持性 op 的脚手架。如果未指定,则会创建默认一个默认的 scaffold。它用于完成图表的创建。


hooks:SessionRunHook 对象的可选列表。可自己定义 SessionRunHook 对象,也可用已经预定义好的 SessionRunHook 对象,如:tf.train.StopAtStepHook()设置停止训练的条件;tf.train.NanTensorHook(loss):如果 loss 的值为 Nan 则停止训练;


chief_only_hooks:SessionRunHook 对象列表。如果 is_chief== True,则激活这些挂钩,否则忽略。


save_checkpoint_secs:用默认的 checkpoint saver 保存 checkpoint 的频率(以秒为单位)。如果 save_checkpoint_secs 设置为 None,不保存 checkpoint。


save_summaries_steps:使用默认 summaries saver 将摘要写入磁盘的频率(以全局步数表示)。如果 save_summaries_steps 和 save_summaries_secs 都设置为 None,则不使用默认的 summaries saver 保存 summaries。默认为 100


save_summaries_secs:使用默认 summaries saver 将摘要写入磁盘的频率(以秒为单位)。如果 save_summaries_steps 和 save_summaries_secs 都设置为 None,则不使用默认的摘要保存。默认未启用。


config:用于配置会话的 tf.ConfigProtoproto 的实例。它是 tf.Session 的构造函数的 config 参数。


stop_grace_period_secs:调用 close()后线程停止的秒数。


log_step_count_steps:记录全局步/秒的全局步数的频率。


实例化后可得到一个 MonitoredSession 对象,可当作普通 session 使用。


然后我们仔细分解下代码:


def _create_monitored_session_with_worker_context(    worker_context,  # pylint: disable=missing-docstring    scaffold,    checkpoint_dir=None,    hooks=None,    chief_only_hooks=None,    save_checkpoint_secs=None,    save_summaries_steps=None,    save_summaries_secs=None,    config=None,    stop_grace_period_secs=120,    log_step_count_steps=100,    max_wait_secs=7200,    save_checkpoint_steps=None,    summary_dir=None):  all_hooks = []
“”“
复制代码


将多个 hook 都加入到 all_hooks 这个列表中


  ”“”
logging.info('all_hooks %r', all_hooks) # 创建session session_creator = worker_context.session_creator( scaffold, config=config, checkpoint_dir=checkpoint_dir, max_wait_secs=max_wait_secs)
return MonitoredSession( session_creator=session_creator, hooks=all_hooks, stop_grace_period_secs=stop_grace_period_secs)
# session_creator 函数主体 def session_creator(self, scaffold=None, config=None, checkpoint_dir=None, checkpoint_filename_with_path=None, max_wait_secs=7200): """
复制代码


根据正确master的target和session的config去返回session的creator方法体。
复制代码


    """    if config:      session_config = copy.deepcopy(config)      session_config.MergeFrom(self._session_config)    else:      session_config = self._session_config
复制代码


# 根据不同的角色来创建session
复制代码


    if not self._strategy or self._strategy.extended.experimental_should_init:      logging.info("Creating chief session creator with config: %r", config)      return monitored_session.ChiefSessionCreator(          scaffold,          master=self.master_target,          config=session_config,          checkpoint_dir=checkpoint_dir,          checkpoint_filename_with_path=checkpoint_filename_with_path)    else:      logging.info("Creating worker session creator with config: %r", config)      return monitored_session.WorkerSessionCreator(          scaffold,          master=self.master_target,          config=session_config,          max_wait_secs=max_wait_secs)
# ChiefSessionCreator@tf_export(v1=['train.ChiefSessionCreator'])class ChiefSessionCreator(SessionCreator): """Creates a tf.compat.v1.Session for a chief.""" def __init__(self, scaffold=None, master='', config=None, checkpoint_dir=None, checkpoint_filename_with_path=None): self._checkpoint_dir = checkpoint_dir self._checkpoint_filename_with_path = checkpoint_filename_with_path self._scaffold = scaffold or Scaffold() self._session_manager = None self._master = master self._config = config
def _get_session_manager(self): if self._session_manager: return self._session_manager
self._session_manager = sm.SessionManager( local_init_op=self._scaffold.local_init_op, ready_op=self._scaffold.ready_op, ready_for_local_init_op=self._scaffold.ready_for_local_init_op, graph=ops.get_default_graph()) return self._session_manager
def create_session(self): self._scaffold.finalize() return self._get_session_manager().prepare_session( self._master, saver=self._scaffold.saver, checkpoint_dir=self._checkpoint_dir, checkpoint_filename_with_path=self._checkpoint_filename_with_path, config=self._config, init_op=self._scaffold.init_op, init_feed_dict=self._scaffold.init_feed_dict, init_fn=self._scaffold.init_fn)
# WorkerSessionCreator@tf_export(v1=['train.WorkerSessionCreator'])class WorkerSessionCreator(SessionCreator): """Creates a tf.compat.v1.Session for a worker.""" def __init__(self, scaffold=None, master='', config=None, max_wait_secs=30 * 60): """Initializes a worker session creator.
Args: max_wait_secs: Maximum time to wait for the session to become available. """ self._scaffold = scaffold or Scaffold() self._session_manager = None self._master = master self._config = config self._max_wait_secs = max_wait_secs
def _get_session_manager(self): if self._session_manager: return self._session_manager
self._session_manager = sm.SessionManager( local_init_op=self._scaffold.local_init_op, ready_op=self._scaffold.ready_op, ready_for_local_init_op=self._scaffold.ready_for_local_init_op, graph=ops.get_default_graph()) return self._session_manager
def create_session(self): self._scaffold.finalize() return self._get_session_manager().wait_for_session( self._master, config=self._config, max_wait_secs=self._max_wait_secs)
复制代码


从上面的源码中分析得到,MonitoredTrainingSession 可根据不同的角色去创建不同种类的 Session,其中 chief 节点是由 ChiefSessionCreator 类去创建 session,而非 chief 的 worker 节点是由 WorkerSessionCreator 类创建,特殊之处就是创建时调用的是 wait_for_session(),大致意识是需要等待 chief 节点的 session 创建完成之后才去创建属于自己节点的 session。其中创建 session 都是属于 SessionManager 类的一个方法,下面我们具体分析下 SessionManager 类:


官方针对 SessionManager 类有一个简单的例子,感觉很清楚:


  # prepare_session函数可以初始化或者restore一个模型,需要传入`init_op`和 `saver`   with tf.Graph().as_default():    # add operations to the graph...    # Create a SessionManager that will checkpoint the model in '/tmp/mydir'.    sm = SessionManager()    sess = sm.prepare_session(master, init_op, saver, checkpoint_dir)    # Use the session to train the graph.    while True:      sess.run(<my_train_op>)
复制代码

第二个进程可以用以下方法启动 op,wait_for_session()的意思是需要等上面一个 session 创建好之后

  # 再创建自己的session  with tf.Graph().as_default():    # ...add operations to the graph...    # Create a SessionManager that will wait for the model to become ready.    sm = SessionManager()    sess = sm.wait_for_session(master)    # Use the session to train the graph.    while True:      sess.run(<my_train_op>)
复制代码


然后我们可以重点关注下 prepare_session 和 wait_for_session 这两个函数:


@tf_export(v1=["train.SessionManager"])class SessionManager(object):  def __init__(self,               local_init_op=None,               ready_op=None,               ready_for_local_init_op=None,               graph=None,               recovery_wait_secs=30,               local_init_run_options=None):    """
复制代码


local_init_op 是每当有一个新的session被创建时,就会运行下local_init_op这个操作。ready_op 用于check模型是否准备好的一个op。ready_for_local_init_op是checkp模型是否已经可以运行local_init_op的一个op。
复制代码


    """    # Sets default values of arguments.    if graph is None:      graph = ops.get_default_graph()    self._local_init_op = local_init_op    self._ready_op = ready_op    self._ready_for_local_init_op = ready_for_local_init_op    self._graph = graph    self._recovery_wait_secs = recovery_wait_secs    self._target = None    self._local_init_run_options = local_init_run_options    if ready_for_local_init_op is not None and local_init_op is None:      raise ValueError("If you pass a ready_for_local_init_op "                       "you must also pass a local_init_op "                       ", ready_for_local_init_op [%s]" %                       ready_for_local_init_op)
def prepare_session(self, master, init_op=None, saver=None, checkpoint_dir=None, checkpoint_filename_with_path=None, wait_for_checkpoint=False, max_wait_secs=7200, config=None, init_feed_dict=None, init_fn=None): """
复制代码


其实prepare_session函数的作用就是如果有checkpoint存在,就从checkpoint恢复session,如果不存在checkpoint就从传入的`init_op`和 调用`init_fn`函数去创建session。
复制代码


    """    sess, is_loaded_from_checkpoint = self._restore_checkpoint(        master,        saver,        checkpoint_dir=checkpoint_dir,        checkpoint_filename_with_path=checkpoint_filename_with_path,        wait_for_checkpoint=wait_for_checkpoint,        max_wait_secs=max_wait_secs,        config=config)    if not is_loaded_from_checkpoint:      if init_op is None and not init_fn and self._local_init_op is None:        raise RuntimeError("Model is not initialized and no init_op or "                           "init_fn or local_init_op was given")      if init_op is not None:        sess.run(init_op, feed_dict=init_feed_dict)      if init_fn:        init_fn(sess)    ”“”    .....    “”“    return sess

def wait_for_session(self, master, config=None, max_wait_secs=float("Inf")): """ Creates a new `Session` and waits for model to be ready. """ self._target = master if max_wait_secs is None: max_wait_secs = float("Inf") timer = _CountDownTimer(max_wait_secs)
while True: sess = session.Session(self._target, graph=self._graph, config=config) not_ready_msg = None not_ready_local_msg = None local_init_success, not_ready_local_msg = self._try_run_local_init_op( sess) if local_init_success: # Successful if local_init_op is None, or ready_for_local_init_op passes is_ready, not_ready_msg = self._model_ready(sess) if is_ready: return sess self._safe_close(sess) # Do we have enough time left to try again? remaining_ms_after_wait = ( timer.secs_remaining() - self._recovery_wait_secs) if remaining_ms_after_wait < 0: raise errors.DeadlineExceededError( None, None, "Session was not ready after waiting %d secs." % (max_wait_secs,)) logging.info("Waiting for model to be ready. " "Ready_for_local_init_op: %s, ready: %s", not_ready_local_msg, not_ready_msg) time.sleep(self._recovery_wait_secs)
复制代码


创建完 session 之后,再包装一下返回最终的 MonitoredSession 类,


一个完整的 monitored session 在创建时间内可做的事情(按顺序):


  • 为每个 hook 调用 hook.begin()

  • MonitoredTrainingSession 是 tensorflow 管理分布式训练中一个使用很广泛的 API,集成了一些监控训练组件,如变量的初始化、从已有 checkpoint 恢复训练、summary、log 和 checkpoint 的保存等。在早期的 tf 版本中,一般使用 tf.train.Supervisor 来管理 session,后来框架升级后,官方推荐使用 MonitoredTrainingSession。MonitoredTrainingSession 有记录日志、训练可视化、checkpoint 保存、early-stop、训练效率调优等功能。


我们直接进入主题,下面是 MonitoredTrainingSession 源码,从注释中可了解到:MonitoredTrainingSession 的作用可用一句话来概括:如果 chief 节点,负责 session 的初始化或者从已有 checkpoint 恢复 session,并且创建一些用于保存 checkpoint 和 summary 的 hooks。如果是非 chief 的 worker 节点,则需要依赖 chief 节点完成初始化或恢复 session 这些操作后才能设置属于自己的 session。


@tf_export(v1=[‘train.MonitoredTrainingSession’])


def MonitoredTrainingSession(


master=’’, # pylint: disable=invalid-name


is_chief=True,


checkpoint_dir=None,


scaffold=None,


hooks=None,


chief_only_hooks=None,


save_checkpoint_secs=USE_DEFAULT,


save_summaries_steps=USE_DEFAULT,


save_summaries_secs=USE_DEFAULT,


config=None,


stop_grace_period_secs=120,


log_step_count_steps=100,


max_wait_secs=7200,


save_checkpoint_steps=USE_DEFAULT,


summary_dir=None):


“”"


Creates a MonitoredSession for training.


Returns:


A MonitoredSession object.


“”"


scaffold = scaffold or Scaffold()


worker_context = distribute_coordinator_context.get_current_worker_context()


if worker_context:


return _create_monitored_session_with_worker_context(


worker_context,


scaffold,


checkpoint_dir=checkpoint_dir,


hooks=hooks,


chief_only_hooks=chief_only_hooks,


save_checkpoint_secs=save_checkpoint_secs,


save_summaries_steps=save_summaries_steps,


save_summaries_secs=save_summaries_secs,


config=config,


stop_grace_period_secs=stop_grace_period_secs,


log_step_count_steps=log_step_count_steps,


max_wait_secs=max_wait_secs,


save_checkpoint_steps=save_checkpoint_steps,


summary_dir=summary_dir)


if not is_chief:


session_creator = WorkerSessionCreator(


scaffold=scaffold,


master=master,


config=config,


max_wait_secs=max_wait_secs)


return MonitoredSession(


session_creator=session_creator,


hooks=hooks or [],


stop_grace_period_secs=stop_grace_period_secs)


all_hooks = []


“”“


将多个 hook 都加入到 all_hooks 这个列表中


”“”


if hooks:


all_hooks.extend(hooks)


return MonitoredSession(


session_creator=session_creator,


hooks=all_hooks,


stop_grace_period_secs=stop_grace_period_secs)


我们首先解释下参数:


is_chief:用于分布式系统中,用于判断该系统是否是 chief,如果为 True,它将负责初始化并恢复底层 TensorFlow 会话。如果为 False,它将等待 chief 初始化或恢复 TensorFlow 会话。


checkpoint_dir:一个字符串。指定一个用于恢复变量的 checkpoint 文件路径。


scaffold:用于收集或建立支持性 op 的脚手架。如果未指定,则会创建默认一个默认的 scaffold。它用于完成图表的创建。


hooks:SessionRunHook 对象的可选列表。可自己定义 SessionRunHook 对象,也可用已经预定义好的 SessionRunHook 对象,如:tf.train.StopAtStepHook()设置停止训练的条件;tf.train.NanTensorHook(loss):如果 loss 的值为 Nan 则停止训练;


chief_only_hooks:SessionRunHook 对象列表。如果 is_chief== True,则激活这些挂钩,否则忽略。


save_checkpoint_secs:用默认的 checkpoint saver 保存 checkpoint 的频率(以秒为单位)。如果 save_checkpoint_secs 设置为 None,不保存 checkpoint。


save_summaries_steps:使用默认 summaries saver 将摘要写入磁盘的频率(以全局步数表示)。如果 save_summaries_steps 和 save_summaries_secs 都设置为 None,则不使用默认的 summaries saver 保存 summaries。默认为 100


save_summaries_secs:使用默认 summaries saver 将摘要写入磁盘的频率(以秒为单位)。如果 save_summaries_steps 和 save_summaries_secs 都设置为 None,则不使用默认的摘要保存。默认未启用。


config:用于配置会话的 tf.ConfigProtoproto 的实例。它是 tf.Session 的构造函数的 config 参数。


stop_grace_period_secs:调用 close()后线程停止的秒数。


log_step_count_steps:记录全局步/秒的全局步数的频率。


实例化后可得到一个 MonitoredSession 对象,可当作普通 session 使用。


然后我们仔细分解下代码:


def _create_monitored_session_with_worker_context(


worker_context, # pylint: disable=missing-docstring


scaffold,


checkpoint_dir=None,


hooks=None,


chief_only_hooks=None,


save_checkpoint_secs=None,


save_summaries_steps=None,


save_summaries_secs=None,


config=None,


stop_grace_period_secs=120,


log_step_count_steps=100,


max_wait_secs=7200,


save_checkpoint_steps=None,


summary_dir=None):


all_hooks = []


“”“


将多个 hook 都加入到 all_hooks 这个列表中


”“”


logging.info(‘all_hooks %r’, all_hooks)

创建 session

session_creator = worker_context.session_creator(


scaffold,


config=config,


checkpoint_dir=checkpoint_dir,


max_wait_secs=max_wait_secs)


return MonitoredSession(


session_creator=session_creator,


hooks=all_hooks,


stop_grace_period_secs=stop_grace_period_secs)

session_creator 函数主体

def session_creator(self,


scaffold=None,


config=None,


checkpoint_dir=None,


checkpoint_filename_with_path=None,


max_wait_secs=7200):


“”"


根据正确 master 的 target 和 session 的 config 去返回 session 的 creator 方法体。


“”"


if config:


session_config = copy.deepcopy(config)


session_config.MergeFrom(self._session_config)


else:


session_config = self._session_config


# 根据不同的角色来创建sessionif not self._strategy or self._strategy.extended.experimental_should_init:  logging.info("Creating chief session creator with config: %r", config)  return monitored_session.ChiefSessionCreator(      scaffold,      master=self.master_target,      config=session_config,      checkpoint_dir=checkpoint_dir,      checkpoint_filename_with_path=checkpoint_filename_with_path)else:  logging.info("Creating worker session creator with config: %r", config)  return monitored_session.WorkerSessionCreator(      scaffold,      master=self.master_target,      config=session_config,      max_wait_secs=max_wait_secs)
复制代码

ChiefSessionCreator

@tf_export(v1=[‘train.ChiefSessionCreator’])


class ChiefSessionCreator(SessionCreator):


“”“Creates a tf.compat.v1.Session for a chief.”""


def init(self,


scaffold=None,


master=’’,


config=None,


checkpoint_dir=None,


checkpoint_filename_with_path=None):


self._checkpoint_dir = checkpoint_dir


self._checkpoint_filename_with_path = checkpoint_filename_with_path


self._scaffold = scaffold or Scaffold()


self._session_manager = None


self._master = master


self._config = config


def _get_session_manager(self):


if self._session_manager:


return self._session_manager


self._session_manager = sm.SessionManager(    local_init_op=self._scaffold.local_init_op,    ready_op=self._scaffold.ready_op,    ready_for_local_init_op=self._scaffold.ready_for_local_init_op,    graph=ops.get_default_graph())return self._session_manager
复制代码


def create_session(self):


self._scaffold.finalize()


return self._get_session_manager().prepare_session(


self._master,


saver=self._scaffold.saver,


checkpoint_dir=self._checkpoint_dir,


checkpoint_filename_with_path=self._checkpoint_filename_with_path,


config=self._config,


init_op=self._scaffold.init_op,


init_feed_dict=self._scaffold.init_feed_dict,


init_fn=self._scaffold.init_fn)

WorkerSessionCreator

@tf_export(v1=[‘train.WorkerSessionCreator’])


class WorkerSessionCreator(SessionCreator):


“”“Creates a tf.compat.v1.Session for a worker.”""


def init(self,


scaffold=None,


master=’’,


config=None,


max_wait_secs=30 * 60):


“”"Initializes a worker session creator.


Args:  max_wait_secs: Maximum time to wait for the session to become available."""self._scaffold = scaffold or Scaffold()self._session_manager = Noneself._master = masterself._config = configself._max_wait_secs = max_wait_secs
复制代码


def _get_session_manager(self):


if self._session_manager:


return self._session_manager


self._session_manager = sm.SessionManager(    local_init_op=self._scaffold.local_init_op,    ready_op=self._scaffold.ready_op,    ready_for_local_init_op=self._scaffold.ready_for_local_init_op,    graph=ops.get_default_graph())return self._session_manager
复制代码


def create_session(self):


self._scaffold.finalize()


return self._get_session_manager().wait_for_session(


self._master, config=self._config, max_wait_secs=self._max_wait_secs)


从上面的源码中分析得到,MonitoredTrainingSession 可根据不同的角色去创建不同种类的 Session,其中 chief 节点是由 ChiefSessionCreator 类去创建 session,而非 chief 的 worker 节点是由 WorkerSessionCreator 类创建,特殊之处就是创建时调用的是 wait_for_session(),大致意识是需要等待 chief 节点的 session 创建完成之后才去创建属于自己节点的 session。其中创建 session 都是属于 SessionManager 类的一个方法,下面我们具体分析下 SessionManager 类:


官方针对 SessionManager 类有一个简单的例子,感觉很清楚:

prepare_session 函数可以初始化或者 restore 一个模型,需要传入init_opsaver

with tf.Graph().as_default():


# add operations to the graph…


# Create a SessionManager that will checkpoint the model in ‘/tmp/mydir’.


sm = SessionManager()


sess = sm.prepare_session(master, init_op, saver, checkpoint_dir)


# Use the session to train the graph.


while True:


sess.run(<my_train_op>)

第二个进程可以用以下方法启动 op,wait_for_session()的意思是需要等上面一个 session 创建好之后

再创建自己的 session

with tf.Graph().as_default():


# …add operations to the graph…


# Create a SessionManager that will wait for the model to become ready.


sm = SessionManager()


sess = sm.wait_for_session(master)


# Use the session to train the graph.


while True:


sess.run(<my_train_op>)


然后我们可以重点关注下 prepare_session 和 wait_for_session 这两个函数:


@tf_export(v1=[“train.SessionManager”])


class SessionManager(object):


def init(self,


local_init_op=None,


ready_op=None,


ready_for_local_init_op=None,


graph=None,


recovery_wait_secs=30,


local_init_run_options=None):


“”"


local_init_op 是每当有一个新的 session 被创建时,就会运行下 local_init_op 这个操作。


ready_op 用于 check 模型是否准备好的一个 op。


ready_for_local_init_op 是 checkp 模型是否已经可以运行 local_init_op 的一个 op。


“”"


# Sets default values of arguments.


if graph is None:


graph = ops.get_default_graph()


self._local_init_op = local_init_op


self._ready_op = ready_op


self._ready_for_local_init_op = ready_for_local_init_op


self._graph = graph


self._recovery_wait_secs = recovery_wait_secs


self._target = None


self._local_init_run_options = local_init_run_options


if ready_for_local_init_op is not None and local_init_op is None:


raise ValueError("If you pass a ready_for_local_init_op "


"you must also pass a local_init_op "


“, ready_for_local_init_op [%s]” %


ready_for_local_init_op)


def prepare_session(self,


master,


init_op=None,


saver=None,


checkpoint_dir=None,


checkpoint_filename_with_path=None,


wait_for_checkpoint=False,


max_wait_secs=7200,


config=None,


init_feed_dict=None,


init_fn=None):


“”"


其实 prepare_session 函数的作用就是如果有 checkpoint 存在,就从 checkpoint 恢复 session,如果


不存在 checkpoint 就从传入的init_op和 调用init_fn函数去创建 session。


“”"


sess, is_loaded_from_checkpoint = self._restore_checkpoint(


master,


saver,


checkpoint_dir=checkpoint_dir,


checkpoint_filename_with_path=checkpoint_filename_with_path,


wait_for_checkpoint=wait_for_checkpoint,


max_wait_secs=max_wait_secs,


config=config)


if not is_loaded_from_checkpoint:


if init_op is None and not init_fn and self._local_init_op is None:


raise RuntimeError("Model is not initialized and no init_op or "


“init_fn or local_init_op was given”)


if init_op is not None:


sess.run(init_op, feed_dict=init_feed_dict)


if init_fn:


init_fn(sess)


”“”



“”“


return sess


def wait_for_session(self, master, config=None, max_wait_secs=float(“Inf”)):


“”"


Creates a new Session and waits for model to be ready.


“”"


self._target = master


if max_wait_secs is None:


max_wait_secs = float(“Inf”)


timer = _CountDownTimer(max_wait_secs)


while True:  sess = session.Session(self._target, graph=self._graph, config=config)  not_ready_msg = None  not_ready_local_msg = None  local_init_success, not_ready_local_msg = self._try_run_local_init_op(      sess)  if local_init_success:    # Successful if local_init_op is None, or ready_for_local_init_op passes    is_ready, not_ready_msg = self._model_ready(sess)    if is_ready:      return sess  self._safe_close(sess)  # Do we have enough time left to try again?  remaining_ms_after_wait = (      timer.secs_remaining() - self._recovery_wait_secs)  if remaining_ms_after_wait < 0:    raise errors.DeadlineExceededError(        None, None,        "Session was not ready after waiting %d secs." % (max_wait_secs,))  logging.info("Waiting for model to be ready.  "               "Ready_for_local_init_op:  %s, ready: %s",               not_ready_local_msg, not_ready_msg)  time.sleep(self._recovery_wait_secs)
复制代码


创建完 session 之后,再包装一下返回最终的 MonitoredSession 类,


一个完整的 monitored session 在创建时间内可做的事情(按顺序):


为每个 hook 调用 hook.begin()


  • 调用 scaffold.finalize()完成 graph

  • 创建 session

  • 为模型参数做初始化 ,通过 Scaffold

  • 如果存在 checkpoint 则根据 checkpoint restore 参数

  • 发布 runners 队列

  • 调用 hook.after_create_session()函数

  • 当 run 函数调用时,monitored session 做的事情:

  • 调用 hook.before_run()

  • 调用 TensorFlow 中的 session.run() with merged fetches and feed_dict

  • 调用 hook.after_run()

  • 返回 session.run()的结果

  • 如果发生 AbortedError 或者 UnavailableError,则在再次执行 run()之前恢复或者重新初始化会话

  • 当 close()函数调用时,monitored session 做的事情:

  • 调用 hook.end()

  • 关闭 queue runners 和 session

  • 如果所有的输入数据被消耗完,抛出 OutOfRange 异常。

  • 最后,给大家贴一个使用 MonitoredSession 类进行分布式训练的 example:


from __future__ import print_function, absolute_import, division
import tensorflow as tf
tf.app.flags.DEFINE_string("ps_hosts", "localhost:2222", "ps hosts")tf.app.flags.DEFINE_string("worker_hosts", "localhost:2223,localhost:2224", "worker hosts")tf.app.flags.DEFINE_string("job_name", "worker", "'ps' or'worker'")tf.app.flags.DEFINE_integer("task_index", 0, "Index of task within the job")tf.app.flags.DEFINE_integer("num_workers", 2, "Number of workers")tf.app.flags.DEFINE_boolean("is_sync", False, "using synchronous training or not")
FLAGS = tf.app.flags.FLAGS

def model(images): """Define a simple mnist classifier""" net = tf.layers.dense(images, 500, activation=tf.nn.relu) net = tf.layers.dense(net, 500, activation=tf.nn.relu) net = tf.layers.dense(net, 10, activation=None) return net

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()x_train = x_train.reshape(-1, 784).astype('float32')x_test = x_test.reshape(-1, 784).astype('float32')x_train /= 255x_test /= 255

def get_batch(image, label, batch_size=32, training=True): df = tf.data.Dataset.from_tensor_slices((image, label)) if training: df = df.repeat(10).shuffle(buffer_size=1000) df = df.batch(batch_size).prefetch(batch_size) iterator = df.make_one_shot_iterator() batch_x, batch_y = iterator.get_next() return batch_x, batch_y

def main(_): ps_hosts = FLAGS.ps_hosts.split(",") worker_hosts = FLAGS.worker_hosts.split(",")
# create the cluster configured by `ps_hosts' and 'worker_hosts' cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts})
# create a server for local task server = tf.train.Server(cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index)
train_batch_x, train_batch_y = get_batch(x_train, y_train) test_batch_x, test_batch_y = get_batch(x_test, y_test, training=False)
if FLAGS.job_name == "ps": server.join() # ps hosts only join elif FLAGS.job_name == "worker": # workers perform the operation # ps_strategy = tf.contrib.training.GreedyLoadBalancingStrategy(FLAGS.num_ps)
# Note: tf.train.replica_device_setter automatically place the paramters (Variables) # on the ps hosts (default placement strategy: round-robin over all ps hosts, and also # place multi copies of operations to each worker host with tf.device(tf.train.replica_device_setter(worker_device="/job:worker/task:%d" % FLAGS.task_index, cluster=cluster)):
logits = model(train_batch_x) loss = tf.reduce_mean( tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=tf.one_hot(train_batch_y, 10)))
# The StopAtStepHook handles stopping after running given steps. hooks = [tf.train.StopAtStepHook(last_step=10000)]
global_step = tf.train.get_or_create_global_step() optimizer = tf.train.AdamOptimizer(learning_rate=1e-04) if FLAGS.is_sync: # asynchronous training # use tf.train.SyncReplicasOptimizer wrap optimizer # ref: https://www.tensorflow.org/api_docs/python/tf/train/SyncReplicasOptimizer optimizer = tf.train.SyncReplicasOptimizer(optimizer, replicas_to_aggregate=FLAGS.num_workers, total_num_replicas=FLAGS.num_workers) # create the hook which handles initialization and queues hooks.append(optimizer.make_session_run_hook((FLAGS.task_index == 0)))
train_op = optimizer.minimize(loss, global_step=global_step)
# The MonitoredTrainingSession takes care of session initialization, # restoring from a checkpoint, saving to a checkpoint, and closing when done # or an error occurs. with tf.train.MonitoredTrainingSession(master=server.target, is_chief=(FLAGS.task_index == 0), checkpoint_dir="./checkpoint_dir", hooks=hooks) as mon_sess: while not mon_sess.should_stop(): # mon_sess.run handles AbortedError in case of preempted PS. _, ls, step = mon_sess.run([train_op, loss, global_step]) if step % 100 == 0: print("Train step %d, loss: %f" % (step, ls))
if __name__ == "__main__": tf.app.run()
复制代码


参考文献:


https://www.cnblogs.com/estragon/p/10034511.html


https://zhuanlan.zhihu.com/p/88876923


本文转载自 Alex-zhai 知乎账号。


原文链接:https://zhuanlan.zhihu.com/p/91608555


2019-11-28 08:001708

评论

发布
暂无评论
发现更多内容

聊聊接口文档的事儿

京茶吉鹿

接口文档 Knife4j swagger2

KgCaptcha验证的那些事

宙哈哈

php Python html 验证码

2023企业上云暨算云融合产业大会在京召开

中国IDC圈

算力 可信云

统一观测丨使用 Prometheus 监控 Nginx Ingress 网关最佳实践

阿里巴巴云原生

阿里云 云原生 Prometheus

# 架构实战营-模块1-作业

Geek_e948d4

最强嘴替:新任技术管理者如何快速成长,完成转型逆袭?

LigaAI

技术管理 管理者 逆袭 技术人成长 企业号 4 月 PK 榜

Spring MVC 之 HttpMessageConverter

Java spring Spring MVC

开源轻量级 IM 框架 MobileIMSDK 的微信小程序端已发布!

JackJiang

网络编程 IM 即时通讯IM

2023 - Dubbo 谷歌编程之夏报名启动了!

阿里巴巴云原生

阿里云 云原生 dubbo

网站上的视频资源被偷偷转载了...

为自己带盐

知识产权 ffmpeg HLS openssl

KgCaptcha验证码实现笔记

宙哈哈

Python html 验证码

强强携手促发展 中建信息成为麒麟软件全国总经销商

极客天地

架构训练营模块一作业

请叫我馒头哥丶

架构 架构实战营

硬核!GitHub置顶102W字Redis高手心法笔记

Java 数据库 redis 缓存 面试

喜讯!华秋电子荣获深圳市半导体行业协会优秀合作奖

华秋电子

LeaRun低代码开发平台 赋能企业快速落地BI大屏

力软低代码开发平台

软件测试/测试开发丨必知必会的Docker 命令

测试人

Docker 软件测试 自动化测试 测试开发

软件测试/测试开发丨Docker 搭建Web服务器nginx

测试人

nginx Docker 软件测试 自动化测试 测试开发

北京国家会计学院副教授王亚星:智能会计和价值财务有力支撑企业高质量发展

用友BIP

【送猫超卡、阿里云代金券】动手体验 SAE+云效 10 分钟快速打通 CI/CD 流水线

阿里巴巴云原生

阿里云 Serverless 云原生

2023 年“和鲸杯”辽宁省普通高等学校本科大学生计算机设计竞赛启动会顺利召开

ModelWhale

大数据 人才培养 数据科学 数据思维 数据竞赛

教你如何通过CodeArts IDE插件调用API,高效合成语音

华为云开发者联盟

云计算 开发 华为云 华为云开发者联盟 企业号 4 月 PK 榜

三点几嚟,饮茶先啦!PaddleSpeech发布全流程粤语语音合成

飞桨PaddlePaddle

人工智能 机器学习 深度学习 语音识别

看我如何用定值 Cookie 实现反爬

华为云开发者联盟

爬虫 开发 华为云 华为云开发者联盟 企业号 4 月 PK 榜

CANN训练:模型推理时数据预处理方法及归一化参数计算

华为云开发者联盟

人工智能 华为云 华为云开发者联盟 企业号 4 月 PK 榜

PHP短信验证码防刷方案

宙哈哈

php html 图片验证码

数智时代的来临,养老行业接入人工智能技术已是势不可挡

加入高科技仿生人

人工智能 AI 养老服务 养老

ES和MongoDB:一次别开生面的比较

Java你猿哥

数据库 mongodb elasticsearch ES API

没有设计师?没问题!Spring+OpenAI让你也能生成漂亮的图片!

Java你猿哥

Java spring maven API

华为云GaussDB践行数字化,护航证券保险高质量发展

华为云开发者联盟

数据库 后端 华为云 华为云开发者联盟 企业号 4 月 PK 榜

分布式tensorflow源码解读2:MonitoredTrainingSession_语言 & 开发_Alex-zhai_InfoQ精选文章