最新发布《数智时代的AI人才粮仓模型解读白皮书(2024版)》,立即领取! 了解详情
写点什么

分布式 tensorflow 源码解读 3:lookup.index_table_from_tensor

  • 2019-11-28
  • 本文字数:9943 字

    阅读完需:约 33 分钟

分布式tensorflow源码解读3:lookup.index_table_from_tensor

背景

推荐排序 dnn 模型中经常会用一些特别稀疏的 id 特征,此时需要对这些 id 特征做 embedding 操作,一般在 tensorflow 中都会使用一个 shape=[id_index_size, embedding_size]的 Variable 矩阵做 embedding 参数,然后根据 id 特征的 index 去 Variable 矩阵中查表得到相应的 embedding 表示。这里需要注意的是:id_index_size 的大小一般都不会等于对应 id table 的元素个数,因为有很多 id 元素不在原始的 id table 表中,比如新上架的一些商品等。此时需要将 id_index_size 设置的大一些,以留一些位置给那些不在 id table 表的元素使用。那么从原始的 id 特征值映射成[0, id_index_size)间的 index 是怎么做的呢?


如果 id 量非常小的话,可在特征提取后把 id 排序一遍生成从 0 开始的连续 id 值,但在工业界的场景下 id 特征的量级往往是百万到上亿级别,很难做排序。幸好 tensorFlow 内部有一个函数可以将原始的 id 特征值映射到从 0 开始的 index,这个函数就是 lookup.index_table_from_tensor。

例子

首先举一个 index_table_from_tensor 的使用例子:


import tensorflow as tf
sess = tf.Session()vocabulary_list = tf.constant(["emerson", "lake", "palmer"])table = tf.contrib.lookup.index_table_from_tensor( vocabulary_list=vocabulary_list, num_oov_buckets=10, default_value=-1)features = tf.constant(["emerson", "lake", "and", "palmer", "dad", "mom", "hello"])table.init.run(session=sess)ids = table.lookup(features)print(sess.run(ids))
返回值:[ 0 1 10 2 9 3 9]
复制代码


index_table_from_tensor 函数的作用是:如果在 vocabulary_list 里面的 id 特征值,则映射为从 0 到 len(table)-1 之间的 index,如果不在 vocabulary_list 里面,则通过 hash 函数映射到 len(table)至 len(table) + num_oov_buckets - 1 之间的区间中。


我们开始分析源代码,首先看下 index_table_from_tensor 的源码:


def index_table_from_tensor(vocabulary_list,                            num_oov_buckets=0,                            default_value=-1,                            hasher_spec=FastHashSpec,                            dtype=dtypes.string,                            name=None):  """Returns a lookup table that converts a string tensor into int64 IDs.  This operation constructs a lookup table to convert tensor of strings into  int64 IDs. The mapping can be initialized from a string `vocabulary_list` 1-D  tensor where each element is a key and corresponding index within the tensor  is the value.
Args: vocabulary_list: A 1-D `Tensor` that specifies the mapping of keys to indices. The type of this object must be castable to `dtype`. num_oov_buckets: The number of out-of-vocabulary buckets. default_value: The value to use for out-of-vocabulary feature values. Defaults to -1. hasher_spec: A `HasherSpec` to specify the hash function to use for assignment of out-of-vocabulary buckets. dtype: The type of values passed to `lookup`. Only string and integers are supported. name: A name for this op (optional). Returns: The lookup table to map an input `Tensor` to index `int64` `Tensor`. Raises: ValueError: If `vocabulary_list` is invalid. ValueError: If `num_oov_buckets` is negative. """
“”“ 类型检查代码 ”“”
with ops.name_scope(name, "string_to_index") as feat_to_id_scope: # 获取由传入的vocabulary_list构造的hash表,keys是vocabulary_list里面的元素,values是 # 【0, len(num_elements)-1】 keys = ops.convert_to_tensor(vocabulary_list) num_elements = array_ops.size(keys) values = math_ops.to_int64(math_ops.range(num_elements))
shared_name = "" with ops.name_scope(None, "hash_table") as hash_table_scope: table_keys = math_ops.to_int64(keys) if keys.dtype.is_integer else keys init = KeyValueTensorInitializer( table_keys, values, table_keys.dtype.base_dtype, dtypes.int64, name="table_init") table = HashTable( init, default_value, shared_name=shared_name, name=hash_table_scope) if num_oov_buckets: table = IdTableWithHashBuckets( table, num_oov_buckets=num_oov_buckets, hasher_spec=hasher_spec, name=feat_to_id_scope, key_dtype=dtype) return table
复制代码


