分布式 tensorflow 源码解读 3:lookup.index_table_from_tensor

阅读数:14 2019 年 11 月 28 日 08:00

分布式tensorflow源码解读3:lookup.index_table_from_tensor

背景

推荐排序 dnn 模型中经常会用一些特别稀疏的 id 特征,此时需要对这些 id 特征做 embedding 操作,一般在 tensorflow 中都会使用一个 shape=[id_index_size, embedding_size] 的 Variable 矩阵做 embedding 参数,然后根据 id 特征的 index 去 Variable 矩阵中查表得到相应的 embedding 表示。这里需要注意的是:id_index_size 的大小一般都不会等于对应 id table 的元素个数,因为有很多 id 元素不在原始的 id table 表中,比如新上架的一些商品等。此时需要将 id_index_size 设置的大一些,以留一些位置给那些不在 id table 表的元素使用。那么从原始的 id 特征值映射成 [0, id_index_size) 间的 index 是怎么做的呢?

如果 id 量非常小的话,可在特征提取后把 id 排序一遍生成从 0 开始的连续 id 值,但在工业界的场景下 id 特征的量级往往是百万到上亿级别,很难做排序。幸好 tensorFlow 内部有一个函数可以将原始的 id 特征值映射到从 0 开始的 index,这个函数就是 lookup.index_table_from_tensor。

例子

首先举一个 index_table_from_tensor 的使用例子:

复制代码
import tensorflow as tf
sess = tf.Session()
vocabulary_list = tf.constant(["emerson", "lake", "palmer"])
table = tf.contrib.lookup.index_table_from_tensor(
vocabulary_list=vocabulary_list, num_oov_buckets=10, default_value=-1)
features = tf.constant(["emerson", "lake", "and", "palmer", "dad", "mom", "hello"])
table.init.run(session=sess)
ids = table.lookup(features)
print(sess.run(ids))
返回值:[ 0 1 10 2 9 3 9]

index_table_from_tensor 函数的作用是:如果在 vocabulary_list 里面的 id 特征值,则映射为从 0 到 len(table)-1 之间的 index,如果不在 vocabulary_list 里面,则通过 hash 函数映射到 len(table) 至 len(table) + num_oov_buckets - 1 之间的区间中。

我们开始分析源代码,首先看下 index_table_from_tensor 的源码:

复制代码
def index_table_from_tensor(vocabulary_list,
num_oov_buckets=0,
default_value=-1,
hasher_spec=FastHashSpec,
dtype=dtypes.string,
name=None):
"""Returns a lookup table that converts a string tensor into int64 IDs.
This operation constructs a lookup table to convert tensor of strings into
int64 IDs. The mapping can be initialized from a string `vocabulary_list` 1-D
tensor where each element is a key and corresponding index within the tensor
is the value.
Args:
vocabulary_list: A 1-D `Tensor` that specifies the mapping of keys to
indices. The type of this object must be castable to `dtype`.
num_oov_buckets: The number of out-of-vocabulary buckets.
default_value: The value to use for out-of-vocabulary feature values.
Defaults to -1.
hasher_spec: A `HasherSpec` to specify the hash function to use for
assignment of out-of-vocabulary buckets.
dtype: The type of values passed to `lookup`. Only string and integers are
supported.
name: A name for this op (optional).
Returns:
The lookup table to map an input `Tensor` to index `int64` `Tensor`.
Raises:
ValueError: If `vocabulary_list` is invalid.
ValueError: If `num_oov_buckets` is negative.
"""
“”“
类型检查代码
”“”
with ops.name_scope(name, "string_to_index") as feat_to_id_scope:
# 获取由传入的 vocabulary_list 构造的 hash 表,keys 是 vocabulary_list 里面的元素,values
# 【0len(num_elements)-1
keys = ops.convert_to_tensor(vocabulary_list)
num_elements = array_ops.size(keys)
values = math_ops.to_int64(math_ops.range(num_elements))
shared_name = ""
with ops.name_scope(None, "hash_table") as hash_table_scope:
table_keys = math_ops.to_int64(keys) if keys.dtype.is_integer else keys
init = KeyValueTensorInitializer(
table_keys,
values,
table_keys.dtype.base_dtype,
dtypes.int64,
name="table_init")
table = HashTable(
init, default_value, shared_name=shared_name, name=hash_table_scope)
if num_oov_buckets:
table = IdTableWithHashBuckets(
table,
num_oov_buckets=num_oov_buckets,
hasher_spec=hasher_spec,
name=feat_to_id_scope,
key_dtype=dtype)
return table

