【ArchSummit】如何通过AIOps推动可量化的业务价值增长和效率提升?>>> 了解详情
写点什么

分布式 tensorflow 源码解读 3:lookup.index_table_from_tensor

  • 2019-11-28
  • 本文字数:9943 字

    阅读完需:约 33 分钟

分布式tensorflow源码解读3:lookup.index_table_from_tensor

背景

推荐排序 dnn 模型中经常会用一些特别稀疏的 id 特征,此时需要对这些 id 特征做 embedding 操作,一般在 tensorflow 中都会使用一个 shape=[id_index_size, embedding_size]的 Variable 矩阵做 embedding 参数,然后根据 id 特征的 index 去 Variable 矩阵中查表得到相应的 embedding 表示。这里需要注意的是:id_index_size 的大小一般都不会等于对应 id table 的元素个数,因为有很多 id 元素不在原始的 id table 表中,比如新上架的一些商品等。此时需要将 id_index_size 设置的大一些,以留一些位置给那些不在 id table 表的元素使用。那么从原始的 id 特征值映射成[0, id_index_size)间的 index 是怎么做的呢?


如果 id 量非常小的话,可在特征提取后把 id 排序一遍生成从 0 开始的连续 id 值,但在工业界的场景下 id 特征的量级往往是百万到上亿级别,很难做排序。幸好 tensorFlow 内部有一个函数可以将原始的 id 特征值映射到从 0 开始的 index,这个函数就是 lookup.index_table_from_tensor。

例子

首先举一个 index_table_from_tensor 的使用例子:


import tensorflow as tf
sess = tf.Session()vocabulary_list = tf.constant(["emerson", "lake", "palmer"])table = tf.contrib.lookup.index_table_from_tensor( vocabulary_list=vocabulary_list, num_oov_buckets=10, default_value=-1)features = tf.constant(["emerson", "lake", "and", "palmer", "dad", "mom", "hello"])table.init.run(session=sess)ids = table.lookup(features)print(sess.run(ids))
返回值:[ 0 1 10 2 9 3 9]
复制代码


index_table_from_tensor 函数的作用是:如果在 vocabulary_list 里面的 id 特征值,则映射为从 0 到 len(table)-1 之间的 index,如果不在 vocabulary_list 里面,则通过 hash 函数映射到 len(table)至 len(table) + num_oov_buckets - 1 之间的区间中。


我们开始分析源代码,首先看下 index_table_from_tensor 的源码:


def index_table_from_tensor(vocabulary_list,                            num_oov_buckets=0,                            default_value=-1,                            hasher_spec=FastHashSpec,                            dtype=dtypes.string,                            name=None):  """Returns a lookup table that converts a string tensor into int64 IDs.  This operation constructs a lookup table to convert tensor of strings into  int64 IDs. The mapping can be initialized from a string `vocabulary_list` 1-D  tensor where each element is a key and corresponding index within the tensor  is the value.
Args: vocabulary_list: A 1-D `Tensor` that specifies the mapping of keys to indices. The type of this object must be castable to `dtype`. num_oov_buckets: The number of out-of-vocabulary buckets. default_value: The value to use for out-of-vocabulary feature values. Defaults to -1. hasher_spec: A `HasherSpec` to specify the hash function to use for assignment of out-of-vocabulary buckets. dtype: The type of values passed to `lookup`. Only string and integers are supported. name: A name for this op (optional). Returns: The lookup table to map an input `Tensor` to index `int64` `Tensor`. Raises: ValueError: If `vocabulary_list` is invalid. ValueError: If `num_oov_buckets` is negative. """
“”“ 类型检查代码 ”“”
with ops.name_scope(name, "string_to_index") as feat_to_id_scope: # 获取由传入的vocabulary_list构造的hash表,keys是vocabulary_list里面的元素,values是 # 【0, len(num_elements)-1】 keys = ops.convert_to_tensor(vocabulary_list) num_elements = array_ops.size(keys) values = math_ops.to_int64(math_ops.range(num_elements))
shared_name = "" with ops.name_scope(None, "hash_table") as hash_table_scope: table_keys = math_ops.to_int64(keys) if keys.dtype.is_integer else keys init = KeyValueTensorInitializer( table_keys, values, table_keys.dtype.base_dtype, dtypes.int64, name="table_init") table = HashTable( init, default_value, shared_name=shared_name, name=hash_table_scope) if num_oov_buckets: table = IdTableWithHashBuckets( table, num_oov_buckets=num_oov_buckets, hasher_spec=hasher_spec, name=feat_to_id_scope, key_dtype=dtype) return table
复制代码


从上面的代码片段可以看出,index_table_from_tensor 函数体中有三个比较重要的步骤:


根据 keys 和 values 创建 KeyValueTensorInitializer。


根据 KeyValueTensorInitializer 创建具体的 HashTable 类,对于不在 vocabulary_list 表里面的元素,统一返回默认值。


如果 num_oov_buckets>0,则会创建带有 Buckets 的 IdTableWithHashBuckets 类,用一个 hash 函数返回那些不在 vocabulary_list 表里面的 id 元素的 index 值。


首先我们来分析第一步的操作,先通过传入的 vocabulary_list 去构造 keys 和 values,然后传入到 KeyValueTensorInitializer 类中返回 table 初始化的 op,注意此时只是返回一个用于 table 初始化的 OP,还没有做具体的初始化工作(初始化要放到具体的 HashTable 中去完成)。下面看下 KeyValueTensorInitializer 的代码:


class KeyValueTensorInitializer(TableInitializerBase):  """Table initializers given `keys` and `values` tensors."""
def __init__(self, keys, values, key_dtype=None, value_dtype=None, name=None): with ops.name_scope(name, "key_value_init", [keys, values]) as scope: self._keys = ops.convert_to_tensor(keys, dtype=key_dtype, name="keys") self._values = ops.convert_to_tensor( values, dtype=value_dtype, name="values") self._name = scope
super(KeyValueTensorInitializer, self).__init__(self._keys.dtype, self._values.dtype)
def initialize(self, table): """Initializes the given `table` with `keys` and `values` tensors. Args: table: The table to initialize. Returns: The operation that initializes the table. Raises: TypeError: when the keys and values data types do not match the table key and value data types. """ _check_table_dtypes(table, self._keys.dtype, self._values.dtype) with ops.name_scope( self._name, values=(table.table_ref, self._keys, self._values)) as scope: init_op = gen_lookup_ops.lookup_table_import_v2( table.table_ref, self._keys, self._values, name=scope) ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op) return init_op可看出table的初始化操作是通过lookup_table_import_v2这个op来完成,下面就分析下这个op的相关c代码:
REGISTER_OP("LookupTableImportV2") .Input("table_handle: resource") .Input("keys: Tin") .Input("values: Tout") .Attr("Tin: type") .Attr("Tout: type") .SetShapeFn([](InferenceContext* c) { ShapeHandle handle; TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle));
ShapeHandle keys; TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &keys)); TF_RETURN_IF_ERROR(c->Merge(keys, c->input(2), &keys)); return Status::OK(); });
// Clear the table and insert data.class LookupTableImportOp : public OpKernel { public: explicit LookupTableImportOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override { lookup::LookupInterface* table; OP_REQUIRES_OK(ctx, GetLookupTable("table_handle", ctx, &table)); “”“ .... ”“”
const Tensor& keys = ctx->input(1); const Tensor& values = ctx->input(2); OP_REQUIRES_OK(ctx, table->CheckKeyAndValueTensorsForImport(keys, values));
int memory_used_before = 0; if (ctx->track_allocations()) { memory_used_before = table->MemoryUsed(); } OP_REQUIRES_OK(ctx, table->ImportValues(ctx, keys, values)); if (ctx->track_allocations()) { ctx->record_persistent_memory_allocation(table->MemoryUsed() - memory_used_before); } }};
复制代码


