阿里云「飞天发布时刻」2024来啦!新产品、新特性、新能力、新方案,等你来探~ 了解详情
写点什么

Stale Synchronous Parallel Parameter Server 解读和代码实现

  • 2019-12-02
  • 本文字数:3544 字

    阅读完需:约 12 分钟

Stale Synchronous Parallel Parameter Server解读和代码实现

论文解读:


常见的并行训练的模式有两种:


  • 同步,各个 worker 并行训练,每次更新梯度时都等待所有 worker 完成本次迭代的计算,然后一起开始下一次迭代。最简单的同步就是每个 worker 的梯度求和求平均,然后更新参数,目前比较流行的同步更新方法则是 ring-all-ruduce 方法。可以保证梯度的正确性。但是速度较慢。

  • 完全异步,各个 worker 并行训练,各自处理各自的数据不等待其他任何 worker。速度比较快,但是梯度有损失。

  • 存在的问题:传统的 SGD 是基于 batch 更新的,并行训练时各个 worker 计算当前 batch 的梯度,然后反向传播之后 push 梯度,然后 pull 最新的参数再处理下一个 batch。这个时候如果当一个 worker 更新速度特别慢,这个 worker push 的梯度是使用一个非常旧的参数计算出来的,这个梯度可能已经不适合当下的参数,甚至有时候会起到反作用。


本文提出的 SSP 方法来让 worker 在效率和正确性上做一个良好的权衡。


核心思想:各个 worker 并行训练,每次进行下一次迭代时判断一下自己的迭代比整个系统中最慢的节点的迭代快多少个 step,如果达到一个阈值就进入等待状态直到 step 小于阈值开始下一次计算。

代码实现

参考:https://blog.csdn.net/li57681522/article/details/87920210


# -*- coding:utf-8 -*-
# python dis_tf_ssp.py --job_name=ps --task_index=0# python dis_tf_ssp.py --job_name=worker --task_index=0# python dis_tf_ssp.py --job_name=worker --task_index=1
import timeimport numpy as npimport tensorflow as tf
from tensorflow.python.util.tf_export import tf_exportfrom tensorflow.python.ops import state_ops, variables, variable_scopefrom tensorflow.python.training import session_run_hook
# Define parametersFLAGS = tf.app.flags.FLAGStf.app.flags.DEFINE_float('learning_rate', 0.00003, 'Initial learning rate.')tf.app.flags.DEFINE_integer('steps_to_validate', 1000, 'Steps to validate and print loss')
# For distributedtf.app.flags.DEFINE_string("ps_hosts", "172.20.181.16:2222", "Comma-separated list of hostname:port pairs")tf.app.flags.DEFINE_string("worker_hosts", "172.20.181.16:2224", "Comma-separated list of hostname:port pairs")tf.app.flags.DEFINE_string("job_name", "worker", "One of 'ps', 'worker'")tf.app.flags.DEFINE_integer("task_index", 0, "Index of task within the job")
# Hyperparameterslearning_rate = FLAGS.learning_ratesteps_to_validate = FLAGS.steps_to_validate

# @tf_export("train.SemiSyncRunHook")# class SemiSyncRunHook(session_run_hook.SessionRunHook):class SemiSyncRunHook(tf.train.SessionRunHook): """Run by SSP."""
def __init__(self, index, worker_count, staleness=10): """Initializes a `SemiSyncRunHook`. Args: index: work index worker_count: number of workers staleness: """
if index >= worker_count: print("worker index {} is bigger than worker_count {}".format(index, worker_count)) return
self._const_max_test_step = 10000 self._last_step = 0 # 上一次wait的步骤数 self._last_time = self._now_time() # 上一次wait的时间
self._index = index self._staleness = staleness self._wait_time = 0.01 # 等待时间,单位:秒;这个时间不能设置的太长,跟worker的训练速度和staleness相关 self._worker_steps = [] # 记录worker训练步骤数的变量列表
for i in range(worker_count): worker_step = variable_scope.variable(0, trainable=False, name="worker_step_" + str(i)) self._worker_steps.append(worker_step) if i == index: self._my_step_update_op = state_ops.assign_add(worker_step, 1)
self._initialize_op = variables.variables_initializer(self._worker_steps)
def _now_time(self): return time.time()
def after_create_session(self, session, coord): session.run(self._initialize_op) # 初始化记录worker训练步骤数的变量
def before_run(self, run_context): run_context.session.run(self._my_step_update_op) # 更新本worker的训练步骤数 return None
def after_run(self, run_context, run_values): while True: # 1.获取所有worker的训练步骤数 all_worker_steps = run_context.session.run(self._worker_steps) # print("all worker steps={}. my work id={}".format(all_worker_steps, self._index))
# 2.如果训练当前worker的训练步骤数 > 最小worker训练步骤数 + staleness,sleep(10ms); 否则 break; if all_worker_steps[self._index] > min(all_worker_steps) + self._staleness: diff_step = all_worker_steps[self._index] - self._last_step if diff_step / self._const_max_test_step > 1: self._wait_time = (self._now_time() - self._last_time) / diff_step * self._staleness * 0.7
# 更新 self._last_step = all_worker_steps[self._index] self._last_time = self._now_time()
time.sleep(self._wait_time) # 等待慢worker执行 # print("all worker steps={}, my work id={}. waiting {}s...".format(all_worker_steps, self._index, self._wait_time)) else: break