从上面的代码片段可以看出,index_table_from_tensor 函数体中有三个比较重要的步骤:


根据 keys 和 values 创建 KeyValueTensorInitializer。


根据 KeyValueTensorInitializer 创建具体的 HashTable 类,对于不在 vocabulary_list 表里面的元素,统一返回默认值。


如果 num_oov_buckets>0,则会创建带有 Buckets 的 IdTableWithHashBuckets 类,用一个 hash 函数返回那些不在 vocabulary_list 表里面的 id 元素的 index 值。


首先我们来分析第一步的操作,先通过传入的 vocabulary_list 去构造 keys 和 values,然后传入到 KeyValueTensorInitializer 类中返回 table 初始化的 op,注意此时只是返回一个用于 table 初始化的 OP,还没有做具体的初始化工作(初始化要放到具体的 HashTable 中去完成)。下面看下 KeyValueTensorInitializer 的代码:


class KeyValueTensorInitializer(TableInitializerBase):  """Table initializers given `keys` and `values` tensors."""
def __init__(self, keys, values, key_dtype=None, value_dtype=None, name=None): with ops.name_scope(name, "key_value_init", [keys, values]) as scope: self._keys = ops.convert_to_tensor(keys, dtype=key_dtype, name="keys") self._values = ops.convert_to_tensor( values, dtype=value_dtype, name="values") self._name = scope
super(KeyValueTensorInitializer, self).__init__(self._keys.dtype, self._values.dtype)
def initialize(self, table): """Initializes the given `table` with `keys` and `values` tensors. Args: table: The table to initialize. Returns: The operation that initializes the table. Raises: TypeError: when the keys and values data types do not match the table key and value data types. """ _check_table_dtypes(table, self._keys.dtype, self._values.dtype) with ops.name_scope( self._name, values=(table.table_ref, self._keys, self._values)) as scope: init_op = gen_lookup_ops.lookup_table_import_v2( table.table_ref, self._keys, self._values, name=scope) ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op) return init_op可看出table的初始化操作是通过lookup_table_import_v2这个op来完成,下面就分析下这个op的相关c代码:
REGISTER_OP("LookupTableImportV2") .Input("table_handle: resource") .Input("keys: Tin") .Input("values: Tout") .Attr("Tin: type") .Attr("Tout: type") .SetShapeFn([](InferenceContext* c) { ShapeHandle handle; TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle));
ShapeHandle keys; TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &keys)); TF_RETURN_IF_ERROR(c->Merge(keys, c->input(2), &keys)); return Status::OK(); });
// Clear the table and insert data.class LookupTableImportOp : public OpKernel { public: explicit LookupTableImportOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override { lookup::LookupInterface* table; OP_REQUIRES_OK(ctx, GetLookupTable("table_handle", ctx, &table)); “”“ .... ”“”
const Tensor& keys = ctx->input(1); const Tensor& values = ctx->input(2); OP_REQUIRES_OK(ctx, table->CheckKeyAndValueTensorsForImport(keys, values));
int memory_used_before = 0; if (ctx->track_allocations()) { memory_used_before = table->MemoryUsed(); } OP_REQUIRES_OK(ctx, table->ImportValues(ctx, keys, values)); if (ctx->track_allocations()) { ctx->record_persistent_memory_allocation(table->MemoryUsed() - memory_used_before); } }};
复制代码


这里通过 table_handle 句柄来传入具体的 Table 类,上面举的例子中使用的是 HashTable 类,HashTable 类继承于 InitializableLookupTable 类,InitializableLookupTable 可理解为那些需要初始化操作的 table 集合,而 InitializableLookupTable 又继承于 LookupInterface 类。通过 table->ImportValues()方法来将 vocabulary_list 中构造的 keys 和 values 导入到 hash 表中,完成 hash 表的初始化工作。


通过 LookupTableImportV2 返回 init_op 之后,将该初始化 hash 表的 op 传给 HashTable 类,HashTable 类和其父类的代码如下:


class HashTable(InitializableLookupTableBase):
def __init__(self, initializer, default_value, shared_name=None, name=None): """Creates a non-initialized `HashTable` object. Creates a table, the type of its keys and values are specified by the initializer. Before using the table you will have to initialize it. After initialization the table will be immutable. Args: initializer: The table initializer to use. See `HashTable` kernel for supported key and value types. default_value: The value to use if a key is missing in the table. shared_name: If non-empty, this table will be shared under the given name across multiple sessions. name: A name for the operation (optional). Returns: A `HashTable` object. """ with ops.name_scope(name, "hash_table", (initializer, default_value)) as scope: table_ref = gen_lookup_ops.hash_table_v2( shared_name=shared_name, key_dtype=initializer.key_dtype, value_dtype=initializer.value_dtype, name=scope)
super(HashTable, self).__init__(table_ref, default_value, initializer) self._value_shape = self._default_value.get_shape()

class InitializableLookupTableBase(LookupInterface): """Initializable lookup table interface. An initializable lookup tables persist across different steps. """
def __init__(self, table_ref, default_value, initializer): """Construct a table object from a table reference. If requires a table initializer object (subclass of `TableInitializerBase`). It provides the table key and value types, as well as the op to initialize the table. The caller is responsible to execute the initialization op. Args: table_ref: The table reference, i.e. the output of the lookup table ops. default_value: The value to use if a key is missing in the table. initializer: The table initializer to use. """ name = table_ref.op.name.split("/")[-1] super(InitializableLookupTableBase, self).__init__(initializer.key_dtype, initializer.value_dtype, name) self._table_ref = table_ref self._default_value = ops.convert_to_tensor( default_value, dtype=self._value_dtype) self._default_value.get_shape().merge_with(tensor_shape.scalar()) self._init = initializer.initialize(self)
@property def table_ref(self): """Get the underlying table reference.""" return self._table_ref
@property def default_value(self): """The default value of the table.""" return self._default_value
@property def init(self): """The table initialization op.""" return self._init
def lookup(self, keys, name=None): """Looks up `keys` in a table, outputs the corresponding values. The `default_value` is used for keys not present in the table. """ key_tensor = keys with ops.name_scope(name, "%s_Lookup" % self._name, (self._table_ref, key_tensor, self._default_value)) as scope: values = gen_lookup_ops.lookup_table_find_v2( self._table_ref, key_tensor, self._default_value, name=scope)
values.set_shape(key_tensor.get_shape()) if isinstance(keys, sparse_tensor.SparseTensor): return sparse_tensor.SparseTensor(keys.indices, values, keys.dense_shape) else: return values
复制代码


InitializableLookupTableBase 类继承于 HashTable,InitializableLookupTableBase 类中有 init 操作和 look_up,所以表的初始化和 look_up 操作其实都是在父类中完成,其中表的初始化操作上面已经讨论过,我们就重点看下 look_up 操作的相关 OP:



lookup_table_find_v2
// Table lookup op. Perform the lookup operation on the given table.class LookupTableFindOp : public OpKernel { public: explicit LookupTableFindOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override { lookup::LookupInterface* table; OP_REQUIRES_OK(ctx, GetLookupTable("table_handle", ctx, &table)); “”“ ...... ”“”
const Tensor& key = ctx->input(1); const Tensor& default_value = ctx->input(2); OP_REQUIRES_OK(ctx, table->CheckFindArguments(key, default_value));
TensorShape output_shape = key.shape(); output_shape.RemoveLastDims(table->key_shape().dims()); output_shape.AppendShape(table->value_shape()); Tensor* out; OP_REQUIRES_OK(ctx, ctx->allocate_output("values", output_shape, &out));
OP_REQUIRES_OK(ctx, table->Find(ctx, key, out, default_value)); }};
REGISTER_KERNEL_BUILDER(Name("LookupTableFind").Device(DEVICE_CPU), LookupTableFindOp);REGISTER_KERNEL_BUILDER(Name("LookupTableFindV2").Device(DEVICE_CPU), LookupTableFindOp);
# table->Find调用的是HashTable类的DoFind方法Status DoFind(const Tensor& key, Tensor* value, const Tensor& default_value) override { const V default_val = default_value.flat<V>()(0); const auto key_values = key.flat<K>(); auto value_values = value->flat<V>();
for (int64 i = 0; i < key_values.size(); ++i) { value_values(i) = gtl::FindWithDefault( *table_, SubtleMustCopyIfIntegral(key_values(i)), default_val); } return Status::OK(); }

复制代码


最后我们分析下 IdTableWithHashBuckets 类,该类可用一个 hash 函数返回那些不在 vocabulary_list 表里 id 元素的 index 值,先举个例子:


import tensorflow as tfnum_oov_buckets = 3input_tensor = tf.constant(["emerson", "lake", "palmer", "king", "crimnson"])table = tf.IdTableWithHashBuckets(      tf.HashTable(tf.TextFileIdTableInitializer(filename), default_value),      num_oov_buckets)out = table.lookup(input_tensor).table.init.run()print(out.eval())
#filename依次为"emerson", "lake", "palmer"
复制代码


