Skip to content

Commit 8f3ed1a

Browse files
Ankang Liumeta-codesync[bot]
authored andcommitted
Revert D85017179: enable feature score auto collection in EBC
Differential Revision: D85017179 Original commit changeset: 3d62f8adbe20 Original Phabricator Diff: D85017179 fbshipit-source-id: 5a980f997b521920c72f2d028d7bf47338b170ea
1 parent f758d08 commit 8f3ed1a

File tree

6 files changed

+8
-339
lines changed

6 files changed

+8
-339
lines changed

torchrec/distributed/batched_embedding_kernel.py

Lines changed: 1 addition & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -326,8 +326,6 @@ def _populate_zero_collision_tbe_params(
326326
meta_header_lens[i] = table.virtual_table_eviction_policy.get_meta_header_len()
327327
if not isinstance(table.virtual_table_eviction_policy, NoEvictionPolicy):
328328
enabled = True
329-
330-
fs_eviction_enabled: bool = False
331329
if enabled:
332330
counter_thresholds = [0] * len(config.embedding_tables)
333331
ttls_in_mins = [0] * len(config.embedding_tables)
@@ -386,7 +384,6 @@ def _populate_zero_collision_tbe_params(
386384
raise ValueError(
387385
f"Do not support multiple eviction strategy in one tbe {eviction_strategy} and 5 for tables {table_names}"
388386
)
389-
fs_eviction_enabled = True
390387
elif isinstance(policy_t, TimestampBasedEvictionPolicy):
391388
training_id_eviction_trigger_count[i] = (
392389
policy_t.training_id_eviction_trigger_count
@@ -462,7 +459,6 @@ def _populate_zero_collision_tbe_params(
462459
backend_return_whole_row=(backend_type == BackendType.DRAM),
463460
eviction_policy=eviction_policy,
464461
embedding_cache_mode=embedding_cache_mode_,
465-
feature_score_collection_enabled=fs_eviction_enabled,
466462
)
467463

468464

@@ -2905,7 +2901,6 @@ def __init__(
29052901
_populate_zero_collision_tbe_params(
29062902
ssd_tbe_params, self._bucket_spec, config, backend_type
29072903
)
2908-
self._kv_zch_params: KVZCHParams = ssd_tbe_params["kv_zch_params"]
29092904
compute_kernel = config.embedding_tables[0].compute_kernel
29102905
embedding_location = compute_kernel_to_embedding_location(compute_kernel)
29112906

@@ -3190,40 +3185,7 @@ def forward(self, features: KeyedJaggedTensor) -> torch.Tensor:
31903185
self._split_weights_res = None
31913186
self._optim.set_sharded_embedding_weight_ids(sharded_embedding_weight_ids=None)
31923187

3193-
weights = features.weights_or_none()
3194-
per_sample_weights = None
3195-
score_weights = None
3196-
if weights is not None and weights.dtype == torch.float64:
3197-
fp32_weights = weights.view(torch.float32)
3198-
per_sample_weights = fp32_weights[:, 0]
3199-
score_weights = fp32_weights[:, 1]
3200-
elif weights is not None and weights.dtype == torch.float32:
3201-
if self._kv_zch_params.feature_score_collection_enabled:
3202-
score_weights = weights.view(-1)
3203-
else:
3204-
per_sample_weights = weights.view(-1)
3205-
if features.variable_stride_per_key() and isinstance(
3206-
self.emb_module,
3207-
(
3208-
SplitTableBatchedEmbeddingBagsCodegen,
3209-
DenseTableBatchedEmbeddingBagsCodegen,
3210-
SSDTableBatchedEmbeddingBags,
3211-
),
3212-
):
3213-
return self.emb_module(
3214-
indices=features.values().long(),
3215-
offsets=features.offsets().long(),
3216-
weights=score_weights,
3217-
per_sample_weights=per_sample_weights,
3218-
batch_size_per_feature_per_rank=features.stride_per_key_per_rank(),
3219-
)
3220-
else:
3221-
return self.emb_module(
3222-
indices=features.values().long(),
3223-
offsets=features.offsets().long(),
3224-
weights=score_weights,
3225-
per_sample_weights=per_sample_weights,
3226-
)
3188+
return super().forward(features)
32273189

32283190

32293191
class BatchedFusedEmbeddingBag(

torchrec/distributed/embedding_lookup.py

Lines changed: 4 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,6 @@
6666
QuantBatchedEmbeddingBag,
6767
)
6868
from torchrec.distributed.types import rank_device, ShardedTensor, ShardingType
69-
from torchrec.modules.embedding_configs import FeatureScoreBasedEvictionPolicy
7069
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
7170

7271
logger: logging.Logger = logging.getLogger(__name__)
@@ -516,23 +515,6 @@ def __init__(
516515
) -> None:
517516
super().__init__()
518517
self._emb_modules: nn.ModuleList = nn.ModuleList()
519-
self._feature_score_auto_collections: List[bool] = []
520-
for config in grouped_configs:
521-
collection = False
522-
for table in config.embedding_tables:
523-
if table.use_virtual_table and isinstance(
524-
table.virtual_table_eviction_policy, FeatureScoreBasedEvictionPolicy
525-
):
526-
if (
527-
table.virtual_table_eviction_policy.enable_auto_feature_score_collection
528-
):
529-
collection = True
530-
self._feature_score_auto_collections.append(collection)
531-
532-
logger.info(
533-
f"GroupedPooledEmbeddingsLookup: {self._feature_score_auto_collections=}"
534-
)
535-
536518
for config in grouped_configs:
537519
self._emb_modules.append(
538520
self._create_embedding_kernel(config, device, pg, sharding_type)
@@ -710,11 +692,8 @@ def forward(
710692
features_by_group = sparse_features.split(
711693
self._feature_splits,
712694
)
713-
for config, emb_op, features, fs_auto_collection in zip(
714-
self.grouped_configs,
715-
self._emb_modules,
716-
features_by_group,
717-
self._feature_score_auto_collections,
695+
for config, emb_op, features in zip(
696+
self.grouped_configs, self._emb_modules, features_by_group
718697
):
719698
if (
720699
config.has_feature_processor
@@ -724,19 +703,9 @@ def forward(
724703
features = self._feature_processor(features)
725704

726705
if config.is_weighted:
727-
feature_weights = CommOpGradientScaling.apply(
706+
features._weights = CommOpGradientScaling.apply(
728707
features._weights, self._scale_gradient_factor
729-
).float()
730-
731-
if fs_auto_collection and features.weights_or_none() is not None:
732-
score_weights = features.weights().float()
733-
assert (
734-
feature_weights.numel() == score_weights.numel()
735-
), f"feature_weights.numel() {feature_weights.numel()} != score_weights.numel() {score_weights.numel()}"
736-
cat_weights = torch.cat(
737-
[feature_weights, score_weights], dim=1
738-
).view(torch.float64)
739-
features._weights = cat_weights
708+
)
740709

741710
lookup = emb_op(features)
742711
embeddings.append(lookup)

torchrec/distributed/embeddingbag.py

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,6 @@
5151
KJTList,
5252
ShardedEmbeddingModule,
5353
)
54-
from torchrec.distributed.feature_score_utils import (
55-
create_sharding_type_to_feature_score_mapping,
56-
may_collect_feature_scores,
57-
)
5854
from torchrec.distributed.fused_params import (
5955
FUSED_PARAM_IS_SSD_TABLE,
6056
FUSED_PARAM_SSD_TABLE_LIST,
@@ -569,24 +565,6 @@ def __init__(
569565
# forward pass flow control
570566
self._has_uninitialized_input_dist: bool = True
571567
self._has_features_permute: bool = True
572-
573-
self._enable_feature_score_weight_accumulation: bool = False
574-
self._enabled_feature_score_auto_collection: bool = False
575-
self._sharding_type_feature_score_mapping: Dict[str, Dict[str, float]] = {}
576-
(
577-
self._enable_feature_score_weight_accumulation,
578-
self._enabled_feature_score_auto_collection,
579-
self._sharding_type_feature_score_mapping,
580-
) = create_sharding_type_to_feature_score_mapping(
581-
self._embedding_bag_configs, self.sharding_type_to_sharding_infos
582-
)
583-
584-
logger.info(
585-
f"EBC feature score weight accumulation enabled: {self._enable_feature_score_weight_accumulation}, "
586-
f"auto collection enabled: {self._enabled_feature_score_auto_collection}, "
587-
f"sharding type to feature score mapping: {self._sharding_type_feature_score_mapping}"
588-
)
589-
590568
# Get all fused optimizers and combine them.
591569
optims = []
592570
for lookup in self._lookups:
@@ -1587,11 +1565,6 @@ def input_dist(
15871565
features_by_shards = features.split(
15881566
self._feature_splits,
15891567
)
1590-
features_by_shards = may_collect_feature_scores(
1591-
features_by_shards,
1592-
self._enabled_feature_score_auto_collection,
1593-
self._sharding_type_feature_score_mapping,
1594-
)
15951568
awaitables = []
15961569
for input_dist, features_by_shard, sharding_type in zip(
15971570
self._input_dists,

torchrec/distributed/feature_score_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from torchrec.distributed.embedding_types import ShardingType
1818

1919
from torchrec.modules.embedding_configs import (
20-
BaseEmbeddingConfig,
20+
EmbeddingConfig,
2121
FeatureScoreBasedEvictionPolicy,
2222
)
2323
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
@@ -26,7 +26,7 @@
2626

2727

2828
def create_sharding_type_to_feature_score_mapping(
29-
embedding_configs: Sequence[BaseEmbeddingConfig],
29+
embedding_configs: Sequence[EmbeddingConfig],
3030
sharding_type_to_sharding_infos: Dict[str, List[EmbeddingShardingInfo]],
3131
) -> Tuple[bool, bool, Dict[str, Dict[str, float]]]:
3232
enable_feature_score_weight_accumulation = False

0 commit comments

Comments
 (0)