-
Notifications
You must be signed in to change notification settings - Fork 348
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix: Improve input handling in aten bmm lowering pass #1821
Conversation
- Add checks for `bmm` lowering pass to return the module with no changes in the event of invalid schemas used, instead of failing or throwing a RuntimeError - Add regression test case to catch error in bmm and evaluate the lowering pass
|
||
# If no input nodes are available, the bmm argument itself could be an input | ||
# Alternatively, if the node has no users, it can be eliminated | ||
if len(input_n.all_input_nodes) == 0 or len(node.users) == 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we have such case which has no inputs to the bmm?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, the case in #1789, which uses this model, has len(input_n.all_input_nodes) == 0
:
def forward(self, x, y):
out = torch.bmm(x, y)
return out
The test case shows this behavior too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@frank-wei I see, thanks for the additional details. The reason I included that test case was for the case where we take the model in question and apply the compose_bmm
lowering pass to it, in which case we see the failure. Specifically, in the above example, add the following lines:
from torch_tensorrt.fx.passes.lower_basic_pass_aten import compose_bmm
composed_module = compose_bmm(gm[0])
out = composed_module.graph_module(*input2)
When doing so, the following error is elicited:
File "test.py", line 17, in <module>
composed_module = compose_bmm(gm[0])
File "~/TensorRT/py/torch_tensorrt/fx/passes/lower_basic_pass_aten.py", line 409, in compose_bmm
input_input_n = input_n.all_input_nodes[0]
IndexError: list index out of range
inputs = [torch.randn(x_shape), torch.randn(y_shape)] | ||
fx_model, _ = trace(BMM(), inputs) | ||
composed_module = compose_bmm(fx_model) | ||
out = composed_module.graph_module(*inputs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test is not complete yet. Could you pls add the converter and complete the test as other test?
One reference for the converter(aten2ait): https://github.com/facebookincubator/AITemplate/blob/main/fx2ait/fx2ait/converters/aten2ait_converters.py#L557-L559
@gs-olive what is remaining to be done on this PR? |
This PR could use some additional testing per @frank-wei's comments, and I discussed with @peri044 about starting to look into some of the aten lowering pass issues such as #1787, #1788, #1789. I can add the requested additional testing to this PR. |
@frank-wei I see the torch.bmm is not decomposed anymore as well. I was wondering if we could remove this pass
fx_ts_compat (by cloning the tracer files and modifying). Can you let us know what your thoughts are ?
|
Hi, @peri044 we can remove it from here:
|
Thanks @frank-wei Here is the PR : #2009 |
Closed in favor of #2009 |
Description
bmm
lowering pass to return the module with no changes in the event of invalid schemas used, instead of failing or throwing a RuntimeErrorFixes #1789
Type of change
Checklist: