Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove reshape workaround for softmax when dim != -1 #895

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 1 addition & 30 deletions fx2ait/fx2ait/converters/ait_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,36 +450,7 @@ def acc_ops_softmax(
if not isinstance(input_val, AITTensor):
raise RuntimeError(f"Unexpected input for {name}: {input_val}")

dim = kwargs["dim"]
rank = len(input_val.shape())
if dim < 0:
dim = rank + dim
if dim != rank - 1:
for i in range(dim + 1, rank):
unsupported = False
if isinstance(input_val.shape()[i], IntImm):
if input_val.shape()[i].value() != 1:
unsupported = True
elif isinstance(input_val.shape()[i], IntVar):
unsupported = True
else:
raise RuntimeError(
f"unknown dimension type={type(i)} in AITTensor={input_val}"
)

if unsupported:
raise ValueError(
f"AIT softmax only supports dim=rank-1, got AITTensor={input_val}, "
f"where dim={dim}, rank={rank}"
)
reshape_dim = size()(input_val)[: dim + 1]
reshape_val = reshape()(input_val, reshape_dim)
softmax_val = softmax()(reshape_val, -1)
return reshape()(
softmax_val, reshape_dim + [IntVarTensor(IntImm(1))] * (rank - dim - 1)
)

return softmax()(input_val, dim)
return softmax()(input_val, kwargs["dim"])


@ait_converter(acc_ops.relu)
Expand Down
27 changes: 4 additions & 23 deletions fx2ait/fx2ait/test/converters/test_ait_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,17 +54,17 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:

# Test static use case
inputs = [
torch.randn(2, 3, 5, 1, 1).half().cuda(),
torch.randn(2, 3, 5, 2, 1).half().cuda(),
]
self.run_test(model, inputs, expected_ops={acc_ops.softmax})

# Test dynamic use case
inputs_spec = TensorSpec.create_spec_from_shapes(
inputs_min=[
[2, 3, 5, 1, 1],
[2, 3, 5, 4, 1],
],
inputs_max=[
[20, 10, 5, 1, 1],
[20, 10, 5, 4, 1],
],
dtype_list=[
torch.float16,
Expand All @@ -76,25 +76,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
expected_ops={acc_ops.softmax},
)

@parameterized.expand(
[
param("default", dim=2),
param("neg", dim=-3),
]
)
def test_softmax_expected_failure(self, name, dim=None):
class TestModule(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.nn.functional.softmax(x, dim=dim)

model = TestModule().cuda().half()

inputs = [
torch.randn(2, 3, 5, 2, 1).half().cuda(),
]
with self.assertRaises(ValueError):
self.run_test(model, inputs, expected_ops={acc_ops.softmax})

@parameterized.expand(
[
param("default", dim=2),
Expand All @@ -119,7 +100,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
torch.float16,
],
)
with self.assertRaises(ValueError):
with self.assertRaises(NotImplementedError):
self.run_test_with_dynamic_shape(
model,
inputs_spec,
Expand Down