Skip to content

Commit

Permalink
[Embedding] Support immutable EmbeddingVariable in inference mode. (#425
Browse files Browse the repository at this point in the history
)
  • Loading branch information
candyzone authored and lixy9474 committed Sep 13, 2022
1 parent cec4417 commit 3d39cf6
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 2 deletions.
14 changes: 13 additions & 1 deletion tensorflow/core/framework/embedding/embedding_filter.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,19 @@ template<typename K, typename V, typename EV>
class EmbeddingFilter {
public:
virtual void LookupOrCreate(K key, V* val, const V* default_value_ptr,
ValuePtr<V>** value_ptr, int count) = 0;
ValuePtr<V>** value_ptr, int count) = 0;

virtual void Lookup(EV* ev, K key, V* val, const V* default_value_ptr) {
ValuePtr<V>* value_ptr = nullptr;
Status s = ev->LookupKey(key, &value_ptr);
if (s.ok()) {
V* mem_val = ev->LookupPrimaryEmb(value_ptr);
memcpy(val, mem_val, sizeof(V) * ev->ValueLen());
} else {
memcpy(val, default_value_ptr, sizeof(V) * ev->ValueLen());
}
}

virtual Status LookupOrCreateKey(K key, ValuePtr<V>** val, bool* is_filter) = 0;

virtual int64 GetFreq(K key, ValuePtr<V>* value_ptr) = 0;
Expand Down
9 changes: 9 additions & 0 deletions tensorflow/core/framework/embedding/embedding_var.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,10 @@ class EmbeddingVar : public ResourceBase {
return is_initialized_;
}

Status LookupKey(K key, ValuePtr<V>** value_ptr) {
return storage_manager_->Get(key, value_ptr);
}

Status LookupOrCreateKey(K key, ValuePtr<V>** value_ptr, bool* is_filter) {
return filter_->LookupOrCreateKey(key, value_ptr, is_filter);
}
Expand Down Expand Up @@ -127,6 +131,11 @@ class EmbeddingVar : public ResourceBase {
return filter_->GetFreq(key);
}

void Lookup(K key, V* val, V* default_v) {
const V* default_value_ptr = (default_v == nullptr) ? default_value_ : default_v;
filter_->Lookup(this, key, val, default_value_ptr);
}

void LookupOrCreate(K key, V* val, V* default_v, int count = 1) {
const V* default_value_ptr = (default_v == nullptr) ? default_value_ : default_v;
ValuePtr<V>* value_ptr = nullptr;
Expand Down
12 changes: 12 additions & 0 deletions tensorflow/core/framework/embedding/multilevel_embedding.h
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,18 @@ class StorageManager {
}
}

Status Get(K key, ValuePtr<V>** value_ptr) {
Status s;
int level = 0;
for (; level < hash_table_count_; ++level) {
s = kvs_[level].first->Lookup(key, value_ptr);
if (s.ok()) {
break;
}
}
return s;
}

Status GetOrCreate(K key, ValuePtr<V>** value_ptr, size_t size) {
bool found = false;
int level = 0;
Expand Down
21 changes: 20 additions & 1 deletion tensorflow/core/kernels/kv_variable_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ using GPUDevice = Eigen::GpuDevice;
namespace {
const int64 kEmbeddingVarUseDB = -214;
const int64 kInitializableEmbeddingVarUseDB = -215;
const char* kInferenceMode = "INFERENCE_MODE";
}

#define REGISTER_KV_VAR_HANDLE(ktype, vtype) \
Expand Down Expand Up @@ -370,6 +371,10 @@ template <typename TKey, typename TValue>
class KvResourceGatherOp : public OpKernel {
public:
explicit KvResourceGatherOp(OpKernelConstruction* c) : OpKernel(c) {
OP_REQUIRES_OK(c, c->GetAttr("is_inference", &is_inference_));
bool is_inference;
TF_CHECK_OK(ReadBoolFromEnvVar(kInferenceMode, false, &is_inference));
is_inference_ |= is_inference;
OP_REQUIRES_OK(c,
c->GetAttr("is_use_default_value_tensor",
&is_use_default_value_tensor_));
Expand All @@ -393,6 +398,17 @@ class KvResourceGatherOp : public OpKernel {
return 1;
};
}
if (!is_inference_) {
lookup_fn_ = [](EmbeddingVar<TKey, TValue>* ev, TKey key,
TValue* val, TValue* default_v, int count) {
ev->LookupOrCreate(key, val, default_v, count);
};
} else {
lookup_fn_ = [](EmbeddingVar<TKey, TValue>* ev, TKey key,
TValue* val, TValue* default_v, int count) {
ev->Lookup(key, val, default_v);
};
}
}

void Compute(OpKernelContext* c) override {
Expand Down Expand Up @@ -443,7 +459,7 @@ class KvResourceGatherOp : public OpKernel {
default_v, indices_flat(i), i, ev->GetDefaultValueDim(),
ev->ValueLen());
int32 count = get_count_fn_(counts, i);
ev->LookupOrCreate(indices_flat(i),
lookup_fn_(ev, indices_flat(i),
out_base + i * slice_elems, default_v_ptr, count);
}
};
Expand All @@ -463,9 +479,12 @@ class KvResourceGatherOp : public OpKernel {

private:
bool is_use_default_value_tensor_;
bool is_inference_;
std::function<
TValue*(TValue*, TKey, int64, int64, int64)> get_default_v_fn_;
std::function<int32(int32*, int64)> get_count_fn_;
std::function<void(EmbeddingVar<TKey, TValue>* ev,
TKey key, TValue* val, TValue* default_v, int count)> lookup_fn_;
};

#define REGISTER_GATHER_FULL(dev, ktype, vtype) \
Expand Down
2 changes: 2 additions & 0 deletions tensorflow/core/ops/kv_variable_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ REGISTER_OP("KvResourceGatherV1")
.Input("counts: counts_type")
.Attr("validate_indices: bool = true")
.Attr("is_use_default_value_tensor: bool = false")
.Attr("is_inference: bool = false")
.Output("output: dtype")
.Attr("dtype: type")
.Attr("Tkeys: {int64,int32,string}")
Expand Down Expand Up @@ -281,6 +282,7 @@ REGISTER_OP("KvResourceGather")
.Output("output: dtype")
.Attr("dtype: type")
.Attr("Tkeys: {int64,int32,string}")
.Attr("is_inference: bool = false")
.SetShapeFn([](InferenceContext* c) {
ShapeAndType handle_shape_and_type;
TF_RETURN_IF_ERROR(
Expand Down
1 change: 1 addition & 0 deletions tensorflow/python/ops/embedding_variable_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from six.moves import xrange # pylint: disable=redefined-builtin

from tensorflow.core.framework import attr_value_pb2
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import string_ops
Expand Down

0 comments on commit 3d39cf6

Please sign in to comment.