Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dialects: (prob) add finite support distribution #10

Merged
merged 3 commits into from
Nov 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 121 additions & 2 deletions inconspiquous/dialects/prob.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,35 @@
from typing import ClassVar, Sequence
from typing_extensions import Self

from xdsl.dialects.builtin import IntegerAttrTypeConstr, i1
from xdsl.ir import Dialect, VerifyException
from xdsl.ir import (
Attribute,
Dialect,
Operation,
SSAValue,
VerifyException,
)
from xdsl.irdl import (
AnyAttr,
IRDLOperation,
VarConstraint,
irdl_op_definition,
operand_def,
prop_def,
result_def,
traits_def,
var_operand_def,
)
from xdsl.parser import Float64Type, FloatAttr, IndexType, IntegerType
from xdsl.parser import (
DenseArrayBase,
Float64Type,
FloatAttr,
IndexType,
IntegerType,
UnresolvedOperand,
)
from xdsl.parser import Parser
from xdsl.printer import Printer
from xdsl.pattern_rewriter import RewritePattern
from xdsl.traits import HasCanonicalizationPatternsTrait

Expand Down Expand Up @@ -68,11 +90,108 @@ def __init__(self, out_type: IntegerType | IndexType):
)


class FinSuppOpHasCanonicalizationPatterns(HasCanonicalizationPatternsTrait):
@classmethod
def get_canonicalization_patterns(cls) -> tuple[RewritePattern, ...]:
from inconspiquous.transforms.canonicalization.prob import (
FinSuppTrivial,
FinSuppRemoveCase,
FinSuppDuplicate,
)

return (FinSuppTrivial(), FinSuppRemoveCase(), FinSuppDuplicate())


@irdl_op_definition
class FinSuppOp(IRDLOperation):
name = "prob.fin_supp"

_T: ClassVar = VarConstraint("T", AnyAttr())

ins = var_operand_def(_T)

default_value = operand_def(_T)

out = result_def(_T)

probabilities = prop_def(DenseArrayBase)

traits = traits_def(FinSuppOpHasCanonicalizationPatterns())

def __init__(
self,
probabilities: Sequence[float] | DenseArrayBase,
default_value: SSAValue,
*ins: SSAValue | Operation,
attr_dict: dict[str, Attribute] | None = None,
):
result_type = SSAValue.get(default_value).type
if not isinstance(probabilities, DenseArrayBase):
probabilities = DenseArrayBase.create_dense_float(
Float64Type(), probabilities
)
super().__init__(
operands=(ins, default_value),
result_types=(result_type,),
properties={"probabilities": probabilities},
attributes=attr_dict,
)

@staticmethod
def parse_case(parser: Parser) -> tuple[UnresolvedOperand, float]:
prob = parser.parse_number()
assert isinstance(prob, float)
parser.parse_keyword("or")
operand = parser.parse_unresolved_operand()
return (operand, prob)

@classmethod
def parse(cls, parser: Parser) -> Self:
parser.parse_punctuation("[")
probabilities: list[float] = []
cases: list[UnresolvedOperand] = []
while (n := parser.parse_optional_number()) is not None:
assert isinstance(n, float)
probabilities.append(n)
parser.parse_keyword("of")
cases.append(parser.parse_unresolved_operand())
parser.parse_punctuation(",")
if cases:
parser.parse_keyword("else")
default_unresolved = parser.parse_unresolved_operand()
parser.parse_punctuation("]")
parser.parse_punctuation(":")
result_type = parser.parse_type()
ins = tuple(parser.resolve_operand(x, result_type) for x in cases)
default_value = parser.resolve_operand(default_unresolved, result_type)
attr_dict = parser.parse_optional_attr_dict()
return cls(probabilities, default_value, *ins, attr_dict=attr_dict)

@staticmethod
def print_case(c: tuple[SSAValue, int | float], printer: Printer):
operand, prob = c
printer.print_string(repr(prob) + " of ")
printer.print_operand(operand)

def print(self, printer: Printer):
printer.print_string(" [ ")
printer.print_list(
zip(self.ins, self.probabilities.as_tuple()),
lambda c: self.print_case(c, printer),
)
if self.ins:
printer.print_string(", else ")
printer.print_operand(self.default_value)
printer.print_string(" ] : ")
printer.print_attribute(self.out.type)


Prob = Dialect(
"prob",
[
BernoulliOp,
UniformOp,
FinSuppOp,
],
[],
)
67 changes: 66 additions & 1 deletion inconspiquous/transforms/canonicalization/prob.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from xdsl.dialects.arith import Constant
from xdsl.dialects.builtin import BoolAttr
from xdsl.ir import SSAValue
from xdsl.pattern_rewriter import (
PatternRewriter,
RewritePattern,
op_type_rewrite_pattern,
)

from inconspiquous.dialects.prob import BernoulliOp
from inconspiquous.dialects.prob import BernoulliOp, FinSuppOp


class BernoulliConst(RewritePattern):
Expand All @@ -23,3 +24,67 @@ def match_and_rewrite(self, op: BernoulliOp, rewriter: PatternRewriter):

if prob == 0.0:
rewriter.replace_matched_op(Constant(BoolAttr.from_bool(False)))


class FinSuppTrivial(RewritePattern):
"""
prob.fin_supp [ %x ] == %x
"""

@op_type_rewrite_pattern
def match_and_rewrite(self, op: FinSuppOp, rewriter: PatternRewriter):
if not op.probabilities.data.data:
rewriter.replace_matched_op((), (op.default_value,))


class FinSuppRemoveCase(RewritePattern):
"""
A case can be removed if its probability is 0 or it's equal to the default case.
"""

@op_type_rewrite_pattern
def match_and_rewrite(self, op: FinSuppOp, rewriter: PatternRewriter):
probs = op.probabilities.as_tuple()
if not any(
p == 0.0 or c == op.default_value
for p, c in zip(probs, op.ins, strict=True)
):
return
new_probabilities = tuple(
p
for p, c in zip(probs, op.ins, strict=True)
if p != 0.0 and c != op.default_value
)
new_ins = tuple(
c
for p, c in zip(probs, op.ins, strict=True)
if p != 0.0 and c != op.default_value
)
rewriter.replace_matched_op(
FinSuppOp(new_probabilities, op.default_value, *new_ins)
)


class FinSuppDuplicate(RewritePattern):
"""
If two cases are the same then we can merge them.
"""

@op_type_rewrite_pattern
def match_and_rewrite(self, op: FinSuppOp, rewriter: PatternRewriter):
print(op.ins)
if len(set(op.ins)) == len(op.ins):
return
seen: dict[SSAValue, int] = dict()
new_probs: list[float] = []
new_ins: list[SSAValue] = []

for p, c in zip(op.probabilities.as_tuple(), op.ins, strict=True):
if c not in seen:
seen[c] = len(new_probs)
new_probs.append(p)
new_ins.append(c)
else:
new_probs[seen[c]] += p

rewriter.replace_matched_op(FinSuppOp(new_probs, op.default_value, *new_ins))
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ dev-dependencies = [
"ruff>=0.6.5",
"pyright>=1.1.380",
"pytest>=8.3.3",
"lit<16.0.0",
"filecheck==0.0.23",
"pre-commit==3.3.1",
"lit<19.0.0",
"filecheck==1.0.1",
"pre-commit==4.0.1",
"psutil>=6.0.0",
]

Expand Down
19 changes: 19 additions & 0 deletions tests/filecheck/dialects/prob/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,22 @@

// Stop them being dead code eliminated
"test.op"(%0, %1) : (i1, i1) -> ()

// CHECK: %[[#x1:]] = "test.op"() {"fin_supp_test"} : () -> i64
%2 = "test.op"() {"fin_supp_test"} : () -> i64
%3 = prob.fin_supp [ %2 ] : i64
// CHECK-NEXT: "test.op"(%[[#x1]]) : (i64) -> ()
"test.op"(%3) : (i64) -> ()

// CHECK: %[[#first:]], %[[#second:]], %[[#third:]] = "test.op"() : () -> (i32, i32, i32)
%4, %5, %6 = "test.op"() : () -> (i32, i32, i32)
// CHECK-NEXT: %{{.*}} = prob.fin_supp [ 0.375 of %[[#first]], else %[[#second]] ] : i32
%7 = prob.fin_supp [ 0.125 of %4, 0.25 of %4, else %5 ] : i32

// CHECK-NEXT: %{{.*}} = prob.fin_supp [ 0.1 of %[[#first]], else %[[#third]] ] : i32
%8 = prob.fin_supp [ 0.1 of %4, 0.0 of %5, else %6 ] : i32

// CHECK-NEXT: %{{.*}} = prob.fin_supp [ 0.2 of %[[#second]], else %[[#first]] ] : i32
%9 = prob.fin_supp [ 0.1 of %4, 0.2 of %5, else %4 ] : i32

"test.op"(%7, %8, %9) : (i32, i32, i32) -> ()
12 changes: 12 additions & 0 deletions tests/filecheck/dialects/prob/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,15 @@
// CHECK: %{{.*}} = prob.uniform : i32
// CHECK-GENERIC: %{{.*}} = "prob.uniform"() : () -> i32
%1 = prob.uniform : i32

%2, %3, %4 = "test.op"() : () -> (i64, i64, i64)

// CHECK: %{{.*}} = prob.fin_supp [ 0.1 of %{{.*}}, 0.2 of %{{.*}}, else %{{.*}} ] : i64
%5 = prob.fin_supp [
0.1 of %2,
0.2 of %3,
else %4
] : i64

// CHECK: %{{.*}} = prob.fin_supp [ %{{.*}} ] : i64
%6 = prob.fin_supp [ %4 ] : i64
18 changes: 9 additions & 9 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading