写点什么

分布式 tensorflow 源码解读 2:MonitoredTrainingSession

  • 2019-11-28
  • 本文字数:15917 字

    阅读完需:约 52 分钟

分布式tensorflow源码解读2:MonitoredTrainingSession

AI 大模型超全落地场景&金融应用实践,8 月 16 - 19 日 FCon x AICon 大会联诀来袭、干货翻倍!

MonitoredTrainingSession 是 tensorflow 管理分布式训练中一个使用很广泛的 API,集成了一些监控训练组件,如变量的初始化、从已有 checkpoint 恢复训练、summary、log 和 checkpoint 的保存等。在早期的 tf 版本中,一般使用 tf.train.Supervisor 来管理 session,后来框架升级后,官方推荐使用 MonitoredTrainingSession。MonitoredTrainingSession 有记录日志、训练可视化、checkpoint 保存、early-stop、训练效率调优等功能。


我们直接进入主题,下面是 MonitoredTrainingSession 源码,从注释中可了解到:MonitoredTrainingSession 的作用可用一句话来概括:如果 chief 节点,负责 session 的初始化或者从已有 checkpoint 恢复 session,并且创建一些用于保存 checkpoint 和 summary 的 hooks。如果是非 chief 的 worker 节点,则需要依赖 chief 节点完成初始化或恢复 session 这些操作后才能设置属于自己的 session。


@tf_export(v1=['train.MonitoredTrainingSession'])def MonitoredTrainingSession(    master='',  # pylint: disable=invalid-name    is_chief=True,    checkpoint_dir=None,    scaffold=None,    hooks=None,    chief_only_hooks=None,    save_checkpoint_secs=USE_DEFAULT,    save_summaries_steps=USE_DEFAULT,    save_summaries_secs=USE_DEFAULT,    config=None,    stop_grace_period_secs=120,    log_step_count_steps=100,    max_wait_secs=7200,    save_checkpoint_steps=USE_DEFAULT,    summary_dir=None):
""" Creates a `MonitoredSession` for training. Returns: A `MonitoredSession` object. """ scaffold = scaffold or Scaffold() worker_context = distribute_coordinator_context.get_current_worker_context()
if worker_context: return _create_monitored_session_with_worker_context( worker_context, scaffold, checkpoint_dir=checkpoint_dir, hooks=hooks, chief_only_hooks=chief_only_hooks, save_checkpoint_secs=save_checkpoint_secs, save_summaries_steps=save_summaries_steps, save_summaries_secs=save_summaries_secs, config=config, stop_grace_period_secs=stop_grace_period_secs, log_step_count_steps=log_step_count_steps, max_wait_secs=max_wait_secs, save_checkpoint_steps=save_checkpoint_steps, summary_dir=summary_dir)
if not is_chief: session_creator = WorkerSessionCreator( scaffold=scaffold, master=master, config=config, max_wait_secs=max_wait_secs) return MonitoredSession( session_creator=session_creator, hooks=hooks or [], stop_grace_period_secs=stop_grace_period_secs)
all_hooks = [] “”“ 将多个hook都加入到all_hooks这个列表中 ”“” if hooks: all_hooks.extend(hooks)
return MonitoredSession( session_creator=session_creator, hooks=all_hooks, stop_grace_period_secs=stop_grace_period_secs)
复制代码


我们首先解释下参数:


is_chief:用于分布式系统中,用于判断该系统是否是 chief,如果为 True,它将负责初始化并恢复底层 TensorFlow 会话。如果为 False,它将等待 chief 初始化或恢复 TensorFlow 会话。


checkpoint_dir:一个字符串。指定一个用于恢复变量的 checkpoint 文件路径。


scaffold:用于收集或建立支持性 op 的脚手架。如果未指定,则会创建默认一个默认的 scaffold。它用于完成图表的创建。


hooks:SessionRunHook 对象的可选列表。可自己定义 SessionRunHook 对象,也可用已经预定义好的 SessionRunHook 对象,如:tf.train.StopAtStepHook()设置停止训练的条件;tf.train.NanTensorHook(loss):如果 loss 的值为 Nan 则停止训练;


chief_only_hooks:SessionRunHook 对象列表。如果 is_chief== True,则激活这些挂钩,否则忽略。


save_checkpoint_secs:用默认的 checkpoint saver 保存 checkpoint 的频率(以秒为单位)。如果 save_checkpoint_secs 设置为 None,不保存 checkpoint。


save_summaries_steps:使用默认 summaries saver 将摘要写入磁盘的频率(以全局步数表示)。如果 save_summaries_steps 和 save_summaries_secs 都设置为 None,则不使用默认的 summaries saver 保存 summaries。默认为 100


save_summaries_secs:使用默认 summaries saver 将摘要写入磁盘的频率(以秒为单位)。如果 save_summaries_steps 和 save_summaries_secs 都设置为 None,则不使用默认的摘要保存。默认未启用。


config:用于配置会话的 tf.ConfigProtoproto 的实例。它是 tf.Session 的构造函数的 config 参数。


stop_grace_period_secs:调用 close()后线程停止的秒数。


log_step_count_steps:记录全局步/秒的全局步数的频率。


实例化后可得到一个 MonitoredSession 对象,可当作普通 session 使用。


然后我们仔细分解下代码:


def _create_monitored_session_with_worker_context(    worker_context,  # pylint: disable=missing-docstring    scaffold,    checkpoint_dir=None,    hooks=None,    chief_only_hooks=None,    save_checkpoint_secs=None,    save_summaries_steps=None,    save_summaries_secs=None,    config=None,    stop_grace_period_secs=120,    log_step_count_steps=100,    max_wait_secs=7200,    save_checkpoint_steps=None,    summary_dir=None):  all_hooks = []
“”“
复制代码


将多个 hook 都加入到 all_hooks 这个列表中


  ”“”
logging.info('all_hooks %r', all_hooks) # 创建session session_creator = worker_context.session_creator( scaffold, config=config, checkpoint_dir=checkpoint_dir, max_wait_secs=max_wait_secs)
return MonitoredSession( session_creator=session_creator, hooks=all_hooks, stop_grace_period_secs=stop_grace_period_secs)
# session_creator 函数主体 def session_creator(self, scaffold=None, config=None, checkpoint_dir=None, checkpoint_filename_with_path=None, max_wait_secs=7200): """
复制代码


根据正确master的target和session的config去返回session的creator方法体。
复制代码


    """    if config:      session_config = copy.deepcopy(config)      session_config.MergeFrom(self._session_config)    else:      session_config = self._session_config
复制代码


# 根据不同的角色来创建session
复制代码


    if not self._strategy or self._strategy.extended.experimental_should_init:      logging.info("Creating chief session creator with config: %r", config)      return monitored_session.ChiefSessionCreator(          scaffold,          master=self.master_target,          config=session_config,          checkpoint_dir=checkpoint_dir,          checkpoint_filename_with_path=checkpoint_filename_with_path)    else:      logging.info("Creating worker session creator with config: %r", config)      return monitored_session.WorkerSessionCreator(          scaffold,          master=self.master_target,          config=session_config,          max_wait_secs=max_wait_secs)
# ChiefSessionCreator@tf_export(v1=['train.ChiefSessionCreator'])class ChiefSessionCreator(SessionCreator): """Creates a tf.compat.v1.Session for a chief.""" def __init__(self, scaffold=None, master='', config=None, checkpoint_dir=None, checkpoint_filename_with_path=None): self._checkpoint_dir = checkpoint_dir self._checkpoint_filename_with_path = checkpoint_filename_with_path self._scaffold = scaffold or Scaffold() self._session_manager = None self._master = master self._config = config
def _get_session_manager(self): if self._session_manager: return self._session_manager
self._session_manager = sm.SessionManager( local_init_op=self._scaffold.local_init_op, ready_op=self._scaffold.ready_op, ready_for_local_init_op=self._scaffold.ready_for_local_init_op, graph=ops.get_default_graph()) return self._session_manager
def create_session(self): self._scaffold.finalize() return self._get_session_manager().prepare_session( self._master, saver=self._scaffold.saver, checkpoint_dir=self._checkpoint_dir, checkpoint_filename_with_path=self._checkpoint_filename_with_path, config=self._config, init_op=self._scaffold.init_op, init_feed_dict=self._scaffold.init_feed_dict, init_fn=self._scaffold.init_fn)
# WorkerSessionCreator@tf_export(v1=['train.WorkerSessionCreator'])class WorkerSessionCreator(SessionCreator): """Creates a tf.compat.v1.Session for a worker.""" def __init__(self, scaffold=None, master='', config=None, max_wait_secs=30 * 60): """Initializes a worker session creator.
Args: max_wait_secs: Maximum time to wait for the session to become available. """ self._scaffold = scaffold or Scaffold() self._session_manager = None self._master = master self._config = config self._max_wait_secs = max_wait_secs
def _get_session_manager(self): if self._session_manager: return self._session_manager
self._session_manager = sm.SessionManager( local_init_op=self._scaffold.local_init_op, ready_op=self._scaffold.ready_op, ready_for_local_init_op=self._scaffold.ready_for_local_init_op, graph=ops.get_default_graph()) return self._session_manager
def create_session(self): self._scaffold.finalize() return self._get_session_manager().wait_for_session( self._master, config=self._config, max_wait_secs=self._max_wait_secs)
复制代码


从上面的源码中分析得到,MonitoredTrainingSession 可根据不同的角色去创建不同种类的 Session,其中 chief 节点是由 ChiefSessionCreator 类去创建 session,而非 chief 的 worker 节点是由 WorkerSessionCreator 类创建,特殊之处就是创建时调用的是 wait_for_session(),大致意识是需要等待 chief 节点的 session 创建完成之后才去创建属于自己节点的 session。其中创建 session 都是属于 SessionManager 类的一个方法,下面我们具体分析下 SessionManager 类:


官方针对 SessionManager 类有一个简单的例子,感觉很清楚:


  # prepare_session函数可以初始化或者restore一个模型,需要传入`init_op`和 `saver`   with tf.Graph().as_default():    # add operations to the graph...    # Create a SessionManager that will checkpoint the model in '/tmp/mydir'.    sm = SessionManager()    sess = sm.prepare_session(master, init_op, saver, checkpoint_dir)    # Use the session to train the graph.    while True:      sess.run(<my_train_op>)
复制代码

第二个进程可以用以下方法启动 op,wait_for_session()的意思是需要等上面一个 session 创建好之后

  # 再创建自己的session  with tf.Graph().as_default():    # ...add operations to the graph...    # Create a SessionManager that will wait for the model to become ready.    sm = SessionManager()    sess = sm.wait_for_session(master)    # Use the session to train the graph.    while True:      sess.run(<my_train_op>)
复制代码


然后我们可以重点关注下 prepare_session 和 wait_for_session 这两个函数:


@tf_export(v1=["train.SessionManager"])class SessionManager(object):  def __init__(self,               local_init_op=None,               ready_op=None,               ready_for_local_init_op=None,               graph=None,               recovery_wait_secs=30,               local_init_run_options=None):    """
复制代码


local_init_op 是每当有一个新的session被创建时,就会运行下local_init_op这个操作。ready_op 用于check模型是否准备好的一个op。ready_for_local_init_op是checkp模型是否已经可以运行local_init_op的一个op。
复制代码


    """    # Sets default values of arguments.    if graph is None:      graph = ops.get_default_graph()    self._local_init_op = local_init_op    self._ready_op = ready_op    self._ready_for_local_init_op = ready_for_local_init_op    self._graph = graph    self._recovery_wait_secs = recovery_wait_secs    self._target = None    self._local_init_run_options = local_init_run_options    if ready_for_local_init_op is not None and local_init_op is None:      raise ValueError("If you pass a ready_for_local_init_op "                       "you must also pass a local_init_op "                       ", ready_for_local_init_op [%s]" %                       ready_for_local_init_op)
def prepare_session(self, master, init_op=None, saver=None, checkpoint_dir=None, checkpoint_filename_with_path=None, wait_for_checkpoint=False, max_wait_secs=7200, config=None, init_feed_dict=None, init_fn=None): """
复制代码


其实prepare_session函数的作用就是如果有checkpoint存在,就从checkpoint恢复session,如果不存在checkpoint就从传入的`init_op`和 调用`init_fn`函数去创建session。
复制代码


    """    sess, is_loaded_from_checkpoint = self._restore_checkpoint(        master,        saver,        checkpoint_dir=checkpoint_dir,        checkpoint_filename_with_path=checkpoint_filename_with_path,        wait_for_checkpoint=wait_for_checkpoint,        max_wait_secs=max_wait_secs,        config=config)    if not is_loaded_from_checkpoint:      if init_op is None and not init_fn and self._local_init_op is None:        raise RuntimeError("Model is not initialized and no init_op or "                           "init_fn or local_init_op was given")      if init_op is not None:        sess.run(init_op, feed_dict=init_feed_dict)      if init_fn:        init_fn(sess)    ”“”    .....    “”“    return sess

def wait_for_session(self, master, config=None, max_wait_secs=float("Inf")): """ Creates a new `Session` and waits for model to be ready. """ self._target = master if max_wait_secs is None: max_wait_secs = float("Inf") timer = _CountDownTimer(max_wait_secs)
while True: sess = session.Session(self._target, graph=self._graph, config=config) not_ready_msg = None not_ready_local_msg = None local_init_success, not_ready_local_msg = self._try_run_local_init_op( sess) if local_init_success: # Successful if local_init_op is None, or ready_for_local_init_op passes is_ready, not_ready_msg = self._model_ready(sess) if is_ready: return sess self._safe_close(sess) # Do we have enough time left to try again? remaining_ms_after_wait = ( timer.secs_remaining() - self._recovery_wait_secs) if remaining_ms_after_wait < 0: raise errors.DeadlineExceededError( None, None, "Session was not ready after waiting %d secs." % (max_wait_secs,)) logging.info("Waiting for model to be ready. " "Ready_for_local_init_op: %s, ready: %s", not_ready_local_msg, not_ready_msg) time.sleep(self._recovery_wait_secs)
复制代码


创建完 session 之后,再包装一下返回最终的 MonitoredSession 类,


一个完整的 monitored session 在创建时间内可做的事情(按顺序):


  • 为每个 hook 调用 hook.begin()

  • MonitoredTrainingSession 是 tensorflow 管理分布式训练中一个使用很广泛的 API,集成了一些监控训练组件,如变量的初始化、从已有 checkpoint 恢复训练、summary、log 和 checkpoint 的保存等。在早期的 tf 版本中,一般使用 tf.train.Supervisor 来管理 session,后来框架升级后,官方推荐使用 MonitoredTrainingSession。MonitoredTrainingSession 有记录日志、训练可视化、checkpoint 保存、early-stop、训练效率调优等功能。


我们直接进入主题,下面是 MonitoredTrainingSession 源码,从注释中可了解到:MonitoredTrainingSession 的作用可用一句话来概括:如果 chief 节点,负责 session 的初始化或者从已有 checkpoint 恢复 session,并且创建一些用于保存 checkpoint 和 summary 的 hooks。如果是非 chief 的 worker 节点,则需要依赖 chief 节点完成初始化或恢复 session 这些操作后才能设置属于自己的 session。


@tf_export(v1=[‘train.MonitoredTrainingSession’])


def MonitoredTrainingSession(


master=’’, # pylint: disable=invalid-name


is_chief=True,


checkpoint_dir=None,


scaffold=None,


hooks=None,


chief_only_hooks=None,


save_checkpoint_secs=USE_DEFAULT,


save_summaries_steps=USE_DEFAULT,


save_summaries_secs=USE_DEFAULT,


config=None,


stop_grace_period_secs=120,


log_step_count_steps=100,


max_wait_secs=7200,


save_checkpoint_steps=USE_DEFAULT,


summary_dir=None):


“”"


Creates a MonitoredSession for training.


Returns:


A MonitoredSession object.


“”"


scaffold = scaffold or Scaffold()


worker_context = distribute_coordinator_context.get_current_worker_context()


if worker_context:


return _create_monitored_session_with_worker_context(


worker_context,


scaffold,


checkpoint_dir=checkpoint_dir,


hooks=hooks,


chief_only_hooks=chief_only_hooks,


save_checkpoint_secs=save_checkpoint_secs,


save_summaries_steps=save_summaries_steps,


save_summaries_secs=save_summaries_secs,


config=config,


stop_grace_period_secs=stop_grace_period_secs,


log_step_count_steps=log_step_count_steps,


max_wait_secs=max_wait_secs,


save_checkpoint_steps=save_checkpoint_steps,


summary_dir=summary_dir)


if not is_chief:


session_creator = WorkerSessionCreator(


scaffold=scaffold,


master=master,


config=config,


max_wait_secs=max_wait_secs)


return MonitoredSession(


session_creator=session_creator,


hooks=hooks or [],


stop_grace_period_secs=stop_grace_period_secs)


all_hooks = []


“”“


将多个 hook 都加入到 all_hooks 这个列表中


”“”


if hooks:


all_hooks.extend(hooks)


return MonitoredSession(


session_creator=session_creator,


hooks=all_hooks,


stop_grace_period_secs=stop_grace_period_secs)


我们首先解释下参数:


is_chief:用于分布式系统中,用于判断该系统是否是 chief,如果为 True,它将负责初始化并恢复底层 TensorFlow 会话。如果为 False,它将等待 chief 初始化或恢复 TensorFlow 会话。


checkpoint_dir:一个字符串。指定一个用于恢复变量的 checkpoint 文件路径。


scaffold:用于收集或建立支持性 op 的脚手架。如果未指定,则会创建默认一个默认的 scaffold。它用于完成图表的创建。


hooks:SessionRunHook 对象的可选列表。可自己定义 SessionRunHook 对象,也可用已经预定义好的 SessionRunHook 对象,如:tf.train.StopAtStepHook()设置停止训练的条件;tf.train.NanTensorHook(loss):如果 loss 的值为 Nan 则停止训练;


chief_only_hooks:SessionRunHook 对象列表。如果 is_chief== True,则激活这些挂钩,否则忽略。


save_checkpoint_secs:用默认的 checkpoint saver 保存 checkpoint 的频率(以秒为单位)。如果 save_checkpoint_secs 设置为 None,不保存 checkpoint。


save_summaries_steps:使用默认 summaries saver 将摘要写入磁盘的频率(以全局步数表示)。如果 save_summaries_steps 和 save_summaries_secs 都设置为 None,则不使用默认的 summaries saver 保存 summaries。默认为 100


save_summaries_secs:使用默认 summaries saver 将摘要写入磁盘的频率(以秒为单位)。如果 save_summaries_steps 和 save_summaries_secs 都设置为 None,则不使用默认的摘要保存。默认未启用。


config:用于配置会话的 tf.ConfigProtoproto 的实例。它是 tf.Session 的构造函数的 config 参数。


stop_grace_period_secs:调用 close()后线程停止的秒数。


log_step_count_steps:记录全局步/秒的全局步数的频率。


实例化后可得到一个 MonitoredSession 对象,可当作普通 session 使用。


然后我们仔细分解下代码:


def _create_monitored_session_with_worker_context(


worker_context, # pylint: disable=missing-docstring


scaffold,


checkpoint_dir=None,


hooks=None,


chief_only_hooks=None,


save_checkpoint_secs=None,


save_summaries_steps=None,


save_summaries_secs=None,


config=None,


stop_grace_period_secs=120,


log_step_count_steps=100,


max_wait_secs=7200,


save_checkpoint_steps=None,


summary_dir=None):


all_hooks = []


“”“


将多个 hook 都加入到 all_hooks 这个列表中


”“”


logging.info(‘all_hooks %r’, all_hooks)

创建 session

session_creator = worker_context.session_creator(


scaffold,


config=config,


checkpoint_dir=checkpoint_dir,


max_wait_secs=max_wait_secs)


return MonitoredSession(


session_creator=session_creator,


hooks=all_hooks,


stop_grace_period_secs=stop_grace_period_secs)

session_creator 函数主体

def session_creator(self,


scaffold=None,


config=None,


checkpoint_dir=None,


checkpoint_filename_with_path=None,


max_wait_secs=7200):


“”"


根据正确 master 的 target 和 session 的 config 去返回 session 的 creator 方法体。


“”"


if config:


session_config = copy.deepcopy(config)


session_config.MergeFrom(self._session_config)


else:


session_config = self._session_config


# 根据不同的角色来创建sessionif not self._strategy or self._strategy.extended.experimental_should_init:  logging.info("Creating chief session creator with config: %r", config)  return monitored_session.ChiefSessionCreator(      scaffold,      master=self.master_target,      config=session_config,      checkpoint_dir=checkpoint_dir,      checkpoint_filename_with_path=checkpoint_filename_with_path)else:  logging.info("Creating worker session creator with config: %r", config)  return monitored_session.WorkerSessionCreator(      scaffold,      master=self.master_target,      config=session_config,      max_wait_secs=max_wait_secs)
复制代码

ChiefSessionCreator

@tf_export(v1=[‘train.ChiefSessionCreator’])


class ChiefSessionCreator(SessionCreator):


“”“Creates a tf.compat.v1.Session for a chief.”""


def init(self,


scaffold=None,


master=’’,


config=None,


checkpoint_dir=None,


checkpoint_filename_with_path=None):


self._checkpoint_dir = checkpoint_dir


self._checkpoint_filename_with_path = checkpoint_filename_with_path


self._scaffold = scaffold or Scaffold()


self._session_manager = None


self._master = master


self._config = config


def _get_session_manager(self):


if self._session_manager:


return self._session_manager


self._session_manager = sm.SessionManager(    local_init_op=self._scaffold.local_init_op,    ready_op=self._scaffold.ready_op,    ready_for_local_init_op=self._scaffold.ready_for_local_init_op,    graph=ops.get_default_graph())return self._session_manager
复制代码


def create_session(self):


self._scaffold.finalize()


return self._get_session_manager().prepare_session(


self._master,


saver=self._scaffold.saver,


checkpoint_dir=self._checkpoint_dir,


checkpoint_filename_with_path=self._checkpoint_filename_with_path,


config=self._config,


init_op=self._scaffold.init_op,


init_feed_dict=self._scaffold.init_feed_dict,


init_fn=self._scaffold.init_fn)

WorkerSessionCreator

@tf_export(v1=[‘train.WorkerSessionCreator’])


class WorkerSessionCreator(SessionCreator):


“”“Creates a tf.compat.v1.Session for a worker.”""


def init(self,


scaffold=None,


master=’’,


config=None,


max_wait_secs=30 * 60):


“”"Initializes a worker session creator.


Args:  max_wait_secs: Maximum time to wait for the session to become available."""self._scaffold = scaffold or Scaffold()self._session_manager = Noneself._master = masterself._config = configself._max_wait_secs = max_wait_secs
复制代码


def _get_session_manager(self):


if self._session_manager:


return self._session_manager


self._session_manager = sm.SessionManager(    local_init_op=self._scaffold.local_init_op,    ready_op=self._scaffold.ready_op,    ready_for_local_init_op=self._scaffold.ready_for_local_init_op,    graph=ops.get_default_graph())return self._session_manager
复制代码


def create_session(self):


self._scaffold.finalize()


return self._get_session_manager().wait_for_session(


self._master, config=self._config, max_wait_secs=self._max_wait_secs)


从上面的源码中分析得到,MonitoredTrainingSession 可根据不同的角色去创建不同种类的 Session,其中 chief 节点是由 ChiefSessionCreator 类去创建 session,而非 chief 的 worker 节点是由 WorkerSessionCreator 类创建,特殊之处就是创建时调用的是 wait_for_session(),大致意识是需要等待 chief 节点的 session 创建完成之后才去创建属于自己节点的 session。其中创建 session 都是属于 SessionManager 类的一个方法,下面我们具体分析下 SessionManager 类:


官方针对 SessionManager 类有一个简单的例子,感觉很清楚:

prepare_session 函数可以初始化或者 restore 一个模型,需要传入init_opsaver

with tf.Graph().as_default():


# add operations to the graph…


# Create a SessionManager that will checkpoint the model in ‘/tmp/mydir’.


sm = SessionManager()


sess = sm.prepare_session(master, init_op, saver, checkpoint_dir)


# Use the session to train the graph.


while True:


sess.run(<my_train_op>)

第二个进程可以用以下方法启动 op,wait_for_session()的意思是需要等上面一个 session 创建好之后

再创建自己的 session

with tf.Graph().as_default():


# …add operations to the graph…


# Create a SessionManager that will wait for the model to become ready.


sm = SessionManager()


sess = sm.wait_for_session(master)


# Use the session to train the graph.


while True:


sess.run(<my_train_op>)


然后我们可以重点关注下 prepare_session 和 wait_for_session 这两个函数:


@tf_export(v1=[“train.SessionManager”])


class SessionManager(object):


def init(self,


local_init_op=None,


ready_op=None,


ready_for_local_init_op=None,


graph=None,


recovery_wait_secs=30,


local_init_run_options=None):


“”"


local_init_op 是每当有一个新的 session 被创建时,就会运行下 local_init_op 这个操作。


ready_op 用于 check 模型是否准备好的一个 op。


ready_for_local_init_op 是 checkp 模型是否已经可以运行 local_init_op 的一个 op。


“”"


# Sets default values of arguments.


if graph is None:


graph = ops.get_default_graph()


self._local_init_op = local_init_op


self._ready_op = ready_op


self._ready_for_local_init_op = ready_for_local_init_op


self._graph = graph


self._recovery_wait_secs = recovery_wait_secs


self._target = None


self._local_init_run_options = local_init_run_options


if ready_for_local_init_op is not None and local_init_op is None:


raise ValueError("If you pass a ready_for_local_init_op "


"you must also pass a local_init_op "


“, ready_for_local_init_op [%s]” %


ready_for_local_init_op)


def prepare_session(self,


master,


init_op=None,


saver=None,


checkpoint_dir=None,


checkpoint_filename_with_path=None,


wait_for_checkpoint=False,


max_wait_secs=7200,


config=None,


init_feed_dict=None,


init_fn=None):


“”"


其实 prepare_session 函数的作用就是如果有 checkpoint 存在,就从 checkpoint 恢复 session,如果


不存在 checkpoint 就从传入的init_op和 调用init_fn函数去创建 session。


“”"


sess, is_loaded_from_checkpoint = self._restore_checkpoint(


master,


saver,


checkpoint_dir=checkpoint_dir,


checkpoint_filename_with_path=checkpoint_filename_with_path,


wait_for_checkpoint=wait_for_checkpoint,


max_wait_secs=max_wait_secs,


config=config)


if not is_loaded_from_checkpoint:


if init_op is None and not init_fn and self._local_init_op is None:


raise RuntimeError("Model is not initialized and no init_op or "


“init_fn or local_init_op was given”)


if init_op is not None:


sess.run(init_op, feed_dict=init_feed_dict)


if init_fn:


init_fn(sess)


”“”



“”“


return sess


def wait_for_session(self, master, config=None, max_wait_secs=float(“Inf”)):


“”"


Creates a new Session and waits for model to be ready.


“”"


self._target = master


if max_wait_secs is None:


max_wait_secs = float(“Inf”)


timer = _CountDownTimer(max_wait_secs)


while True:  sess = session.Session(self._target, graph=self._graph, config=config)  not_ready_msg = None  not_ready_local_msg = None  local_init_success, not_ready_local_msg = self._try_run_local_init_op(      sess)  if local_init_success:    # Successful if local_init_op is None, or ready_for_local_init_op passes    is_ready, not_ready_msg = self._model_ready(sess)    if is_ready:      return sess  self._safe_close(sess)  # Do we have enough time left to try again?  remaining_ms_after_wait = (      timer.secs_remaining() - self._recovery_wait_secs)  if remaining_ms_after_wait < 0:    raise errors.DeadlineExceededError(        None, None,        "Session was not ready after waiting %d secs." % (max_wait_secs,))  logging.info("Waiting for model to be ready.  "               "Ready_for_local_init_op:  %s, ready: %s",               not_ready_local_msg, not_ready_msg)  time.sleep(self._recovery_wait_secs)
复制代码


创建完 session 之后,再包装一下返回最终的 MonitoredSession 类,


一个完整的 monitored session 在创建时间内可做的事情(按顺序):


为每个 hook 调用 hook.begin()


  • 调用 scaffold.finalize()完成 graph

  • 创建 session

  • 为模型参数做初始化 ,通过 Scaffold

  • 如果存在 checkpoint 则根据 checkpoint restore 参数

  • 发布 runners 队列

  • 调用 hook.after_create_session()函数

  • 当 run 函数调用时,monitored session 做的事情:

  • 调用 hook.before_run()

  • 调用 TensorFlow 中的 session.run() with merged fetches and feed_dict

  • 调用 hook.after_run()

  • 返回 session.run()的结果

  • 如果发生 AbortedError 或者 UnavailableError,则在再次执行 run()之前恢复或者重新初始化会话

  • 当 close()函数调用时,monitored session 做的事情:

  • 调用 hook.end()

  • 关闭 queue runners 和 session

  • 如果所有的输入数据被消耗完,抛出 OutOfRange 异常。

  • 最后,给大家贴一个使用 MonitoredSession 类进行分布式训练的 example:


from __future__ import print_function, absolute_import, division
import tensorflow as tf
tf.app.flags.DEFINE_string("ps_hosts", "localhost:2222", "ps hosts")tf.app.flags.DEFINE_string("worker_hosts", "localhost:2223,localhost:2224", "worker hosts")tf.app.flags.DEFINE_string("job_name", "worker", "'ps' or'worker'")tf.app.flags.DEFINE_integer("task_index", 0, "Index of task within the job")tf.app.flags.DEFINE_integer("num_workers", 2, "Number of workers")tf.app.flags.DEFINE_boolean("is_sync", False, "using synchronous training or not")
FLAGS = tf.app.flags.FLAGS

def model(images): """Define a simple mnist classifier""" net = tf.layers.dense(images, 500, activation=tf.nn.relu) net = tf.layers.dense(net, 500, activation=tf.nn.relu) net = tf.layers.dense(net, 10, activation=None) return net

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()x_train = x_train.reshape(-1, 784).astype('float32')x_test = x_test.reshape(-1, 784).astype('float32')x_train /= 255x_test /= 255

def get_batch(image, label, batch_size=32, training=True): df = tf.data.Dataset.from_tensor_slices((image, label)) if training: df = df.repeat(10).shuffle(buffer_size=1000) df = df.batch(batch_size).prefetch(batch_size) iterator = df.make_one_shot_iterator() batch_x, batch_y = iterator.get_next() return batch_x, batch_y

def main(_): ps_hosts = FLAGS.ps_hosts.split(",") worker_hosts = FLAGS.worker_hosts.split(",")
# create the cluster configured by `ps_hosts' and 'worker_hosts' cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts})
# create a server for local task server = tf.train.Server(cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index)
train_batch_x, train_batch_y = get_batch(x_train, y_train) test_batch_x, test_batch_y = get_batch(x_test, y_test, training=False)
if FLAGS.job_name == "ps": server.join() # ps hosts only join elif FLAGS.job_name == "worker": # workers perform the operation # ps_strategy = tf.contrib.training.GreedyLoadBalancingStrategy(FLAGS.num_ps)
# Note: tf.train.replica_device_setter automatically place the paramters (Variables) # on the ps hosts (default placement strategy: round-robin over all ps hosts, and also # place multi copies of operations to each worker host with tf.device(tf.train.replica_device_setter(worker_device="/job:worker/task:%d" % FLAGS.task_index, cluster=cluster)):
logits = model(train_batch_x) loss = tf.reduce_mean( tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=tf.one_hot(train_batch_y, 10)))
# The StopAtStepHook handles stopping after running given steps. hooks = [tf.train.StopAtStepHook(last_step=10000)]
global_step = tf.train.get_or_create_global_step() optimizer = tf.train.AdamOptimizer(learning_rate=1e-04) if FLAGS.is_sync: # asynchronous training # use tf.train.SyncReplicasOptimizer wrap optimizer # ref: https://www.tensorflow.org/api_docs/python/tf/train/SyncReplicasOptimizer optimizer = tf.train.SyncReplicasOptimizer(optimizer, replicas_to_aggregate=FLAGS.num_workers, total_num_replicas=FLAGS.num_workers) # create the hook which handles initialization and queues hooks.append(optimizer.make_session_run_hook((FLAGS.task_index == 0)))
train_op = optimizer.minimize(loss, global_step=global_step)
# The MonitoredTrainingSession takes care of session initialization, # restoring from a checkpoint, saving to a checkpoint, and closing when done # or an error occurs. with tf.train.MonitoredTrainingSession(master=server.target, is_chief=(FLAGS.task_index == 0), checkpoint_dir="./checkpoint_dir", hooks=hooks) as mon_sess: while not mon_sess.should_stop(): # mon_sess.run handles AbortedError in case of preempted PS. _, ls, step = mon_sess.run([train_op, loss, global_step]) if step % 100 == 0: print("Train step %d, loss: %f" % (step, ls))
if __name__ == "__main__": tf.app.run()
复制代码


参考文献:


https://www.cnblogs.com/estragon/p/10034511.html


https://zhuanlan.zhihu.com/p/88876923


本文转载自 Alex-zhai 知乎账号。


原文链接:https://zhuanlan.zhihu.com/p/91608555


2019-11-28 08:001351

评论

发布
暂无评论
发现更多内容

干货时间:聊聊DevOps下的技术系列之契约测试

华为云开发者联盟

DevOps 测试 交互

等不到明年金三银四了!五面滴滴之路,爆砍37K+16薪Offer

Java架构追梦

Java 学习 架构 面试 滴滴

倍频程与钢琴调式的距离

阿里云视频云

音频技术 音频

OPPO小布助手正在改变普罗米修斯的世界

脑极体

阿里架构师478页Java工程师面试知识解析笔记pdf,一份2021年通往阿里的面试指南

Java架构之路

Java 程序员 架构 面试 编程语言

学透这份300页的2020最新java面试题及答案,一线大厂offer随便拿

Java架构之路

Java 程序员 架构 面试 编程语言

量化交易系统开发搭建案例

薇電13242772558

区块链 策略模式

被阿里、腾讯、华为追捧为最牛逼的 Java 框架你知道是什么吗?

Java架构师迁哥

大作业1

龙卷风

架构师一期

阿里开发10年,全部心血汇聚成到这份文档里,拿到30W的offer没问题

Java架构之路

Java 程序员 架构 面试 编程语言

云服务的可服务性经典6问

华为云开发者联盟

服务 计算

5. 穿过拥挤的人潮,Spring已为你制作好高级赛道

YourBatman

Spring Framework 类型转换 Converter

LeetCode题解:55. 跳跃游戏,贪心,JavaScript,详细注释

Lee Chen

算法 大前端 LeetCode

大众汽车“芯片荒”,折射汽车芯片的漫漫“自主替代”路

脑极体

由于不知线程池的bug,某Java程序员叕被祭天

Java架构师迁哥

volatile,synchronized可见性,有序性,原子性代码证明(基础硬核)

叫练

volatile 多线程 synchronized 原子性 指令

比特币10年:从2个披萨涨到2万美金,背后的三个“神秘人”

CECBC

比特币

阿里技术官亲荐“998页的应届生面试手册”看完才发现,原来求职也没那么难!

比伯

Java 程序员 面试 编程语言 计算机

一文带你了解传统手工特征的骨龄评估方法的发展历史

华为云开发者联盟

方法 骨龄 评估

App性能测试揭秘(Android篇)

移动研发平台EMAS

阿里云 软件测试 测试 性能测试 云性能测试

测开之函数进阶· 第2篇《纯函数》

清菡软件测试

测试开发

PostgreSQL:您可能需要增加MAX_LOCKS_PER_TRANSACTION

PostgreSQLChina

数据库 postgresql 开源

【Java入门】流

Albert

Java 七日更

GitHub上3天1W赞的程序员学习路线!入门进阶都非常实用

Java架构之路

Java 程序员 架构 面试 编程语言

架构师训练营W10作业

Geek_f06ede

TypeScript | 第二章:类、接口和之间的关系

梁龙先森

typescript 大前端 七日更

基于App SDK和API搭建无人自习室等无人场景

IoT云工坊

物联网 智慧琴房 24小时无人自习室 24小时自助游戏厅 共享办公室

KKR四币连发挖矿系统软件APP开发

系统开发

物联网打工人必备:LiteOS Studio图形化调测能力

华为云开发者联盟

互联网 LiteOS 打工人

软件测试(功能、接口、性能、自动化)详解

测试人生路

软件测试

“区块链+社会治理”模式获居民点赞

CECBC

区块链 区块链投票

分布式tensorflow源码解读2:MonitoredTrainingSession_语言 & 开发_Alex-zhai_InfoQ精选文章