Skip to content

Commit

Permalink
Fancy constraints for qssa
Browse files Browse the repository at this point in the history
  • Loading branch information
alexarice committed Oct 18, 2024
1 parent 5c6f568 commit e050248
Show file tree
Hide file tree
Showing 7 changed files with 109 additions and 62 deletions.
8 changes: 5 additions & 3 deletions inconspiquous/alloc.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from abc import ABC
from xdsl.ir import ParametrizedAttribute
from typing import Sequence
from xdsl.ir import Attribute, ParametrizedAttribute
from xdsl.irdl import WithRangeType

Check failure on line 4 in inconspiquous/alloc.py

View workflow job for this annotation

GitHub Actions / build (3.10)

"WithRangeType" is unknown import symbol (reportGeneralTypeIssues)

Check failure on line 4 in inconspiquous/alloc.py

View workflow job for this annotation

GitHub Actions / build (3.10)

Type of "WithRangeType" is unknown (reportUnknownVariableType)


class AllocAttr(ParametrizedAttribute, ABC):
pass
class AllocAttr(ParametrizedAttribute, WithRangeType, ABC):

Check failure on line 7 in inconspiquous/alloc.py

View workflow job for this annotation

GitHub Actions / build (3.10)

Base class type is unknown, obscuring type of derived class (reportUntypedBaseClass)
def get_types(self) -> Sequence[Attribute]: ...
22 changes: 8 additions & 14 deletions inconspiquous/dialects/qssa.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import ClassVar
from xdsl.dialects.builtin import i1
from xdsl.ir import Dialect, Operation, SSAValue, VerifyException
from xdsl.ir import Dialect, Operation, SSAValue
from xdsl.irdl import (
EqAttrConstraint,
IRDLOperation,
Expand All @@ -15,23 +15,23 @@
var_result_def,
)

from inconspiquous.gates import GateAttr
from inconspiquous.gates import GateAttr, GateConstraint
from inconspiquous.dialects.qubit import BitType


@irdl_op_definition
class GateOp(IRDLOperation):
name = "qssa.gate"

gate = prop_def(GateAttr)

_T: ClassVar[RangeConstraint] = RangeVarConstraint(
"T", RangeOf(EqAttrConstraint(BitType()))
_Q: ClassVar[RangeConstraint] = RangeVarConstraint(
"Q", RangeOf(EqAttrConstraint(BitType()))
)

ins = var_operand_def(_T)
gate = prop_def(GateConstraint(_Q))

ins = var_operand_def(_Q)

outs = var_result_def(_T)
outs = var_result_def(_Q)

assembly_format = "`<` $gate `>` $ins attr-dict `:` type($ins)"

Expand All @@ -44,12 +44,6 @@ def __init__(self, gate: GateAttr, *ins: SSAValue | Operation):
result_types=tuple(BitType() for _ in ins),
)

def verify_(self) -> None:
if len(self.ins) != self.gate.num_qubits:
raise VerifyException(
f"Gate {self.gate} expected {self.gate.num_qubits} input qubits but got {len(self.ins)}."
)


@irdl_op_definition
class MeasureOp(IRDLOperation):
Expand Down
24 changes: 19 additions & 5 deletions inconspiquous/dialects/qubit.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
from xdsl.ir import Dialect, ParametrizedAttribute, TypeAttribute
from typing import ClassVar, Sequence
from xdsl.ir import Attribute, Dialect, ParametrizedAttribute, TypeAttribute
from xdsl.irdl import (
AnyAttr,
BaseAttr,
IRDLOperation,
RangeConstraint,
RangeOf,
RangeVarConstraint,
WithRangeTypeConstraint,
irdl_attr_definition,
irdl_op_definition,
prop_def,
result_def,
var_result_def,
)

from inconspiquous.alloc import AllocAttr
Expand All @@ -27,14 +34,21 @@ class AllocZeroAttr(AllocAttr):

name = "qubit.zero"

def get_types(self) -> Sequence[Attribute]:
return (BitType(),)


@irdl_op_definition
class AllocOp(IRDLOperation):
name = "qubit.alloc"

alloc = prop_def(AllocAttr, default_value=AllocZeroAttr())
_T: ClassVar[RangeConstraint] = RangeVarConstraint("T", RangeOf(AnyAttr()))

alloc = prop_def(
WithRangeTypeConstraint(BaseAttr(AllocAttr), _T), default_value=AllocZeroAttr()
)

outs = result_def(BitType())
outs = var_result_def(_T)

assembly_format = "(`` `<` $alloc^ `>`)? attr-dict"

Expand All @@ -43,7 +57,7 @@ def __init__(self, alloc: AllocAttr = AllocZeroAttr()):
properties={
"alloc": alloc,
},
result_types=[BitType()],
result_types=[alloc.get_types()],
)


