Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Improve input handling in aten bmm lowering pass #1821

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 24 additions & 6 deletions py/torch_tensorrt/fx/passes/lower_basic_pass_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,33 +412,50 @@ def compose_bmm(
modified = False
for n in module.graph.nodes:
if n.op == "call_function" and n.target in (torch.ops.aten.bmm.default,):
modified = True
modified = False
node = n
input_n = node.all_input_nodes[0]
other_n = node.all_input_nodes[1]

# If no input nodes are available, the bmm argument itself could be an input
# Alternatively, if the node has no users, it can be eliminated
if len(input_n.all_input_nodes) == 0 or len(node.users) == 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we have such case which has no inputs to the bmm?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the case in #1789, which uses this model, has len(input_n.all_input_nodes) == 0:

def forward(self, x, y):
    out = torch.bmm(x, y)
    return out

The test case shows this behavior too.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did an experiment. 1) Looks like bmm is not decomposed anymore. 2) bmm's all_input_nodes is not 0
image

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@frank-wei I see, thanks for the additional details. The reason I included that test case was for the case where we take the model in question and apply the compose_bmm lowering pass to it, in which case we see the failure. Specifically, in the above example, add the following lines:

from torch_tensorrt.fx.passes.lower_basic_pass_aten import compose_bmm
composed_module = compose_bmm(gm[0])
out = composed_module.graph_module(*input2)

When doing so, the following error is elicited:

  File "test.py", line 17, in <module>
    composed_module = compose_bmm(gm[0])
  File "~/TensorRT/py/torch_tensorrt/fx/passes/lower_basic_pass_aten.py", line 409, in compose_bmm
    input_input_n = input_n.all_input_nodes[0]
IndexError: list index out of range

return PassResult(module, modified)

output = next(iter(node.users))
input_input_n = input_n.all_input_nodes[0]
if (
input_input_n.target != torch.ops.aten.expand.default
and input_n.target != torch.ops.aten.view.default
):
raise RuntimeError(
"Bmm is addressed in fixed pattern. A new pattern is met!"
_LOGGER.warn(
"Bmm is addressed in fixed pattern. "
+ f"A new pattern {input_input_n.target}, {input_n.target} is met! "
+ "Skipping bmm lowering on this operation"
)
return PassResult(module, modified)

real_input = input_input_n.all_input_nodes[0]
input_other_n = other_n.all_input_nodes[0]
if (
input_other_n.target != torch.ops.aten.expand.default
and other_n.target != torch.ops.aten.view.default
):
raise RuntimeError(
"Bmm is addressed in fixed pattern. A new pattern is met!"
_LOGGER.warn(
"Bmm is addressed in fixed pattern. "
+ f"A new pattern {input_other_n.target}, {other_n.target} is met! "
+ "Skipping bmm lowering on this operation"
)
return PassResult(module, modified)

real_other = input_other_n.all_input_nodes[0]
if len(real_other.meta["val"].size()) == 2:
new_func = aten_compose_bmm_2d
if len(real_other.meta["val"].size()) == 3:
elif len(real_other.meta["val"].size()) == 3:
new_func = aten_compose_bmm_3d
else:
# No valid bmm replacement exists for the specified dimensions
return PassResult(module, modified)

with module.graph.inserting_after(node):
new_args = (real_input, real_other)
Expand All @@ -449,6 +466,7 @@ def compose_bmm(
kwargs=None,
)
output.replace_all_uses_with(new_node)
modified = True

module.graph.eliminate_dead_code()
module.recompile()
Expand Down
28 changes: 28 additions & 0 deletions py/torch_tensorrt/fx/test/passes/test_compose_bmm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import torch
import torch.nn as nn
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests, TestCase
from torch_tensorrt.fx.tracer.dispatch_tracer.aten_tracer import trace
from torch_tensorrt.fx.passes.lower_basic_pass_aten import compose_bmm


class TestComposeBMM(TestCase):
@parameterized.expand(
[
("3_dim", (2, 3, 4), (2, 4, 3)),
("3_dim_same_shape", (4, 4, 4), (4, 4, 4)),
]
)
def test_compose_bmm(self, test_name, x_shape, y_shape):
class BMM(nn.Module):
def forward(self, x, y):
return torch.bmm(x, y)

inputs = [torch.randn(x_shape), torch.randn(y_shape)]
fx_model, _ = trace(BMM(), inputs)
composed_module = compose_bmm(fx_model)
out = composed_module.graph_module(*inputs)
Copy link
Contributor

@frank-wei frank-wei Apr 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test is not complete yet. Could you pls add the converter and complete the test as other test?
One reference for the converter(aten2ait): https://github.com/facebookincubator/AITemplate/blob/main/fx2ait/fx2ait/converters/aten2ait_converters.py#L557-L559



if __name__ == "__main__":
run_tests()