NVIDIA 初创加速计划,免费加速您的创业启动 了解详情
写点什么

分布式 tensorflow 源码解读 3:lookup.index_table_from_tensor

  • 2019-11-28
  • 本文字数:9943 字

    阅读完需:约 33 分钟

分布式tensorflow源码解读3:lookup.index_table_from_tensor

背景

推荐排序 dnn 模型中经常会用一些特别稀疏的 id 特征,此时需要对这些 id 特征做 embedding 操作,一般在 tensorflow 中都会使用一个 shape=[id_index_size, embedding_size]的 Variable 矩阵做 embedding 参数,然后根据 id 特征的 index 去 Variable 矩阵中查表得到相应的 embedding 表示。这里需要注意的是:id_index_size 的大小一般都不会等于对应 id table 的元素个数,因为有很多 id 元素不在原始的 id table 表中,比如新上架的一些商品等。此时需要将 id_index_size 设置的大一些,以留一些位置给那些不在 id table 表的元素使用。那么从原始的 id 特征值映射成[0, id_index_size)间的 index 是怎么做的呢?


如果 id 量非常小的话,可在特征提取后把 id 排序一遍生成从 0 开始的连续 id 值,但在工业界的场景下 id 特征的量级往往是百万到上亿级别,很难做排序。幸好 tensorFlow 内部有一个函数可以将原始的 id 特征值映射到从 0 开始的 index,这个函数就是 lookup.index_table_from_tensor。

例子

首先举一个 index_table_from_tensor 的使用例子:


import tensorflow as tf
sess = tf.Session()vocabulary_list = tf.constant(["emerson", "lake", "palmer"])table = tf.contrib.lookup.index_table_from_tensor( vocabulary_list=vocabulary_list, num_oov_buckets=10, default_value=-1)features = tf.constant(["emerson", "lake", "and", "palmer", "dad", "mom", "hello"])table.init.run(session=sess)ids = table.lookup(features)print(sess.run(ids))
返回值:[ 0 1 10 2 9 3 9]
复制代码


index_table_from_tensor 函数的作用是:如果在 vocabulary_list 里面的 id 特征值,则映射为从 0 到 len(table)-1 之间的 index,如果不在 vocabulary_list 里面,则通过 hash 函数映射到 len(table)至 len(table) + num_oov_buckets - 1 之间的区间中。


我们开始分析源代码,首先看下 index_table_from_tensor 的源码:


def index_table_from_tensor(vocabulary_list,                            num_oov_buckets=0,                            default_value=-1,                            hasher_spec=FastHashSpec,                            dtype=dtypes.string,                            name=None):  """Returns a lookup table that converts a string tensor into int64 IDs.  This operation constructs a lookup table to convert tensor of strings into  int64 IDs. The mapping can be initialized from a string `vocabulary_list` 1-D  tensor where each element is a key and corresponding index within the tensor  is the value.
Args: vocabulary_list: A 1-D `Tensor` that specifies the mapping of keys to indices. The type of this object must be castable to `dtype`. num_oov_buckets: The number of out-of-vocabulary buckets. default_value: The value to use for out-of-vocabulary feature values. Defaults to -1. hasher_spec: A `HasherSpec` to specify the hash function to use for assignment of out-of-vocabulary buckets. dtype: The type of values passed to `lookup`. Only string and integers are supported. name: A name for this op (optional). Returns: The lookup table to map an input `Tensor` to index `int64` `Tensor`. Raises: ValueError: If `vocabulary_list` is invalid. ValueError: If `num_oov_buckets` is negative. """
“”“ 类型检查代码 ”“”
with ops.name_scope(name, "string_to_index") as feat_to_id_scope: # 获取由传入的vocabulary_list构造的hash表,keys是vocabulary_list里面的元素,values是 # 【0, len(num_elements)-1】 keys = ops.convert_to_tensor(vocabulary_list) num_elements = array_ops.size(keys) values = math_ops.to_int64(math_ops.range(num_elements))
shared_name = "" with ops.name_scope(None, "hash_table") as hash_table_scope: table_keys = math_ops.to_int64(keys) if keys.dtype.is_integer else keys init = KeyValueTensorInitializer( table_keys, values, table_keys.dtype.base_dtype, dtypes.int64, name="table_init") table = HashTable( init, default_value, shared_name=shared_name, name=hash_table_scope) if num_oov_buckets: table = IdTableWithHashBuckets( table, num_oov_buckets=num_oov_buckets, hasher_spec=hasher_spec, name=feat_to_id_scope, key_dtype=dtype) return table
复制代码


从上面的代码片段可以看出,index_table_from_tensor 函数体中有三个比较重要的步骤:


根据 keys 和 values 创建 KeyValueTensorInitializer。


根据 KeyValueTensorInitializer 创建具体的 HashTable 类,对于不在 vocabulary_list 表里面的元素,统一返回默认值。


如果 num_oov_buckets>0,则会创建带有 Buckets 的 IdTableWithHashBuckets 类,用一个 hash 函数返回那些不在 vocabulary_list 表里面的 id 元素的 index 值。


首先我们来分析第一步的操作,先通过传入的 vocabulary_list 去构造 keys 和 values,然后传入到 KeyValueTensorInitializer 类中返回 table 初始化的 op,注意此时只是返回一个用于 table 初始化的 OP,还没有做具体的初始化工作(初始化要放到具体的 HashTable 中去完成)。下面看下 KeyValueTensorInitializer 的代码:


class KeyValueTensorInitializer(TableInitializerBase):  """Table initializers given `keys` and `values` tensors."""
def __init__(self, keys, values, key_dtype=None, value_dtype=None, name=None): with ops.name_scope(name, "key_value_init", [keys, values]) as scope: self._keys = ops.convert_to_tensor(keys, dtype=key_dtype, name="keys") self._values = ops.convert_to_tensor( values, dtype=value_dtype, name="values") self._name = scope
super(KeyValueTensorInitializer, self).__init__(self._keys.dtype, self._values.dtype)
def initialize(self, table): """Initializes the given `table` with `keys` and `values` tensors. Args: table: The table to initialize. Returns: The operation that initializes the table. Raises: TypeError: when the keys and values data types do not match the table key and value data types. """ _check_table_dtypes(table, self._keys.dtype, self._values.dtype) with ops.name_scope( self._name, values=(table.table_ref, self._keys, self._values)) as scope: init_op = gen_lookup_ops.lookup_table_import_v2( table.table_ref, self._keys, self._values, name=scope) ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op) return init_op可看出table的初始化操作是通过lookup_table_import_v2这个op来完成,下面就分析下这个op的相关c代码:
REGISTER_OP("LookupTableImportV2") .Input("table_handle: resource") .Input("keys: Tin") .Input("values: Tout") .Attr("Tin: type") .Attr("Tout: type") .SetShapeFn([](InferenceContext* c) { ShapeHandle handle; TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle));
ShapeHandle keys; TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &keys)); TF_RETURN_IF_ERROR(c->Merge(keys, c->input(2), &keys)); return Status::OK(); });
// Clear the table and insert data.class LookupTableImportOp : public OpKernel { public: explicit LookupTableImportOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override { lookup::LookupInterface* table; OP_REQUIRES_OK(ctx, GetLookupTable("table_handle", ctx, &table)); “”“ .... ”“”
const Tensor& keys = ctx->input(1); const Tensor& values = ctx->input(2); OP_REQUIRES_OK(ctx, table->CheckKeyAndValueTensorsForImport(keys, values));
int memory_used_before = 0; if (ctx->track_allocations()) { memory_used_before = table->MemoryUsed(); } OP_REQUIRES_OK(ctx, table->ImportValues(ctx, keys, values)); if (ctx->track_allocations()) { ctx->record_persistent_memory_allocation(table->MemoryUsed() - memory_used_before); } }};
复制代码


这里通过 table_handle 句柄来传入具体的 Table 类,上面举的例子中使用的是 HashTable 类,HashTable 类继承于 InitializableLookupTable 类,InitializableLookupTable 可理解为那些需要初始化操作的 table 集合,而 InitializableLookupTable 又继承于 LookupInterface 类。通过 table->ImportValues()方法来将 vocabulary_list 中构造的 keys 和 values 导入到 hash 表中,完成 hash 表的初始化工作。


通过 LookupTableImportV2 返回 init_op 之后,将该初始化 hash 表的 op 传给 HashTable 类,HashTable 类和其父类的代码如下:


class HashTable(InitializableLookupTableBase):
def __init__(self, initializer, default_value, shared_name=None, name=None): """Creates a non-initialized `HashTable` object. Creates a table, the type of its keys and values are specified by the initializer. Before using the table you will have to initialize it. After initialization the table will be immutable. Args: initializer: The table initializer to use. See `HashTable` kernel for supported key and value types. default_value: The value to use if a key is missing in the table. shared_name: If non-empty, this table will be shared under the given name across multiple sessions. name: A name for the operation (optional). Returns: A `HashTable` object. """ with ops.name_scope(name, "hash_table", (initializer, default_value)) as scope: table_ref = gen_lookup_ops.hash_table_v2( shared_name=shared_name, key_dtype=initializer.key_dtype, value_dtype=initializer.value_dtype, name=scope)
super(HashTable, self).__init__(table_ref, default_value, initializer) self._value_shape = self._default_value.get_shape()

