NVIDIA 初创加速计划,免费加速您的创业启动 了解详情
写点什么

Stale Synchronous Parallel Parameter Server 解读和代码实现

  • 2019-12-02
  • 本文字数:3544 字

    阅读完需:约 12 分钟

Stale Synchronous Parallel Parameter Server解读和代码实现

论文解读:


常见的并行训练的模式有两种:


  • 同步,各个 worker 并行训练,每次更新梯度时都等待所有 worker 完成本次迭代的计算,然后一起开始下一次迭代。最简单的同步就是每个 worker 的梯度求和求平均,然后更新参数,目前比较流行的同步更新方法则是 ring-all-ruduce 方法。可以保证梯度的正确性。但是速度较慢。

  • 完全异步,各个 worker 并行训练,各自处理各自的数据不等待其他任何 worker。速度比较快,但是梯度有损失。

  • 存在的问题:传统的 SGD 是基于 batch 更新的,并行训练时各个 worker 计算当前 batch 的梯度,然后反向传播之后 push 梯度,然后 pull 最新的参数再处理下一个 batch。这个时候如果当一个 worker 更新速度特别慢,这个 worker push 的梯度是使用一个非常旧的参数计算出来的,这个梯度可能已经不适合当下的参数,甚至有时候会起到反作用。


本文提出的 SSP 方法来让 worker 在效率和正确性上做一个良好的权衡。


核心思想:各个 worker 并行训练,每次进行下一次迭代时判断一下自己的迭代比整个系统中最慢的节点的迭代快多少个 step,如果达到一个阈值就进入等待状态直到 step 小于阈值开始下一次计算。

代码实现

参考:https://blog.csdn.net/li57681522/article/details/87920210


# -*- coding:utf-8 -*-
# python dis_tf_ssp.py --job_name=ps --task_index=0# python dis_tf_ssp.py --job_name=worker --task_index=0# python dis_tf_ssp.py --job_name=worker --task_index=1
import timeimport numpy as npimport tensorflow as tf
from tensorflow.python.util.tf_export import tf_exportfrom tensorflow.python.ops import state_ops, variables, variable_scopefrom tensorflow.python.training import session_run_hook
# Define parametersFLAGS = tf.app.flags.FLAGStf.app.flags.DEFINE_float('learning_rate', 0.00003, 'Initial learning rate.')tf.app.flags.DEFINE_integer('steps_to_validate', 1000, 'Steps to validate and print loss')
# For distributedtf.app.flags.DEFINE_string("ps_hosts", "172.20.181.16:2222", "Comma-separated list of hostname:port pairs")tf.app.flags.DEFINE_string("worker_hosts", "172.20.181.16:2224", "Comma-separated list of hostname:port pairs")tf.app.flags.DEFINE_string("job_name", "worker", "One of 'ps', 'worker'")tf.app.flags.DEFINE_integer("task_index", 0, "Index of task within the job")
# Hyperparameterslearning_rate = FLAGS.learning_ratesteps_to_validate = FLAGS.steps_to_validate

# @tf_export("train.SemiSyncRunHook")# class SemiSyncRunHook(session_run_hook.SessionRunHook):class SemiSyncRunHook(tf.train.SessionRunHook): """Run by SSP."""
def __init__(self, index, worker_count, staleness=10): """Initializes a `SemiSyncRunHook`. Args: index: work index worker_count: number of workers staleness: """
if index >= worker_count: print("worker index {} is bigger than worker_count {}".format(index, worker_count)) return
self._const_max_test_step = 10000 self._last_step = 0 # 上一次wait的步骤数 self._last_time = self._now_time() # 上一次wait的时间
self._index = index self._staleness = staleness self._wait_time = 0.01 # 等待时间,单位:秒;这个时间不能设置的太长,跟worker的训练速度和staleness相关 self._worker_steps = [] # 记录worker训练步骤数的变量列表
for i in range(worker_count): worker_step = variable_scope.variable(0, trainable=False, name="worker_step_" + str(i)) self._worker_steps.append(worker_step) if i == index: self._my_step_update_op = state_ops.assign_add(worker_step, 1)
self._initialize_op = variables.variables_initializer(self._worker_steps)
def _now_time(self): return time.time()
def after_create_session(self, session, coord): session.run(self._initialize_op) # 初始化记录worker训练步骤数的变量
def before_run(self, run_context): run_context.session.run(self._my_step_update_op) # 更新本worker的训练步骤数 return None
def after_run(self, run_context, run_values): while True: # 1.获取所有worker的训练步骤数 all_worker_steps = run_context.session.run(self._worker_steps) # print("all worker steps={}. my work id={}".format(all_worker_steps, self._index))
# 2.如果训练当前worker的训练步骤数 > 最小worker训练步骤数 + staleness,sleep(10ms); 否则 break; if all_worker_steps[self._index] > min(all_worker_steps) + self._staleness: diff_step = all_worker_steps[self._index] - self._last_step if diff_step / self._const_max_test_step > 1: self._wait_time = (self._now_time() - self._last_time) / diff_step * self._staleness * 0.7
# 更新 self._last_step = all_worker_steps[self._index] self._last_time = self._now_time()
time.sleep(self._wait_time) # 等待慢worker执行 # print("all worker steps={}, my work id={}. waiting {}s...".format(all_worker_steps, self._index, self._wait_time)) else: break

