From 243055a14182cbf57466bd5451fd88e64cc83dd4 Mon Sep 17 00:00:00 2001 From: Yixin Bao Date: Wed, 29 Oct 2025 19:14:23 -0700 Subject: [PATCH] Create always on mask tensor based on remapped lengths Differential Revision: D83387941 --- torchrec/modules/mc_embedding_modules.py | 13 +++++++++++++ .../modules/tests/test_mc_embedding_modules.py | 18 ++++++++++++++++++ 2 files changed, 31 insertions(+) diff --git a/torchrec/modules/mc_embedding_modules.py b/torchrec/modules/mc_embedding_modules.py index 6e7850dba..129da1e69 100644 --- a/torchrec/modules/mc_embedding_modules.py +++ b/torchrec/modules/mc_embedding_modules.py @@ -93,6 +93,19 @@ def forward( return embedding_res, None return embedding_res, features + def lookup_remapped_lengths_mask( + self, + features: KeyedJaggedTensor, + ) -> torch.Tensor: + features = self._managed_collision_collection(features) + remapped_lengths = return_remapped_lengths_as_mask(features) + return remapped_lengths + + +@torch.fx.wrap +def return_remapped_lengths_as_mask(features: KeyedJaggedTensor) -> torch.Tensor: + return features.lengths().to(torch.bool) + class ManagedCollisionEmbeddingCollection(BaseManagedCollisionEmbeddingCollection): """ diff --git a/torchrec/modules/tests/test_mc_embedding_modules.py b/torchrec/modules/tests/test_mc_embedding_modules.py index 58c4fa466..cdb296e6b 100644 --- a/torchrec/modules/tests/test_mc_embedding_modules.py +++ b/torchrec/modules/tests/test_mc_embedding_modules.py @@ -21,6 +21,7 @@ from torchrec.modules.mc_embedding_modules import ( ManagedCollisionEmbeddingBagCollection, ManagedCollisionEmbeddingCollection, + return_remapped_lengths_as_mask, ) from torchrec.modules.mc_modules import ( DistanceLFU_EvictionPolicy, @@ -409,3 +410,20 @@ def test_mc_collection_traceable(self) -> None: ) mcc.train(False) symbolic_trace(mcc, leaf_modules=[ComputeJTDictToKJT.__name__]) + + def test_return_remapped_lengths_as_mask(self) -> None: + mask = return_remapped_lengths_as_mask( + KeyedJaggedTensor( + keys=["f0"], + values=torch.rand(6), + lengths=torch.tensor([1, 0, 1, 0, 1, 0, 0, 1, 1, 1], dtype=torch.int64), + ) + ) + self.assertTrue( + torch.equal( + mask, + torch.tensor( + [True, False, True, False, True, False, False, True, True, True] + ), + ) + )