Skip to content

Commit

Permalink
bugfix: allow empty tuple for inputs or arg_inputs (#3122)
Browse files Browse the repository at this point in the history
  • Loading branch information
jiwoong-choi authored Sep 3, 2024
1 parent ae7e6c8 commit 8e75039
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 2 deletions.
4 changes: 2 additions & 2 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,10 +604,10 @@ def convert_exported_program_to_serialized_trt_engine(
DeprecationWarning,
stacklevel=2,
)
if not arg_inputs and not inputs:
if arg_inputs is None and inputs is None:
raise AssertionError("'arg_inputs' and 'inputs' should not both be None.")

elif arg_inputs and inputs:
elif arg_inputs is not None and inputs is not None:
raise AssertionError(
"'arg_inputs' and 'inputs' should not be used at the same time."
)
Expand Down
47 changes: 47 additions & 0 deletions tests/py/dynamo/models/test_export_kwargs_serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,3 +525,50 @@ def forward(self, x, b=5, c=None, d=None):
engine = convert_exported_program_to_serialized_trt_engine(
exp_program, **compile_spec
)


def test_custom_model_compile_engine_with_pure_kwarg_inputs():
class net(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 12, 3, padding=1)
self.bn = nn.BatchNorm2d(12)
self.conv2 = nn.Conv2d(12, 12, 3, padding=1)
self.fc1 = nn.Linear(12 * 56 * 56, 10)

def forward(self, x, b=5, c=None, d=None):
x = self.conv1(x)
x = F.relu(x)
x = self.bn(x)
x = F.max_pool2d(x, (2, 2))
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, (2, 2))
x = torch.flatten(x, 1)
x = x + b
if c is not None:
x = x * c
if d is not None:
x = x - d["value"]
return self.fc1(x)

model = net().eval().to("cuda")
kwargs = {
"x": torch.rand((1, 3, 224, 224)).to("cuda"),
"b": torch.tensor(6).to("cuda"),
"d": {"value": torch.tensor(8).to("cuda")},
}

compile_spec = {
"arg_inputs": (),
"kwarg_inputs": kwargs,
"device": torchtrt.Device("cuda:0"),
"enabled_precisions": {torch.float},
"pass_through_build_failures": True,
"optimization_level": 1,
"min_block_size": 1,
"ir": "dynamo",
}

exp_program = torch.export.export(model, args=(), kwargs=kwargs)
_ = convert_exported_program_to_serialized_trt_engine(exp_program, **compile_spec)

0 comments on commit 8e75039

Please sign in to comment.