Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🐛 [Bug] IndexError encountered when using bmm in FX aten path #1789

Closed
gs-olive opened this issue Mar 30, 2023 · 1 comment
Closed

🐛 [Bug] IndexError encountered when using bmm in FX aten path #1789

gs-olive opened this issue Mar 30, 2023 · 1 comment
Assignees
Labels
bug Something isn't working

Comments

@gs-olive
Copy link
Collaborator

Bug Description

When compiling the small model below via the FX aten path, an error is encountered in the compose_bmm lowering pass.

def forward(self, x, y):
    out = torch.bmm(x, y)
    return out

ERROR:

  File "~/TensorRT/py/torch_tensorrt/fx/lower.py", line 316, in __call__
    return do_lower(module, inputs)
  File "~/TensorRT/py/torch_tensorrt/fx/passes/pass_utils.py", line 118, in pass_with_validation
    processed_module = pass_(module, input, *args, **kwargs)
  File "~/TensorRT/py/torch_tensorrt/fx/lower.py", line 313, in do_lower
    lower_result = pm(module)
  File "/usr/local/lib/python3.8/dist-packages/torch/fx/passes/pass_manager.py", line 246, in __call__
    out = _pass(out)
  File "~/TensorRT/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py", line 68, in wrapped_fn
    return fn(gm, input)
  File "~/TensorRT/py/torch_tensorrt/fx/lower.py", line 262, in <lambda>
    trace_func=lambda module, inputs: aten_tracer.opt_trace(
  File "~/TensorRT/py/torch_tensorrt/fx/utils.py", line 136, in function_wrapper
    return f(*args, **kwargs)
  File "~/TensorRT/py/torch_tensorrt/fx/tracer/dispatch_tracer/aten_tracer.py", line 158, in opt_trace
    pr: PassResult = passes(fx_module)
  File "~/TensorRT/py/torch_tensorrt/fx/passes/lower_basic_pass_aten.py", line 420, in compose_bmm
    input_input_n = input_n.all_input_nodes[0]
IndexError: list index out of range

To Reproduce

Steps to reproduce the behavior:

  1. Run the code sample below
import torch
import torch_tensorrt
    

class Sample(torch.nn.Module):
    def __init__(self):
        super(Sample, self).__init__()

    def forward(self, x, y):
        out = torch.bmm(x, y)
        return out

def main():
    model = Sample().cuda().eval()
    input_data = torch.zeros((5, 5, 5), dtype=torch.float, device="cuda:0")
    input_data_2 = torch.ones((5, 5, 5), dtype=torch.float, device="cuda:0")
    out_torch = model(input_data, input_data_2)

    mod = torch_tensorrt.fx.compile(model, [input_data, input_data_2],
                                    lower_precision=torch_tensorrt.fx.utils.LowerPrecision.FP32,
                                    min_acc_module_size=1, is_aten=True)

    out_trt = mod(input_data, input_data_2)
    print(out_trt)

main()

Expected behavior

The model should compile

Environment

  • Torch-TensorRT Version (e.g. 1.0.0): ad5e764
  • PyTorch Version (e.g. 1.0): 2.1.0.dev20230314+cu117
@peri044
Copy link
Collaborator

peri044 commented Jun 13, 2023

This error would be fixed by #2009

@peri044 peri044 closed this as completed Jun 13, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants