Skip to content

feat: log_softmax decomposition #3137

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Oct 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,6 +704,11 @@ def aten_ops_unsqueeze(
@dynamo_tensorrt_converter(
torch.ops.aten._softmax.default, supports_dynamic_shapes=True
)
@enforce_tensor_types(
{
0: (TRTTensor,),
}
)
def aten_ops_softmax(
ctx: ConversionContext,
target: Target,
Expand All @@ -712,7 +717,7 @@ def aten_ops_softmax(
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.normalization.softmax(
ctx, target, SourceIR.ATEN, name, args[0], args[1]
ctx, target, SourceIR.ATEN, name, args[0], args[1], args[2]
)


Expand Down
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/conversion/impl/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def scaled_dot_product_attention(
)

softmax = impl.normalization.softmax(
ctx, target, source_ir, name + "_softmax", scaled, -1
ctx, target, source_ir, name + "_softmax", scaled, -1, False
)
out = impl.matmul.matrix_multiply(
ctx,
Expand Down
27 changes: 5 additions & 22 deletions py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,30 +439,13 @@ def softmax(
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
dim: Optional[Any] = None,
dim: int,
half_to_float: bool,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_ranks = len(input.shape)
dim = get_positive_dim(dim, len(input.shape))

if not isinstance(input, TRTTensor):
raise RuntimeError(
f"softmax received input {input} that is not part "
"of the TensorRT region!"
)

# Used to get dim when dim is None. Copied from PyTorch softmax implementation.
def get_softmax_dim(ndim: int) -> int:
if ndim == 0 or ndim == 1 or ndim == 3:
ret = 0
else:
ret = 1
return ret

if dim is None:
dim = get_softmax_dim(input_ranks)
else:
dim = cast(int, dim)

dim = get_positive_dim(dim, input_ranks)
if half_to_float:
input = cast_trt_tensor(ctx, input, torch.float, name, target, source_ir)

layer = ctx.net.add_softmax(input)
layer.axes = 1 << dim
Expand Down
1 change: 0 additions & 1 deletion py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@
aten.logit_backward,
aten.log_sigmoid_backward,
aten.log_sigmoid_forward,
aten._log_softmax,
aten._log_softmax_backward_data,
aten.logspace,
aten.logsumexp.default,
Expand Down
11 changes: 11 additions & 0 deletions py/torch_tensorrt/dynamo/lowering/_decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,17 @@ def scatter_reduce_decomposition(
return scatter_loop_tensor


@register_torch_trt_decomposition(aten._log_softmax, registry=TORCH_TRT_DECOMPOSITIONS)
def log_softmax_decomposition(
x: torch.Tensor,
dim: int,
half_to_float: bool,
) -> torch.Tensor:
return torch.log(
torch.softmax(x, dim, dtype=torch.float if half_to_float else None)
)


def get_decompositions(
enable_experimental_decompositions: bool = False,
) -> Dict[OpOverload, Callable[[Any], Any]]:
Expand Down
37 changes: 26 additions & 11 deletions tests/py/dynamo/conversion/test_softmax_aten.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,47 @@
import torch
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt import Input

from .harness import DispatchTestCase


class TestSoftMaxConverter(DispatchTestCase):
def test_softmax(self):
class TestSoftmaxConverter(DispatchTestCase):
@parameterized.expand(
[
(torch.float, False),
(torch.half, False),
(torch.half, True),
]
)
def test_softmax(self, dtype, half_to_float):
class TestModule(torch.nn.Module):
def forward(self, x):
return torch.ops.aten._softmax.default(x, 1, False)
return torch.ops.aten._softmax.default(x, 1, half_to_float)

inputs = [torch.randn(1, 3, 224, 224)]
inputs = [torch.randn(1, 3, 224, 224, dtype=dtype)]
self.run_test(TestModule(), inputs)

def test_softmax_with_dynamic_shape(self):
@parameterized.expand(
[
(torch.float, False),
(torch.half, False),
(torch.half, True),
]
)
def test_softmax_with_dynamic_shape(self, dtype, half_to_float):
class TestModule(torch.nn.Module):
def forward(self, x):
return torch.ops.aten._softmax.default(x, 2, False)
return torch.ops.aten._softmax.default(x, 2, half_to_float)

input_specs = [
Input(
shape=(-1, 3, -1, -1),
dtype=torch.float32,
shape_ranges=[((1, 3, 1, 1), (1, 3, 5, 5), (2, 3, 10, 10))],
),
min_shape=(1, 1, 1, 1),
opt_shape=(2, 4, 6, 8),
max_shape=(8, 8, 8, 8),
dtype=dtype,
)
]

self.run_test_with_dynamic_shape(TestModule(), input_specs)


Expand Down
62 changes: 62 additions & 0 deletions tests/py/dynamo/lowering/test_decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1525,6 +1525,68 @@ def forward(self, input):
f"Scatter_reduce TRT outputs don't match with the original model.",
)

@parameterized.expand(
[
(torch.float, False),
(torch.half, False),
(torch.half, True),
]
)
def test_lowering_log_softmax(self, dtype, half_to_float):
class TestModule(torch.nn.Module):
def forward(self, x):
return torch.ops.aten._log_softmax.default(x, 1, half_to_float)

# Operations expected to be removed in the traced graph after decompositions
expected_ops = {torch.ops.aten._softmax.default, torch.ops.aten.log.default}
unexpected_ops = {torch.ops.aten._log_softmax.default}

inputs = [torch.randn(1, 3, 5, 7, dtype=dtype, device="cuda")]

fx_graph = torch.fx.symbolic_trace(TestModule())
unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
fx_graph,
inputs,
expected_ops=expected_ops,
unexpected_ops=unexpected_ops,
min_block_size=1,
)

self.assertEqual(
len(unexpected_ops_seen),
0,
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
)

self.assertEqual(
len(expected_ops_unseen),
0,
f"The following expected ops were not encountered: {expected_ops_unseen}",
)

torch._dynamo.reset()

# Validate that the results between Torch and Torch-TRT are similar
optimized_model = torch_tensorrt.compile(
fx_graph,
"torch_compile",
inputs,
min_block_size=1,
pass_through_build_failures=True,
)
optimized_model_results = optimized_model(*inputs).detach().cpu()
torch_model_results = fx_graph(*inputs).detach().cpu()

max_diff = float(
torch.max(torch.abs(optimized_model_results - torch_model_results))
)
self.assertAlmostEqual(
max_diff,
0,
DECIMALS_OF_AGREEMENT,
f"Log_softmax TRT outputs don't match with the original model.",
)


if __name__ == "__main__":
run_tests()
Loading