Skip to content

Commit

Permalink
dialects: (prob) initialize
Browse files Browse the repository at this point in the history
  • Loading branch information
alexarice committed Oct 25, 2024
1 parent 0c92224 commit 913d475
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 3 deletions.
12 changes: 9 additions & 3 deletions inconspiquous/dialects/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,20 @@ def get_func():

return Func

def get_gate():
from inconspiquous.dialects.gate import Gate

return Gate

def get_linalg():
from xdsl.dialects.linalg import Linalg

return Linalg

def get_gate():
from inconspiquous.dialects.gate import Gate
def get_prob():
from inconspiquous.dialects.prob import Prob

return Gate
return Prob

def get_qref():
from inconspiquous.dialects.qref import Qref
Expand Down Expand Up @@ -78,6 +83,7 @@ def get_test():
"func": get_func,
"gate": get_gate,
"linalg": get_linalg,
"prob": get_prob,
"qref": get_qref,
"qubit": get_qubit,
"qssa": get_qssa,
Expand Down
65 changes: 65 additions & 0 deletions inconspiquous/dialects/prob.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from xdsl.dialects.builtin import IntegerAttrTypeConstr, i1
from xdsl.ir import Dialect, VerifyException
from xdsl.irdl import (
IRDLOperation,
irdl_op_definition,
prop_def,
result_def,
)
from xdsl.parser import Float64Type, FloatAttr, IndexType, IntegerType


@irdl_op_definition
class BernoulliOp(IRDLOperation):
name = "prob.bernoulli"

prob = prop_def(FloatAttr[Float64Type])

out = result_def(i1)

assembly_format = "$prob attr-dict"

def __init__(self, prob: float | FloatAttr[Float64Type]):
if isinstance(prob, float):
prob = FloatAttr(prob, 64)

# Why is this needed?
assert not isinstance(prob, int)

super().__init__(
properties={
"prob": prob,
},
result_types=(i1,),
)

def verify_(self) -> None:
prob = self.prob.value.data
if prob < 0 or prob > 1:
raise VerifyException(
f"Property 'prob' = {prob} should be in the range [0, 1]"
)


@irdl_op_definition
class UniformOp(IRDLOperation):
name = "prob.uniform"

out = result_def(IntegerAttrTypeConstr)

assembly_format = "attr-dict `:` type($out)"

def __init__(self, out_type: IntegerType | IndexType):
super().__init__(
result_types=(out_type,),
)


Prob = Dialect(
"prob",
[
BernoulliOp,
UniformOp,
],
[],
)
10 changes: 10 additions & 0 deletions tests/filecheck/dialects/prob/ops.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// RUN: QUOPT_ROUNDTRIP
// RUN: QUOPT_GENERIC_ROUNDTRIP

// CHECK: %{{.*}} = prob.bernoulli 5.000000e-01 : f64
// CHECK-GENERIC: %{{.*}} = "prob.bernoulli"() <{"prob" = 5.000000e-01 : f64}> : () -> i1
%0 = prob.bernoulli 0.5

// CHECK: %{{.*}} = prob.uniform : i32
// CHECK-GENERIC: %{{.*}} = "prob.uniform"() : () -> i32
%1 = prob.uniform : i32
9 changes: 9 additions & 0 deletions tests/filecheck/dialects/prob/verify.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// RUN: quopt %s --verify-diagnostics --split-input-file | filecheck %s

%0 = prob.bernoulli -1.0
// CHECK: Property 'prob' = -1.0 should be in the range [0, 1]

// -----

%1 = prob.bernoulli 1.5
// CHECK: Property 'prob' = 1.5 should be in the range [0, 1]

0 comments on commit 913d475

Please sign in to comment.