Skip to content

Commit

Permalink
Add test cases for pattern matching against optional inputs (#1890)
Browse files Browse the repository at this point in the history
I updated the pattern matcher to support matching against optional
inputs. The change was accidentally pushed into the main branch (not a
sub-branch as I thought) ... guess the branch protections were not good
enough, changed it now.

Adding test-cases now to test it in this PR.

---------

Signed-off-by: Ganesan Ramalingam <[email protected]>
  • Loading branch information
gramalingam authored Oct 2, 2024
1 parent dc9a12d commit 35fdcf5
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 5 deletions.
10 changes: 6 additions & 4 deletions onnxscript/rewriter/pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,10 +592,12 @@ def producer(self) -> NodePattern:

Var = ValuePattern


def _is_pattern_variable(x: Any) -> bool:
# The derived classes of ValuePattern represent constant patterns and node-output patterns.
return type(x) is ValuePattern


class Constant(ValuePattern):
"""Represents a pattern that matches against a scalar constant value."""

Expand Down Expand Up @@ -988,16 +990,16 @@ def _bind_value(self, pattern_value: ValuePattern, value: ir.Value | None) -> bo

def _match_value(self, pattern_value: ValuePattern, value: ir.Value | None) -> bool:
"""Match an IR value against a ValuePattern instance."""
if value is None:
if not _is_pattern_variable(pattern_value):
return self.fail("Mismatch: input value is None, but pattern value is not a variable.")

if not self._bind_value(pattern_value, value):
return False

if isinstance(pattern_value, NodeOutputPattern):
if value is None:
return self.fail("Mismatch: Computed node pattern does not match None.")
return self._match_node_output(pattern_value, value)
if isinstance(pattern_value, Constant):
if value is None:
return self.fail("Mismatch: Constant pattern does not match None.")
return self._match_constant(pattern_value, value)
return True

Expand Down
59 changes: 58 additions & 1 deletion onnxscript/rewriter/pattern_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
import onnx.checker
import onnx.parser

from onnxscript import ir
from onnxscript import FLOAT, ir, script
from onnxscript import opset17 as op
from onnxscript.rewriter import _ir_utils, cast_constant_of_shape, pattern

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -420,6 +421,62 @@ def concat(op, x, y, result: ir.Value):
self.assertEqual(model.graph[0].op_type, "Concat")
self.assertNotIn("axis", model.graph[0].attributes)

def test_match_none_input(self):
def none_pattern(op, x):
# match against a call to Original where the first input is None
return op.Original(None, x)

def replacement(op, x):
return op.Replaced(x)

rule = pattern.RewriteRule(none_pattern, replacement)

@script()
def test_model(x: FLOAT[1024]) -> FLOAT[1024]:
# Pattern should match following call
t1 = op.Original(None, x)
# Pattern should not match following call
z = op.Original(t1, x)
return z

model_proto = test_model.to_model_proto()
model = ir.serde.deserialize_model(model_proto)

count = rule.apply_to_model(model)
self.assertEqual(count, 1)
self.assertEqual(len(model.graph), 2)
self.assertEqual(model.graph.node(0).op_type, "Replaced")
self.assertEqual(model.graph.node(1).op_type, "Original")

def test_match_optional_input(self):
def none_pattern(op, optional_input, x):
# match against a call to Original where the first input may or may not be None
return op.Original(optional_input, x)

def replacement(op, optional_input, x):
if optional_input is None:
return op.ReplacedNone(x)
return op.ReplacedNotNone(x)

rule = pattern.RewriteRule(none_pattern, replacement)

@script()
def test_model(x: FLOAT[1024]) -> FLOAT[1024]:
# Pattern should match following call
t1 = op.Original(None, x)
# as well as this one
z = op.Original(t1, x)
return z

model_proto = test_model.to_model_proto()
model = ir.serde.deserialize_model(model_proto)

count = rule.apply_to_model(model)
self.assertEqual(count, 2)
self.assertEqual(len(model.graph), 2)
self.assertEqual(model.graph.node(0).op_type, "ReplacedNone")
self.assertEqual(model.graph.node(1).op_type, "ReplacedNotNone")


class PatternBuilderTest(unittest.TestCase):
def test_pattern_builder_context(self):
Expand Down

0 comments on commit 35fdcf5

Please sign in to comment.