从上面的代码片段可以看出,index_table_from_tensor 函数体中有三个比较重要的步骤:

根据 keys 和 values 创建 KeyValueTensorInitializer。
根据 KeyValueTensorInitializer 创建具体的 HashTable 类,对于不在 vocabulary_list 表里面的元素,统一返回默认值。
如果 num_oov_buckets>0,则会创建带有 Buckets 的 IdTableWithHashBuckets 类,用一个 hash 函数返回那些不在 vocabulary_list 表里面的 id 元素的 index 值。
首先我们来分析第一步的操作,先通过传入的 vocabulary_list 去构造 keys 和 values,然后传入到 KeyValueTensorInitializer 类中返回 table 初始化的 op,注意此时只是返回一个用于 table 初始化的 OP,还没有做具体的初始化工作(初始化要放到具体的 HashTable 中去完成)。下面看下 KeyValueTensorInitializer 的代码:

复制代码
class KeyValueTensorInitializer(TableInitializerBase):
"""Table initializers given `keys` and `values` tensors."""
def __init__(self, keys, values, key_dtype=None, value_dtype=None, name=None):
with ops.name_scope(name, "key_value_init", [keys, values]) as scope:
self._keys = ops.convert_to_tensor(keys, dtype=key_dtype, name="keys")
self._values = ops.convert_to_tensor(
values, dtype=value_dtype, name="values")
self._name = scope
super(KeyValueTensorInitializer, self).__init__(self._keys.dtype,
self._values.dtype)
def initialize(self, table):
"""Initializes the given `table` with `keys` and `values` tensors.
Args:
table: The table to initialize.
Returns:
The operation that initializes the table.
Raises:
TypeError: when the keys and values data types do not match the table
key and value data types.
"""
_check_table_dtypes(table, self._keys.dtype, self._values.dtype)
with ops.name_scope(
self._name, values=(table.table_ref, self._keys,
self._values)) as scope:
init_op = gen_lookup_ops.lookup_table_import_v2(
table.table_ref, self._keys, self._values, name=scope)
ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op)
return init_op
可看出 table 的初始化操作是通过 lookup_table_import_v2 这个 op 来完成,下面就分析下这个 op 的相关 c 代码:
REGISTER_OP("LookupTableImportV2")
.Input("table_handle: resource")
.Input("keys: Tin")
.Input("values: Tout")
.Attr("Tin: type")
.Attr("Tout: type")
.SetShapeFn( {
ShapeHandle handle;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle));
ShapeHandle keys;
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &keys));
TF_RETURN_IF_ERROR(c->Merge(keys, c->input(2), &keys));
return Status::OK();
});
// Clear the table and insert data.
class LookupTableImportOp : public OpKernel {
public:
explicit LookupTableImportOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override {
lookup::LookupInterface* table;
OP_REQUIRES_OK(ctx, GetLookupTable("table_handle", ctx, &table));
“”“
....
”“”
const Tensor& keys = ctx->input(1);
const Tensor& values = ctx->input(2);
OP_REQUIRES_OK(ctx, table->CheckKeyAndValueTensorsForImport(keys, values));
int memory_used_before = 0;
if (ctx->track_allocations()) {
memory_used_before = table->MemoryUsed();
}
OP_REQUIRES_OK(ctx, table->ImportValues(ctx, keys, values));
if (ctx->track_allocations()) {
ctx->record_persistent_memory_allocation(table->MemoryUsed() -
memory_used_before);
}
}
};

这里通过 table_handle 句柄来传入具体的 Table 类,上面举的例子中使用的是 HashTable 类,HashTable 类继承于 InitializableLookupTable 类,InitializableLookupTable 可理解为那些需要初始化操作的 table 集合,而 InitializableLookupTable 又继承于 LookupInterface 类。通过 table->ImportValues() 方法来将 vocabulary_list 中构造的 keys 和 values 导入到 hash 表中,完成 hash 表的初始化工作。

