Skip to content

Commit

Permalink
Merge pull request #852 from take-cheeze/tuple_input
Browse files Browse the repository at this point in the history
[ONNX] Flatten input test data to support nested inputs
  • Loading branch information
linshokaku authored Dec 5, 2024
2 parents c168321 + 315572d commit a70ef34
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 1 deletion.
3 changes: 2 additions & 1 deletion pytorch_pfn_extras/onnx/export_testcase.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,8 @@ def write_to_pb(f: str, tensor: torch.Tensor, name: Optional[str] = None) -> Non
os.makedirs(data_set_path, exist_ok=True)
for pb_name in glob.glob(os.path.join(data_set_path, "*.pb")):
os.remove(pb_name)
for i, (arg, name) in enumerate(zip(named_args, input_names)):
flat_inputs = torch._C._jit_flatten(named_args)[0]
for i, (arg, name) in enumerate(zip(flat_inputs, input_names)):
f = os.path.join(data_set_path, 'input_{}.pb'.format(i))
write_to_pb(f, arg, name)

Expand Down
28 changes: 28 additions & 0 deletions tests/pytorch_pfn_extras_tests/onnx_tests/test_export_testcase.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,3 +633,31 @@ def custom(model, args, f, **kwargs):
actual = ort_session.run(None, {"x": x.cpu().numpy()})[0]
expected = model(x)
np.testing.assert_allclose(actual, expected.detach().numpy(), atol=1e-3)


def test_export_tuple_input():

class Net(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(5, 10, bias=False)
self.in_channels = [5, 10]

def forward(self, inputs):
assert isinstance(inputs, tuple), f"{type(inputs)=}"
linears = [self.linear(x) for x in inputs]
return linears


model = Net()
x = torch.rand(2, 5)

export_testcase(
model,
((x,),),
output_dir,
input_names=["x"],
training=model.training,
do_constant_folding=False,
opset_version=12,
)

0 comments on commit a70ef34

Please sign in to comment.