Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Update lowering passes in aten tracer FX #1708

Closed
wants to merge 1 commit into from

Conversation

gs-olive
Copy link
Collaborator

@gs-olive gs-olive commented Mar 2, 2023

Description

  • Enable translation to reshape from view, which was causing failures when compiling BERT model due to memory layout of Tensors
  • Default to matmul within compose_bmm lowering pass when the dimension of inputs exceeds 3

Error displayed prior to remove_ops view fix (BERT model from Issue #1673):

RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

Error displayed prior to compose_bmm fix:

  File "~/TensorRT/py/torch_tensorrt/fx/passes/lower_basic_pass_aten.py", line 439, in compose_bmm
    new_func,
UnboundLocalError: local variable 'new_func' referenced before assignment

Note: test_reshape_aten is currently failing since the aten.view.default ops are being converted to aten.reshape

Fixes #1673

Type of change

Please delete options that are not relevant and/or add your own.

  • Bug fix (non-breaking change which fixes an issue)

Checklist:

  • [ x ] My code follows the style guidelines of this project (You can use the linters)
  • [ x ] I have performed a self-review of my own code
  • [ x ] I have commented my code, particularly in hard-to-understand areas and hacks
  • [ x ] I have made corresponding changes to the documentation
  • [ x ] I have added tests to verify my fix or my feature
  • [ x ] New and existing unit tests pass locally with my changes
  • [ x ] I have added the relevant labels to my PR in so that relevant reviewers are notified

- Enable translation to `reshape` from `view`, which was causing
failures when compiling BERT model due to memory layout of Tensors
- Default to `matmul` within `compose_bmm` lowering pass when the
dimension of inputs exceeds 3
@@ -258,6 +258,7 @@ def remove_ops(
for n in module.graph.nodes:
if n.op == "call_function" and n.target in (
torch.ops.aten._unsafe_view.default,
torch.ops.aten.view.default,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not necessary to remove aten.view since the reshape operation is decomposed into aten.view(which is safe) and we have converter to support aten.view.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see - thank you for the clarification on that. The reason I had removed the view operator was for cases like this:

    def forward(self, x):
        x = x.permute(0, 2, 1, 3).contiguous()
        new_shape = x.size()[:-2] + (-1,)
        return x.view(new_shape)

These show up in the GPT2 code, and when using the aten tracer, they result in the following error (though they run fine in Torch):

  File "~/TensorRT/py/torch_tensorrt/fx/tracer/dispatch_tracer/aten_tracer.py", line 161, in opt_trace
    fx_module(*args)
  File "/usr/local/lib/python3.8/dist-packages/torch/fx/graph_module.py", line 662, in call_wrapped
    return self._wrapped_call(self, *args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/fx/graph_module.py", line 281, in __call__
    raise e
  File "/usr/local/lib/python3.8/dist-packages/torch/fx/graph_module.py", line 271, in __call__
    return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "<eval_with_key>.15", line 9, in forward
  File "/usr/local/lib/python3.8/dist-packages/torch/_ops.py", line 329, in __call__
    return self._op(*args, **kwargs or {})
RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

new_func = aten_compose_bmm_3d
else:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not clear why we need this new_func = torch.ops.aten.matmul? Any example or unit test?

Copy link
Collaborator Author

@gs-olive gs-olive Mar 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This addition is related to an issue in the compose_bmm lowering pass. I noticed that input_n can have a different shape than real_input, which causes the batch matrix multiply to have 4 dimensions instead of 3, reaching this else statement. I don't yet have a minimal reproducing example yet, as #1789 would likely need to be addressed first.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed component: api [Python] Issues re: Python API component: fx fx WIP Work is in progress, pull request should not be merged yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

🐛 [Bug] Transformers BERT Model does not compile via FX Path
3 participants