Skip to content

Commit

Permalink
Update test
Browse files Browse the repository at this point in the history
  • Loading branch information
justinchuby committed Oct 14, 2024
1 parent 4fe6dd8 commit cf6cd76
Showing 1 changed file with 7 additions and 25 deletions.
32 changes: 7 additions & 25 deletions onnxscript/optimizer/_remove_unused_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@

@parameterized.parameterized_class(("using_ir",), [(False,), (True,)])
class RemoveUnusedTest(unittest.TestCase):
using_ir: bool

def remove_unused_nodes(self, model: onnx.ModelProto):
if self.using_ir:
model_ir = ir.serde.deserialize_model(model)
Expand Down Expand Up @@ -81,11 +83,7 @@ def test_remove_unused_optional_outputs_maxpool(self):
model = self.remove_unused_nodes(model)
self.assertEqual(len(model.graph.node), 1)
self.assertEqual(model.graph.node[0].op_type, "MaxPool")
if self.using_ir:
expected_outputs = ["z", ""]
else:
expected_outputs = ["z"]
self.assertEqual(model.graph.node[0].output, expected_outputs)
self.assertEqual(model.graph.node[0].output, ["z"])

def test_remove_unused_optional_outputs_dropout_in_function(self):
model = onnx.parser.parse_model(
Expand All @@ -110,11 +108,7 @@ def test_remove_unused_optional_outputs_dropout_in_function(self):
self.assertEqual(len(model.functions), 1)
self.assertEqual(len(model.functions[0].node), 1)
self.assertEqual(model.functions[0].node[0].op_type, "MaxPool")
if self.using_ir:
expected_outputs = ["z", ""]
else:
expected_outputs = ["z"]
self.assertEqual(model.functions[0].node[0].output, expected_outputs)
self.assertEqual(model.functions[0].node[0].output, ["z"])

def test_remove_used_optional_outputs_maxpool(self):
model = onnx.parser.parse_model(
Expand Down Expand Up @@ -150,11 +144,7 @@ def test_remove_multiple_unused_optional_outputs_layernorm(self):
model = self.remove_unused_nodes(model)
self.assertEqual(len(model.graph.node), 3)
self.assertEqual(model.graph.node[2].op_type, "LayerNormalization")
if self.using_ir:
expected_outputs = ["z", "", ""]
else:
expected_outputs = ["z"]
self.assertEqual(list(model.graph.node[2].output), expected_outputs)
self.assertEqual(list(model.graph.node[2].output), ["z"])

def test_remove_trailing_unused_optional_outputs_layernorm(self):
model = onnx.parser.parse_model(
Expand All @@ -173,11 +163,7 @@ def test_remove_trailing_unused_optional_outputs_layernorm(self):
model = self.remove_unused_nodes(model)
self.assertEqual(len(model.graph.node), 3)
self.assertEqual(model.graph.node[2].op_type, "LayerNormalization")
if self.using_ir:
expected_outputs = ["z", "mean", ""]
else:
expected_outputs = ["z", "mean"]
self.assertEqual(list(model.graph.node[2].output), expected_outputs)
self.assertEqual(list(model.graph.node[2].output), ["z", "mean"])

def test_avoid_remove_non_trailing_unused_optional_outputs_layernorm(self):
model = onnx.parser.parse_model(
Expand Down Expand Up @@ -212,11 +198,7 @@ def test_remove_trailing_unused_optional_outputs_batchnorm(self):
self.assertEqual(len(model.graph.node), 1)
self.assertEqual(model.graph.node[0].op_type, "BatchNormalization")
# Check that both the mean/var outputs are removed, and training_mode attribute is removed.
if self.using_ir:
expected_outputs = ["z", "", ""]
else:
expected_outputs = ["z"]
self.assertEqual(list(model.graph.node[0].output), expected_outputs)
self.assertEqual(list(model.graph.node[0].output), ["z"])
self.assertEqual(len(model.graph.node[0].attribute), 0)

def test_avoid_remove_used_optional_outputs_batchnorm(self):
Expand Down

0 comments on commit cf6cd76

Please sign in to comment.