写点什么

快速搭建 tensorflow 线上服务

2019 年 9 月 24 日

快速搭建tensorflow 线上服务

在这篇 tutorial 中,我将主要介绍如何 freeze 一个训练好的 tensorflow 模型并部署成 webserver,webserver 使用的是 python flask 框架(其它框架也可以)另外训练 tensorflow 模型时,数据输入方式选择 placeholder 加载方式,并且给 tensor variable “name” 赋值,后边 freeze 模型时会用到。


1x = tf.placeholder(tf.float32, shape=[None, img_size,img_size,num_channels], name='x')23y_true = tf.placeholder(tf.float32, shape=[None, num_classes], name='y_true')
复制代码


freeze model

freeze tensorflow model 是什么,为什么要 freeze,让我们先看下经典的 Alexnet network 结构:



1conv1 layer: (11*11)*3*96 (weights) + 96 (biases)   = 349442conv2 layer: (5*5)*96*256 (weights)+ 256 (biases)   = 6146563conv3 layer: (3*3)*256*384 (weights) + 384 (biases) = 8851204conv4 layer: (3*3)*384*384 (weights) + 384 (biases) = 13274885conv5 layer: (3*3)*384*256 (weights) + 256 (biases) = 8849926fc1 layer:   (6*6)*256*4096 (weights) + 4096 (biases) = 377528327fc2 layer:   4096*4096 (weights) + 4096 (biases) = 167813128fc3 layer:   4096*1000 (weights) + 1000 (biases) = 4097000
复制代码


网络所要计算的参数量已经超过了 6000 万,除此之外在网络训练时梯度反向传播的过程中还有相同数量的梯度值需要计算。tensorflow 训练的模型文件包含所有这些参数,但当你部署这些模型时是不需要这些梯度值的。freeze 就是把所需要的 tensorflow graph、weight 等参数保存到一个文件中的过程。


tensorflow model 的参数包含在下面四类文件中:


1)model-ckpt.meta


This contains the complete graph. This contains a serialized MetaGraphDef protocol buffer, it contains the graphDef that describles the data-flow , annotations for variables, input pipelines and other relevant information.


2)model-ckpt.data-0000-of-00001


This contains all the values of variables (weights, biases, placeholders, gradients, hyper-parameters etc.)


3)model-ckpt.index


It is an immutable table(tensorflow::table::Table),Each key is a name of a tensor and it is value is a serialized BundleEntryProto.Each BundleEntryProto describles the metadata of a Tensor.


4)checkpoint


All the checkpoint information.


总的来说当我们想把模型部署到 webserber 的时候,我们就要去除一些不必要的 meta-data, gradients and unnecessary training variables 以及 encapsulate 压缩剩余的参数到一个文件中,这个压缩后的单个文件(.pb extension)被叫做 “frozen graph def”。


freeze graph 代码如下:


 1import tensorflow as tf 2from tensorflow.python.framework import graph_util 3import os,sys 4import argparse 5# 选择保存模型inference 时需要的tensor variable 6parser = argparse.ArgumentParser() 7parser.add_argument( 8       '--meta', 9       required=True,10       type=str,11       help='input model checkpoint meta data file (.meta)'12       )13parser.add_argument(14       '--prefix',15       required=True,16       type=str,17       help='input model data prefix')18FLAGS, unparsed = parser.parse_known_args()19# 确定你想从网络中保存哪个output, 大多数时候你只会用到预测节点,这里我们只保存预测节点20output_node_names = "y_pred"21#加载保存graph 的.meta 文件并在会话中恢复weights22#saver = tf.train.import_meta_graph('model.ckpt-74928.meta', clear_devices=True)23saver = tf.train.import_meta_graph(FLAGS.meta, clear_devices=True)2425# 把graph 转换为 graph_def26graph = tf.get_default_graph()27input_graph_def = graph.as_graph_def()28sess = tf.Session()29# 使用 graph_util中的函数 convert_variables_to_constants 保存graph_def 以及网络中的ends30#saver.restore(sess, "./model.ckpt-74928")31saver.restore(sess, FLAGS.prefix)32output_graph_def = graph_util.convert_variables_to_constants(33                      sess, # The session is used to retrieve the weights34                      input_graph_def, # The graph_def is used to retrieve the nodes35                      output_node_names.split(",") # The output node names are used to select the usefull nodes36)37output_graph="estate_model.pb"38# 最后序列化并把output graph 写入.pb 文件39with tf.gfile.GFile(output_graph, "wb") as f:40       f.write(output_graph_def.SerializeToString())41sess.close()42# 最后模型从600多M 减小到 200M43# 加载使用frozen 后的模型44# 创建graph并加载权重使之保存到内存中(否则每次request 都会重新加载权重)45def load_graph(trained_model):46      """47      method 1: load graph as default graph.48      #Unpersists graph from file as default graph.49      with tf.gfile.GFile(trained_model, 'rb') as f:50              graph_def = tf.GraphDef()51              graph_def.ParseFromString(f.read())52              tf.import_graph_def(graph_def, name='')53      """54      #load graph55      with tf.gfile.GFile(trained_model, "rb") as f:56             graph_def = tf.GraphDef()57             graph_def.ParseFromString(f.read())58             tf.import_graph_def(graph_def, name='')59      with tf.Graph().as_default() as graph:60             tf.import_graph_def(61                    graph_def,62                    input_map=None,63                    return_elements=None,64                    name="")65      return graph
复制代码


最后加载保存的.pb 文件


1app = Flask(__name__)2FLAGS, unparsed = parser.parse_known_args()3g1 = load_graph(FLAGS.graph1)4session1 = tf.Session(graph=g1, config=config)5@app.route('/image_classification', methods=['POST'])6def parse_request():7...8...9app.run(host="10.200.0.174", port=int("16888"), debug=True, use_reloader=False)
复制代码


问题与解决

以下是在部署模型中遇到的一些坑:


1.模型 inference 耗时严重


运行 freeze 后的模型发现单张图片的 inference 时间消耗达到了几秒钟,经过定位发现是每次 inference 时 tensorflow 会把所有的参数从内存加载到 GPU 显存中,本质上,Tensorflow 在每次启用 run_graph 时,将所有计算加载至内存中,如果你试着在 GPU 上执行推断时会明显发现这一现象,你会看到 GPU 内存随着 tensorflow 在 GPU 上加载和卸载模型参数而升降。


解决方案:


去掉 with tf.Session() as sess 构造,向 run_graph 添加 sess 变量,这样处理后图模型的参数只会在 webserver 第一次启动时从内存加载到 GPU 显存消耗一段时间,之后每次 inference 模型参数都是在 GPU 显存中。


2.tensorflow 内存泄漏以及耗时不断增加的问题


问题代码:


 1with tf.Graph().as_default(): 2          # build graph 3          preprocessed_image = tf.placeholder(tf.float32, shape=(image_size,image_size,3), name="preprocessed_images") 4          processed_image = tf.expand_dims(preprocessed_image, 0) 5          # execute graph 6          with tf.Session() as sess: 7                 image_string_tmp = tf.gfile.FastGFile(line, 'rb').read() 8                 # 严禁把tf.image.decode_image() operate 写在此处 9                 image_decode_tmp = tf.image.decode_image(testImage_string_tmp, channels=3)10                 preprocessed_image_tmp = inception_preprocessing.preprocess_image(image_decode_tmp, image_size, image_size, is_training=False)11                 preprocessed_image_tmp_val = sess.run([preprocessed_image_tmp])12                 np_probabilities = sess.run(probabilities,{"preprocessed_image:0":preprocessed_image_tmp_val[0]})
复制代码


通过使用 time.time() 和 resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024 记录每一步骤的耗时以及内存占用情况。


从日志从发现是 tf.image.decode_image 造成的内存泄露以及耗时不断增加的问题。


解决方案


 1def preprocess(img_name, height, width, 2                           central_fraction=0.875, scope=None): 3      """ 4      :param image: preprocess image name 5      :param height: 6      :param width: 7      :param central_fraction: fraction of the image to crop 8      :param scope: scope  for name_scope 9      :return: 3-D float Tensor of prepared image.10      """11      image_raw_data = tf.gfile.FastGFile(img_name, 'r').read()12      file_extension = img_name.rsplit('.', 1)[1]13      logging.info("file_extension: %s", file_extension)14      if file_extension == 'jpg' or file_extension == 'jpeg':15             image_raw_data = tf.image.decode_jpeg(image_raw_data)16      elif file_extension == 'png':17             image_raw_data = tf.image.decode_png(image_raw_data)18             image_raw_data = tf.image.encode_jpeg(image_raw_data)19             image_raw_data = tf.image.decode_jpeg(image_raw_data)202122def run_graph1(filename, sess):23      # build graph24      with sess.graph.as_default():25             image_width = 25626             image_height = 25627             num_channels = 328             start_load_graph = time.time()29             y_pred = sess.graph.get_tensor_by_name("y_pred:0")30             ## Let's feed the images to the input placeholders31             x = sess.graph.get_tensor_by_name("x:0")32             # y_true = graph.get_tensor_by_name("y_true:0")33             # y_test_images = np.zeros((1, 2))34             #sess = tf.Session(graph=graph, config=config)35             load_graph_elapsed = time.time() - start_load_graph36             logging.info("load_graph_elapsed: %f:", load_graph_elapsed)3738             # compute preprocess image time39             start_process = time.time()40             images = preprocess(os.path.join(UPLOAD_FOLDER, filename), image_height, image_width)41             process_elapsed = time.time() - start_process42             logging.info("process_elapsed: %f:", process_elapsed)43             # execute graph44             image = images.eval(session=sess)45             x_batch = image.reshape(1, image_height, image_width, num_channels)46             feed_data_time = time.time()47             ### Creating the feed_dict that is required to be fed to calculate y_pred48             feed_dict_testing = {x: x_batch}
复制代码


虽然 tf.image.decode_image 仅仅是对图片进行解码(把图片字符转换成 tensor,可能存在为 tensor 分配内存的操作),在使用 tensorflow 的过程中把涉及 tensor 的相关操作放在构建图中。


3.加载多个模型


在 Tensorflow 中,所有操作对象都包装到相应的 Session 中的,所以想要使用不同的模型就需要将这些模型加载到不同的 Session 中并在使用的时候申明是哪个 Session,从而避免由于 Session 和想使用的模型不匹配导致的错误。而使用多个 graph,就需要为每个 graph 使用不同的 Session,但是每个 graph 也可以在多个 Session 中使用,这个时候就需要在每个 Session 使用的时候明确申明使用的 graph。


需要注意的是由于有多个 graph,所以 sess.graph 与 tf.get_default_value 的值是不相等的,因此在进入 sess 的时候必须 sess.graph.as_default()明确申明 sess.graph 为当前默认 graph,否则就会报错。


 1def run_graph1(filename, sess): 2      with sess.graph.as_default(): 3             image_width = 256 4             image_height = 256 5             num_channels = 3 6             start_load_graph = time.time() 7             y_pred = sess.graph.get_tensor_by_name("y_pred:0") 8             ## Let's feed the images to the input placeholders 9             x = sess.graph.get_tensor_by_name("x:0")10             load_graph_elapsed = time.time() - start_load_graph11             logging.info("load_graph_elapsed: %f:", load_graph_elapsed)12             # compute preprocess image time13             start_process = time.time()14             images = preprocess(os.path.join(UPLOAD_FOLDER, filename), image_height, image_width)15             process_elapsed = time.time() - start_process16             logging.info("process_elapsed: %f:", process_elapsed)17             image = images.eval(session=sess)18             x_batch = image.reshape(1, image_height, image_width, num_channels)19             feed_data_time = time.time()20             ### Creating the feed_dict that is required to be fed to calculate y_pred21             feed_dict_testing = {x: x_batch}22             feed_data_elapsed = time.time() - feed_data_time23             logging.info("feed_data_time:", feed_data_elapsed)24             start_compute_time = time.time()25             result = sess.run(y_pred, feed_dict=feed_dict_testing)26             compute_elapsed_time = time.time() - start_compute_time27             logging.info("compute_elapsed_time: %f:", compute_elapsed_time)28             return result293031def run_graph2(filename, sess):32      with sess.graph.as_default():33             image_width = 25634             image_height = 25635             num_channels = 336             start_load_graph = time.time()37             y_pred = sess.graph.get_tensor_by_name("y_pred:0")38             ## Let's feed the images to the input placeholders39             x = sess.graph.get_tensor_by_name("x:0")40             load_graph_elapsed = time.time() - start_load_graph41             logging.info("load_graph_elapsed: %f:", load_graph_elapsed)42             # compute preprocess image time43             start_process = time.time()44             images = preprocess(os.path.join(UPLOAD_FOLDER, filename), image_height, image_width)45             process_elapsed = time.time() - start_process46             logging.info("process_elapsed: %f:", process_elapsed)47             image = images.eval(session=sess)48             x_batch = image.reshape(1, image_height, image_width, num_channels)49             feed_data_time = time.time()50             ### Creating the feed_dict that is required to be fed to calculate y_pred51             feed_dict_testing = {x: x_batch}52             feed_data_elapsed = time.time() - feed_data_time53             logging.info("feed_data_time:", feed_data_elapsed)54             start_compute_time = time.time()55             result = sess.run(y_pred, feed_dict=feed_dict_testing)56             compute_elapsed_time = time.time() - start_compute_time57             logging.info("compute_elapsed_time: %f:", compute_elapsed_time)58             return result5960# web框架加载时就把不同的图赋值给不同的session61app = Flask(__name__)62FLAGS, unparsed = parser.parse_known_args()63g1 = load_graph(FLAGS.graph1)64g2 = load_graph(FLAGS.graph2)65session1 = tf.Session(graph=g1, config=config)66session2 = tf.Session(graph=g2, config=config)