Expand Down
37 changes: 37 additions & 0 deletions inconspiquous/gates.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
from abc import ABC

from xdsl.ir import (
Attribute,
ParametrizedAttribute,
VerifyException,
)
from xdsl.irdl import (
BaseAttr,
ConstraintContext,
GenericAttrConstraint,
RangeVarConstraint,

Check failure on line 12 in inconspiquous/gates.py

View workflow job for this annotation

GitHub Actions / build (3.10)

"RangeVarConstraint" is unknown import symbol (reportGeneralTypeIssues)

Check failure on line 12 in inconspiquous/gates.py

View workflow job for this annotation

GitHub Actions / build (3.10)

Type of "RangeVarConstraint" is unknown (reportUnknownVariableType)
)


Expand Down Expand Up @@ -33,3 +41,32 @@ class TwoQubitGate(GateAttr):
@property
def num_qubits(self) -> int:
return 2


class GateConstraint(GenericAttrConstraint[GateAttr]):
"""
Constrains a given range variable to have the correct size for the gate.
"""

range_var: str
gate_constraint: GenericAttrConstraint[GateAttr]

def __init__(
self,
range_constraint: str | RangeVarConstraint[Attribute],

Check failure on line 56 in inconspiquous/gates.py

View workflow job for this annotation

GitHub Actions / build (3.10)

Type of parameter "range_constraint" is partially unknown   Parameter type is "str | Unknown" (reportUnknownParameterType)
gate_constraint: GenericAttrConstraint[GateAttr] = BaseAttr[GateAttr](GateAttr),
):
if isinstance(range_constraint, str):
self.range_var = range_constraint
else:
self.range_var = range_constraint.name

Check failure on line 62 in inconspiquous/gates.py

View workflow job for this annotation

GitHub Actions / build (3.10)

Type of "name" is unknown (reportUnknownMemberType)
self.gate_constraint = gate_constraint

def verify(self, attr: Attribute, constraint_context: ConstraintContext) -> None:
if self.range_var in constraint_context.range_variables:

Check failure on line 66 in inconspiquous/gates.py

View workflow job for this annotation

GitHub Actions / build (3.10)

Cannot access member "range_variables" for type "ConstraintContext"   Member "range_variables" is unknown (reportGeneralTypeIssues)

Check failure on line 66 in inconspiquous/gates.py

View workflow job for this annotation

GitHub Actions / build (3.10)

Type of "range_variables" is unknown (reportUnknownMemberType)
attrs = constraint_context.get_range_variable(self.range_var)

Check failure on line 67 in inconspiquous/gates.py

View workflow job for this annotation

GitHub Actions / build (3.10)

Cannot access member "get_range_variable" for type "ConstraintContext"   Member "get_range_variable" is unknown (reportGeneralTypeIssues)
assert isinstance(attr, GateAttr)
if attr.num_qubits != len(attrs):
raise VerifyException(
f"Gate {attr} expected {attr.num_qubits} qubits but got {len(attrs)}"
)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ dev-dependencies = [
]

[tool.uv.sources]
xdsl = { git = "https://github.com/xdslproject/xdsl", branch = "alexarice/range-variable-sequence" }
xdsl = { git = "https://github.com/xdslproject/xdsl", branch = "alexarice/constraint-branch" }

[project.scripts]
quopt = "inconspiquous.tools.quopt_main:main"
Expand Down
6 changes: 3 additions & 3 deletions tests/filecheck/dialects/qssa/gate_counts.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,20 @@

%q0 = qubit.alloc

// CHECK: attributes ('!qubit.bit',) expected from range variable 'T', but got ('!qubit.bit', '!qubit.bit')
// CHECK: attributes ('!qubit.bit',) expected from range variable 'Q', but got ('!qubit.bit', '!qubit.bit')
%q1, %q2 = "qssa.gate"(%q0) <{"gate" = #gate.cnot}> : (!qubit.bit) -> (!qubit.bit, !qubit.bit)

// -----

%q0 = qubit.alloc
%q1 = qubit.alloc

// CHECK: attributes ('!qubit.bit', '!qubit.bit') expected from range variable 'T', but got ('!qubit.bit',)
// CHECK: attributes ('!qubit.bit', '!qubit.bit') expected from range variable 'Q', but got ('!qubit.bit',)
%q2 = "qssa.gate"(%q0, %q1) <{"gate" = #gate.cnot}> : (!qubit.bit, !qubit.bit) -> !qubit.bit

// -----

%q0 = qubit.alloc

// CHECK: Gate #gate.cnot expected 2 input qubits but got 1
// CHECK: Gate #gate.cnot expected 2 qubits but got 1
%q1 = "qssa.gate"(%q0) <{"gate" = #gate.cnot}> : (!qubit.bit) -> !qubit.bit
Loading

0 comments on commit e050248

Please sign in to comment.