Skip to content

Commit

Permalink
persist cirq tags through to stim
Browse files Browse the repository at this point in the history
  • Loading branch information
emma-louise-rosenfeld committed Nov 27, 2024
1 parent 1acf7fc commit 84d3e40
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 58 deletions.
108 changes: 63 additions & 45 deletions glue/cirq/stimcirq/_cirq_to_stim.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import functools
import itertools
import math
from typing import Callable, cast, Dict, Iterable, List, Optional, Sequence, Tuple, Type
from collections.abc import Callable
from typing import cast, Dict, Iterable, List, Optional, Sequence, Tuple, Type

import cirq
import stim
Expand All @@ -11,7 +12,7 @@ def cirq_circuit_to_stim_circuit(
circuit: cirq.AbstractCircuit,
*,
qubit_to_index_dict: Optional[Dict[cirq.Qid, int]] = None,
custom_op_conversion_func: Callable | None = None
custom_op_conversion_func: Callable[[cirq.Operation], cirq.Operation] | None = None
) -> stim.Circuit:
"""Converts a cirq circuit into an equivalent stim circuit.
Expand All @@ -36,6 +37,9 @@ def cirq_circuit_to_stim_circuit(
circuit: The circuit to convert.
qubit_to_index_dict: Optional. Which integer each qubit should get mapped to. If not specified, defaults to
indexing qubits in the circuit in sorted order.
custom_op_conversion_func: Optional. A function which will transform cirq operators into other cirq operators, to be then
converted to STIM. Useful in e.g. the case of non-Clifford operations in a cirq circuit, which are to be replaced
by Clifford operations in STIM.
Returns:
The converted circuit.
Expand Down Expand Up @@ -101,7 +105,7 @@ def cirq_circuit_to_stim_data(
*,
q2i: Optional[Dict[cirq.Qid, int]] = None,
flatten: bool = False,
custom_op_conversion_func: Callable | None = None,
custom_op_conversion_func: Callable[[cirq.Operation], cirq.Operation] | None = None,
) -> Tuple[stim.Circuit, List[Tuple[str, int]]]:
"""Converts a Cirq circuit into a Stim circuit and also metadata about where measurements go."""
if q2i is None:
Expand Down Expand Up @@ -145,21 +149,21 @@ def use(
) -> Callable[[stim.Circuit, List[int]], None]:
if len(gates) == 1 and not individuals:
(g,) = gates
return lambda c, t: c.append_operation(g, t)
return lambda c, t, tag: c.append(stim.CircuitInstruction(g, t, tag=tag))

if not individuals:

def do(c, t):
def do(c, t, tag):
for g in gates:
c.append_operation(g, t)
c.append(stim.CircuitInstruction(g, t, tag=tag))

else:

def do(c, t):
def do(c, t, tag):
for g in gates:
c.append_operation(g, t)
c.append(stim.CircuitInstruction(g, t, tag=tag))
for g, k in individuals:
c.append_operation(g, [t[k]])
c.append(stim.CircuitInstruction(g, [t[k]], tag))

return do

Expand Down Expand Up @@ -251,16 +255,17 @@ def gate_type_to_stim_append_func() -> Dict[Type[cirq.Gate], StimTypeHandler]:
cirq.AsymmetricDepolarizingChannel: cast(
StimTypeHandler, _stim_append_asymmetric_depolarizing_channel
),
cirq.BitFlipChannel: lambda c, g, t: c.append_operation(
"X_ERROR", t, cast(cirq.BitFlipChannel, g).p
cirq.BitFlipChannel: lambda c, g, t, tag: c.append(stim.CircuitInstruction(
"X_ERROR", t, cast(cirq.BitFlipChannel, g).p, tag=tag)
),
cirq.PhaseFlipChannel: lambda c, g, t: c.append_operation(
"Z_ERROR", t, cast(cirq.PhaseFlipChannel, g).p
cirq.PhaseFlipChannel: lambda c, g, t, tag: c.append(stim.CircuitInstruction(
"Z_ERROR", t, cast(cirq.PhaseFlipChannel, g).p, tag=tag)
),
cirq.PhaseDampingChannel: lambda c, g, t: c.append_operation(
cirq.PhaseDampingChannel: lambda c, g, t, tag: c.append(stim.CircuitInstruction(
"Z_ERROR",
t,
0.5 - math.sqrt(1 - cast(cirq.PhaseDampingChannel, g).gamma) / 2,
tag=tag)
),
cirq.RandomGateChannel: cast(StimTypeHandler, _stim_append_random_gate_channel),
cirq.DepolarizingChannel: cast(
Expand All @@ -270,16 +275,16 @@ def gate_type_to_stim_append_func() -> Dict[Type[cirq.Gate], StimTypeHandler]:


def _stim_append_measurement_gate(
circuit: stim.Circuit, gate: cirq.MeasurementGate, targets: List[int]
circuit: stim.Circuit, gate: cirq.MeasurementGate, targets: List[int], tag: str
):
for i, b in enumerate(gate.invert_mask):
if b:
targets[i] = stim.target_inv(targets[i])
circuit.append_operation("M", targets)
circuit.append(stim.CircuitInstruction("M", targets, tag=tag))


def _stim_append_pauli_measurement_gate(
circuit: stim.Circuit, gate: cirq.PauliMeasurementGate, targets: List[int]
circuit: stim.Circuit, gate: cirq.PauliMeasurementGate, targets: List[int], tag: str
):
obs: cirq.DensePauliString = gate.observable()

Expand All @@ -304,11 +309,11 @@ def _stim_append_pauli_measurement_gate(
if obs.coefficient != 1 and obs.coefficient != -1:
raise NotImplementedError(f"obs.coefficient={obs.coefficient!r} not in [1, -1]")

circuit.append_operation("MPP", new_targets)
circuit.append(stim.CircuitInstruction("MPP", new_targets, tag=tag))


def _stim_append_spp_gate(
circuit: stim.Circuit, gate: cirq.PauliStringPhasorGate, targets: List[int]
circuit: stim.Circuit, gate: cirq.PauliStringPhasorGate, targets: List[int], tag: str
):
obs: cirq.DensePauliString = gate.dense_pauli_string
a = gate.exponent_neg
Expand All @@ -329,26 +334,26 @@ def _stim_append_spp_gate(
return False
new_targets.pop()

circuit.append_operation("SPP" if d == 0.5 else "SPP_DAG", new_targets)
circuit.append(stim.CircuitInstruction("SPP" if d == 0.5 else "SPP_DAG", new_targets, tag=tag))
return True


def _stim_append_dense_pauli_string_gate(
c: stim.Circuit, g: cirq.BaseDensePauliString, t: List[int]
c: stim.Circuit, g: cirq.BaseDensePauliString, t: List[int], tag: str
):
gates = [None, "X", "Y", "Z"]
for p, k in zip(g.pauli_mask, t):
if p:
c.append_operation(gates[p], [k])
c.append(stim.CircuitInstruction(gates[p], [k], tag=tag))


def _stim_append_asymmetric_depolarizing_channel(
c: stim.Circuit, g: cirq.AsymmetricDepolarizingChannel, t: List[int]
c: stim.Circuit, g: cirq.AsymmetricDepolarizingChannel, t: List[int], tag: str
):
if cirq.num_qubits(g) == 1:
c.append_operation("PAULI_CHANNEL_1", t, [g.p_x, g.p_y, g.p_z])
c.append(stim.CircuitInstruction("PAULI_CHANNEL_1", t, [g.p_x, g.p_y, g.p_z], tag=tag))
elif cirq.num_qubits(g) == 2:
c.append_operation(
c.append(stim.CircuitInstruction(
"PAULI_CHANNEL_2",
t,
[
Expand All @@ -368,34 +373,35 @@ def _stim_append_asymmetric_depolarizing_channel(
g.error_probabilities.get("ZY", 0),
g.error_probabilities.get("ZZ", 0),
],
tag=tag)
)
else:
raise NotImplementedError(f"cirq-to-stim gate {g!r}")


def _stim_append_depolarizing_channel(
c: stim.Circuit, g: cirq.DepolarizingChannel, t: List[int]
c: stim.Circuit, g: cirq.DepolarizingChannel, t: List[int], tag: str
):
if g.num_qubits() == 1:
c.append_operation("DEPOLARIZE1", t, g.p)
c.append(stim.CircuitInstruction("DEPOLARIZE1", t, g.p, tag=tag))
elif g.num_qubits() == 2:
c.append_operation("DEPOLARIZE2", t, g.p)
c.append(stim.CircuitInstruction("DEPOLARIZE2", t, g.p, tag=tag))
else:
raise TypeError(f"Don't know how to turn {g!r} into Stim operations.")


def _stim_append_controlled_gate(c: stim.Circuit, g: cirq.ControlledGate, t: List[int]):
def _stim_append_controlled_gate(c: stim.Circuit, g: cirq.ControlledGate, t: List[int], tag: str):
if isinstance(g.sub_gate, cirq.BaseDensePauliString) and g.num_controls() == 1:
gates = [None, "CX", "CY", "CZ"]
for p, k in zip(g.sub_gate.pauli_mask, t[1:]):
if p:
c.append_operation(gates[p], [t[0], k])
c.append(stim.CircuitInstruction(gates[p], [t[0], k], tag=tag))
if g.sub_gate.coefficient == 1j:
c.append_operation("S", t[:1])
c.append(stim.CircuitInstruction("S", t[:1], tag=tag))
elif g.sub_gate.coefficient == -1:
c.append_operation("Z", t[:1])
c.append(stim.CircuitInstruction("Z", t[:1], tag=tag))
elif g.sub_gate.coefficient == -1j:
c.append_operation("S_DAG", t[:1])
c.append(stim.CircuitInstruction("S_DAG", t[:1], tag=tag))
elif g.sub_gate.coefficient == 1:
pass
else:
Expand All @@ -408,14 +414,14 @@ def _stim_append_controlled_gate(c: stim.Circuit, g: cirq.ControlledGate, t: Lis


def _stim_append_random_gate_channel(
c: stim.Circuit, g: cirq.RandomGateChannel, t: List[int]
c: stim.Circuit, g: cirq.RandomGateChannel, t: List[int], tag: str
):
if g.sub_gate in [cirq.X, cirq.Y, cirq.Z]:
c.append_operation(f"{g.sub_gate}_ERROR", t, g.probability)
c.append(stim.CircuitInstruction(f"{g.sub_gate}_ERROR", t, g.probability, tag=tag))
elif isinstance(g.sub_gate, cirq.DensePauliString):
target_p = [None, stim.target_x, stim.target_y, stim.target_z]
pauli_targets = [target_p[p](t) for t, p in zip(t, g.sub_gate.pauli_mask) if p]
c.append_operation(f"CORRELATED_ERROR", pauli_targets, g.probability)
c.append(stim.CircuitInstruction(f"CORRELATED_ERROR", pauli_targets, g.probability, tag=tag))
else:
raise NotImplementedError(
f"Don't know how to turn probabilistic {g!r} into Stim operations."
Expand All @@ -431,7 +437,7 @@ def __init__(self):
self.flatten = False

def process_circuit_operation_into_repeat_block(
self, op: cirq.CircuitOperation, custom_op_conversion_func: Callable | None
self, op: cirq.CircuitOperation, custom_op_conversion_func: Callable[[cirq.Operation], cirq.Operation] | None
) -> None:
if self.flatten or op.repetitions == 1:
moments = cirq.unroll_circuit_op(
Expand All @@ -451,14 +457,26 @@ def process_circuit_operation_into_repeat_block(
)
self.out += child.out * op.repetitions

def process_operations(self, operations: Iterable[cirq.Operation], custom_op_conversion_func: Callable | None) -> None:
def process_operations(self, operations: Iterable[cirq.Operation], custom_op_conversion_func: Callable[[cirq.Operation], cirq.Operation] | None) -> None:
g2f = gate_to_stim_append_func()
t2f = gate_type_to_stim_append_func()
for op in operations:
assert isinstance(op, cirq.Operation)
op = op.untagged if custom_op_conversion_func is None else custom_op_conversion_func(op)
gate = op.gate
targets = [self.q2i[q] for q in op.qubits]
if isinstance(op, cirq.TaggedOperation):
str_tags = [tag for tag in op.tags if isinstance(tag, str)]
tag = ""
i = 0
while i < len(str_tags):
tag += str_tags[i]
if len(str_tags) > 1:
tag += ", "
i += 1
op = op.untagged
else:
tag = ""

custom_method = getattr(
op, "_stim_conversion_", getattr(gate, "_stim_conversion_", None)
Expand All @@ -479,27 +497,27 @@ def process_operations(self, operations: Iterable[cirq.Operation], custom_op_con

# Special case measurement, because of its metadata.
if isinstance(gate, cirq.PauliStringPhasorGate):
if _stim_append_spp_gate(self.out, gate, targets):
if _stim_append_spp_gate(self.out, gate, targets, tag):
continue
if isinstance(gate, cirq.PauliMeasurementGate):
self.key_out.append((gate.key, len(targets)))
_stim_append_pauli_measurement_gate(self.out, gate, targets)
_stim_append_pauli_measurement_gate(self.out, gate, targets, tag)
continue
if isinstance(gate, cirq.MeasurementGate):
self.key_out.append((gate.key, len(targets)))
_stim_append_measurement_gate(self.out, gate, targets)
_stim_append_measurement_gate(self.out, gate, targets, tag)
continue

# Look for recognized gate values like cirq.H.
val_append_func = g2f.get(gate)
if val_append_func is not None:
val_append_func(self.out, targets)
val_append_func(self.out, targets, tag)
continue

# Look for recognized gate types like cirq.DepolarizingChannel.
type_append_func = t2f.get(type(gate))
if type_append_func is not None:
type_append_func(self.out, gate, targets)
type_append_func(self.out, gate, targets, tag)
continue

# Ask unrecognized operations to decompose themselves into simpler operations.
Expand All @@ -512,7 +530,7 @@ def process_operations(self, operations: Iterable[cirq.Operation], custom_op_con
f"- It doesn't have a _stim_conversion_ method.\n"
) from ex

def process_moment(self, moment: cirq.Moment, custom_op_conversion_func: Callable | None):
def process_moment(self, moment: cirq.Moment, custom_op_conversion_func: Callable[[cirq.Operation], cirq.Operation] | None):
length_before = len(self.out)
self.process_operations(moment, custom_op_conversion_func=custom_op_conversion_func)

Expand All @@ -522,6 +540,6 @@ def process_moment(self, moment: cirq.Moment, custom_op_conversion_func: Callabl
):
self.out.append_operation("TICK", [])

def process_moments(self, moments: Iterable[cirq.Moment], custom_op_conversion_func: Callable | None):
def process_moments(self, moments: Iterable[cirq.Moment], custom_op_conversion_func: Callable[[cirq.Operation], cirq.Operation] | None):
for moment in moments:
self.process_moment(moment, custom_op_conversion_func=custom_op_conversion_func)
54 changes: 41 additions & 13 deletions glue/cirq/stimcirq/_cirq_to_stim_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,20 +411,22 @@ def test_random_gate_channel():


def test_stimcirq_custom_conversion():
""" Checks the custom operation conversion functionality. In this test, we specifically
convert cirq TaggedOperations with particular tag values to a given STIM operation,
according to the lookup `_tag_lookup`. """

_tag_lookup = {"H": cirq.H, "X": cirq.X, "Y": cirq.Y, "Z": cirq.Z}

def _op_conversion(op: cirq.Operation) -> cirq.Operation:
"""" For converting particular tagged cirq.Operator's to a value described by the tag content.
Useful when treating non-Clifford gates in cirq and converting to STIM.
"""
if isinstance(op, cirq.TaggedOperation):
tag_checks = [tag for tag in op.tags if tag in list(_tag_lookup.keys())]
if len(tag_checks) == 1:
gate = _tag_lookup[tag_checks[0]]
op = gate.on(*op.qubits)
elif len(tag_checks) > 1:
raise ValueError(f"found multiple {op.tags=} matching a conversion")
else:
return op # a different tag
return op.untagged

a, b = cirq.LineQubit.range(2)
Expand All @@ -437,7 +439,7 @@ def _op_conversion(op: cirq.Operation) -> cirq.Operation:

stim_circuit = stimcirq.cirq_circuit_to_stim_circuit(c, custom_op_conversion_func=_op_conversion)
assert stim_circuit == stim.Circuit(
"""
"""
H 0
X 1
TICK
Expand Down Expand Up @@ -480,13 +482,39 @@ def _op_conversion(op: cirq.Operation) -> cirq.Operation:

stim_circuit = stimcirq.cirq_circuit_to_stim_circuit(c, custom_op_conversion_func=_op_conversion)
assert stim_circuit == stim.Circuit(
"""
REPEAT 3 {
H 0
X 1
TICK
M 0 1
TICK
}
"""
"""
REPEAT 3 {
H 0
X 1
TICK
M 0 1
TICK
}
"""
)

a, b = cirq.LineQubit.range(2)
c = cirq.Circuit(
cirq.CircuitOperation(
cirq.FrozenCircuit(
cirq.X(a).with_tags("H"),
cirq.X(b).with_tags("hi"),
cirq.measure(a, key="a"),
cirq.measure(b, key="b"),
),
repetitions=3,
).with_tags("my_tag")
)

stim_circuit = stimcirq.cirq_circuit_to_stim_circuit(c, custom_op_conversion_func=_op_conversion)
assert stim_circuit == stim.Circuit(
"""
REPEAT 3 {
H 0
X[hi] 1
TICK
M 0 1
TICK
}
"""
)

0 comments on commit 84d3e40

Please sign in to comment.