通过 LookupTableImportV2 返回 init_op 之后,将该初始化 hash 表的 op 传给 HashTable 类,HashTable 类和其父类的代码如下:

复制代码
class HashTable(InitializableLookupTableBase):
def __init__(self, initializer, default_value, shared_name=None, name=None):
"""Creates a non-initialized `HashTable` object.
Creates a table, the type of its keys and values are specified by the
initializer.
Before using the table you will have to initialize it. After initialization
the table will be immutable.
Args:
initializer: The table initializer to use. See `HashTable` kernel for
supported key and value types.
default_value: The value to use if a key is missing in the table.
shared_name: If non-empty, this table will be shared under
the given name across multiple sessions.
name: A name for the operation (optional).
Returns:
A `HashTable` object.
"""
with ops.name_scope(name, "hash_table", (initializer,
default_value)) as scope:
table_ref = gen_lookup_ops.hash_table_v2(
shared_name=shared_name,
key_dtype=initializer.key_dtype,
value_dtype=initializer.value_dtype,
name=scope)
super(HashTable, self).__init__(table_ref, default_value, initializer)
self._value_shape = self._default_value.get_shape()
class InitializableLookupTableBase(LookupInterface):
"""Initializable lookup table interface.
An initializable lookup tables persist across different steps.
"""
def __init__(self, table_ref, default_value, initializer):
"""Construct a table object from a table reference.
If requires a table initializer object (subclass of `TableInitializerBase`).
It provides the table key and value types, as well as the op to initialize
the table. The caller is responsible to execute the initialization op.
Args:
table_ref: The table reference, i.e. the output of the lookup table ops.
default_value: The value to use if a key is missing in the table.
initializer: The table initializer to use.
"""
name = table_ref.op.name.split("/")[-1]
super(InitializableLookupTableBase,
self).__init__(initializer.key_dtype, initializer.value_dtype,
name)
self._table_ref = table_ref
self._default_value = ops.convert_to_tensor(
default_value, dtype=self._value_dtype)
self._default_value.get_shape().merge_with(tensor_shape.scalar())
self._init = initializer.initialize(self)
@property
def table_ref(self):
"""Get the underlying table reference."""
return self._table_ref
@property
def default_value(self):
"""The default value of the table."""
return self._default_value
@property
def init(self):
"""The table initialization op."""
return self._init
def lookup(self, keys, name=None):
"""Looks up `keys` in a table, outputs the corresponding values.
The `default_value` is used for keys not present in the table.
"""
key_tensor = keys
with ops.name_scope(name, "%s_Lookup" % self._name,
(self._table_ref, key_tensor,
self._default_value)) as scope:
values = gen_lookup_ops.lookup_table_find_v2(
self._table_ref, key_tensor, self._default_value, name=scope)
values.set_shape(key_tensor.get_shape())
if isinstance(keys, sparse_tensor.SparseTensor):
return sparse_tensor.SparseTensor(keys.indices, values, keys.dense_shape)
else:
return values

InitializableLookupTableBase 类继承于 HashTable,InitializableLookupTableBase 类中有 init 操作和 look_up,所以表的初始化和 look_up 操作其实都是在父类中完成,其中表的初始化操作上面已经讨论过,我们就重点看下 look_up 操作的相关 OP:

复制代码
lookup_table_find_v2
// Table lookup op. Perform the lookup operation on the given table.
class LookupTableFindOp : public OpKernel {
public:
explicit LookupTableFindOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override {
lookup::LookupInterface* table;
OP_REQUIRES_OK(ctx, GetLookupTable("table_handle", ctx, &table));
“”“
......
”“”
const Tensor& key = ctx->input(1);
const Tensor& default_value = ctx->input(2);
OP_REQUIRES_OK(ctx, table->CheckFindArguments(key, default_value));
TensorShape output_shape = key.shape();
output_shape.RemoveLastDims(table->key_shape().dims());
output_shape.AppendShape(table->value_shape());
Tensor* out;
OP_REQUIRES_OK(ctx, ctx->allocate_output("values", output_shape, &out));
OP_REQUIRES_OK(ctx, table->Find(ctx, key, out, default_value));
}
};
REGISTER_KERNEL_BUILDER(Name("LookupTableFind").Device(DEVICE_CPU),
LookupTableFindOp);
REGISTER_KERNEL_BUILDER(Name("LookupTableFindV2").Device(DEVICE_CPU),
LookupTableFindOp);
# table->Find 调用的是 HashTable 类的 DoFind 方法
Status DoFind(const Tensor& key, Tensor* value,
const Tensor& default_value) override {
const V default_val = default_value.flat<V>()(0);
const auto key_values = key.flat<K>();
auto value_values = value->flat<V>();
for (int64 i = 0; i < key_values.size(); ++i) {
value_values(i) = gtl::FindWithDefault(
*table_, SubtleMustCopyIfIntegral(key_values(i)), default_val);
}
return Status::OK();
}

最后我们分析下 IdTableWithHashBuckets 类,该类可用一个 hash 函数返回那些不在 vocabulary_list 表里 id 元素的 index 值,先举个例子:

复制代码
import tensorflow as tf
num_oov_buckets = 3
input_tensor = tf.constant(["emerson", "lake", "palmer", "king", "crimnson"])
table = tf.IdTableWithHashBuckets(
tf.HashTable(tf.TextFileIdTableInitializer(filename), default_value),
num_oov_buckets)
out = table.lookup(input_tensor).
table.init.run()
print(out.eval())
#filename 依次为 "emerson", "lake", "palmer"

#结果为 [0, 1, 2, 4, 7]
下面看下 IdTableWithHashBuckets 的源码:

复制代码
class IdTableWithHashBuckets(LookupInterface):
"""String to Id table wrapper that assigns out-of-vocabulary keys to buckets.
For example, if an instance of `IdTableWithHashBuckets` is initialized with a
string-to-id table that maps:
* `emerson -> 0`
* `lake -> 1`
* `palmer -> 2`
The `IdTableWithHashBuckets` object will performs the following mapping:
* `emerson -> 0`
* `lake -> 1`
* `palmer -> 2`
* `<other term> -> bucket_id`, where bucket_id will be between `3` and
`3 + num_oov_buckets - 1`, calculated by:
`hash(<term>) % num_oov_buckets + vocab_size`
If input_tensor is `["emerson", "lake", "palmer", "king", "crimson"]`,
the lookup result is `[0, 1, 2, 4, 7]`.
If `table` is None, only out-of-vocabulary buckets are used.
"""
{1}
def __init__(self,
table,
num_oov_buckets,
hasher_spec=FastHashSpec,
name=None,
key_dtype=None):
"""Construct a `IdTableWithHashBuckets` object.
Args:
table: Table that maps `tf.string` or `tf.int64` keys to `tf.int64` ids.
num_oov_buckets: Number of buckets to use for out-of-vocabulary keys.
hasher_spec: A `HasherSpec` to specify the hash function to use for
assignation of out-of-vocabulary buckets (optional).
name: A name for the operation (optional).
key_dtype: Data type of keys passed to `lookup`. Defaults to
`table.key_dtype` if `table` is specified, otherwise `tf.string`.
Must be string or integer, and must be castable to `table.key_dtype`.
Raises:
ValueError: when `table` in None and `num_oov_buckets` is not positive.
TypeError: when `hasher_spec` is invalid.
"""
# If a name ends with a '/' it is a "name scope", remove all trailing '/'
# characters to use as table name.
“”“
类型检查
“”“
{1}
self._num_oov_buckets = num_oov_buckets
super(IdTableWithHashBuckets, self).__init__(key_dtype, dtypes.int64,
name.split("/")[-1])
{1}
@property
def init(self):
"""The table initialization op."""
if self._table:
return self._table.init
with ops.name_scope(None, "init"):
return control_flow_ops.no_op()
{1}
@property
def table_ref(self):
if self._table is not None:
return self._table.table_ref
return None
{1}
def _get_string_to_hash_bucket_fn(self, hasher_spec):
"""Returns the string_to_hash_bucket op to use based on `hasher_spec`."""
if not isinstance(hasher_spec, HasherSpec):
raise TypeError("hasher_spec must be of type HasherSpec %s" % hasher_spec)
if hasher_spec.hasher == "fasthash":
return string_ops.string_to_hash_bucket_fast
if hasher_spec.hasher == "legacy":
return string_ops.string_to_hash_bucket
if hasher_spec.hasher == "stronghash":
return functools.partial(
string_ops.string_to_hash_bucket_strong, key=hasher_spec.key)
raise ValueError("Unknown hasher %s" % hasher_spec.hasher)
{1}
def lookup(self, keys, name=None):
{1}
values = keys
if self._num_oov_buckets == 0:
ids = self._table.lookup(values, name=name)
else:
# TODO(yleon): Consider moving this functionality to its own kernel.
with ops.name_scope(name, "%s_Lookup" % self.name) as scope:
str_to_hash_bucket = self._get_string_to_hash_bucket_fn(
self._hasher_spec)
buckets = str_to_hash_bucket(
_as_string(values),
num_buckets=self._num_oov_buckets,
name="hash_bucket")
if self._table:
ids = self._table.lookup(values)
buckets = math_ops.add(buckets, self._table.size())
is_id_non_default = math_ops.not_equal(ids, self._table.default_value)
ids = array_ops.where(is_id_non_default, ids, buckets, name=scope)
else:
ids = buckets
return ids
{1}

上面的 lookup 操作为核心代码,大致流程为:先将当前待查找的所有 keys 通过 str_to_hash_bucket 这个 hash 函数映射成 int 类型的 bucket 号,然后为了防止和已经在表中的那些 index 冲突,需要统一加上 table.size()。然后在 table 中查找所有 keys 的 index,如果在 id table 表中会返回正确的 index,如果不在就返回默认值。然后将那些查找后等于默认值的 keys 的 index 换成通过 hash 函数得到的 bucket 号。这样存在的性能问题是当 id table 表较大时,会有很多时间浪费在没必要的 hash 操作上。所以一种优化方案是:先查表得到所有 keys 的 index,然后只 hash 那些映射值等于默认值的元素。

我自己基于这个思路实现了一版,暂且就称为 lookup_v2 吧,但是虽然这样不用为所有的 keys 做 hash,但也增加了几个操作,所以性能还不能保证。后续会考虑直接优化 c++ 层面的代码。下面是核心代码:

复制代码
def lookup_v2(self, keys, name=None):
values = keys
if self._num_oov_buckets == 0:
ids = self._table.lookup(values, name=name)
else:
with ops.name_scope(name, "%s_Lookup" % self.name) as scope:
str_to_hash_bucket = self._get_string_to_hash_bucket_fn(
self._hasher_spec)
if self._table:
init_ids = self._table.lookup(values)
is_id_default_idx = array_ops.where(math_ops.equal(init_ids, self._table.default_value))
hash_values = array_ops.gather(values, array_ops.squeeze(is_id_default_idx))
default_buckets = str_to_hash_bucket(
_as_string(hash_values),
num_buckets=self._num_oov_buckets,
name="hash_bucket")
default_buckets = math_ops.add(default_buckets, self._table.size())
default_buckets = control_flow_ops.cond(gen_math_ops.equal(array_ops.size(default_buckets), 1),
lambda: array_ops.expand_dims(default_buckets, axis=0),
lambda: default_buckets)
ids = gen_array_ops.tensor_scatter_update(init_ids, is_id_default_idx, default_buckets)
else:
ids = str_to_hash_bucket(_as_string(values), num_buckets=self._num_oov_buckets, name="hash_bucket")
return ids

参考文献:

https://mathmach.com/2442aa9e/
https://github.com/tensorflow/tensorflow/blob/5b900cfe4b3b848f577315a0dde09a729f770e95/tensorflow/python/ops/lookup_ops.py#L1042

本文转载自 Alex-zhai 知乎账号。

原文链接: https://zhuanlan.zhihu.com/p/93116229

评论

发布