diff --git a/tensorflow/core/framework/embedding/embedding_filter.h b/tensorflow/core/framework/embedding/embedding_filter.h index 53c37b33d8b..97cafe7c6a6 100644 --- a/tensorflow/core/framework/embedding/embedding_filter.h +++ b/tensorflow/core/framework/embedding/embedding_filter.h @@ -45,6 +45,18 @@ class EmbeddingFilter { public: virtual void LookupOrCreate(K key, V* val, const V* default_value_ptr, ValuePtr** value_ptr, int count) = 0; + + virtual void Lookup(EV* ev, K key, V* val, const V* default_value_ptr) { + ValuePtr* value_ptr = nullptr; + Status s = ev->LookupKey(key, &value_ptr); + if (s.ok()) { + V* mem_val = ev->LookupPrimaryEmb(value_ptr); + memcpy(val, mem_val, sizeof(V) * ev->ValueLen()); + } else { + memcpy(val, default_value_ptr, sizeof(V) * ev->ValueLen()); + } + } + virtual Status LookupOrCreateKey(K key, ValuePtr** val, bool* is_filter) = 0; virtual int64 GetFreq(K key, ValuePtr* value_ptr) = 0; diff --git a/tensorflow/core/framework/embedding/embedding_var.h b/tensorflow/core/framework/embedding/embedding_var.h index bcf95648d0e..802aae5a0be 100644 --- a/tensorflow/core/framework/embedding/embedding_var.h +++ b/tensorflow/core/framework/embedding/embedding_var.h @@ -99,6 +99,10 @@ class EmbeddingVar : public ResourceBase { return is_initialized_; } + Status LookupKey(K key, ValuePtr** value_ptr) { + return storage_manager_->Get(key, value_ptr); + } + Status LookupOrCreateKey(K key, ValuePtr** value_ptr, bool* is_filter) { return filter_->LookupOrCreateKey(key, value_ptr, is_filter); } @@ -127,6 +131,11 @@ class EmbeddingVar : public ResourceBase { return filter_->GetFreq(key); } + void Lookup(K key, V* val, V* default_v) { + const V* default_value_ptr = (default_v == nullptr) ? default_value_ : default_v; + filter_->Lookup(this, key, val, default_value_ptr); + } + void LookupOrCreate(K key, V* val, V* default_v, int count = 1) { const V* default_value_ptr = (default_v == nullptr) ? default_value_ : default_v; ValuePtr* value_ptr = nullptr; diff --git a/tensorflow/core/framework/embedding/multilevel_embedding.h b/tensorflow/core/framework/embedding/multilevel_embedding.h index f22d5f1ad62..1c54d4f9465 100644 --- a/tensorflow/core/framework/embedding/multilevel_embedding.h +++ b/tensorflow/core/framework/embedding/multilevel_embedding.h @@ -223,6 +223,18 @@ class StorageManager { } } + Status Get(K key, ValuePtr** value_ptr) { + Status s; + int level = 0; + for (; level < hash_table_count_; ++level) { + s = kvs_[level].first->Lookup(key, value_ptr); + if (s.ok()) { + break; + } + } + return s; + } + Status GetOrCreate(K key, ValuePtr** value_ptr, size_t size) { bool found = false; int level = 0; diff --git a/tensorflow/core/kernels/kv_variable_ops.cc b/tensorflow/core/kernels/kv_variable_ops.cc index 246eff34182..b888cb83f4c 100644 --- a/tensorflow/core/kernels/kv_variable_ops.cc +++ b/tensorflow/core/kernels/kv_variable_ops.cc @@ -52,6 +52,7 @@ using GPUDevice = Eigen::GpuDevice; namespace { const int64 kEmbeddingVarUseDB = -214; const int64 kInitializableEmbeddingVarUseDB = -215; +const char* kInferenceMode = "INFERENCE_MODE"; } #define REGISTER_KV_VAR_HANDLE(ktype, vtype) \ @@ -370,6 +371,10 @@ template class KvResourceGatherOp : public OpKernel { public: explicit KvResourceGatherOp(OpKernelConstruction* c) : OpKernel(c) { + OP_REQUIRES_OK(c, c->GetAttr("is_inference", &is_inference_)); + bool is_inference; + TF_CHECK_OK(ReadBoolFromEnvVar(kInferenceMode, false, &is_inference)); + is_inference_ |= is_inference; OP_REQUIRES_OK(c, c->GetAttr("is_use_default_value_tensor", &is_use_default_value_tensor_)); @@ -393,6 +398,17 @@ class KvResourceGatherOp : public OpKernel { return 1; }; } + if (!is_inference_) { + lookup_fn_ = [](EmbeddingVar* ev, TKey key, + TValue* val, TValue* default_v, int count) { + ev->LookupOrCreate(key, val, default_v, count); + }; + } else { + lookup_fn_ = [](EmbeddingVar* ev, TKey key, + TValue* val, TValue* default_v, int count) { + ev->Lookup(key, val, default_v); + }; + } } void Compute(OpKernelContext* c) override { @@ -443,7 +459,7 @@ class KvResourceGatherOp : public OpKernel { default_v, indices_flat(i), i, ev->GetDefaultValueDim(), ev->ValueLen()); int32 count = get_count_fn_(counts, i); - ev->LookupOrCreate(indices_flat(i), + lookup_fn_(ev, indices_flat(i), out_base + i * slice_elems, default_v_ptr, count); } }; @@ -463,9 +479,12 @@ class KvResourceGatherOp : public OpKernel { private: bool is_use_default_value_tensor_; + bool is_inference_; std::function< TValue*(TValue*, TKey, int64, int64, int64)> get_default_v_fn_; std::function get_count_fn_; + std::function* ev, + TKey key, TValue* val, TValue* default_v, int count)> lookup_fn_; }; #define REGISTER_GATHER_FULL(dev, ktype, vtype) \ diff --git a/tensorflow/core/ops/kv_variable_ops.cc b/tensorflow/core/ops/kv_variable_ops.cc index d87b55dc396..fa07addc051 100644 --- a/tensorflow/core/ops/kv_variable_ops.cc +++ b/tensorflow/core/ops/kv_variable_ops.cc @@ -231,6 +231,7 @@ REGISTER_OP("KvResourceGatherV1") .Input("counts: counts_type") .Attr("validate_indices: bool = true") .Attr("is_use_default_value_tensor: bool = false") + .Attr("is_inference: bool = false") .Output("output: dtype") .Attr("dtype: type") .Attr("Tkeys: {int64,int32,string}") @@ -281,6 +282,7 @@ REGISTER_OP("KvResourceGather") .Output("output: dtype") .Attr("dtype: type") .Attr("Tkeys: {int64,int32,string}") + .Attr("is_inference: bool = false") .SetShapeFn([](InferenceContext* c) { ShapeAndType handle_shape_and_type; TF_RETURN_IF_ERROR( diff --git a/tensorflow/python/ops/embedding_variable_ops_test.py b/tensorflow/python/ops/embedding_variable_ops_test.py index cb5c10b10a6..ab6524ca8f9 100644 --- a/tensorflow/python/ops/embedding_variable_ops_test.py +++ b/tensorflow/python/ops/embedding_variable_ops_test.py @@ -16,6 +16,7 @@ from six.moves import xrange # pylint: disable=redefined-builtin +from tensorflow.core.framework import attr_value_pb2 from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import string_ops