def main(_): ps_hosts = FLAGS.ps_hosts.split(",") worker_hosts = FLAGS.worker_hosts.split(",") cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts}) server = tf.train.Server(cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index)
worker_count = len(worker_hosts)
if FLAGS.job_name == "ps": server.join() elif FLAGS.job_name == "worker": with tf.device(tf.train.replica_device_setter( worker_device="/job:worker/task:%d" % FLAGS.task_index, cluster=cluster)): global_step = tf.Variable(0, name='global_step', trainable=False)
X = tf.placeholder(tf.float32) Y = tf.placeholder(tf.float32) w = tf.Variable(0.0, name="weight") b = tf.Variable(0.0, name="reminder") y = w * X + b
loss = tf.reduce_mean(tf.square(y - Y)) optimizer = tf.train.GradientDescentOptimizer(learning_rate)
# 更新梯度 train_op = optimizer.minimize(loss, global_step=global_step)
hooks = [tf.train.StopAtStepHook(last_step=1000000)]
semiSyncRunHook = SemiSyncRunHook(FLAGS.task_index, worker_count=worker_count, staleness=10) hooks.append(semiSyncRunHook)
with tf.train.MonitoredTrainingSession( master=server.target, is_chief=(FLAGS.task_index == 0), checkpoint_dir="./ssp_saved_model", hooks=hooks) as mon_sess: while not mon_sess.should_stop(): train_x = np.random.randn(1) train_y = 2 * train_x + np.random.randn(1) * 0.33 + 10 _, loss_v, step = mon_sess.run([train_op, loss, global_step], feed_dict={X: train_x, Y: train_y}) if step % steps_to_validate == 0: w_, b_ = mon_sess.run([w, b]) print("step: %d, weight: %f, biase: %f, loss: %f" % (step, w_, b_, loss_v))

if __name__ == "__main__": tf.app.run()
复制代码


参考文献:


More Effective Distributed ML via a Stale Synchronous Parallel Parameter Server


本文转载自 Alex-zhai 知乎账号。


原文链接:https://zhuanlan.zhihu.com/p/69665236


2019-12-02 16:25821

评论

发布
暂无评论
发现更多内容

EMQ 映云科技与 RT-Thread 达成战略合作,共建产业物联网平台

EMQ映云科技

人工智能 云计算 大数据 物联网 emq

「技术点串烧」☕【Java 技术指南】「难点-核心-遗漏」Java线程状态流转及生命周期的技术指南!

洛神灬殇

Java 线程 Thread 9月日更

低代码的5个误区,你踩雷了吗?

禅道项目管理

低代码 开发

一文读懂数据库最新技术趋势:TDSQL带你深度纵览VLDB 2019

腾讯云数据库

数据库

狂刷《Java权威面试指南(阿里版)》,冲击“金九银十”有望了

Java 编程 架构 面试 程序人生

为数据赋能:腾讯TDSQL分布式金融级数据库前沿技术 - 云+社区 - 腾讯云

腾讯云数据库

数据库 tdsql

打爆怪兽 一起来养猪 养蜂人 幸福饭店

游戏开发_软件开发

软件 App 开发 游戏 语音合成

令我入职阿里的750页微服务架构深度解析文档有何神秘之处?

Java 编程 架构 面试 架构师

TLS协议分析 (五) handshake协议 证书与密钥交换

OpenIM

论亚马逊QLDB与腾讯TDSQL对历史数据的管理和计算

腾讯云数据库

数据库 tdsql

架构训练营 模块一

Leach Sun

华为云GaussDB首次亮相2021服贸会,为数字人民币提供坚实数据底座

华为云数据库小助手

金融科技 数字经济 GaussDB 华为云数据库

揭秘TDSQL-A:兼容Oracle的同时支持海量数据交互

腾讯云数据库

数据库 tdsql

2021云计算白皮书发布,腾讯云原生数据库TDSQL-C助力共建云上技术生态

腾讯云数据库

数据库 tdsql

TDSQL:深度剖析数据库国产化迁移之路

腾讯云数据库

数据库 tdsql

安卓工控主板双网口有什么用途?

双赞工控

安卓主板 工控主板

腾讯云数据库TDSQL两篇论文入选数据库顶会SIGMOD,产学研结合助力国产数据库生态建设

腾讯云数据库

数据库 tdsql

如何在MacOS上无缝切换Win11和MacOS?

Zhendong

MacBook m1 Parallels

springboot项目集成docker

try catch

Docker Dockerfile springboot

浅析 DDD 领域驱动设计

牧小农

DDD 领域驱动

把工作讲给家人听

FunTester

读书笔记 FunTester 奈非文化手册 办公效率 居家工作

精品!阿里P7爆款《K8s+Jenkins》技术笔记,高质量干货必收藏

Java 程序员 架构 面试 k8s

腾讯云 TDSQL 审计原理揭秘

腾讯云数据库

数据库 tdsql

Android技术分享| 开源Demo any自习室布局架构

anyRTC开发者

音视频 移动开发 在线自习室 Android技术分享

微信或推出聊天记录云备份付费服务,你的微信记录值多少钱?

郑州埃文科技

云服务 微信聊天 数据风险管理

高性能利器:CDN我建议你好好学一下!

九灵

Java 分布式 微服务 CDN

在同一台计算机中运行多个MySQL服务

Java 数据库 后端 msyql

搞懂现代Web端即时通讯技术一文就够:WebSocket、socket.io、SSE

JackJiang

websocket 即时通讯 IM

TDSQL:从自主可控金融级数据库看腾讯“智能+”技术中台之路

腾讯云数据库

数据库 tdsql

TDSQL:关于未来,数据库大咖们都聊了什么?

腾讯云数据库

数据库 tdsql

你的 SQL 还在回表查询吗?快给它安排上覆盖索引

Java MySQL 数据库 后端

Stale Synchronous Parallel Parameter Server解读和代码实现_语言 & 开发_Rick_InfoQ精选文章