Skip to content

Commit 1076264

Browse files
Joey Yangfacebook-github-bot
authored andcommitted
Look up and forward raw_ids from tracker to TBE from embedding module
Summary: We look up the identities corresponding to slot index from `raw_id_tracker` and forward it to TBE. The identities are stored in `raw_id_tracker` during `mc_module` lookup. `raw_ids` are retrieved per-table from `raw_id_tracker` with the api `get_indexed_lookup()`, we concatenat them to maintain 1-to-1 alignment with `features.values()`. Reviewed By: chouxi Differential Revision: D86242001
1 parent 0898642 commit 1076264

File tree

1 file changed

+71
-2
lines changed

1 file changed

+71
-2
lines changed

torchrec/distributed/batched_embedding_kernel.py

Lines changed: 71 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2600,7 +2600,71 @@ def init_parameters(self) -> None:
26002600
weight_init_max,
26012601
)
26022602

2603-
def forward(self, features: KeyedJaggedTensor) -> torch.Tensor:
2603+
def _get_hash_zch_identities(
2604+
self, features: KeyedJaggedTensor
2605+
) -> Optional[torch.Tensor]:
2606+
if self._raw_id_tracker_wrapper is None or not isinstance(
2607+
self.emb_module, SplitTableBatchedEmbeddingBagsCodegen
2608+
):
2609+
return None
2610+
2611+
raw_id_tracker_wrapper = self._raw_id_tracker_wrapper
2612+
assert (
2613+
raw_id_tracker_wrapper is not None
2614+
), "self._raw_id_tracker_wrapper should not be None"
2615+
assert hasattr(
2616+
self.emb_module, "res_params"
2617+
), "res_params should exist when raw_id_tracker is enabled"
2618+
res_params: RESParams = self.emb_module.res_params # pyre-ignore[9]
2619+
table_names = res_params.table_names
2620+
2621+
# TODO: get_indexed_lookups() may return multiple IndexedLookup objects
2622+
# across multiple training iterations. Current logic appends raw_ids from
2623+
# all batches sequentially. This may cause misalignment with
2624+
# features.values() which only contains the current batch.
2625+
raw_ids_dict = raw_id_tracker_wrapper.get_indexed_lookups(
2626+
table_names, self.emb_module.uuid
2627+
)
2628+
2629+
# Build hash_zch_identities by concatenating raw IDs from tracked tables.
2630+
# Output maintains 1-to-1 alignment with features.values().
2631+
# Iterate through table_names explicitly (not raw_ids_dict.values()) to
2632+
# ensure correct ordering, since there is no guarantee on dict ordering.
2633+
#
2634+
# E.g. If features.values() = [f1_val1, f1_val2, f2_val1, f2_val2, ...]
2635+
# where table1 has [feature1, feature2] and table2 has [feature3, feature4]
2636+
# then hash_zch_identities = [f1_id1, f1_id2, f2_id1, f2_id2, ...]
2637+
#
2638+
# TODO: Handle tables without identity tracking. Currently, only tables with
2639+
# raw_ids are included. If some tables lack identity while others have them,
2640+
# padding with -1 may be needed to maintain alignment.
2641+
all_raw_ids = []
2642+
for table_name in table_names:
2643+
if table_name in raw_ids_dict:
2644+
raw_ids_list = raw_ids_dict[table_name]
2645+
for raw_ids in raw_ids_list:
2646+
all_raw_ids.append(raw_ids)
2647+
2648+
if not all_raw_ids:
2649+
return None
2650+
2651+
hash_zch_identities = torch.cat(all_raw_ids)
2652+
assert hash_zch_identities.size(0) == features.values().numel(), (
2653+
f"hash_zch_identities row count ({hash_zch_identities.size(0)}) must match "
2654+
f"features.values() length ({features.values().numel()}) to maintain 1-to-1 alignment"
2655+
)
2656+
2657+
return hash_zch_identities
2658+
2659+
def forward(
2660+
self,
2661+
features: KeyedJaggedTensor,
2662+
) -> torch.Tensor:
2663+
forward_args: Dict[str, Any] = {}
2664+
hash_zch_identities = self._get_hash_zch_identities(features)
2665+
if hash_zch_identities is not None:
2666+
forward_args["hash_zch_identities"] = hash_zch_identities
2667+
26042668
weights = features.weights_or_none()
26052669
if weights is not None and not torch.is_floating_point(weights):
26062670
weights = None
@@ -2612,17 +2676,22 @@ def forward(self, features: KeyedJaggedTensor) -> torch.Tensor:
26122676
SSDTableBatchedEmbeddingBags,
26132677
),
26142678
):
2679+
forward_args["batch_size_per_feature_per_rank"] = (
2680+
features.stride_per_key_per_rank()
2681+
)
2682+
2683+
if len(forward_args) == 0:
26152684
return self.emb_module(
26162685
indices=features.values().long(),
26172686
offsets=features.offsets().long(),
26182687
per_sample_weights=weights,
2619-
batch_size_per_feature_per_rank=features.stride_per_key_per_rank(),
26202688
)
26212689
else:
26222690
return self.emb_module(
26232691
indices=features.values().long(),
26242692
offsets=features.offsets().long(),
26252693
per_sample_weights=weights,
2694+
**forward_args,
26262695
)
26272696

26282697
# pyre-fixme[14]: `state_dict` overrides method defined in `Module` inconsistently.

0 commit comments

Comments
 (0)