Skip to content

Commit

Permalink
Add complex export test in pfto
Browse files Browse the repository at this point in the history
  • Loading branch information
twata committed Mar 3, 2023
1 parent 8ceabdb commit 53915d9
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 4 deletions.
12 changes: 8 additions & 4 deletions pytorch_pfn_extras/onnx/pfto_exporter/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,8 @@ def onnx_node_doc_string(onnx_node: torch._C.Node, torch_node: torch._C.Node) ->
torch.bool: onnx.TensorProto.DataType.BOOL,
torch.float64: onnx.TensorProto.DataType.DOUBLE,
torch.float16: onnx.TensorProto.DataType.FLOAT16,
torch.complex64: onnx.TensorProto.DataType.COMPLEX64,
torch.complex128: onnx.TensorProto.DataType.COMPLEX128,
}


Expand Down Expand Up @@ -165,6 +167,7 @@ class _ExporterOptions:
onnx_data_prop: bool = True
onnx_lowprecision_cast: bool = True
onnx_peephole: bool = True
onnx_scalar_type_analysis: bool = True
fixed_batch_size: bool = False

input_names: Optional[List[str]] = None
Expand Down Expand Up @@ -339,10 +342,11 @@ def optimize_torch(self, graph: torch._C.Graph) -> torch._C.Graph:

# ONNX level graph optimizer
def optimize_onnx(self, graph: torch._C.Graph) -> torch._C.Graph:
if pytorch_pfn_extras.requires("1.9.0"):
run_jit_pass(torch._C._jit_pass_onnx_scalar_type_analysis, graph, self.onnx_lowprecision_cast, self.opset_version)
else:
run_jit_pass(torch._C._jit_pass_onnx_scalar_type_analysis, graph)
if self.onnx_scalar_type_analysis:
if pytorch_pfn_extras.requires("1.9.0"):
run_jit_pass(torch._C._jit_pass_onnx_scalar_type_analysis, graph, self.onnx_lowprecision_cast, self.opset_version)
else:
run_jit_pass(torch._C._jit_pass_onnx_scalar_type_analysis, graph)

if self.do_constant_folding and self.opset_version in pytorch_pfn_extras.onnx._constants.onnx_constant_folding_opsets:
folded: Dict[str, torch.IValue] = torch._C._jit_pass_onnx_constant_fold( # type: ignore[attr-defined]
Expand Down
15 changes: 15 additions & 0 deletions tests/pytorch_pfn_extras_tests/onnx_tests/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,3 +237,18 @@ def forward(self, x):

def test_softmax():
run_model_test(torch.nn.Softmax(3), (torch.randn(1, 10, 30, 30),))


def test_complex():
class Complex(torch.nn.Module):
def forward(self, x):
return x + 1

x = torch.rand(32, 32, dtype=torch.complex64)
run_model_test(
Complex(),
(x,),
check_torch_export=False,
onnx_scalar_type_analysis=False,
skip_oxrt=True, # Add op in ONNX spec doesn't support complex input
)

0 comments on commit 53915d9

Please sign in to comment.