Skip to content

🐛 [Bug] Shape mismatch bug using view in FX aten path #1788

Closed
@gs-olive

Description

@gs-olive

Bug Description

When compiling the small model below via the FX aten path, an error is encountered during shape comparisons.

def forward(self, x):
    y = x.view(-1, 25)
    return y

ERROR:

  File "~/TensorRT/py/torch_tensorrt/fx/lower.py", line 86, in compile
    return lowerer(module, input)
  File "~/TensorRT/py/torch_tensorrt/fx/lower.py", line 316, in __call__
    return do_lower(module, inputs)
  File "~/TensorRT/py/torch_tensorrt/fx/passes/pass_utils.py", line 156, in pass_with_validation
    raise e
  File "~/TensorRT/py/torch_tensorrt/fx/passes/pass_utils.py", line 142, in pass_with_validation
    torch.testing.assert_close(x, y, **kwargs2)
  File "/usr/local/lib/python3.8/dist-packages/torch/testing/_comparison.py", line 1514, in assert_close
    raise error_metas[0].to_error(msg)
AssertionError: Pass <function Lowerer.__call__.<locals>.do_lower at 0x7f4e8db0b790> failed correctness check due at output 0:
The values for attribute 'shape' do not match: torch.Size([5, 25]) != torch.Size([5, 1, 25]).

To Reproduce

Steps to reproduce the behavior:

  1. Run the code sample below
import torch
import torch_tensorrt
    

class Sample(torch.nn.Module):
    def __init__(self):
        super(Sample, self).__init__()

    def forward(self, x):
        y = x.view(-1, 25)
        return y

def main():
    model = Sample().cuda().eval()
    input_data = torch.zeros((5, 5, 5), dtype=torch.float, device="cuda:0")
    out_torch = model(input_data)

    mod = torch_tensorrt.fx.compile(model, [input_data],
                                    lower_precision=torch_tensorrt.fx.utils.LowerPrecision.FP32,
                                    min_acc_module_size=1, is_aten=True)

    out_trt = mod(input_data)
    print(out_trt)

main()

Expected behavior

The model should compile

Environment

  • Torch-TensorRT Version (e.g. 1.0.0): ad5e764
  • PyTorch Version (e.g. 1.0): 2.1.0.dev20230314+cu117

Additional Information

Related to #1673 and #1708

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions