-
Notifications
You must be signed in to change notification settings - Fork 350
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix: Update lowering passes in aten
tracer FX
#1708
Conversation
- Enable translation to `reshape` from `view`, which was causing failures when compiling BERT model due to memory layout of Tensors - Default to `matmul` within `compose_bmm` lowering pass when the dimension of inputs exceeds 3
4990f6c
to
a063082
Compare
@@ -258,6 +258,7 @@ def remove_ops( | |||
for n in module.graph.nodes: | |||
if n.op == "call_function" and n.target in ( | |||
torch.ops.aten._unsafe_view.default, | |||
torch.ops.aten.view.default, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is not necessary to remove aten.view since the reshape operation is decomposed into aten.view(which is safe) and we have converter to support aten.view.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see - thank you for the clarification on that. The reason I had removed the view operator was for cases like this:
def forward(self, x):
x = x.permute(0, 2, 1, 3).contiguous()
new_shape = x.size()[:-2] + (-1,)
return x.view(new_shape)
These show up in the GPT2 code, and when using the aten
tracer, they result in the following error (though they run fine in Torch):
File "~/TensorRT/py/torch_tensorrt/fx/tracer/dispatch_tracer/aten_tracer.py", line 161, in opt_trace
fx_module(*args)
File "/usr/local/lib/python3.8/dist-packages/torch/fx/graph_module.py", line 662, in call_wrapped
return self._wrapped_call(self, *args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/fx/graph_module.py", line 281, in __call__
raise e
File "/usr/local/lib/python3.8/dist-packages/torch/fx/graph_module.py", line 271, in __call__
return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc]
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "<eval_with_key>.15", line 9, in forward
File "/usr/local/lib/python3.8/dist-packages/torch/_ops.py", line 329, in __call__
return self._op(*args, **kwargs or {})
RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
new_func = aten_compose_bmm_3d | ||
else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not clear why we need this new_func = torch.ops.aten.matmul
? Any example or unit test?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This addition is related to an issue in the compose_bmm
lowering pass. I noticed that input_n can have a different shape than real_input, which causes the batch matrix multiply to have 4 dimensions instead of 3, reaching this else
statement. I don't yet have a minimal reproducing example yet, as #1789 would likely need to be addressed first.
Description
reshape
fromview
, which was causing failures when compiling BERT model due to memory layout of Tensorsmatmul
withincompose_bmm
lowering pass when the dimension of inputs exceeds 3Error displayed prior to
remove_ops
view fix (BERT model from Issue #1673):Error displayed prior to
compose_bmm
fix:Note:
test_reshape_aten
is currently failing since theaten.view.default
ops are being converted toaten.reshape
Fixes #1673
Type of change
Please delete options that are not relevant and/or add your own.
Checklist: