From 2fb54020b934a32216989b2e71e0cad7a1b17895 Mon Sep 17 00:00:00 2001 From: David Plankensteiner Date: Mon, 27 Oct 2025 15:53:35 +0100 Subject: [PATCH 1/6] Bump kirin version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index b9fa49f9..17e94b44 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,7 @@ requires-python = ">=3.10" dependencies = [ "numpy>=1.22.0", "scipy>=1.13.1", - "kirin-toolchain~=0.17.30", + "kirin-toolchain~=0.20.0", "rich>=13.9.4", "pydantic>=1.3.0,<2.11.0", "pandas>=2.2.3", From dc9a5bb284c86356835f802ff68d18305c865d2b Mon Sep 17 00:00:00 2001 From: Phillip Weinberg Date: Wed, 29 Oct 2025 09:56:12 -0400 Subject: [PATCH 2/6] Fix `CallGraphPass` issues in Kirin 0.20 (#586) In this PR I port the `CallGraphPass` originally in Kirin 0.17 to the passes here in bloqade-circuit. I have also updated them to use the new Kirin APIs so they are not compatible with Kirin 0.20. Unfortunately the tests are currently breaking because of some import issues so the CI will not be able to pass for the tests until those are fixed. Co-authored-by: David Plankensteiner --- src/bloqade/native/upstream/squin2native.py | 67 +---------- src/bloqade/rewrite/passes/__init__.py | 5 + src/bloqade/rewrite/passes/callgraph.py | 116 ++++++++++++++++++++ 3 files changed, 125 insertions(+), 63 deletions(-) create mode 100644 src/bloqade/rewrite/passes/callgraph.py diff --git a/src/bloqade/native/upstream/squin2native.py b/src/bloqade/native/upstream/squin2native.py index 97604630..2a9131f2 100644 --- a/src/bloqade/native/upstream/squin2native.py +++ b/src/bloqade/native/upstream/squin2native.py @@ -1,14 +1,13 @@ from itertools import chain -from dataclasses import field, dataclass -from kirin import ir, passes, rewrite +from kirin import ir, rewrite from kirin.dialects import py, func from kirin.rewrite.abc import RewriteRule, RewriteResult -from kirin.passes.callgraph import CallGraphPass, ReplaceMethods from kirin.analysis.callgraph import CallGraph from bloqade.native import kernel, broadcast from bloqade.squin.gate import stmts, dialect as gate_dialect +from bloqade.rewrite.passes import CallGraphPass, UpdateDialectsOnCallGraph class GateRule(RewriteRule): @@ -46,63 +45,6 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: return RewriteResult(has_done_something=True) -@dataclass -class UpdateDialectsOnCallGraph(passes.Pass): - """Update All dialects on the call graph to a new set of dialects given to this pass. - - Usage: - pass_ = UpdateDialectsOnCallGraph(rule=rule, dialects=new_dialects) - pass_(some_method) - - Note: This pass does not update the dialects of the input method, but copies - all other methods invoked within it before updating their dialects. - - """ - - fold_pass: passes.Fold = field(init=False) - - def __post_init__(self): - self.fold_pass = passes.Fold(self.dialects, no_raise=self.no_raise) - - def unsafe_run(self, mt: ir.Method) -> RewriteResult: - mt_map = {} - - cg = CallGraph(mt) - - all_methods = set(sum(map(tuple, cg.defs.values()), ())) - for original_mt in all_methods: - if original_mt is mt: - new_mt = original_mt - else: - new_mt = original_mt.similar(self.dialects) - mt_map[original_mt] = new_mt - - result = RewriteResult() - - for _, new_mt in mt_map.items(): - result = ( - rewrite.Walk(ReplaceMethods(mt_map)).rewrite(new_mt.code).join(result) - ) - self.fold_pass(new_mt) - - return result - - -@dataclass -class SquinToNativePass(passes.Pass): - - call_graph_pass: CallGraphPass = field(init=False) - - def __post_init__(self): - rule = rewrite.Walk(GateRule()) - self.call_graph_pass = CallGraphPass( - self.dialects, rule, no_raise=self.no_raise - ) - - def unsafe_run(self, mt: ir.Method) -> RewriteResult: - return self.call_graph_pass.unsafe_run(mt) - - class SquinToNative: """A Target that converts Squin gates to native gates.""" @@ -126,11 +68,10 @@ def emit(self, mt: ir.Method, *, no_raise=True) -> ir.Method: out = mt.similar(new_dialects) UpdateDialectsOnCallGraph(new_dialects, no_raise=no_raise)(out) - SquinToNativePass(new_dialects, no_raise=no_raise)(out) + CallGraphPass(new_dialects, rewrite.Walk(GateRule()), no_raise=no_raise)(out) # verify all kernels in the callgraph new_callgraph = CallGraph(out) - all_kernels = (ker for kers in new_callgraph.defs.values() for ker in kers) - for ker in all_kernels: + for ker in new_callgraph.edges.keys(): ker.verify() return out diff --git a/src/bloqade/rewrite/passes/__init__.py b/src/bloqade/rewrite/passes/__init__.py index b6236a76..e35a23cb 100644 --- a/src/bloqade/rewrite/passes/__init__.py +++ b/src/bloqade/rewrite/passes/__init__.py @@ -1,2 +1,7 @@ +from .callgraph import ( + CallGraphPass as CallGraphPass, + ReplaceMethods as ReplaceMethods, + UpdateDialectsOnCallGraph as UpdateDialectsOnCallGraph, +) from .aggressive_unroll import AggressiveUnroll as AggressiveUnroll from .canonicalize_ilist import CanonicalizeIList as CanonicalizeIList diff --git a/src/bloqade/rewrite/passes/callgraph.py b/src/bloqade/rewrite/passes/callgraph.py new file mode 100644 index 00000000..0b5e64f0 --- /dev/null +++ b/src/bloqade/rewrite/passes/callgraph.py @@ -0,0 +1,116 @@ +from dataclasses import field, dataclass + +from kirin import ir, passes, rewrite +from kirin.analysis import CallGraph +from kirin.rewrite.abc import RewriteRule, RewriteResult +from kirin.dialects.func.stmts import Invoke + + +@dataclass +class ReplaceMethods(RewriteRule): + new_symbols: dict[ir.Method, ir.Method] + + def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: + if ( + not isinstance(node, Invoke) + or (new_callee := self.new_symbols.get(node.callee)) is None + ): + return RewriteResult() + + node.replace_by( + Invoke( + inputs=node.inputs, + callee=new_callee, + purity=node.purity, + ) + ) + + return RewriteResult(has_done_something=True) + + +@dataclass +class UpdateDialectsOnCallGraph(passes.Pass): + """Update All dialects on the call graph to a new set of dialects given to this pass. + + Usage: + pass_ = UpdateDialectsOnCallGraph(rule=rule, dialects=new_dialects) + pass_(some_method) + + Note: This pass does not update the dialects of the input method, but copies + all other methods invoked within it before updating their dialects. + + """ + + fold_pass: passes.Fold = field(init=False) + + def __post_init__(self): + self.fold_pass = passes.Fold(self.dialects, no_raise=self.no_raise) + + def unsafe_run(self, mt: ir.Method) -> RewriteResult: + mt_map = {} + + cg = CallGraph(mt) + + all_methods = set(sum(map(tuple, cg.defs.values()), ())) + for original_mt in all_methods: + if original_mt is mt: + new_mt = original_mt + else: + new_mt = original_mt.similar(self.dialects) + mt_map[original_mt] = new_mt + + result = RewriteResult() + + for _, new_mt in mt_map.items(): + result = ( + rewrite.Walk(ReplaceMethods(mt_map)).rewrite(new_mt.code).join(result) + ) + self.fold_pass(new_mt) + + return result + + +@dataclass +class CallGraphPass(passes.Pass): + """Copy all functions in the call graph and apply a rule to each of them. + + + Usage: + rule = Walk(SomeRewriteRule()) + pass_ = CallGraphPass(rule=rule, dialects=...) + pass_(some_method) + + Note: This pass modifies the input method in place, but copies + all methods invoked within it before applying the rule to them. + + """ + + rule: RewriteRule + """The rule to apply to each function in the call graph.""" + + fold_pass: passes.Fold = field(init=False) + + def __post_init__(self): + self.fold_pass = passes.Fold(self.dialects, no_raise=self.no_raise) + + def unsafe_run(self, mt: ir.Method) -> RewriteResult: + result = RewriteResult() + mt_map = {} + + cg = CallGraph(mt) + + all_methods = set(cg.edges.keys()) + for original_mt in all_methods: + if original_mt is mt: + new_mt = original_mt + else: + new_mt = original_mt.similar() + result = self.rule.rewrite(new_mt.code).join(result) + mt_map[original_mt] = new_mt + + if result.has_done_something: + for _, new_mt in mt_map.items(): + rewrite.Walk(ReplaceMethods(mt_map)).rewrite(new_mt.code) + self.fold_pass(new_mt) + + return result From 55686c3ebcb64d0e2f3713be4441b519f93973b9 Mon Sep 17 00:00:00 2001 From: Dennis Liew <48105496+zhenrongliew@users.noreply.github.com> Date: Tue, 4 Nov 2025 14:25:34 -0500 Subject: [PATCH 3/6] Code Generation for QASM2 (#555) Refactor QASM2's code generation to use the new Emit APIs. --------- Co-authored-by: David Plankensteiner --- src/bloqade/analysis/fidelity/analysis.py | 7 +- src/bloqade/analysis/measure_id/analysis.py | 3 + src/bloqade/cirq_utils/emit/base.py | 82 ++++++++++---- src/bloqade/cirq_utils/emit/gate.py | 16 +-- src/bloqade/cirq_utils/emit/noise.py | 8 +- src/bloqade/cirq_utils/emit/qubit.py | 4 +- src/bloqade/cirq_utils/lowering.py | 4 +- src/bloqade/native/upstream/__init__.py | 1 - src/bloqade/pyqrack/device.py | 2 +- src/bloqade/pyqrack/target.py | 2 +- src/bloqade/qasm2/dialects/expr/_emit.py | 27 +++-- src/bloqade/qasm2/emit/base.py | 24 ++-- src/bloqade/qasm2/emit/gate.py | 19 ++-- src/bloqade/qasm2/emit/main.py | 106 +++++++++++++++++- src/bloqade/qasm2/emit/target.py | 20 ++-- src/bloqade/qasm2/groups.py | 5 +- src/bloqade/qasm2/passes/glob.py | 4 +- src/bloqade/qasm2/passes/noise.py | 2 +- src/bloqade/qasm2/passes/parallel.py | 8 +- src/bloqade/squin/analysis/schedule.py | 2 +- src/bloqade/stim/dialects/auxiliary/emit.py | 37 +++--- .../stim/dialects/collapse/emit_str.py | 15 ++- src/bloqade/stim/dialects/gate/emit.py | 19 ++-- src/bloqade/stim/dialects/noise/emit.py | 23 ++-- src/bloqade/stim/emit/impls.py | 7 +- src/bloqade/stim/emit/stim_str.py | 77 ++++++++----- src/bloqade/stim/groups.py | 14 ++- src/bloqade/stim/parse/lowering.py | 2 + src/bloqade/stim/passes/squin_to_stim.py | 4 +- src/bloqade/test_utils.py | 2 +- test/analysis/address/test_qubit_analysis.py | 6 +- test/analysis/measure_id/test_measure_id.py | 20 ++-- test/cirq_utils/test_clifford_to_cirq.py | 4 +- test/qasm2/emit/t_qasm2.qasm | 13 +++ test/qasm2/emit/test_extended_noise.py | 2 +- test/qasm2/emit/test_qasm2.py | 23 +++- test/qasm2/emit/test_qasm2_emit.py | 21 ++-- test/qasm2/passes/test_global_to_parallel.py | 3 + test/qasm2/passes/test_global_to_uop.py | 2 + test/qasm2/passes/test_heuristic_noise.py | 2 + test/qasm2/passes/test_parallel_to_global.py | 7 ++ test/qasm2/passes/test_parallel_to_uop.py | 2 + test/qasm2/passes/test_uop_to_parallel.py | 11 +- test/qasm2/test_count.py | 16 ++- test/qasm2/test_lowering.py | 7 ++ test/qasm2/test_native.py | 2 + test/squin/test_typeinfer.py | 6 +- test/stim/dialects/stim/emit/base.py | 12 -- test/stim/dialects/stim/emit/test_stim_1q.py | 35 +++--- .../stim/dialects/stim/emit/test_stim_ctrl.py | 18 +-- .../dialects/stim/emit/test_stim_debug.py | 12 +- .../dialects/stim/emit/test_stim_detector.py | 15 ++- .../stim/dialects/stim/emit/test_stim_meas.py | 12 +- .../dialects/stim/emit/test_stim_noise.py | 22 ++-- .../dialects/stim/emit/test_stim_obs_inc.py | 12 +- .../dialects/stim/emit/test_stim_ppmeas.py | 16 ++- .../stim/emit/test_stim_qubit_coords.py | 12 +- test/stim/dialects/stim/emit/test_stim_spp.py | 15 ++- test/stim/dialects/stim/test_stim_circuits.py | 104 +++++++++++++---- test/stim/parse/base.py | 10 +- test/stim/passes/test_squin_debug_to_stim.py | 9 +- test/stim/passes/test_squin_meas_to_stim.py | 10 +- test/stim/passes/test_squin_noise_to_stim.py | 14 ++- test/stim/passes/test_squin_qubit_to_stim.py | 10 +- test/stim/test_measure_id_analysis.py | 11 +- 65 files changed, 685 insertions(+), 357 deletions(-) create mode 100644 test/qasm2/emit/t_qasm2.qasm delete mode 100644 test/stim/dialects/stim/emit/base.py diff --git a/src/bloqade/analysis/fidelity/analysis.py b/src/bloqade/analysis/fidelity/analysis.py index 763152fe..f1ad252f 100644 --- a/src/bloqade/analysis/fidelity/analysis.py +++ b/src/bloqade/analysis/fidelity/analysis.py @@ -83,12 +83,15 @@ def run_analysis( self, method: ir.Method, args: tuple | None = None, *, no_raise: bool = True ) -> tuple[ForwardFrame, Any]: self._run_address_analysis(method, no_raise=no_raise) - return super().run_analysis(method, args, no_raise=no_raise) + return super().run(method) def _run_address_analysis(self, method: ir.Method, no_raise: bool): addr_analysis = AddressAnalysis(self.dialects) - addr_frame, _ = addr_analysis.run_analysis(method=method, no_raise=no_raise) + addr_frame, _ = addr_analysis.run(method=method) self.addr_frame = addr_frame # NOTE: make sure we have as many probabilities as we have addresses self.atom_survival_probability = [1.0] * addr_analysis.qubit_count + + def method_self(self, method: ir.Method) -> EmptyLattice: + return self.lattice.bottom() diff --git a/src/bloqade/analysis/measure_id/analysis.py b/src/bloqade/analysis/measure_id/analysis.py index 6c014f66..f2d5e9f3 100644 --- a/src/bloqade/analysis/measure_id/analysis.py +++ b/src/bloqade/analysis/measure_id/analysis.py @@ -53,3 +53,6 @@ def get_const_value( return hint.data return None + + def method_self(self, method: ir.Method) -> MeasureId: + return self.lattice.bottom() diff --git a/src/bloqade/cirq_utils/emit/base.py b/src/bloqade/cirq_utils/emit/base.py index 8ef114ba..199e49dc 100644 --- a/src/bloqade/cirq_utils/emit/base.py +++ b/src/bloqade/cirq_utils/emit/base.py @@ -4,9 +4,9 @@ import cirq from kirin import ir, types, interp -from kirin.emit import EmitABC, EmitError, EmitFrame +from kirin.emit import EmitABC, EmitFrame from kirin.interp import MethodTable, impl -from kirin.dialects import py, func +from kirin.dialects import py, func, ilist from typing_extensions import Self from bloqade.squin import kernel @@ -102,7 +102,7 @@ def main(): and isinstance(mt.code, func.Function) and not mt.code.signature.output.is_subseteq(types.NoneType) ): - raise EmitError( + raise interp.exceptions.InterpreterError( "The method you are trying to convert to a circuit has a return value, but returning from a circuit is not supported." " Set `ignore_returns = True` in order to simply ignore the return values and emit a circuit." ) @@ -116,12 +116,14 @@ def main(): symbol_op_trait = mt.code.get_trait(ir.SymbolOpInterface) if (symbol_op_trait := mt.code.get_trait(ir.SymbolOpInterface)) is None: - raise EmitError("The method is not a symbol, cannot emit circuit!") + raise interp.exceptions.InterpreterError( + "The method is not a symbol, cannot emit circuit!" + ) sym_name = symbol_op_trait.get_sym_name(mt.code).unwrap() if (signature_trait := mt.code.get_trait(ir.HasSignature)) is None: - raise EmitError( + raise interp.exceptions.InterpreterError( f"The method {sym_name} does not have a signature, cannot emit circuit!" ) @@ -135,7 +137,7 @@ def main(): assert first_stmt is not None, "Method has no statements!" if len(args_ssa) - 1 != len(args): - raise EmitError( + raise interp.exceptions.InterpreterError( f"The method {sym_name} takes {len(args_ssa) - 1} arguments, but you passed in {len(args)} via the `args` keyword!" ) @@ -147,17 +149,22 @@ def main(): new_func = func.Function( sym_name=sym_name, body=callable_region, signature=new_signature ) - mt_ = ir.Method(None, None, sym_name, [], mt.dialects, new_func) + mt_ = ir.Method( + dialects=mt.dialects, + code=new_func, + sym_name=sym_name, + ) AggressiveUnroll(mt_.dialects).fixpoint(mt_) - return emitter.run(mt_, args=()) + emitter.initialize() + emitter.run(mt_) + return emitter.circuit @dataclass class EmitCirqFrame(EmitFrame): qubit_index: int = 0 qubits: Sequence[cirq.Qid] | None = None - circuit: cirq.Circuit = field(default_factory=cirq.Circuit) def _default_kernel(): @@ -166,23 +173,24 @@ def _default_kernel(): @dataclass class EmitCirq(EmitABC[EmitCirqFrame, cirq.Circuit]): - keys = ["emit.cirq", "main"] + keys = ("emit.cirq", "emit.main") dialects: ir.DialectGroup = field(default_factory=_default_kernel) void = cirq.Circuit() qubits: Sequence[cirq.Qid] | None = None + circuit: cirq.Circuit = field(default_factory=cirq.Circuit) def initialize(self) -> Self: return super().initialize() def initialize_frame( - self, code: ir.Statement, *, has_parent_access: bool = False + self, node: ir.Statement, *, has_parent_access: bool = False ) -> EmitCirqFrame: return EmitCirqFrame( - code, has_parent_access=has_parent_access, qubits=self.qubits + node, has_parent_access=has_parent_access, qubits=self.qubits ) def run_method(self, method: ir.Method, args: tuple[cirq.Circuit, ...]): - return self.run_callable(method.code, args) + return self.call(method, *args) def run_callable_region( self, @@ -196,7 +204,7 @@ def run_callable_region( # NOTE: skip self arg frame.set_values(block_args[1:], args) - results = self.eval_stmt(frame, code) + results = self.frame_eval(frame, code) if isinstance(results, tuple): if len(results) == 0: return self.void @@ -206,11 +214,17 @@ def run_callable_region( def emit_block(self, frame: EmitCirqFrame, block: ir.Block) -> cirq.Circuit: for stmt in block.stmts: - result = self.eval_stmt(frame, stmt) + result = self.frame_eval(frame, stmt) if isinstance(result, tuple): frame.set_values(stmt.results, result) - return frame.circuit + return self.circuit + + def reset(self): + pass + + def eval_fallback(self, frame: EmitCirqFrame, node: ir.Statement) -> tuple: + return tuple(None for _ in range(len(node.results))) @func.dialect.register(key="emit.cirq") @@ -218,21 +232,25 @@ class __FuncEmit(MethodTable): @impl(func.Function) def emit_func(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: func.Function): - emit.run_ssacfg_region(frame, stmt.body, ()) - return (frame.circuit,) + for block in stmt.body.blocks: + frame.current_block = block + for s in block.stmts: + frame.current_stmt = s + stmt_results = emit.frame_eval(frame, s) + if isinstance(stmt_results, tuple): + if len(stmt_results) != 0: + frame.set_values(s.results, stmt_results) + continue + + return (emit.circuit,) @impl(func.Invoke) def emit_invoke(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: func.Invoke): - raise EmitError( + raise interp.exceptions.InterpreterError( "Function invokes should need to be inlined! " "If you called the emit_circuit method, that should have happened, please report this issue." ) - @impl(func.Return) - def return_(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: func.Return): - # NOTE: should only be hit if ignore_returns == True - return () - @py.indexing.dialect.register(key="emit.cirq") class __Concrete(interp.MethodTable): @@ -241,3 +259,19 @@ class __Concrete(interp.MethodTable): def getindex(self, interp, frame: interp.Frame, stmt: py.indexing.GetItem): # NOTE: no support for indexing into single statements in cirq return () + + @interp.impl(py.Constant) + def emit_constant(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: py.Constant): + return (stmt.value.data,) # pyright: ignore[reportAttributeAccessIssue] + + +@ilist.dialect.register(key="emit.cirq") +class __IList(interp.MethodTable): + @interp.impl(ilist.New) + def new_ilist( + self, + emit: EmitCirq, + frame: interp.Frame, + stmt: ilist.New, + ): + return (ilist.IList(data=frame.get_values(stmt.values)),) diff --git a/src/bloqade/cirq_utils/emit/gate.py b/src/bloqade/cirq_utils/emit/gate.py index 78f9ce14..c884e9c3 100644 --- a/src/bloqade/cirq_utils/emit/gate.py +++ b/src/bloqade/cirq_utils/emit/gate.py @@ -20,7 +20,7 @@ def hermitian( ): qubits = frame.get(stmt.qubits) cirq_op = getattr(cirq, stmt.name.upper()) - frame.circuit.append(cirq_op.on_each(qubits)) + emit.circuit.append(cirq_op.on_each(qubits)) return () @impl(gate.stmts.S) @@ -36,7 +36,7 @@ def unitary( if stmt.adjoint: cirq_op = cirq_op ** (-1) - frame.circuit.append(cirq_op.on_each(qubits)) + emit.circuit.append(cirq_op.on_each(qubits)) return () @impl(gate.stmts.SqrtX) @@ -58,7 +58,7 @@ def sqrt( else: cirq_op = cirq.YPowGate(exponent=exponent) - frame.circuit.append(cirq_op.on_each(qubits)) + emit.circuit.append(cirq_op.on_each(qubits)) return () @impl(gate.stmts.CX) @@ -71,7 +71,7 @@ def control( targets = frame.get(stmt.targets) cirq_op = getattr(cirq, stmt.name.upper()) cirq_qubits = [(ctrl, target) for ctrl, target in zip(controls, targets)] - frame.circuit.append(cirq_op.on_each(cirq_qubits)) + emit.circuit.append(cirq_op.on_each(cirq_qubits)) return () @impl(gate.stmts.Rx) @@ -84,7 +84,7 @@ def rot(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: gate.stmts.RotationGat angle = turns * 2 * math.pi cirq_op = getattr(cirq, stmt.name.title())(rads=angle) - frame.circuit.append(cirq_op.on_each(qubits)) + emit.circuit.append(cirq_op.on_each(qubits)) return () @impl(gate.stmts.U3) @@ -95,10 +95,10 @@ def u3(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: gate.stmts.U3): phi = frame.get(stmt.phi) * 2 * math.pi lam = frame.get(stmt.lam) * 2 * math.pi - frame.circuit.append(cirq.Rz(rads=lam).on_each(*qubits)) + emit.circuit.append(cirq.Rz(rads=lam).on_each(*qubits)) - frame.circuit.append(cirq.Ry(rads=theta).on_each(*qubits)) + emit.circuit.append(cirq.Ry(rads=theta).on_each(*qubits)) - frame.circuit.append(cirq.Rz(rads=phi).on_each(*qubits)) + emit.circuit.append(cirq.Rz(rads=phi).on_each(*qubits)) return () diff --git a/src/bloqade/cirq_utils/emit/noise.py b/src/bloqade/cirq_utils/emit/noise.py index 70476c93..d69fa68c 100644 --- a/src/bloqade/cirq_utils/emit/noise.py +++ b/src/bloqade/cirq_utils/emit/noise.py @@ -34,7 +34,7 @@ def depolarize( p = frame.get(stmt.p) qubits = frame.get(stmt.qubits) cirfq_op = cirq.depolarize(p, n_qubits=1).on_each(qubits) - frame.circuit.append(cirfq_op) + interp.circuit.append(cirfq_op) return () @impl(noise.stmts.Depolarize2) @@ -46,7 +46,7 @@ def depolarize2( targets = frame.get(stmt.targets) cirq_qubits = [(ctrl, target) for ctrl, target in zip(controls, targets)] cirq_op = cirq.depolarize(p, n_qubits=2).on_each(cirq_qubits) - frame.circuit.append(cirq_op) + interp.circuit.append(cirq_op) return () @impl(noise.stmts.SingleQubitPauliChannel) @@ -62,7 +62,7 @@ def single_qubit_pauli_channel( qubits = frame.get(stmt.qubits) cirq_op = cirq.asymmetric_depolarize(px, py, pz).on_each(qubits) - frame.circuit.append(cirq_op) + interp.circuit.append(cirq_op) return () @@ -85,6 +85,6 @@ def two_qubit_pauli_channel( cirq_op = cirq.asymmetric_depolarize( error_probabilities=error_probabilities ).on_each(cirq_qubits) - frame.circuit.append(cirq_op) + interp.circuit.append(cirq_op) return () diff --git a/src/bloqade/cirq_utils/emit/qubit.py b/src/bloqade/cirq_utils/emit/qubit.py index 222d1798..51a84ffe 100644 --- a/src/bloqade/cirq_utils/emit/qubit.py +++ b/src/bloqade/cirq_utils/emit/qubit.py @@ -23,13 +23,13 @@ def measure_qubit_list( self, emit: EmitCirq, frame: EmitCirqFrame, stmt: qubit.Measure ): qbits = frame.get(stmt.qubits) - frame.circuit.append(cirq.measure(qbits)) + emit.circuit.append(cirq.measure(qbits)) return (emit.void,) @impl(qubit.Reset) def reset(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: qubit.Reset): qubits = frame.get(stmt.qubits) - frame.circuit.append( + emit.circuit.append( cirq.ResetChannel().on_each(*qubits), ) return () diff --git a/src/bloqade/cirq_utils/lowering.py b/src/bloqade/cirq_utils/lowering.py index 8806e879..7d81b45b 100644 --- a/src/bloqade/cirq_utils/lowering.py +++ b/src/bloqade/cirq_utils/lowering.py @@ -260,9 +260,7 @@ def run( # NOTE: create a new register of appropriate size n_qubits = len(self.qreg_index) n = frame.push(py.Constant(n_qubits)) - self.qreg = frame.push( - func.Invoke((n.result,), callee=qalloc, kwargs=()) - ).result + self.qreg = frame.push(func.Invoke((n.result,), callee=qalloc)).result self.visit(state, stmt) diff --git a/src/bloqade/native/upstream/__init__.py b/src/bloqade/native/upstream/__init__.py index 2d04fe3e..23b78ebf 100644 --- a/src/bloqade/native/upstream/__init__.py +++ b/src/bloqade/native/upstream/__init__.py @@ -1,5 +1,4 @@ from .squin2native import ( GateRule as GateRule, SquinToNative as SquinToNative, - SquinToNativePass as SquinToNativePass, ) diff --git a/src/bloqade/pyqrack/device.py b/src/bloqade/pyqrack/device.py index 60e840de..4615989a 100644 --- a/src/bloqade/pyqrack/device.py +++ b/src/bloqade/pyqrack/device.py @@ -353,7 +353,7 @@ def task( kwargs = {} address_analysis = AddressAnalysis(dialects=kernel.dialects) - frame, _ = address_analysis.run_analysis(kernel) + frame, _ = address_analysis.run(kernel) if self.min_qubits == 0 and any( isinstance(a, (UnknownQubit, UnknownReg)) for a in frame.entries.values() ): diff --git a/src/bloqade/pyqrack/target.py b/src/bloqade/pyqrack/target.py index 6f18e7a0..e9f21233 100644 --- a/src/bloqade/pyqrack/target.py +++ b/src/bloqade/pyqrack/target.py @@ -51,7 +51,7 @@ def _get_interp(self, mt: ir.Method[Params, RetType]): return PyQrackInterpreter(mt.dialects, memory=DynamicMemory(options)) else: address_analysis = AddressAnalysis(mt.dialects) - frame, _ = address_analysis.run_analysis(mt) + frame, _ = address_analysis.run(mt) if self.min_qubits == 0 and any( isinstance(a, UnknownQubit) for a in frame.entries.values() ): diff --git a/src/bloqade/qasm2/dialects/expr/_emit.py b/src/bloqade/qasm2/dialects/expr/_emit.py index f429cb85..2d95843a 100644 --- a/src/bloqade/qasm2/dialects/expr/_emit.py +++ b/src/bloqade/qasm2/dialects/expr/_emit.py @@ -20,7 +20,10 @@ def emit_func( args: list[ast.Node] = [] cparams, qparams = [], [] - for arg in stmt.body.blocks[0].args: + entry_args = stmt.body.blocks[0].args + user_args = entry_args[1:] if len(entry_args) > 0 else [] + + for arg in user_args: assert arg.name is not None args.append(ast.Name(id=arg.name)) @@ -29,14 +32,22 @@ def emit_func( else: cparams.append(arg.name) - emit.run_ssacfg_region(frame, stmt.body, tuple(args)) - emit.output = ast.Gate( - name=stmt.sym_name, - cparams=cparams, - qparams=qparams, - body=frame.body, + frame.worklist.append(interp.Successor(stmt.body.blocks[0], *args)) + if len(entry_args) > 0: + frame.set(entry_args[0], ast.Name(stmt.sym_name or "gate")) + + while (succ := frame.worklist.pop()) is not None: + frame.set_values(succ.block.args[1:], succ.block_args) + block_header = emit.emit_block(frame, succ.block) + frame.block_ref[succ.block] = block_header + return ( + ast.Gate( + name=stmt.sym_name, + cparams=cparams, + qparams=qparams, + body=frame.body, + ), ) - return () @interp.impl(stmts.ConstInt) @interp.impl(stmts.ConstFloat) diff --git a/src/bloqade/qasm2/emit/base.py b/src/bloqade/qasm2/emit/base.py index cd63c547..4f7fba1d 100644 --- a/src/bloqade/qasm2/emit/base.py +++ b/src/bloqade/qasm2/emit/base.py @@ -2,8 +2,9 @@ from typing import Generic, TypeVar, overload from dataclasses import field, dataclass -from kirin import ir, idtable -from kirin.emit import EmitABC, EmitError, EmitFrame +from kirin import ir, interp, idtable +from kirin.emit import EmitABC, EmitFrame +from kirin.worklist import WorkList from typing_extensions import Self from bloqade.qasm2.parse import ast @@ -15,6 +16,9 @@ @dataclass class EmitQASM2Frame(EmitFrame[ast.Node | None], Generic[StmtType]): body: list[StmtType] = field(default_factory=list) + worklist: WorkList[interp.Successor] = field(default_factory=WorkList) + block_ref: dict[ir.Block, ast.Node | None] = field(default_factory=dict) + _indent: int = 0 @dataclass @@ -37,18 +41,18 @@ def initialize(self) -> Self: return self def initialize_frame( - self, code: ir.Statement, *, has_parent_access: bool = False + self, node: ir.Statement, *, has_parent_access: bool = False ) -> EmitQASM2Frame[StmtType]: - return EmitQASM2Frame(code, has_parent_access=has_parent_access) + return EmitQASM2Frame(node, has_parent_access=has_parent_access) def run_method( self, method: ir.Method, args: tuple[ast.Node | None, ...] ) -> tuple[EmitQASM2Frame[StmtType], ast.Node | None]: - return self.run_callable(method.code, (ast.Name(method.sym_name),) + args) + return self.call(method, *args) def emit_block(self, frame: EmitQASM2Frame, block: ir.Block) -> ast.Node | None: for stmt in block.stmts: - result = self.eval_stmt(frame, stmt) + result = self.frame_eval(frame, stmt) if isinstance(result, tuple): frame.set_values(stmt.results, result) return None @@ -70,5 +74,11 @@ def assert_node( node: ast.Node | None, ) -> A | B: if not isinstance(node, typ): - raise EmitError(f"expected {typ}, got {type(node)}") + raise TypeError(f"expected {typ}, got {type(node)}") return node + + def reset(self): + pass + + def eval_fallback(self, frame: EmitQASM2Frame, node: ir.Statement): + return tuple(None for _ in range(len(node.results))) diff --git a/src/bloqade/qasm2/emit/gate.py b/src/bloqade/qasm2/emit/gate.py index ae7c4f30..ebed83b5 100644 --- a/src/bloqade/qasm2/emit/gate.py +++ b/src/bloqade/qasm2/emit/gate.py @@ -3,11 +3,12 @@ from kirin import ir, types, interp from kirin.dialects import py, func, ilist from kirin.ir.dialect import Dialect as Dialect +from typing_extensions import Self from bloqade.types import QubitType from bloqade.qasm2.parse import ast -from .base import EmitError, EmitQASM2Base, EmitQASM2Frame +from .base import EmitQASM2Base, EmitQASM2Frame def _default_dialect_group(): @@ -18,9 +19,13 @@ def _default_dialect_group(): @dataclass class EmitQASM2Gate(EmitQASM2Base[ast.UOp | ast.Barrier, ast.Gate]): - keys = ["emit.qasm2.gate"] + keys = ("emit.qasm2.gate",) dialects: ir.DialectGroup = field(default_factory=_default_dialect_group) + def initialize(self) -> Self: + super().initialize() + return self + @ilist.dialect.register(key="emit.qasm2.gate") class Ilist(interp.MethodTable): @@ -45,7 +50,7 @@ class Func(interp.MethodTable): @interp.impl(func.Call) def emit_call(self, emit: EmitQASM2Gate, frame: EmitQASM2Frame, stmt: func.Call): - raise EmitError("cannot emit dynamic call") + raise RuntimeError("cannot emit dynamic call") @interp.impl(func.Invoke) def emit_invoke( @@ -55,7 +60,7 @@ def emit_invoke( if len(stmt.results) == 1 and stmt.results[0].type.is_subseteq(types.NoneType): ret = (None,) elif len(stmt.results) > 0: - raise EmitError( + raise RuntimeError( "cannot emit invoke with results, this " "is not compatible QASM2 gate routine" " (consider pass qreg/creg by argument)" @@ -67,10 +72,9 @@ def emit_invoke( qparams.append(frame.get(arg)) else: cparams.append(frame.get(arg)) - frame.body.append( ast.Instruction( - name=ast.Name(stmt.callee.sym_name), + name=ast.Name(stmt.callee.__getattribute__("sym_name")), params=cparams, qargs=qparams, ) @@ -80,9 +84,8 @@ def emit_invoke( @interp.impl(func.Lambda) @interp.impl(func.GetField) def emit_err(self, emit: EmitQASM2Gate, frame: EmitQASM2Frame, stmt): - raise EmitError(f"illegal statement {stmt.name} for QASM2 gate routine") + raise RuntimeError(f"illegal statement {stmt.name} for QASM2 gate routine") @interp.impl(func.Return) - @interp.impl(func.ConstantNone) def ignore(self, emit: EmitQASM2Gate, frame: EmitQASM2Frame, stmt): return () diff --git a/src/bloqade/qasm2/emit/main.py b/src/bloqade/qasm2/emit/main.py index 8bb5810c..7a98ab0e 100644 --- a/src/bloqade/qasm2/emit/main.py +++ b/src/bloqade/qasm2/emit/main.py @@ -1,8 +1,10 @@ +from typing import List, cast from dataclasses import dataclass from kirin import ir, interp from kirin.dialects import cf, scf, func from kirin.ir.dialect import Dialect as Dialect +from typing_extensions import Self from bloqade.qasm2.parse import ast from bloqade.qasm2.dialects.uop import SingleQubitGate, TwoQubitCtrlGate @@ -14,26 +16,124 @@ @dataclass class EmitQASM2Main(EmitQASM2Base[ast.Statement, ast.MainProgram]): - keys = ["emit.qasm2.main", "emit.qasm2.gate"] + keys = ("emit.qasm2.main", "emit.qasm2.gate") dialects: ir.DialectGroup + def initialize(self) -> Self: + super().initialize() + return self + + def eval_fallback(self, frame: EmitQASM2Frame, node: ir.Statement): + return tuple(None for _ in range(len(node.results))) + @func.dialect.register(key="emit.qasm2.main") class Func(interp.MethodTable): + @interp.impl(func.Invoke) + def invoke(self, emit: EmitQASM2Main, frame: EmitQASM2Frame, node: func.Invoke): + name = emit.callables.get(node.callee.code) + if name is None: + name = emit.callables.add(node.callee.code) + emit.callable_to_emit.append(node.callee.code) + + if isinstance(node.callee.code, GateFunction): + c_params: list[ast.Expr] = [] + q_args: list[ast.Bit | ast.Name] = [] + + for arg in node.args: + val = frame.get(arg) + if val is None: + raise interp.InterpreterError(f"missing mapping for arg {arg}") + if isinstance(val, (ast.Bit, ast.Name)): + q_args.append(val) + elif isinstance(val, ast.Expr): + c_params.append(val) + + instr = ast.Instruction( + name=ast.Name(name) if isinstance(name, str) else name, + params=c_params, + qargs=q_args, + ) + frame.body.append(instr) + return () + + callee_name_node = ast.Name(name) if isinstance(name, str) else name + args = tuple(frame.get_values(node.args)) + _, call_expr = emit.call(node.callee.code, callee_name_node, *args) + if call_expr is not None: + frame.body.append(call_expr) + return () @interp.impl(func.Function) def emit_func( self, emit: EmitQASM2Main, frame: EmitQASM2Frame, stmt: func.Function ): from bloqade.qasm2.dialects import glob, parallel + from bloqade.qasm2.emit.gate import EmitQASM2Gate + + if isinstance(stmt, GateFunction): + return () + + func_name = emit.callables.get(stmt) + if func_name is None: + func_name = emit.callables.add(stmt) + + for block in stmt.body.blocks: + frame.current_block = block + for s in block.stmts: + frame.current_stmt = s + stmt_results = emit.frame_eval(frame, s) + if isinstance(stmt_results, tuple): + if len(stmt_results) != 0: + frame.set_values(s._results, stmt_results) + continue + + gate_defs: list[ast.Gate] = [] + + gate_emitter = EmitQASM2Gate(dialects=emit.dialects).initialize() + gate_emitter.callables = emit.callables + + while emit.callable_to_emit: + callable_node = emit.callable_to_emit.pop() + if callable_node is None: + break + + if isinstance(callable_node, GateFunction): + with gate_emitter.eval_context(): + with gate_emitter.new_frame( + callable_node, has_parent_access=False + ) as gate_frame: + gate_result = gate_emitter.frame_eval(gate_frame, callable_node) + gate_obj = None + if isinstance(gate_result, tuple) and len(gate_result) > 0: + maybe = gate_result[0] + if isinstance(maybe, ast.Gate): + gate_obj = maybe + + if gate_obj is None: + name = emit.callables.get( + callable_node + ) or emit.callables.add(callable_node) + prefix = getattr(emit.callables, "prefix", "") or "" + emit_name = ( + name[len(prefix) :] + if prefix and name.startswith(prefix) + else name + ) + gate_obj = ast.Gate( + name=emit_name, cparams=[], qparams=[], body=[] + ) + + gate_defs.append(gate_obj) - emit.run_ssacfg_region(frame, stmt.body, ()) if emit.dialects.data.intersection((parallel.dialect, glob.dialect)): header = ast.Kirin([dialect.name for dialect in emit.dialects]) else: header = ast.OPENQASM(ast.Version(2, 0)) - emit.output = ast.MainProgram(header=header, statements=frame.body) + full_body = gate_defs + frame.body + stmt_list = cast(List[ast.Statement], full_body) + emit.output = ast.MainProgram(header=header, statements=stmt_list) return () diff --git a/src/bloqade/qasm2/emit/target.py b/src/bloqade/qasm2/emit/target.py index 53049b88..a2548bba 100644 --- a/src/bloqade/qasm2/emit/target.py +++ b/src/bloqade/qasm2/emit/target.py @@ -106,17 +106,17 @@ def emit(self, entry: ir.Method) -> ast.MainProgram: unroll_ifs=self.unroll_ifs, ).fixpoint(entry) - if not self.allow_global: - # rewrite global to parallel - GlobalToParallel(dialects=entry.dialects)(entry) + # if not self.allow_global: + # # rewrite global to parallel + # GlobalToParallel(dialects=entry.dialects)(entry) - if not self.allow_parallel: - # rewrite parallel to uop - ParallelToUOp(dialects=entry.dialects)(entry) + # if not self.allow_parallel: + # # rewrite parallel to uop + # ParallelToUOp(dialects=entry.dialects)(entry) Py2QASM(entry.dialects)(entry) - target_main = EmitQASM2Main(self.main_target) - target_main.run(entry, ()) + target_main = EmitQASM2Main(self.main_target).initialize() + target_main.run(entry) main_program = target_main.output assert main_program is not None, f"failed to emit {entry.sym_name}" @@ -127,7 +127,7 @@ def emit(self, entry: ir.Method) -> ast.MainProgram: if self.custom_gate: cg = CallGraph(entry) - target_gate = EmitQASM2Gate(self.gate_target) + target_gate = EmitQASM2Gate(self.gate_target).initialize() for _, fns in cg.defs.items(): if len(fns) != 1: @@ -150,7 +150,7 @@ def emit(self, entry: ir.Method) -> ast.MainProgram: Py2QASM(fn.dialects)(fn) - target_gate.run(fn, tuple(ast.Name(name) for name in fn.arg_names[1:])) + target_gate.run(fn) assert target_gate.output is not None, f"failed to emit {fn.sym_name}" extra.append(target_gate.output) diff --git a/src/bloqade/qasm2/groups.py b/src/bloqade/qasm2/groups.py index 280638c0..4e495562 100644 --- a/src/bloqade/qasm2/groups.py +++ b/src/bloqade/qasm2/groups.py @@ -1,6 +1,6 @@ from kirin import ir, passes from kirin.prelude import structural_no_opt -from kirin.dialects import scf, func, ilist, lowering +from kirin.dialects import scf, func, ilist, ssacfg, lowering from bloqade.qasm2.dialects import ( uop, @@ -15,7 +15,7 @@ from bloqade.qasm2.rewrite.desugar import IndexingDesugarPass -@ir.dialect_group([uop, func, expr, lowering.func, lowering.call]) +@ir.dialect_group([uop, func, expr, lowering.func, lowering.call, ssacfg]) def gate(self): fold_pass = passes.Fold(self) typeinfer_pass = passes.TypeInfer(self) @@ -58,6 +58,7 @@ def run_pass( func, lowering.func, lowering.call, + ssacfg, ] ) def main(self): diff --git a/src/bloqade/qasm2/passes/glob.py b/src/bloqade/qasm2/passes/glob.py index 99509ca0..ab8c3c9f 100644 --- a/src/bloqade/qasm2/passes/glob.py +++ b/src/bloqade/qasm2/passes/glob.py @@ -51,7 +51,7 @@ def main(): """ def generate_rule(self, mt: ir.Method) -> GlobalToUOpRule: - frame, _ = address.AddressAnalysis(mt.dialects).run_analysis(mt) + frame, _ = address.AddressAnalysis(mt.dialects).run(mt) return GlobalToUOpRule(frame.entries) def unsafe_run(self, mt: ir.Method) -> abc.RewriteResult: @@ -105,7 +105,7 @@ def main(): """ def generate_rule(self, mt: ir.Method) -> GlobalToParallelRule: - frame, _ = address.AddressAnalysis(mt.dialects).run_analysis(mt) + frame, _ = address.AddressAnalysis(mt.dialects).run(mt) return GlobalToParallelRule(frame.entries) def unsafe_run(self, mt: ir.Method) -> abc.RewriteResult: diff --git a/src/bloqade/qasm2/passes/noise.py b/src/bloqade/qasm2/passes/noise.py index 2d077a76..ad6206b6 100644 --- a/src/bloqade/qasm2/passes/noise.py +++ b/src/bloqade/qasm2/passes/noise.py @@ -55,7 +55,7 @@ def __post_init__(self): self.address_analysis = address.AddressAnalysis(self.dialects) def get_qubit_values(self, mt: ir.Method): - frame, _ = self.address_analysis.run_analysis(mt, no_raise=self.no_raise) + frame, _ = self.address_analysis.run(mt) qubit_ssa_values = {} # Traverse statements in block order to fine the first SSA value for each qubit for block in mt.callable_region.blocks: diff --git a/src/bloqade/qasm2/passes/parallel.py b/src/bloqade/qasm2/passes/parallel.py index 977acc2b..0951673e 100644 --- a/src/bloqade/qasm2/passes/parallel.py +++ b/src/bloqade/qasm2/passes/parallel.py @@ -63,7 +63,7 @@ def main(): """ def generate_rule(self, mt: ir.Method) -> ParallelToUOpRule: - frame, _ = address.AddressAnalysis(mt.dialects).run_analysis(mt) + frame, _ = address.AddressAnalysis(mt.dialects).run(mt) id_map = {} @@ -159,10 +159,10 @@ def unsafe_run(self, mt: ir.Method) -> abc.RewriteResult: .join(result) ) - frame, _ = self.constprop.run_analysis(mt) + frame, _ = self.constprop.run(mt) result = Walk(WrapConst(frame)).rewrite(mt.code).join(result) - frame, _ = address.AddressAnalysis(mt.dialects).run_analysis(mt) + frame, _ = address.AddressAnalysis(mt.dialects).run(mt) dags = schedule.DagScheduleAnalysis( mt.dialects, address_analysis=frame.entries ).get_dags(mt) @@ -193,7 +193,7 @@ class ParallelToGlobal(Pass): def generate_rule(self, mt: ir.Method) -> ParallelToGlobalRule: address_analysis = address.AddressAnalysis(mt.dialects) - frame, _ = address_analysis.run_analysis(mt) + frame, _ = address_analysis.run(mt) return ParallelToGlobalRule(frame.entries) def unsafe_run(self, mt: ir.Method) -> abc.RewriteResult: diff --git a/src/bloqade/squin/analysis/schedule.py b/src/bloqade/squin/analysis/schedule.py index ae8ac57d..e99e219d 100644 --- a/src/bloqade/squin/analysis/schedule.py +++ b/src/bloqade/squin/analysis/schedule.py @@ -226,7 +226,7 @@ def get_dags(self, mt: ir.Method, args=None, kwargs=None): if args is None: args = tuple(self.lattice.top() for _ in mt.args) - self.run(mt, args, kwargs) + self.run(mt) return self.stmt_dags diff --git a/src/bloqade/stim/dialects/auxiliary/emit.py b/src/bloqade/stim/dialects/auxiliary/emit.py index ae34e22f..079fee6b 100644 --- a/src/bloqade/stim/dialects/auxiliary/emit.py +++ b/src/bloqade/stim/dialects/auxiliary/emit.py @@ -1,7 +1,6 @@ -from kirin.emit import EmitStrFrame from kirin.interp import MethodTable, impl -from bloqade.stim.emit.stim_str import EmitStimMain +from bloqade.stim.emit.stim_str import EmitStimMain, EmitStimFrame from . import stmts from ._dialect import dialect @@ -11,7 +10,7 @@ class EmitStimAuxMethods(MethodTable): @impl(stmts.ConstInt) - def const_int(self, emit: EmitStimMain, frame: EmitStrFrame, stmt: stmts.ConstInt): + def const_int(self, emit: EmitStimMain, frame: EmitStimFrame, stmt: stmts.ConstInt): out: str = f"{stmt.value}" @@ -19,7 +18,7 @@ def const_int(self, emit: EmitStimMain, frame: EmitStrFrame, stmt: stmts.ConstIn @impl(stmts.ConstFloat) def const_float( - self, emit: EmitStimMain, frame: EmitStrFrame, stmt: stmts.ConstFloat + self, emit: EmitStimMain, frame: EmitStimFrame, stmt: stmts.ConstFloat ): out: str = f"{stmt.value:.8f}" @@ -28,26 +27,28 @@ def const_float( @impl(stmts.ConstBool) def const_bool( - self, emit: EmitStimMain, frame: EmitStrFrame, stmt: stmts.ConstBool + self, emit: EmitStimMain, frame: EmitStimFrame, stmt: stmts.ConstBool ): out: str = "!" if stmt.value else "" return (out,) @impl(stmts.ConstStr) - def const_str(self, emit: EmitStimMain, frame: EmitStrFrame, stmt: stmts.ConstBool): + def const_str( + self, emit: EmitStimMain, frame: EmitStimFrame, stmt: stmts.ConstBool + ): return (stmt.value,) @impl(stmts.Neg) - def neg(self, emit: EmitStimMain, frame: EmitStrFrame, stmt: stmts.Neg): + def neg(self, emit: EmitStimMain, frame: EmitStimFrame, stmt: stmts.Neg): operand: str = frame.get(stmt.operand) return ("-" + operand,) @impl(stmts.GetRecord) - def get_rec(self, emit: EmitStimMain, frame: EmitStrFrame, stmt: stmts.GetRecord): + def get_rec(self, emit: EmitStimMain, frame: EmitStimFrame, stmt: stmts.GetRecord): id: str = frame.get(stmt.id) out: str = f"rec[{id}]" @@ -55,14 +56,14 @@ def get_rec(self, emit: EmitStimMain, frame: EmitStrFrame, stmt: stmts.GetRecord return (out,) @impl(stmts.Tick) - def tick(self, emit: EmitStimMain, frame: EmitStrFrame, stmt: stmts.Tick): + def tick(self, emit: EmitStimMain, frame: EmitStimFrame, stmt: stmts.Tick): - emit.writeln(frame, "TICK") + frame.write_line("TICK") return () @impl(stmts.Detector) - def detector(self, emit: EmitStimMain, frame: EmitStrFrame, stmt: stmts.Detector): + def detector(self, emit: EmitStimMain, frame: EmitStimFrame, stmt: stmts.Detector): coords: tuple[str, ...] = frame.get_values(stmt.coord) targets: tuple[str, ...] = frame.get_values(stmt.targets) @@ -70,27 +71,27 @@ def detector(self, emit: EmitStimMain, frame: EmitStrFrame, stmt: stmts.Detector coord_str: str = ", ".join(coords) target_str: str = " ".join(targets) if len(coords): - emit.writeln(frame, f"DETECTOR({coord_str}) {target_str}") + frame.write_line(f"DETECTOR({coord_str}) {target_str}") else: - emit.writeln(frame, f"DETECTOR {target_str}") + frame.write_line(f"DETECTOR {target_str}") return () @impl(stmts.ObservableInclude) def obs_include( - self, emit: EmitStimMain, frame: EmitStrFrame, stmt: stmts.ObservableInclude + self, emit: EmitStimMain, frame: EmitStimFrame, stmt: stmts.ObservableInclude ): idx: str = frame.get(stmt.idx) targets: tuple[str, ...] = frame.get_values(stmt.targets) target_str: str = " ".join(targets) - emit.writeln(frame, f"OBSERVABLE_INCLUDE({idx}) {target_str}") + frame.write_line(f"OBSERVABLE_INCLUDE({idx}) {target_str}") return () @impl(stmts.NewPauliString) def new_paulistr( - self, emit: EmitStimMain, frame: EmitStrFrame, stmt: stmts.NewPauliString + self, emit: EmitStimMain, frame: EmitStimFrame, stmt: stmts.NewPauliString ): string: tuple[str, ...] = frame.get_values(stmt.string) @@ -105,13 +106,13 @@ def new_paulistr( @impl(stmts.QubitCoordinates) def qubit_coordinates( - self, emit: EmitStimMain, frame: EmitStrFrame, stmt: stmts.QubitCoordinates + self, emit: EmitStimMain, frame: EmitStimFrame, stmt: stmts.QubitCoordinates ): coords: tuple[str, ...] = frame.get_values(stmt.coord) target: str = frame.get(stmt.target) coord_str: str = ", ".join(coords) - emit.writeln(frame, f"QUBIT_COORDS({coord_str}) {target}") + frame.write_line(f"QUBIT_COORDS({coord_str}) {target}") return () diff --git a/src/bloqade/stim/dialects/collapse/emit_str.py b/src/bloqade/stim/dialects/collapse/emit_str.py index 85b29fd6..7a99b94f 100644 --- a/src/bloqade/stim/dialects/collapse/emit_str.py +++ b/src/bloqade/stim/dialects/collapse/emit_str.py @@ -1,7 +1,6 @@ -from kirin.emit import EmitStrFrame from kirin.interp import MethodTable, impl -from bloqade.stim.emit.stim_str import EmitStimMain +from bloqade.stim.emit.stim_str import EmitStimMain, EmitStimFrame from . import stmts from ._dialect import dialect @@ -27,13 +26,13 @@ class EmitStimCollapseMethods(MethodTable): @impl(stmts.MXX) @impl(stmts.MYY) @impl(stmts.MZZ) - def get_measure(self, emit: EmitStimMain, frame: EmitStrFrame, stmt: Measurement): + def get_measure(self, emit: EmitStimMain, frame: EmitStimFrame, stmt: Measurement): probability: str = frame.get(stmt.p) targets: tuple[str, ...] = frame.get_values(stmt.targets) out = f"{self.meas_map[stmt.name]}({probability}) " + " ".join(targets) - emit.writeln(frame, out) + frame.write_line(out) return () @@ -46,18 +45,18 @@ def get_measure(self, emit: EmitStimMain, frame: EmitStrFrame, stmt: Measurement @impl(stmts.RX) @impl(stmts.RY) @impl(stmts.RZ) - def get_reset(self, emit: EmitStimMain, frame: EmitStrFrame, stmt: Reset): + def get_reset(self, emit: EmitStimMain, frame: EmitStimFrame, stmt: Reset): targets: tuple[str, ...] = frame.get_values(stmt.targets) out = f"{self.reset_map[stmt.name]} " + " ".join(targets) - emit.writeln(frame, out) + frame.write_line(out) return () @impl(stmts.PPMeasurement) def pp_measure( - self, emit: EmitStimMain, frame: EmitStrFrame, stmt: stmts.PPMeasurement + self, emit: EmitStimMain, frame: EmitStimFrame, stmt: stmts.PPMeasurement ): probability: str = frame.get(stmt.p) targets: tuple[str, ...] = tuple( @@ -65,6 +64,6 @@ def pp_measure( ) out = f"MPP({probability}) " + " ".join(targets) - emit.writeln(frame, out) + frame.write_line(out) return () diff --git a/src/bloqade/stim/dialects/gate/emit.py b/src/bloqade/stim/dialects/gate/emit.py index 510e2117..c275903a 100644 --- a/src/bloqade/stim/dialects/gate/emit.py +++ b/src/bloqade/stim/dialects/gate/emit.py @@ -1,7 +1,6 @@ -from kirin.emit import EmitStrFrame from kirin.interp import MethodTable, impl -from bloqade.stim.emit.stim_str import EmitStimMain +from bloqade.stim.emit.stim_str import EmitStimMain, EmitStimFrame from . import stmts from ._dialect import dialect @@ -33,11 +32,11 @@ class EmitStimGateMethods(MethodTable): @impl(stmts.SqrtY) @impl(stmts.SqrtZ) def single_qubit_gate( - self, emit: EmitStimMain, frame: EmitStrFrame, stmt: SingleQubitGate + self, emit: EmitStimMain, frame: EmitStimFrame, stmt: SingleQubitGate ): targets: tuple[str, ...] = frame.get_values(stmt.targets) res = f"{self.gate_1q_map[stmt.name][int(stmt.dagger)]} " + " ".join(targets) - emit.writeln(frame, res) + frame.write_line(res) return () @@ -47,13 +46,13 @@ def single_qubit_gate( @impl(stmts.Swap) def two_qubit_gate( - self, emit: EmitStimMain, frame: EmitStrFrame, stmt: ControlledTwoQubitGate + self, emit: EmitStimMain, frame: EmitStimFrame, stmt: ControlledTwoQubitGate ): targets: tuple[str, ...] = frame.get_values(stmt.targets) res = f"{self.gate_ctrl_2q_map[stmt.name][int(stmt.dagger)]} " + " ".join( targets ) - emit.writeln(frame, res) + frame.write_line(res) return () @@ -68,19 +67,19 @@ def two_qubit_gate( @impl(stmts.CY) @impl(stmts.CZ) def ctrl_two_qubit_gate( - self, emit: EmitStimMain, frame: EmitStrFrame, stmt: ControlledTwoQubitGate + self, emit: EmitStimMain, frame: EmitStimFrame, stmt: ControlledTwoQubitGate ): controls: tuple[str, ...] = frame.get_values(stmt.controls) targets: tuple[str, ...] = frame.get_values(stmt.targets) res = f"{self.gate_ctrl_2q_map[stmt.name][int(stmt.dagger)]} " + " ".join( f"{ctrl} {tgt}" for ctrl, tgt in zip(controls, targets) ) - emit.writeln(frame, res) + frame.write_line(res) return () @impl(stmts.SPP) - def spp(self, emit: EmitStimMain, frame: EmitStrFrame, stmt: stmts.SPP): + def spp(self, emit: EmitStimMain, frame: EmitStimFrame, stmt: stmts.SPP): targets: tuple[str, ...] = tuple( targ.upper() for targ in frame.get_values(stmt.targets) @@ -89,6 +88,6 @@ def spp(self, emit: EmitStimMain, frame: EmitStrFrame, stmt: stmts.SPP): res = "SPP_DAG " + " ".join(targets) else: res = "SPP " + " ".join(targets) - emit.writeln(frame, res) + frame.write_line(res) return () diff --git a/src/bloqade/stim/dialects/noise/emit.py b/src/bloqade/stim/dialects/noise/emit.py index 41e57a06..901e5c88 100644 --- a/src/bloqade/stim/dialects/noise/emit.py +++ b/src/bloqade/stim/dialects/noise/emit.py @@ -1,7 +1,6 @@ -from kirin.emit import EmitStrFrame from kirin.interp import MethodTable, impl -from bloqade.stim.emit.stim_str import EmitStimMain +from bloqade.stim.emit.stim_str import EmitStimMain, EmitStimFrame from . import stmts from ._dialect import dialect @@ -24,20 +23,20 @@ class EmitStimNoiseMethods(MethodTable): @impl(stmts.Depolarize1) @impl(stmts.Depolarize2) def single_p_error( - self, emit: EmitStimMain, frame: EmitStrFrame, stmt: stmts.Depolarize1 + self, emit: EmitStimMain, frame: EmitStimFrame, stmt: stmts.Depolarize1 ): targets: tuple[str, ...] = frame.get_values(stmt.targets) p: str = frame.get(stmt.p) name = self.single_p_error_map[stmt.name] res = f"{name}({p}) " + " ".join(targets) - emit.writeln(frame, res) + frame.write_line(res) return () @impl(stmts.PauliChannel1) def pauli_channel1( - self, emit: EmitStimMain, frame: EmitStrFrame, stmt: stmts.PauliChannel1 + self, emit: EmitStimMain, frame: EmitStimFrame, stmt: stmts.PauliChannel1 ): targets: tuple[str, ...] = frame.get_values(stmt.targets) @@ -45,13 +44,13 @@ def pauli_channel1( py: str = frame.get(stmt.py) pz: str = frame.get(stmt.pz) res = f"PAULI_CHANNEL_1({px}, {py}, {pz}) " + " ".join(targets) - emit.writeln(frame, res) + frame.write_line(res) return () @impl(stmts.PauliChannel2) def pauli_channel2( - self, emit: EmitStimMain, frame: EmitStrFrame, stmt: stmts.PauliChannel2 + self, emit: EmitStimMain, frame: EmitStimFrame, stmt: stmts.PauliChannel2 ): targets: tuple[str, ...] = frame.get_values(stmt.targets) @@ -61,14 +60,14 @@ def pauli_channel2( prob_str: str = ", ".join(prob) res = f"PAULI_CHANNEL_2({prob_str}) " + " ".join(targets) - emit.writeln(frame, res) + frame.write_line(res) return () @impl(stmts.TrivialError) @impl(stmts.QubitLoss) def non_stim_error( - self, emit: EmitStimMain, frame: EmitStrFrame, stmt: stmts.TrivialError + self, emit: EmitStimMain, frame: EmitStimFrame, stmt: stmts.TrivialError ): targets: tuple[str, ...] = frame.get_values(stmt.targets) @@ -76,7 +75,7 @@ def non_stim_error( prob_str: str = ", ".join(prob) res = f"I_ERROR[{stmt.name}]({prob_str}) " + " ".join(targets) - emit.writeln(frame, res) + frame.write_line(res) return () @@ -85,7 +84,7 @@ def non_stim_error( def non_stim_corr_error( self, emit: EmitStimMain, - frame: EmitStrFrame, + frame: EmitStimFrame, stmt: stmts.TrivialCorrelatedError, ): @@ -98,6 +97,6 @@ def non_stim_corr_error( + " ".join(targets) ) emit.correlated_error_count += 1 - emit.writeln(frame, res) + frame.write_line(res) return () diff --git a/src/bloqade/stim/emit/impls.py b/src/bloqade/stim/emit/impls.py index 83a0c867..2aebafc9 100644 --- a/src/bloqade/stim/emit/impls.py +++ b/src/bloqade/stim/emit/impls.py @@ -1,17 +1,16 @@ -from kirin.emit import EmitStrFrame from kirin.interp import MethodTable, impl from kirin.dialects.debug import Info, dialect -from bloqade.stim.emit.stim_str import EmitStimMain +from bloqade.stim.emit.stim_str import EmitStimMain, EmitStimFrame @dialect.register(key="emit.stim") class EmitStimDebugMethods(MethodTable): @impl(Info) - def info(self, emit: EmitStimMain, frame: EmitStrFrame, stmt: Info): + def info(self, emit: EmitStimMain, frame: EmitStimFrame, stmt: Info): msg: str = frame.get(stmt.msg) - emit.writeln(frame, f"# {msg}") + frame.write_line(f"# {msg}") return () diff --git a/src/bloqade/stim/emit/stim_str.py b/src/bloqade/stim/emit/stim_str.py index 0da37ff8..e94f9d60 100644 --- a/src/bloqade/stim/emit/stim_str.py +++ b/src/bloqade/stim/emit/stim_str.py @@ -1,56 +1,71 @@ -from io import StringIO -from typing import IO, TypeVar -from dataclasses import field, dataclass +import sys +from typing import IO, Generic, TypeVar, cast +from dataclasses import dataclass from kirin import ir, interp -from kirin.emit import EmitStr, EmitStrFrame from kirin.dialects import func +from kirin.emit.abc import EmitABC, EmitFrame IO_t = TypeVar("IO_t", bound=IO) -def _default_dialect_group() -> ir.DialectGroup: - from ..groups import main +@dataclass +class EmitStimFrame(EmitFrame[str], Generic[IO_t]): + io: IO_t = cast(IO_t, sys.stdout) + + def write(self, value: str) -> None: + self.io.write(value) - return main + def write_line(self, value: str) -> None: + self.write(" " * self._indent + value + "\n") @dataclass -class EmitStimMain(EmitStr): - keys = ["emit.stim"] - dialects: ir.DialectGroup = field(default_factory=_default_dialect_group) - file: StringIO = field(default_factory=StringIO) +class EmitStimMain(EmitABC[EmitStimFrame, str], Generic[IO_t]): + io: IO_t = cast(IO_t, sys.stdout) + keys = ("emit.stim",) + void = "" correlation_identifier_offset: int = 0 - def initialize(self): + def initialize(self) -> "EmitStimMain": super().initialize() - self.file.truncate(0) - self.file.seek(0) self.correlated_error_count = self.correlation_identifier_offset return self - def eval_stmt_fallback( - self, frame: EmitStrFrame, stmt: ir.Statement - ) -> tuple[str, ...]: - return (stmt.name,) + def initialize_frame( + self, node: ir.Statement, *, has_parent_access: bool = False + ) -> EmitStimFrame: + return EmitStimFrame(node, self.io, has_parent_access=has_parent_access) + + def frame_call( + self, frame: EmitStimFrame, node: ir.Statement, *args: str, **kwargs: str + ) -> str: + return f"{args[0]}({', '.join(args[1:])})" - def emit_block(self, frame: EmitStrFrame, block: ir.Block) -> str | None: - for stmt in block.stmts: - result = self.eval_stmt(frame, stmt) - if isinstance(result, tuple): - frame.set_values(stmt.results, result) - return None + def get_attribute(self, frame: EmitStimFrame, node: ir.Attribute) -> str: + method = self.registry.get(interp.Signature(type(node))) + if method is None: + raise ValueError(f"Method not found for node: {node}") + return method(self, frame, node) - def get_output(self) -> str: - self.file.seek(0) - return self.file.read() + def reset(self): + self.io.truncate(0) + self.io.seek(0) + + def eval_fallback(self, frame: EmitStimFrame, node: ir.Statement) -> tuple: + return tuple("" for _ in range(len(node.results))) @func.dialect.register(key="emit.stim") class FuncEmit(interp.MethodTable): - @interp.impl(func.Function) - def emit_func(self, emit: EmitStimMain, frame: EmitStrFrame, stmt: func.Function): - _ = emit.run_ssacfg_region(frame, stmt.body, ()) - # emit.output = "\n".join(frame.body) + def emit_func(self, emit: EmitStimMain, frame: EmitStimFrame, stmt: func.Function): + for block in stmt.body.blocks: + frame.current_block = block + for stmt_ in block.stmts: + frame.current_stmt = stmt_ + res = emit.frame_eval(frame, stmt_) + if isinstance(res, tuple): + frame.set_values(stmt_.results, res) + return () diff --git a/src/bloqade/stim/groups.py b/src/bloqade/stim/groups.py index df1215d1..fdf6cde8 100644 --- a/src/bloqade/stim/groups.py +++ b/src/bloqade/stim/groups.py @@ -1,12 +1,22 @@ from kirin import ir from kirin.passes import Fold, TypeInfer -from kirin.dialects import func, debug, lowering +from kirin.dialects import func, debug, ssacfg, lowering from .dialects import gate, noise, collapse, auxiliary @ir.dialect_group( - [noise, gate, auxiliary, collapse, func, lowering.func, lowering.call, debug] + [ + noise, + gate, + auxiliary, + collapse, + func, + lowering.func, + lowering.call, + debug, + ssacfg, + ] ) def main(self): typeinfer_pass = TypeInfer(self) diff --git a/src/bloqade/stim/parse/lowering.py b/src/bloqade/stim/parse/lowering.py index 8efe9369..062b792d 100644 --- a/src/bloqade/stim/parse/lowering.py +++ b/src/bloqade/stim/parse/lowering.py @@ -98,6 +98,8 @@ def loads( signature=func.Signature((), return_node.value.type), body=body, ) + self_arg = ir.BlockArgument(body.blocks[0], 0) # Self argument + body.blocks[0]._args = (self_arg,) return ir.Method( mod=None, py_func=None, diff --git a/src/bloqade/stim/passes/squin_to_stim.py b/src/bloqade/stim/passes/squin_to_stim.py index bf73b2c7..4a8f1b4c 100644 --- a/src/bloqade/stim/passes/squin_to_stim.py +++ b/src/bloqade/stim/passes/squin_to_stim.py @@ -44,10 +44,10 @@ def unsafe_run(self, mt: Method) -> RewriteResult: # ------------------------------------------------------------------- mia = MeasurementIDAnalysis(dialects=mt.dialects) - meas_analysis_frame, _ = mia.run_analysis(mt, no_raise=self.no_raise) + meas_analysis_frame, _ = mia.run(mt) aa = AddressAnalysis(dialects=mt.dialects) - address_analysis_frame, _ = aa.run_analysis(mt, no_raise=self.no_raise) + address_analysis_frame, _ = aa.run(mt) # wrap the address analysis result rewrite_result = ( diff --git a/src/bloqade/test_utils.py b/src/bloqade/test_utils.py index a87b6e29..6f636e07 100644 --- a/src/bloqade/test_utils.py +++ b/src/bloqade/test_utils.py @@ -25,7 +25,7 @@ def print_diff(node: pprint.Printable, expected_node: pprint.Printable): def assert_nodes(node: ir.IRNode, expected_node: ir.IRNode): try: - assert node.is_equal(expected_node) + assert node.is_structurally_equal(expected_node) except AssertionError as e: print_diff(node, expected_node) raise e diff --git a/test/analysis/address/test_qubit_analysis.py b/test/analysis/address/test_qubit_analysis.py index cee699ce..1866c9a1 100644 --- a/test/analysis/address/test_qubit_analysis.py +++ b/test/analysis/address/test_qubit_analysis.py @@ -47,7 +47,7 @@ def test(): return (y, z, x) address_analysis = address.AddressAnalysis(test.dialects) - frame, _ = address_analysis.run_analysis(test, no_raise=False) + frame, _ = address_analysis.run(test) address_tuples = collect_address_types(frame, address.PartialTuple) address_qubits = collect_address_types(frame, address.AddressQubit) @@ -73,7 +73,7 @@ def test(): return extract_qubits(q) address_analysis = address.AddressAnalysis(test.dialects) - frame, _ = address_analysis.run_analysis(test, no_raise=False) + frame, _ = address_analysis.run(test) address_tuples = collect_address_types(frame, address.PartialTuple) @@ -95,7 +95,7 @@ def main(): squin.h(single_q) address_analysis = address.AddressAnalysis(main.dialects) - frame, _ = address_analysis.run_analysis(main, no_raise=False) + frame, _ = address_analysis.run(main) address_regs = collect_address_types(frame, address.AddressReg) address_qubits = collect_address_types(frame, address.AddressQubit) diff --git a/test/analysis/measure_id/test_measure_id.py b/test/analysis/measure_id/test_measure_id.py index f933a9ad..2a4f97c8 100644 --- a/test/analysis/measure_id/test_measure_id.py +++ b/test/analysis/measure_id/test_measure_id.py @@ -40,7 +40,7 @@ def test(): Flatten(test.dialects).fixpoint(test) - frame, _ = MeasurementIDAnalysis(test.dialects).run_analysis(test) + frame, _ = MeasurementIDAnalysis(test.dialects).run(test) measure_id_tuples = [ value for value in frame.entries.values() if isinstance(value, MeasureIdTuple) @@ -64,7 +64,7 @@ def test(): return ml_alias Flatten(test.dialects).fixpoint(test) - frame, _ = MeasurementIDAnalysis(test.dialects).run_analysis(test) + frame, _ = MeasurementIDAnalysis(test.dialects).run(test) # Collect MeasureIdTuples measure_id_tuples = [ @@ -105,7 +105,7 @@ def test(): squin.y(q[1]) Flatten(test.dialects).fixpoint(test) - frame, _ = MeasurementIDAnalysis(test.dialects).run_analysis(test) + frame, _ = MeasurementIDAnalysis(test.dialects).run(test) assert all( isinstance(stmt, scf.IfElse) and measures_accumulated == 5 @@ -129,7 +129,7 @@ def test(): return ms InlinePass(test.dialects).fixpoint(test) - frame, _ = MeasurementIDAnalysis(test.dialects).run_analysis(test) + frame, _ = MeasurementIDAnalysis(test.dialects).run(test) # MeasureIdBool(idx=1) should occur twice: # First from the measurement in the true branch, then @@ -158,7 +158,7 @@ def test(): # need to preserve the scf.IfElse but need things like qalloc to be inlined InlinePass(test.dialects).fixpoint(test) - frame, _ = MeasurementIDAnalysis(test.dialects).run_analysis(test) + frame, _ = MeasurementIDAnalysis(test.dialects).run(test) test.print(analysis=frame.entries) # MeasureIdBool(idx=1) should occur twice: @@ -187,7 +187,7 @@ def test(cond: bool): # We can use Flatten here because the variable condition for the scf.IfElse # means it cannot be simplified. Flatten(test.dialects).fixpoint(test) - frame, _ = MeasurementIDAnalysis(test.dialects).run_analysis(test) + frame, _ = MeasurementIDAnalysis(test.dialects).run(test) analysis_results = [ val for val in frame.entries.values() if isinstance(val, MeasureIdTuple) ] @@ -215,7 +215,7 @@ def test(): return ms_final Flatten(test.dialects).fixpoint(test) - frame, _ = MeasurementIDAnalysis(test.dialects).run_analysis(test) + frame, _ = MeasurementIDAnalysis(test.dialects).run(test) results = results_of_variables(test, ("msi", "msi2", "ms_final")) @@ -241,7 +241,7 @@ def test(idx): return ms[idx] - frame, _ = MeasurementIDAnalysis(test.dialects).run_analysis(test) + frame, _ = MeasurementIDAnalysis(test.dialects).run(test) assert [frame.entries[result] for result in results_at(test, 0, 3)] == [ InvalidMeasureId(), @@ -256,7 +256,7 @@ def test(): return ms["x"] - frame, _ = MeasurementIDAnalysis(test.dialects).run_analysis(test) + frame, _ = MeasurementIDAnalysis(test.dialects).run(test) assert [frame.entries[result] for result in results_at(test, 0, 4)] == [ InvalidMeasureId() @@ -273,7 +273,7 @@ def test(): invalid_ms = ms["x"] return invalid_ms[0] - frame, _ = MeasurementIDAnalysis(test.dialects).run_analysis(test) + frame, _ = MeasurementIDAnalysis(test.dialects).run(test) assert [frame.entries[result] for result in results_at(test, 0, 6)] == [ InvalidMeasureId() diff --git a/test/cirq_utils/test_clifford_to_cirq.py b/test/cirq_utils/test_clifford_to_cirq.py index 7220b7ae..a1c4d72d 100644 --- a/test/cirq_utils/test_clifford_to_cirq.py +++ b/test/cirq_utils/test_clifford_to_cirq.py @@ -4,8 +4,8 @@ import cirq import numpy as np import pytest -from kirin.emit import EmitError from kirin.dialects import ilist +from kirin.interp.exceptions import InterpreterError from bloqade import squin from bloqade.pyqrack import Measurement, StackMemorySimulator @@ -129,7 +129,7 @@ def main(): print(circuit) - with pytest.raises(EmitError): + with pytest.raises(InterpreterError): emit_circuit(sub_kernel) @squin.kernel diff --git a/test/qasm2/emit/t_qasm2.qasm b/test/qasm2/emit/t_qasm2.qasm new file mode 100644 index 00000000..aeb093d3 --- /dev/null +++ b/test/qasm2/emit/t_qasm2.qasm @@ -0,0 +1,13 @@ +OPENQASM 2.0; +include "qelib1.inc"; +gate custom_gate a, b { + CX a, b; +} +qreg qreg[4]; +creg creg[2]; +CX qreg[0], qreg[1]; +reset qreg[0]; +measure qreg[0] -> creg[0]; +if (creg[0] == 1) reset qreg[1]; +custom_gate qreg[0], qreg[1]; +custom_gate qreg[1], qreg[2]; diff --git a/test/qasm2/emit/test_extended_noise.py b/test/qasm2/emit/test_extended_noise.py index 0878022c..2c95705c 100644 --- a/test/qasm2/emit/test_extended_noise.py +++ b/test/qasm2/emit/test_extended_noise.py @@ -50,7 +50,7 @@ def main(): target = qasm2.emit.QASM2(allow_noise=True, allow_parallel=True) out = target.emit_str(main) - expected = """KIRIN {func,lowering.call,lowering.func,py.ilist,qasm2.core,qasm2.expr,qasm2.indexing,qasm2.noise,qasm2.parallel,qasm2.uop,scf}; + expected = """KIRIN {func,lowering.call,lowering.func,py.ilist,qasm2.core,qasm2.expr,qasm2.indexing,qasm2.noise,qasm2.parallel,qasm2.uop,scf,ssacfg}; include "qelib1.inc"; qreg qreg[4]; CX qreg[0], qreg[1]; diff --git a/test/qasm2/emit/test_qasm2.py b/test/qasm2/emit/test_qasm2.py index 83f0ab80..d7a2d6a3 100644 --- a/test/qasm2/emit/test_qasm2.py +++ b/test/qasm2/emit/test_qasm2.py @@ -1,3 +1,7 @@ +import io +from pathlib import Path +from contextlib import redirect_stdout + from bloqade import qasm2 @@ -6,21 +10,30 @@ def test_qasm2_custom_gate(): def custom_gate(a: qasm2.Qubit, b: qasm2.Qubit): qasm2.cx(a, b) + @qasm2.gate + def custom_gate2(a: qasm2.Bit, b: qasm2.Bit): + return + @qasm2.main def main(): qreg = qasm2.qreg(4) creg = qasm2.creg(2) qasm2.cx(qreg[0], qreg[1]) qasm2.reset(qreg[0]) - # qasm2.parallel.cz(ctrls=[qreg[0], qreg[1]], qargs=[qreg[2], qreg[3]]) qasm2.measure(qreg[0], creg[0]) if creg[0] == 1: qasm2.reset(qreg[1]) custom_gate(qreg[0], qreg[1]) - - main.print() - custom_gate.print() + custom_gate2(creg[0], creg[1]) + custom_gate(qreg[1], qreg[2]) target = qasm2.emit.QASM2(custom_gate=True) ast = target.emit(main) - qasm2.parse.pprint(ast) + filename = "t_qasm2.qasm" + with open(Path(__file__).parent.resolve() / filename, "r") as txt: + target = txt.read() + buf = io.StringIO() + with redirect_stdout(buf): + qasm2.parse.pprint(ast) + generated = buf.getvalue() + assert generated.strip() == target.strip() diff --git a/test/qasm2/emit/test_qasm2_emit.py b/test/qasm2/emit/test_qasm2_emit.py index e044b7ee..00eaaec4 100644 --- a/test/qasm2/emit/test_qasm2_emit.py +++ b/test/qasm2/emit/test_qasm2_emit.py @@ -20,7 +20,7 @@ def glob_u(): qasm2_str = target.emit_str(glob_u) assert ( qasm2_str - == """KIRIN {func,lowering.call,lowering.func,py.ilist,qasm2.core,qasm2.expr,qasm2.glob,qasm2.indexing,qasm2.noise,qasm2.uop,scf}; + == """KIRIN {func,lowering.call,lowering.func,py.ilist,qasm2.core,qasm2.expr,qasm2.glob,qasm2.indexing,qasm2.noise,qasm2.uop,scf,ssacfg}; include "qelib1.inc"; qreg qreg[3]; qreg qreg1[3]; @@ -43,10 +43,9 @@ def glob_u(): custom_gate=True, ) qasm2_str = target.emit_str(glob_u) - print(qasm2_str) assert ( qasm2_str - == """KIRIN {func,lowering.call,lowering.func,py.ilist,qasm2.core,qasm2.expr,qasm2.glob,qasm2.indexing,qasm2.noise,qasm2.parallel,qasm2.uop,scf}; + == """KIRIN {func,lowering.call,lowering.func,py.ilist,qasm2.core,qasm2.expr,qasm2.glob,qasm2.indexing,qasm2.noise,qasm2.parallel,qasm2.uop,scf,ssacfg}; include "qelib1.inc"; qreg qreg[3]; qreg qreg1[3]; @@ -55,6 +54,7 @@ def glob_u(): ) +@pytest.mark.xfail(reason="Global and Parallel refactor still ongoing") def test_global(): @qasm2.extended @@ -85,6 +85,7 @@ def glob_u(): ) +@pytest.mark.xfail(reason="Global and Parallel refactor still ongoing") def test_global_allow_para(): @qasm2.extended @@ -101,7 +102,7 @@ def glob_u(): qasm2_str = target.emit_str(glob_u) assert ( qasm2_str - == """KIRIN {func,lowering.call,lowering.func,py.ilist,qasm2.core,qasm2.expr,qasm2.indexing,qasm2.noise,qasm2.parallel,qasm2.uop,scf}; + == """KIRIN {func,lowering.call,lowering.func,py.ilist,qasm2.core,qasm2.expr,qasm2.indexing,qasm2.noise,qasm2.parallel,qasm2.uop,scf,ssacfg}; include "qelib1.inc"; qreg qreg[3]; qreg qreg1[3]; @@ -117,6 +118,7 @@ def glob_u(): ) +@pytest.mark.xfail(reason="Global and Parallel refactor still ongoing") def test_para(): @qasm2.extended @@ -132,7 +134,6 @@ def para_u(): custom_gate=True, ) qasm2_str = target.emit_str(para_u) - print(qasm2_str) assert ( qasm2_str == """OPENQASM 2.0; @@ -144,6 +145,7 @@ def para_u(): ) +@pytest.mark.xfail(reason="Global and Parallel refactor still ongoing") def test_para_allow_para(): @qasm2.extended @@ -160,7 +162,7 @@ def para_u(): qasm2_str = target.emit_str(para_u) assert ( qasm2_str - == """KIRIN {func,lowering.call,lowering.func,py.ilist,qasm2.core,qasm2.expr,qasm2.indexing,qasm2.noise,qasm2.parallel,qasm2.uop,scf}; + == """KIRIN {func,lowering.call,lowering.func,py.ilist,qasm2.core,qasm2.expr,qasm2.indexing,qasm2.noise,qasm2.parallel,qasm2.uop,scf,ssacfg}; include "qelib1.inc"; qreg qreg[3]; parallel.U(0.1, 0.2, 0.3) { @@ -188,7 +190,7 @@ def para_u(): qasm2_str = target.emit_str(para_u) assert ( qasm2_str - == """KIRIN {func,lowering.call,lowering.func,py.ilist,qasm2.core,qasm2.expr,qasm2.glob,qasm2.indexing,qasm2.noise,qasm2.parallel,qasm2.uop,scf}; + == """KIRIN {func,lowering.call,lowering.func,py.ilist,qasm2.core,qasm2.expr,qasm2.glob,qasm2.indexing,qasm2.noise,qasm2.parallel,qasm2.uop,scf,ssacfg}; include "qelib1.inc"; qreg qreg[3]; parallel.U(0.1, 0.2, 0.3) { @@ -199,6 +201,7 @@ def para_u(): ) +@pytest.mark.xfail(reason="Global and Parallel refactor still ongoing") def test_para_allow_global(): @qasm2.extended @@ -217,7 +220,7 @@ def para_u(): print(qasm2_str) assert ( qasm2_str - == """KIRIN {func,lowering.call,lowering.func,py.ilist,qasm2.core,qasm2.expr,qasm2.glob,qasm2.indexing,qasm2.noise,qasm2.uop,scf}; + == """KIRIN {func,lowering.call,lowering.func,py.ilist,qasm2.core,qasm2.expr,qasm2.glob,qasm2.indexing,qasm2.noise,qasm2.uop,scf,ssacfg}; include "qelib1.inc"; qreg qreg[3]; U(0.1, 0.2, 0.3) qreg[1]; @@ -306,7 +309,7 @@ def ghz_linear(): qasm2_str = target.emit_str(ghz_linear) assert qasm2_str == ( - """KIRIN {func,lowering.call,lowering.func,py.ilist,qasm2.core,qasm2.expr,qasm2.indexing,qasm2.noise,qasm2.parallel,qasm2.uop,scf}; + """KIRIN {func,lowering.call,lowering.func,py.ilist,qasm2.core,qasm2.expr,qasm2.indexing,qasm2.noise,qasm2.parallel,qasm2.uop,scf,ssacfg}; include "qelib1.inc"; qreg q[4]; h q[0]; diff --git a/test/qasm2/passes/test_global_to_parallel.py b/test/qasm2/passes/test_global_to_parallel.py index 3b147636..b77a3a13 100644 --- a/test/qasm2/passes/test_global_to_parallel.py +++ b/test/qasm2/passes/test_global_to_parallel.py @@ -1,5 +1,6 @@ from typing import List +import pytest from kirin import ir, types from kirin.rewrite import Walk, Fixpoint, CommonSubexpressionElimination from kirin.dialects import py, func, ilist @@ -17,6 +18,7 @@ def as_float(value: float): return py.constant.Constant(value=value) +@pytest.mark.xfail def test_global2para_rewrite(): @qasm2.extended @@ -77,6 +79,7 @@ def main(): assert_methods(expected_method, main) +@pytest.mark.xfail def test_global2para_rewrite2(): @qasm2.extended diff --git a/test/qasm2/passes/test_global_to_uop.py b/test/qasm2/passes/test_global_to_uop.py index dafe6f2a..9be187d9 100644 --- a/test/qasm2/passes/test_global_to_uop.py +++ b/test/qasm2/passes/test_global_to_uop.py @@ -1,5 +1,6 @@ from typing import List +import pytest from kirin import ir, types from kirin.rewrite import Walk, Fixpoint, CommonSubexpressionElimination from kirin.dialects import py, func @@ -17,6 +18,7 @@ def as_float(value: float): return py.constant.Constant(value=value) +@pytest.mark.xfail def test_global_rewrite(): @qasm2.extended diff --git a/test/qasm2/passes/test_heuristic_noise.py b/test/qasm2/passes/test_heuristic_noise.py index 78879a5e..20a4c5b0 100644 --- a/test/qasm2/passes/test_heuristic_noise.py +++ b/test/qasm2/passes/test_heuristic_noise.py @@ -1,3 +1,4 @@ +import pytest from kirin import ir, types from kirin.dialects import func, ilist from kirin.dialects.py import constant @@ -255,6 +256,7 @@ def test_parallel_cz_gate_noise(): assert_nodes(block, expected_block) +@pytest.mark.xfail def test_global_noise(): @qasm2.extended diff --git a/test/qasm2/passes/test_parallel_to_global.py b/test/qasm2/passes/test_parallel_to_global.py index c72533b8..93fbac7f 100644 --- a/test/qasm2/passes/test_parallel_to_global.py +++ b/test/qasm2/passes/test_parallel_to_global.py @@ -1,7 +1,10 @@ +import pytest + from bloqade import qasm2 from bloqade.qasm2.passes.parallel import ParallelToGlobal +@pytest.mark.xfail def test_basic_rewrite(): @qasm2.extended @@ -29,6 +32,7 @@ def main(): ) +@pytest.mark.xfail def test_if_rewrite(): @qasm2.extended def main(): @@ -63,6 +67,7 @@ def main(): ) +@pytest.mark.xfail def test_should_not_be_rewritten(): @qasm2.extended @@ -88,6 +93,7 @@ def main(): ) +@pytest.mark.xfail def test_multiple_registers(): @qasm2.extended def main(): @@ -120,6 +126,7 @@ def main(): ) +@pytest.mark.xfail def test_reverse_order(): @qasm2.extended def main(): diff --git a/test/qasm2/passes/test_parallel_to_uop.py b/test/qasm2/passes/test_parallel_to_uop.py index c3e2c59e..7484542e 100644 --- a/test/qasm2/passes/test_parallel_to_uop.py +++ b/test/qasm2/passes/test_parallel_to_uop.py @@ -1,5 +1,6 @@ from typing import List +import pytest from kirin import ir, types from kirin.dialects import py, func @@ -16,6 +17,7 @@ def as_float(value: float): return py.constant.Constant(value=value) +@pytest.mark.xfail def test_cz_rewrite(): @qasm2.extended diff --git a/test/qasm2/passes/test_uop_to_parallel.py b/test/qasm2/passes/test_uop_to_parallel.py index 2016bb7a..7e1f3bdc 100644 --- a/test/qasm2/passes/test_uop_to_parallel.py +++ b/test/qasm2/passes/test_uop_to_parallel.py @@ -1,3 +1,5 @@ +import pytest + from bloqade import qasm2 from bloqade.qasm2 import glob from bloqade.analysis import address @@ -5,6 +7,7 @@ from bloqade.qasm2.rewrite import SimpleOptimalMergePolicy +@pytest.mark.xfail def test_one(): @qasm2.gate @@ -36,7 +39,7 @@ def test(): test.print() # add this to raise error if there are broken ssa references - _, _ = address.AddressAnalysis(test.dialects).run_analysis(test, no_raise=False) + _, _ = address.AddressAnalysis(test.dialects).run(test) # check that there's parallel statements now assert any( @@ -47,6 +50,7 @@ def test(): ) +@pytest.mark.xfail def test_two(): @qasm2.extended @@ -82,9 +86,10 @@ def test(): test.print() # add this to raise error if there are broken ssa references - _, _ = address.AddressAnalysis(test.dialects).run_analysis(test, no_raise=False) + _, _ = address.AddressAnalysis(test.dialects).run(test) +@pytest.mark.xfail def test_three(): @qasm2.extended @@ -104,4 +109,4 @@ def test(): test.print() # add this to raise error if there are broken ssa references - _, _ = address.AddressAnalysis(test.dialects).run_analysis(test, no_raise=False) + _, _ = address.AddressAnalysis(test.dialects).run(test) diff --git a/test/qasm2/test_count.py b/test/qasm2/test_count.py index a09cefc0..df5cef31 100644 --- a/test/qasm2/test_count.py +++ b/test/qasm2/test_count.py @@ -1,3 +1,4 @@ +import pytest from kirin import passes from kirin.dialects import py, ilist @@ -15,6 +16,7 @@ fold = passes.Fold(qasm2.main.add(py.tuple).add(ilist)) +@pytest.mark.xfail def test_fixed_count(): @qasm2.main def fixed_count(): @@ -27,13 +29,14 @@ def fixed_count(): return q3 fold(fixed_count) - results, ret = address.run_analysis(fixed_count) + results, ret = address.run(fixed_count) # fixed_count.print(analysis=address.results) assert isinstance(ret, AddressQubit) assert ret.data == range(3, 7)[1] assert address.qubit_count == 7 +@pytest.mark.xfail def test_multiple_return_only_reg(): @qasm2.main.add(py.tuple) @@ -44,13 +47,14 @@ def tuple_count(): # tuple_count.dce() fold(tuple_count) - frame, ret = address.run_analysis(tuple_count) + frame, ret = address.run(tuple_count) # tuple_count.code.print(analysis=frame.entries) assert isinstance(ret, PartialTuple) assert isinstance(ret.data[0], AddressReg) and ret.data[0].data == range(0, 3) assert isinstance(ret.data[1], AddressReg) and ret.data[1].data == range(3, 7) +@pytest.mark.xfail def test_dynamic_address(): @qasm2.main def dynamic_address(): @@ -88,6 +92,7 @@ def dynamic_address(): # assert isinstance(result, ConstResult) +@pytest.mark.xfail def test_multi_return(): @qasm2.main.add(py.tuple) def multi_return_cnt(): @@ -97,7 +102,7 @@ def multi_return_cnt(): multi_return_cnt.code.print() fold(multi_return_cnt) - _, result = address.run_analysis(multi_return_cnt) + _, result = address.run(multi_return_cnt) print(result) assert isinstance(result, PartialTuple) assert isinstance(result.data[0], AddressReg) @@ -105,6 +110,7 @@ def multi_return_cnt(): assert isinstance(result.data[2], AddressReg) +@pytest.mark.xfail def test_list(): @qasm2.main.add(ilist) def list_count_analy(): @@ -119,6 +125,7 @@ def list_count_analy(): assert ret == AddressReg(data=(0, 1, 3)) +@pytest.mark.xfail def test_tuple_qubits(): @qasm2.main.add(py.tuple) def list_count_analy2(): @@ -159,6 +166,7 @@ def list_count_analy2(): # assert isinstance(result.data[5], AddressQubit) and result.data[5].data == 6 +@pytest.mark.xfail def test_alias(): @qasm2.main @@ -173,6 +181,6 @@ def test_alias(): test_alias.code.print() fold(test_alias) - _, ret = address.run_analysis(test_alias) + _, ret = address.run(test_alias) assert isinstance(ret, AddressQubit) assert ret.data == 0 diff --git a/test/qasm2/test_lowering.py b/test/qasm2/test_lowering.py index 617d7154..6eee1509 100644 --- a/test/qasm2/test_lowering.py +++ b/test/qasm2/test_lowering.py @@ -3,6 +3,7 @@ import tempfile import textwrap +import pytest from kirin import ir, types from kirin.dialects import func @@ -25,12 +26,14 @@ ) +@pytest.mark.xfail def test_run_lowering(): ast = qasm2.parse.loads(lines) code = QASM2(qasm2.main).run(ast) code.print() +@pytest.mark.xfail def test_loadfile(): with tempfile.TemporaryDirectory() as tmp_dir: @@ -41,6 +44,7 @@ def test_loadfile(): qasm2.loadfile(file) +@pytest.mark.xfail def test_negative_lowering(): mwe = """ @@ -80,6 +84,7 @@ def test_negative_lowering(): assert entry.code.is_structurally_equal(code) +@pytest.mark.xfail def test_gate(): qasm2_prog = textwrap.dedent( """ @@ -108,6 +113,7 @@ def test_gate(): assert math.isclose(abs(ket[3]) ** 2, 0.5, abs_tol=1e-6) +@pytest.mark.xfail def test_gate_with_params(): qasm2_prog = textwrap.dedent( """ @@ -138,6 +144,7 @@ def test_gate_with_params(): assert math.isclose(abs(ket[3]) ** 2, 0.5, abs_tol=1e-6) +@pytest.mark.xfail def test_if_lowering(): qasm2_prog = textwrap.dedent( diff --git a/test/qasm2/test_native.py b/test/qasm2/test_native.py index fdfaf64d..15e6ebcb 100644 --- a/test/qasm2/test_native.py +++ b/test/qasm2/test_native.py @@ -3,6 +3,7 @@ import cirq import numpy as np +import pytest import cirq.testing import cirq.contrib.qasm_import as qasm_import import cirq.circuits.qasm_output as qasm_output @@ -157,6 +158,7 @@ def kernel(): assert new_qasm2.count("\n") > prog.count("\n") +@pytest.mark.xfail def test_ccx_rewrite(): @qasm2.extended diff --git a/test/squin/test_typeinfer.py b/test/squin/test_typeinfer.py index 306f0d4b..2a4e5118 100644 --- a/test/squin/test_typeinfer.py +++ b/test/squin/test_typeinfer.py @@ -32,7 +32,7 @@ def test(): return q type_infer_analysis = TypeInference(dialects=test.dialects) - frame, _ = type_infer_analysis.run_analysis(test) + frame, _ = type_infer_analysis.run(test) assert [frame.entries[result] for result in results_at(test, 0, 1)] == [ IListType[QubitType, Literal(4)] @@ -48,7 +48,7 @@ def test(n: int): type_infer_analysis = TypeInference(dialects=test.dialects) - frame_ambiguous, _ = type_infer_analysis.run_analysis(test) + frame_ambiguous, _ = type_infer_analysis.run(test) assert [frame_ambiguous.entries[result] for result in results_at(test, 0, 0)] == [ IListType[QubitType, Any] @@ -67,7 +67,7 @@ def test(): return [q0, q1] type_infer_analysis = TypeInference(dialects=test.dialects) - frame, _ = type_infer_analysis.run_analysis(test) + frame, _ = type_infer_analysis.run(test) assert [frame.entries[result] for result in results_at(test, 0, 3)] == [QubitType] assert [frame.entries[result] for result in results_at(test, 0, 5)] == [QubitType] diff --git a/test/stim/dialects/stim/emit/base.py b/test/stim/dialects/stim/emit/base.py deleted file mode 100644 index a07f4456..00000000 --- a/test/stim/dialects/stim/emit/base.py +++ /dev/null @@ -1,12 +0,0 @@ -from kirin import ir - -from bloqade.stim.emit import EmitStimMain - -emit = EmitStimMain() - - -def codegen(mt: ir.Method): - # method should not have any arguments! - emit.initialize() - emit.run(mt=mt, args=()) - return emit.get_output() diff --git a/test/stim/dialects/stim/emit/test_stim_1q.py b/test/stim/dialects/stim/emit/test_stim_1q.py index 36b7b94e..0d08a70f 100644 --- a/test/stim/dialects/stim/emit/test_stim_1q.py +++ b/test/stim/dialects/stim/emit/test_stim_1q.py @@ -1,6 +1,7 @@ -from bloqade import stim +import io -from .base import codegen +from bloqade import stim +from bloqade.stim.emit import EmitStimMain def test_x(): @@ -9,18 +10,19 @@ def test_x(): def test_x(): stim.x(targets=(0, 1, 2, 3), dagger=False) - test_x.print() - out = codegen(test_x) - - assert out.strip() == "X 0 1 2 3" + buf = io.StringIO() + stim_emit = EmitStimMain(dialects=stim.main, io=buf) + stim_emit.run(test_x) + assert buf.getvalue().strip() == "X 0 1 2 3" @stim.main def test_x_dag(): stim.x(targets=(0, 1, 2, 3), dagger=True) - out = codegen(test_x_dag) - - assert out.strip() == "X 0 1 2 3" + buf = io.StringIO() + stim_emit = EmitStimMain(dialects=stim.main, io=buf) + stim_emit.run(test_x_dag) + assert buf.getvalue().strip() == "X 0 1 2 3" def test_y(): @@ -29,15 +31,16 @@ def test_y(): def test_y(): stim.y(targets=(0, 1, 2, 3), dagger=False) - test_y.print() - out = codegen(test_y) - - assert out.strip() == "Y 0 1 2 3" + buf = io.StringIO() + stim_emit = EmitStimMain(dialects=stim.main, io=buf) + stim_emit.run(test_y) + assert buf.getvalue().strip() == "Y 0 1 2 3" @stim.main def test_y_dag(): stim.y(targets=(0, 1, 2, 3), dagger=True) - out = codegen(test_y_dag) - - assert out.strip() == "Y 0 1 2 3" + buf = io.StringIO() + stim_emit = EmitStimMain(dialects=stim.main, io=buf) + stim_emit.run(test_y_dag) + assert buf.getvalue().strip() == "Y 0 1 2 3" diff --git a/test/stim/dialects/stim/emit/test_stim_ctrl.py b/test/stim/dialects/stim/emit/test_stim_ctrl.py index 5b2ac582..93ebe018 100644 --- a/test/stim/dialects/stim/emit/test_stim_ctrl.py +++ b/test/stim/dialects/stim/emit/test_stim_ctrl.py @@ -1,8 +1,9 @@ +import io + from bloqade import stim +from bloqade.stim.emit import EmitStimMain from bloqade.stim.dialects import gate, auxiliary -from .base import codegen - def test_cx(): @@ -10,8 +11,10 @@ def test_cx(): def test_simple_cx(): gate.CX(controls=(4, 5, 6, 7), targets=(0, 1, 2, 3), dagger=False) - out = codegen(test_simple_cx) - assert out.strip() == "CX 4 0 5 1 6 2 7 3" + buf = io.StringIO() + stim_emit = EmitStimMain(dialects=stim.main, io=buf) + stim_emit.run(test_simple_cx) + assert buf.getvalue().strip() == "CX 4 0 5 1 6 2 7 3" def test_cx_cond_on_measure(): @@ -24,6 +27,7 @@ def test_simple_cx_cond_measure(): dagger=False, ) - out = codegen(test_simple_cx_cond_measure) - - assert out.strip() == "CX rec[-1] 0 4 1 rec[-2] 2" + buf = io.StringIO() + stim_emit = EmitStimMain(dialects=stim.main, io=buf) + stim_emit.run(test_simple_cx_cond_measure) + assert buf.getvalue().strip() == "CX rec[-1] 0 4 1 rec[-2] 2" diff --git a/test/stim/dialects/stim/emit/test_stim_debug.py b/test/stim/dialects/stim/emit/test_stim_debug.py index b056e540..c09819a0 100644 --- a/test/stim/dialects/stim/emit/test_stim_debug.py +++ b/test/stim/dialects/stim/emit/test_stim_debug.py @@ -1,8 +1,9 @@ +import io + from kirin.dialects import debug from bloqade import stim - -from .base import codegen +from bloqade.stim.emit import EmitStimMain def test_debug(): @@ -12,5 +13,8 @@ def test_debug_main(): debug.info("debug message") test_debug_main.print() - out = codegen(test_debug_main) - assert out.strip() == "# debug message" + + buf = io.StringIO() + stim_emit: EmitStimMain[io.StringIO] = EmitStimMain(dialects=stim.main, io=buf) + stim_emit.run(test_debug_main) + assert buf.getvalue().strip() == "# debug message" diff --git a/test/stim/dialects/stim/emit/test_stim_detector.py b/test/stim/dialects/stim/emit/test_stim_detector.py index e12c22ae..013f1adc 100644 --- a/test/stim/dialects/stim/emit/test_stim_detector.py +++ b/test/stim/dialects/stim/emit/test_stim_detector.py @@ -1,6 +1,7 @@ -from bloqade import stim +import io -from .base import codegen +from bloqade import stim +from bloqade.stim.emit import EmitStimMain def test_detector(): @@ -9,9 +10,7 @@ def test_detector(): def test_simple_cx(): stim.detector(coord=(1, 2, 3), targets=(stim.rec(-3), stim.rec(-1))) - out = codegen(test_simple_cx) - - assert out.strip() == "DETECTOR(1, 2, 3) rec[-3] rec[-1]" - - -test_detector() + buf = io.StringIO() + stim_emit = EmitStimMain(dialects=stim.main, io=buf) + stim_emit.run(test_simple_cx) + assert buf.getvalue().strip() == "DETECTOR(1, 2, 3) rec[-3] rec[-1]" diff --git a/test/stim/dialects/stim/emit/test_stim_meas.py b/test/stim/dialects/stim/emit/test_stim_meas.py index 54cf8e06..1d157581 100644 --- a/test/stim/dialects/stim/emit/test_stim_meas.py +++ b/test/stim/dialects/stim/emit/test_stim_meas.py @@ -1,8 +1,9 @@ +import io + from bloqade import stim +from bloqade.stim.emit import EmitStimMain from bloqade.stim.dialects import collapse -from .base import codegen - def test_meas(): @@ -10,6 +11,7 @@ def test_meas(): def test_simple_meas(): collapse.MX(p=0.3, targets=(0, 3, 4, 5)) - out = codegen(test_simple_meas) - - assert out.strip() == "MX(0.30000000) 0 3 4 5" + buf = io.StringIO() + stim_emit = EmitStimMain(dialects=stim.main, io=buf) + stim_emit.run(test_simple_meas) + assert buf.getvalue().strip() == "MX(0.30000000) 0 3 4 5" diff --git a/test/stim/dialects/stim/emit/test_stim_noise.py b/test/stim/dialects/stim/emit/test_stim_noise.py index 773bc5c1..44b2f41c 100644 --- a/test/stim/dialects/stim/emit/test_stim_noise.py +++ b/test/stim/dialects/stim/emit/test_stim_noise.py @@ -1,16 +1,18 @@ +import io + from bloqade import stim from bloqade.stim.emit import EmitStimMain from bloqade.stim.parse import loads from bloqade.stim.dialects import noise -emit = EmitStimMain() - def codegen(mt): # method should not have any arguments! + buf = io.StringIO() + emit = EmitStimMain(dialects=stim.main, io=buf) emit.initialize() - emit.run(mt=mt, args=()) - return emit.get_output() + emit.run(mt) + return buf.getvalue().strip() def test_noise(): @@ -36,9 +38,8 @@ def test_pauli2(): targets=(0, 3, 4, 5), ) - out = codegen(test_pauli2) assert ( - out.strip() + codegen(test_pauli2) == "PAULI_CHANNEL_2(0.10000000, 0.10000000, 0.10000000, 0.10000000, 0.10000000, 0.10000000, 0.10000000, 0.10000000, 0.10000000, 0.10000000, 0.10000000, 0.10000000, 0.10000000, 0.10000000, 0.10000000) 0 3 4 5" ) @@ -48,8 +49,7 @@ def test_qubit_loss(): def test_qubit_loss(): stim.qubit_loss(probs=(0.1,), targets=(0, 1, 2)) - out = codegen(test_qubit_loss) - assert out.strip() == "I_ERROR[loss](0.10000000) 0 1 2" + assert codegen(test_qubit_loss) == "I_ERROR[loss](0.10000000) 0 1 2" def test_correlated_qubit_loss(): @@ -57,8 +57,10 @@ def test_correlated_qubit_loss(): def test_correlated_qubit_loss(): stim.correlated_qubit_loss(probs=(0.1,), targets=(0, 1, 2)) - out = codegen(test_correlated_qubit_loss) - assert out.strip() == "I_ERROR[correlated_loss:0](0.10000000) 0 1 2" + assert ( + codegen(test_correlated_qubit_loss) + == "I_ERROR[correlated_loss:0](0.10000000) 0 1 2" + ) def test_correlated_qubit_loss_multiple(): diff --git a/test/stim/dialects/stim/emit/test_stim_obs_inc.py b/test/stim/dialects/stim/emit/test_stim_obs_inc.py index 5ffca3f5..773c95ac 100644 --- a/test/stim/dialects/stim/emit/test_stim_obs_inc.py +++ b/test/stim/dialects/stim/emit/test_stim_obs_inc.py @@ -1,8 +1,9 @@ +import io + from bloqade import stim +from bloqade.stim.emit import EmitStimMain from bloqade.stim.dialects import auxiliary -from .base import codegen - def test_obs_inc(): @@ -12,6 +13,7 @@ def test_simple_obs_inc(): idx=3, targets=(auxiliary.GetRecord(-3), auxiliary.GetRecord(-1)) ) - out = codegen(test_simple_obs_inc) - - assert out.strip() == "OBSERVABLE_INCLUDE(3) rec[-3] rec[-1]" + buf = io.StringIO() + stim_emit = EmitStimMain(dialects=stim.main, io=buf) + stim_emit.run(test_simple_obs_inc) + assert buf.getvalue().strip() == "OBSERVABLE_INCLUDE(3) rec[-3] rec[-1]" diff --git a/test/stim/dialects/stim/emit/test_stim_ppmeas.py b/test/stim/dialects/stim/emit/test_stim_ppmeas.py index b32d0815..51b03254 100644 --- a/test/stim/dialects/stim/emit/test_stim_ppmeas.py +++ b/test/stim/dialects/stim/emit/test_stim_ppmeas.py @@ -1,6 +1,7 @@ -from bloqade import stim +import io -from .base import codegen +from bloqade import stim +from bloqade.stim.emit import EmitStimMain def test_mpp(): @@ -23,10 +24,7 @@ def test_mpp_main(): p=0.3, ) - test_mpp_main.print() - out = codegen(test_mpp_main) - - assert out.strip() == "MPP(0.30000000) !X0*X1*Z2 Y3*X4*!Y5" - - -test_mpp() + buf = io.StringIO() + stim_emit = EmitStimMain(dialects=stim.main, io=buf) + stim_emit.run(test_mpp_main) + assert buf.getvalue().strip() == "MPP(0.30000000) !X0*X1*Z2 Y3*X4*!Y5" diff --git a/test/stim/dialects/stim/emit/test_stim_qubit_coords.py b/test/stim/dialects/stim/emit/test_stim_qubit_coords.py index 393fc9b4..43874e04 100644 --- a/test/stim/dialects/stim/emit/test_stim_qubit_coords.py +++ b/test/stim/dialects/stim/emit/test_stim_qubit_coords.py @@ -1,8 +1,9 @@ +import io + from bloqade import stim +from bloqade.stim.emit import EmitStimMain from bloqade.stim.dialects import auxiliary -from .base import codegen - def test_qcoords(): @@ -10,6 +11,7 @@ def test_qcoords(): def test_simple_qcoords(): auxiliary.QubitCoordinates(coord=(0.1, 0.2), target=3) - out = codegen(test_simple_qcoords) - - assert out.strip() == "QUBIT_COORDS(0.10000000, 0.20000000) 3" + buf = io.StringIO() + stim_emit = EmitStimMain(dialects=stim.main, io=buf) + stim_emit.run(test_simple_qcoords) + assert buf.getvalue().strip() == "QUBIT_COORDS(0.10000000, 0.20000000) 3" diff --git a/test/stim/dialects/stim/emit/test_stim_spp.py b/test/stim/dialects/stim/emit/test_stim_spp.py index d0cf7b9e..44c0a99b 100644 --- a/test/stim/dialects/stim/emit/test_stim_spp.py +++ b/test/stim/dialects/stim/emit/test_stim_spp.py @@ -1,6 +1,7 @@ -from bloqade import stim +import io -from .base import codegen +from bloqade import stim +from bloqade.stim.emit import EmitStimMain def test_spp(): @@ -23,9 +24,7 @@ def test_spp_main(): dagger=False, ) - test_spp_main.print() - out = codegen(test_spp_main) - assert out.strip() == "SPP !X0*X1*Z2 Y3*X4*!Y5" - - -test_spp() + buf = io.StringIO() + stim_emit = EmitStimMain(dialects=stim.main, io=buf) + stim_emit.run(test_spp_main) + assert buf.getvalue().strip() == "SPP !X0*X1*Z2 Y3*X4*!Y5" diff --git a/test/stim/dialects/stim/test_stim_circuits.py b/test/stim/dialects/stim/test_stim_circuits.py index 84caf6ab..8b18431c 100644 --- a/test/stim/dialects/stim/test_stim_circuits.py +++ b/test/stim/dialects/stim/test_stim_circuits.py @@ -1,9 +1,11 @@ import re +from io import StringIO from bloqade import stim from bloqade.stim.emit import EmitStimMain -interp = EmitStimMain(stim.main) +buf = StringIO() +interp = EmitStimMain(stim.main, io=buf) def test_gates(): @@ -17,15 +19,23 @@ def test_single_qubit_gates(): stim.s(targets=(0, 1, 2), dagger=False) stim.s(targets=(0, 1, 2), dagger=True) - interp.run(test_single_qubit_gates, args=()) - print(interp.get_output()) + interp.run(test_single_qubit_gates) + expected = """SQRT_Z 0 1 2 +X 0 1 2 +Y 0 1 +Z 1 2 +H 0 1 2 +S 0 1 2 +S_DAG 0 1 2""" + assert buf.getvalue().strip() == expected @stim.main def test_two_qubit_gates(): stim.swap(targets=(2, 3)) - interp.run(test_two_qubit_gates, args=()) - print(interp.get_output()) + interp.run(test_two_qubit_gates) + expected = "SWAP 2 3" + assert buf.getvalue().strip() == expected @stim.main def test_controlled_two_qubit_gates(): @@ -33,8 +43,11 @@ def test_controlled_two_qubit_gates(): stim.cy(controls=(0, 1), targets=(2, 3), dagger=True) stim.cz(controls=(0, 1), targets=(2, 3)) - interp.run(test_controlled_two_qubit_gates, args=()) - print(interp.get_output()) + interp.run(test_controlled_two_qubit_gates) + expected = """CX 0 2 1 3 +CY 0 2 1 3 +CZ 0 2 1 3""" + assert buf.getvalue().strip() == expected # @stim.main # def test_spp(): @@ -45,14 +58,19 @@ def test_controlled_two_qubit_gates(): # print(interp.get_output()) +test_gates() + + def test_noise(): @stim.main def test_depolarize(): stim.depolarize1(p=0.1, targets=(0, 1, 2)) stim.depolarize2(p=0.1, targets=(0, 1)) - interp.run(test_depolarize, args=()) - print(interp.get_output()) + interp.run(test_depolarize) + expected = """DEPOLARIZE1(0.10000000) 0 1 2 +DEPOLARIZE2(0.10000000) 0 1""" + assert buf.getvalue().strip() == expected @stim.main def test_pauli_channel(): @@ -76,8 +94,10 @@ def test_pauli_channel(): targets=(0, 1, 2, 3), ) - interp.run(test_pauli_channel, args=()) - print(interp.get_output()) + interp.run(test_pauli_channel) + expected = """PAULI_CHANNEL_1(0.01000000, 0.01000000, 0.10000000) 0 1 2 +PAULI_CHANNEL_2(0.01000000, 0.01000000, 0.10000000, 0.01000000, 0.01000000, 0.01000000, 0.10000000, 0.01000000, 0.01000000, 0.01000000, 0.10000000, 0.10000000, 0.10000000, 0.10000000, 0.20000000) 0 1 2 3""" + assert buf.getvalue().strip() == expected @stim.main def test_pauli_error(): @@ -85,15 +105,19 @@ def test_pauli_error(): stim.y_error(p=0.1, targets=(0, 1)) stim.z_error(p=0.1, targets=(1, 2)) - interp.run(test_pauli_error, args=()) - print(interp.get_output()) + interp.run(test_pauli_error) + expected = """X_ERROR(0.10000000) 0 1 2 +Y_ERROR(0.10000000) 0 1 +Z_ERROR(0.10000000) 1 2""" + assert buf.getvalue().strip() == expected @stim.main def test_qubit_loss(): stim.qubit_loss(probs=(0.1, 0.2), targets=(0, 1, 2)) - interp.run(test_qubit_loss, args=()) - assert interp.get_output() == "\nI_ERROR[loss](0.10000000, 0.20000000) 0 1 2" + interp.run(test_qubit_loss) + expected = "I_ERROR[loss](0.10000000, 0.20000000) 0 1 2" + assert buf.getvalue().strip() == expected def test_correlated_qubit_loss(): @@ -102,10 +126,9 @@ def test_correlated_qubit_loss(): def test_correlated_qubit_loss(): stim.correlated_qubit_loss(probs=(0.1,), targets=(0, 3, 1)) - interp.run(test_correlated_qubit_loss, args=()) - + interp.run(test_correlated_qubit_loss) assert re.match( - r"\nI_ERROR\[correlated_loss:\d+\]\(0\.10000000\) 0 3 1", interp.get_output() + r"I_ERROR\[correlated_loss:\d+\]\(0\.10000000\) 0 3 1", buf.getvalue().strip() ) @@ -119,8 +142,14 @@ def test_measure(): stim.myy(p=0.04, targets=(0, 1)) stim.mxx(p=0.05, targets=(1, 2)) - interp.run(test_measure, args=()) - print(interp.get_output()) + interp.run(test_measure) + expected = """MX(0.00000000) 0 1 2 +MY(0.01000000) 0 1 +MZ(0.02000000) 1 2 +MZZ(0.03000000) 0 1 2 3 +MYY(0.04000000) 0 1 +MXX(0.05000000) 1 2""" + assert buf.getvalue().strip() == expected @stim.main def test_reset(): @@ -128,8 +157,11 @@ def test_reset(): stim.ry(targets=(0, 1)) stim.rz(targets=(1, 2)) - interp.run(test_reset, args=()) - print(interp.get_output()) + interp.run(test_reset) + expected = """RX 0 1 2 +RY 0 1 +RZ 1 2""" + assert buf.getvalue().strip() == expected def test_repetition(): @@ -160,5 +192,29 @@ def test_repetition_memory(): stim.detector(coord=(3, 2), targets=(stim.rec(-1), stim.rec(-2), stim.rec(-4))) stim.observable_include(idx=0, targets=(stim.rec(-1),)) - interp.run(test_repetition_memory, args=()) - print(interp.get_output()) + interp.run(test_repetition_memory) + expected = """RZ 0 1 2 3 4 +TICK +DEPOLARIZE1(0.10000000) 0 2 4 +CX 0 1 2 3 +TICK +CX 2 1 4 3 +TICK +MZ(0.10000000) 1 3 +DETECTOR(1, 0) rec[-2] +DETECTOR(3, 0) rec[-1] +RZ 1 3 +TICK +DEPOLARIZE1(0.10000000) 0 2 4 +CX 0 1 2 3 +TICK +CX 2 1 4 3 +TICK +MZ(0.10000000) 1 3 +DETECTOR(1, 1) rec[-2] rec[-4] +DETECTOR(3, 1) rec[-1] rec[-3] +MZ(0.10000000) 0 2 4 +DETECTOR(1, 2) rec[-2] rec[-3] rec[-5] +DETECTOR(3, 2) rec[-1] rec[-2] rec[-4] +OBSERVABLE_INCLUDE(0) rec[-1]""" + assert buf.getvalue().strip() == expected diff --git a/test/stim/parse/base.py b/test/stim/parse/base.py index a07f4456..8ef8974d 100644 --- a/test/stim/parse/base.py +++ b/test/stim/parse/base.py @@ -1,12 +1,16 @@ +from io import StringIO + from kirin import ir +from bloqade import stim from bloqade.stim.emit import EmitStimMain -emit = EmitStimMain() +buf = StringIO() +emit = EmitStimMain(stim.main, io=buf) def codegen(mt: ir.Method): # method should not have any arguments! emit.initialize() - emit.run(mt=mt, args=()) - return emit.get_output() + emit.run(node=mt) + return buf.getvalue().strip() diff --git a/test/stim/passes/test_squin_debug_to_stim.py b/test/stim/passes/test_squin_debug_to_stim.py index 7319bc47..26aa4613 100644 --- a/test/stim/passes/test_squin_debug_to_stim.py +++ b/test/stim/passes/test_squin_debug_to_stim.py @@ -1,8 +1,10 @@ +import io import os from kirin import ir from kirin.dialects import py, debug +from bloqade import stim from bloqade.squin import kernel from bloqade.stim.emit import EmitStimMain from bloqade.stim.passes import SquinToStimPass @@ -11,10 +13,11 @@ # Taken gratuitously from Kai's unit test def codegen(mt: ir.Method): # method should not have any arguments! - emit = EmitStimMain() + buf = io.StringIO() + emit = EmitStimMain(dialects=stim.main, io=buf) emit.initialize() - emit.run(mt=mt, args=()) - return emit.get_output() + emit.run(mt) + return buf.getvalue().strip() def as_int(value: int): diff --git a/test/stim/passes/test_squin_meas_to_stim.py b/test/stim/passes/test_squin_meas_to_stim.py index d52555e1..dd323051 100644 --- a/test/stim/passes/test_squin_meas_to_stim.py +++ b/test/stim/passes/test_squin_meas_to_stim.py @@ -1,9 +1,10 @@ +import io import os from kirin import ir from kirin.dialects.ilist import IList -from bloqade import squin as sq +from bloqade import stim, squin as sq from bloqade.types import MeasurementResult from bloqade.stim.emit import EmitStimMain from bloqade.stim.passes import SquinToStimPass @@ -11,10 +12,11 @@ def codegen(mt: ir.Method): # method should not have any arguments! - emit = EmitStimMain() + buf = io.StringIO() + emit = EmitStimMain(dialects=stim.main, io=buf) emit.initialize() - emit.run(mt=mt, args=()) - return emit.get_output().strip() + emit.run(mt) + return buf.getvalue().strip() def load_reference_program(filename): diff --git a/test/stim/passes/test_squin_noise_to_stim.py b/test/stim/passes/test_squin_noise_to_stim.py index 985178d5..dc3346d5 100644 --- a/test/stim/passes/test_squin_noise_to_stim.py +++ b/test/stim/passes/test_squin_noise_to_stim.py @@ -1,3 +1,4 @@ +import io import os import kirin.types as kirin_types @@ -6,7 +7,7 @@ from kirin.rewrite import Walk from kirin.dialects import ilist -from bloqade import squin as sq +from bloqade import stim, squin as sq from bloqade.squin import noise, kernel from bloqade.types import Qubit, QubitType from bloqade.stim.emit import EmitStimMain @@ -18,10 +19,11 @@ def codegen(mt: ir.Method): # method should not have any arguments! - emit = EmitStimMain() + buf = io.StringIO() + emit = EmitStimMain(dialects=stim.main, io=buf) emit.initialize() - emit.run(mt=mt, args=()) - return emit.get_output().strip() + emit.run(mt) + return buf.getvalue().strip() def load_reference_program(filename): @@ -298,7 +300,7 @@ def test(): NonExistentNoiseChannel(qubits=q) return - frame, _ = AddressAnalysis(test.dialects).run_analysis(test) + frame, _ = AddressAnalysis(test.dialects).run(test) WrapAddressAnalysis(address_analysis=frame.entries).rewrite(test.code) rewrite_result = Walk(SquinNoiseToStim()).rewrite(test.code) @@ -319,7 +321,7 @@ def test(): sq.x(qubit=q[0]) return - frame, _ = AddressAnalysis(test.dialects).run_analysis(test) + frame, _ = AddressAnalysis(test.dialects).run(test) WrapAddressAnalysis(address_analysis=frame.entries).rewrite(test.code) rewrite_result = Walk(SquinNoiseToStim()).rewrite(test.code) diff --git a/test/stim/passes/test_squin_qubit_to_stim.py b/test/stim/passes/test_squin_qubit_to_stim.py index 84a90928..63d644a3 100644 --- a/test/stim/passes/test_squin_qubit_to_stim.py +++ b/test/stim/passes/test_squin_qubit_to_stim.py @@ -1,3 +1,4 @@ +import io import os import math from math import pi @@ -5,7 +6,7 @@ from kirin import ir from kirin.dialects import py -from bloqade import qubit, squin as sq +from bloqade import stim, qubit, squin as sq from bloqade.squin import kernel from bloqade.stim.emit import EmitStimMain from bloqade.stim.passes import SquinToStimPass @@ -15,10 +16,11 @@ # Taken gratuitously from Kai's unit test def codegen(mt: ir.Method): # method should not have any arguments! - emit = EmitStimMain() + buf = io.StringIO() + emit = EmitStimMain(dialects=stim.main, io=buf) emit.initialize() - emit.run(mt=mt, args=()) - return emit.get_output() + emit.run(mt) + return buf.getvalue().strip() def as_int(value: int): diff --git a/test/stim/test_measure_id_analysis.py b/test/stim/test_measure_id_analysis.py index 3f7a79da..b39461c6 100644 --- a/test/stim/test_measure_id_analysis.py +++ b/test/stim/test_measure_id_analysis.py @@ -1,3 +1,5 @@ +import pytest + from bloqade import qubit from bloqade.squin import kernel, qalloc from bloqade.analysis.measure_id import MeasurementIDAnalysis @@ -17,12 +19,13 @@ def main(): res = (meas_res[0], meas_res[1], meas_res[2]) return res - main.print() + # main.print() - frame, _ = MeasurementIDAnalysis(kernel).run_analysis(main) - main.print(analysis=frame.entries) + frame, _ = MeasurementIDAnalysis(kernel).run(main) + # main.print(analysis=frame.entries) +@pytest.mark.xfail def test_scf_measure_analysis(): @kernel def main(): @@ -40,5 +43,5 @@ def main(): main.print() - frame, _ = MeasurementIDAnalysis(kernel).run_analysis(main) + frame, _ = MeasurementIDAnalysis(kernel).run(main) main.print(analysis=frame.entries) From 8f35f99a76109c1eed01a5abcefc2c35749443e4 Mon Sep 17 00:00:00 2001 From: David Plankensteiner Date: Fri, 7 Nov 2025 10:00:55 +0100 Subject: [PATCH 4/6] Update remaining code base to new kirin version (#600) CI is blocked by: * Requires kirin main * This kirin PR: https://github.com/QuEraComputing/kirin/pull/563 * This kirin issue: https://github.com/QuEraComputing/kirin/issues/564 Other than that, we should be good. I'm still targeting the kirin upgrade branch to make review easier, but this once the above issues are resolved, this can actually go into `main`. @weinbe58 please have a look at the address analysis, specifically at the `run_lattice` method. I had to change the signature a bit. --------- Co-authored-by: kaihsin --- src/bloqade/analysis/address/analysis.py | 36 +++++++-------- src/bloqade/analysis/address/impls.py | 46 +++++++++---------- src/bloqade/analysis/fidelity/analysis.py | 25 ++-------- src/bloqade/analysis/measure_id/analysis.py | 16 +++---- src/bloqade/analysis/measure_id/impls.py | 9 ++-- src/bloqade/cirq_utils/emit/base.py | 33 +------------ src/bloqade/cirq_utils/lowering.py | 4 +- .../analysis/logical_validation/analysis.py | 6 +-- src/bloqade/native/upstream/squin2native.py | 16 ++++--- src/bloqade/pyqrack/target.py | 3 +- src/bloqade/pyqrack/task.py | 12 ++--- src/bloqade/qasm2/_qasm_loading.py | 6 +-- src/bloqade/qasm2/dialects/expr/stmts.py | 14 +++--- src/bloqade/qasm2/dialects/noise/fidelity.py | 4 +- src/bloqade/qasm2/emit/base.py | 5 -- src/bloqade/qasm2/emit/target.py | 12 ++--- src/bloqade/qasm2/parse/lowering.py | 1 - src/bloqade/qasm2/rewrite/uop_to_parallel.py | 2 +- src/bloqade/qbraid/lowering.py | 1 + src/bloqade/qbraid/schema.py | 4 +- src/bloqade/squin/analysis/schedule.py | 15 +++--- src/bloqade/validation/analysis/analysis.py | 14 +++--- src/bloqade/validation/kernel_validation.py | 18 +++++++- test/analysis/address/test_qubit_analysis.py | 25 +++++----- test/analysis/fidelity/test_fidelity.py | 16 +++---- test/cirq_utils/test_cirq_to_squin.py | 4 +- test/gemini/test_logical_validation.py | 10 ++-- test/pyqrack/runtime/noise/qasm2/test_loss.py | 8 ++-- .../pyqrack/runtime/noise/qasm2/test_pauli.py | 2 +- test/pyqrack/runtime/test_qrack.py | 2 +- test/qasm2/emit/test_qasm2_emit.py | 5 -- test/qasm2/passes/test_global_to_parallel.py | 3 -- test/qasm2/passes/test_global_to_uop.py | 2 - test/qasm2/passes/test_heuristic_noise.py | 2 - test/qasm2/passes/test_parallel_to_global.py | 7 --- test/qasm2/passes/test_parallel_to_uop.py | 2 - test/qasm2/passes/test_uop_to_parallel.py | 5 -- test/qasm2/test_count.py | 20 +++----- test/qasm2/test_lowering.py | 7 --- test/qasm2/test_native.py | 2 - test/qbraid/test_lowering.py | 21 ++++++--- test/squin/test_qubit.py | 6 +-- .../stim_reference_programs/debug/debug.stim | 1 - .../qubit/for_loop_nontrivial_index.stim | 1 - .../qubit/nested_for_loop.stim | 1 - .../qubit/nested_list.stim | 1 - .../qubit/non_pure_loop_iterator.stim | 1 - .../qubit/pick_if_else.stim | 1 - .../stim_reference_programs/qubit/qubit.stim | 1 - .../qubit/qubit_broadcast.stim | 1 - .../qubit/qubit_loss.stim | 1 - .../qubit/qubit_reset.stim | 1 - .../qubit/rep_code.stim | 1 - .../qubit/u3_gates.stim | 1 - .../qubit/u3_to_clifford.stim | 1 - test/stim/passes/test_squin_noise_to_stim.py | 7 +-- 56 files changed, 188 insertions(+), 283 deletions(-) diff --git a/src/bloqade/analysis/address/analysis.py b/src/bloqade/analysis/address/analysis.py index 1f78cfdb..921f4b3b 100644 --- a/src/bloqade/analysis/address/analysis.py +++ b/src/bloqade/analysis/address/analysis.py @@ -15,7 +15,7 @@ class AddressAnalysis(Forward[Address]): This analysis pass can be used to track the global addresses of qubits and wires. """ - keys = ["qubit.address"] + keys = ("qubit.address",) _const_prop: const.Propagate lattice = Address next_address: int = field(init=False) @@ -45,7 +45,7 @@ def try_eval_const_prop( ) -> interp.StatementResult[Address]: _frame = self._const_prop.initialize_frame(frame.code) _frame.set_values(stmt.args, tuple(x.result for x in args)) - result = self._const_prop.eval_stmt(_frame, stmt) + result = self._const_prop.frame_eval(_frame, stmt) match result: case interp.ReturnValue(constant_ret): @@ -96,7 +96,8 @@ def run_lattice( self, callee: Address, inputs: tuple[Address, ...], - kwargs: tuple[str, ...], + keys: tuple[str, ...], + kwargs: tuple[Address, ...], ) -> Address: """Run a callable lattice element with the given inputs and keyword arguments. @@ -111,15 +112,16 @@ def run_lattice( """ match callee: - case PartialLambda(code=code, argnames=argnames): - _, ret = self.run_callable( - code, (callee,) + self.permute_values(argnames, inputs, kwargs) + case PartialLambda(code=code): + _, ret = self.call( + code, callee, *inputs, **{k: v for k, v in zip(keys, kwargs)} ) - return ret case ConstResult(const.Value(ir.Method() as method)): - _, ret = self.run_method( - method, - self.permute_values(method.arg_names, inputs, kwargs), + _, ret = self.call( + method.code, + self.method_self(method), + *inputs, + **{k: v for k, v in zip(keys, kwargs)}, ) return ret case _: @@ -137,14 +139,12 @@ def get_const_value(self, addr: Address, typ: Type[T]) -> T | None: return value - def eval_stmt_fallback(self, frame: ForwardFrame[Address], stmt: ir.Statement): - args = frame.get_values(stmt.args) + def eval_fallback(self, frame: ForwardFrame[Address], node: ir.Statement): + args = frame.get_values(node.args) if types.is_tuple_of(args, ConstResult): - return self.try_eval_const_prop(frame, stmt, args) + return self.try_eval_const_prop(frame, node, args) - return tuple(Address.from_type(result.type) for result in stmt.results) + return tuple(Address.from_type(result.type) for result in node.results) - def run_method(self, method: ir.Method, args: tuple[Address, ...]): - # NOTE: we do not support dynamic calls here, thus no need to propagate method object - self_mt = ConstResult(const.Value(method)) - return self.run_callable(method.code, (self_mt,) + args) + def method_self(self, method: ir.Method) -> Address: + return ConstResult(const.Value(method)) diff --git a/src/bloqade/analysis/address/impls.py b/src/bloqade/analysis/address/impls.py index 1a89bb3e..d7932986 100644 --- a/src/bloqade/analysis/address/impls.py +++ b/src/bloqade/analysis/address/impls.py @@ -97,7 +97,7 @@ def map_( results = [] for ele in values: - ret = interp_.run_lattice(fn, (ele,), ()) + ret = interp_.run_lattice(fn, (ele,), (), ()) results.append(ret) if isinstance(stmt, ilist.Map): @@ -180,13 +180,10 @@ def invoke( frame: ForwardFrame[Address], stmt: func.Invoke, ): - - args = interp_.permute_values( - stmt.callee.arg_names, frame.get_values(stmt.inputs), stmt.kwargs - ) - _, ret = interp_.run_method( - stmt.callee, - args, + _, ret = interp_.call( + stmt.callee.code, + interp_.method_self(stmt.callee), + *frame.get_values(stmt.inputs), ) return (ret,) @@ -219,7 +216,8 @@ def call( result = interp_.run_lattice( frame.get(stmt.callee), frame.get_values(stmt.inputs), - stmt.kwargs, + stmt.keys, + frame.get_values(stmt.kwargs), ) return (result,) @@ -319,26 +317,28 @@ def ifelse( ): body = stmt.then_body if const_cond.data else stmt.else_body with interp_.new_frame(stmt, has_parent_access=True) as body_frame: - ret = interp_.run_ssacfg_region(body_frame, body, (address_cond,)) + ret = interp_.frame_call_region(body_frame, stmt, body, address_cond) # interp_.set_values(frame, body_frame.entries.keys(), body_frame.entries.values()) return ret else: # run both branches with interp_.new_frame(stmt, has_parent_access=True) as then_frame: - then_results = interp_.run_ssacfg_region( - then_frame, stmt.then_body, (address_cond,) - ) - interp_.set_values( - frame, then_frame.entries.keys(), then_frame.entries.values() + then_results = interp_.frame_call_region( + then_frame, + stmt, + stmt.then_body, + address_cond, ) + frame.set_values(then_frame.entries.keys(), then_frame.entries.values()) with interp_.new_frame(stmt, has_parent_access=True) as else_frame: - else_results = interp_.run_ssacfg_region( - else_frame, stmt.else_body, (address_cond,) - ) - interp_.set_values( - frame, else_frame.entries.keys(), else_frame.entries.values() + else_results = interp_.frame_call_region( + else_frame, + stmt, + stmt.else_body, + address_cond, ) + frame.set_values(else_frame.entries.keys(), else_frame.entries.values()) # TODO: pick the non-return value if isinstance(then_results, interp.ReturnValue) and isinstance( else_results, interp.ReturnValue @@ -364,12 +364,12 @@ def for_loop( iter_type, iterable = interp_.unpack_iterable(frame.get(stmt.iterable)) if iter_type is None: - return interp_.eval_stmt_fallback(frame, stmt) + return interp_.eval_fallback(frame, stmt) for value in iterable: with interp_.new_frame(stmt, has_parent_access=True) as body_frame: - loop_vars = interp_.run_ssacfg_region( - body_frame, stmt.body, (value,) + loop_vars + loop_vars = interp_.frame_call_region( + body_frame, stmt, stmt.body, value, *loop_vars ) if loop_vars is None: diff --git a/src/bloqade/analysis/fidelity/analysis.py b/src/bloqade/analysis/fidelity/analysis.py index f1ad252f..815b5725 100644 --- a/src/bloqade/analysis/fidelity/analysis.py +++ b/src/bloqade/analysis/fidelity/analysis.py @@ -4,7 +4,6 @@ from kirin import ir from kirin.lattice import EmptyLattice from kirin.analysis import Forward -from kirin.interp.value import Successor from kirin.analysis.forward import ForwardFrame from ..address import Address, AddressAnalysis @@ -48,15 +47,11 @@ def main(): The fidelity of the gate set described by the analysed program. It reduces whenever a noise channel is encountered. """ - _current_gate_fidelity: float = field(init=False) - atom_survival_probability: list[float] = field(init=False) """ The probabilities that each of the atoms in the register survive the duration of the analysed program. The order of the list follows the order they are in the register. """ - _current_atom_survival_probability: list[float] = field(init=False) - addr_frame: ForwardFrame[Address] = field(init=False) def initialize(self): @@ -67,25 +62,15 @@ def initialize(self): ] return self - def posthook_succ(self, frame: ForwardFrame, succ: Successor): - self.gate_fidelity *= self._current_gate_fidelity - for i, _current_survival in enumerate(self._current_atom_survival_probability): - self.atom_survival_probability[i] *= _current_survival - - def eval_stmt_fallback(self, frame: ForwardFrame, stmt: ir.Statement): + def eval_fallback(self, frame: ForwardFrame, node: ir.Statement): # NOTE: default is to conserve fidelity, so do nothing here return - def run_method(self, method: ir.Method, args: tuple[EmptyLattice, ...]): - return self.run_callable(method.code, (self.lattice.bottom(),) + args) - - def run_analysis( - self, method: ir.Method, args: tuple | None = None, *, no_raise: bool = True - ) -> tuple[ForwardFrame, Any]: - self._run_address_analysis(method, no_raise=no_raise) - return super().run(method) + def run(self, method: ir.Method, *args, **kwargs) -> tuple[ForwardFrame, Any]: + self._run_address_analysis(method) + return super().run(method, *args, **kwargs) - def _run_address_analysis(self, method: ir.Method, no_raise: bool): + def _run_address_analysis(self, method: ir.Method): addr_analysis = AddressAnalysis(self.dialects) addr_frame, _ = addr_analysis.run(method=method) self.addr_frame = addr_frame diff --git a/src/bloqade/analysis/measure_id/analysis.py b/src/bloqade/analysis/measure_id/analysis.py index f2d5e9f3..8b65b2f3 100644 --- a/src/bloqade/analysis/measure_id/analysis.py +++ b/src/bloqade/analysis/measure_id/analysis.py @@ -22,20 +22,16 @@ class MeasurementIDAnalysis(ForwardExtra[MeasureIDFrame, MeasureId]): measure_count = 0 def initialize_frame( - self, code: ir.Statement, *, has_parent_access: bool = False + self, node: ir.Statement, *, has_parent_access: bool = False ) -> MeasureIDFrame: - return MeasureIDFrame(code, has_parent_access=has_parent_access) + return MeasureIDFrame(node, has_parent_access=has_parent_access) # Still default to bottom, # but let constants return the softer "NoMeasureId" type from impl - def eval_stmt_fallback( - self, frame: ForwardFrame[MeasureId], stmt: ir.Statement + def eval_fallback( + self, frame: ForwardFrame[MeasureId], node: ir.Statement ) -> tuple[MeasureId, ...]: - return tuple(NotMeasureId() for _ in stmt.results) - - def run_method(self, method: ir.Method, args: tuple[MeasureId, ...]): - # NOTE: we do not support dynamic calls here, thus no need to propagate method object - return self.run_callable(method.code, (self.lattice.bottom(),) + args) + return tuple(NotMeasureId() for _ in node.results) # Xiu-zhe (Roger) Luo came up with this in the address analysis, # reused here for convenience (now modified to be a bit more graceful) @@ -45,7 +41,7 @@ def run_method(self, method: ir.Method, args: tuple[MeasureId, ...]): T = TypeVar("T") def get_const_value( - self, input_type: type[T], value: ir.SSAValue + self, input_type: type[T] | tuple[type[T], ...], value: ir.SSAValue ) -> type[T] | None: if isinstance(hint := value.hints.get("const"), const.Value): data = hint.data diff --git a/src/bloqade/analysis/measure_id/impls.py b/src/bloqade/analysis/measure_id/impls.py index 993b97bd..439ae2a6 100644 --- a/src/bloqade/analysis/measure_id/impls.py +++ b/src/bloqade/analysis/measure_id/impls.py @@ -138,11 +138,10 @@ def return_(self, _: MeasurementIDAnalysis, frame: interp.Frame, stmt: func.Retu def invoke( self, interp_: MeasurementIDAnalysis, frame: interp.Frame, stmt: func.Invoke ): - _, ret = interp_.run_method( - stmt.callee, - interp_.permute_values( - stmt.callee.arg_names, frame.get_values(stmt.inputs), stmt.kwargs - ), + _, ret = interp_.call( + stmt.callee.code, + interp_.method_self(stmt.callee), + *frame.get_values(stmt.inputs), ) return (ret,) diff --git a/src/bloqade/cirq_utils/emit/base.py b/src/bloqade/cirq_utils/emit/base.py index 199e49dc..4c831d81 100644 --- a/src/bloqade/cirq_utils/emit/base.py +++ b/src/bloqade/cirq_utils/emit/base.py @@ -189,39 +189,8 @@ def initialize_frame( node, has_parent_access=has_parent_access, qubits=self.qubits ) - def run_method(self, method: ir.Method, args: tuple[cirq.Circuit, ...]): - return self.call(method, *args) - - def run_callable_region( - self, - frame: EmitCirqFrame, - code: ir.Statement, - region: ir.Region, - args: tuple, - ): - if len(region.blocks) > 0: - block_args = list(region.blocks[0].args) - # NOTE: skip self arg - frame.set_values(block_args[1:], args) - - results = self.frame_eval(frame, code) - if isinstance(results, tuple): - if len(results) == 0: - return self.void - elif len(results) == 1: - return results[0] - raise interp.InterpreterError(f"Unexpected results {results}") - - def emit_block(self, frame: EmitCirqFrame, block: ir.Block) -> cirq.Circuit: - for stmt in block.stmts: - result = self.frame_eval(frame, stmt) - if isinstance(result, tuple): - frame.set_values(stmt.results, result) - - return self.circuit - def reset(self): - pass + self.circuit = cirq.Circuit() def eval_fallback(self, frame: EmitCirqFrame, node: ir.Statement) -> tuple: return tuple(None for _ in range(len(node.results))) diff --git a/src/bloqade/cirq_utils/lowering.py b/src/bloqade/cirq_utils/lowering.py index 7d81b45b..fc6d895a 100644 --- a/src/bloqade/cirq_utils/lowering.py +++ b/src/bloqade/cirq_utils/lowering.py @@ -99,7 +99,7 @@ def main(): ``` """ - target = Squin(dialects=dialects, circuit=circuit) + target = Squin(dialects, circuit) body = target.run( circuit, source=str(circuit), # TODO: proper source string @@ -144,8 +144,6 @@ def main(): ) mt = ir.Method( - mod=None, - py_func=None, sym_name=kernel_name, arg_names=arg_names, dialects=dialects, diff --git a/src/bloqade/gemini/analysis/logical_validation/analysis.py b/src/bloqade/gemini/analysis/logical_validation/analysis.py index 14a03cbf..4cf2d7a2 100644 --- a/src/bloqade/gemini/analysis/logical_validation/analysis.py +++ b/src/bloqade/gemini/analysis/logical_validation/analysis.py @@ -9,9 +9,9 @@ class GeminiLogicalValidationAnalysis(ValidationAnalysis): first_gate = True - def eval_stmt_fallback(self, frame: ValidationFrame, stmt: ir.Statement): - if isinstance(stmt, squin.gate.stmts.Gate): + def eval_fallback(self, frame: ValidationFrame, node: ir.Statement): + if isinstance(node, squin.gate.stmts.Gate): # NOTE: to validate that only the first encountered gate can be non-Clifford, we need to track this here self.first_gate = False - return super().eval_stmt_fallback(frame, stmt) + return super().eval_fallback(frame, node) diff --git a/src/bloqade/native/upstream/squin2native.py b/src/bloqade/native/upstream/squin2native.py index 2a9131f2..998d34b6 100644 --- a/src/bloqade/native/upstream/squin2native.py +++ b/src/bloqade/native/upstream/squin2native.py @@ -62,16 +62,18 @@ def emit(self, mt: ir.Method, *, no_raise=True) -> ir.Method: all_dialects = chain.from_iterable( ker.dialects.data for kers in old_callgraph.defs.values() for ker in kers ) - new_dialects = ( - mt.dialects.union(all_dialects).discard(gate_dialect).union(kernel) - ) + combined_dialects = mt.dialects.union(all_dialects).union(kernel) - out = mt.similar(new_dialects) - UpdateDialectsOnCallGraph(new_dialects, no_raise=no_raise)(out) - CallGraphPass(new_dialects, rewrite.Walk(GateRule()), no_raise=no_raise)(out) - # verify all kernels in the callgraph + out = mt.similar(combined_dialects) + UpdateDialectsOnCallGraph(combined_dialects, no_raise=no_raise)(out) + CallGraphPass(combined_dialects, rewrite.Walk(GateRule()), no_raise=no_raise)( + out + ) + # verify all kernels in the callgraph and discard gate dialect + out.dialects.discard(gate_dialect) new_callgraph = CallGraph(out) for ker in new_callgraph.edges.keys(): + ker.dialects.discard(gate_dialect) ker.verify() return out diff --git a/src/bloqade/pyqrack/target.py b/src/bloqade/pyqrack/target.py index e9f21233..54700933 100644 --- a/src/bloqade/pyqrack/target.py +++ b/src/bloqade/pyqrack/target.py @@ -87,7 +87,8 @@ def run( """ fold = Fold(mt.dialects) fold(mt) - return self._get_interp(mt).run(mt, args, kwargs) + _, ret = self._get_interp(mt).run(mt, *args, **kwargs) + return ret def multi_run( self, diff --git a/src/bloqade/pyqrack/task.py b/src/bloqade/pyqrack/task.py index 1502f430..0acb6ef0 100644 --- a/src/bloqade/pyqrack/task.py +++ b/src/bloqade/pyqrack/task.py @@ -24,14 +24,12 @@ class PyQrackSimulatorTask(AbstractSimulatorTask[Param, RetType, MemoryType]): pyqrack_interp: PyQrackInterpreter[MemoryType] def run(self) -> RetType: - return cast( - RetType, - self.pyqrack_interp.run( - self.kernel, - args=self.args, - kwargs=self.kwargs, - ), + _, ret = self.pyqrack_interp.run( + self.kernel, + *self.args, + **self.kwargs, ) + return cast(RetType, ret) @property def state(self) -> MemoryType: diff --git a/src/bloqade/qasm2/_qasm_loading.py b/src/bloqade/qasm2/_qasm_loading.py index 63ffcd5f..57ee5815 100644 --- a/src/bloqade/qasm2/_qasm_loading.py +++ b/src/bloqade/qasm2/_qasm_loading.py @@ -4,6 +4,7 @@ from typing import Any from kirin import ir, lowering +from kirin.types import MethodType from kirin.dialects import func from . import parse @@ -82,11 +83,10 @@ def loads( body=body, ) + body.blocks[0].args.append_from(MethodType, kernel_name + "_self") + mt = ir.Method( - mod=None, - py_func=None, sym_name=kernel_name, - arg_names=[], dialects=qasm2_lowering.dialects, code=code, ) diff --git a/src/bloqade/qasm2/dialects/expr/stmts.py b/src/bloqade/qasm2/dialects/expr/stmts.py index e2e130e6..fad08b33 100644 --- a/src/bloqade/qasm2/dialects/expr/stmts.py +++ b/src/bloqade/qasm2/dialects/expr/stmts.py @@ -87,7 +87,7 @@ def print_impl(self, printer: Printer) -> None: # QASM 2.0 arithmetic operations -PyNum = types.Union(types.Int, types.Float) +PyNum = types.TypeVar("PyNum", bound=types.Union(types.Int, types.Float)) @statement(dialect=dialect) @@ -110,7 +110,7 @@ class Sin(ir.Statement): traits = frozenset({lowering.FromPythonCall()}) value: ir.SSAValue = info.argument(PyNum) """value (Union[int, float]): The number to take the sine of.""" - result: ir.ResultValue = info.result(PyNum) + result: ir.ResultValue = info.result(types.Float) """result (float): The sine of the number.""" @@ -122,7 +122,7 @@ class Cos(ir.Statement): traits = frozenset({lowering.FromPythonCall()}) value: ir.SSAValue = info.argument(PyNum) """value (Union[int, float]): The number to take the cosine of.""" - result: ir.ResultValue = info.result(PyNum) + result: ir.ResultValue = info.result(types.Float) """result (float): The cosine of the number.""" @@ -134,7 +134,7 @@ class Tan(ir.Statement): traits = frozenset({lowering.FromPythonCall()}) value: ir.SSAValue = info.argument(PyNum) """value (Union[int, float]): The number to take the tangent of.""" - result: ir.ResultValue = info.result(PyNum) + result: ir.ResultValue = info.result(types.Float) """result (float): The tangent of the number.""" @@ -146,7 +146,7 @@ class Exp(ir.Statement): traits = frozenset({lowering.FromPythonCall()}) value: ir.SSAValue = info.argument(PyNum) """value (Union[int, float]): The number to take the exponential of.""" - result: ir.ResultValue = info.result(PyNum) + result: ir.ResultValue = info.result(types.Float) """result (float): The exponential of the number.""" @@ -158,7 +158,7 @@ class Log(ir.Statement): traits = frozenset({lowering.FromPythonCall()}) value: ir.SSAValue = info.argument(PyNum) """value (Union[int, float]): The number to take the natural log of.""" - result: ir.ResultValue = info.result(PyNum) + result: ir.ResultValue = info.result(types.Float) """result (float): The natural log of the number.""" @@ -170,7 +170,7 @@ class Sqrt(ir.Statement): traits = frozenset({lowering.FromPythonCall()}) value: ir.SSAValue = info.argument(PyNum) """value (Union[int, float]): The number to take the square root of.""" - result: ir.ResultValue = info.result(PyNum) + result: ir.ResultValue = info.result(types.Float) """result (float): The square root of the number.""" diff --git a/src/bloqade/qasm2/dialects/noise/fidelity.py b/src/bloqade/qasm2/dialects/noise/fidelity.py index f7ed75c4..acd17ac9 100644 --- a/src/bloqade/qasm2/dialects/noise/fidelity.py +++ b/src/bloqade/qasm2/dialects/noise/fidelity.py @@ -31,7 +31,7 @@ def pauli_channel( # NOTE: fidelity is just the inverse probability of any noise to occur fid = (1 - p) * (1 - p_ctrl) - interp._current_gate_fidelity *= fid + interp.gate_fidelity *= fid @interp.impl(AtomLossChannel) def atom_loss( @@ -44,4 +44,4 @@ def atom_loss( addresses = interp.addr_frame.get(stmt.qargs) # NOTE: get the corresponding index and reduce survival probability accordingly for index in addresses.data: - interp._current_atom_survival_probability[index] *= 1 - stmt.prob + interp.atom_survival_probability[index] *= 1 - stmt.prob diff --git a/src/bloqade/qasm2/emit/base.py b/src/bloqade/qasm2/emit/base.py index 4f7fba1d..4fac3217 100644 --- a/src/bloqade/qasm2/emit/base.py +++ b/src/bloqade/qasm2/emit/base.py @@ -45,11 +45,6 @@ def initialize_frame( ) -> EmitQASM2Frame[StmtType]: return EmitQASM2Frame(node, has_parent_access=has_parent_access) - def run_method( - self, method: ir.Method, args: tuple[ast.Node | None, ...] - ) -> tuple[EmitQASM2Frame[StmtType], ast.Node | None]: - return self.call(method, *args) - def emit_block(self, frame: EmitQASM2Frame, block: ir.Block) -> ast.Node | None: for stmt in block.stmts: result = self.frame_eval(frame, stmt) diff --git a/src/bloqade/qasm2/emit/target.py b/src/bloqade/qasm2/emit/target.py index a2548bba..784034a4 100644 --- a/src/bloqade/qasm2/emit/target.py +++ b/src/bloqade/qasm2/emit/target.py @@ -106,13 +106,13 @@ def emit(self, entry: ir.Method) -> ast.MainProgram: unroll_ifs=self.unroll_ifs, ).fixpoint(entry) - # if not self.allow_global: - # # rewrite global to parallel - # GlobalToParallel(dialects=entry.dialects)(entry) + if not self.allow_global: + # rewrite global to parallel + GlobalToParallel(dialects=entry.dialects)(entry) - # if not self.allow_parallel: - # # rewrite parallel to uop - # ParallelToUOp(dialects=entry.dialects)(entry) + if not self.allow_parallel: + # rewrite parallel to uop + ParallelToUOp(dialects=entry.dialects)(entry) Py2QASM(entry.dialects)(entry) target_main = EmitQASM2Main(self.main_target).initialize() diff --git a/src/bloqade/qasm2/parse/lowering.py b/src/bloqade/qasm2/parse/lowering.py index 765d1eb3..07d66d28 100644 --- a/src/bloqade/qasm2/parse/lowering.py +++ b/src/bloqade/qasm2/parse/lowering.py @@ -450,7 +450,6 @@ def visit_Instruction(self, state: lowering.State[ast.Node], node: ast.Instructi func.Invoke( callee=value, inputs=tuple(params + qargs), - kwargs=tuple(), ) ) diff --git a/src/bloqade/qasm2/rewrite/uop_to_parallel.py b/src/bloqade/qasm2/rewrite/uop_to_parallel.py index b1c102ed..5c85376b 100644 --- a/src/bloqade/qasm2/rewrite/uop_to_parallel.py +++ b/src/bloqade/qasm2/rewrite/uop_to_parallel.py @@ -66,7 +66,7 @@ def same_id_checker(ssa1: ir.SSAValue, ssa2: ir.SSAValue): assert isinstance(hint1, lattice.Result) and isinstance( hint2, lattice.Result ) - return hint1.is_equal(hint2) + return hint1.is_structurally_equal(hint2) else: return False diff --git a/src/bloqade/qbraid/lowering.py b/src/bloqade/qbraid/lowering.py index 15958071..276e62ff 100644 --- a/src/bloqade/qbraid/lowering.py +++ b/src/bloqade/qbraid/lowering.py @@ -320,5 +320,6 @@ def lower_full_turns(self, value: float) -> ir.SSAValue: self.block_list.append(const_pi) turns = self.lower_number(2 * value) mul = qasm2.expr.Mul(const_pi.result, turns) + mul.result.type = types.Float self.block_list.append(mul) return mul.result diff --git a/src/bloqade/qbraid/schema.py b/src/bloqade/qbraid/schema.py index 54eed1c2..450d4f5a 100644 --- a/src/bloqade/qbraid/schema.py +++ b/src/bloqade/qbraid/schema.py @@ -238,13 +238,13 @@ def decompiled_circuit(self) -> str: str: The decompiled circuit from hardware execution. """ - from bloqade.noise import native from bloqade.qasm2.emit import QASM2 from bloqade.qasm2.passes import glob, parallel + from bloqade.qasm2.rewrite.noise import remove_noise mt = self.lower_noise_model("method") - native.RemoveNoisePass(mt.dialects)(mt) + remove_noise.RemoveNoisePass(mt.dialects)(mt) parallel.ParallelToUOp(mt.dialects)(mt) glob.GlobalToUOP(mt.dialects)(mt) return QASM2(qelib1=True).emit_str(mt) diff --git a/src/bloqade/squin/analysis/schedule.py b/src/bloqade/squin/analysis/schedule.py index e99e219d..35487e08 100644 --- a/src/bloqade/squin/analysis/schedule.py +++ b/src/bloqade/squin/analysis/schedule.py @@ -185,18 +185,17 @@ def push_current_dag(self, block: ir.Block): self.stmt_dag = StmtDag() self.use_def = {} - def run_method(self, method: ir.Method, args: tuple[GateSchedule, ...]): - # NOTE: we do not support dynamic calls here, thus no need to propagate method object - return self.run_callable(method.code, (self.lattice.bottom(),) + args) + def method_self(self, method: ir.Method) -> GateSchedule: + return self.lattice.bottom() - def eval_stmt_fallback(self, frame: ForwardFrame, stmt: ir.Statement): - if stmt.has_trait(ir.IsTerminator): + def eval_fallback(self, frame: ForwardFrame, node: ir.Statement): + if node.has_trait(ir.IsTerminator): assert ( - stmt.parent_block is not None + node.parent_block is not None ), "Terminator statement has no parent block" - self.push_current_dag(stmt.parent_block) + self.push_current_dag(node.parent_block) - return tuple(self.lattice.top() for _ in stmt.results) + return tuple(self.lattice.top() for _ in node.results) def _update_dag(self, stmt: ir.Statement, addr: address.Address): if isinstance(addr, address.AddressQubit): diff --git a/src/bloqade/validation/analysis/analysis.py b/src/bloqade/validation/analysis/analysis.py index 323cbd40..19220b73 100644 --- a/src/bloqade/validation/analysis/analysis.py +++ b/src/bloqade/validation/analysis/analysis.py @@ -28,14 +28,14 @@ class ValidationAnalysis(ForwardExtra[ValidationFrame, ErrorType], ABC): lattice = ErrorType - def run_method(self, method: ir.Method, args: tuple[ErrorType, ...]): - return self.run_callable(method.code, (self.lattice.top(),) + args) - - def eval_stmt_fallback(self, frame: ValidationFrame, stmt: ir.Statement): + def eval_fallback(self, frame: ValidationFrame, node: ir.Statement): # NOTE: default to no errors - return tuple(self.lattice.top() for _ in stmt.results) + return tuple(self.lattice.top() for _ in node.results) def initialize_frame( - self, code: ir.Statement, *, has_parent_access: bool = False + self, node: ir.Statement, *, has_parent_access: bool = False ) -> ValidationFrame: - return ValidationFrame(code, has_parent_access=has_parent_access) + return ValidationFrame(node, has_parent_access=has_parent_access) + + def method_self(self, method: ir.Method) -> ErrorType: + return self.lattice.top() diff --git a/src/bloqade/validation/kernel_validation.py b/src/bloqade/validation/kernel_validation.py index 84159352..d3f82774 100644 --- a/src/bloqade/validation/kernel_validation.py +++ b/src/bloqade/validation/kernel_validation.py @@ -48,9 +48,23 @@ class KernelValidation: validation_analysis_cls: type[ValidationAnalysis] """The analysis that you want to run in order to validate the kernel.""" - def run(self, mt: ir.Method, **kwargs) -> None: + def run(self, mt: ir.Method, no_raise: bool = True) -> None: + """Run the kernel validation analysis and raise any errors found. + + Args: + mt (ir.Method): The method to validate + no_raise (bool): Whether or not to raise errors when running the analysis. + This is only to make sure the analysis works. Errors found during + the analysis will be raised regardless of this setting. Defaults to `True`. + + """ + validation_analysis = self.validation_analysis_cls(mt.dialects) - validation_frame, _ = validation_analysis.run_analysis(mt, **kwargs) + + if no_raise: + validation_frame, _ = validation_analysis.run_no_raise(mt) + else: + validation_frame, _ = validation_analysis.run(mt) errors = validation_frame.errors diff --git a/test/analysis/address/test_qubit_analysis.py b/test/analysis/address/test_qubit_analysis.py index 1866c9a1..dddf825a 100644 --- a/test/analysis/address/test_qubit_analysis.py +++ b/test/analysis/address/test_qubit_analysis.py @@ -21,7 +21,7 @@ def test(): return (q1[1], q2) address_analysis = address.AddressAnalysis(test.dialects) - frame, _ = address_analysis.run_analysis(test, no_raise=False) + frame, _ = address_analysis.run(test) address_types = collect_address_types(frame, address.PartialTuple) test.print(analysis=frame.entries) @@ -116,7 +116,7 @@ def main(): return q address_analysis = address.AddressAnalysis(main.dialects) - address_analysis.run_analysis(main, no_raise=False) + address_analysis.run(main) def test_new_qubit(): @@ -125,7 +125,7 @@ def main(): return squin.qubit.new() address_analysis = address.AddressAnalysis(main.dialects) - _, result = address_analysis.run_analysis(main, no_raise=False) + _, result = address_analysis.run(main) assert result == address.AddressQubit(0) @@ -139,8 +139,9 @@ def main(n: int): return qreg address_analysis = address.AddressAnalysis(main.dialects) - frame, result = address_analysis.run_analysis( - main, args=(address.ConstResult(const.Unknown()),), no_raise=False + frame, result = address_analysis.run( + main, + address.ConstResult(const.Unknown()), ) assert result == address.AddressReg(data=tuple(range(4))) @@ -155,7 +156,7 @@ def main(n: int): return qreg address_analysis = address.AddressAnalysis(main.dialects) - frame, result = address_analysis.run_analysis(main, no_raise=False) + frame, result = address_analysis.run(main) assert result == address.AddressReg(data=tuple(range(4))) @@ -165,7 +166,7 @@ def main(n: int): return (0, 1) + (2, n) address_analysis = address.AddressAnalysis(main.dialects) - frame, result = address_analysis.run_analysis(main, no_raise=False) + frame, result = address_analysis.run(main) assert result == address.PartialTuple( data=( @@ -183,7 +184,7 @@ def main(n: int): return (0, 1) + [2, n] # type: ignore address_analysis = address.AddressAnalysis(main.dialects) - frame, result = address_analysis.run_analysis(main, no_raise=False) + frame, result = address_analysis.run(main) assert result == address.Bottom() @@ -194,7 +195,7 @@ def main(n: tuple[int, ...]): return (0, 1) + n address_analysis = address.AddressAnalysis(main.dialects) - frame, result = address_analysis.run_analysis(main, no_raise=False) + frame, result = address_analysis.run(main) assert result == address.Unknown() @@ -207,7 +208,7 @@ def main(q: qubit.Qubit): return (0, q, 2, q)[1::2] address_analysis = address.AddressAnalysis(main.dialects) - frame, result = address_analysis.run_analysis(main, no_raise=False) + frame, result = address_analysis.run(main) assert result == address.UnknownReg() @@ -219,7 +220,7 @@ def main(n: int): main.print() address_analysis = address.AddressAnalysis(main.dialects) - frame, result = address_analysis.run_analysis(main, no_raise=False) + frame, result = address_analysis.run(main) main.print(analysis=frame.entries) assert ( result == address.UnknownReg() @@ -260,7 +261,7 @@ def main(): func = main analysis = address.AddressAnalysis(squin.kernel) - _, ret = analysis.run_analysis(func, no_raise=False) + _, ret = analysis.run(func) assert ret == address.AddressReg(data=tuple(range(20))) assert analysis.qubit_count == 20 diff --git a/test/analysis/fidelity/test_fidelity.py b/test/analysis/fidelity/test_fidelity.py index cbb39845..78ca8f26 100644 --- a/test/analysis/fidelity/test_fidelity.py +++ b/test/analysis/fidelity/test_fidelity.py @@ -19,7 +19,7 @@ def main(): return q fid_analysis = FidelityAnalysis(main.dialects) - fid_analysis.run_analysis(main, no_raise=False) + fid_analysis.run(main) assert fid_analysis.gate_fidelity == fid_analysis._current_gate_fidelity == 1 assert fid_analysis.atom_survival_probability[0] == 1 - p_loss @@ -49,11 +49,10 @@ def main(): main.print() fid_analysis = FidelityAnalysis(main.dialects) - fid_analysis.run_analysis(main, no_raise=False) + fid_analysis.run(main) expected_fidelity = (1 - 3 * p_ch) ** 2 - assert fid_analysis.gate_fidelity == fid_analysis._current_gate_fidelity assert math.isclose(fid_analysis.gate_fidelity, expected_fidelity) @@ -69,11 +68,10 @@ def main(): main.print() fid_analysis = FidelityAnalysis(main.dialects) - fid_analysis.run_analysis(main, no_raise=False) + fid_analysis.run(main) expected_fidelity = 1 - 3 * p_ch - assert fid_analysis.gate_fidelity == fid_analysis._current_gate_fidelity assert math.isclose(fid_analysis.gate_fidelity, expected_fidelity) @@ -123,12 +121,12 @@ def main_if(): ) NoisePass(main.dialects, noise_model=model)(main) fid_analysis = FidelityAnalysis(main.dialects) - fid_analysis.run_analysis(main, no_raise=False) + fid_analysis.run(main) model = NoiseTestModel() NoisePass(main_if.dialects, noise_model=model)(main_if) fid_if_analysis = FidelityAnalysis(main_if.dialects) - fid_if_analysis.run_analysis(main_if, no_raise=False) + fid_if_analysis.run(main_if) assert 0 < fid_if_analysis.gate_fidelity == fid_analysis.gate_fidelity < 1 assert ( @@ -186,7 +184,7 @@ def main_for(): ) NoisePass(main.dialects, noise_model=model)(main) fid_analysis = FidelityAnalysis(main.dialects) - fid_analysis.run_analysis(main, no_raise=False) + fid_analysis.run(main) model = NoiseTestModel() NoisePass(main_for.dialects, noise_model=model)(main_for) @@ -194,7 +192,7 @@ def main_for(): main_for.print() fid_for_analysis = FidelityAnalysis(main_for.dialects) - fid_for_analysis.run_analysis(main_for, no_raise=False) + fid_for_analysis.run(main_for) assert 0 < fid_for_analysis.gate_fidelity == fid_analysis.gate_fidelity < 1 assert ( diff --git a/test/cirq_utils/test_cirq_to_squin.py b/test/cirq_utils/test_cirq_to_squin.py index 101cd3be..479ba051 100644 --- a/test/cirq_utils/test_cirq_to_squin.py +++ b/test/cirq_utils/test_cirq_to_squin.py @@ -250,9 +250,9 @@ def test_nesting_lowered_circuit(): @squin.kernel def main(): qreg = get_entangled_qubits() - qreg2 = squin.squin.qalloc(1) + qreg2 = squin.qalloc(1) entangle_qubits([qreg[1], qreg2[0]]) - return squin.qubit.measure(qreg2) + return squin.broadcast.measure(qreg2) # if you get up to here, the validation works main.print() diff --git a/test/gemini/test_logical_validation.py b/test/gemini/test_logical_validation.py index ce8e1a34..046b2ac4 100644 --- a/test/gemini/test_logical_validation.py +++ b/test/gemini/test_logical_validation.py @@ -30,16 +30,14 @@ def main(): if m2: squin.y(q[2]) - frame, _ = GeminiLogicalValidationAnalysis(main.dialects).run_analysis( - main, no_raise=False - ) + frame, _ = GeminiLogicalValidationAnalysis(main.dialects).run_no_raise(main) main.print(analysis=frame.entries) validator = KernelValidation(GeminiLogicalValidationAnalysis) with pytest.raises(ValidationErrorGroup): - validator.run(main) + validator.run(main, no_raise=False) def test_for_loop(): @@ -104,8 +102,8 @@ def invalid(): squin.cx(q[0], q[1]) squin.u3(0.123, 0.253, 1.2, q[0]) - frame, _ = GeminiLogicalValidationAnalysis(invalid.dialects).run_analysis( - invalid, no_raise=False + frame, _ = GeminiLogicalValidationAnalysis(invalid.dialects).run_no_raise( + invalid ) invalid.print(analysis=frame.entries) diff --git a/test/pyqrack/runtime/noise/qasm2/test_loss.py b/test/pyqrack/runtime/noise/qasm2/test_loss.py index 4c1d19a1..29214c09 100644 --- a/test/pyqrack/runtime/noise/qasm2/test_loss.py +++ b/test/pyqrack/runtime/noise/qasm2/test_loss.py @@ -1,12 +1,10 @@ -from typing import Literal from unittest.mock import Mock from kirin import ir -from kirin.dialects import ilist from bloqade import qasm2 from bloqade.qasm2 import noise -from bloqade.pyqrack import PyQrackQubit, PyQrackInterpreter, reg +from bloqade.pyqrack import PyQrackInterpreter, reg from bloqade.pyqrack.base import MockMemory @@ -34,9 +32,9 @@ def test_atom_loss(c: qasm2.CReg): input = reg.CRegister(1) memory = MockMemory() - result: ilist.IList[PyQrackQubit, Literal[2]] = PyQrackInterpreter( + _, result = PyQrackInterpreter( qasm2.extended, memory=memory, rng_state=rng_state - ).run(test_atom_loss, (input,)) + ).run(test_atom_loss, input) assert result[0].state is reg.QubitState.Lost assert result[1].state is reg.QubitState.Active diff --git a/test/pyqrack/runtime/noise/qasm2/test_pauli.py b/test/pyqrack/runtime/noise/qasm2/test_pauli.py index 04541d23..2f9f2bb9 100644 --- a/test/pyqrack/runtime/noise/qasm2/test_pauli.py +++ b/test/pyqrack/runtime/noise/qasm2/test_pauli.py @@ -11,7 +11,7 @@ def run_mock(program: ir.Method, rng_state: Mock | None = None): PyQrackInterpreter( program.dialects, memory=(memory := MockMemory()), rng_state=rng_state - ).run(program, ()) + ).run(program) assert isinstance(mock := memory.sim_reg, Mock) return mock diff --git a/test/pyqrack/runtime/test_qrack.py b/test/pyqrack/runtime/test_qrack.py index 9161cabf..e6174797 100644 --- a/test/pyqrack/runtime/test_qrack.py +++ b/test/pyqrack/runtime/test_qrack.py @@ -14,7 +14,7 @@ def run_mock(program: ir.Method, rng_state: Mock | None = None): PyQrackInterpreter( program.dialects, memory=(memory := MockMemory()), rng_state=rng_state - ).run(program, ()) + ).run(program) assert isinstance(mock := memory.sim_reg, Mock) return mock diff --git a/test/qasm2/emit/test_qasm2_emit.py b/test/qasm2/emit/test_qasm2_emit.py index 00eaaec4..f7e776eb 100644 --- a/test/qasm2/emit/test_qasm2_emit.py +++ b/test/qasm2/emit/test_qasm2_emit.py @@ -54,7 +54,6 @@ def glob_u(): ) -@pytest.mark.xfail(reason="Global and Parallel refactor still ongoing") def test_global(): @qasm2.extended @@ -85,7 +84,6 @@ def glob_u(): ) -@pytest.mark.xfail(reason="Global and Parallel refactor still ongoing") def test_global_allow_para(): @qasm2.extended @@ -118,7 +116,6 @@ def glob_u(): ) -@pytest.mark.xfail(reason="Global and Parallel refactor still ongoing") def test_para(): @qasm2.extended @@ -145,7 +142,6 @@ def para_u(): ) -@pytest.mark.xfail(reason="Global and Parallel refactor still ongoing") def test_para_allow_para(): @qasm2.extended @@ -201,7 +197,6 @@ def para_u(): ) -@pytest.mark.xfail(reason="Global and Parallel refactor still ongoing") def test_para_allow_global(): @qasm2.extended diff --git a/test/qasm2/passes/test_global_to_parallel.py b/test/qasm2/passes/test_global_to_parallel.py index b77a3a13..3b147636 100644 --- a/test/qasm2/passes/test_global_to_parallel.py +++ b/test/qasm2/passes/test_global_to_parallel.py @@ -1,6 +1,5 @@ from typing import List -import pytest from kirin import ir, types from kirin.rewrite import Walk, Fixpoint, CommonSubexpressionElimination from kirin.dialects import py, func, ilist @@ -18,7 +17,6 @@ def as_float(value: float): return py.constant.Constant(value=value) -@pytest.mark.xfail def test_global2para_rewrite(): @qasm2.extended @@ -79,7 +77,6 @@ def main(): assert_methods(expected_method, main) -@pytest.mark.xfail def test_global2para_rewrite2(): @qasm2.extended diff --git a/test/qasm2/passes/test_global_to_uop.py b/test/qasm2/passes/test_global_to_uop.py index 9be187d9..dafe6f2a 100644 --- a/test/qasm2/passes/test_global_to_uop.py +++ b/test/qasm2/passes/test_global_to_uop.py @@ -1,6 +1,5 @@ from typing import List -import pytest from kirin import ir, types from kirin.rewrite import Walk, Fixpoint, CommonSubexpressionElimination from kirin.dialects import py, func @@ -18,7 +17,6 @@ def as_float(value: float): return py.constant.Constant(value=value) -@pytest.mark.xfail def test_global_rewrite(): @qasm2.extended diff --git a/test/qasm2/passes/test_heuristic_noise.py b/test/qasm2/passes/test_heuristic_noise.py index 20a4c5b0..78879a5e 100644 --- a/test/qasm2/passes/test_heuristic_noise.py +++ b/test/qasm2/passes/test_heuristic_noise.py @@ -1,4 +1,3 @@ -import pytest from kirin import ir, types from kirin.dialects import func, ilist from kirin.dialects.py import constant @@ -256,7 +255,6 @@ def test_parallel_cz_gate_noise(): assert_nodes(block, expected_block) -@pytest.mark.xfail def test_global_noise(): @qasm2.extended diff --git a/test/qasm2/passes/test_parallel_to_global.py b/test/qasm2/passes/test_parallel_to_global.py index 93fbac7f..c72533b8 100644 --- a/test/qasm2/passes/test_parallel_to_global.py +++ b/test/qasm2/passes/test_parallel_to_global.py @@ -1,10 +1,7 @@ -import pytest - from bloqade import qasm2 from bloqade.qasm2.passes.parallel import ParallelToGlobal -@pytest.mark.xfail def test_basic_rewrite(): @qasm2.extended @@ -32,7 +29,6 @@ def main(): ) -@pytest.mark.xfail def test_if_rewrite(): @qasm2.extended def main(): @@ -67,7 +63,6 @@ def main(): ) -@pytest.mark.xfail def test_should_not_be_rewritten(): @qasm2.extended @@ -93,7 +88,6 @@ def main(): ) -@pytest.mark.xfail def test_multiple_registers(): @qasm2.extended def main(): @@ -126,7 +120,6 @@ def main(): ) -@pytest.mark.xfail def test_reverse_order(): @qasm2.extended def main(): diff --git a/test/qasm2/passes/test_parallel_to_uop.py b/test/qasm2/passes/test_parallel_to_uop.py index 7484542e..c3e2c59e 100644 --- a/test/qasm2/passes/test_parallel_to_uop.py +++ b/test/qasm2/passes/test_parallel_to_uop.py @@ -1,6 +1,5 @@ from typing import List -import pytest from kirin import ir, types from kirin.dialects import py, func @@ -17,7 +16,6 @@ def as_float(value: float): return py.constant.Constant(value=value) -@pytest.mark.xfail def test_cz_rewrite(): @qasm2.extended diff --git a/test/qasm2/passes/test_uop_to_parallel.py b/test/qasm2/passes/test_uop_to_parallel.py index 7e1f3bdc..29231d39 100644 --- a/test/qasm2/passes/test_uop_to_parallel.py +++ b/test/qasm2/passes/test_uop_to_parallel.py @@ -1,5 +1,3 @@ -import pytest - from bloqade import qasm2 from bloqade.qasm2 import glob from bloqade.analysis import address @@ -7,7 +5,6 @@ from bloqade.qasm2.rewrite import SimpleOptimalMergePolicy -@pytest.mark.xfail def test_one(): @qasm2.gate @@ -50,7 +47,6 @@ def test(): ) -@pytest.mark.xfail def test_two(): @qasm2.extended @@ -89,7 +85,6 @@ def test(): _, _ = address.AddressAnalysis(test.dialects).run(test) -@pytest.mark.xfail def test_three(): @qasm2.extended diff --git a/test/qasm2/test_count.py b/test/qasm2/test_count.py index df5cef31..32aed54b 100644 --- a/test/qasm2/test_count.py +++ b/test/qasm2/test_count.py @@ -1,4 +1,3 @@ -import pytest from kirin import passes from kirin.dialects import py, ilist @@ -16,7 +15,6 @@ fold = passes.Fold(qasm2.main.add(py.tuple).add(ilist)) -@pytest.mark.xfail def test_fixed_count(): @qasm2.main def fixed_count(): @@ -36,7 +34,6 @@ def fixed_count(): assert address.qubit_count == 7 -@pytest.mark.xfail def test_multiple_return_only_reg(): @qasm2.main.add(py.tuple) @@ -54,7 +51,6 @@ def tuple_count(): assert isinstance(ret.data[1], AddressReg) and ret.data[1].data == range(3, 7) -@pytest.mark.xfail def test_dynamic_address(): @qasm2.main def dynamic_address(): @@ -64,14 +60,16 @@ def dynamic_address(): qasm2.measure(ra[0], ca[0]) qasm2.measure(rb[1], ca[1]) if ca[0] == ca[1]: - return ra + ret = ra else: - return rb + ret = rb + + return ret # dynamic_address.code.print() dynamic_address.print() fold(dynamic_address) - frame, result = address.run_analysis(dynamic_address) + frame, result = address.run(dynamic_address) dynamic_address.print(analysis=frame.entries) assert isinstance(result, Unknown) @@ -92,7 +90,6 @@ def dynamic_address(): # assert isinstance(result, ConstResult) -@pytest.mark.xfail def test_multi_return(): @qasm2.main.add(py.tuple) def multi_return_cnt(): @@ -110,7 +107,6 @@ def multi_return_cnt(): assert isinstance(result.data[2], AddressReg) -@pytest.mark.xfail def test_list(): @qasm2.main.add(ilist) def list_count_analy(): @@ -120,12 +116,11 @@ def list_count_analy(): return f list_count_analy.code.print() - _, ret = address.run_analysis(list_count_analy) + _, ret = address.run(list_count_analy) assert ret == AddressReg(data=(0, 1, 3)) -@pytest.mark.xfail def test_tuple_qubits(): @qasm2.main.add(py.tuple) def list_count_analy2(): @@ -136,7 +131,7 @@ def list_count_analy2(): return f list_count_analy2.code.print() - _, ret = address.run_analysis(list_count_analy2) + _, ret = address.run(list_count_analy2) assert isinstance(ret, PartialTuple) assert isinstance(ret.data[0], AddressQubit) and ret.data[0].data == 0 assert isinstance(ret.data[1], AddressQubit) and ret.data[1].data == 1 @@ -166,7 +161,6 @@ def list_count_analy2(): # assert isinstance(result.data[5], AddressQubit) and result.data[5].data == 6 -@pytest.mark.xfail def test_alias(): @qasm2.main diff --git a/test/qasm2/test_lowering.py b/test/qasm2/test_lowering.py index 6eee1509..617d7154 100644 --- a/test/qasm2/test_lowering.py +++ b/test/qasm2/test_lowering.py @@ -3,7 +3,6 @@ import tempfile import textwrap -import pytest from kirin import ir, types from kirin.dialects import func @@ -26,14 +25,12 @@ ) -@pytest.mark.xfail def test_run_lowering(): ast = qasm2.parse.loads(lines) code = QASM2(qasm2.main).run(ast) code.print() -@pytest.mark.xfail def test_loadfile(): with tempfile.TemporaryDirectory() as tmp_dir: @@ -44,7 +41,6 @@ def test_loadfile(): qasm2.loadfile(file) -@pytest.mark.xfail def test_negative_lowering(): mwe = """ @@ -84,7 +80,6 @@ def test_negative_lowering(): assert entry.code.is_structurally_equal(code) -@pytest.mark.xfail def test_gate(): qasm2_prog = textwrap.dedent( """ @@ -113,7 +108,6 @@ def test_gate(): assert math.isclose(abs(ket[3]) ** 2, 0.5, abs_tol=1e-6) -@pytest.mark.xfail def test_gate_with_params(): qasm2_prog = textwrap.dedent( """ @@ -144,7 +138,6 @@ def test_gate_with_params(): assert math.isclose(abs(ket[3]) ** 2, 0.5, abs_tol=1e-6) -@pytest.mark.xfail def test_if_lowering(): qasm2_prog = textwrap.dedent( diff --git a/test/qasm2/test_native.py b/test/qasm2/test_native.py index 15e6ebcb..fdfaf64d 100644 --- a/test/qasm2/test_native.py +++ b/test/qasm2/test_native.py @@ -3,7 +3,6 @@ import cirq import numpy as np -import pytest import cirq.testing import cirq.contrib.qasm_import as qasm_import import cirq.circuits.qasm_output as qasm_output @@ -158,7 +157,6 @@ def kernel(): assert new_qasm2.count("\n") > prog.count("\n") -@pytest.mark.xfail def test_ccx_rewrite(): @qasm2.extended diff --git a/test/qbraid/test_lowering.py b/test/qbraid/test_lowering.py index 323972a8..cabca44e 100644 --- a/test/qbraid/test_lowering.py +++ b/test/qbraid/test_lowering.py @@ -56,11 +56,8 @@ def run_assert(noise_model: schema.NoiseModel, expected_stmts: List[ir.Statement ) expected_mt = ir.Method( - mod=None, - py_func=None, dialects=lowering.qbraid_noise, sym_name="test", - arg_names=[], code=expected_func_stmt, ) @@ -242,7 +239,10 @@ def test_lowering_global_w(): (lam_num := as_float(2 * -(0.5 + phi_val))), (lam := qasm2.expr.Mul(pi_lam.result, lam_num.result)), parallel.UGate( - theta=theta.result, phi=phi.result, lam=lam.result, qargs=qargs.result + theta=ir.ResultValue(theta, 0, type=types.Float), + phi=ir.ResultValue(phi, 0, type=types.Float), + lam=ir.ResultValue(lam, 0, type=types.Float), + qargs=qargs.result, ), func.Return(creg.result), ] @@ -304,7 +304,10 @@ def test_lowering_local_w(): (lam_num := as_float(2 * -(0.5 + phi_val))), (lam := qasm2.expr.Mul(pi_lam.result, lam_num.result)), parallel.UGate( - qargs=qargs.result, theta=theta.result, phi=phi.result, lam=lam.result + qargs=qargs.result, + theta=ir.ResultValue(theta, 0, type=types.Float), + phi=ir.ResultValue(phi, 0, type=types.Float), + lam=ir.ResultValue(lam, 0, type=types.Float), ), func.Return(creg.result), ] @@ -348,7 +351,9 @@ def test_lowering_global_rz(): (theta_pi := qasm2.expr.ConstPI()), (theta_num := as_float(2 * phi_val)), (theta := qasm2.expr.Mul(theta_pi.result, theta_num.result)), - parallel.RZ(theta=theta.result, qargs=qargs.result), + parallel.RZ( + theta=ir.ResultValue(theta, 0, type=types.Float), qargs=qargs.result + ), func.Return(creg.result), ] @@ -401,7 +406,9 @@ def test_lowering_local_rz(): (theta_pi := qasm2.expr.ConstPI()), (theta_num := as_float(2 * phi_val)), (theta := qasm2.expr.Mul(theta_pi.result, theta_num.result)), - parallel.RZ(theta=theta.result, qargs=qargs.result), + parallel.RZ( + theta=ir.ResultValue(theta, 0, type=types.Float), qargs=qargs.result + ), func.Return(creg.result), ] diff --git a/test/squin/test_qubit.py b/test/squin/test_qubit.py index bc16f4cc..001b9250 100644 --- a/test/squin/test_qubit.py +++ b/test/squin/test_qubit.py @@ -19,7 +19,7 @@ def main(): main.print() assert main.return_type.is_subseteq(types.Int) - @squin.kernel + @squin.kernel(fold=False) def main2(): q = squin.qalloc(2) @@ -34,10 +34,10 @@ def main2(): if m1_id != 0: # do something that errors - squin.x(q[4]) + q[0] + 1 if m2_id != 1: - squin.x(q[4]) + q[0] + 1 return squin.broadcast.measure(q) diff --git a/test/stim/passes/stim_reference_programs/debug/debug.stim b/test/stim/passes/stim_reference_programs/debug/debug.stim index 7479e928..d4d0923c 100644 --- a/test/stim/passes/stim_reference_programs/debug/debug.stim +++ b/test/stim/passes/stim_reference_programs/debug/debug.stim @@ -1,2 +1 @@ - # debug message diff --git a/test/stim/passes/stim_reference_programs/qubit/for_loop_nontrivial_index.stim b/test/stim/passes/stim_reference_programs/qubit/for_loop_nontrivial_index.stim index 5081b891..7dc34014 100644 --- a/test/stim/passes/stim_reference_programs/qubit/for_loop_nontrivial_index.stim +++ b/test/stim/passes/stim_reference_programs/qubit/for_loop_nontrivial_index.stim @@ -1,4 +1,3 @@ - H 0 CX 0 1 CX 1 2 diff --git a/test/stim/passes/stim_reference_programs/qubit/nested_for_loop.stim b/test/stim/passes/stim_reference_programs/qubit/nested_for_loop.stim index 26adce10..f3c3c462 100644 --- a/test/stim/passes/stim_reference_programs/qubit/nested_for_loop.stim +++ b/test/stim/passes/stim_reference_programs/qubit/nested_for_loop.stim @@ -1,4 +1,3 @@ - H 0 CX 0 2 CX 1 2 diff --git a/test/stim/passes/stim_reference_programs/qubit/nested_list.stim b/test/stim/passes/stim_reference_programs/qubit/nested_list.stim index 37920f35..0c58c050 100644 --- a/test/stim/passes/stim_reference_programs/qubit/nested_list.stim +++ b/test/stim/passes/stim_reference_programs/qubit/nested_list.stim @@ -1,3 +1,2 @@ - H 0 H 2 diff --git a/test/stim/passes/stim_reference_programs/qubit/non_pure_loop_iterator.stim b/test/stim/passes/stim_reference_programs/qubit/non_pure_loop_iterator.stim index 9088e275..ccb35ef0 100644 --- a/test/stim/passes/stim_reference_programs/qubit/non_pure_loop_iterator.stim +++ b/test/stim/passes/stim_reference_programs/qubit/non_pure_loop_iterator.stim @@ -1,4 +1,3 @@ - MZ(0.00000000) 0 1 2 3 4 X 0 X 1 diff --git a/test/stim/passes/stim_reference_programs/qubit/pick_if_else.stim b/test/stim/passes/stim_reference_programs/qubit/pick_if_else.stim index 05e9c279..f0285a3d 100644 --- a/test/stim/passes/stim_reference_programs/qubit/pick_if_else.stim +++ b/test/stim/passes/stim_reference_programs/qubit/pick_if_else.stim @@ -1,2 +1 @@ - H 1 diff --git a/test/stim/passes/stim_reference_programs/qubit/qubit.stim b/test/stim/passes/stim_reference_programs/qubit/qubit.stim index 17873714..51cc860c 100644 --- a/test/stim/passes/stim_reference_programs/qubit/qubit.stim +++ b/test/stim/passes/stim_reference_programs/qubit/qubit.stim @@ -1,4 +1,3 @@ - H 0 1 X 0 CX 0 1 diff --git a/test/stim/passes/stim_reference_programs/qubit/qubit_broadcast.stim b/test/stim/passes/stim_reference_programs/qubit/qubit_broadcast.stim index 708cb2a0..e033406d 100644 --- a/test/stim/passes/stim_reference_programs/qubit/qubit_broadcast.stim +++ b/test/stim/passes/stim_reference_programs/qubit/qubit_broadcast.stim @@ -1,3 +1,2 @@ - H 0 1 2 3 MZ(0.00000000) 0 1 2 3 diff --git a/test/stim/passes/stim_reference_programs/qubit/qubit_loss.stim b/test/stim/passes/stim_reference_programs/qubit/qubit_loss.stim index 4dee3da2..62b6bf1b 100644 --- a/test/stim/passes/stim_reference_programs/qubit/qubit_loss.stim +++ b/test/stim/passes/stim_reference_programs/qubit/qubit_loss.stim @@ -1,4 +1,3 @@ - H 0 1 2 3 4 I_ERROR[loss](0.10000000) 3 I_ERROR[loss](0.05000000) 0 1 2 3 4 diff --git a/test/stim/passes/stim_reference_programs/qubit/qubit_reset.stim b/test/stim/passes/stim_reference_programs/qubit/qubit_reset.stim index 958e0cfe..bcecaca7 100644 --- a/test/stim/passes/stim_reference_programs/qubit/qubit_reset.stim +++ b/test/stim/passes/stim_reference_programs/qubit/qubit_reset.stim @@ -1,3 +1,2 @@ - RZ 0 MZ(0.00000000) 0 diff --git a/test/stim/passes/stim_reference_programs/qubit/rep_code.stim b/test/stim/passes/stim_reference_programs/qubit/rep_code.stim index 9105cf43..171dd872 100644 --- a/test/stim/passes/stim_reference_programs/qubit/rep_code.stim +++ b/test/stim/passes/stim_reference_programs/qubit/rep_code.stim @@ -1,4 +1,3 @@ - RZ 0 1 2 3 4 CX 0 1 2 3 CX 2 1 4 3 diff --git a/test/stim/passes/stim_reference_programs/qubit/u3_gates.stim b/test/stim/passes/stim_reference_programs/qubit/u3_gates.stim index 4764563d..cfc100b0 100644 --- a/test/stim/passes/stim_reference_programs/qubit/u3_gates.stim +++ b/test/stim/passes/stim_reference_programs/qubit/u3_gates.stim @@ -1,4 +1,3 @@ - Z 0 SQRT_X_DAG 0 SQRT_X_DAG 0 diff --git a/test/stim/passes/stim_reference_programs/qubit/u3_to_clifford.stim b/test/stim/passes/stim_reference_programs/qubit/u3_to_clifford.stim index 6eb044c6..17a51d33 100644 --- a/test/stim/passes/stim_reference_programs/qubit/u3_to_clifford.stim +++ b/test/stim/passes/stim_reference_programs/qubit/u3_to_clifford.stim @@ -1,3 +1,2 @@ - H 0 MZ(0.00000000) 0 diff --git a/test/stim/passes/test_squin_noise_to_stim.py b/test/stim/passes/test_squin_noise_to_stim.py index dc3346d5..04df2837 100644 --- a/test/stim/passes/test_squin_noise_to_stim.py +++ b/test/stim/passes/test_squin_noise_to_stim.py @@ -252,10 +252,11 @@ def test(): SquinToStimPass(test.dialects)(test) - emit = EmitStimMain(correlation_identifier_offset=10) + buf = io.StringIO() + emit = EmitStimMain(stim.main, correlation_identifier_offset=10, io=buf) emit.initialize() - emit.run(mt=test, args=()) - stim_str = emit.get_output().strip() + emit.run(test) + stim_str = buf.getvalue().strip() assert stim_str == "I_ERROR[correlated_loss:10](0.10000000) 0 1 2 3" From 8970ee5cdabbfe85b9471798d0485531732648c8 Mon Sep 17 00:00:00 2001 From: kaihsin Date: Fri, 7 Nov 2025 14:20:26 -0500 Subject: [PATCH 5/6] merge main and fix annotation upgrade to 0.21 --- .../annotate/broadcast_with_alias.stim | 1 - .../annotate/clean_surface_code.stim | 1 - .../annotate/detector_coords_as_args.stim | 1 - .../annotate/kernel_base_op.stim | 1 - .../annotate/linear_program_rewrite.stim | 1 - .../annotate/measure_desugar.stim | 1 - .../annotate/nested_for.stim | 1 - .../annotate/no_kernel_base_op.stim | 1 - .../stim_reference_programs/annotate/rep_code.stim | 1 - .../annotate/set_detector_with_alias.stim | 1 - .../annotate/simple_if_rewrite.stim | 1 - test/stim/passes/test_annotation_to_stim.py | 10 ++++++---- test/stim/passes/test_code_basic_operations.py | 10 ++++++---- test/test_annotate.py | 13 ++++++++----- 14 files changed, 20 insertions(+), 24 deletions(-) diff --git a/test/stim/passes/stim_reference_programs/annotate/broadcast_with_alias.stim b/test/stim/passes/stim_reference_programs/annotate/broadcast_with_alias.stim index 13098f64..880dd33a 100644 --- a/test/stim/passes/stim_reference_programs/annotate/broadcast_with_alias.stim +++ b/test/stim/passes/stim_reference_programs/annotate/broadcast_with_alias.stim @@ -1,2 +1 @@ - X 0 1 \ No newline at end of file diff --git a/test/stim/passes/stim_reference_programs/annotate/clean_surface_code.stim b/test/stim/passes/stim_reference_programs/annotate/clean_surface_code.stim index 9703ef7d..1fb69374 100644 --- a/test/stim/passes/stim_reference_programs/annotate/clean_surface_code.stim +++ b/test/stim/passes/stim_reference_programs/annotate/clean_surface_code.stim @@ -1,4 +1,3 @@ - RZ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 H 13 14 15 16 CX 13 14 16 1 3 7 0 2 4 10 11 12 diff --git a/test/stim/passes/stim_reference_programs/annotate/detector_coords_as_args.stim b/test/stim/passes/stim_reference_programs/annotate/detector_coords_as_args.stim index f7d6290a..3ae575d3 100644 --- a/test/stim/passes/stim_reference_programs/annotate/detector_coords_as_args.stim +++ b/test/stim/passes/stim_reference_programs/annotate/detector_coords_as_args.stim @@ -1,4 +1,3 @@ - MZ(0.00000000) 0 1 DETECTOR(0, 1) rec[-2] rec[-1] DETECTOR(3, 4) rec[-2] rec[-1] diff --git a/test/stim/passes/stim_reference_programs/annotate/kernel_base_op.stim b/test/stim/passes/stim_reference_programs/annotate/kernel_base_op.stim index 7aca4024..6fcaa758 100644 --- a/test/stim/passes/stim_reference_programs/annotate/kernel_base_op.stim +++ b/test/stim/passes/stim_reference_programs/annotate/kernel_base_op.stim @@ -1,4 +1,3 @@ - H 1 3 2 MZ(0.00000000) 1 3 2 DETECTOR(0, 0) rec[-3] diff --git a/test/stim/passes/stim_reference_programs/annotate/linear_program_rewrite.stim b/test/stim/passes/stim_reference_programs/annotate/linear_program_rewrite.stim index 148b64b7..2a87ac9c 100644 --- a/test/stim/passes/stim_reference_programs/annotate/linear_program_rewrite.stim +++ b/test/stim/passes/stim_reference_programs/annotate/linear_program_rewrite.stim @@ -1,4 +1,3 @@ - X 0 Y 1 Z 2 diff --git a/test/stim/passes/stim_reference_programs/annotate/measure_desugar.stim b/test/stim/passes/stim_reference_programs/annotate/measure_desugar.stim index 9092c649..313025ee 100644 --- a/test/stim/passes/stim_reference_programs/annotate/measure_desugar.stim +++ b/test/stim/passes/stim_reference_programs/annotate/measure_desugar.stim @@ -1,4 +1,3 @@ - MZ(0.00000000) 0 MZ(0.00000000) 0 MZ(0.00000000) 0 diff --git a/test/stim/passes/stim_reference_programs/annotate/nested_for.stim b/test/stim/passes/stim_reference_programs/annotate/nested_for.stim index 6697fd60..2c75e452 100644 --- a/test/stim/passes/stim_reference_programs/annotate/nested_for.stim +++ b/test/stim/passes/stim_reference_programs/annotate/nested_for.stim @@ -1,4 +1,3 @@ - MZ(0.00000000) 0 1 DETECTOR(0, 0) rec[-2] DETECTOR(0, 0) rec[-1] diff --git a/test/stim/passes/stim_reference_programs/annotate/no_kernel_base_op.stim b/test/stim/passes/stim_reference_programs/annotate/no_kernel_base_op.stim index 86a28575..0acb78f5 100644 --- a/test/stim/passes/stim_reference_programs/annotate/no_kernel_base_op.stim +++ b/test/stim/passes/stim_reference_programs/annotate/no_kernel_base_op.stim @@ -1,4 +1,3 @@ - X 0 2 1 MZ(0.00000000) 0 2 1 DETECTOR(0, 0) rec[-3] diff --git a/test/stim/passes/stim_reference_programs/annotate/rep_code.stim b/test/stim/passes/stim_reference_programs/annotate/rep_code.stim index ce74b78f..fb3090b6 100644 --- a/test/stim/passes/stim_reference_programs/annotate/rep_code.stim +++ b/test/stim/passes/stim_reference_programs/annotate/rep_code.stim @@ -1,4 +1,3 @@ - RZ 0 1 2 3 4 CX 0 1 2 3 CX 2 1 4 3 diff --git a/test/stim/passes/stim_reference_programs/annotate/set_detector_with_alias.stim b/test/stim/passes/stim_reference_programs/annotate/set_detector_with_alias.stim index 372b8ad3..d8851744 100644 --- a/test/stim/passes/stim_reference_programs/annotate/set_detector_with_alias.stim +++ b/test/stim/passes/stim_reference_programs/annotate/set_detector_with_alias.stim @@ -1,3 +1,2 @@ - MZ(0.00000000) 0 1 DETECTOR(0, 0) rec[-2] rec[-1] \ No newline at end of file diff --git a/test/stim/passes/stim_reference_programs/annotate/simple_if_rewrite.stim b/test/stim/passes/stim_reference_programs/annotate/simple_if_rewrite.stim index 89aee87c..4d99e833 100644 --- a/test/stim/passes/stim_reference_programs/annotate/simple_if_rewrite.stim +++ b/test/stim/passes/stim_reference_programs/annotate/simple_if_rewrite.stim @@ -1,4 +1,3 @@ - MZ(0.00000000) 0 1 2 3 CZ rec[-4] 0 CX rec[-4] 1 rec[-4] 2 rec[-4] 3 diff --git a/test/stim/passes/test_annotation_to_stim.py b/test/stim/passes/test_annotation_to_stim.py index 5ab36d2f..13d55223 100644 --- a/test/stim/passes/test_annotation_to_stim.py +++ b/test/stim/passes/test_annotation_to_stim.py @@ -1,19 +1,21 @@ +import io import os from kirin import ir from kirin.dialects import scf, ilist -from bloqade import squin +from bloqade import stim, squin from bloqade.stim.emit import EmitStimMain from bloqade.stim.passes import SquinToStimPass def codegen(mt: ir.Method): # method should not have any arguments! - emit = EmitStimMain() + buf = io.StringIO() + emit = EmitStimMain(dialects=stim.main, io=buf) emit.initialize() - emit.run(mt=mt, args=()) - return emit.get_output() + emit.run(mt) + return buf.getvalue().strip() def load_reference_program(filename): diff --git a/test/stim/passes/test_code_basic_operations.py b/test/stim/passes/test_code_basic_operations.py index d16db938..ecef66d7 100644 --- a/test/stim/passes/test_code_basic_operations.py +++ b/test/stim/passes/test_code_basic_operations.py @@ -10,13 +10,14 @@ PhysicalAndSquinToStim pass and are included here """ +import io import os from typing import Any from kirin import ir from kirin.dialects import ilist -from bloqade import squin +from bloqade import stim, squin from bloqade.types import Qubit from bloqade.stim.emit import EmitStimMain from bloqade.stim.passes import SquinToStimPass @@ -32,10 +33,11 @@ def load_stim_reference(filename): def codegen(mt: ir.Method): # method should not have any arguments! - emit = EmitStimMain() + buf = io.StringIO() + emit = EmitStimMain(dialects=stim.main, io=buf) emit.initialize() - emit.run(mt=mt, args=()) - return emit.get_output() + emit.run(mt) + return buf.getvalue().strip() def test_no_kernel_base_op(): diff --git a/test/test_annotate.py b/test/test_annotate.py index 74974503..f902866c 100644 --- a/test/test_annotate.py +++ b/test/test_annotate.py @@ -1,16 +1,19 @@ +import io + from kirin import ir -from bloqade import squin +from bloqade import stim, squin from bloqade.stim.emit import EmitStimMain from bloqade.stim.passes import SquinToStimPass def codegen(mt: ir.Method): # method should not have any arguments! - emit = EmitStimMain() + buf = io.StringIO() + emit = EmitStimMain(dialects=stim.main, io=buf) emit.initialize() - emit.run(mt=mt, args=()) - return emit.get_output() + emit.run(mt) + return buf.getvalue().strip() def test_annotate(): @@ -25,7 +28,7 @@ def test(): SquinToStimPass(dialects=test.dialects)(test) codegen_output = codegen(test) expected_output = ( - "\nMZ(0.00000000) 0 1 2 3\n" + "MZ(0.00000000) 0 1 2 3\n" "DETECTOR(0, 0) rec[-4] rec[-3] rec[-2]\n" "OBSERVABLE_INCLUDE(0) rec[-1]" ) From 29d7e964e882c6bd47df7db2a0f2dd86197ad7c8 Mon Sep 17 00:00:00 2001 From: kaihsin Date: Fri, 7 Nov 2025 14:21:24 -0500 Subject: [PATCH 6/6] bump kirin deps to 0.21 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index fc723d53..5091c45b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,7 @@ requires-python = ">=3.10" dependencies = [ "numpy>=1.22.0", "scipy>=1.13.1", - "kirin-toolchain~=0.20.0", + "kirin-toolchain~=0.21.0", "rich>=13.9.4", "pydantic>=1.3.0,<2.11.0", "pandas>=2.2.3",