From 7b54824acd5656bdd8eb192e9ae8a54e24158939 Mon Sep 17 00:00:00 2001 From: lxy268263 Date: Fri, 10 Jun 2022 18:01:57 +0800 Subject: [PATCH] [Embedding] Support print EmbeddingVariableOption. --- tensorflow/python/ops/variables.py | 42 +++++++++++++++++++++++++++++- 1 file changed, 41 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py index 99f702ca9e7..597b51fe870 100644 --- a/tensorflow/python/ops/variables.py +++ b/tensorflow/python/ops/variables.py @@ -186,6 +186,9 @@ def __init__(self, if default_value_dim <=0: print("default value dim must larger than 1, the default value dim is set to default 4096.") default_value_dim = 4096 + def __repr__(self): + return "[Initializier: {init}".format(init=self.initializer) + \ + ", default_value_dim: {dim}]".format(dim=self.default_value_dim) class MultihashOption(object): def __init__(self, @@ -203,6 +206,8 @@ class GlobalStepEvict(object): def __init__(self, steps_to_live = None): self.steps_to_live = steps_to_live + def __repr__(self): + return "[Eviction type: GlobalStepEvict, steps_to_live: {step}]".format(step=self.steps_to_live) @tf_export(v1=["L2WeightEvict"]) class L2WeightEvict(object): @@ -211,6 +216,8 @@ def __init__(self, self.l2_weight_threshold = l2_weight_threshold if l2_weight_threshold <= 0 and l2_weight_threshold != -1.0: logging.warning("l2_weight_threshold is invalid, l2_weight-based eviction is disabled") + def __repr__(self): + return "[Eviction type: L2WeightEvict, l2_weight_threshold: {threshold}]".format(threshold=self.l2_weight_threshold) @tf_export(v1=["CheckpointOption"]) class CheckpointOption(object): @@ -250,6 +257,23 @@ def __init__(self, config_pb2.StorageType.DRAM_SSDHASH, config_pb2.StorageType.DRAM_LEVELDB]: raise ValueError("storage_path musnt'be None when storage_type is set") + def __repr__(self): + name_dict = {} + name_dict[None] = "None" + name_dict[config_pb2.StorageType.DRAM] = "DRAM" + name_dict[config_pb2.StorageType.PMEM_LIBPMEM] = "PMEM_LIBPMEM" + name_dict[config_pb2.StorageType.PMEM_MEMKIND] = "PMEM_MEMKIND" + name_dict[config_pb2.StorageType.LEVELDB] = "LEVELDB" + name_dict[config_pb2.StorageType.SSDHASH] = "SSDHASH" + name_dict[config_pb2.StorageType.DRAM_LEVELDB] = "DRAM_LEVELDB" + name_dict[config_pb2.StorageType.PMEM_MEMKIND] = "HBM_DRAM" + name_dict[config_pb2.StorageType.DRAM_SSDHASH] = "DRAM_SSDHASH" + name_dict[config_pb2.StorageType.DRAM_PMEM] = "DRAM_PMEM" + name_dict[config_pb2.StorageType.DRAM_PMEM_SSDHASH] = "DRAM_PMEM_SSDHASH" + name_dict[config_pb2.StorageType.HBM_DRAM_SSDHASH] = "HBM_DRAM_SSDHASH" + return "[Storage type: {type}".format(type=name_dict[self.storage_type]) + \ + ", storage path: {path}".format(path=self.storage_path) + \ + ", storage_size: {size}]".format(size=self.storage_size) @tf_export(v1=["EmbeddingVariableOption"]) class EmbeddingVariableOption(object): @@ -267,12 +291,23 @@ def __init__(self, self.ckpt = ckpt self.filter_strategy = filter_option self.storage_option = storage_option - self.init = init_option + self.init = init_option + + def __repr__(self): + return "ht_type: {type}".format(type=self.ht_type) + \ + ", ht_partition_num: {num}".format(num=self.ht_partition_num) + \ + ", evict_option: {evict}".format(evict=self.evict) + \ + ", filter_strategy: {filter}".format(filter=self.filter_strategy) + \ + ", storage_option: {storage}".format(storage=self.storage_option) + \ + ", init_option: {init}".format(init=self.init) + @tf_export(v1=["CounterFilter"]) class CounterFilter(object): def __init__(self, filter_freq = 0): self.filter_freq = filter_freq + def __repr__(self): + return "[filter_type: CounterFilter, filter_freq: {freq}]".format(freq=self.filter_freq) @tf_export(v1=["CBFFilter"]) class CBFFilter(object): @@ -294,6 +329,11 @@ def __init__(self, self.false_positive_probability = false_positive_probability self.counter_type = counter_type self.filter_freq = filter_freq + def __repr__(self): + return "[filter_type: CBFFilter, filter_freq: {freq}".format(freq=self.filter_freq) + \ + ", max_element_size: {size}".format(size=self.max_element_size) + \ + ", false_positive_probability: {prob}".format(prob=self.false_positive_probability) + \ + ", self.counter_type: {type}]".format(type=self.counter_type) class EmbeddingVariableConfig(object): def __init__(self,