Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: log_softmax decomposition #3137

Merged
merged 10 commits into from
Oct 15, 2024
1 change: 0 additions & 1 deletion py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@
aten.logit_backward,
aten.log_sigmoid_backward,
aten.log_sigmoid_forward,
aten._log_softmax,
aten._log_softmax_backward_data,
aten.logspace,
aten.logsumexp.default,
Expand Down
9 changes: 9 additions & 0 deletions py/torch_tensorrt/dynamo/lowering/_decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,15 @@ def scatter_add_decomposition(
return scatter_add_tensor


@register_torch_trt_decomposition(aten._log_softmax, registry=TORCH_TRT_DECOMPOSITIONS)
def log_softmax_decomposition(
x: torch.Tensor,
dim: int,
half_to_float: bool,
HolyWu marked this conversation as resolved.
Show resolved Hide resolved
) -> torch.Tensor:
return torch.log(torch.softmax(x, dim))


def get_decompositions(
enable_experimental_decompositions: bool = False,
) -> Dict[OpOverload, Callable[[Any], Any]]:
Expand Down
55 changes: 55 additions & 0 deletions tests/py/dynamo/lowering/test_decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1129,6 +1129,61 @@ def forward(self, input):
f"Scatter_add TRT outputs don't match with the original model.",
)

def test_lowering_log_softmax(self):
class TestModule(torch.nn.Module):
def forward(self, x):
return torch.ops.aten._log_softmax.default(x, 1, False)

# Operations expected to be removed in the traced graph after decompositions
expected_ops = {torch.ops.aten._softmax.default, torch.ops.aten.log.default}
unexpected_ops = {torch.ops.aten._log_softmax.default}

inputs = [torch.randn(1, 3).cuda()]

fx_graph = torch.fx.symbolic_trace(TestModule())
unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
fx_graph,
inputs,
expected_ops=expected_ops,
unexpected_ops=unexpected_ops,
min_block_size=1,
)

self.assertEqual(
len(unexpected_ops_seen),
0,
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
)

self.assertEqual(
len(expected_ops_unseen),
0,
f"The following expected ops were not encountered: {expected_ops_unseen}",
)

torch._dynamo.reset()

# Validate that the results between Torch and Torch-TRT are similar
optimized_model = torch_tensorrt.compile(
fx_graph,
"torch_compile",
inputs,
min_block_size=1,
pass_through_build_failures=True,
)
optimized_model_results = optimized_model(*inputs).detach().cpu()
torch_model_results = fx_graph(*inputs).detach().cpu()

max_diff = float(
torch.max(torch.abs(optimized_model_results - torch_model_results))
)
self.assertAlmostEqual(
max_diff,
0,
DECIMALS_OF_AGREEMENT,
f"Log_softmax TRT outputs don't match with the original model.",
)


if __name__ == "__main__":
run_tests()
Loading