Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 38 additions & 1 deletion src/bloqade/analysis/measure_id/impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
from bloqade import qubit, annotate

from .lattice import (
Predicate,
AnyMeasureId,
NotMeasureId,
RawMeasureId,
MeasureIdBool,
MeasureIdTuple,
InvalidMeasureId,
Expand Down Expand Up @@ -41,10 +43,45 @@ def measure_qubit_list(
measure_id_bools = []
for _ in range(num_qubits.data):
interp.measure_count += 1
measure_id_bools.append(MeasureIdBool(interp.measure_count))
measure_id_bools.append(RawMeasureId(interp.measure_count))

return (MeasureIdTuple(data=tuple(measure_id_bools)),)

@interp.impl(qubit.stmts.IsLost)
@interp.impl(qubit.stmts.IsOne)
@interp.impl(qubit.stmts.IsZero)
def measurement_predicate(
self,
interp: MeasurementIDAnalysis,
frame: interp.Frame,
stmt: qubit.stmts.IsLost | qubit.stmts.IsOne | qubit.stmts.IsZero,
):
original_measure_id_tuple = frame.get(stmt.measurements)
# all members should be RawMeasureId, if it's anything else
# it's Invalid.
if not all(
isinstance(measure_id, RawMeasureId)
for measure_id in original_measure_id_tuple.data
):
return (InvalidMeasureId(),)

# get the proper predicate type
if isinstance(stmt, qubit.stmts.IsLost):
predicate = Predicate.IS_LOST
elif isinstance(stmt, qubit.stmts.IsOne):
predicate = Predicate.IS_ONE
elif isinstance(stmt, qubit.stmts.IsZero):
predicate = Predicate.IS_ZERO
else:
return (InvalidMeasureId(),)

# Create new MeasureIdBools with proper predicate type
predicate_measure_ids = [
MeasureIdBool(measure_id.idx, predicate)
for measure_id in original_measure_id_tuple.data
]
return (MeasureIdTuple(data=tuple(predicate_measure_ids)),)


@annotate.dialect.register(key="measure_id")
class Annotate(interp.MethodTable):
Expand Down
27 changes: 21 additions & 6 deletions src/bloqade/analysis/measure_id/lattice.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from enum import Enum
from typing import final
from dataclasses import dataclass

Expand All @@ -8,8 +9,15 @@
SimpleMeetMixin,
)


class Predicate(Enum):
IS_ZERO = 1
IS_ONE = 2
IS_LOST = 3


# Taken directly from Kai-Hsin Wu's implementation
# with minor changes to names and addition of CanMeasureId type
# with minor changes to names


@dataclass
Expand Down Expand Up @@ -57,18 +65,25 @@ def is_subseteq(self, other: MeasureId) -> bool:

@final
@dataclass
class MeasureIdBool(MeasureId):
class RawMeasureId(MeasureId):
idx: int

def is_subseteq(self, other: MeasureId) -> bool:
if isinstance(other, MeasureIdBool):
if isinstance(other, RawMeasureId):
return self.idx == other.idx
return False


# Might be nice to have some print override
# here so all the CanMeasureId's/other types are consolidated for
# readability
@final
@dataclass
class MeasureIdBool(MeasureId):
idx: int
predicate: Predicate

def is_subseteq(self, other: MeasureId) -> bool:
if isinstance(other, MeasureIdBool):
return self.predicate == other.predicate and self.idx == other.idx
return False


@final
Expand Down
3 changes: 3 additions & 0 deletions src/bloqade/qubit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
from ._prelude import kernel as kernel
from .stdlib.simple import (
reset as reset,
is_one as is_one,
is_lost as is_lost,
is_zero as is_zero,
measure as measure,
get_qubit_id as get_qubit_id,
get_measurement_id as get_measurement_id,
Expand Down
18 changes: 17 additions & 1 deletion src/bloqade/qubit/_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from bloqade.types import Qubit, MeasurementResult

from .stmts import New, Reset, Measure, QubitId, MeasurementId
from .stmts import New, IsOne, Reset, IsLost, IsZero, Measure, QubitId, MeasurementId


@wraps(New)
Expand Down Expand Up @@ -47,3 +47,19 @@ def get_measurement_id(

@wraps(Reset)
def reset(qubits: ilist.IList[Qubit, Any]) -> None: ...


@wraps(IsZero)
def is_zero(
measurements: ilist.IList[MeasurementResult, N],
) -> ilist.IList[bool, N]: ...


@wraps(IsOne)
def is_one(measurements: ilist.IList[MeasurementResult, N]) -> ilist.IList[bool, N]: ...


@wraps(IsLost)
def is_lost(
measurements: ilist.IList[MeasurementResult, N],
) -> ilist.IList[bool, N]: ...
36 changes: 36 additions & 0 deletions src/bloqade/qubit/stdlib/broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,39 @@ def get_measurement_id(
measurement_ids (IList[int, N]): The list of global, unique IDs of the measurements.
"""
return _qubit.get_measurement_id(measurements)


@kernel
def is_zero(measurements: ilist.IList[MeasurementResult, N]) -> ilist.IList[bool, N]:
"""Check if each MeasurementResult in the list is equivalent to measuring the zero state.

Args:
measurements (IList[MeasurementResult, N]): The list of measurement results to check.
Returns:
IList[bool, N]: A list of booleans indicating whether each MeasurementResult is equivalent to the zero state.
"""
return _qubit.is_zero(measurements)


@kernel
def is_one(measurements: ilist.IList[MeasurementResult, N]) -> ilist.IList[bool, N]:
"""Check if each MeasurementResult in the list is equivalent to measuring the one state.

Args:
measurements (IList[MeasurementResult, N]): The list of measurement results to check.
Returns:
IList[bool, N]: A list of booleans indicating whether each MeasurementResult is equivalent to the one state.
"""
return _qubit.is_one(measurements)


@kernel
def is_lost(measurements: ilist.IList[MeasurementResult, N]) -> ilist.IList[bool, N]:
"""Check if each MeasurementResult in the list indicates atom loss.

Args:
measurements (IList[MeasurementResult, N]): The list of measurement results to check.
Returns:
IList[bool, N]: A list of booleans indicating whether each MeasurementResult indicates atom loss.
"""
return _qubit.is_lost(measurements)
41 changes: 41 additions & 0 deletions src/bloqade/qubit/stdlib/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,44 @@ def get_measurement_id(measurement: MeasurementResult) -> int:
"""
ids = broadcast.get_measurement_id(ilist.IList([measurement]))
return ids[0]


@kernel
def is_zero(measurement: MeasurementResult) -> bool:
"""Check if the measurement result is equivalent to measuring the zero state.

Args:
measurement (MeasurementResult): The measurement result to check.
Returns:
bool: True if the measurement result is equivalent to measuring the zero state, False otherwise.

"""
results = broadcast.is_zero(ilist.IList([measurement]))
return results[0]


@kernel
def is_one(measurement: MeasurementResult) -> bool:
"""Check if the measurement result is equivalent to measuring the one state.

Args:
measurement (MeasurementResult): The measurement result to check.
Returns:
bool: True if the measurement result is equivalent to measuring the one state, False otherwise.

"""
results = broadcast.is_one(ilist.IList([measurement]))
return results[0]


@kernel
def is_lost(measurement: MeasurementResult) -> bool:
"""Check if the measurement result indicates atom loss.

Args:
measurement (MeasurementResult): The measurement result to check.
Returns:
bool: True if the measurement result indicates atom loss, False otherwise.
"""
results = broadcast.is_lost(ilist.IList([measurement]))
return results[0]
24 changes: 24 additions & 0 deletions src/bloqade/qubit/stmts.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,30 @@ class Reset(ir.Statement):
qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType, types.Any])


@statement
class MeasurementPredicate(ir.Statement):
traits = frozenset({lowering.FromPythonCall(), ir.Pure()})
measurements: ir.SSAValue = info.argument(
ilist.IListType[MeasurementResultType, Len]
)
result: ir.ResultValue = info.result(ilist.IListType[types.Bool, Len])


@statement(dialect=dialect)
class IsZero(MeasurementPredicate):
pass


@statement(dialect=dialect)
class IsOne(MeasurementPredicate):
pass


@statement(dialect=dialect)
class IsLost(MeasurementPredicate):
pass


# TODO: investigate why this is needed to get type inference to be correct.
@dialect.register(key="typeinfer")
class __TypeInfer(interp.MethodTable):
Expand Down
3 changes: 3 additions & 0 deletions src/bloqade/squin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
from .. import qubit as qubit, annotate as annotate
from ..qubit import (
reset as reset,
is_one as is_one,
qalloc as qalloc,
is_lost as is_lost,
is_zero as is_zero,
measure as measure,
get_qubit_id as get_qubit_id,
get_measurement_id as get_measurement_id,
Expand Down
8 changes: 7 additions & 1 deletion src/bloqade/squin/stdlib/broadcast/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,10 @@
two_qubit_pauli_channel as two_qubit_pauli_channel,
single_qubit_pauli_channel as single_qubit_pauli_channel,
)
from ._qubit import reset as reset, measure as measure
from ._qubit import (
reset as reset,
is_one as is_one,
is_lost as is_lost,
is_zero as is_zero,
measure as measure,
)
3 changes: 3 additions & 0 deletions src/bloqade/squin/stdlib/broadcast/_qubit.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from bloqade.qubit.stdlib.broadcast import (
reset as reset,
is_one as is_one,
is_lost as is_lost,
is_zero as is_zero,
measure as measure,
)
8 changes: 4 additions & 4 deletions src/bloqade/stim/rewrite/get_record_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from kirin.dialects import py

from bloqade.stim.dialects import auxiliary
from bloqade.analysis.measure_id.lattice import MeasureIdBool, MeasureIdTuple
from bloqade.analysis.measure_id.lattice import RawMeasureId, MeasureIdTuple


def insert_get_records(
Expand All @@ -12,9 +12,9 @@ def insert_get_records(
Insert GetRecord statements before the given node
"""
get_record_ssas = []
for measure_id_bool in measure_id_tuple.data:
assert isinstance(measure_id_bool, MeasureIdBool)
target_rec_idx = (measure_id_bool.idx - 1) - meas_count_at_stmt
for measure_id in measure_id_tuple.data:
assert isinstance(measure_id, RawMeasureId)
target_rec_idx = (measure_id.idx - 1) - meas_count_at_stmt
idx_stmt = py.constant.Constant(target_rec_idx)
idx_stmt.insert_before(node)
get_record_stmt = auxiliary.GetRecord(idx_stmt.result)
Expand Down
19 changes: 9 additions & 10 deletions src/bloqade/stim/rewrite/ifs_to_stim.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@
from bloqade.stim.dialects.gate import CX as stim_CX, CY as stim_CY, CZ as stim_CZ
from bloqade.analysis.measure_id import MeasureIDFrame
from bloqade.stim.dialects.auxiliary import GetRecord
from bloqade.analysis.measure_id.lattice import (
MeasureIdBool,
)
from bloqade.analysis.measure_id.lattice import Predicate, MeasureIdBool


@dataclass
Expand Down Expand Up @@ -139,8 +137,13 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:

def rewrite_IfElse(self, stmt: scf.IfElse) -> RewriteResult:

# Check the condition is a singular MeasurementIdBool
if not isinstance(self.measure_frame.entries[stmt.cond], MeasureIdBool):
condition_type = self.measure_frame.entries[stmt.cond]
# Check the condition is a singular MeasurementIdBool and that
# it was generated by querying if the measurement is equivalent to the one state
if not isinstance(condition_type, MeasureIdBool):
return RewriteResult()

if condition_type.predicate != Predicate.IS_ONE:
return RewriteResult()

# Reusing code from SplitIf,
Expand All @@ -158,13 +161,9 @@ def rewrite_IfElse(self, stmt: scf.IfElse) -> RewriteResult:
else:
return RewriteResult()

# get necessary measurement ID type from analysis
measure_id_bool = self.measure_frame.entries[stmt.cond]
assert isinstance(measure_id_bool, MeasureIdBool)

# generate get record statement
measure_id_idx_stmt = py.Constant(
(measure_id_bool.idx - 1) - self.measure_frame.num_measures_at_stmt[stmt]
(condition_type.idx - 1) - self.measure_frame.num_measures_at_stmt[stmt]
)
get_record_stmt = GetRecord(id=measure_id_idx_stmt.result) # noqa: F841

Expand Down
Loading