这里通过 table_handle 句柄来传入具体的 Table 类,上面举的例子中使用的是 HashTable 类,HashTable 类继承于 InitializableLookupTable 类,InitializableLookupTable 可理解为那些需要初始化操作的 table 集合,而 InitializableLookupTable 又继承于 LookupInterface 类。通过 table->ImportValues()方法来将 vocabulary_list 中构造的 keys 和 values 导入到 hash 表中,完成 hash 表的初始化工作。


通过 LookupTableImportV2 返回 init_op 之后,将该初始化 hash 表的 op 传给 HashTable 类,HashTable 类和其父类的代码如下:


class HashTable(InitializableLookupTableBase):
def __init__(self, initializer, default_value, shared_name=None, name=None): """Creates a non-initialized `HashTable` object. Creates a table, the type of its keys and values are specified by the initializer. Before using the table you will have to initialize it. After initialization the table will be immutable. Args: initializer: The table initializer to use. See `HashTable` kernel for supported key and value types. default_value: The value to use if a key is missing in the table. shared_name: If non-empty, this table will be shared under the given name across multiple sessions. name: A name for the operation (optional). Returns: A `HashTable` object. """ with ops.name_scope(name, "hash_table", (initializer, default_value)) as scope: table_ref = gen_lookup_ops.hash_table_v2( shared_name=shared_name, key_dtype=initializer.key_dtype, value_dtype=initializer.value_dtype, name=scope)
super(HashTable, self).__init__(table_ref, default_value, initializer) self._value_shape = self._default_value.get_shape()

class InitializableLookupTableBase(LookupInterface): """Initializable lookup table interface. An initializable lookup tables persist across different steps. """
def __init__(self, table_ref, default_value, initializer): """Construct a table object from a table reference. If requires a table initializer object (subclass of `TableInitializerBase`). It provides the table key and value types, as well as the op to initialize the table. The caller is responsible to execute the initialization op. Args: table_ref: The table reference, i.e. the output of the lookup table ops. default_value: The value to use if a key is missing in the table. initializer: The table initializer to use. """ name = table_ref.op.name.split("/")[-1] super(InitializableLookupTableBase, self).__init__(initializer.key_dtype, initializer.value_dtype, name) self._table_ref = table_ref self._default_value = ops.convert_to_tensor( default_value, dtype=self._value_dtype) self._default_value.get_shape().merge_with(tensor_shape.scalar()) self._init = initializer.initialize(self)
@property def table_ref(self): """Get the underlying table reference.""" return self._table_ref
@property def default_value(self): """The default value of the table.""" return self._default_value
@property def init(self): """The table initialization op.""" return self._init
def lookup(self, keys, name=None): """Looks up `keys` in a table, outputs the corresponding values. The `default_value` is used for keys not present in the table. """ key_tensor = keys with ops.name_scope(name, "%s_Lookup" % self._name, (self._table_ref, key_tensor, self._default_value)) as scope: values = gen_lookup_ops.lookup_table_find_v2( self._table_ref, key_tensor, self._default_value, name=scope)
values.set_shape(key_tensor.get_shape()) if isinstance(keys, sparse_tensor.SparseTensor): return sparse_tensor.SparseTensor(keys.indices, values, keys.dense_shape) else: return values
复制代码


InitializableLookupTableBase 类继承于 HashTable,InitializableLookupTableBase 类中有 init 操作和 look_up,所以表的初始化和 look_up 操作其实都是在父类中完成,其中表的初始化操作上面已经讨论过,我们就重点看下 look_up 操作的相关 OP:



