Skip to content

Commit

Permalink
Add deduce_jagged_tensor_with_graph_analysis flag for batch dim disti…
Browse files Browse the repository at this point in the history
…nguish (#929)

Summary:

For vdd, it seems that the jagged tensor batch dim is identical to dense tensor batch dim, which caused issue in bmm kernel, that it cannot handle batch size as large as 2^16.

This fix adds a flag `deduce_jagged_tensor_with_graph_analysis` so that when it is turnt on, we depend on graph analysis, i.e. `try_getting_jagged_tensor_map`, to deduce batch dim for jagged tensor. This can be more reliable than deducing based on value.

Differential Revision: D49262422
  • Loading branch information
tissue3 authored and facebook-github-bot committed Sep 14, 2023
1 parent 4ca3435 commit 6d87d33
Showing 1 changed file with 4 additions and 17 deletions.
21 changes: 4 additions & 17 deletions fx2ait/fx2ait/tensor_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,8 +306,9 @@ def _get_max_seq_lens_from_offsets(

return max_seq_lens

@staticmethod
def _try_getting_jagged_tensor_map(
@classmethod
def try_getting_jagged_tensor_map(
cls,
inputs: List[torch.Tensor],
jagged_tensor_batch_dims: Set[int],
fx_inputs: Optional[List[torch.fx.Node]] = None,
Expand Down Expand Up @@ -371,6 +372,7 @@ def from_input_list_with_batch_size_jagged_tensor(
additional_inputs: List[torch.Tensor] = None,
infer_max_seq_lens_from_offsets: bool = False,
fx_inputs: List[torch.fx.Node] = None,
jagged_tensor_map: Optional[Dict[int, int]] = None,
) -> List["TensorSpec"]:
"""
Most of the recommendation models will work fine using this function.
Expand All @@ -385,21 +387,6 @@ def from_input_list_with_batch_size_jagged_tensor(
jagged_offsets_batch_dims=jagged_offsets_batch_dims,
)

jagged_tensor_map = cls._try_getting_jagged_tensor_map(
inputs=inputs,
jagged_tensor_batch_dims=jagged_tensor_batch_dims,
fx_inputs=fx_inputs,
)
if jagged_tensor_map:
logger.info("Successfully detected a jagged_tensor_map:")
for input_id, jagged_tensor_id in jagged_tensor_map.items():
logger.info(f"{input_id=}, {jagged_tensor_id=}")
else:
logger.info(
"Unable to detect a jagged_tensor_map: falling back "
"to the batch dim-based jagged tensor detection."
)

result: List = []
result_unsorted: List = []
left_inputs: List = []
Expand Down

0 comments on commit 6d87d33

Please sign in to comment.