Closed
Description
Bug Description
When compiling the small model below via the FX aten path, an error is encountered in the compose_bmm
lowering pass.
def forward(self, x, y):
out = torch.bmm(x, y)
return out
ERROR:
File "~/TensorRT/py/torch_tensorrt/fx/lower.py", line 316, in __call__
return do_lower(module, inputs)
File "~/TensorRT/py/torch_tensorrt/fx/passes/pass_utils.py", line 118, in pass_with_validation
processed_module = pass_(module, input, *args, **kwargs)
File "~/TensorRT/py/torch_tensorrt/fx/lower.py", line 313, in do_lower
lower_result = pm(module)
File "/usr/local/lib/python3.8/dist-packages/torch/fx/passes/pass_manager.py", line 246, in __call__
out = _pass(out)
File "~/TensorRT/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py", line 68, in wrapped_fn
return fn(gm, input)
File "~/TensorRT/py/torch_tensorrt/fx/lower.py", line 262, in <lambda>
trace_func=lambda module, inputs: aten_tracer.opt_trace(
File "~/TensorRT/py/torch_tensorrt/fx/utils.py", line 136, in function_wrapper
return f(*args, **kwargs)
File "~/TensorRT/py/torch_tensorrt/fx/tracer/dispatch_tracer/aten_tracer.py", line 158, in opt_trace
pr: PassResult = passes(fx_module)
File "~/TensorRT/py/torch_tensorrt/fx/passes/lower_basic_pass_aten.py", line 420, in compose_bmm
input_input_n = input_n.all_input_nodes[0]
IndexError: list index out of range
To Reproduce
Steps to reproduce the behavior:
- Run the code sample below
import torch
import torch_tensorrt
class Sample(torch.nn.Module):
def __init__(self):
super(Sample, self).__init__()
def forward(self, x, y):
out = torch.bmm(x, y)
return out
def main():
model = Sample().cuda().eval()
input_data = torch.zeros((5, 5, 5), dtype=torch.float, device="cuda:0")
input_data_2 = torch.ones((5, 5, 5), dtype=torch.float, device="cuda:0")
out_torch = model(input_data, input_data_2)
mod = torch_tensorrt.fx.compile(model, [input_data, input_data_2],
lower_precision=torch_tensorrt.fx.utils.LowerPrecision.FP32,
min_acc_module_size=1, is_aten=True)
out_trt = mod(input_data, input_data_2)
print(out_trt)
main()
Expected behavior
The model should compile
Environment
- Torch-TensorRT Version (e.g. 1.0.0): ad5e764
- PyTorch Version (e.g. 1.0):
2.1.0.dev20230314+cu117