diff --git a/onnxscript/optimizer/_remove_unused_test.py b/onnxscript/optimizer/_remove_unused_test.py index b87a176f6..425a00a44 100644 --- a/onnxscript/optimizer/_remove_unused_test.py +++ b/onnxscript/optimizer/_remove_unused_test.py @@ -11,6 +11,8 @@ @parameterized.parameterized_class(("using_ir",), [(False,), (True,)]) class RemoveUnusedTest(unittest.TestCase): + using_ir: bool + def remove_unused_nodes(self, model: onnx.ModelProto): if self.using_ir: model_ir = ir.serde.deserialize_model(model) @@ -81,11 +83,7 @@ def test_remove_unused_optional_outputs_maxpool(self): model = self.remove_unused_nodes(model) self.assertEqual(len(model.graph.node), 1) self.assertEqual(model.graph.node[0].op_type, "MaxPool") - if self.using_ir: - expected_outputs = ["z", ""] - else: - expected_outputs = ["z"] - self.assertEqual(model.graph.node[0].output, expected_outputs) + self.assertEqual(model.graph.node[0].output, ["z"]) def test_remove_unused_optional_outputs_dropout_in_function(self): model = onnx.parser.parse_model( @@ -110,11 +108,7 @@ def test_remove_unused_optional_outputs_dropout_in_function(self): self.assertEqual(len(model.functions), 1) self.assertEqual(len(model.functions[0].node), 1) self.assertEqual(model.functions[0].node[0].op_type, "MaxPool") - if self.using_ir: - expected_outputs = ["z", ""] - else: - expected_outputs = ["z"] - self.assertEqual(model.functions[0].node[0].output, expected_outputs) + self.assertEqual(model.functions[0].node[0].output, ["z"]) def test_remove_used_optional_outputs_maxpool(self): model = onnx.parser.parse_model( @@ -150,11 +144,7 @@ def test_remove_multiple_unused_optional_outputs_layernorm(self): model = self.remove_unused_nodes(model) self.assertEqual(len(model.graph.node), 3) self.assertEqual(model.graph.node[2].op_type, "LayerNormalization") - if self.using_ir: - expected_outputs = ["z", "", ""] - else: - expected_outputs = ["z"] - self.assertEqual(list(model.graph.node[2].output), expected_outputs) + self.assertEqual(list(model.graph.node[2].output), ["z"]) def test_remove_trailing_unused_optional_outputs_layernorm(self): model = onnx.parser.parse_model( @@ -173,11 +163,7 @@ def test_remove_trailing_unused_optional_outputs_layernorm(self): model = self.remove_unused_nodes(model) self.assertEqual(len(model.graph.node), 3) self.assertEqual(model.graph.node[2].op_type, "LayerNormalization") - if self.using_ir: - expected_outputs = ["z", "mean", ""] - else: - expected_outputs = ["z", "mean"] - self.assertEqual(list(model.graph.node[2].output), expected_outputs) + self.assertEqual(list(model.graph.node[2].output), ["z", "mean"]) def test_avoid_remove_non_trailing_unused_optional_outputs_layernorm(self): model = onnx.parser.parse_model( @@ -212,11 +198,7 @@ def test_remove_trailing_unused_optional_outputs_batchnorm(self): self.assertEqual(len(model.graph.node), 1) self.assertEqual(model.graph.node[0].op_type, "BatchNormalization") # Check that both the mean/var outputs are removed, and training_mode attribute is removed. - if self.using_ir: - expected_outputs = ["z", "", ""] - else: - expected_outputs = ["z"] - self.assertEqual(list(model.graph.node[0].output), expected_outputs) + self.assertEqual(list(model.graph.node[0].output), ["z"]) self.assertEqual(len(model.graph.node[0].attribute), 0) def test_avoid_remove_used_optional_outputs_batchnorm(self):