diff --git a/src/bloqade/analysis/measure_id/impls.py b/src/bloqade/analysis/measure_id/impls.py index eb5abc22..53bf7b1d 100644 --- a/src/bloqade/analysis/measure_id/impls.py +++ b/src/bloqade/analysis/measure_id/impls.py @@ -5,8 +5,10 @@ from bloqade import qubit, annotate from .lattice import ( + Predicate, AnyMeasureId, NotMeasureId, + RawMeasureId, MeasureIdBool, MeasureIdTuple, InvalidMeasureId, @@ -41,10 +43,45 @@ def measure_qubit_list( measure_id_bools = [] for _ in range(num_qubits.data): interp.measure_count += 1 - measure_id_bools.append(MeasureIdBool(interp.measure_count)) + measure_id_bools.append(RawMeasureId(interp.measure_count)) return (MeasureIdTuple(data=tuple(measure_id_bools)),) + @interp.impl(qubit.stmts.IsLost) + @interp.impl(qubit.stmts.IsOne) + @interp.impl(qubit.stmts.IsZero) + def measurement_predicate( + self, + interp: MeasurementIDAnalysis, + frame: interp.Frame, + stmt: qubit.stmts.IsLost | qubit.stmts.IsOne | qubit.stmts.IsZero, + ): + original_measure_id_tuple = frame.get(stmt.measurements) + # all members should be RawMeasureId, if it's anything else + # it's Invalid. + if not all( + isinstance(measure_id, RawMeasureId) + for measure_id in original_measure_id_tuple.data + ): + return (InvalidMeasureId(),) + + # get the proper predicate type + if isinstance(stmt, qubit.stmts.IsLost): + predicate = Predicate.IS_LOST + elif isinstance(stmt, qubit.stmts.IsOne): + predicate = Predicate.IS_ONE + elif isinstance(stmt, qubit.stmts.IsZero): + predicate = Predicate.IS_ZERO + else: + return (InvalidMeasureId(),) + + # Create new MeasureIdBools with proper predicate type + predicate_measure_ids = [ + MeasureIdBool(measure_id.idx, predicate) + for measure_id in original_measure_id_tuple.data + ] + return (MeasureIdTuple(data=tuple(predicate_measure_ids)),) + @annotate.dialect.register(key="measure_id") class Annotate(interp.MethodTable): diff --git a/src/bloqade/analysis/measure_id/lattice.py b/src/bloqade/analysis/measure_id/lattice.py index 34d78b3c..a7ffada3 100644 --- a/src/bloqade/analysis/measure_id/lattice.py +++ b/src/bloqade/analysis/measure_id/lattice.py @@ -1,3 +1,4 @@ +from enum import Enum from typing import final from dataclasses import dataclass @@ -8,8 +9,15 @@ SimpleMeetMixin, ) + +class Predicate(Enum): + IS_ZERO = 1 + IS_ONE = 2 + IS_LOST = 3 + + # Taken directly from Kai-Hsin Wu's implementation -# with minor changes to names and addition of CanMeasureId type +# with minor changes to names @dataclass @@ -57,18 +65,25 @@ def is_subseteq(self, other: MeasureId) -> bool: @final @dataclass -class MeasureIdBool(MeasureId): +class RawMeasureId(MeasureId): idx: int def is_subseteq(self, other: MeasureId) -> bool: - if isinstance(other, MeasureIdBool): + if isinstance(other, RawMeasureId): return self.idx == other.idx return False -# Might be nice to have some print override -# here so all the CanMeasureId's/other types are consolidated for -# readability +@final +@dataclass +class MeasureIdBool(MeasureId): + idx: int + predicate: Predicate + + def is_subseteq(self, other: MeasureId) -> bool: + if isinstance(other, MeasureIdBool): + return self.predicate == other.predicate and self.idx == other.idx + return False @final diff --git a/src/bloqade/qubit/__init__.py b/src/bloqade/qubit/__init__.py index e8881b4a..17487534 100644 --- a/src/bloqade/qubit/__init__.py +++ b/src/bloqade/qubit/__init__.py @@ -6,6 +6,9 @@ from ._prelude import kernel as kernel from .stdlib.simple import ( reset as reset, + is_one as is_one, + is_lost as is_lost, + is_zero as is_zero, measure as measure, get_qubit_id as get_qubit_id, get_measurement_id as get_measurement_id, diff --git a/src/bloqade/qubit/_interface.py b/src/bloqade/qubit/_interface.py index 86910cdd..42738f56 100644 --- a/src/bloqade/qubit/_interface.py +++ b/src/bloqade/qubit/_interface.py @@ -5,7 +5,7 @@ from bloqade.types import Qubit, MeasurementResult -from .stmts import New, Reset, Measure, QubitId, MeasurementId +from .stmts import New, IsOne, Reset, IsLost, IsZero, Measure, QubitId, MeasurementId @wraps(New) @@ -47,3 +47,19 @@ def get_measurement_id( @wraps(Reset) def reset(qubits: ilist.IList[Qubit, Any]) -> None: ... + + +@wraps(IsZero) +def is_zero( + measurements: ilist.IList[MeasurementResult, N], +) -> ilist.IList[bool, N]: ... + + +@wraps(IsOne) +def is_one(measurements: ilist.IList[MeasurementResult, N]) -> ilist.IList[bool, N]: ... + + +@wraps(IsLost) +def is_lost( + measurements: ilist.IList[MeasurementResult, N], +) -> ilist.IList[bool, N]: ... diff --git a/src/bloqade/qubit/stdlib/broadcast.py b/src/bloqade/qubit/stdlib/broadcast.py index 09e9adff..04ace274 100644 --- a/src/bloqade/qubit/stdlib/broadcast.py +++ b/src/bloqade/qubit/stdlib/broadcast.py @@ -60,3 +60,39 @@ def get_measurement_id( measurement_ids (IList[int, N]): The list of global, unique IDs of the measurements. """ return _qubit.get_measurement_id(measurements) + + +@kernel +def is_zero(measurements: ilist.IList[MeasurementResult, N]) -> ilist.IList[bool, N]: + """Check if each MeasurementResult in the list is equivalent to measuring the zero state. + + Args: + measurements (IList[MeasurementResult, N]): The list of measurement results to check. + Returns: + IList[bool, N]: A list of booleans indicating whether each MeasurementResult is equivalent to the zero state. + """ + return _qubit.is_zero(measurements) + + +@kernel +def is_one(measurements: ilist.IList[MeasurementResult, N]) -> ilist.IList[bool, N]: + """Check if each MeasurementResult in the list is equivalent to measuring the one state. + + Args: + measurements (IList[MeasurementResult, N]): The list of measurement results to check. + Returns: + IList[bool, N]: A list of booleans indicating whether each MeasurementResult is equivalent to the one state. + """ + return _qubit.is_one(measurements) + + +@kernel +def is_lost(measurements: ilist.IList[MeasurementResult, N]) -> ilist.IList[bool, N]: + """Check if each MeasurementResult in the list indicates atom loss. + + Args: + measurements (IList[MeasurementResult, N]): The list of measurement results to check. + Returns: + IList[bool, N]: A list of booleans indicating whether each MeasurementResult indicates atom loss. + """ + return _qubit.is_lost(measurements) diff --git a/src/bloqade/qubit/stdlib/simple.py b/src/bloqade/qubit/stdlib/simple.py index 29953611..ea4ec951 100644 --- a/src/bloqade/qubit/stdlib/simple.py +++ b/src/bloqade/qubit/stdlib/simple.py @@ -57,3 +57,44 @@ def get_measurement_id(measurement: MeasurementResult) -> int: """ ids = broadcast.get_measurement_id(ilist.IList([measurement])) return ids[0] + + +@kernel +def is_zero(measurement: MeasurementResult) -> bool: + """Check if the measurement result is equivalent to measuring the zero state. + + Args: + measurement (MeasurementResult): The measurement result to check. + Returns: + bool: True if the measurement result is equivalent to measuring the zero state, False otherwise. + + """ + results = broadcast.is_zero(ilist.IList([measurement])) + return results[0] + + +@kernel +def is_one(measurement: MeasurementResult) -> bool: + """Check if the measurement result is equivalent to measuring the one state. + + Args: + measurement (MeasurementResult): The measurement result to check. + Returns: + bool: True if the measurement result is equivalent to measuring the one state, False otherwise. + + """ + results = broadcast.is_one(ilist.IList([measurement])) + return results[0] + + +@kernel +def is_lost(measurement: MeasurementResult) -> bool: + """Check if the measurement result indicates atom loss. + + Args: + measurement (MeasurementResult): The measurement result to check. + Returns: + bool: True if the measurement result indicates atom loss, False otherwise. + """ + results = broadcast.is_lost(ilist.IList([measurement])) + return results[0] diff --git a/src/bloqade/qubit/stmts.py b/src/bloqade/qubit/stmts.py index bd99756c..bdc6fd34 100644 --- a/src/bloqade/qubit/stmts.py +++ b/src/bloqade/qubit/stmts.py @@ -45,6 +45,30 @@ class Reset(ir.Statement): qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType, types.Any]) +@statement +class MeasurementPredicate(ir.Statement): + traits = frozenset({lowering.FromPythonCall(), ir.Pure()}) + measurements: ir.SSAValue = info.argument( + ilist.IListType[MeasurementResultType, Len] + ) + result: ir.ResultValue = info.result(ilist.IListType[types.Bool, Len]) + + +@statement(dialect=dialect) +class IsZero(MeasurementPredicate): + pass + + +@statement(dialect=dialect) +class IsOne(MeasurementPredicate): + pass + + +@statement(dialect=dialect) +class IsLost(MeasurementPredicate): + pass + + # TODO: investigate why this is needed to get type inference to be correct. @dialect.register(key="typeinfer") class __TypeInfer(interp.MethodTable): diff --git a/src/bloqade/squin/__init__.py b/src/bloqade/squin/__init__.py index 3a91751d..fba2659c 100644 --- a/src/bloqade/squin/__init__.py +++ b/src/bloqade/squin/__init__.py @@ -6,7 +6,10 @@ from .. import qubit as qubit, annotate as annotate from ..qubit import ( reset as reset, + is_one as is_one, qalloc as qalloc, + is_lost as is_lost, + is_zero as is_zero, measure as measure, get_qubit_id as get_qubit_id, get_measurement_id as get_measurement_id, diff --git a/src/bloqade/squin/stdlib/broadcast/__init__.py b/src/bloqade/squin/stdlib/broadcast/__init__.py index ae1015ea..eb3e4d79 100644 --- a/src/bloqade/squin/stdlib/broadcast/__init__.py +++ b/src/bloqade/squin/stdlib/broadcast/__init__.py @@ -31,4 +31,10 @@ two_qubit_pauli_channel as two_qubit_pauli_channel, single_qubit_pauli_channel as single_qubit_pauli_channel, ) -from ._qubit import reset as reset, measure as measure +from ._qubit import ( + reset as reset, + is_one as is_one, + is_lost as is_lost, + is_zero as is_zero, + measure as measure, +) diff --git a/src/bloqade/squin/stdlib/broadcast/_qubit.py b/src/bloqade/squin/stdlib/broadcast/_qubit.py index f6ef7bae..e88d6a9b 100644 --- a/src/bloqade/squin/stdlib/broadcast/_qubit.py +++ b/src/bloqade/squin/stdlib/broadcast/_qubit.py @@ -1,4 +1,7 @@ from bloqade.qubit.stdlib.broadcast import ( reset as reset, + is_one as is_one, + is_lost as is_lost, + is_zero as is_zero, measure as measure, ) diff --git a/src/bloqade/stim/rewrite/get_record_util.py b/src/bloqade/stim/rewrite/get_record_util.py index aaa28261..102bac1a 100644 --- a/src/bloqade/stim/rewrite/get_record_util.py +++ b/src/bloqade/stim/rewrite/get_record_util.py @@ -2,7 +2,7 @@ from kirin.dialects import py from bloqade.stim.dialects import auxiliary -from bloqade.analysis.measure_id.lattice import MeasureIdBool, MeasureIdTuple +from bloqade.analysis.measure_id.lattice import RawMeasureId, MeasureIdTuple def insert_get_records( @@ -12,9 +12,9 @@ def insert_get_records( Insert GetRecord statements before the given node """ get_record_ssas = [] - for measure_id_bool in measure_id_tuple.data: - assert isinstance(measure_id_bool, MeasureIdBool) - target_rec_idx = (measure_id_bool.idx - 1) - meas_count_at_stmt + for measure_id in measure_id_tuple.data: + assert isinstance(measure_id, RawMeasureId) + target_rec_idx = (measure_id.idx - 1) - meas_count_at_stmt idx_stmt = py.constant.Constant(target_rec_idx) idx_stmt.insert_before(node) get_record_stmt = auxiliary.GetRecord(idx_stmt.result) diff --git a/src/bloqade/stim/rewrite/ifs_to_stim.py b/src/bloqade/stim/rewrite/ifs_to_stim.py index 6a15253b..ec3d3c56 100644 --- a/src/bloqade/stim/rewrite/ifs_to_stim.py +++ b/src/bloqade/stim/rewrite/ifs_to_stim.py @@ -13,9 +13,7 @@ from bloqade.stim.dialects.gate import CX as stim_CX, CY as stim_CY, CZ as stim_CZ from bloqade.analysis.measure_id import MeasureIDFrame from bloqade.stim.dialects.auxiliary import GetRecord -from bloqade.analysis.measure_id.lattice import ( - MeasureIdBool, -) +from bloqade.analysis.measure_id.lattice import Predicate, MeasureIdBool @dataclass @@ -139,8 +137,13 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: def rewrite_IfElse(self, stmt: scf.IfElse) -> RewriteResult: - # Check the condition is a singular MeasurementIdBool - if not isinstance(self.measure_frame.entries[stmt.cond], MeasureIdBool): + condition_type = self.measure_frame.entries[stmt.cond] + # Check the condition is a singular MeasurementIdBool and that + # it was generated by querying if the measurement is equivalent to the one state + if not isinstance(condition_type, MeasureIdBool): + return RewriteResult() + + if condition_type.predicate != Predicate.IS_ONE: return RewriteResult() # Reusing code from SplitIf, @@ -158,13 +161,9 @@ def rewrite_IfElse(self, stmt: scf.IfElse) -> RewriteResult: else: return RewriteResult() - # get necessary measurement ID type from analysis - measure_id_bool = self.measure_frame.entries[stmt.cond] - assert isinstance(measure_id_bool, MeasureIdBool) - # generate get record statement measure_id_idx_stmt = py.Constant( - (measure_id_bool.idx - 1) - self.measure_frame.num_measures_at_stmt[stmt] + (condition_type.idx - 1) - self.measure_frame.num_measures_at_stmt[stmt] ) get_record_stmt = GetRecord(id=measure_id_idx_stmt.result) # noqa: F841 diff --git a/test/analysis/measure_id/test_measure_id.py b/test/analysis/measure_id/test_measure_id.py index 2a4f97c8..19bc5f06 100644 --- a/test/analysis/measure_id/test_measure_id.py +++ b/test/analysis/measure_id/test_measure_id.py @@ -5,7 +5,9 @@ from bloqade.analysis.measure_id import MeasurementIDAnalysis from bloqade.stim.passes.flatten import Flatten from bloqade.analysis.measure_id.lattice import ( + Predicate, NotMeasureId, + RawMeasureId, MeasureIdBool, MeasureIdTuple, InvalidMeasureId, @@ -26,6 +28,25 @@ def results_of_variables(kernel, variable_names): return results +def test_subset_eq_MeasureIdBool(): + + m0 = MeasureIdBool(idx=1, predicate=Predicate.IS_ONE) + m1 = MeasureIdBool(idx=1, predicate=Predicate.IS_ONE) + + assert m0.is_subseteq(m1) + + # not equivalent if predicate is different + m2 = MeasureIdBool(idx=1, predicate=Predicate.IS_ZERO) + + assert not m0.is_subseteq(m2) + + # not equivalent if index is different either, + # they are only equivalent if both index and predicate match + m3 = MeasureIdBool(idx=2, predicate=Predicate.IS_ONE) + + assert not m0.is_subseteq(m3) + + def test_add(): @squin.kernel def test(): @@ -48,7 +69,7 @@ def test(): # construct expected MeasureIdTuple expected_measure_id_tuple = MeasureIdTuple( - data=tuple([MeasureIdBool(idx=i) for i in range(1, 11)]) + data=tuple([RawMeasureId(idx=i) for i in range(1, 11)]) ) assert measure_id_tuples[-1] == expected_measure_id_tuple @@ -73,7 +94,7 @@ def test(): # construct expected MeasureIdTuples measure_id_tuple_with_id_bools = MeasureIdTuple( - data=tuple([MeasureIdBool(idx=i) for i in range(1, 6)]) + data=tuple([RawMeasureId(idx=i) for i in range(1, 6)]) ) measure_id_tuple_with_not_measures = MeasureIdTuple( data=tuple([NotMeasureId() for _ in range(5)]) @@ -135,7 +156,7 @@ def test(): # First from the measurement in the true branch, then # the result of the scf.IfElse itself analysis_results = [ - val for val in frame.entries.values() if val == MeasureIdBool(idx=1) + val for val in frame.entries.values() if val == RawMeasureId(idx=1) ] assert len(analysis_results) == 2 @@ -165,7 +186,7 @@ def test(): # First from the measurement in the false branch, then # the result of the scf.IfElse itself analysis_results = [ - val for val in frame.entries.values() if val == MeasureIdBool(idx=1) + val for val in frame.entries.values() if val == RawMeasureId(idx=1) ] assert len(analysis_results) == 2 @@ -194,9 +215,9 @@ def test(cond: bool): # Both branches of the scf.IfElse should be properly traversed and contain the following # analysis results. expected_full_register_measurement = MeasureIdTuple( - data=tuple([MeasureIdBool(idx=i) for i in range(1, 6)]) + data=tuple([RawMeasureId(idx=i) for i in range(1, 6)]) ) - expected_else_measurement = MeasureIdTuple(data=(MeasureIdBool(idx=6),)) + expected_else_measurement = MeasureIdTuple(data=(RawMeasureId(idx=6),)) assert expected_full_register_measurement in analysis_results assert expected_else_measurement in analysis_results @@ -221,15 +242,15 @@ def test(): # This is an assertion against `msi` NOT the initial list of measurements assert frame.get(results["msi"]) == MeasureIdTuple( - data=tuple(list(MeasureIdBool(idx=i) for i in range(2, 7))) + data=tuple(list(RawMeasureId(idx=i) for i in range(2, 7))) ) # msi2 assert frame.get(results["msi2"]) == MeasureIdTuple( - data=tuple(list(MeasureIdBool(idx=i) for i in range(3, 7))) + data=tuple(list(RawMeasureId(idx=i) for i in range(3, 7))) ) # ms_final assert frame.get(results["ms_final"]) == MeasureIdTuple( - data=(MeasureIdBool(idx=3), MeasureIdBool(idx=5)) + data=(RawMeasureId(idx=3), RawMeasureId(idx=5)) ) @@ -278,3 +299,43 @@ def test(): assert [frame.entries[result] for result in results_at(test, 0, 6)] == [ InvalidMeasureId() ] + + +def test_measurement_predicates(): + @squin.kernel + def test(): + q = squin.qalloc(3) + ms = squin.broadcast.measure(q) + + is_zero_bools = squin.broadcast.is_zero(ms) + is_one_bools = squin.broadcast.is_one(ms) + is_lost_bools = squin.broadcast.is_lost(ms) + + return is_zero_bools, is_one_bools, is_lost_bools + + Flatten(test.dialects).fixpoint(test) + frame, _ = MeasurementIDAnalysis(test.dialects).run(test) + + results = results_of_variables( + test, ("is_zero_bools", "is_one_bools", "is_lost_bools") + ) + + expected_is_zero_bools = MeasureIdTuple( + data=tuple( + [MeasureIdBool(idx=i, predicate=Predicate.IS_ZERO) for i in range(1, 4)] + ) + ) + expected_is_one_bools = MeasureIdTuple( + data=tuple( + [MeasureIdBool(idx=i, predicate=Predicate.IS_ONE) for i in range(1, 4)] + ) + ) + expected_is_lost_bools = MeasureIdTuple( + data=tuple( + [MeasureIdBool(idx=i, predicate=Predicate.IS_LOST) for i in range(1, 4)] + ) + ) + + assert frame.get(results["is_zero_bools"]) == expected_is_zero_bools + assert frame.get(results["is_one_bools"]) == expected_is_one_bools + assert frame.get(results["is_lost_bools"]) == expected_is_lost_bools diff --git a/test/stim/passes/stim_reference_programs/qubit/valid_if_measure_predicate.stim b/test/stim/passes/stim_reference_programs/qubit/valid_if_measure_predicate.stim new file mode 100644 index 00000000..2aa5da4e --- /dev/null +++ b/test/stim/passes/stim_reference_programs/qubit/valid_if_measure_predicate.stim @@ -0,0 +1,5 @@ +MZ(0.00000000) 0 1 2 +RZ 0 1 2 +CX rec[-3] 0 +CY rec[-2] 1 +CZ rec[-1] 2 diff --git a/test/stim/passes/test_annotation_to_stim.py b/test/stim/passes/test_annotation_to_stim.py index 13d55223..bf79cca5 100644 --- a/test/stim/passes/test_annotation_to_stim.py +++ b/test/stim/passes/test_annotation_to_stim.py @@ -70,12 +70,12 @@ def main(): ms = squin.broadcast.measure(q) - if ms[0]: + if squin.is_one(ms[0]): squin.z(q[0]) squin.broadcast.x([q[1], q[2], q[3]]) squin.broadcast.z(q) - if ms[1]: + if squin.is_one(ms[1]): squin.x(q[0]) squin.y(q[1]) @@ -101,13 +101,14 @@ def main(): ms = squin.broadcast.measure(q) - if ms[0]: + if squin.is_one(ms[0]): squin.z(q[0]) else: squin.x(q[0]) return + SquinToStimPass(main.dialects)(main) assert any(isinstance(stmt, scf.IfElse) for stmt in main.code.regions[0].stmts()) @@ -120,13 +121,56 @@ def main(): ms = squin.broadcast.measure(q) - if ms[0]: + if squin.is_one(ms[0]): squin.z(q[0]) - if ms[0]: + if squin.is_one(ms[0]): squin.x(q[1]) return + SquinToStimPass(main.dialects)(main) + assert any(isinstance(stmt, scf.IfElse) for stmt in main.code.regions[0].stmts()) + + +def test_missing_predicate(): + + # No rewrite should occur because even though there is an scf.IfElse, + # it does not have the proper predicate to be rewritten. + @squin.kernel + def main(): + n_qubits = 4 + q = squin.qalloc(n_qubits) + + ms = squin.broadcast.measure(q) + + if ms[0]: + squin.z(q[0]) + + return + + SquinToStimPass(main.dialects, no_raise=True)(main) + assert any(isinstance(stmt, scf.IfElse) for stmt in main.code.regions[0].stmts()) + + +def test_incorrect_predicate(): + + # You can only rewrite squin.is_one(...) predicates to + # stim equivalent feedforward statements. Anything else + # is invalid. + + @squin.kernel + def main(): + n_qubits = 4 + q = squin.qalloc(n_qubits) + + ms = squin.broadcast.measure(q) + + if squin.is_lost(ms[0]): + squin.z(q[0]) + + return + + SquinToStimPass(main.dialects, no_raise=True)(main) assert any(isinstance(stmt, scf.IfElse) for stmt in main.code.regions[0].stmts()) diff --git a/test/stim/passes/test_squin_meas_to_stim.py b/test/stim/passes/test_squin_meas_to_stim.py index dd323051..78a86476 100644 --- a/test/stim/passes/test_squin_meas_to_stim.py +++ b/test/stim/passes/test_squin_meas_to_stim.py @@ -37,12 +37,12 @@ def main(): ms = sq.broadcast.measure(q) - if ms[0]: + if sq.is_one(ms[0]): sq.z(q[0]) sq.broadcast.x([q[1], q[2], q[3]]) sq.broadcast.z(q) - if ms[1]: + if sq.is_one(ms[1]): sq.x(q[0]) sq.y(q[1]) @@ -64,7 +64,7 @@ def main(): ms = sq.broadcast.measure(q) new_ms = ms - if new_ms[0]: + if sq.is_one(new_ms[0]): sq.z(q[0]) SquinToStimPass(main.dialects)(main) @@ -83,13 +83,13 @@ def main(): ms0 = sq.broadcast.measure(q) - if ms0[0]: # should be rec[-4] + if sq.is_one(ms0[0]): # should be rec[-4] sq.z(q[0]) # another measurement ms1 = sq.broadcast.measure(q) - if ms1[0]: # should be rec[-4] + if sq.is_one(ms1[0]): # should be rec[-4] sq.x(q[0]) # second round of measurement @@ -98,7 +98,7 @@ def main(): # try accessing measurements from the very first round ## There are now 12 total measurements, ms0[0] ## is the oldest measurement in the entire program - if ms0[0]: + if sq.is_one(ms0[0]): sq.y(q[1]) SquinToStimPass(main.dialects)(main) @@ -117,25 +117,25 @@ def main(): ms0 = sq.broadcast.measure(q) - if ms0[0]: + if sq.is_one(ms0[0]): sq.z(q[0]) ms1 = sq.broadcast.measure(q) - if ms1[0]: + if sq.is_one(ms1[0]): sq.x(q[1]) # another measurement ms2 = sq.broadcast.measure(q) - if ms2[0]: + if sq.is_one(ms2[0]): sq.y(q[2]) # Intentionally obnoxious mix of measurements mix = [ms0[0], ms1[2], ms2[3]] mix_again = (mix[2], mix[0]) - if mix_again[0]: + if sq.is_one(mix_again[0]): sq.z(q[3]) SquinToStimPass(main.dialects)(main) diff --git a/test/stim/passes/test_squin_qubit_to_stim.py b/test/stim/passes/test_squin_qubit_to_stim.py index 63d644a3..10e14883 100644 --- a/test/stim/passes/test_squin_qubit_to_stim.py +++ b/test/stim/passes/test_squin_qubit_to_stim.py @@ -4,7 +4,7 @@ from math import pi from kirin import ir -from kirin.dialects import py +from kirin.dialects import py, scf from bloqade import stim, qubit, squin as sq from bloqade.squin import kernel @@ -23,6 +23,16 @@ def codegen(mt: ir.Method): return buf.getvalue().strip() +def filter_statements_by_type( + method: ir.Method, types: tuple[type, ...] +) -> list[ir.Statement]: + return [ + stmt + for stmt in method.callable_region.blocks[0].stmts + if isinstance(stmt, types) + ] + + def as_int(value: int): return py.constant.Constant(value=value) @@ -258,6 +268,55 @@ def main(): assert codegen(main) == base_stim_prog.rstrip() +def test_valid_if_measure_predicate(): + @sq.kernel + def test(): + q = sq.qalloc(3) + ms = sq.broadcast.measure(q) + could_be_one = sq.broadcast.is_one(ms) + sq.broadcast.reset(q) + if could_be_one[0]: + sq.x(q[0]) + + if could_be_one[1]: + sq.y(q[1]) + + if could_be_one[2]: + sq.z(q[2]) + + SquinToStimPass(test.dialects)(test) + base_stim_prog = load_reference_program("valid_if_measure_predicate.stim") + assert codegen(test) == base_stim_prog.rstrip() + + +# You can only convert a combination of a predicate type and +# scf.IfElse if the predicate type is IS_ONE. Otherwise anything +# else is invalid +def test_invalid_if_measure_predicate(): + @sq.kernel + def test(): + q = sq.qalloc(3) + ms = sq.broadcast.measure(q) + could_be_zero = sq.broadcast.is_zero(ms) + could_be_lost = sq.broadcast.is_lost(ms) + sq.broadcast.reset(q) + + if could_be_zero[0]: + sq.x(q[0]) + + if could_be_lost[1]: + sq.y(q[1]) + + SquinToStimPass(test.dialects)(test) + # rewrite for scf.IfElse did not occur due to invalid predicate type, + # should have two scf.IfElse remaining + remaining_if_else = filter_statements_by_type(test, (scf.IfElse,)) + assert len(remaining_if_else) == 2 + + +test_invalid_if_measure_predicate() + + def test_non_pure_loop_iterator(): @kernel def test_squin_kernel():