#结果为[0, 1, 2, 4, 7]


下面看下 IdTableWithHashBuckets 的源码:


class IdTableWithHashBuckets(LookupInterface):  """String to Id table wrapper that assigns out-of-vocabulary keys to buckets.  For example, if an instance of `IdTableWithHashBuckets` is initialized with a  string-to-id table that maps:  * `emerson -> 0`  * `lake -> 1`  * `palmer -> 2`  The `IdTableWithHashBuckets` object will performs the following mapping:  * `emerson -> 0`  * `lake -> 1`  * `palmer -> 2`  * `<other term> -> bucket_id`, where bucket_id will be between `3` and  `3 + num_oov_buckets - 1`, calculated by:  `hash(<term>) % num_oov_buckets + vocab_size`  If input_tensor is `["emerson", "lake", "palmer", "king", "crimson"]`,  the lookup result is `[0, 1, 2, 4, 7]`.  If `table` is None, only out-of-vocabulary buckets are used.  """
def __init__(self, table, num_oov_buckets, hasher_spec=FastHashSpec, name=None, key_dtype=None): """Construct a `IdTableWithHashBuckets` object. Args: table: Table that maps `tf.string` or `tf.int64` keys to `tf.int64` ids. num_oov_buckets: Number of buckets to use for out-of-vocabulary keys. hasher_spec: A `HasherSpec` to specify the hash function to use for assignation of out-of-vocabulary buckets (optional). name: A name for the operation (optional). key_dtype: Data type of keys passed to `lookup`. Defaults to `table.key_dtype` if `table` is specified, otherwise `tf.string`. Must be string or integer, and must be castable to `table.key_dtype`. Raises: ValueError: when `table` in None and `num_oov_buckets` is not positive. TypeError: when `hasher_spec` is invalid. """ # If a name ends with a '/' it is a "name scope", remove all trailing '/' # characters to use as table name. “”“ 类型检查 “”“
self._num_oov_buckets = num_oov_buckets super(IdTableWithHashBuckets, self).__init__(key_dtype, dtypes.int64, name.split("/")[-1])
@property def init(self): """The table initialization op.""" if self._table: return self._table.init with ops.name_scope(None, "init"): return control_flow_ops.no_op()
@property def table_ref(self): if self._table is not None: return self._table.table_ref return None
def _get_string_to_hash_bucket_fn(self, hasher_spec): """Returns the string_to_hash_bucket op to use based on `hasher_spec`.""" if not isinstance(hasher_spec, HasherSpec): raise TypeError("hasher_spec must be of type HasherSpec %s" % hasher_spec) if hasher_spec.hasher == "fasthash": return string_ops.string_to_hash_bucket_fast if hasher_spec.hasher == "legacy": return string_ops.string_to_hash_bucket if hasher_spec.hasher == "stronghash": return functools.partial( string_ops.string_to_hash_bucket_strong, key=hasher_spec.key) raise ValueError("Unknown hasher %s" % hasher_spec.hasher)
def lookup(self, keys, name=None):
values = keys if self._num_oov_buckets == 0: ids = self._table.lookup(values, name=name) else: # TODO(yleon): Consider moving this functionality to its own kernel. with ops.name_scope(name, "%s_Lookup" % self.name) as scope: str_to_hash_bucket = self._get_string_to_hash_bucket_fn( self._hasher_spec) buckets = str_to_hash_bucket( _as_string(values), num_buckets=self._num_oov_buckets, name="hash_bucket") if self._table: ids = self._table.lookup(values) buckets = math_ops.add(buckets, self._table.size()) is_id_non_default = math_ops.not_equal(ids, self._table.default_value) ids = array_ops.where(is_id_non_default, ids, buckets, name=scope) else: ids = buckets return ids
复制代码


上面的 lookup 操作为核心代码,大致流程为:先将当前待查找的所有 keys 通过 str_to_hash_bucket 这个 hash 函数映射成 int 类型的 bucket 号,然后为了防止和已经在表中的那些 index 冲突,需要统一加上 table.size()。然后在 table 中查找所有 keys 的 index,如果在 id table 表中会返回正确的 index,如果不在就返回默认值。然后将那些查找后等于默认值的 keys 的 index 换成通过 hash 函数得到的 bucket 号。这样存在的性能问题是当 id table 表较大时,会有很多时间浪费在没必要的 hash 操作上。所以一种优化方案是:先查表得到所有 keys 的 index,然后只 hash 那些映射值等于默认值的元素。