class InitializableLookupTableBase(LookupInterface): """Initializable lookup table interface. An initializable lookup tables persist across different steps. """
def __init__(self, table_ref, default_value, initializer): """Construct a table object from a table reference. If requires a table initializer object (subclass of `TableInitializerBase`). It provides the table key and value types, as well as the op to initialize the table. The caller is responsible to execute the initialization op. Args: table_ref: The table reference, i.e. the output of the lookup table ops. default_value: The value to use if a key is missing in the table. initializer: The table initializer to use. """ name = table_ref.op.name.split("/")[-1] super(InitializableLookupTableBase, self).__init__(initializer.key_dtype, initializer.value_dtype, name) self._table_ref = table_ref self._default_value = ops.convert_to_tensor( default_value, dtype=self._value_dtype) self._default_value.get_shape().merge_with(tensor_shape.scalar()) self._init = initializer.initialize(self)
@property def table_ref(self): """Get the underlying table reference.""" return self._table_ref
@property def default_value(self): """The default value of the table.""" return self._default_value
@property def init(self): """The table initialization op.""" return self._init
def lookup(self, keys, name=None): """Looks up `keys` in a table, outputs the corresponding values. The `default_value` is used for keys not present in the table. """ key_tensor = keys with ops.name_scope(name, "%s_Lookup" % self._name, (self._table_ref, key_tensor, self._default_value)) as scope: values = gen_lookup_ops.lookup_table_find_v2( self._table_ref, key_tensor, self._default_value, name=scope)
values.set_shape(key_tensor.get_shape()) if isinstance(keys, sparse_tensor.SparseTensor): return sparse_tensor.SparseTensor(keys.indices, values, keys.dense_shape) else: return values
复制代码


InitializableLookupTableBase 类继承于 HashTable,InitializableLookupTableBase 类中有 init 操作和 look_up,所以表的初始化和 look_up 操作其实都是在父类中完成,其中表的初始化操作上面已经讨论过,我们就重点看下 look_up 操作的相关 OP:



lookup_table_find_v2
// Table lookup op. Perform the lookup operation on the given table.class LookupTableFindOp : public OpKernel { public: explicit LookupTableFindOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override { lookup::LookupInterface* table; OP_REQUIRES_OK(ctx, GetLookupTable("table_handle", ctx, &table)); “”“ ...... ”“”
const Tensor& key = ctx->input(1); const Tensor& default_value = ctx->input(2); OP_REQUIRES_OK(ctx, table->CheckFindArguments(key, default_value));
TensorShape output_shape = key.shape(); output_shape.RemoveLastDims(table->key_shape().dims()); output_shape.AppendShape(table->value_shape()); Tensor* out; OP_REQUIRES_OK(ctx, ctx->allocate_output("values", output_shape, &out));
OP_REQUIRES_OK(ctx, table->Find(ctx, key, out, default_value)); }};
REGISTER_KERNEL_BUILDER(Name("LookupTableFind").Device(DEVICE_CPU), LookupTableFindOp);REGISTER_KERNEL_BUILDER(Name("LookupTableFindV2").Device(DEVICE_CPU), LookupTableFindOp);
# table->Find调用的是HashTable类的DoFind方法Status DoFind(const Tensor& key, Tensor* value, const Tensor& default_value) override { const V default_val = default_value.flat<V>()(0); const auto key_values = key.flat<K>(); auto value_values = value->flat<V>();
for (int64 i = 0; i < key_values.size(); ++i) { value_values(i) = gtl::FindWithDefault( *table_, SubtleMustCopyIfIntegral(key_values(i)), default_val); } return Status::OK(); }

复制代码


最后我们分析下 IdTableWithHashBuckets 类,该类可用一个 hash 函数返回那些不在 vocabulary_list 表里 id 元素的 index 值,先举个例子:


import tensorflow as tfnum_oov_buckets = 3input_tensor = tf.constant(["emerson", "lake", "palmer", "king", "crimnson"])table = tf.IdTableWithHashBuckets(      tf.HashTable(tf.TextFileIdTableInitializer(filename), default_value),      num_oov_buckets)out = table.lookup(input_tensor).table.init.run()print(out.eval())
#filename依次为"emerson", "lake", "palmer"
复制代码


#结果为[0, 1, 2, 4, 7]


下面看下 IdTableWithHashBuckets 的源码:


