Skip to content

🐛 [Bug] IndexError encountered when using bmm in FX aten path #1789

Closed
@gs-olive

Description

@gs-olive

Bug Description

When compiling the small model below via the FX aten path, an error is encountered in the compose_bmm lowering pass.

def forward(self, x, y):
    out = torch.bmm(x, y)
    return out

ERROR:

  File "~/TensorRT/py/torch_tensorrt/fx/lower.py", line 316, in __call__
    return do_lower(module, inputs)
  File "~/TensorRT/py/torch_tensorrt/fx/passes/pass_utils.py", line 118, in pass_with_validation
    processed_module = pass_(module, input, *args, **kwargs)
  File "~/TensorRT/py/torch_tensorrt/fx/lower.py", line 313, in do_lower
    lower_result = pm(module)
  File "/usr/local/lib/python3.8/dist-packages/torch/fx/passes/pass_manager.py", line 246, in __call__
    out = _pass(out)
  File "~/TensorRT/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py", line 68, in wrapped_fn
    return fn(gm, input)
  File "~/TensorRT/py/torch_tensorrt/fx/lower.py", line 262, in <lambda>
    trace_func=lambda module, inputs: aten_tracer.opt_trace(
  File "~/TensorRT/py/torch_tensorrt/fx/utils.py", line 136, in function_wrapper
    return f(*args, **kwargs)
  File "~/TensorRT/py/torch_tensorrt/fx/tracer/dispatch_tracer/aten_tracer.py", line 158, in opt_trace
    pr: PassResult = passes(fx_module)
  File "~/TensorRT/py/torch_tensorrt/fx/passes/lower_basic_pass_aten.py", line 420, in compose_bmm
    input_input_n = input_n.all_input_nodes[0]
IndexError: list index out of range

To Reproduce

Steps to reproduce the behavior:

  1. Run the code sample below
import torch
import torch_tensorrt
    

class Sample(torch.nn.Module):
    def __init__(self):
        super(Sample, self).__init__()

    def forward(self, x, y):
        out = torch.bmm(x, y)
        return out

def main():
    model = Sample().cuda().eval()
    input_data = torch.zeros((5, 5, 5), dtype=torch.float, device="cuda:0")
    input_data_2 = torch.ones((5, 5, 5), dtype=torch.float, device="cuda:0")
    out_torch = model(input_data, input_data_2)

    mod = torch_tensorrt.fx.compile(model, [input_data, input_data_2],
                                    lower_precision=torch_tensorrt.fx.utils.LowerPrecision.FP32,
                                    min_acc_module_size=1, is_aten=True)

    out_trt = mod(input_data, input_data_2)
    print(out_trt)

main()

Expected behavior

The model should compile

Environment

  • Torch-TensorRT Version (e.g. 1.0.0): ad5e764
  • PyTorch Version (e.g. 1.0): 2.1.0.dev20230314+cu117

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions