@@ -2600,7 +2600,71 @@ def init_parameters(self) -> None:
26002600 weight_init_max ,
26012601 )
26022602
2603- def forward (self , features : KeyedJaggedTensor ) -> torch .Tensor :
2603+ def _get_hash_zch_identities (
2604+ self , features : KeyedJaggedTensor
2605+ ) -> Optional [torch .Tensor ]:
2606+ if self ._raw_id_tracker_wrapper is None or not isinstance (
2607+ self .emb_module , SplitTableBatchedEmbeddingBagsCodegen
2608+ ):
2609+ return None
2610+
2611+ raw_id_tracker_wrapper = self ._raw_id_tracker_wrapper
2612+ assert (
2613+ raw_id_tracker_wrapper is not None
2614+ ), "self._raw_id_tracker_wrapper should not be None"
2615+ assert hasattr (
2616+ self .emb_module , "res_params"
2617+ ), "res_params should exist when raw_id_tracker is enabled"
2618+ res_params : RESParams = self .emb_module .res_params # pyre-ignore[9]
2619+ table_names = res_params .table_names
2620+
2621+ # TODO: get_indexed_lookups() may return multiple IndexedLookup objects
2622+ # across multiple training iterations. Current logic appends raw_ids from
2623+ # all batches sequentially. This may cause misalignment with
2624+ # features.values() which only contains the current batch.
2625+ raw_ids_dict = raw_id_tracker_wrapper .get_indexed_lookups (
2626+ table_names , self .emb_module .uuid
2627+ )
2628+
2629+ # Build hash_zch_identities by concatenating raw IDs from tracked tables.
2630+ # Output maintains 1-to-1 alignment with features.values().
2631+ # Iterate through table_names explicitly (not raw_ids_dict.values()) to
2632+ # ensure correct ordering, since there is no guarantee on dict ordering.
2633+ #
2634+ # E.g. If features.values() = [f1_val1, f1_val2, f2_val1, f2_val2, ...]
2635+ # where table1 has [feature1, feature2] and table2 has [feature3, feature4]
2636+ # then hash_zch_identities = [f1_id1, f1_id2, f2_id1, f2_id2, ...]
2637+ #
2638+ # TODO: Handle tables without identity tracking. Currently, only tables with
2639+ # raw_ids are included. If some tables lack identity while others have them,
2640+ # padding with -1 may be needed to maintain alignment.
2641+ all_raw_ids = []
2642+ for table_name in table_names :
2643+ if table_name in raw_ids_dict :
2644+ raw_ids_list = raw_ids_dict [table_name ]
2645+ for raw_ids in raw_ids_list :
2646+ all_raw_ids .append (raw_ids )
2647+
2648+ if not all_raw_ids :
2649+ return None
2650+
2651+ hash_zch_identities = torch .cat (all_raw_ids )
2652+ assert hash_zch_identities .size (0 ) == features .values ().numel (), (
2653+ f"hash_zch_identities row count ({ hash_zch_identities .size (0 )} ) must match "
2654+ f"features.values() length ({ features .values ().numel ()} ) to maintain 1-to-1 alignment"
2655+ )
2656+
2657+ return hash_zch_identities
2658+
2659+ def forward (
2660+ self ,
2661+ features : KeyedJaggedTensor ,
2662+ ) -> torch .Tensor :
2663+ forward_args : Dict [str , Any ] = {}
2664+ hash_zch_identities = self ._get_hash_zch_identities (features )
2665+ if hash_zch_identities is not None :
2666+ forward_args ["hash_zch_identities" ] = hash_zch_identities
2667+
26042668 weights = features .weights_or_none ()
26052669 if weights is not None and not torch .is_floating_point (weights ):
26062670 weights = None
@@ -2612,17 +2676,22 @@ def forward(self, features: KeyedJaggedTensor) -> torch.Tensor:
26122676 SSDTableBatchedEmbeddingBags ,
26132677 ),
26142678 ):
2679+ forward_args ["batch_size_per_feature_per_rank" ] = (
2680+ features .stride_per_key_per_rank ()
2681+ )
2682+
2683+ if len (forward_args ) == 0 :
26152684 return self .emb_module (
26162685 indices = features .values ().long (),
26172686 offsets = features .offsets ().long (),
26182687 per_sample_weights = weights ,
2619- batch_size_per_feature_per_rank = features .stride_per_key_per_rank (),
26202688 )
26212689 else :
26222690 return self .emb_module (
26232691 indices = features .values ().long (),
26242692 offsets = features .offsets ().long (),
26252693 per_sample_weights = weights ,
2694+ ** forward_args ,
26262695 )
26272696
26282697 # pyre-fixme[14]: `state_dict` overrides method defined in `Module` inconsistently.
0 commit comments