复制代码


作者介绍:


作者爱贝(企业代号名),目前负责贝壳找房图像处理方向的相关工作。


本文转载自公众号贝壳产品技术(ID:gh_9afeb423f390)。


原文链接:


https://mp.weixin.qq.com/s/0cmEmIC_CEgiC_o4JwTHnw


2019 年 9 月 24 日 10:27852

评论

发布
暂无评论
发现更多内容

2021Java岗面试预备手册:涵盖20个技术栈​​​​,助你通往大厂的面试必备指南

Java成神之路

Java 程序员 架构 面试 编程语言

WinDbg 分析高内存占用问题

圣杰

dotnet windbg

浅谈nodejs进程和线程

梁龙先森

前端 nodejs 2月春节不断更

JVM又曾放过谁,垃圾终将被回收!

Simon郎

Java 大数据 架构 JVM 后端开发

猜猜用什么来存储Docker的镜像?这还真是个“非常手段”

互联网架构师小马

收录99+案例!Github获赞百万的性能优化小册也太香了

Java王路飞

Java 程序员 面试 性能优化 JVM

架构设计篇之微服务实战笔记(三)

小诚信驿站

架构师 刘晓成 小诚信驿站 28天写作 架构师成长笔记

流媒体传输协议之 RTMP

阿里云视频云

TCP 音视频 RTMP 传输协议 流媒体;

行业缩水,光靠一份神仙般的“进阶面试宝典”,我居然拿到开发岗60K京东offer

Java成神之路

Java 程序员 架构 面试 编程语言

话题讨论 | 比特币攻击重现江湖,你准备好了吗?

程序员架构进阶

话题讨论 28天写作 2月春节不断更 话题王者 勒索攻击

容器 & 服务:一个Java应用的Docker构建实战

程序员架构进阶

Docker 容器 七日更 28天写作 2月春节不断更

用Stylish精简极客时间专栏页面

闪闪带你学前端

CSS

为图片添加Emoji,微信这隐藏功能让你不花冤枉钱

彭宏豪95

微信 效率 效率工具 emoji

微信十年,弹指一挥间

彭宏豪95

微信 产品 互联网 写作

不可思议!我靠这些笔记练手,竟然拿到了蚂蚁Java岗后端开发offer

Java成神之路

Java 程序员 架构 面试 编程语言

泰山版震撼来袭!阿里巴巴2021年Java程序员面试指导小册已开源

Java架构追梦

Java 架构 面试 金三银四 跳槽

还愁追不到女神吗?一键生成舔狗日记,一秒速成舔狗之王

不脱发的程序猿

程序人生 28天写作 二月春节不断更 舔狗文化

话题讨论 | 现实中程序员是怎样飞快敲代码的?

xcbeyond

程序人生 话题讨论

话题讨论 | 你在互联网大厂是个啥级别?

架构精进之路

话题讨论 28天写作 话题王者

为什么不推荐使用汉字作为密码?

不脱发的程序猿

程序人生 密码学 28天写作 二月春节不断更

优秀!阿里甩出GC面试小册,仅7天Github获赞96.9K

程序员小毕

Java 程序员 算法 JVM GC

太猛了!Github大佬那白嫖的分布式进阶宝典,啃完感觉能吊锤面试官

Crud的程序员

Java 架构

电影台词反向搜索视频片段,这个工具也太好用了吧|33 台词

彭宏豪95

效率 效率工具 电影

Laravel来信|Event

LeastCoding

laravel Event 观察者模式

Dapr 知多少 | 分布式应用运行时

圣杰

架构 云原生 k8s dapr

2021最新百度/平安/蚂蚁金服/腾讯/拼多多面经总结(附答案解析)

比伯

Java 编程 架构 面试 计算机

话题讨论 | 程序员是做前端开发好,还是后端开发好呢?

xcbeyond

程序人生 话题讨论

有赞 Flink 实时任务资源优化探索与实践

Apache Flink

flink

2021阿里总监最新手码BAT等大厂面经!GitHub已标星86.2K

比伯

Java 编程 架构 面试 程序人生

架构师不至于“架构”-《架构师应该知道的37件事》阅读笔记

代码技艺

读书笔记 架构 架构师

CAP理论,我拆开来细讲,用例子展现,够清晰了吧

互联网架构师小马

快速搭建tensorflow 线上服务-InfoQ