Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Embedding] Support print EmbeddingVariableOption. #260

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 41 additions & 1 deletion tensorflow/python/ops/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,9 @@ def __init__(self,
if default_value_dim <=0:
print("default value dim must larger than 1, the default value dim is set to default 4096.")
default_value_dim = 4096
def __repr__(self):
return "[Initializier: {init}".format(init=self.initializer) + \
", default_value_dim: {dim}]".format(dim=self.default_value_dim)

class MultihashOption(object):
def __init__(self,
Expand All @@ -203,6 +206,8 @@ class GlobalStepEvict(object):
def __init__(self,
steps_to_live = None):
self.steps_to_live = steps_to_live
def __repr__(self):
return "[Eviction type: GlobalStepEvict, steps_to_live: {step}]".format(step=self.steps_to_live)

@tf_export(v1=["L2WeightEvict"])
class L2WeightEvict(object):
Expand All @@ -211,6 +216,8 @@ def __init__(self,
self.l2_weight_threshold = l2_weight_threshold
if l2_weight_threshold <= 0 and l2_weight_threshold != -1.0:
logging.warning("l2_weight_threshold is invalid, l2_weight-based eviction is disabled")
def __repr__(self):
return "[Eviction type: L2WeightEvict, l2_weight_threshold: {threshold}]".format(threshold=self.l2_weight_threshold)

@tf_export(v1=["CheckpointOption"])
class CheckpointOption(object):
Expand Down Expand Up @@ -250,6 +257,23 @@ def __init__(self,
config_pb2.StorageType.DRAM_SSDHASH,
config_pb2.StorageType.DRAM_LEVELDB]:
raise ValueError("storage_path musnt'be None when storage_type is set")
def __repr__(self):
name_dict = {}
name_dict[None] = "None"
name_dict[config_pb2.StorageType.DRAM] = "DRAM"
name_dict[config_pb2.StorageType.PMEM_LIBPMEM] = "PMEM_LIBPMEM"
name_dict[config_pb2.StorageType.PMEM_MEMKIND] = "PMEM_MEMKIND"
name_dict[config_pb2.StorageType.LEVELDB] = "LEVELDB"
name_dict[config_pb2.StorageType.SSDHASH] = "SSDHASH"
name_dict[config_pb2.StorageType.DRAM_LEVELDB] = "DRAM_LEVELDB"
name_dict[config_pb2.StorageType.PMEM_MEMKIND] = "HBM_DRAM"
name_dict[config_pb2.StorageType.DRAM_SSDHASH] = "DRAM_SSDHASH"
name_dict[config_pb2.StorageType.DRAM_PMEM] = "DRAM_PMEM"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dict放到函数外面,别人也可以用
参考一下tf 的dtypes.py 中 _TYPE_TO_STRING 的实现方法

name_dict[config_pb2.StorageType.DRAM_PMEM_SSDHASH] = "DRAM_PMEM_SSDHASH"
name_dict[config_pb2.StorageType.HBM_DRAM_SSDHASH] = "HBM_DRAM_SSDHASH"
return "[Storage type: {type}".format(type=name_dict[self.storage_type]) + \
", storage path: {path}".format(path=self.storage_path) + \
", storage_size: {size}]".format(size=self.storage_size)

@tf_export(v1=["EmbeddingVariableOption"])
class EmbeddingVariableOption(object):
Expand All @@ -267,12 +291,23 @@ def __init__(self,
self.ckpt = ckpt
self.filter_strategy = filter_option
self.storage_option = storage_option
self.init = init_option
self.init = init_option

def __repr__(self):
return "ht_type: {type}".format(type=self.ht_type) + \
", ht_partition_num: {num}".format(num=self.ht_partition_num) + \
", evict_option: {evict}".format(evict=self.evict) + \
", filter_strategy: {filter}".format(filter=self.filter_strategy) + \
", storage_option: {storage}".format(storage=self.storage_option) + \
", init_option: {init}".format(init=self.init)


@tf_export(v1=["CounterFilter"])
class CounterFilter(object):
def __init__(self, filter_freq = 0):
self.filter_freq = filter_freq
def __repr__(self):
return "[filter_type: CounterFilter, filter_freq: {freq}]".format(freq=self.filter_freq)

@tf_export(v1=["CBFFilter"])
class CBFFilter(object):
Expand All @@ -294,6 +329,11 @@ def __init__(self,
self.false_positive_probability = false_positive_probability
self.counter_type = counter_type
self.filter_freq = filter_freq
def __repr__(self):
return "[filter_type: CBFFilter, filter_freq: {freq}".format(freq=self.filter_freq) + \
", max_element_size: {size}".format(size=self.max_element_size) + \
", false_positive_probability: {prob}".format(prob=self.false_positive_probability) + \
", self.counter_type: {type}]".format(type=self.counter_type)

class EmbeddingVariableConfig(object):
def __init__(self,
Expand Down