Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Improve input handling in aten bmm lowering pass #1821

Closed
wants to merge 1 commit into from

Conversation

gs-olive
Copy link
Collaborator

Description

  • Add checks for bmm lowering pass to return the module with no changes in the event of invalid schemas used, instead of failing or throwing a RuntimeError
  • Add regression test case to catch error in bmm and evaluate the lowering pass

Fixes #1789

Type of change

  • Bug fix (non-breaking change which fixes an issue)

Checklist:

  • [ x ] My code follows the style guidelines of this project (You can use the linters)
  • [ x ] I have performed a self-review of my own code
  • [ x ] I have commented my code, particularly in hard-to-understand areas and hacks
  • [ x ] I have made corresponding changes to the documentation
  • [ x ] I have added tests to verify my fix or my feature
  • [ x ] New and existing unit tests pass locally with my changes
  • [ x ] I have added the relevant labels to my PR in so that relevant reviewers are notified

- Add checks for `bmm` lowering pass to return the module with no
changes in the event of invalid schemas used, instead of failing or
throwing a RuntimeError
- Add regression test case to catch error in bmm and evaluate the
lowering pass

# If no input nodes are available, the bmm argument itself could be an input
# Alternatively, if the node has no users, it can be eliminated
if len(input_n.all_input_nodes) == 0 or len(node.users) == 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we have such case which has no inputs to the bmm?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the case in #1789, which uses this model, has len(input_n.all_input_nodes) == 0:

def forward(self, x, y):
    out = torch.bmm(x, y)
    return out

The test case shows this behavior too.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did an experiment. 1) Looks like bmm is not decomposed anymore. 2) bmm's all_input_nodes is not 0
image

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@frank-wei I see, thanks for the additional details. The reason I included that test case was for the case where we take the model in question and apply the compose_bmm lowering pass to it, in which case we see the failure. Specifically, in the above example, add the following lines:

from torch_tensorrt.fx.passes.lower_basic_pass_aten import compose_bmm
composed_module = compose_bmm(gm[0])
out = composed_module.graph_module(*input2)

When doing so, the following error is elicited:

  File "test.py", line 17, in <module>
    composed_module = compose_bmm(gm[0])
  File "~/TensorRT/py/torch_tensorrt/fx/passes/lower_basic_pass_aten.py", line 409, in compose_bmm
    input_input_n = input_n.all_input_nodes[0]
IndexError: list index out of range

inputs = [torch.randn(x_shape), torch.randn(y_shape)]
fx_model, _ = trace(BMM(), inputs)
composed_module = compose_bmm(fx_model)
out = composed_module.graph_module(*inputs)
Copy link
Contributor

@frank-wei frank-wei Apr 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test is not complete yet. Could you pls add the converter and complete the test as other test?
One reference for the converter(aten2ait): https://github.com/facebookincubator/AITemplate/blob/main/fx2ait/fx2ait/converters/aten2ait_converters.py#L557-L559

@narendasan
Copy link
Collaborator

@gs-olive what is remaining to be done on this PR?

@gs-olive
Copy link
Collaborator Author

This PR could use some additional testing per @frank-wei's comments, and I discussed with @peri044 about starting to look into some of the aten lowering pass issues such as #1787, #1788, #1789. I can add the requested additional testing to this PR.

@peri044
Copy link
Collaborator

peri044 commented Jun 7, 2023

@frank-wei I see the torch.bmm is not decomposed anymore as well. I was wondering if we could remove this pass

. We could either remove this here or we can do it under fx_ts_compat (by cloning the tracer files and modifying). Can you let us know what your thoughts are ?

@frank-wei
Copy link
Contributor

@frank-wei I see the torch.bmm is not decomposed anymore as well. I was wondering if we could remove this pass

. We could either remove this here or we can do it under fx_ts_compat (by cloning the tracer files and modifying). Can you let us know what your thoughts are ?

Hi, @peri044 we can remove it from here:

@peri044
Copy link
Collaborator

peri044 commented Jun 12, 2023

Thanks @frank-wei Here is the PR : #2009

@gs-olive
Copy link
Collaborator Author

Closed in favor of #2009

@gs-olive gs-olive closed this Jun 20, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

🐛 [Bug] IndexError encountered when using bmm in FX aten path
5 participants