lookup_table_find_v2
// Table lookup op. Perform the lookup operation on the given table.class LookupTableFindOp : public OpKernel { public: explicit LookupTableFindOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override { lookup::LookupInterface* table; OP_REQUIRES_OK(ctx, GetLookupTable("table_handle", ctx, &table)); “”“ ...... ”“”
const Tensor& key = ctx->input(1); const Tensor& default_value = ctx->input(2); OP_REQUIRES_OK(ctx, table->CheckFindArguments(key, default_value));
TensorShape output_shape = key.shape(); output_shape.RemoveLastDims(table->key_shape().dims()); output_shape.AppendShape(table->value_shape()); Tensor* out; OP_REQUIRES_OK(ctx, ctx->allocate_output("values", output_shape, &out));
OP_REQUIRES_OK(ctx, table->Find(ctx, key, out, default_value)); }};
REGISTER_KERNEL_BUILDER(Name("LookupTableFind").Device(DEVICE_CPU), LookupTableFindOp);REGISTER_KERNEL_BUILDER(Name("LookupTableFindV2").Device(DEVICE_CPU), LookupTableFindOp);
# table->Find调用的是HashTable类的DoFind方法Status DoFind(const Tensor& key, Tensor* value, const Tensor& default_value) override { const V default_val = default_value.flat<V>()(0); const auto key_values = key.flat<K>(); auto value_values = value->flat<V>();
for (int64 i = 0; i < key_values.size(); ++i) { value_values(i) = gtl::FindWithDefault( *table_, SubtleMustCopyIfIntegral(key_values(i)), default_val); } return Status::OK(); }

复制代码


最后我们分析下 IdTableWithHashBuckets 类,该类可用一个 hash 函数返回那些不在 vocabulary_list 表里 id 元素的 index 值,先举个例子:


import tensorflow as tfnum_oov_buckets = 3input_tensor = tf.constant(["emerson", "lake", "palmer", "king", "crimnson"])table = tf.IdTableWithHashBuckets(      tf.HashTable(tf.TextFileIdTableInitializer(filename), default_value),      num_oov_buckets)out = table.lookup(input_tensor).table.init.run()print(out.eval())
#filename依次为"emerson", "lake", "palmer"
复制代码


#结果为[0, 1, 2, 4, 7]


下面看下 IdTableWithHashBuckets 的源码:


class IdTableWithHashBuckets(LookupInterface):  """String to Id table wrapper that assigns out-of-vocabulary keys to buckets.  For example, if an instance of `IdTableWithHashBuckets` is initialized with a  string-to-id table that maps:  * `emerson -> 0`  * `lake -> 1`  * `palmer -> 2`  The `IdTableWithHashBuckets` object will performs the following mapping:  * `emerson -> 0`  * `lake -> 1`  * `palmer -> 2`  * `<other term> -> bucket_id`, where bucket_id will be between `3` and  `3 + num_oov_buckets - 1`, calculated by:  `hash(<term>) % num_oov_buckets + vocab_size`  If input_tensor is `["emerson", "lake", "palmer", "king", "crimson"]`,  the lookup result is `[0, 1, 2, 4, 7]`.  If `table` is None, only out-of-vocabulary buckets are used.  """
def __init__(self, table, num_oov_buckets, hasher_spec=FastHashSpec, name=None, key_dtype=None): """Construct a `IdTableWithHashBuckets` object. Args: table: Table that maps `tf.string` or `tf.int64` keys to `tf.int64` ids. num_oov_buckets: Number of buckets to use for out-of-vocabulary keys. hasher_spec: A `HasherSpec` to specify the hash function to use for assignation of out-of-vocabulary buckets (optional). name: A name for the operation (optional). key_dtype: Data type of keys passed to `lookup`. Defaults to `table.key_dtype` if `table` is specified, otherwise `tf.string`. Must be string or integer, and must be castable to `table.key_dtype`. Raises: ValueError: when `table` in None and `num_oov_buckets` is not positive. TypeError: when `hasher_spec` is invalid. """ # If a name ends with a '/' it is a "name scope", remove all trailing '/' # characters to use as table name. “”“ 类型检查 “”“
self._num_oov_buckets = num_oov_buckets super(IdTableWithHashBuckets, self).__init__(key_dtype, dtypes.int64, name.split("/")[-1])
@property def init(self): """The table initialization op.""" if self._table: return self._table.init with ops.name_scope(None, "init"): return control_flow_ops.no_op()
@property def table_ref(self): if self._table is not None: return self._table.table_ref return None
def _get_string_to_hash_bucket_fn(self, hasher_spec): """Returns the string_to_hash_bucket op to use based on `hasher_spec`.""" if not isinstance(hasher_spec, HasherSpec): raise TypeError("hasher_spec must be of type HasherSpec %s" % hasher_spec) if hasher_spec.hasher == "fasthash": return string_ops.string_to_hash_bucket_fast if hasher_spec.hasher == "legacy": return string_ops.string_to_hash_bucket if hasher_spec.hasher == "stronghash": return functools.partial( string_ops.string_to_hash_bucket_strong, key=hasher_spec.key) raise ValueError("Unknown hasher %s" % hasher_spec.hasher)
def lookup(self, keys, name=None):
values = keys if self._num_oov_buckets == 0: ids = self._table.lookup(values, name=name) else: # TODO(yleon): Consider moving this functionality to its own kernel. with ops.name_scope(name, "%s_Lookup" % self.name) as scope: str_to_hash_bucket = self._get_string_to_hash_bucket_fn( self._hasher_spec) buckets = str_to_hash_bucket( _as_string(values), num_buckets=self._num_oov_buckets, name="hash_bucket") if self._table: ids = self._table.lookup(values) buckets = math_ops.add(buckets, self._table.size()) is_id_non_default = math_ops.not_equal(ids, self._table.default_value) ids = array_ops.where(is_id_non_default, ids, buckets, name=scope) else: ids = buckets return ids
复制代码


上面的 lookup 操作为核心代码,大致流程为:先将当前待查找的所有 keys 通过 str_to_hash_bucket 这个 hash 函数映射成 int 类型的 bucket 号,然后为了防止和已经在表中的那些 index 冲突,需要统一加上 table.size()。然后在 table 中查找所有 keys 的 index,如果在 id table 表中会返回正确的 index,如果不在就返回默认值。然后将那些查找后等于默认值的 keys 的 index 换成通过 hash 函数得到的 bucket 号。这样存在的性能问题是当 id table 表较大时,会有很多时间浪费在没必要的 hash 操作上。所以一种优化方案是:先查表得到所有 keys 的 index,然后只 hash 那些映射值等于默认值的元素。


我自己基于这个思路实现了一版,暂且就称为 lookup_v2 吧,但是虽然这样不用为所有的 keys 做 hash,但也增加了几个操作,所以性能还不能保证。后续会考虑直接优化 c++层面的代码。下面是核心代码:


 def lookup_v2(self, keys, name=None):
values = keys if self._num_oov_buckets == 0: ids = self._table.lookup(values, name=name) else: with ops.name_scope(name, "%s_Lookup" % self.name) as scope: str_to_hash_bucket = self._get_string_to_hash_bucket_fn( self._hasher_spec) if self._table: init_ids = self._table.lookup(values) is_id_default_idx = array_ops.where(math_ops.equal(init_ids, self._table.default_value)) hash_values = array_ops.gather(values, array_ops.squeeze(is_id_default_idx)) default_buckets = str_to_hash_bucket( _as_string(hash_values), num_buckets=self._num_oov_buckets, name="hash_bucket") default_buckets = math_ops.add(default_buckets, self._table.size()) default_buckets = control_flow_ops.cond(gen_math_ops.equal(array_ops.size(default_buckets), 1), lambda: array_ops.expand_dims(default_buckets, axis=0), lambda: default_buckets) ids = gen_array_ops.tensor_scatter_update(init_ids, is_id_default_idx, default_buckets) else: ids = str_to_hash_bucket(_as_string(values), num_buckets=self._num_oov_buckets, name="hash_bucket") return ids
复制代码