我自己基于这个思路实现了一版,暂且就称为 lookup_v2 吧,但是虽然这样不用为所有的 keys 做 hash,但也增加了几个操作,所以性能还不能保证。后续会考虑直接优化 c++层面的代码。下面是核心代码:


 def lookup_v2(self, keys, name=None):
values = keys if self._num_oov_buckets == 0: ids = self._table.lookup(values, name=name) else: with ops.name_scope(name, "%s_Lookup" % self.name) as scope: str_to_hash_bucket = self._get_string_to_hash_bucket_fn( self._hasher_spec) if self._table: init_ids = self._table.lookup(values) is_id_default_idx = array_ops.where(math_ops.equal(init_ids, self._table.default_value)) hash_values = array_ops.gather(values, array_ops.squeeze(is_id_default_idx)) default_buckets = str_to_hash_bucket( _as_string(hash_values), num_buckets=self._num_oov_buckets, name="hash_bucket") default_buckets = math_ops.add(default_buckets, self._table.size()) default_buckets = control_flow_ops.cond(gen_math_ops.equal(array_ops.size(default_buckets), 1), lambda: array_ops.expand_dims(default_buckets, axis=0), lambda: default_buckets) ids = gen_array_ops.tensor_scatter_update(init_ids, is_id_default_idx, default_buckets) else: ids = str_to_hash_bucket(_as_string(values), num_buckets=self._num_oov_buckets, name="hash_bucket") return ids
复制代码


参考文献:


https://mathmach.com/2442aa9e/


https://github.com/tensorflow/tensorflow/blob/5b900cfe4b3b848f577315a0dde09a729f770e95/tensorflow/python/ops/lookup_ops.py#L1042


本文转载自 Alex-zhai 知乎账号。


原文链接:https://zhuanlan.zhihu.com/p/93116229


2019-11-28 08:003141

评论

发布
暂无评论
发现更多内容

【Week02】框架设计

Aldaron

第二周总结

qqq

极客大学架构师训练营

面向开发者的 WSL2 安装指南

simpleapples

Python Windows 10 wsl Go 语言

【架构】—回归本质(面向对象)

不二架构

面向对象 架构师 极客大学架构师训练营

第二周 作业

尔东雨田

【大厂面试06期】谈一谈你对Redis持久化的理解?

NotFound9

数据库 redis 后端

第二周总结

晨光

Mybatis-plus 之 DIP

无心水

极客大学架构师训练营

看清远处模糊的事,不如做好身边清楚的事

Neco.W

创业心态 未知

当你启动Redis的时候,Redis做了什么

老胡爱分享

redis 源码分析 面试

28岁程序员期权过亿,彪悍从字节退休,网友:酸了酸了!

程序员生活志

程序员 字节跳动 开发 退休

小师妹学JVM之:JDK14中JVM的性能优化

程序那些事

JVM 小师妹 JIT JDK14 签约计划第二季

课程总结

GAC·DU

20年行业变革与技术演进,当下CDN如何为政企数字化转型加速?

阿里云Edge Plus

CDN 边缘计算 移动视频

第二周作业

晨光

架构训练营第二章作业

mh

依赖倒置及 Cache 重构设计

秤须苑

极客大学架构师训练营

第二周作业

Aldaron

数仓系列 | 深入解读 Flink 资源管理机制

Apache Flink

大数据 flink 流计算 实时计算

从字符串到常量池,一文看懂String类设计

程序员DMZ

JVM 常量池 intern

如何构建低延时的直播体验,让互动更实时?

阿里云Edge Plus

CDN 短视频 直播 视频

免费下载 | 阿里云实时计算整体解决方案白皮书重磅发布!

Apache Flink

大数据 flink 流计算 实时计算

设计原则之依赖倒置和接口隔离

dapaul

极客大学架构师训练营 框架设计、设计原则、设计模式 第四课 听课总结

John(易筋)

极客时间 极客大学 极客大学架构师训练营 设计原则 框架设计

依赖倒置架构

GAC·DU

架构师训练营 0 期第二周

Blink

红警1游戏开源,代码非常规范。网友:秀色可餐

程序员生活志

开源 红警1

使用WebMaker快速预览Ionic页面效果

davidce

Ionic WebMaker 混合应用开发

设计模式的主要原则

编程这件事

dapaul

架构师训练营第二章 总结

尔东雨田

分布式tensorflow源码解读3:lookup.index_table_from_tensor_文化 & 方法_Alex-zhai_InfoQ精选文章