Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Embedding] Support immutable EmbeddingVariable in inference mode. #438

Open
wants to merge 1 commit into
base: deeprec2206
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions tensorflow/core/framework/embedding/embedding_filter.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,18 @@ class EmbeddingFilter {
public:
virtual void LookupOrCreate(K key, V* val, const V* default_value_ptr,
ValuePtr<V>** value_ptr, int count) = 0;

virtual void Lookup(EV* ev, K key, V* val, const V* default_value_ptr) {
ValuePtr<V>* value_ptr = nullptr;
Status s = ev->LookupKey(key, &value_ptr);
if (s.ok()) {
V* mem_val = ev->LookupPrimaryEmb(value_ptr);
memcpy(val, mem_val, sizeof(V) * ev->ValueLen());
} else {
memcpy(val, default_value_ptr, sizeof(V) * ev->ValueLen());
}
}

virtual Status LookupOrCreateKey(K key, ValuePtr<V>** val, bool* is_filter) = 0;

virtual int64 GetFreq(K key, ValuePtr<V>* value_ptr) = 0;
Expand Down
9 changes: 9 additions & 0 deletions tensorflow/core/framework/embedding/embedding_var.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,10 @@ class EmbeddingVar : public ResourceBase {
return is_initialized_;
}

Status LookupKey(K key, ValuePtr<V>** value_ptr) {
return storage_manager_->Get(key, value_ptr);
}

Status LookupOrCreateKey(K key, ValuePtr<V>** value_ptr, bool* is_filter) {
return filter_->LookupOrCreateKey(key, value_ptr, is_filter);
}
Expand Down Expand Up @@ -127,6 +131,11 @@ class EmbeddingVar : public ResourceBase {
return filter_->GetFreq(key);
}

void Lookup(K key, V* val, V* default_v) {
const V* default_value_ptr = (default_v == nullptr) ? default_value_ : default_v;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

代码行80

filter_->Lookup(this, key, val, default_value_ptr);
}

void LookupOrCreate(K key, V* val, V* default_v, int count = 1) {
const V* default_value_ptr = (default_v == nullptr) ? default_value_ : default_v;
ValuePtr<V>* value_ptr = nullptr;
Expand Down
12 changes: 12 additions & 0 deletions tensorflow/core/framework/embedding/multilevel_embedding.h
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,18 @@ class StorageManager {
}
}

Status Get(K key, ValuePtr<V>** value_ptr) {
Status s;
int level = 0;
for (; level < hash_table_count_; ++level) {
s = kvs_[level].first->Lookup(key, value_ptr);
if (s.ok()) {
break;
}
}
return s;
}

Status GetOrCreate(K key, ValuePtr<V>** value_ptr, size_t size) {
bool found = false;
int level = 0;
Expand Down
21 changes: 20 additions & 1 deletion tensorflow/core/kernels/kv_variable_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ using GPUDevice = Eigen::GpuDevice;
namespace {
const int64 kEmbeddingVarUseDB = -214;
const int64 kInitializableEmbeddingVarUseDB = -215;
const char* kInferenceMode = "INFERENCE_MODE";
}

#define REGISTER_KV_VAR_HANDLE(ktype, vtype) \
Expand Down Expand Up @@ -370,6 +371,10 @@ template <typename TKey, typename TValue>
class KvResourceGatherOp : public OpKernel {
public:
explicit KvResourceGatherOp(OpKernelConstruction* c) : OpKernel(c) {
OP_REQUIRES_OK(c, c->GetAttr("is_inference", &is_inference_));
bool is_inference;
TF_CHECK_OK(ReadBoolFromEnvVar(kInferenceMode, false, &is_inference));
is_inference_ |= is_inference;
OP_REQUIRES_OK(c,
c->GetAttr("is_use_default_value_tensor",
&is_use_default_value_tensor_));
Expand All @@ -393,6 +398,17 @@ class KvResourceGatherOp : public OpKernel {
return 1;
};
}
if (!is_inference_) {
lookup_fn_ = [](EmbeddingVar<TKey, TValue>* ev, TKey key,
TValue* val, TValue* default_v, int count) {
ev->LookupOrCreate(key, val, default_v, count);
};
} else {
lookup_fn_ = [](EmbeddingVar<TKey, TValue>* ev, TKey key,
TValue* val, TValue* default_v, int count) {
ev->Lookup(key, val, default_v);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

少了count?

};
}
}

void Compute(OpKernelContext* c) override {
Expand Down Expand Up @@ -443,7 +459,7 @@ class KvResourceGatherOp : public OpKernel {
default_v, indices_flat(i), i, ev->GetDefaultValueDim(),
ev->ValueLen());
int32 count = get_count_fn_(counts, i);
ev->LookupOrCreate(indices_flat(i),
lookup_fn_(ev, indices_flat(i),
out_base + i * slice_elems, default_v_ptr, count);
}
};
Expand All @@ -463,9 +479,12 @@ class KvResourceGatherOp : public OpKernel {

private:
bool is_use_default_value_tensor_;
bool is_inference_;
std::function<
TValue*(TValue*, TKey, int64, int64, int64)> get_default_v_fn_;
std::function<int32(int32*, int64)> get_count_fn_;
std::function<void(EmbeddingVar<TKey, TValue>* ev,
TKey key, TValue* val, TValue* default_v, int count)> lookup_fn_;
};

#define REGISTER_GATHER_FULL(dev, ktype, vtype) \
Expand Down
2 changes: 2 additions & 0 deletions tensorflow/core/ops/kv_variable_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ REGISTER_OP("KvResourceGatherV1")
.Input("counts: counts_type")
.Attr("validate_indices: bool = true")
.Attr("is_use_default_value_tensor: bool = false")
.Attr("is_inference: bool = false")
.Output("output: dtype")
.Attr("dtype: type")
.Attr("Tkeys: {int64,int32,string}")
Expand Down Expand Up @@ -281,6 +282,7 @@ REGISTER_OP("KvResourceGather")
.Output("output: dtype")
.Attr("dtype: type")
.Attr("Tkeys: {int64,int32,string}")
.Attr("is_inference: bool = false")
.SetShapeFn([](InferenceContext* c) {
ShapeAndType handle_shape_and_type;
TF_RETURN_IF_ERROR(
Expand Down
1 change: 1 addition & 0 deletions tensorflow/python/ops/embedding_variable_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from six.moves import xrange # pylint: disable=redefined-builtin

from tensorflow.core.framework import attr_value_pb2
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import string_ops
Expand Down