From d44b0dc8273984afabb090822f6c88a66660f307 Mon Sep 17 00:00:00 2001 From: Alex Rice Date: Tue, 5 Nov 2024 00:39:35 +0000 Subject: [PATCH 1/3] dialects: (prob) add finite support distribution --- inconspiquous/dialects/prob.py | 121 +++++++++++++++++- .../transforms/canonicalization/prob.py | 67 +++++++++- pyproject.toml | 6 +- .../filecheck/dialects/prob/canonicalize.mlir | 19 +++ tests/filecheck/dialects/prob/ops.mlir | 12 ++ uv.lock | 18 +-- 6 files changed, 228 insertions(+), 15 deletions(-) diff --git a/inconspiquous/dialects/prob.py b/inconspiquous/dialects/prob.py index b00d956..796acab 100644 --- a/inconspiquous/dialects/prob.py +++ b/inconspiquous/dialects/prob.py @@ -1,13 +1,33 @@ +from typing import ClassVar, Self, Sequence from xdsl.dialects.builtin import IntegerAttrTypeConstr, i1 -from xdsl.ir import Dialect, VerifyException +from xdsl.ir import ( + Attribute, + Dialect, + Operation, + SSAValue, + VerifyException, +) from xdsl.irdl import ( + AnyAttr, IRDLOperation, + VarConstraint, irdl_op_definition, + operand_def, prop_def, result_def, traits_def, + var_operand_def, +) +from xdsl.parser import ( + DenseArrayBase, + Float64Type, + FloatAttr, + IndexType, + IntegerType, + UnresolvedOperand, ) -from xdsl.parser import Float64Type, FloatAttr, IndexType, IntegerType +from xdsl.parser import Parser +from xdsl.printer import Printer from xdsl.pattern_rewriter import RewritePattern from xdsl.traits import HasCanonicalizationPatternsTrait @@ -68,11 +88,108 @@ def __init__(self, out_type: IntegerType | IndexType): ) +class FinSuppOpHasCanonicalizationPatterns(HasCanonicalizationPatternsTrait): + @classmethod + def get_canonicalization_patterns(cls) -> tuple[RewritePattern, ...]: + from inconspiquous.transforms.canonicalization.prob import ( + FinSuppTrivial, + FinSuppRemoveCase, + FinSuppDuplicate, + ) + + return (FinSuppTrivial(), FinSuppRemoveCase(), FinSuppDuplicate()) + + +@irdl_op_definition +class FinSuppOp(IRDLOperation): + name = "prob.fin_supp" + + _T: ClassVar = VarConstraint("T", AnyAttr()) + + ins = var_operand_def(_T) + + default_value = operand_def(_T) + + out = result_def(_T) + + probabilities = prop_def(DenseArrayBase) + + traits = traits_def(FinSuppOpHasCanonicalizationPatterns()) + + def __init__( + self, + probabilities: Sequence[float] | DenseArrayBase, + default_value: SSAValue, + *ins: SSAValue | Operation, + attr_dict: dict[str, Attribute] | None = None, + ): + result_type = SSAValue.get(default_value).type + if not isinstance(probabilities, DenseArrayBase): + probabilities = DenseArrayBase.create_dense_float( + Float64Type(), probabilities + ) + super().__init__( + operands=(ins, default_value), + result_types=(result_type,), + properties={"probabilities": probabilities}, + attributes=attr_dict, + ) + + @staticmethod + def parse_case(parser: Parser) -> tuple[UnresolvedOperand, float]: + prob = parser.parse_number() + assert isinstance(prob, float) + parser.parse_keyword("or") + operand = parser.parse_unresolved_operand() + return (operand, prob) + + @classmethod + def parse(cls, parser: Parser) -> Self: + parser.parse_punctuation("[") + probabilities: list[float] = [] + cases: list[UnresolvedOperand] = [] + while (n := parser.parse_optional_number()) is not None: + assert isinstance(n, float) + probabilities.append(n) + parser.parse_keyword("of") + cases.append(parser.parse_unresolved_operand()) + parser.parse_punctuation(",") + if cases: + parser.parse_keyword("else") + default_unresolved = parser.parse_unresolved_operand() + parser.parse_punctuation("]") + parser.parse_punctuation(":") + result_type = parser.parse_type() + ins = tuple(parser.resolve_operand(x, result_type) for x in cases) + default_value = parser.resolve_operand(default_unresolved, result_type) + attr_dict = parser.parse_optional_attr_dict() + return cls(probabilities, default_value, *ins, attr_dict=attr_dict) + + @staticmethod + def print_case(c: tuple[SSAValue, int | float], printer: Printer): + operand, prob = c + printer.print_string(repr(prob) + " of ") + printer.print_operand(operand) + + def print(self, printer: Printer): + printer.print_string(" [ ") + printer.print_list( + zip(self.ins, self.probabilities.as_tuple()), + lambda c: self.print_case(c, printer), + ) + if self.ins: + printer.print_string(", else ") + printer.print_operand(self.default_value) + printer.print_string(" ] : ") + printer.print_attribute(self.out.type) + + Prob = Dialect( "prob", [ BernoulliOp, UniformOp, + FinSuppOp, ], [], ) diff --git a/inconspiquous/transforms/canonicalization/prob.py b/inconspiquous/transforms/canonicalization/prob.py index 04e0971..80e77f1 100644 --- a/inconspiquous/transforms/canonicalization/prob.py +++ b/inconspiquous/transforms/canonicalization/prob.py @@ -1,12 +1,13 @@ from xdsl.dialects.arith import Constant from xdsl.dialects.builtin import BoolAttr +from xdsl.ir import SSAValue from xdsl.pattern_rewriter import ( PatternRewriter, RewritePattern, op_type_rewrite_pattern, ) -from inconspiquous.dialects.prob import BernoulliOp +from inconspiquous.dialects.prob import BernoulliOp, FinSuppOp class BernoulliConst(RewritePattern): @@ -23,3 +24,67 @@ def match_and_rewrite(self, op: BernoulliOp, rewriter: PatternRewriter): if prob == 0.0: rewriter.replace_matched_op(Constant(BoolAttr.from_bool(False))) + + +class FinSuppTrivial(RewritePattern): + """ + prob.fin_supp [ %x ] == %x + """ + + @op_type_rewrite_pattern + def match_and_rewrite(self, op: FinSuppOp, rewriter: PatternRewriter): + if not op.probabilities.data.data: + rewriter.replace_matched_op((), (op.default_value,)) + + +class FinSuppRemoveCase(RewritePattern): + """ + A case can be removed if its probability is 0 or it's equal to the default case. + """ + + @op_type_rewrite_pattern + def match_and_rewrite(self, op: FinSuppOp, rewriter: PatternRewriter): + probs = op.probabilities.as_tuple() + if not any( + p == 0.0 or c == op.default_value + for p, c in zip(probs, op.ins, strict=True) + ): + return + new_probabilities = tuple( + p + for p, c in zip(probs, op.ins, strict=True) + if p != 0.0 and c != op.default_value + ) + new_ins = tuple( + c + for p, c in zip(probs, op.ins, strict=True) + if p != 0.0 and c != op.default_value + ) + rewriter.replace_matched_op( + FinSuppOp(new_probabilities, op.default_value, *new_ins) + ) + + +class FinSuppDuplicate(RewritePattern): + """ + If two cases are the same then we can merge them. + """ + + @op_type_rewrite_pattern + def match_and_rewrite(self, op: FinSuppOp, rewriter: PatternRewriter): + print(op.ins) + if len(set(op.ins)) == len(op.ins): + return + seen: dict[SSAValue, int] = dict() + new_probs: list[float] = [] + new_ins: list[SSAValue] = [] + + for p, c in zip(op.probabilities.as_tuple(), op.ins, strict=True): + if c not in seen: + seen[c] = len(new_probs) + new_probs.append(p) + new_ins.append(c) + else: + new_probs[seen[c]] += p + + rewriter.replace_matched_op(FinSuppOp(new_probs, op.default_value, *new_ins)) diff --git a/pyproject.toml b/pyproject.toml index d4e7cc0..9861bc2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,9 +13,9 @@ dev-dependencies = [ "ruff>=0.6.5", "pyright>=1.1.380", "pytest>=8.3.3", - "lit<16.0.0", - "filecheck==0.0.23", - "pre-commit==3.3.1", + "lit<19.0.0", + "filecheck==1.0.1", + "pre-commit==4.0.1", "psutil>=6.0.0", ] diff --git a/tests/filecheck/dialects/prob/canonicalize.mlir b/tests/filecheck/dialects/prob/canonicalize.mlir index 3674916..b312d65 100644 --- a/tests/filecheck/dialects/prob/canonicalize.mlir +++ b/tests/filecheck/dialects/prob/canonicalize.mlir @@ -8,3 +8,22 @@ // Stop them being dead code eliminated "test.op"(%0, %1) : (i1, i1) -> () + +// CHECK: %[[#x1:]] = "test.op"() {"fin_supp_test"} : () -> i64 +%2 = "test.op"() {"fin_supp_test"} : () -> i64 +%3 = prob.fin_supp [ %2 ] : i64 +// CHECK-NEXT: "test.op"(%[[#x1]]) : (i64) -> () +"test.op"(%3) : (i64) -> () + +// CHECK: %[[#first:]], %[[#second:]], %[[#third:]] = "test.op"() : () -> (i32, i32, i32) +%4, %5, %6 = "test.op"() : () -> (i32, i32, i32) +// CHECK-NEXT: %{{.*}} = prob.fin_supp [ 0.30000000000000004 of %[[#first]], else %[[#second]] ] : i32 +%7 = prob.fin_supp [ 0.1 of %4, 0.2 of %4, else %5 ] : i32 + +// CHECK-NEXT: %{{.*}} = prob.fin_supp [ 0.1 of %[[#first]], else %[[#third]] ] : i32 +%8 = prob.fin_supp [ 0.1 of %4, 0.0 of %5, else %6 ] : i32 + +// CHECK-NEXT: %{{.*}} = prob.fin_supp [ 0.2 of %[[#second]], else %[[#first]] ] : i32 +%9 = prob.fin_supp [ 0.1 of %4, 0.2 of %5, else %4 ] : i32 + +"test.op"(%7, %8, %9) : (i32, i32, i32) -> () diff --git a/tests/filecheck/dialects/prob/ops.mlir b/tests/filecheck/dialects/prob/ops.mlir index fa3b735..cb77ed9 100644 --- a/tests/filecheck/dialects/prob/ops.mlir +++ b/tests/filecheck/dialects/prob/ops.mlir @@ -8,3 +8,15 @@ // CHECK: %{{.*}} = prob.uniform : i32 // CHECK-GENERIC: %{{.*}} = "prob.uniform"() : () -> i32 %1 = prob.uniform : i32 + +%2, %3, %4 = "test.op"() : () -> (i64, i64, i64) + +// CHECK: %{{.*}} = prob.fin_supp [ 0.1 of %{{.*}}, 0.2 of %{{.*}}, else %{{.*}} ] : i64 +%5 = prob.fin_supp [ + 0.1 of %2, + 0.2 of %3, + else %4 +] : i64 + +// CHECK: %{{.*}} = prob.fin_supp [ %{{.*}} ] : i64 +%6 = prob.fin_supp [ %4 ] : i64 diff --git a/uv.lock b/uv.lock index ac3fb37..388034d 100644 --- a/uv.lock +++ b/uv.lock @@ -39,11 +39,11 @@ wheels = [ [[package]] name = "filecheck" -version = "0.0.23" +version = "1.0.1" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/c3/fe/9f11b6dff466ec57a9f49902858ef1f4a90993998ac6fdd6cf579d5fbd95/filecheck-0.0.23.tar.gz", hash = "sha256:1c5db511fb7b5a32e1e24736479cfe754ea27c9ae0d5b6d52c0af132c8db3e7d", size = 12804 } +sdist = { url = "https://files.pythonhosted.org/packages/93/d2/7e8bf9acf2ccb522fef4845d940b7db14782e6a36929da0ca5a4791bc2b6/filecheck-1.0.1.tar.gz", hash = "sha256:bbc3c49c190bd3af2445426a193ff0b54e1fad5a81ea7c2116c4dc36f36614f2", size = 20480 } wheels = [ - { url = "https://files.pythonhosted.org/packages/d8/6a/a864c347dcffa6ac6b97f3770b5f4642b26cb3acf04a5b5bc2b14a04149b/filecheck-0.0.23-py3-none-any.whl", hash = "sha256:cc1dc3fc2fc682ccd059b0d535606d32235613a32c018211d93aa6a99047ceb2", size = 13217 }, + { url = "https://files.pythonhosted.org/packages/c7/29/827b9f240e03c2cc6a2fd534ac980a12b3c7e8d8aa71c2f1039a5f91e932/filecheck-1.0.1-py3-none-any.whl", hash = "sha256:2d1a0e8784b723a4b04a655cff3af09dd159a31f4e39d477186f5547553124ab", size = 23702 }, ] [[package]] @@ -97,9 +97,9 @@ requires-dist = [{ name = "xdsl", extras = ["gui"], git = "https://github.com/xd [package.metadata.requires-dev] dev = [ - { name = "filecheck", specifier = "==0.0.23" }, - { name = "lit", specifier = "<16.0.0" }, - { name = "pre-commit", specifier = "==3.3.1" }, + { name = "filecheck", specifier = "==1.0.1" }, + { name = "lit", specifier = "<19.0.0" }, + { name = "pre-commit", specifier = "==4.0.1" }, { name = "psutil", specifier = ">=6.0.0" }, { name = "pyright", specifier = ">=1.1.380" }, { name = "pytest", specifier = ">=8.3.3" }, @@ -227,7 +227,7 @@ wheels = [ [[package]] name = "pre-commit" -version = "3.3.1" +version = "4.0.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "cfgv" }, @@ -236,9 +236,9 @@ dependencies = [ { name = "pyyaml" }, { name = "virtualenv" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/f6/f9/fd40593d83357bb03733c0e77e71a08f2f5f523595d0a10401d7e5c22f16/pre_commit-3.3.1.tar.gz", hash = "sha256:733f78c9a056cdd169baa6cd4272d51ecfda95346ef8a89bf93712706021b907", size = 176605 } +sdist = { url = "https://files.pythonhosted.org/packages/2e/c8/e22c292035f1bac8b9f5237a2622305bc0304e776080b246f3df57c4ff9f/pre_commit-4.0.1.tar.gz", hash = "sha256:80905ac375958c0444c65e9cebebd948b3cdb518f335a091a670a89d652139d2", size = 191678 } wheels = [ - { url = "https://files.pythonhosted.org/packages/77/39/86e07f4e9671ee9311fa4bafc41c66d6a907192707160e3f45272e78be38/pre_commit-3.3.1-py2.py3-none-any.whl", hash = "sha256:218e9e3f7f7f3271ebc355a15598a4d3893ad9fc7b57fe446db75644543323b9", size = 202528 }, + { url = "https://files.pythonhosted.org/packages/16/8f/496e10d51edd6671ebe0432e33ff800aa86775d2d147ce7d43389324a525/pre_commit-4.0.1-py2.py3-none-any.whl", hash = "sha256:efde913840816312445dc98787724647c65473daefe420785f885e8ed9a06878", size = 218713 }, ] [[package]] From a033e158a200117b7b0bb28e8f4cd85ee8cd30c4 Mon Sep 17 00:00:00 2001 From: Alex Rice Date: Thu, 7 Nov 2024 17:08:05 +0000 Subject: [PATCH 2/3] fix 3.10 --- inconspiquous/dialects/prob.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/inconspiquous/dialects/prob.py b/inconspiquous/dialects/prob.py index 796acab..f144775 100644 --- a/inconspiquous/dialects/prob.py +++ b/inconspiquous/dialects/prob.py @@ -1,4 +1,6 @@ -from typing import ClassVar, Self, Sequence +from typing import ClassVar, Sequence +from typing_extensions import Self + from xdsl.dialects.builtin import IntegerAttrTypeConstr, i1 from xdsl.ir import ( Attribute, From 8d2e74dc3710b21185927c0f3c5a8b28d57b93b0 Mon Sep 17 00:00:00 2001 From: Alex Rice Date: Fri, 8 Nov 2024 09:46:38 +0000 Subject: [PATCH 3/3] Make test use powers of 2 --- tests/filecheck/dialects/prob/canonicalize.mlir | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/filecheck/dialects/prob/canonicalize.mlir b/tests/filecheck/dialects/prob/canonicalize.mlir index b312d65..2162b81 100644 --- a/tests/filecheck/dialects/prob/canonicalize.mlir +++ b/tests/filecheck/dialects/prob/canonicalize.mlir @@ -17,8 +17,8 @@ // CHECK: %[[#first:]], %[[#second:]], %[[#third:]] = "test.op"() : () -> (i32, i32, i32) %4, %5, %6 = "test.op"() : () -> (i32, i32, i32) -// CHECK-NEXT: %{{.*}} = prob.fin_supp [ 0.30000000000000004 of %[[#first]], else %[[#second]] ] : i32 -%7 = prob.fin_supp [ 0.1 of %4, 0.2 of %4, else %5 ] : i32 +// CHECK-NEXT: %{{.*}} = prob.fin_supp [ 0.375 of %[[#first]], else %[[#second]] ] : i32 +%7 = prob.fin_supp [ 0.125 of %4, 0.25 of %4, else %5 ] : i32 // CHECK-NEXT: %{{.*}} = prob.fin_supp [ 0.1 of %[[#first]], else %[[#third]] ] : i32 %8 = prob.fin_supp [ 0.1 of %4, 0.0 of %5, else %6 ] : i32