10 月 23 - 25 日,QCon 上海站即将召开,现在购票,享9折优惠 了解详情
写点什么

浅谈 Tensorflow 分布式架构:parameter server 及优化策略

  • 2019-12-02
  • 本文字数:3380 字

    阅读完需:约 11 分钟

浅谈Tensorflow分布式架构:parameter server及优化策略

当我们想将一个单机的 tensorflow 训练程序改写成分布式训练(多机多卡)的时候,一般有两个大方向的选择:1.完全异步的梯度更新策略,其代表方法是 parameter server 架构。2.同步的梯度更新策略,代表方法有:百度的 ring all-reduce 策略。本文首先介绍 parameter server 架构。

parameter server 策略:

parameter server 异步更新策略是指每个 GPU 或者 CPU 计算完梯度后,无需等待其他 GPU 或 CPU 的梯度计算(有时可以设置需要等待的梯度个数),就可立即更新整体的权值,然后同步此权值,即可进行下一轮计算。



parameter server 的架构


而 Tensorflow 一开始支持分布式的时候,便是这种 parameter server 架构。TensorFlow 一般将任务分为两类 job:一类叫参数服务器,parameter server,简称为 ps,用于存储可训练的参数变量 tf.Variable;一类就是普通任务,称为 worker,用于执行具体的计算。


Tensorflow 支持两种方式实现 parameter server:低阶 API 创建 parameter server 集群方式和 tf.distribute.Strategy 中的 ParameterServerStrategy。

低阶 API 创建 parameter server 集群

完整案例 dist_tf.py:


import tensorflow as tfimport numpy as np
# 创建集群信息,包括ps和worker两种角色。# 集群有两类任务,ps和worker;ps由2个任务组成(一般一个任务是一个机器或者一个分配单元),worker由3个任务组成。ps_hosts = ["xx.xxx.xx.xxxx:oooo", "xx.xxx.xx.xxxx:oooo"]worker_hosts = ["xx.xxx.xx.xxxx:oooo", "xx.xxx.xx.xxxx:oooo", "xx.xxx.xx.xxxx:oooo"]cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts})
tf.app.flags.DEFINE_string("job_name", "worker", "One of 'ps', 'worker'")tf.app.flags.DEFINE_integer("task_index", 0, "Index of task within the job")FLAGS = tf.app.flags.FLAGS
def main(_): server = tf.train.Server(cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index) if FLAGS.job_name == "ps": server.join() else: # 会根据job名,将with内的Variable op放到ps tasks,将其他计算op放到worker tasks。默认分配策略是轮询 with tf.device(tf.train.replica_device_setter( worker_device="/job:worker/task:%d" % FLAGS.task_index, cluster=cluster)):
x_data = tf.placeholder(tf.float32, [100]) y_data = tf.placeholder(tf.float32, [100])
W = tf.Variable(tf.random_uniform([1], -1.0, 1.0)) b = tf.Variable(tf.zeros([1])) y = W * x_data + b loss = tf.reduce_mean(tf.square(y - y_data))
global_step = tf.Variable(0, name="global_step", trainable=False) optimizer = tf.train.GradientDescentOptimizer(0.1) train_op = optimizer.minimize(loss, global_step=global_step)
# The StopAtStepHook handles stopping after running given steps. hooks = [tf.train.StopAtStepHook(last_step=1000000)] # The MonitoredTrainingSession takes care of session initialization, # restoring from a checkpoint, saving to a checkpoint, and closing when done # or an error occurs. with tf.train.MonitoredTrainingSession(master=server.target, is_chief=(FLAGS.task_index == 0), # 我们制定task_index为0的任务为主任务,用于负责变量初始化、做checkpoint、保存summary和复原 checkpoint_dir="/tmp/tf_train_logs", save_checkpoint_secs=None, hooks=hooks) as mon_sess: while not mon_sess.should_stop(): # Run a training step asynchronously. # See `tf.train.SyncReplicasOptimizer` for additional details on how to # perform *synchronous* training. # mon_sess.run handles AbortedError in case of preempted PS. train_x = np.random.rand(100).astype(np.float32) train_y = train_x * 0.1 + 0.3 _, step, loss_v, weight, biase = mon_sess.run([train_op, global_step, loss, W, b], feed_dict={x_data: train_x, y_data: train_y}) if step % 100 == 0: print("step: %d, weight: %f, biase: %f, loss: %f" % (step, weight, biase, loss_v)) print("Optimization finished.")

if __name__ == "__main__": tf.app.run()
复制代码


对于本例而言,我们需要在对应的 5 台机器上分别运行每个任务,共需执行五次代码,生成五个任务。


python dist_tf.py --job_name=ps --task_index=0python dist_tf.py --job_name=ps --task_index=1python dist_tf.py --job_name=worker --task_index=0python dist_tf.py --job_name=worker --task_index=1python dist_tf.py --job_name=worker --task_index=2
复制代码


低阶 API 创建 parameter server 集群缺点:


概念多,学习曲线陡峭。


单机代码到多机修改的代码量大。


需要多台机子跑不同的脚本,当然这可以通过 k8s 集群管理工具来解决。


PS 和 Worker 的比例不好选取。(建议选取偶数个的 ps,我的经验是 ps 和 worker 的比例是 1:3)


训练速度性能损失较大。(通信代价较高)


parameter server 常见的优化点:


如果有参数量较大的 embedding 变量时,可选择使用 embedding_lookup_sparse_with_distributed_aggregation 函数替代 tf.nn.embedding_lookup_sparse 函数。该函数可将 embedding 的聚合计算都放在变量所在的 PS 端,计算后转成稠密张量再传送到 Worker 上继续网络模型的计算。


