写点什么

浅谈 Tensorflow 分布式架构:parameter server 及优化策略

  • 2019-12-02
  • 本文字数:3380 字

    阅读完需:约 11 分钟

浅谈Tensorflow分布式架构:parameter server及优化策略

当我们想将一个单机的 tensorflow 训练程序改写成分布式训练(多机多卡)的时候,一般有两个大方向的选择:1.完全异步的梯度更新策略,其代表方法是 parameter server 架构。2.同步的梯度更新策略,代表方法有:百度的 ring all-reduce 策略。本文首先介绍 parameter server 架构。

parameter server 策略:

parameter server 异步更新策略是指每个 GPU 或者 CPU 计算完梯度后,无需等待其他 GPU 或 CPU 的梯度计算(有时可以设置需要等待的梯度个数),就可立即更新整体的权值,然后同步此权值,即可进行下一轮计算。



parameter server 的架构


而 Tensorflow 一开始支持分布式的时候,便是这种 parameter server 架构。TensorFlow 一般将任务分为两类 job:一类叫参数服务器,parameter server,简称为 ps,用于存储可训练的参数变量 tf.Variable;一类就是普通任务,称为 worker,用于执行具体的计算。


Tensorflow 支持两种方式实现 parameter server:低阶 API 创建 parameter server 集群方式和 tf.distribute.Strategy 中的 ParameterServerStrategy。

低阶 API 创建 parameter server 集群

完整案例 dist_tf.py:


import tensorflow as tfimport numpy as np
# 创建集群信息,包括ps和worker两种角色。# 集群有两类任务,ps和worker;ps由2个任务组成(一般一个任务是一个机器或者一个分配单元),worker由3个任务组成。ps_hosts = ["xx.xxx.xx.xxxx:oooo", "xx.xxx.xx.xxxx:oooo"]worker_hosts = ["xx.xxx.xx.xxxx:oooo", "xx.xxx.xx.xxxx:oooo", "xx.xxx.xx.xxxx:oooo"]cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts})
tf.app.flags.DEFINE_string("job_name", "worker", "One of 'ps', 'worker'")tf.app.flags.DEFINE_integer("task_index", 0, "Index of task within the job")FLAGS = tf.app.flags.FLAGS
def main(_): server = tf.train.Server(cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index) if FLAGS.job_name == "ps": server.join() else: # 会根据job名,将with内的Variable op放到ps tasks,将其他计算op放到worker tasks。默认分配策略是轮询 with tf.device(tf.train.replica_device_setter( worker_device="/job:worker/task:%d" % FLAGS.task_index, cluster=cluster)):
x_data = tf.placeholder(tf.float32, [100]) y_data = tf.placeholder(tf.float32, [100])
W = tf.Variable(tf.random_uniform([1], -1.0, 1.0)) b = tf.Variable(tf.zeros([1])) y = W * x_data + b loss = tf.reduce_mean(tf.square(y - y_data))
global_step = tf.Variable(0, name="global_step", trainable=False) optimizer = tf.train.GradientDescentOptimizer(0.1) train_op = optimizer.minimize(loss, global_step=global_step)
# The StopAtStepHook handles stopping after running given steps. hooks = [tf.train.StopAtStepHook(last_step=1000000)] # The MonitoredTrainingSession takes care of session initialization, # restoring from a checkpoint, saving to a checkpoint, and closing when done # or an error occurs. with tf.train.MonitoredTrainingSession(master=server.target, is_chief=(FLAGS.task_index == 0), # 我们制定task_index为0的任务为主任务,用于负责变量初始化、做checkpoint、保存summary和复原 checkpoint_dir="/tmp/tf_train_logs", save_checkpoint_secs=None, hooks=hooks) as mon_sess: while not mon_sess.should_stop(): # Run a training step asynchronously. # See `tf.train.SyncReplicasOptimizer` for additional details on how to # perform *synchronous* training. # mon_sess.run handles AbortedError in case of preempted PS. train_x = np.random.rand(100).astype(np.float32) train_y = train_x * 0.1 + 0.3 _, step, loss_v, weight, biase = mon_sess.run([train_op, global_step, loss, W, b], feed_dict={x_data: train_x, y_data: train_y}) if step % 100 == 0: print("step: %d, weight: %f, biase: %f, loss: %f" % (step, weight, biase, loss_v)) print("Optimization finished.")

if __name__ == "__main__": tf.app.run()
复制代码


对于本例而言,我们需要在对应的 5 台机器上分别运行每个任务,共需执行五次代码,生成五个任务。


python dist_tf.py --job_name=ps --task_index=0python dist_tf.py --job_name=ps --task_index=1python dist_tf.py --job_name=worker --task_index=0python dist_tf.py --job_name=worker --task_index=1python dist_tf.py --job_name=worker --task_index=2
复制代码


低阶 API 创建 parameter server 集群缺点:


概念多,学习曲线陡峭。


单机代码到多机修改的代码量大。


需要多台机子跑不同的脚本,当然这可以通过 k8s 集群管理工具来解决。


PS 和 Worker 的比例不好选取。(建议选取偶数个的 ps,我的经验是 ps 和 worker 的比例是 1:3)


训练速度性能损失较大。(通信代价较高)


parameter server 常见的优化点:


如果有参数量较大的 embedding 变量时,可选择使用 embedding_lookup_sparse_with_distributed_aggregation 函数替代 tf.nn.embedding_lookup_sparse 函数。该函数可将 embedding 的聚合计算都放在变量所在的 PS 端,计算后转成稠密张量再传送到 Worker 上继续网络模型的计算。


