From 966f738639efbc977eae284221ebe2446e4232f2 Mon Sep 17 00:00:00 2001 From: HolyWu Date: Thu, 5 Sep 2024 21:23:15 +0800 Subject: [PATCH] feat: log_softmax decomposition --- .../dynamo/lowering/_decomposition_groups.py | 1 - .../dynamo/lowering/_decompositions.py | 9 +++ .../py/dynamo/lowering/test_decompositions.py | 55 +++++++++++++++++++ 3 files changed, 64 insertions(+), 1 deletion(-) diff --git a/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py b/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py index a84a550a1e..c472c31a84 100644 --- a/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py +++ b/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py @@ -76,7 +76,6 @@ aten.logit_backward, aten.log_sigmoid_backward, aten.log_sigmoid_forward, - aten._log_softmax, aten._log_softmax_backward_data, aten.logspace, aten.logsumexp.default, diff --git a/py/torch_tensorrt/dynamo/lowering/_decompositions.py b/py/torch_tensorrt/dynamo/lowering/_decompositions.py index f86e3c5cb5..63d39872c3 100644 --- a/py/torch_tensorrt/dynamo/lowering/_decompositions.py +++ b/py/torch_tensorrt/dynamo/lowering/_decompositions.py @@ -287,6 +287,15 @@ def scatter_add_decomposition( return scatter_add_tensor +@register_torch_trt_decomposition(aten._log_softmax, registry=TORCH_TRT_DECOMPOSITIONS) +def log_softmax_decomposition( + x: torch.Tensor, + dim: int, + half_to_float: bool, +) -> torch.Tensor: + return torch.log(torch.softmax(x, dim)) + + def get_decompositions( enable_experimental_decompositions: bool = False, ) -> Dict[OpOverload, Callable[[Any], Any]]: diff --git a/tests/py/dynamo/lowering/test_decompositions.py b/tests/py/dynamo/lowering/test_decompositions.py index 2f06aa7d23..1e259f3ad5 100644 --- a/tests/py/dynamo/lowering/test_decompositions.py +++ b/tests/py/dynamo/lowering/test_decompositions.py @@ -1129,6 +1129,61 @@ def forward(self, input): f"Scatter_add TRT outputs don't match with the original model.", ) + def test_lowering_log_softmax(self): + class TestModule(torch.nn.Module): + def forward(self, x): + return torch.ops.aten._log_softmax.default(x, 1, False) + + # Operations expected to be removed in the traced graph after decompositions + expected_ops = {torch.ops.aten._softmax.default, torch.ops.aten.log.default} + unexpected_ops = {torch.ops.aten._log_softmax.default} + + inputs = [torch.randn(1, 3).cuda()] + + fx_graph = torch.fx.symbolic_trace(TestModule()) + unexpected_ops_seen, expected_ops_unseen = lower_graph_testing( + fx_graph, + inputs, + expected_ops=expected_ops, + unexpected_ops=unexpected_ops, + min_block_size=1, + ) + + self.assertEqual( + len(unexpected_ops_seen), + 0, + f"The following unexpected ops were encountered: {unexpected_ops_seen}", + ) + + self.assertEqual( + len(expected_ops_unseen), + 0, + f"The following expected ops were not encountered: {expected_ops_unseen}", + ) + + torch._dynamo.reset() + + # Validate that the results between Torch and Torch-TRT are similar + optimized_model = torch_tensorrt.compile( + fx_graph, + "torch_compile", + inputs, + min_block_size=1, + pass_through_build_failures=True, + ) + optimized_model_results = optimized_model(*inputs).detach().cpu() + torch_model_results = fx_graph(*inputs).detach().cpu() + + max_diff = float( + torch.max(torch.abs(optimized_model_results - torch_model_results)) + ) + self.assertAlmostEqual( + max_diff, + 0, + DECIMALS_OF_AGREEMENT, + f"Log_softmax TRT outputs don't match with the original model.", + ) + if __name__ == "__main__": run_tests()