class IdTableWithHashBuckets(LookupInterface):  """String to Id table wrapper that assigns out-of-vocabulary keys to buckets.  For example, if an instance of `IdTableWithHashBuckets` is initialized with a  string-to-id table that maps:  * `emerson -> 0`  * `lake -> 1`  * `palmer -> 2`  The `IdTableWithHashBuckets` object will performs the following mapping:  * `emerson -> 0`  * `lake -> 1`  * `palmer -> 2`  * `<other term> -> bucket_id`, where bucket_id will be between `3` and  `3 + num_oov_buckets - 1`, calculated by:  `hash(<term>) % num_oov_buckets + vocab_size`  If input_tensor is `["emerson", "lake", "palmer", "king", "crimson"]`,  the lookup result is `[0, 1, 2, 4, 7]`.  If `table` is None, only out-of-vocabulary buckets are used.  """
def __init__(self, table, num_oov_buckets, hasher_spec=FastHashSpec, name=None, key_dtype=None): """Construct a `IdTableWithHashBuckets` object. Args: table: Table that maps `tf.string` or `tf.int64` keys to `tf.int64` ids. num_oov_buckets: Number of buckets to use for out-of-vocabulary keys. hasher_spec: A `HasherSpec` to specify the hash function to use for assignation of out-of-vocabulary buckets (optional). name: A name for the operation (optional). key_dtype: Data type of keys passed to `lookup`. Defaults to `table.key_dtype` if `table` is specified, otherwise `tf.string`. Must be string or integer, and must be castable to `table.key_dtype`. Raises: ValueError: when `table` in None and `num_oov_buckets` is not positive. TypeError: when `hasher_spec` is invalid. """ # If a name ends with a '/' it is a "name scope", remove all trailing '/' # characters to use as table name. “”“ 类型检查 “”“
self._num_oov_buckets = num_oov_buckets super(IdTableWithHashBuckets, self).__init__(key_dtype, dtypes.int64, name.split("/")[-1])
@property def init(self): """The table initialization op.""" if self._table: return self._table.init with ops.name_scope(None, "init"): return control_flow_ops.no_op()
@property def table_ref(self): if self._table is not None: return self._table.table_ref return None
def _get_string_to_hash_bucket_fn(self, hasher_spec): """Returns the string_to_hash_bucket op to use based on `hasher_spec`.""" if not isinstance(hasher_spec, HasherSpec): raise TypeError("hasher_spec must be of type HasherSpec %s" % hasher_spec) if hasher_spec.hasher == "fasthash": return string_ops.string_to_hash_bucket_fast if hasher_spec.hasher == "legacy": return string_ops.string_to_hash_bucket if hasher_spec.hasher == "stronghash": return functools.partial( string_ops.string_to_hash_bucket_strong, key=hasher_spec.key) raise ValueError("Unknown hasher %s" % hasher_spec.hasher)
def lookup(self, keys, name=None):
values = keys if self._num_oov_buckets == 0: ids = self._table.lookup(values, name=name) else: # TODO(yleon): Consider moving this functionality to its own kernel. with ops.name_scope(name, "%s_Lookup" % self.name) as scope: str_to_hash_bucket = self._get_string_to_hash_bucket_fn( self._hasher_spec) buckets = str_to_hash_bucket( _as_string(values), num_buckets=self._num_oov_buckets, name="hash_bucket") if self._table: ids = self._table.lookup(values) buckets = math_ops.add(buckets, self._table.size()) is_id_non_default = math_ops.not_equal(ids, self._table.default_value) ids = array_ops.where(is_id_non_default, ids, buckets, name=scope) else: ids = buckets return ids
复制代码


上面的 lookup 操作为核心代码,大致流程为:先将当前待查找的所有 keys 通过 str_to_hash_bucket 这个 hash 函数映射成 int 类型的 bucket 号,然后为了防止和已经在表中的那些 index 冲突,需要统一加上 table.size()。然后在 table 中查找所有 keys 的 index,如果在 id table 表中会返回正确的 index,如果不在就返回默认值。然后将那些查找后等于默认值的 keys 的 index 换成通过 hash 函数得到的 bucket 号。这样存在的性能问题是当 id table 表较大时,会有很多时间浪费在没必要的 hash 操作上。所以一种优化方案是:先查表得到所有 keys 的 index,然后只 hash 那些映射值等于默认值的元素。


我自己基于这个思路实现了一版,暂且就称为 lookup_v2 吧,但是虽然这样不用为所有的 keys 做 hash,但也增加了几个操作,所以性能还不能保证。后续会考虑直接优化 c++层面的代码。下面是核心代码:


 def lookup_v2(self, keys, name=None):