参考文献:


https://mathmach.com/2442aa9e/


https://github.com/tensorflow/tensorflow/blob/5b900cfe4b3b848f577315a0dde09a729f770e95/tensorflow/python/ops/lookup_ops.py#L1042


本文转载自 Alex-zhai 知乎账号。


原文链接:https://zhuanlan.zhihu.com/p/93116229


2019-11-28 08:003151

评论

发布
暂无评论
发现更多内容

2020年亚洲南京大数据产业展览会

南京专业智博会

展览会 论坛会 博览会 智博会

PageHelper

BitSea

Wireshark的使用与数据分析(一)

姬翔

原创 | 使用JUnit、AssertJ和Mockito编写单元测试和实践TDD (十二)编写测试-简单测试

编程道与术

Java 编程 TDD 单元测试 JUnit

用大数据分析了一线城市1000多份岗位招聘需求,告诉你如何科学找工作

程序员柠檬

Python 数据分析

2020亚洲智能家居全屋智能展会-南京站

南京专业智博会

展览会 论坛会 博览会 智博会

分布式锁🔒是个啥❓ 其实就这么点事

山中兰花草

Java redis 后端 分布式锁 开发

你真的清楚 Nginx 指令的规则吗?

子杨

nginx 运维

2020亚洲南京第十三届人工智能机器人服务展览会

南京专业智博会

展览会 论坛会 博览会 智博会

面试造飞机系列:说下微服务接口设计原则?

程序员柠檬

分布式 微服务 后端 架构设计

Linux下程序出问题不要慌,资深程序员教你6招搞定!

程序员柠檬

Linux 程序员 后台开发

原创 面试官:你说对MySQL事务很熟?那我问你10个问题

程序员柠檬

MySQL 数据库

Java 底层基础笔记(一)硬件

奈何花开

Java Linux 计算机基础

游戏夜读 | 记忆里的老游戏

game1night

Markdown 几行字符就可以生成思维导图了!

JackTian

markdown 思维导图 markdown语法 markdown编辑器 Markmap

2020年南京第十三届智慧停车展会

南京专业智博会

展览会 博览会 智博会 展览会论坛会

如何优雅地实现泛型类的类型参数化

KAMI

Java 编程 反射 泛型

10分钟白嫖我的常用的在线工具网站清单

JavaGuide

设计 在线工具 工具类网站 PDF

使用 PCA 进行降维可视化,了解特征分布

黄大路

数据挖掘 数据分析 可视化

github看代码效率提高10倍!因为用了sourcegraph这个工具

程序员柠檬

GitHub 程序员 效率工具

k6新崛起的性能测试工具

风中之心

DevOps 性能 性能测试

这可能是 Markdown 写微信公众号的一款神器了!

JackTian

效率工具 markdown markdown编辑器 markdownnice 神器

ARTS打卡-01

Geek_yansheng25

ARTS 打卡计划

2020年南京第十三届物联网应用展览会

南京专业智博会

展览会 论坛会 博览会 智博会

2020南京第十三届智慧工地装备展览会

南京专业智博会

展览会 论坛会 博览会 智博会

Jupyter最佳实践

pydata

ARTS-week-1

saddamwilson

ARTS 打卡计划

不忘初心,继续努力

一周思进

ARTS 打卡计划

如何衡量产品需求效果

黄大路

产品经理 产品设计 运营

思维模型 - 概念篇

石云升

学习 高效 思维模型 决策

推荐几款有意思的小众App(05.30)

静陌

产品 App

分布式tensorflow源码解读3:lookup.index_table_from_tensor_文化 & 方法_Alex-zhai_InfoQ精选文章