Skip to content

Commit

Permalink
feat: log_softmax decomposition
Browse files Browse the repository at this point in the history
  • Loading branch information
HolyWu committed Sep 5, 2024
1 parent 29b4913 commit 966f738
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 1 deletion.
1 change: 0 additions & 1 deletion py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@
aten.logit_backward,
aten.log_sigmoid_backward,
aten.log_sigmoid_forward,
aten._log_softmax,
aten._log_softmax_backward_data,
aten.logspace,
aten.logsumexp.default,
Expand Down
9 changes: 9 additions & 0 deletions py/torch_tensorrt/dynamo/lowering/_decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,15 @@ def scatter_add_decomposition(
return scatter_add_tensor


@register_torch_trt_decomposition(aten._log_softmax, registry=TORCH_TRT_DECOMPOSITIONS)
def log_softmax_decomposition(
x: torch.Tensor,
dim: int,
half_to_float: bool,
) -> torch.Tensor:
return torch.log(torch.softmax(x, dim))


def get_decompositions(
enable_experimental_decompositions: bool = False,
) -> Dict[OpOverload, Callable[[Any], Any]]:
Expand Down
55 changes: 55 additions & 0 deletions tests/py/dynamo/lowering/test_decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1129,6 +1129,61 @@ def forward(self, input):
f"Scatter_add TRT outputs don't match with the original model.",
)

def test_lowering_log_softmax(self):
class TestModule(torch.nn.Module):
def forward(self, x):
return torch.ops.aten._log_softmax.default(x, 1, False)

# Operations expected to be removed in the traced graph after decompositions
expected_ops = {torch.ops.aten._softmax.default, torch.ops.aten.log.default}
unexpected_ops = {torch.ops.aten._log_softmax.default}

inputs = [torch.randn(1, 3).cuda()]

fx_graph = torch.fx.symbolic_trace(TestModule())
unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
fx_graph,
inputs,
expected_ops=expected_ops,
unexpected_ops=unexpected_ops,
min_block_size=1,
)

self.assertEqual(
len(unexpected_ops_seen),
0,
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
)

self.assertEqual(
len(expected_ops_unseen),
0,
f"The following expected ops were not encountered: {expected_ops_unseen}",
)

torch._dynamo.reset()

# Validate that the results between Torch and Torch-TRT are similar
optimized_model = torch_tensorrt.compile(
fx_graph,
"torch_compile",
inputs,
min_block_size=1,
pass_through_build_failures=True,
)
optimized_model_results = optimized_model(*inputs).detach().cpu()
torch_model_results = fx_graph(*inputs).detach().cpu()

max_diff = float(
torch.max(torch.abs(optimized_model_results - torch_model_results))
)
self.assertAlmostEqual(
max_diff,
0,
DECIMALS_OF_AGREEMENT,
f"Log_softmax TRT outputs don't match with the original model.",
)


if __name__ == "__main__":
run_tests()

0 comments on commit 966f738

Please sign in to comment.