def main(_): ps_hosts = FLAGS.ps_hosts.split(",") worker_hosts = FLAGS.worker_hosts.split(",") cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts}) server = tf.train.Server(cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index)
worker_count = len(worker_hosts)
if FLAGS.job_name == "ps": server.join() elif FLAGS.job_name == "worker": with tf.device(tf.train.replica_device_setter( worker_device="/job:worker/task:%d" % FLAGS.task_index, cluster=cluster)): global_step = tf.Variable(0, name='global_step', trainable=False)
X = tf.placeholder(tf.float32) Y = tf.placeholder(tf.float32) w = tf.Variable(0.0, name="weight") b = tf.Variable(0.0, name="reminder") y = w * X + b
loss = tf.reduce_mean(tf.square(y - Y)) optimizer = tf.train.GradientDescentOptimizer(learning_rate)
# 更新梯度 train_op = optimizer.minimize(loss, global_step=global_step)
hooks = [tf.train.StopAtStepHook(last_step=1000000)]
semiSyncRunHook = SemiSyncRunHook(FLAGS.task_index, worker_count=worker_count, staleness=10) hooks.append(semiSyncRunHook)
with tf.train.MonitoredTrainingSession( master=server.target, is_chief=(FLAGS.task_index == 0), checkpoint_dir="./ssp_saved_model", hooks=hooks) as mon_sess: while not mon_sess.should_stop(): train_x = np.random.randn(1) train_y = 2 * train_x + np.random.randn(1) * 0.33 + 10 _, loss_v, step = mon_sess.run([train_op, loss, global_step], feed_dict={X: train_x, Y: train_y}) if step % steps_to_validate == 0: w_, b_ = mon_sess.run([w, b]) print("step: %d, weight: %f, biase: %f, loss: %f" % (step, w_, b_, loss_v))

if __name__ == "__main__": tf.app.run()
复制代码


参考文献:


More Effective Distributed ML via a Stale Synchronous Parallel Parameter Server


本文转载自 Alex-zhai 知乎账号。


原文链接:https://zhuanlan.zhihu.com/p/69665236


2019-12-02 16:25819

评论

发布
暂无评论
发现更多内容

QSDK/ipq5018/2T2R/Bluetooth BT5.1 supporting QCN9074/industrial wifi6 module

wallysSK

QCN9074 ipq5018

Java: 在Excel中插入和提取图片

Geek_249eec

Java Excel 图片

迎接工业互联网的龙卷风暴,软通动力绘制了一张转型地图

脑极体

武汉前端培训学习靠不靠谱?

小谷哥

ERP系统是什么?能起到什么作用?

优秀

ERP系统

那些你不知道的炫酷开关交互效果(12种)

南城FE

CSS 前端 交互设计

纷繁复杂见真章 华为云大型产品需求管理利器CodeArts Req解读

Geek_2d6073

研发 Leader 怎样写出非研发也看得懂的年终总结?

思码逸研发效能

研发效能 年终总结

云计算技术是基于互联网和网络的新技术

Finovy Cloud

云服务器 云技术 云渲染

柏睿数据完成阿里云PolarDB数据库产品生态集成认证

阿里云数据库开源

阿里云 polarDB PolarDB-X PolarDB-PG PolarDB for PostgreSQL

大数据有没有必要参加培训?

小谷哥

体验百度Java后端一面凉经,让我有了新的感悟

小小怪下士

Java 百度 程序员 面试

AI作画技术实践第二期|用腾讯云智能图片融合优化AI绘画的效果

牵着蜗牛去散步

腾讯云 腾讯 AI作画 腾讯云智能 智能内容创作

学习web前端应该选择哪个培训机构?

小谷哥

Best Machine Learning Tools for Java

Mahipal_Nehra

Java AI Machine Learning tools best tools

公司项目终于用上了插入式注解,真香!

Java永远的神

程序员 程序人生 项目 架构师 后端开发

行业方案 | 新规落地,企业集团财务公司如何构建数智财务体系?

袋鼠云数栈

任务管理轻松实现大规模设备管理控制——设备管理类

阿里云AIoT

运维 监控 云安全 消息中间件 储存

数据治理:聊聊数据血缘!

用友BIP

安全可信 | 强墙出击!天翼云Web应用防火墙(原生版)硬核亮相!

天翼云开发者社区

安全 防火墙

重磅 | 招商局集团、招商局港口荣获CGMA年度大奖——九科信息与百年招商局共同探索财务数智化转型之路

九科Ninetech

零基础学习前端开发培训机构怎么选

小谷哥

瓴羊Quick BI 权限管理:开拓数据分析效率和智能化水平的新高度

对不起该用户已成仙‖

数字先锋 | 主机、硬盘、CPU统统没有? 这个电教室有点“潮”!

天翼云开发者社区

云主机 云电脑

LeaRun.net代码生成器 一键生成前后端代码

力软低代码开发平台

创新研发负载分担机制,天翼云IPv6网络带宽再升级!

天翼云开发者社区

负载均衡 网络 ipv6

聚焦技术,锐意创新,GaussDB给世界一个更优选择

Geek_2d6073

我人傻了!新入职的同事三下五除二就搭建了一个简易版秒杀系统

程序员小毕

程序员 程序人生 后端 架构师 秒杀系统

一名曾因线上P0故障导致月工资扣了10%的码农心得:如何在故障10分钟黄金时间快速排障

KINDLING

Java 运维 可观测性 线上故障 ebpf

零基础去程序员培训机构靠不靠谱?

小谷哥

可视化:数据可视化的作用

Data 探险实验室

数据分析 可视化 数据可视化 数据大屏

Stale Synchronous Parallel Parameter Server解读和代码实现_语言 & 开发_Rick_InfoQ精选文章