tf.device 函数中有一个参数是设置变量在 ps 端放置策略的,可使用 tf.contrib.training.GreedyLoadBalancingStrategy 来替代默认的轮循。优点是:可根据参数的内存字节来完成类似在线垃圾收集的工作。根据 weight 和 bias 的字节数来放置到内存合适的 task 中,带来更好的负载平衡。


当参数有超大量级时(比如 embedding 参数),可在创建变量的时候使用分割变量策略:partitioner=tf.fixed_size_partitioner(ps_nums)


优化 input pipeline。链接:https://www.tensorflow.org/guide/performance/datasets


bandwidth 高带宽范亲和策略,保证多个 ps 分布在不同的物理机上。


Estimator 中的 ParameterServerStrategy 策略


# https://stackoverflow.com/questions/55003279/parameter-server-strategy-with-estimatorstensorflowimport tensorflow as tfimport osimport json
NUM_WORKERS = 1IP_ADDRS = ['localhost']PORTS = [12345]
def model_fn(...): .....
def input_fn(...): .....
复制代码

需要每个机器配置 TF_CONFIG 环境变量

os.environ['TF_CONFIG'] = json.dumps({    'cluster': {        'worker': ['%s:%d' % (IP_ADDRS[w], PORTS[w]) for w in range(NUM_WORKERS)],        'ps': ['%s:%d' % (IP_ADDRS[w], PORTS[w]) for w in range(NUM_WORKERS)]    },    'task': {'type': 'worker', 'index': 0}})
# Method for using ParamterServerStrategystrategy = tf.distribute.experimental.ParameterServerStrategy()
config = tf.estimator.RunConfig(train_distribute=strategy)
classifier = tf.estimator.Estimator( model_fn=model_fn, model_dir='/tmp/multiworker', config=config)tf.estimator.train_and_evaluate( classifier, train_spec=tf.estimator.TrainSpec(input_fn=input_fn), eval_spec=tf.estimator.EvalSpec(input_fn=input_fn))
复制代码


本文转载自 Alex-zhai 知乎账号。


原文链接:https://zhuanlan.zhihu.com/p/69010949


2019-12-02 16:234692

评论

发布
暂无评论
发现更多内容

localStorage和sessionStorage本地存储

我是哪吒

html html5 面试 大前端 html/css

软件教练说:性能优化与性能设计,“相亲相爱”的一对

华为云开发者联盟

架构 性能优化 设计 程序 软件教练

30+程序员竞争力从哪里来?

我心依然

程序员 竞争力

【Node.js】事件触发器 - 基础篇

德育处主任

Node 28天写作

面试官问:ZooKeeper是强一致的吗?怎么实现的?

Java 编程 程序员 面试 分布式

成长篇-结构思考力笔记(完整版)

小诚信驿站

程序员 刘晓成 小诚信驿站 成长笔记 28天写作

两种端到端通用目标检测方法

华为云开发者联盟

训练 目标检测 端到端 DETR DeFCN

融资融券两融系统搭建开发

v16629866266

老外程序员的Java性能优化方式是什么?JVM调优策略+工具+技巧

Java架构追梦

Java 学习 架构 面试 jvm调优

14天1000+大集群滚动升级,银行柜台竟然毫无感觉

华为云开发者联盟

大数据 金融 FusionInsight 华为云 集群

Soul 源码阅读 06|Nacos 同步数据分析

哼干嘛

数字货币将如何改变日常生活

CECBC

数字货币

jdk8 String和StringBuilder对象创建所在位置

ilovealt

Java string StringBuilder

NeoKylin-Server-5.0离线部署etcd+flannel集群,实现docker容器跨主机网络通信

星河寒水

Docker etcd flannel 麒麟操作系统 离线部署

产业互联网业务与团队的思考

Geek_vidmje

真狠!涵盖了Netty+Spark+Hadoop+分布式五部分!讲的清清楚楚!

996小迁

redis hadoop 架构 面试 Netty

BAT面试Spring全家桶:Spring+SpringBoot+SpringCloud+SpringMVC

Java架构之路

Java 程序员 架构 面试 编程语言

架构解读丨Volcano作业资源预留设计原理

华为云开发者联盟

批处理 Volcano 资源预留 作业资源预留

区块链真正的价值即将“引爆”行业应用

CECBC

区块链金融

奇葩java迭代器笔试题,做对算你厉害

田维常

迭代器模式

蚂蚁金服二面被血虐,鬼知道面试的我经历了什么?

Java架构之路

Java 程序员 架构 面试 编程语言

苹果设备电池及充电周期

张老蔫

28天写作

Elasticsearch 是分布式文件存储么 ?

escray

elastic 七日更 28天写作 死磕Elasticsearch 60天通过Elastic认证考试

区块链人才能力评价测试机构亮相

CECBC

区块链人才

谁,是产品的利益相关方?

不离

极客大学认识产品经理 极客大学产品经理训练营 跟着二爷学产品

浅说 SQLite 的许可证模式

Justin

开源 版权保护 28天写作

架构师训练营第 2 期 第 7 周 作业一

老腊肉

架构师训练营第2期

阿里一线架构师甩出“源码阅读指南”,从源码到实战,一键搞定

比伯

Java 编程 程序员 架构 计算机

Java学习笔记整理:Spring+tomcat+Kafka+多线程面试笔记

Java架构之路

Java 程序员 架构 面试 编程语言

Mybatis【16】-- Mybatis多对一关联查询

秦怀杂货店

数据库 mybatis

团队建设,凝聚人心打胜战

一笑

管理 团队建设 28天写作

浅谈Tensorflow分布式架构:parameter server及优化策略_语言 & 开发_Alex-zhai_InfoQ精选文章