Skip to content

Commit

Permalink
feat: log_softmax decomposition (#3137)
Browse files Browse the repository at this point in the history
  • Loading branch information
HolyWu authored Oct 15, 2024
1 parent 9fed78c commit 873dd36
Show file tree
Hide file tree
Showing 7 changed files with 111 additions and 36 deletions.
7 changes: 6 additions & 1 deletion py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,6 +704,11 @@ def aten_ops_unsqueeze(
@dynamo_tensorrt_converter(
torch.ops.aten._softmax.default, supports_dynamic_shapes=True
)
@enforce_tensor_types(
{
0: (TRTTensor,),
}
)
def aten_ops_softmax(
ctx: ConversionContext,
target: Target,
Expand All @@ -712,7 +717,7 @@ def aten_ops_softmax(
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.normalization.softmax(
ctx, target, SourceIR.ATEN, name, args[0], args[1]
ctx, target, SourceIR.ATEN, name, args[0], args[1], args[2]
)


Expand Down
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/conversion/impl/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def scaled_dot_product_attention(
)

softmax = impl.normalization.softmax(
ctx, target, source_ir, name + "_softmax", scaled, -1
ctx, target, source_ir, name + "_softmax", scaled, -1, False
)
out = impl.matmul.matrix_multiply(
ctx,
Expand Down
27 changes: 5 additions & 22 deletions py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,30 +439,13 @@ def softmax(
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
dim: Optional[Any] = None,
dim: int,
half_to_float: bool,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_ranks = len(input.shape)
dim = get_positive_dim(dim, len(input.shape))

if not isinstance(input, TRTTensor):
raise RuntimeError(
f"softmax received input {input} that is not part "
"of the TensorRT region!"
)

# Used to get dim when dim is None. Copied from PyTorch softmax implementation.
def get_softmax_dim(ndim: int) -> int:
if ndim == 0 or ndim == 1 or ndim == 3:
ret = 0
else:
ret = 1
return ret

if dim is None:
dim = get_softmax_dim(input_ranks)
else:
dim = cast(int, dim)

dim = get_positive_dim(dim, input_ranks)
if half_to_float:
input = cast_trt_tensor(ctx, input, torch.float, name, target, source_ir)

layer = ctx.net.add_softmax(input)
layer.axes = 1 << dim
Expand Down
1 change: 0 additions & 1 deletion py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@
aten.logit_backward,
aten.log_sigmoid_backward,
aten.log_sigmoid_forward,
aten._log_softmax,
aten._log_softmax_backward_data,
aten.logspace,
aten.logsumexp.default,
Expand Down
11 changes: 11 additions & 0 deletions py/torch_tensorrt/dynamo/lowering/_decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,17 @@ def scatter_reduce_decomposition(
return scatter_loop_tensor


@register_torch_trt_decomposition(aten._log_softmax, registry=TORCH_TRT_DECOMPOSITIONS)
def log_softmax_decomposition(
x: torch.Tensor,
dim: int,
half_to_float: bool,
) -> torch.Tensor:
return torch.log(
torch.softmax(x, dim, dtype=torch.float if half_to_float else None)
)


def get_decompositions(
enable_experimental_decompositions: bool = False,
) -> Dict[OpOverload, Callable[[Any], Any]]:
Expand Down
37 changes: 26 additions & 11 deletions tests/py/dynamo/conversion/test_softmax_aten.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,47 @@
import torch
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt import Input

from .harness import DispatchTestCase


class TestSoftMaxConverter(DispatchTestCase):
def test_softmax(self):
class TestSoftmaxConverter(DispatchTestCase):
@parameterized.expand(
[
(torch.float, False),
(torch.half, False),
(torch.half, True),
]
)
def test_softmax(self, dtype, half_to_float):
class TestModule(torch.nn.Module):
def forward(self, x):
return torch.ops.aten._softmax.default(x, 1, False)
return torch.ops.aten._softmax.default(x, 1, half_to_float)

inputs = [torch.randn(1, 3, 224, 224)]
inputs = [torch.randn(1, 3, 224, 224, dtype=dtype)]
self.run_test(TestModule(), inputs)

def test_softmax_with_dynamic_shape(self):
@parameterized.expand(
[
(torch.float, False),
(torch.half, False),
(torch.half, True),
]
)
def test_softmax_with_dynamic_shape(self, dtype, half_to_float):
class TestModule(torch.nn.Module):
def forward(self, x):
return torch.ops.aten._softmax.default(x, 2, False)
return torch.ops.aten._softmax.default(x, 2, half_to_float)

input_specs = [
Input(
shape=(-1, 3, -1, -1),
dtype=torch.float32,
shape_ranges=[((1, 3, 1, 1), (1, 3, 5, 5), (2, 3, 10, 10))],
),
min_shape=(1, 1, 1, 1),
opt_shape=(2, 4, 6, 8),
max_shape=(8, 8, 8, 8),
dtype=dtype,
)
]

self.run_test_with_dynamic_shape(TestModule(), input_specs)


Expand Down
62 changes: 62 additions & 0 deletions tests/py/dynamo/lowering/test_decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1525,6 +1525,68 @@ def forward(self, input):
f"Scatter_reduce TRT outputs don't match with the original model.",
)

@parameterized.expand(
[
(torch.float, False),
(torch.half, False),
(torch.half, True),
]
)
def test_lowering_log_softmax(self, dtype, half_to_float):
class TestModule(torch.nn.Module):
def forward(self, x):
return torch.ops.aten._log_softmax.default(x, 1, half_to_float)

# Operations expected to be removed in the traced graph after decompositions
expected_ops = {torch.ops.aten._softmax.default, torch.ops.aten.log.default}
unexpected_ops = {torch.ops.aten._log_softmax.default}

inputs = [torch.randn(1, 3, 5, 7, dtype=dtype, device="cuda")]

fx_graph = torch.fx.symbolic_trace(TestModule())
unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
fx_graph,
inputs,
expected_ops=expected_ops,
unexpected_ops=unexpected_ops,
min_block_size=1,
)

self.assertEqual(
len(unexpected_ops_seen),
0,
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
)

self.assertEqual(
len(expected_ops_unseen),
0,
f"The following expected ops were not encountered: {expected_ops_unseen}",
)

torch._dynamo.reset()

# Validate that the results between Torch and Torch-TRT are similar
optimized_model = torch_tensorrt.compile(
fx_graph,
"torch_compile",
inputs,
min_block_size=1,
pass_through_build_failures=True,
)
optimized_model_results = optimized_model(*inputs).detach().cpu()
torch_model_results = fx_graph(*inputs).detach().cpu()

max_diff = float(
torch.max(torch.abs(optimized_model_results - torch_model_results))
)
self.assertAlmostEqual(
max_diff,
0,
DECIMALS_OF_AGREEMENT,
f"Log_softmax TRT outputs don't match with the original model.",
)


if __name__ == "__main__":
run_tests()

0 comments on commit 873dd36

Please sign in to comment.