From 5e46bcd9f4f4d2f81a56e026556e8d61972e714e Mon Sep 17 00:00:00 2001 From: twata Date: Thu, 6 Jul 2023 09:32:03 +0000 Subject: [PATCH] [pfto] Add prim::Loop test --- .../onnx_tests/test_export.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tests/pytorch_pfn_extras_tests/onnx_tests/test_export.py b/tests/pytorch_pfn_extras_tests/onnx_tests/test_export.py index 201b025b6..3963c9473 100644 --- a/tests/pytorch_pfn_extras_tests/onnx_tests/test_export.py +++ b/tests/pytorch_pfn_extras_tests/onnx_tests/test_export.py @@ -416,3 +416,20 @@ def forward(self, x, y): output_names=["out"], dynamic_axes={"x": {1: "A"}, "y": {0: "B"}}, ) + + +def test_loop(): + @torch.jit.script + def f(x: torch.Tensor): + for _ in range(10): + x = x + 5 + return x + + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x: torch.Tensor): + return f(x) + + run_model_test(Model(), (torch.randn(2, 7, 17),))