Skip to content

Commit 3550545

Browse files
EddyLXJfacebook-github-bot
authored andcommitted
Change kvzch_eviction_tbe_config to kvzch_tbe_config (#3514)
Summary: X-link: pytorch/FBGEMM#5084 X-link: facebookresearch/FBGEMM#2092 Change tbe config name from kvzch_eviction_tbe_config to kvzch_tbe_config, as it may use not only for eviction but also for some other processes like st publish. Reviewed By: emlin Differential Revision: D86212643
1 parent 0898642 commit 3550545

File tree

2 files changed

+10
-12
lines changed

2 files changed

+10
-12
lines changed

torchrec/distributed/batched_embedding_kernel.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -262,10 +262,8 @@ def _populate_ssd_tbe_params(config: GroupedEmbeddingConfig) -> Dict[str, Any]:
262262
)
263263
ssd_tbe_params["cache_sets"] = int(max_cache_sets)
264264

265-
if "kvzch_eviction_tbe_config" in fused_params and config.is_using_virtual_table():
266-
ssd_tbe_params["kvzch_eviction_tbe_config"] = fused_params.get(
267-
"kvzch_eviction_tbe_config"
268-
)
265+
if "kvzch_tbe_config" in fused_params and config.is_using_virtual_table():
266+
ssd_tbe_params["kvzch_tbe_config"] = fused_params.get("kvzch_tbe_config")
269267

270268
ssd_tbe_params["table_names"] = [table.name for table in config.embedding_tables]
271269

@@ -359,10 +357,10 @@ def _populate_zero_collision_tbe_params(
359357
l2_cache_size = tbe_params["l2_cache_size"]
360358

361359
assert (
362-
"kvzch_eviction_tbe_config" in tbe_params
363-
), "kvzch_eviction_tbe_config should be in tbe_params"
364-
eviction_tbe_config = tbe_params["kvzch_eviction_tbe_config"]
365-
tbe_params.pop("kvzch_eviction_tbe_config")
360+
"kvzch_tbe_config" in tbe_params
361+
), "kvzch_tbe_config should be in tbe_params"
362+
eviction_tbe_config = tbe_params["kvzch_tbe_config"]
363+
tbe_params.pop("kvzch_tbe_config")
366364
eviction_trigger_mode = eviction_tbe_config.kvzch_eviction_trigger_mode
367365
eviction_free_mem_threshold_gb = (
368366
eviction_tbe_config.eviction_free_mem_threshold_gb

torchrec/distributed/types.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from fbgemm_gpu.split_table_batched_embeddings_ops_common import (
3434
BoundsCheckMode,
3535
CacheAlgorithm,
36-
KVZCHEvictionTBEConfig,
36+
KVZCHTBEConfig,
3737
MultiPassPrefetchConfig,
3838
)
3939

@@ -668,7 +668,7 @@ class KeyValueParams:
668668
lazy_bulk_init_enabled: bool: whether to enable lazy(async) bulk init for SSD TBE
669669
enable_raw_embedding_streaming: Optional[bool]: enable raw embedding streaming for SSD TBE
670670
res_store_shards: Optional[int] = None: the number of shards to store the raw embeddings
671-
kvzch_eviction_tbe_config: Optional[KVZCHEvictionTBEConfig]: KVZCH eviction config for TBE
671+
kvzch_tbe_config: Optional[KVZCHTBEConfig]: KVZCH config for TBE
672672
673673
# Parameter Server (PS) Attributes
674674
ps_hosts (Optional[Tuple[Tuple[str, int]]]): List of PS host ip addresses
@@ -694,7 +694,7 @@ class KeyValueParams:
694694
None # enable raw embedding streaming for SSD TBE
695695
)
696696
res_store_shards: Optional[int] = None # shards to store the raw embeddings
697-
kvzch_eviction_tbe_config: Optional[KVZCHEvictionTBEConfig] = None
697+
kvzch_tbe_config: Optional[KVZCHTBEConfig] = None
698698

699699
# Parameter Server (PS) Attributes
700700
ps_hosts: Optional[Tuple[Tuple[str, int], ...]] = None
@@ -723,7 +723,7 @@ def __hash__(self) -> int:
723723
self.lazy_bulk_init_enabled,
724724
self.enable_raw_embedding_streaming,
725725
self.res_store_shards,
726-
self.kvzch_eviction_tbe_config,
726+
self.kvzch_tbe_config,
727727
)
728728
)
729729

0 commit comments

Comments
 (0)