From aec605f35fd5a362f93b2a62b7c848a6f93793cc Mon Sep 17 00:00:00 2001 From: Jez Ng Date: Thu, 10 Aug 2023 15:42:33 -0700 Subject: [PATCH] Remove reshape workaround for softmax when dim != -1 Summary: Now that https://github.com/facebookincubator/AITemplate/pull/845 has landed, the backend supports softmax with `dim != -1` directly, and the fx converter no longer needs the workaround from https://github.com/facebookincubator/AITemplate/pull/395. Differential Revision: D48248330 fbshipit-source-id: fad534f63b642ecbf79a90f7fae3c4cc9ad4dadf --- fx2ait/fx2ait/converters/ait_converters.py | 31 +------------------ .../test/converters/test_ait_softmax.py | 27 +++------------- 2 files changed, 5 insertions(+), 53 deletions(-) diff --git a/fx2ait/fx2ait/converters/ait_converters.py b/fx2ait/fx2ait/converters/ait_converters.py index ca079ffb4..c792d26a7 100644 --- a/fx2ait/fx2ait/converters/ait_converters.py +++ b/fx2ait/fx2ait/converters/ait_converters.py @@ -450,36 +450,7 @@ def acc_ops_softmax( if not isinstance(input_val, AITTensor): raise RuntimeError(f"Unexpected input for {name}: {input_val}") - dim = kwargs["dim"] - rank = len(input_val.shape()) - if dim < 0: - dim = rank + dim - if dim != rank - 1: - for i in range(dim + 1, rank): - unsupported = False - if isinstance(input_val.shape()[i], IntImm): - if input_val.shape()[i].value() != 1: - unsupported = True - elif isinstance(input_val.shape()[i], IntVar): - unsupported = True - else: - raise RuntimeError( - f"unknown dimension type={type(i)} in AITTensor={input_val}" - ) - - if unsupported: - raise ValueError( - f"AIT softmax only supports dim=rank-1, got AITTensor={input_val}, " - f"where dim={dim}, rank={rank}" - ) - reshape_dim = size()(input_val)[: dim + 1] - reshape_val = reshape()(input_val, reshape_dim) - softmax_val = softmax()(reshape_val, -1) - return reshape()( - softmax_val, reshape_dim + [IntVarTensor(IntImm(1))] * (rank - dim - 1) - ) - - return softmax()(input_val, dim) + return softmax()(input_val, kwargs["dim"]) @ait_converter(acc_ops.relu) diff --git a/fx2ait/fx2ait/test/converters/test_ait_softmax.py b/fx2ait/fx2ait/test/converters/test_ait_softmax.py index d1171f852..1871070c5 100644 --- a/fx2ait/fx2ait/test/converters/test_ait_softmax.py +++ b/fx2ait/fx2ait/test/converters/test_ait_softmax.py @@ -54,17 +54,17 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # Test static use case inputs = [ - torch.randn(2, 3, 5, 1, 1).half().cuda(), + torch.randn(2, 3, 5, 2, 1).half().cuda(), ] self.run_test(model, inputs, expected_ops={acc_ops.softmax}) # Test dynamic use case inputs_spec = TensorSpec.create_spec_from_shapes( inputs_min=[ - [2, 3, 5, 1, 1], + [2, 3, 5, 4, 1], ], inputs_max=[ - [20, 10, 5, 1, 1], + [20, 10, 5, 4, 1], ], dtype_list=[ torch.float16, @@ -76,25 +76,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: expected_ops={acc_ops.softmax}, ) - @parameterized.expand( - [ - param("default", dim=2), - param("neg", dim=-3), - ] - ) - def test_softmax_expected_failure(self, name, dim=None): - class TestModule(torch.nn.Module): - def forward(self, x: torch.Tensor) -> torch.Tensor: - return torch.nn.functional.softmax(x, dim=dim) - - model = TestModule().cuda().half() - - inputs = [ - torch.randn(2, 3, 5, 2, 1).half().cuda(), - ] - with self.assertRaises(ValueError): - self.run_test(model, inputs, expected_ops={acc_ops.softmax}) - @parameterized.expand( [ param("default", dim=2), @@ -119,7 +100,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: torch.float16, ], ) - with self.assertRaises(ValueError): + with self.assertRaises(NotImplementedError): self.run_test_with_dynamic_shape( model, inputs_spec,