From 21d12604354955584be5b38e3315df97203d4891 Mon Sep 17 00:00:00 2001 From: Shuai Yang Date: Thu, 16 Jan 2025 19:55:30 -0800 Subject: [PATCH] Fix specailization issue in keyed_jagged_index_select_dim1_forward_cuda (#3578) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/664 Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3578 `lengths` is a tensor with symbolic shapes. Calling `len` on it will force specialization on it which will cause data dependent failure as shown below: {F1974383976} tlparse: https://fburl.com/74rjmr8e The fix is to replace `len` with equivalent operations which support symbolic shapes. Reviewed By: TroyGarden Differential Revision: D67491452 fbshipit-source-id: ed2207b310697d774a284f296c8d34ca2da61adc --- fbgemm_gpu/fbgemm_gpu/sparse_ops.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/fbgemm_gpu/fbgemm_gpu/sparse_ops.py b/fbgemm_gpu/fbgemm_gpu/sparse_ops.py index 730322ef7..01fa577bb 100644 --- a/fbgemm_gpu/fbgemm_gpu/sparse_ops.py +++ b/fbgemm_gpu/fbgemm_gpu/sparse_ops.py @@ -900,13 +900,14 @@ def keyed_jagged_index_select_dim1_forward_cuda_impl_abstract( lengths: torch.Tensor, offsets: torch.Tensor, indices: torch.Tensor, - batch_size: int, + batch_size: torch.SymInt, weights: Optional[torch.Tensor] = None, - selected_lengths_sum: Optional[int] = None, + selected_lengths_sum: Optional[torch.SymInt] = None, ) -> List[torch.Tensor]: - num_batches = len(lengths) // batch_size - torch._check(len(lengths) + 1 == len(offsets)) - torch._check(len(lengths) % batch_size == 0) + num_batches = lengths.size(0) // batch_size + torch._check(lengths.size(0) + 1 == offsets.size(0)) + # pyre-ignore + torch._check(lengths.size(0) % batch_size == 0) if weights is not None: # weights must have the same shape as values