Skip to content

Update KJT stride calculation logic to be based off of inverse_indices for VBE KJTs. #2949

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions torchrec/schema/api_tests/test_jagged_tensor_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import inspect
import unittest
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional, Tuple, Union

import torch
from torchrec.schema.utils import is_signature_compatible
Expand Down Expand Up @@ -112,7 +112,9 @@ def __init__(
lengths: Optional[torch.Tensor] = None,
offsets: Optional[torch.Tensor] = None,
stride: Optional[int] = None,
stride_per_key_per_rank: Optional[List[List[int]]] = None,
stride_per_key_per_rank: Optional[
Union[List[List[int]], torch.IntTensor]
] = None,
# Below exposed to ensure torch.script-able
stride_per_key: Optional[List[int]] = None,
length_per_key: Optional[List[int]] = None,
Expand Down
92 changes: 53 additions & 39 deletions torchrec/sparse/jagged_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1096,13 +1096,19 @@ def _maybe_compute_stride_kjt(
stride: Optional[int],
lengths: Optional[torch.Tensor],
offsets: Optional[torch.Tensor],
stride_per_key_per_rank: Optional[List[List[int]]],
stride_per_key_per_rank: Optional[torch.IntTensor],
inverse_indices: Optional[Tuple[List[str], torch.Tensor]] = None,
) -> int:
if stride is None:
if len(keys) == 0:
stride = 0
elif stride_per_key_per_rank is not None and len(stride_per_key_per_rank) > 0:
stride = max([sum(s) for s in stride_per_key_per_rank])
elif (
stride_per_key_per_rank is not None and stride_per_key_per_rank.numel() > 0
):
# For VBE KJT, use inverse_indices for the batch size of the EBC output KeyedTensor.
if inverse_indices is not None and inverse_indices[1].numel() > 0:
return inverse_indices[1].shape[-1]
stride = int(stride_per_key_per_rank.sum(dim=1).max().item())
elif offsets is not None and offsets.numel() > 0:
stride = (offsets.numel() - 1) // len(keys)
elif lengths is not None:
Expand Down Expand Up @@ -1668,14 +1674,18 @@ def _maybe_compute_lengths_offset_per_key(

def _maybe_compute_stride_per_key(
stride_per_key: Optional[List[int]],
stride_per_key_per_rank: Optional[List[List[int]]],
stride_per_key_per_rank: Optional[torch.IntTensor],
stride: Optional[int],
keys: List[str],
) -> Optional[List[int]]:
if stride_per_key is not None:
return stride_per_key
elif stride_per_key_per_rank is not None:
return [sum(s) for s in stride_per_key_per_rank]
if stride_per_key_per_rank.dim() != 2:
# after permute the kjt could be empty
return []
rt: List[int] = stride_per_key_per_rank.sum(dim=1).tolist()
return rt
elif stride is not None:
return [stride] * len(keys)
else:
Expand Down Expand Up @@ -1766,7 +1776,9 @@ def __init__(
lengths: Optional[torch.Tensor] = None,
offsets: Optional[torch.Tensor] = None,
stride: Optional[int] = None,
stride_per_key_per_rank: Optional[List[List[int]]] = None,
stride_per_key_per_rank: Optional[
Union[torch.IntTensor, List[List[int]]]
] = None,
# Below exposed to ensure torch.script-able
stride_per_key: Optional[List[int]] = None,
length_per_key: Optional[List[int]] = None,
Expand All @@ -1788,8 +1800,10 @@ def __init__(
self._lengths: Optional[torch.Tensor] = lengths
self._offsets: Optional[torch.Tensor] = offsets
self._stride: Optional[int] = stride
self._stride_per_key_per_rank: Optional[List[List[int]]] = (
stride_per_key_per_rank
self._stride_per_key_per_rank: Optional[torch.IntTensor] = (
torch.IntTensor(stride_per_key_per_rank, device="cpu")
if isinstance(stride_per_key_per_rank, list)
else stride_per_key_per_rank
)
self._stride_per_key: Optional[List[int]] = stride_per_key
self._length_per_key: Optional[List[int]] = length_per_key
Expand All @@ -1815,10 +1829,11 @@ def _init_pt2_checks(self) -> None:
return
if self._stride_per_key is not None:
pt2_checks_all_is_size(self._stride_per_key)
if self._stride_per_key_per_rank is not None:
# pyre-ignore [16]
for s in self._stride_per_key_per_rank:
pt2_checks_all_is_size(s)
_stride_per_key_per_rank = self._stride_per_key_per_rank
if _stride_per_key_per_rank is not None:
for stride_per_rank in _stride_per_key_per_rank:
for s in stride_per_rank:
torch._check_is_size(s.item())

@staticmethod
def from_offsets_sync(
Expand Down Expand Up @@ -2028,7 +2043,7 @@ def from_jt_dict(jt_dict: Dict[str, JaggedTensor]) -> "KeyedJaggedTensor":
kjt_stride, kjt_stride_per_key_per_rank = (
(stride_per_key[0], None)
if all(s == stride_per_key[0] for s in stride_per_key)
else (None, [[stride] for stride in stride_per_key])
else (None, torch.IntTensor(stride_per_key, device="cpu").reshape(-1, 1))
)
kjt = KeyedJaggedTensor(
keys=kjt_keys,
Expand Down Expand Up @@ -2165,6 +2180,7 @@ def stride(self) -> int:
self._lengths,
self._offsets,
self._stride_per_key_per_rank,
self._inverse_indices,
)
self._stride = stride
return stride
Expand Down Expand Up @@ -2193,8 +2209,13 @@ def stride_per_key_per_rank(self) -> List[List[int]]:
Returns:
List[List[int]]: stride per key per rank of the KeyedJaggedTensor.
"""
stride_per_key_per_rank = self._stride_per_key_per_rank
return stride_per_key_per_rank if stride_per_key_per_rank is not None else []
# making a local reference to the class variable to make jit.script behave
_stride_per_key_per_rank = self._stride_per_key_per_rank
return (
[]
if _stride_per_key_per_rank is None
else _stride_per_key_per_rank.tolist()
)

def variable_stride_per_key(self) -> bool:
"""
Expand Down Expand Up @@ -2514,17 +2535,17 @@ def permute(

length_per_key = self.length_per_key()
permuted_keys: List[str] = []
permuted_stride_per_key_per_rank: List[List[int]] = []
permuted_length_per_key: List[int] = []
permuted_length_per_key_sum = 0
for index in indices:
key = self.keys()[index]
permuted_keys.append(key)
permuted_length_per_key.append(length_per_key[index])
if self.variable_stride_per_key():
permuted_stride_per_key_per_rank.append(
self.stride_per_key_per_rank()[index]
)
_stride_per_key_per_rank = self._stride_per_key_per_rank
if self.variable_stride_per_key() and _stride_per_key_per_rank is not None:
permuted_stride_per_key_per_rank = _stride_per_key_per_rank[indices, :]
else:
permuted_stride_per_key_per_rank = None

permuted_length_per_key_sum = sum(permuted_length_per_key)
if not torch.jit.is_scripting() and is_non_strict_exporting():
Expand Down Expand Up @@ -2576,17 +2597,15 @@ def permute(
self.weights_or_none(),
permuted_length_per_key_sum,
)
stride_per_key_per_rank = (
permuted_stride_per_key_per_rank if self.variable_stride_per_key() else None
)

kjt = KeyedJaggedTensor(
keys=permuted_keys,
values=permuted_values,
weights=permuted_weights,
lengths=permuted_lengths.view(-1),
offsets=None,
stride=self._stride,
stride_per_key_per_rank=stride_per_key_per_rank,
stride_per_key_per_rank=permuted_stride_per_key_per_rank,
stride_per_key=None,
length_per_key=permuted_length_per_key if len(permuted_keys) > 0 else None,
lengths_offset_per_key=None,
Expand Down Expand Up @@ -2904,7 +2923,7 @@ def dist_init(

if variable_stride_per_key:
assert stride_per_rank_per_key is not None
stride_per_key_per_rank_tensor: torch.Tensor = stride_per_rank_per_key.view(
stride_per_key_per_rank: torch.Tensor = stride_per_rank_per_key.view(
num_workers, len(keys)
).T.cpu()

Expand Down Expand Up @@ -2941,23 +2960,18 @@ def dist_init(
weights,
)

stride_per_key_per_rank = torch.jit.annotate(
List[List[int]], stride_per_key_per_rank_tensor.tolist()
)
if stride_per_key_per_rank.numel() == 0:
stride_per_key_per_rank = torch.zeros(
(len(keys), 1), device="cpu", dtype=torch.int64
)

if not stride_per_key_per_rank:
stride_per_key_per_rank = [[0]] * len(keys)
if stagger > 1:
stride_per_key_per_rank_stagger: List[List[int]] = []
local_world_size = num_workers // stagger
for i in range(len(keys)):
stride_per_rank_stagger: List[int] = []
for j in range(local_world_size):
stride_per_rank_stagger.extend(
stride_per_key_per_rank[i][j::local_world_size]
)
stride_per_key_per_rank_stagger.append(stride_per_rank_stagger)
stride_per_key_per_rank = stride_per_key_per_rank_stagger
indices = [
list(range(i, num_workers, local_world_size))
for i in range(local_world_size)
]
stride_per_key_per_rank = stride_per_key_per_rank[:, indices]

kjt = KeyedJaggedTensor(
keys=keys,
Expand Down
12 changes: 12 additions & 0 deletions torchrec/sparse/tests/test_keyed_jagged_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1017,6 +1017,18 @@ def test_meta_device_compatibility(self) -> None:
lengths=torch.tensor([], device=torch.device("meta")),
)

def test_vbe_kjt_stride(self) -> None:
inverse_indices = torch.tensor([[0, 1, 0], [0, 0, 0]])
kjt = KeyedJaggedTensor(
keys=["f1", "f2", "f3"],
values=torch.tensor([5, 6, 7, 1, 2, 3, 0, 1]),
lengths=torch.tensor([3, 3, 2]),
stride_per_key_per_rank=[[2], [1]],
inverse_indices=(["f1", "f2"], inverse_indices),
)

self.assertEqual(kjt.stride(), inverse_indices.shape[-1])


class TestKeyedJaggedTensorScripting(unittest.TestCase):
def test_scriptable_forward(self) -> None:
Expand Down
Loading