From fb0f4585fb4d0947fe3be33fb13302bc4ade05d4 Mon Sep 17 00:00:00 2001 From: Emma Rosenfeld Date: Tue, 26 Nov 2024 17:54:47 +0000 Subject: [PATCH 1/3] add stimcirq tags --- glue/cirq/stimcirq/_cirq_to_stim.py | 153 +++++++++++++++-------- glue/cirq/stimcirq/_cirq_to_stim_test.py | 132 +++++++++++++++---- 2 files changed, 208 insertions(+), 77 deletions(-) diff --git a/glue/cirq/stimcirq/_cirq_to_stim.py b/glue/cirq/stimcirq/_cirq_to_stim.py index 248dcd53d..9928d248e 100644 --- a/glue/cirq/stimcirq/_cirq_to_stim.py +++ b/glue/cirq/stimcirq/_cirq_to_stim.py @@ -5,10 +5,15 @@ import cirq import stim +import warnings + +_STIMCIRQ_TAG_LOOKUP = {"H": cirq.H, "X": cirq.X, "Y": cirq.Y, "Z": cirq.Z} def cirq_circuit_to_stim_circuit( - circuit: cirq.AbstractCircuit, *, qubit_to_index_dict: Optional[Dict[cirq.Qid, int]] = None + circuit: cirq.AbstractCircuit, + *, + qubit_to_index_dict: Optional[Dict[cirq.Qid, int]] = None, ) -> stim.Circuit: """Converts a cirq circuit into an equivalent stim circuit. @@ -94,7 +99,10 @@ def _stim_conversion_( def cirq_circuit_to_stim_data( - circuit: cirq.AbstractCircuit, *, q2i: Optional[Dict[cirq.Qid, int]] = None, flatten: bool = False, + circuit: cirq.AbstractCircuit, + *, + q2i: Optional[Dict[cirq.Qid, int]] = None, + flatten: bool = False, ) -> Tuple[stim.Circuit, List[Tuple[str, int]]]: """Converts a Cirq circuit into a Stim circuit and also metadata about where measurements go.""" if q2i is None: @@ -119,7 +127,9 @@ def cirq_circuit_to_stim_data( @functools.lru_cache(maxsize=1) -def gate_to_stim_append_func() -> Dict[cirq.Gate, Callable[[stim.Circuit, List[int]], None]]: +def gate_to_stim_append_func() -> ( + Dict[cirq.Gate, Callable[[stim.Circuit, List[int]], None]] +): """A dictionary mapping specific gate instances to stim circuit appending functions.""" x = (cirq.X, False) y = (cirq.Y, False) @@ -161,38 +171,38 @@ def do(c, t): cirq.ResetChannel(): use("R"), # Identities. cirq.I: use("I"), - cirq.H ** 0: do_nothing, - cirq.X ** 0: do_nothing, - cirq.Y ** 0: do_nothing, - cirq.Z ** 0: do_nothing, - cirq.ISWAP ** 0: do_nothing, - cirq.SWAP ** 0: do_nothing, + cirq.H**0: do_nothing, + cirq.X**0: do_nothing, + cirq.Y**0: do_nothing, + cirq.Z**0: do_nothing, + cirq.ISWAP**0: do_nothing, + cirq.SWAP**0: do_nothing, # Common named gates. cirq.H: use("H"), cirq.X: use("X"), cirq.Y: use("Y"), cirq.Z: use("Z"), - cirq.X ** 0.5: use("SQRT_X"), - cirq.X ** -0.5: use("SQRT_X_DAG"), - cirq.Y ** 0.5: use("SQRT_Y"), - cirq.Y ** -0.5: use("SQRT_Y_DAG"), - cirq.Z ** 0.5: use("SQRT_Z"), - cirq.Z ** -0.5: use("SQRT_Z_DAG"), + cirq.X**0.5: use("SQRT_X"), + cirq.X**-0.5: use("SQRT_X_DAG"), + cirq.Y**0.5: use("SQRT_Y"), + cirq.Y**-0.5: use("SQRT_Y_DAG"), + cirq.Z**0.5: use("SQRT_Z"), + cirq.Z**-0.5: use("SQRT_Z_DAG"), cirq.CNOT: use("CNOT"), cirq.CZ: use("CZ"), cirq.ISWAP: use("ISWAP"), - cirq.ISWAP ** -1: use("ISWAP_DAG"), - cirq.ISWAP ** 2: use("Z"), + cirq.ISWAP**-1: use("ISWAP_DAG"), + cirq.ISWAP**2: use("Z"), cirq.SWAP: use("SWAP"), cirq.X.controlled(1): use("CX"), cirq.Y.controlled(1): use("CY"), cirq.Z.controlled(1): use("CZ"), - cirq.XX ** 0.5: use("SQRT_XX"), - cirq.YY ** 0.5: use("SQRT_YY"), - cirq.ZZ ** 0.5: use("SQRT_ZZ"), - cirq.XX ** -0.5: use("SQRT_XX_DAG"), - cirq.YY ** -0.5: use("SQRT_YY_DAG"), - cirq.ZZ ** -0.5: use("SQRT_ZZ_DAG"), + cirq.XX**0.5: use("SQRT_XX"), + cirq.YY**0.5: use("SQRT_YY"), + cirq.ZZ**0.5: use("SQRT_ZZ"), + cirq.XX**-0.5: use("SQRT_XX_DAG"), + cirq.YY**-0.5: use("SQRT_YY_DAG"), + cirq.ZZ**-0.5: use("SQRT_ZZ_DAG"), # All 24 cirq.SingleQubitCliffordGate instances. sqcg(x, y): use("SQRT_X_DAG"), sqcg(x, ny): use("SQRT_X"), @@ -233,8 +243,12 @@ def gate_type_to_stim_append_func() -> Dict[Type[cirq.Gate], StimTypeHandler]: """A dictionary mapping specific gate types to stim circuit appending functions.""" return { cirq.ControlledGate: cast(StimTypeHandler, _stim_append_controlled_gate), - cirq.DensePauliString: cast(StimTypeHandler, _stim_append_dense_pauli_string_gate), - cirq.MutableDensePauliString: cast(StimTypeHandler, _stim_append_dense_pauli_string_gate), + cirq.DensePauliString: cast( + StimTypeHandler, _stim_append_dense_pauli_string_gate + ), + cirq.MutableDensePauliString: cast( + StimTypeHandler, _stim_append_dense_pauli_string_gate + ), cirq.AsymmetricDepolarizingChannel: cast( StimTypeHandler, _stim_append_asymmetric_depolarizing_channel ), @@ -245,10 +259,14 @@ def gate_type_to_stim_append_func() -> Dict[Type[cirq.Gate], StimTypeHandler]: "Z_ERROR", t, cast(cirq.PhaseFlipChannel, g).p ), cirq.PhaseDampingChannel: lambda c, g, t: c.append_operation( - "Z_ERROR", t, 0.5 - math.sqrt(1 - cast(cirq.PhaseDampingChannel, g).gamma) / 2 + "Z_ERROR", + t, + 0.5 - math.sqrt(1 - cast(cirq.PhaseDampingChannel, g).gamma) / 2, ), cirq.RandomGateChannel: cast(StimTypeHandler, _stim_append_random_gate_channel), - cirq.DepolarizingChannel: cast(StimTypeHandler, _stim_append_depolarizing_channel), + cirq.DepolarizingChannel: cast( + StimTypeHandler, _stim_append_depolarizing_channel + ), } @@ -335,28 +353,30 @@ def _stim_append_asymmetric_depolarizing_channel( "PAULI_CHANNEL_2", t, [ - g.error_probabilities.get('IX', 0), - g.error_probabilities.get('IY', 0), - g.error_probabilities.get('IZ', 0), - g.error_probabilities.get('XI', 0), - g.error_probabilities.get('XX', 0), - g.error_probabilities.get('XY', 0), - g.error_probabilities.get('XZ', 0), - g.error_probabilities.get('YI', 0), - g.error_probabilities.get('YX', 0), - g.error_probabilities.get('YY', 0), - g.error_probabilities.get('YZ', 0), - g.error_probabilities.get('ZI', 0), - g.error_probabilities.get('ZX', 0), - g.error_probabilities.get('ZY', 0), - g.error_probabilities.get('ZZ', 0), - ] + g.error_probabilities.get("IX", 0), + g.error_probabilities.get("IY", 0), + g.error_probabilities.get("IZ", 0), + g.error_probabilities.get("XI", 0), + g.error_probabilities.get("XX", 0), + g.error_probabilities.get("XY", 0), + g.error_probabilities.get("XZ", 0), + g.error_probabilities.get("YI", 0), + g.error_probabilities.get("YX", 0), + g.error_probabilities.get("YY", 0), + g.error_probabilities.get("YZ", 0), + g.error_probabilities.get("ZI", 0), + g.error_probabilities.get("ZX", 0), + g.error_probabilities.get("ZY", 0), + g.error_probabilities.get("ZZ", 0), + ], ) else: - raise NotImplementedError(f'cirq-to-stim gate {g!r}') + raise NotImplementedError(f"cirq-to-stim gate {g!r}") -def _stim_append_depolarizing_channel(c: stim.Circuit, g: cirq.DepolarizingChannel, t: List[int]): +def _stim_append_depolarizing_channel( + c: stim.Circuit, g: cirq.DepolarizingChannel, t: List[int] +): if g.num_qubits() == 1: c.append_operation("DEPOLARIZE1", t, g.p) elif g.num_qubits() == 2: @@ -383,10 +403,14 @@ def _stim_append_controlled_gate(c: stim.Circuit, g: cirq.ControlledGate, t: Lis raise TypeError(f"Phase kickback from {g!r} isn't a stabilizer operation.") return - raise TypeError(f"Don't know how to turn controlled gate {g!r} into Stim operations.") + raise TypeError( + f"Don't know how to turn controlled gate {g!r} into Stim operations." + ) -def _stim_append_random_gate_channel(c: stim.Circuit, g: cirq.RandomGateChannel, t: List[int]): +def _stim_append_random_gate_channel( + c: stim.Circuit, g: cirq.RandomGateChannel, t: List[int] +): if g.sub_gate in [cirq.X, cirq.Y, cirq.Z]: c.append_operation(f"{g.sub_gate}_ERROR", t, g.probability) elif isinstance(g.sub_gate, cirq.DensePauliString): @@ -407,11 +431,15 @@ def __init__(self): self.have_seen_loop = False self.flatten = False - def process_circuit_operation_into_repeat_block(self, op: cirq.CircuitOperation) -> None: + def process_circuit_operation_into_repeat_block( + self, op: cirq.CircuitOperation + ) -> None: if self.flatten or op.repetitions == 1: - moments = cirq.unroll_circuit_op(cirq.Circuit(op), deep=False, tags_to_check=None).moments + moments = cirq.unroll_circuit_op( + cirq.Circuit(op), deep=False, tags_to_check=None + ).moments self.process_moments(moments) - self.out = self.out[:-1] # Remove a trailing TICK (to avoid double TICK) + self.out = self.out[:-1] # Remove a trailing TICK (to avoid double TICK) return child = CirqToStimHelper() @@ -419,7 +447,9 @@ def process_circuit_operation_into_repeat_block(self, op: cirq.CircuitOperation) child.q2i = self.q2i child.have_seen_loop = True self.have_seen_loop = True - child.process_moments(op.transform_qubits(lambda q: op.qubit_map.get(q, q)).circuit) + child.process_moments( + op.transform_qubits(lambda q: op.qubit_map.get(q, q)).circuit + ) self.out += child.out * op.repetitions def process_operations(self, operations: Iterable[cirq.Operation]) -> None: @@ -427,12 +457,27 @@ def process_operations(self, operations: Iterable[cirq.Operation]) -> None: t2f = gate_type_to_stim_append_func() for op in operations: assert isinstance(op, cirq.Operation) + + # for tagged cirq operations with a particular conversion + if isinstance(op, cirq.TaggedOperation) and len(op.tags) == 1: + (tag,) = op.tags + if tag in list(_STIMCIRQ_TAG_LOOKUP.keys()): + gate = _STIMCIRQ_TAG_LOOKUP[tag] + op = gate.on(*op.qubits) + else: + warnings.warn(f"ignoring cirq {tag=} for conversion to stim") + op = op.untagged + gate = op.gate + else: + op = op.untagged + gate = op.gate + op = op.untagged gate = op.gate targets = [self.q2i[q] for q in op.qubits] custom_method = getattr( - op, '_stim_conversion_', getattr(gate, '_stim_conversion_', None) + op, "_stim_conversion_", getattr(gate, "_stim_conversion_", None) ) if custom_method is not None: custom_method( @@ -488,7 +533,9 @@ def process_moment(self, moment: cirq.Moment): self.process_operations(moment) # Append a TICK, unless it was already handled by an internal REPEAT block. - if length_before == len(self.out) or not isinstance(self.out[-1], stim.CircuitRepeatBlock): + if length_before == len(self.out) or not isinstance( + self.out[-1], stim.CircuitRepeatBlock + ): self.out.append_operation("TICK", []) def process_moments(self, moments: Iterable[cirq.Moment]): diff --git a/glue/cirq/stimcirq/_cirq_to_stim_test.py b/glue/cirq/stimcirq/_cirq_to_stim_test.py index 65f3f1d2c..4099be0c5 100644 --- a/glue/cirq/stimcirq/_cirq_to_stim_test.py +++ b/glue/cirq/stimcirq/_cirq_to_stim_test.py @@ -34,16 +34,16 @@ def solve_tableau(gate: cirq.Gate) -> Dict[cirq.PauliString, cirq.PauliString]: out: cirq.PauliString = cirq.PauliString({q: "IXZY"[s[str(q)][0]] for q in qs}) # Use phase kickback to determine the sign of the output stabilizer. - sign = cirq.NamedQubit('a') + sign = cirq.NamedQubit("a") c = cirq.Circuit( cirq.H(sign), inp.controlled_by(sign), gate(*qs), out.controlled_by(sign), cirq.H(sign), - cirq.measure(sign, key='sign'), + cirq.measure(sign, key="sign"), ) - if cirq.Simulator().sample(c)['sign'][0]: + if cirq.Simulator().sample(c)["sign"][0]: out *= -1 result[inp] = out @@ -54,7 +54,7 @@ def test_solve_tableau(): a, b = cirq.LineQubit.range(2) assert solve_tableau(cirq.I) == {cirq.X(a): cirq.X(a), cirq.Z(a): cirq.Z(a)} assert solve_tableau(cirq.S) == {cirq.X(a): cirq.Y(a), cirq.Z(a): cirq.Z(a)} - assert solve_tableau(cirq.S ** -1) == {cirq.X(a): -cirq.Y(a), cirq.Z(a): cirq.Z(a)} + assert solve_tableau(cirq.S**-1) == {cirq.X(a): -cirq.Y(a), cirq.Z(a): cirq.Z(a)} assert solve_tableau(cirq.H) == {cirq.X(a): cirq.Z(a), cirq.Z(a): cirq.X(a)} assert solve_tableau( cirq.SingleQubitCliffordGate.from_xz_map((cirq.Y, False), (cirq.X, False)) @@ -81,7 +81,9 @@ def assert_unitary_gate_converts_correctly(gate: cirq.Gate): for q, p in pre.items(): c.append_operation(f"C{p}", [2 * n, q.x]) qs = cirq.LineQubit.range(n) - conv_gate, _ = cirq_circuit_to_stim_data(cirq.Circuit(gate(*qs)), q2i={q: q.x for q in qs}) + conv_gate, _ = cirq_circuit_to_stim_data( + cirq.Circuit(gate(*qs)), q2i={q: q.x for q in qs} + ) c += conv_gate for q, p in post.items(): c.append_operation(f"C{p}", [2 * n, q.x]) @@ -90,7 +92,9 @@ def assert_unitary_gate_converts_correctly(gate: cirq.Gate): c.append_operation("H", [2 * n]) c.append_operation("M", [2 * n]) correct = np.count_nonzero(c.compile_sampler().sample_bit_packed(10)) == 0 - assert correct, f"{gate!r} failed to turn {pre} into {post}.\nConverted to:\n{conv_gate}\n" + assert ( + correct + ), f"{gate!r} failed to turn {pre} into {post}.\nConverted to:\n{conv_gate}\n" @pytest.mark.parametrize("gate", gate_to_stim_append_func().keys()) @@ -103,7 +107,9 @@ def test_unitary_gate_conversions(gate: cirq.Gate): def test_more_unitary_gate_conversions(): for p in [1, 1j, -1, -1j]: assert_unitary_gate_converts_correctly(p * cirq.DensePauliString("IXYZ")) - assert_unitary_gate_converts_correctly((p * cirq.DensePauliString("IXYZ")).controlled(1)) + assert_unitary_gate_converts_correctly( + (p * cirq.DensePauliString("IXYZ")).controlled(1) + ) a, b = cirq.LineQubit.range(2) c, _ = cirq_circuit_to_stim_data( @@ -148,9 +154,9 @@ def test_more_unitary_gate_conversions(): cirq.AsymmetricDepolarizingChannel(p_x=0, p_y=0, p_z=0.1), *[ cirq.asymmetric_depolarize(error_probabilities={a + b: 0.1}) - for a, b in list(itertools.product('IXYZ', repeat=2))[1:] + for a, b in list(itertools.product("IXYZ", repeat=2))[1:] ], - cirq.asymmetric_depolarize(error_probabilities={'IX': 0.125, 'ZY': 0.375}), + cirq.asymmetric_depolarize(error_probabilities={"IX": 0.125, "ZY": 0.375}), ] @@ -180,7 +186,9 @@ def test_frame_simulator_sampling_noisy_gates_agrees_with_cirq_data(gate: cirq.G for value, count in zip(unique, counts): expected_rate = expected_rates[value] actual_rate = count / sample_count - allowed_variation = 5 * (expected_rate * (1 - expected_rate) / sample_count) ** 0.5 + allowed_variation = ( + 5 * (expected_rate * (1 - expected_rate) / sample_count) ** 0.5 + ) if not 0 <= expected_rate - allowed_variation <= 1: raise ValueError("Not enough samples to bound results away from extremes.") assert abs(expected_rate - actual_rate) < allowed_variation, ( @@ -227,7 +235,9 @@ def test_tableau_simulator_sampling_noisy_gates_agrees_with_cirq_data(gate: cirq for value, count in zip(unique, counts): expected_rate = expected_rates[value] actual_rate = count / sample_count - allowed_variation = 5 * (expected_rate * (1 - expected_rate) / sample_count) ** 0.5 + allowed_variation = ( + 5 * (expected_rate * (1 - expected_rate) / sample_count) ** 0.5 + ) if not 0 <= expected_rate - allowed_variation <= 1: raise ValueError("Not enough samples to bound results away from extremes.") assert abs(expected_rate - actual_rate) < allowed_variation, ( @@ -266,7 +276,7 @@ def with_qubits(self, *new_qubits): raise NotImplementedError() @property - def qubits(self) -> Tuple['cirq.Qid', ...]: + def qubits(self) -> Tuple["cirq.Qid", ...]: return () a, b, c = cirq.LineQubit.range(3) @@ -316,11 +326,11 @@ def test_custom_qubit_indexing(): actual = stimcirq.cirq_circuit_to_stim_circuit( cirq.Circuit(cirq.CNOT(a, b)), qubit_to_index_dict={a: 10, b: 15} ) - assert actual == stim.Circuit('CX 10 15\nTICK') + assert actual == stim.Circuit("CX 10 15\nTICK") actual = stimcirq.cirq_circuit_to_stim_circuit( cirq.FrozenCircuit(cirq.CNOT(a, b)), qubit_to_index_dict={a: 10, b: 15} ) - assert actual == stim.Circuit('CX 10 15\nTICK') + assert actual == stim.Circuit("CX 10 15\nTICK") def test_on_loop(): @@ -337,7 +347,7 @@ def test_on_loop(): ) ) result = stimcirq.StimSampler().run(c) - assert result.measurements.keys() == {'0:a', '0:b', '1:a', '1:b', '2:a', '2:b'} + assert result.measurements.keys() == {"0:a", "0:b", "1:a", "1:b", "2:a", "2:b"} def test_multi_moment_circuit_operation(): @@ -352,7 +362,8 @@ def test_multi_moment_circuit_operation(): ) ) ) - assert stimcirq.cirq_circuit_to_stim_circuit(cc) == stim.Circuit(""" + assert stimcirq.cirq_circuit_to_stim_circuit(cc) == stim.Circuit( + """ H 0 TICK H 0 @@ -361,7 +372,8 @@ def test_multi_moment_circuit_operation(): TICK H 0 TICK - """) + """ + ) def test_on_tagged_loop(): @@ -375,9 +387,9 @@ def test_on_tagged_loop(): cirq.measure(b, key="b"), ), repetitions=3, - ).with_tags('my_tag') + ).with_tags("my_tag") ) - + stim_circuit = stimcirq.cirq_circuit_to_stim_circuit(c) assert stim.CircuitRepeatBlock in {type(instr) for instr in stim_circuit} @@ -385,10 +397,82 @@ def test_on_tagged_loop(): def test_random_gate_channel(): q0, q1 = cirq.LineQubit.range(2) - circuit = cirq.Circuit(cirq.RandomGateChannel( - sub_gate=cirq.DensePauliString((0, 1)), - probability=0.25).on(q0, q1)) - assert stimcirq.cirq_circuit_to_stim_circuit(circuit) == stim.Circuit(""" + circuit = cirq.Circuit( + cirq.RandomGateChannel( + sub_gate=cirq.DensePauliString((0, 1)), probability=0.25 + ).on(q0, q1) + ) + assert stimcirq.cirq_circuit_to_stim_circuit(circuit) == stim.Circuit( + """ E(0.25) X1 TICK - """) + """ + ) + + +def test_stimcirq_tags(): + + a, b = cirq.LineQubit.range(2) + c = cirq.FrozenCircuit( + cirq.X(a).with_tags("H"), + cirq.X(b), + cirq.measure(a, key="a"), + cirq.measure(b, key="b"), + ) + + stim_circuit = stimcirq.cirq_circuit_to_stim_circuit(c) + assert stim_circuit == stim.Circuit( + """ + H 0 + X 1 + TICK + M 0 1 + TICK + """ + ) + + a, b = cirq.LineQubit.range(2) + c = cirq.FrozenCircuit( + cirq.XPowGate(exponent=0.243).on(a).with_tags("H"), + cirq.X(b), + cirq.measure(a, key="a"), + cirq.measure(b, key="b"), + ) + + stim_circuit = stimcirq.cirq_circuit_to_stim_circuit(c) + assert stim_circuit == stim.Circuit( + """ + H 0 + X 1 + TICK + M 0 1 + TICK + """ + ) + + a, b = cirq.LineQubit.range(2) + c = cirq.Circuit( + cirq.CircuitOperation( + cirq.FrozenCircuit( + cirq.X(a).with_tags("H"), + cirq.X(b), + cirq.measure(a, key="a"), + cirq.measure(b, key="b"), + ), + repetitions=3, + ).with_tags("my_tag") + ) + + stim_circuit = stimcirq.cirq_circuit_to_stim_circuit(c) + assert stim_circuit == stim.Circuit( + """ + + REPEAT 3 { + H 0 + X 1 + TICK + M 0 1 + TICK + } +""" + ) From 1acf7fc09adc799124f83f6c95b1b23c2cac6485 Mon Sep 17 00:00:00 2001 From: Emma Rosenfeld Date: Wed, 27 Nov 2024 02:26:04 +0000 Subject: [PATCH 2/3] custom op conversion function --- glue/cirq/stimcirq/_cirq_to_stim.py | 44 ++++++++---------------- glue/cirq/stimcirq/_cirq_to_stim_test.py | 40 ++++++++++++++------- 2 files changed, 41 insertions(+), 43 deletions(-) diff --git a/glue/cirq/stimcirq/_cirq_to_stim.py b/glue/cirq/stimcirq/_cirq_to_stim.py index 9928d248e..e08abea6c 100644 --- a/glue/cirq/stimcirq/_cirq_to_stim.py +++ b/glue/cirq/stimcirq/_cirq_to_stim.py @@ -5,15 +5,13 @@ import cirq import stim -import warnings - -_STIMCIRQ_TAG_LOOKUP = {"H": cirq.H, "X": cirq.X, "Y": cirq.Y, "Z": cirq.Z} def cirq_circuit_to_stim_circuit( circuit: cirq.AbstractCircuit, *, qubit_to_index_dict: Optional[Dict[cirq.Qid, int]] = None, + custom_op_conversion_func: Callable | None = None ) -> stim.Circuit: """Converts a cirq circuit into an equivalent stim circuit. @@ -95,7 +93,7 @@ def _stim_conversion_( edit_circuit.append_operation("H", targets) """ - return cirq_circuit_to_stim_data(circuit, q2i=qubit_to_index_dict, flatten=False)[0] + return cirq_circuit_to_stim_data(circuit, q2i=qubit_to_index_dict, flatten=False, custom_op_conversion_func=custom_op_conversion_func)[0] def cirq_circuit_to_stim_data( @@ -103,6 +101,7 @@ def cirq_circuit_to_stim_data( *, q2i: Optional[Dict[cirq.Qid, int]] = None, flatten: bool = False, + custom_op_conversion_func: Callable | None = None, ) -> Tuple[stim.Circuit, List[Tuple[str, int]]]: """Converts a Cirq circuit into a Stim circuit and also metadata about where measurements go.""" if q2i is None: @@ -119,7 +118,7 @@ def cirq_circuit_to_stim_data( elif isinstance(q, cirq.GridQubit): helper.out.append_operation("QUBIT_COORDS", [q2i[q]], [q.row, q.col]) - helper.process_moments(circuit) + helper.process_moments(circuit, custom_op_conversion_func=custom_op_conversion_func) return helper.out, helper.key_out @@ -432,13 +431,13 @@ def __init__(self): self.flatten = False def process_circuit_operation_into_repeat_block( - self, op: cirq.CircuitOperation + self, op: cirq.CircuitOperation, custom_op_conversion_func: Callable | None ) -> None: if self.flatten or op.repetitions == 1: moments = cirq.unroll_circuit_op( cirq.Circuit(op), deep=False, tags_to_check=None ).moments - self.process_moments(moments) + self.process_moments(moments, custom_op_conversion_func=custom_op_conversion_func) self.out = self.out[:-1] # Remove a trailing TICK (to avoid double TICK) return @@ -448,31 +447,16 @@ def process_circuit_operation_into_repeat_block( child.have_seen_loop = True self.have_seen_loop = True child.process_moments( - op.transform_qubits(lambda q: op.qubit_map.get(q, q)).circuit + op.transform_qubits(lambda q: op.qubit_map.get(q, q)).circuit, custom_op_conversion_func=custom_op_conversion_func ) self.out += child.out * op.repetitions - def process_operations(self, operations: Iterable[cirq.Operation]) -> None: + def process_operations(self, operations: Iterable[cirq.Operation], custom_op_conversion_func: Callable | None) -> None: g2f = gate_to_stim_append_func() t2f = gate_type_to_stim_append_func() for op in operations: assert isinstance(op, cirq.Operation) - - # for tagged cirq operations with a particular conversion - if isinstance(op, cirq.TaggedOperation) and len(op.tags) == 1: - (tag,) = op.tags - if tag in list(_STIMCIRQ_TAG_LOOKUP.keys()): - gate = _STIMCIRQ_TAG_LOOKUP[tag] - op = gate.on(*op.qubits) - else: - warnings.warn(f"ignoring cirq {tag=} for conversion to stim") - op = op.untagged - gate = op.gate - else: - op = op.untagged - gate = op.gate - - op = op.untagged + op = op.untagged if custom_op_conversion_func is None else custom_op_conversion_func(op) gate = op.gate targets = [self.q2i[q] for q in op.qubits] @@ -490,7 +474,7 @@ def process_operations(self, operations: Iterable[cirq.Operation]) -> None: continue if isinstance(op, cirq.CircuitOperation): - self.process_circuit_operation_into_repeat_block(op) + self.process_circuit_operation_into_repeat_block(op, custom_op_conversion_func=custom_op_conversion_func) continue # Special case measurement, because of its metadata. @@ -528,9 +512,9 @@ def process_operations(self, operations: Iterable[cirq.Operation]) -> None: f"- It doesn't have a _stim_conversion_ method.\n" ) from ex - def process_moment(self, moment: cirq.Moment): + def process_moment(self, moment: cirq.Moment, custom_op_conversion_func: Callable | None): length_before = len(self.out) - self.process_operations(moment) + self.process_operations(moment, custom_op_conversion_func=custom_op_conversion_func) # Append a TICK, unless it was already handled by an internal REPEAT block. if length_before == len(self.out) or not isinstance( @@ -538,6 +522,6 @@ def process_moment(self, moment: cirq.Moment): ): self.out.append_operation("TICK", []) - def process_moments(self, moments: Iterable[cirq.Moment]): + def process_moments(self, moments: Iterable[cirq.Moment], custom_op_conversion_func: Callable | None): for moment in moments: - self.process_moment(moment) + self.process_moment(moment, custom_op_conversion_func=custom_op_conversion_func) diff --git a/glue/cirq/stimcirq/_cirq_to_stim_test.py b/glue/cirq/stimcirq/_cirq_to_stim_test.py index 4099be0c5..fb46008e1 100644 --- a/glue/cirq/stimcirq/_cirq_to_stim_test.py +++ b/glue/cirq/stimcirq/_cirq_to_stim_test.py @@ -410,7 +410,22 @@ def test_random_gate_channel(): ) -def test_stimcirq_tags(): +def test_stimcirq_custom_conversion(): + + _tag_lookup = {"H": cirq.H, "X": cirq.X, "Y": cirq.Y, "Z": cirq.Z} + + def _op_conversion(op: cirq.Operation) -> cirq.Operation: + """" For converting particular tagged cirq.Operator's to a value described by the tag content. + Useful when treating non-Clifford gates in cirq and converting to STIM. + """ + if isinstance(op, cirq.TaggedOperation): + tag_checks = [tag for tag in op.tags if tag in list(_tag_lookup.keys())] + if len(tag_checks) == 1: + gate = _tag_lookup[tag_checks[0]] + op = gate.on(*op.qubits) + elif len(tag_checks) > 1: + raise ValueError(f"found multiple {op.tags=} matching a conversion") + return op.untagged a, b = cirq.LineQubit.range(2) c = cirq.FrozenCircuit( @@ -420,7 +435,7 @@ def test_stimcirq_tags(): cirq.measure(b, key="b"), ) - stim_circuit = stimcirq.cirq_circuit_to_stim_circuit(c) + stim_circuit = stimcirq.cirq_circuit_to_stim_circuit(c, custom_op_conversion_func=_op_conversion) assert stim_circuit == stim.Circuit( """ H 0 @@ -439,7 +454,7 @@ def test_stimcirq_tags(): cirq.measure(b, key="b"), ) - stim_circuit = stimcirq.cirq_circuit_to_stim_circuit(c) + stim_circuit = stimcirq.cirq_circuit_to_stim_circuit(c, custom_op_conversion_func=_op_conversion) assert stim_circuit == stim.Circuit( """ H 0 @@ -463,16 +478,15 @@ def test_stimcirq_tags(): ).with_tags("my_tag") ) - stim_circuit = stimcirq.cirq_circuit_to_stim_circuit(c) + stim_circuit = stimcirq.cirq_circuit_to_stim_circuit(c, custom_op_conversion_func=_op_conversion) assert stim_circuit == stim.Circuit( """ - - REPEAT 3 { - H 0 - X 1 - TICK - M 0 1 - TICK - } -""" + REPEAT 3 { + H 0 + X 1 + TICK + M 0 1 + TICK + } + """ ) From 84d3e40c9cf8f9dc10574a3da0126b4b9719eedb Mon Sep 17 00:00:00 2001 From: Emma Rosenfeld Date: Wed, 27 Nov 2024 15:33:57 +0000 Subject: [PATCH 3/3] persist cirq tags through to stim --- glue/cirq/stimcirq/_cirq_to_stim.py | 108 +++++++++++++---------- glue/cirq/stimcirq/_cirq_to_stim_test.py | 54 +++++++++--- 2 files changed, 104 insertions(+), 58 deletions(-) diff --git a/glue/cirq/stimcirq/_cirq_to_stim.py b/glue/cirq/stimcirq/_cirq_to_stim.py index e08abea6c..f6a47344c 100644 --- a/glue/cirq/stimcirq/_cirq_to_stim.py +++ b/glue/cirq/stimcirq/_cirq_to_stim.py @@ -1,7 +1,8 @@ import functools import itertools import math -from typing import Callable, cast, Dict, Iterable, List, Optional, Sequence, Tuple, Type +from collections.abc import Callable +from typing import cast, Dict, Iterable, List, Optional, Sequence, Tuple, Type import cirq import stim @@ -11,7 +12,7 @@ def cirq_circuit_to_stim_circuit( circuit: cirq.AbstractCircuit, *, qubit_to_index_dict: Optional[Dict[cirq.Qid, int]] = None, - custom_op_conversion_func: Callable | None = None + custom_op_conversion_func: Callable[[cirq.Operation], cirq.Operation] | None = None ) -> stim.Circuit: """Converts a cirq circuit into an equivalent stim circuit. @@ -36,6 +37,9 @@ def cirq_circuit_to_stim_circuit( circuit: The circuit to convert. qubit_to_index_dict: Optional. Which integer each qubit should get mapped to. If not specified, defaults to indexing qubits in the circuit in sorted order. + custom_op_conversion_func: Optional. A function which will transform cirq operators into other cirq operators, to be then + converted to STIM. Useful in e.g. the case of non-Clifford operations in a cirq circuit, which are to be replaced + by Clifford operations in STIM. Returns: The converted circuit. @@ -101,7 +105,7 @@ def cirq_circuit_to_stim_data( *, q2i: Optional[Dict[cirq.Qid, int]] = None, flatten: bool = False, - custom_op_conversion_func: Callable | None = None, + custom_op_conversion_func: Callable[[cirq.Operation], cirq.Operation] | None = None, ) -> Tuple[stim.Circuit, List[Tuple[str, int]]]: """Converts a Cirq circuit into a Stim circuit and also metadata about where measurements go.""" if q2i is None: @@ -145,21 +149,21 @@ def use( ) -> Callable[[stim.Circuit, List[int]], None]: if len(gates) == 1 and not individuals: (g,) = gates - return lambda c, t: c.append_operation(g, t) + return lambda c, t, tag: c.append(stim.CircuitInstruction(g, t, tag=tag)) if not individuals: - def do(c, t): + def do(c, t, tag): for g in gates: - c.append_operation(g, t) + c.append(stim.CircuitInstruction(g, t, tag=tag)) else: - def do(c, t): + def do(c, t, tag): for g in gates: - c.append_operation(g, t) + c.append(stim.CircuitInstruction(g, t, tag=tag)) for g, k in individuals: - c.append_operation(g, [t[k]]) + c.append(stim.CircuitInstruction(g, [t[k]], tag)) return do @@ -251,16 +255,17 @@ def gate_type_to_stim_append_func() -> Dict[Type[cirq.Gate], StimTypeHandler]: cirq.AsymmetricDepolarizingChannel: cast( StimTypeHandler, _stim_append_asymmetric_depolarizing_channel ), - cirq.BitFlipChannel: lambda c, g, t: c.append_operation( - "X_ERROR", t, cast(cirq.BitFlipChannel, g).p + cirq.BitFlipChannel: lambda c, g, t, tag: c.append(stim.CircuitInstruction( + "X_ERROR", t, cast(cirq.BitFlipChannel, g).p, tag=tag) ), - cirq.PhaseFlipChannel: lambda c, g, t: c.append_operation( - "Z_ERROR", t, cast(cirq.PhaseFlipChannel, g).p + cirq.PhaseFlipChannel: lambda c, g, t, tag: c.append(stim.CircuitInstruction( + "Z_ERROR", t, cast(cirq.PhaseFlipChannel, g).p, tag=tag) ), - cirq.PhaseDampingChannel: lambda c, g, t: c.append_operation( + cirq.PhaseDampingChannel: lambda c, g, t, tag: c.append(stim.CircuitInstruction( "Z_ERROR", t, 0.5 - math.sqrt(1 - cast(cirq.PhaseDampingChannel, g).gamma) / 2, + tag=tag) ), cirq.RandomGateChannel: cast(StimTypeHandler, _stim_append_random_gate_channel), cirq.DepolarizingChannel: cast( @@ -270,16 +275,16 @@ def gate_type_to_stim_append_func() -> Dict[Type[cirq.Gate], StimTypeHandler]: def _stim_append_measurement_gate( - circuit: stim.Circuit, gate: cirq.MeasurementGate, targets: List[int] + circuit: stim.Circuit, gate: cirq.MeasurementGate, targets: List[int], tag: str ): for i, b in enumerate(gate.invert_mask): if b: targets[i] = stim.target_inv(targets[i]) - circuit.append_operation("M", targets) + circuit.append(stim.CircuitInstruction("M", targets, tag=tag)) def _stim_append_pauli_measurement_gate( - circuit: stim.Circuit, gate: cirq.PauliMeasurementGate, targets: List[int] + circuit: stim.Circuit, gate: cirq.PauliMeasurementGate, targets: List[int], tag: str ): obs: cirq.DensePauliString = gate.observable() @@ -304,11 +309,11 @@ def _stim_append_pauli_measurement_gate( if obs.coefficient != 1 and obs.coefficient != -1: raise NotImplementedError(f"obs.coefficient={obs.coefficient!r} not in [1, -1]") - circuit.append_operation("MPP", new_targets) + circuit.append(stim.CircuitInstruction("MPP", new_targets, tag=tag)) def _stim_append_spp_gate( - circuit: stim.Circuit, gate: cirq.PauliStringPhasorGate, targets: List[int] + circuit: stim.Circuit, gate: cirq.PauliStringPhasorGate, targets: List[int], tag: str ): obs: cirq.DensePauliString = gate.dense_pauli_string a = gate.exponent_neg @@ -329,26 +334,26 @@ def _stim_append_spp_gate( return False new_targets.pop() - circuit.append_operation("SPP" if d == 0.5 else "SPP_DAG", new_targets) + circuit.append(stim.CircuitInstruction("SPP" if d == 0.5 else "SPP_DAG", new_targets, tag=tag)) return True def _stim_append_dense_pauli_string_gate( - c: stim.Circuit, g: cirq.BaseDensePauliString, t: List[int] + c: stim.Circuit, g: cirq.BaseDensePauliString, t: List[int], tag: str ): gates = [None, "X", "Y", "Z"] for p, k in zip(g.pauli_mask, t): if p: - c.append_operation(gates[p], [k]) + c.append(stim.CircuitInstruction(gates[p], [k], tag=tag)) def _stim_append_asymmetric_depolarizing_channel( - c: stim.Circuit, g: cirq.AsymmetricDepolarizingChannel, t: List[int] + c: stim.Circuit, g: cirq.AsymmetricDepolarizingChannel, t: List[int], tag: str ): if cirq.num_qubits(g) == 1: - c.append_operation("PAULI_CHANNEL_1", t, [g.p_x, g.p_y, g.p_z]) + c.append(stim.CircuitInstruction("PAULI_CHANNEL_1", t, [g.p_x, g.p_y, g.p_z], tag=tag)) elif cirq.num_qubits(g) == 2: - c.append_operation( + c.append(stim.CircuitInstruction( "PAULI_CHANNEL_2", t, [ @@ -368,34 +373,35 @@ def _stim_append_asymmetric_depolarizing_channel( g.error_probabilities.get("ZY", 0), g.error_probabilities.get("ZZ", 0), ], + tag=tag) ) else: raise NotImplementedError(f"cirq-to-stim gate {g!r}") def _stim_append_depolarizing_channel( - c: stim.Circuit, g: cirq.DepolarizingChannel, t: List[int] + c: stim.Circuit, g: cirq.DepolarizingChannel, t: List[int], tag: str ): if g.num_qubits() == 1: - c.append_operation("DEPOLARIZE1", t, g.p) + c.append(stim.CircuitInstruction("DEPOLARIZE1", t, g.p, tag=tag)) elif g.num_qubits() == 2: - c.append_operation("DEPOLARIZE2", t, g.p) + c.append(stim.CircuitInstruction("DEPOLARIZE2", t, g.p, tag=tag)) else: raise TypeError(f"Don't know how to turn {g!r} into Stim operations.") -def _stim_append_controlled_gate(c: stim.Circuit, g: cirq.ControlledGate, t: List[int]): +def _stim_append_controlled_gate(c: stim.Circuit, g: cirq.ControlledGate, t: List[int], tag: str): if isinstance(g.sub_gate, cirq.BaseDensePauliString) and g.num_controls() == 1: gates = [None, "CX", "CY", "CZ"] for p, k in zip(g.sub_gate.pauli_mask, t[1:]): if p: - c.append_operation(gates[p], [t[0], k]) + c.append(stim.CircuitInstruction(gates[p], [t[0], k], tag=tag)) if g.sub_gate.coefficient == 1j: - c.append_operation("S", t[:1]) + c.append(stim.CircuitInstruction("S", t[:1], tag=tag)) elif g.sub_gate.coefficient == -1: - c.append_operation("Z", t[:1]) + c.append(stim.CircuitInstruction("Z", t[:1], tag=tag)) elif g.sub_gate.coefficient == -1j: - c.append_operation("S_DAG", t[:1]) + c.append(stim.CircuitInstruction("S_DAG", t[:1], tag=tag)) elif g.sub_gate.coefficient == 1: pass else: @@ -408,14 +414,14 @@ def _stim_append_controlled_gate(c: stim.Circuit, g: cirq.ControlledGate, t: Lis def _stim_append_random_gate_channel( - c: stim.Circuit, g: cirq.RandomGateChannel, t: List[int] + c: stim.Circuit, g: cirq.RandomGateChannel, t: List[int], tag: str ): if g.sub_gate in [cirq.X, cirq.Y, cirq.Z]: - c.append_operation(f"{g.sub_gate}_ERROR", t, g.probability) + c.append(stim.CircuitInstruction(f"{g.sub_gate}_ERROR", t, g.probability, tag=tag)) elif isinstance(g.sub_gate, cirq.DensePauliString): target_p = [None, stim.target_x, stim.target_y, stim.target_z] pauli_targets = [target_p[p](t) for t, p in zip(t, g.sub_gate.pauli_mask) if p] - c.append_operation(f"CORRELATED_ERROR", pauli_targets, g.probability) + c.append(stim.CircuitInstruction(f"CORRELATED_ERROR", pauli_targets, g.probability, tag=tag)) else: raise NotImplementedError( f"Don't know how to turn probabilistic {g!r} into Stim operations." @@ -431,7 +437,7 @@ def __init__(self): self.flatten = False def process_circuit_operation_into_repeat_block( - self, op: cirq.CircuitOperation, custom_op_conversion_func: Callable | None + self, op: cirq.CircuitOperation, custom_op_conversion_func: Callable[[cirq.Operation], cirq.Operation] | None ) -> None: if self.flatten or op.repetitions == 1: moments = cirq.unroll_circuit_op( @@ -451,7 +457,7 @@ def process_circuit_operation_into_repeat_block( ) self.out += child.out * op.repetitions - def process_operations(self, operations: Iterable[cirq.Operation], custom_op_conversion_func: Callable | None) -> None: + def process_operations(self, operations: Iterable[cirq.Operation], custom_op_conversion_func: Callable[[cirq.Operation], cirq.Operation] | None) -> None: g2f = gate_to_stim_append_func() t2f = gate_type_to_stim_append_func() for op in operations: @@ -459,6 +465,18 @@ def process_operations(self, operations: Iterable[cirq.Operation], custom_op_con op = op.untagged if custom_op_conversion_func is None else custom_op_conversion_func(op) gate = op.gate targets = [self.q2i[q] for q in op.qubits] + if isinstance(op, cirq.TaggedOperation): + str_tags = [tag for tag in op.tags if isinstance(tag, str)] + tag = "" + i = 0 + while i < len(str_tags): + tag += str_tags[i] + if len(str_tags) > 1: + tag += ", " + i += 1 + op = op.untagged + else: + tag = "" custom_method = getattr( op, "_stim_conversion_", getattr(gate, "_stim_conversion_", None) @@ -479,27 +497,27 @@ def process_operations(self, operations: Iterable[cirq.Operation], custom_op_con # Special case measurement, because of its metadata. if isinstance(gate, cirq.PauliStringPhasorGate): - if _stim_append_spp_gate(self.out, gate, targets): + if _stim_append_spp_gate(self.out, gate, targets, tag): continue if isinstance(gate, cirq.PauliMeasurementGate): self.key_out.append((gate.key, len(targets))) - _stim_append_pauli_measurement_gate(self.out, gate, targets) + _stim_append_pauli_measurement_gate(self.out, gate, targets, tag) continue if isinstance(gate, cirq.MeasurementGate): self.key_out.append((gate.key, len(targets))) - _stim_append_measurement_gate(self.out, gate, targets) + _stim_append_measurement_gate(self.out, gate, targets, tag) continue # Look for recognized gate values like cirq.H. val_append_func = g2f.get(gate) if val_append_func is not None: - val_append_func(self.out, targets) + val_append_func(self.out, targets, tag) continue # Look for recognized gate types like cirq.DepolarizingChannel. type_append_func = t2f.get(type(gate)) if type_append_func is not None: - type_append_func(self.out, gate, targets) + type_append_func(self.out, gate, targets, tag) continue # Ask unrecognized operations to decompose themselves into simpler operations. @@ -512,7 +530,7 @@ def process_operations(self, operations: Iterable[cirq.Operation], custom_op_con f"- It doesn't have a _stim_conversion_ method.\n" ) from ex - def process_moment(self, moment: cirq.Moment, custom_op_conversion_func: Callable | None): + def process_moment(self, moment: cirq.Moment, custom_op_conversion_func: Callable[[cirq.Operation], cirq.Operation] | None): length_before = len(self.out) self.process_operations(moment, custom_op_conversion_func=custom_op_conversion_func) @@ -522,6 +540,6 @@ def process_moment(self, moment: cirq.Moment, custom_op_conversion_func: Callabl ): self.out.append_operation("TICK", []) - def process_moments(self, moments: Iterable[cirq.Moment], custom_op_conversion_func: Callable | None): + def process_moments(self, moments: Iterable[cirq.Moment], custom_op_conversion_func: Callable[[cirq.Operation], cirq.Operation] | None): for moment in moments: self.process_moment(moment, custom_op_conversion_func=custom_op_conversion_func) diff --git a/glue/cirq/stimcirq/_cirq_to_stim_test.py b/glue/cirq/stimcirq/_cirq_to_stim_test.py index fb46008e1..5ac500866 100644 --- a/glue/cirq/stimcirq/_cirq_to_stim_test.py +++ b/glue/cirq/stimcirq/_cirq_to_stim_test.py @@ -411,13 +411,13 @@ def test_random_gate_channel(): def test_stimcirq_custom_conversion(): + """ Checks the custom operation conversion functionality. In this test, we specifically + convert cirq TaggedOperations with particular tag values to a given STIM operation, + according to the lookup `_tag_lookup`. """ _tag_lookup = {"H": cirq.H, "X": cirq.X, "Y": cirq.Y, "Z": cirq.Z} def _op_conversion(op: cirq.Operation) -> cirq.Operation: - """" For converting particular tagged cirq.Operator's to a value described by the tag content. - Useful when treating non-Clifford gates in cirq and converting to STIM. - """ if isinstance(op, cirq.TaggedOperation): tag_checks = [tag for tag in op.tags if tag in list(_tag_lookup.keys())] if len(tag_checks) == 1: @@ -425,6 +425,8 @@ def _op_conversion(op: cirq.Operation) -> cirq.Operation: op = gate.on(*op.qubits) elif len(tag_checks) > 1: raise ValueError(f"found multiple {op.tags=} matching a conversion") + else: + return op # a different tag return op.untagged a, b = cirq.LineQubit.range(2) @@ -437,7 +439,7 @@ def _op_conversion(op: cirq.Operation) -> cirq.Operation: stim_circuit = stimcirq.cirq_circuit_to_stim_circuit(c, custom_op_conversion_func=_op_conversion) assert stim_circuit == stim.Circuit( - """ + """ H 0 X 1 TICK @@ -480,13 +482,39 @@ def _op_conversion(op: cirq.Operation) -> cirq.Operation: stim_circuit = stimcirq.cirq_circuit_to_stim_circuit(c, custom_op_conversion_func=_op_conversion) assert stim_circuit == stim.Circuit( - """ - REPEAT 3 { - H 0 - X 1 - TICK - M 0 1 - TICK - } - """ + """ + REPEAT 3 { + H 0 + X 1 + TICK + M 0 1 + TICK + } + """ + ) + + a, b = cirq.LineQubit.range(2) + c = cirq.Circuit( + cirq.CircuitOperation( + cirq.FrozenCircuit( + cirq.X(a).with_tags("H"), + cirq.X(b).with_tags("hi"), + cirq.measure(a, key="a"), + cirq.measure(b, key="b"), + ), + repetitions=3, + ).with_tags("my_tag") + ) + + stim_circuit = stimcirq.cirq_circuit_to_stim_circuit(c, custom_op_conversion_func=_op_conversion) + assert stim_circuit == stim.Circuit( + """ + REPEAT 3 { + H 0 + X[hi] 1 + TICK + M 0 1 + TICK + } + """ )