values = keys if self._num_oov_buckets == 0: ids = self._table.lookup(values, name=name) else: with ops.name_scope(name, "%s_Lookup" % self.name) as scope: str_to_hash_bucket = self._get_string_to_hash_bucket_fn( self._hasher_spec) if self._table: init_ids = self._table.lookup(values) is_id_default_idx = array_ops.where(math_ops.equal(init_ids, self._table.default_value)) hash_values = array_ops.gather(values, array_ops.squeeze(is_id_default_idx)) default_buckets = str_to_hash_bucket( _as_string(hash_values), num_buckets=self._num_oov_buckets, name="hash_bucket") default_buckets = math_ops.add(default_buckets, self._table.size()) default_buckets = control_flow_ops.cond(gen_math_ops.equal(array_ops.size(default_buckets), 1), lambda: array_ops.expand_dims(default_buckets, axis=0), lambda: default_buckets) ids = gen_array_ops.tensor_scatter_update(init_ids, is_id_default_idx, default_buckets) else: ids = str_to_hash_bucket(_as_string(values), num_buckets=self._num_oov_buckets, name="hash_bucket") return ids
复制代码


参考文献:


https://mathmach.com/2442aa9e/


https://github.com/tensorflow/tensorflow/blob/5b900cfe4b3b848f577315a0dde09a729f770e95/tensorflow/python/ops/lookup_ops.py#L1042


本文转载自 Alex-zhai 知乎账号。


原文链接:https://zhuanlan.zhihu.com/p/93116229


2019-11-28 08:003149

评论

发布
暂无评论
发现更多内容

synchronized源码分析之锁的膨胀

Ayue、

synchronized 锁机制 锁升级

React进阶(十二):HOOK

No Silver Bullet

React Hooks 12月日更

React进阶(十一):create-react-app脚手架关闭 eslint 提醒

No Silver Bullet

React 12月日更 creat-react-app

Istio的认证授权机制分析

xcbeyond

istio 认证授权 28天写作 12月日更

[Pulsar] Consumer消费

Zike Yang

Apache Pulsar 12月日更

Android EventBus 集成问题小结

阿策小和尚

28天写作 Android 小菜鸟 12月日更

dart系列之:手写Library,Library编写最佳实践

程序那些事

flutter dart 程序那些事 12月日更

元宇宙地产:品牌和投资者的大好机会?

devpoint

以太坊 NFT 元宇宙 12月日更

解决:Command ‘mongo‘ not found, but can be installed with

liuzhen007

28天写作 12月日更

【架构实战营】模块七

衣谷

架构实战营

元宇宙100讲-0x007

hackstoic

元宇宙

50 K8S之Contour控制器

穿过生命散发芬芳

k8s 28天写作 12月日更

如何实现Redis限流

喵叔

28天写作 12月日更

NGINX从入门到实践-基础篇

小志Codings

nginx Python3

Kvrocks 在 RocksDB 上的优化实践

Kvrocks

Redis 协议

怎么活的超脱:把自己的生活看成一场戏

mtfelix

28天写作

DDD领域驱动设计实战(三)-深入理解实体

JavaEdge

12月日更

语音信号处理 4:语音信号在时域和频域的表示

轻口味

28天写作 12月日更

Flutter开发小技巧【Flutter专题23】

坚果

flutter 28天写作 12月日更

念叨了一年的游戏叙事书中文版终于出了!

博文视点Broadview

vivo:不做开发者的过客,变成IoT的归人

脑极体

【CSS 学习总结】目录 - CSS 知识点梳理

Brave

CSS 12月日更

Python 自动化领域起点篇,Selenium WebDriver 学习第1篇

梦想橡皮擦

12月日更

花一点时间优化一次年迈的后台系统的检索体验

为自己带盐

28天写作 12月日更 ​jQuery

什么是VLAN?如何配置?VLAN间路由又是怎样的?一文了解!

Ethereal

VLAN 网络技术

开始了解DevSecOps 1

搬砖的周狮傅

DevSecOps

如何用Docker Compose部署项目?

秦怀杂货店

Docker springboot

面试官:HashSet如何保证元素不重复?

王磊

【架构实战营】模块八

衣谷

架构实战营

【LeetCode】一年中的第几天Java题解

Albert

算法 LeetCode 12月日更

20《重学JAVA》--集合(二)

杨鹏Geek

Java25周年 28天写作 12月日更

分布式tensorflow源码解读3:lookup.index_table_from_tensor_文化 & 方法_Alex-zhai_InfoQ精选文章