tf.device 函数中有一个参数是设置变量在 ps 端放置策略的,可使用 tf.contrib.training.GreedyLoadBalancingStrategy 来替代默认的轮循。优点是:可根据参数的内存字节来完成类似在线垃圾收集的工作。根据 weight 和 bias 的字节数来放置到内存合适的 task 中,带来更好的负载平衡。


当参数有超大量级时(比如 embedding 参数),可在创建变量的时候使用分割变量策略:partitioner=tf.fixed_size_partitioner(ps_nums)


优化 input pipeline。链接:https://www.tensorflow.org/guide/performance/datasets


bandwidth 高带宽范亲和策略,保证多个 ps 分布在不同的物理机上。


Estimator 中的 ParameterServerStrategy 策略


# https://stackoverflow.com/questions/55003279/parameter-server-strategy-with-estimatorstensorflowimport tensorflow as tfimport osimport json
NUM_WORKERS = 1IP_ADDRS = ['localhost']PORTS = [12345]
def model_fn(...): .....
def input_fn(...): .....
复制代码

需要每个机器配置 TF_CONFIG 环境变量

os.environ['TF_CONFIG'] = json.dumps({    'cluster': {        'worker': ['%s:%d' % (IP_ADDRS[w], PORTS[w]) for w in range(NUM_WORKERS)],        'ps': ['%s:%d' % (IP_ADDRS[w], PORTS[w]) for w in range(NUM_WORKERS)]    },    'task': {'type': 'worker', 'index': 0}})
# Method for using ParamterServerStrategystrategy = tf.distribute.experimental.ParameterServerStrategy()
config = tf.estimator.RunConfig(train_distribute=strategy)
classifier = tf.estimator.Estimator( model_fn=model_fn, model_dir='/tmp/multiworker', config=config)tf.estimator.train_and_evaluate( classifier, train_spec=tf.estimator.TrainSpec(input_fn=input_fn), eval_spec=tf.estimator.EvalSpec(input_fn=input_fn))
复制代码


本文转载自 Alex-zhai 知乎账号。


原文链接:https://zhuanlan.zhihu.com/p/69010949


2019-12-02 16:234377

评论

发布
暂无评论
发现更多内容

运行时应用自我保护(RASP):应用安全的自我修养

SEAL安全

RASP

攻防演练合集 | 3个阶段,4大要点,蓝队防守全流程纲要解读

青藤云安全

网络安全 网络攻防 安全服务 攻防演练

如何用 Redis 实现一个分布式锁

Ayue、

redis 分布式锁

大数据培训 | Flink如何监控恶意登录

@零度

大数据

得物多活架构设计之路由服务设计

得物技术

架构 高可用 架构设计 双活 路由

使用Mycat进行MySQL单库分表

迷彩

架构 运维 mycat 分布式数据库中间件 6月月更

DevEco Device Tool 助力OpenHarmony设备开发

OpenHarmony开发者

OpenHarmony

不止于观测|阿里云可观测套件正式发布

阿里巴巴云原生

阿里云 云原生 可观测 套件

quarkus+saas多租户动态数据源切换实现简单完美

weir威尔

SaaS 多租户 Quarkus 动态数据源

Angular 服务器端渲染应用一个常见的内存泄漏问题

汪子熙

typescript 前端开发 angular Spartacus 6月月更

开发增效利器—2022年VsCode插件分享

中原银行

ide vscode 插件 中原银行 降本增效

《Java编程思想》作者Bruce Eckel新作,到底做了哪些升级?

图灵教育

Java

“芯”有灵“蜥”,万人在线!龙蜥社区走进 Intel MeetUp 精彩回顾

OpenAnolis小助手

开源 直播 Meetup 龙蜥社区 走进 Intel

直播带货app源码搭建中,直播CDN的原理是什么?

开源直播系统源码

软件开发 直播带货 直播系统 app源码

5 个关于 NFT 的技术漏洞

devpoint

区块链 以太坊 NFT 6月月更

如何使用 Django Forms 创建表单?

海拥(haiyong.site)

Python django 6月月更

Wallys/DR6018-S/ 802.11AX MU-MIMO OFDMA / 2* GE PORTS/WIFI 6e / BAND DUAL CONCURRENT

wallys-wifi6

【云舟说直播间】-数字安全专场明天下午正式上线

云计算

消息队列的丢失、重复与积压问题

Damon

6月月更

Kafka ETL 之后,我们将如何定义新一代实时数据集成解决方案?

tapdata

kafka ETL 数据集成 实时数据 DaaS

Vone新闻 | 旺链科技赋能众享链网自组织管理,打造企业级联盟DAO

旺链科技

区块链 产业区块链 DAO 自组织协作

navicat定时任务无效

源字节1号

Go 语言使用 MySQL 的常见故障分析和应对方法

百度Geek说

Go MySQL

坚持五件事,带你走出迷茫困境!

博文视点Broadview

成熟的知识管理,应具备哪些条件?

小炮

Rancher 2.6 全新 Monitoring 快速入门

Rancher

Kubernetes k8s rancher

想学习eTS开发?教你开发一款IQ-EQ测试应用

HarmonyOS开发者

HarmonyOS

java培训 | Java设计模式之装饰者设计模式

@零度

JAVA开发

java程序员培训 | Java设计模式之桥接模式

@零度

设计模式 JAVA开发

并购增资或将有望启动东软越通新动能?

E科讯

大数据培训 | 电商用户行为分析之订单支付实时监控

@零度

大数据 flink

浅谈Tensorflow分布式架构:parameter server及优化策略_语言 & 开发_Alex-zhai_InfoQ精选文章