Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add stimcirq tags #862

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 14 additions & 30 deletions glue/cirq/stimcirq/_cirq_to_stim.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,13 @@

import cirq
import stim
import warnings

_STIMCIRQ_TAG_LOOKUP = {"H": cirq.H, "X": cirq.X, "Y": cirq.Y, "Z": cirq.Z}


def cirq_circuit_to_stim_circuit(
circuit: cirq.AbstractCircuit,
*,
qubit_to_index_dict: Optional[Dict[cirq.Qid, int]] = None,
custom_op_conversion_func: Callable | None = None
emma-louise-rosenfeld marked this conversation as resolved.
Show resolved Hide resolved
) -> stim.Circuit:
"""Converts a cirq circuit into an equivalent stim circuit.

Expand Down Expand Up @@ -95,14 +93,15 @@ def _stim_conversion_(

edit_circuit.append_operation("H", targets)
"""
return cirq_circuit_to_stim_data(circuit, q2i=qubit_to_index_dict, flatten=False)[0]
return cirq_circuit_to_stim_data(circuit, q2i=qubit_to_index_dict, flatten=False, custom_op_conversion_func=custom_op_conversion_func)[0]


def cirq_circuit_to_stim_data(
circuit: cirq.AbstractCircuit,
*,
q2i: Optional[Dict[cirq.Qid, int]] = None,
flatten: bool = False,
custom_op_conversion_func: Callable | None = None,
) -> Tuple[stim.Circuit, List[Tuple[str, int]]]:
"""Converts a Cirq circuit into a Stim circuit and also metadata about where measurements go."""
if q2i is None:
Expand All @@ -119,7 +118,7 @@ def cirq_circuit_to_stim_data(
elif isinstance(q, cirq.GridQubit):
helper.out.append_operation("QUBIT_COORDS", [q2i[q]], [q.row, q.col])

helper.process_moments(circuit)
helper.process_moments(circuit, custom_op_conversion_func=custom_op_conversion_func)
return helper.out, helper.key_out


Expand Down Expand Up @@ -432,13 +431,13 @@ def __init__(self):
self.flatten = False

def process_circuit_operation_into_repeat_block(
self, op: cirq.CircuitOperation
self, op: cirq.CircuitOperation, custom_op_conversion_func: Callable | None
) -> None:
if self.flatten or op.repetitions == 1:
moments = cirq.unroll_circuit_op(
cirq.Circuit(op), deep=False, tags_to_check=None
).moments
self.process_moments(moments)
self.process_moments(moments, custom_op_conversion_func=custom_op_conversion_func)
self.out = self.out[:-1] # Remove a trailing TICK (to avoid double TICK)
return

Expand All @@ -448,31 +447,16 @@ def process_circuit_operation_into_repeat_block(
child.have_seen_loop = True
self.have_seen_loop = True
child.process_moments(
op.transform_qubits(lambda q: op.qubit_map.get(q, q)).circuit
op.transform_qubits(lambda q: op.qubit_map.get(q, q)).circuit, custom_op_conversion_func=custom_op_conversion_func
)
self.out += child.out * op.repetitions

def process_operations(self, operations: Iterable[cirq.Operation]) -> None:
def process_operations(self, operations: Iterable[cirq.Operation], custom_op_conversion_func: Callable | None) -> None:
g2f = gate_to_stim_append_func()
t2f = gate_type_to_stim_append_func()
for op in operations:
assert isinstance(op, cirq.Operation)

# for tagged cirq operations with a particular conversion
if isinstance(op, cirq.TaggedOperation) and len(op.tags) == 1:
(tag,) = op.tags
if tag in list(_STIMCIRQ_TAG_LOOKUP.keys()):
gate = _STIMCIRQ_TAG_LOOKUP[tag]
op = gate.on(*op.qubits)
else:
warnings.warn(f"ignoring cirq {tag=} for conversion to stim")
op = op.untagged
gate = op.gate
else:
op = op.untagged
gate = op.gate

op = op.untagged
op = op.untagged if custom_op_conversion_func is None else custom_op_conversion_func(op)
gate = op.gate
targets = [self.q2i[q] for q in op.qubits]

Expand All @@ -490,7 +474,7 @@ def process_operations(self, operations: Iterable[cirq.Operation]) -> None:
continue

if isinstance(op, cirq.CircuitOperation):
self.process_circuit_operation_into_repeat_block(op)
self.process_circuit_operation_into_repeat_block(op, custom_op_conversion_func=custom_op_conversion_func)
continue

# Special case measurement, because of its metadata.
Expand Down Expand Up @@ -528,16 +512,16 @@ def process_operations(self, operations: Iterable[cirq.Operation]) -> None:
f"- It doesn't have a _stim_conversion_ method.\n"
) from ex

def process_moment(self, moment: cirq.Moment):
def process_moment(self, moment: cirq.Moment, custom_op_conversion_func: Callable | None):
length_before = len(self.out)
self.process_operations(moment)
self.process_operations(moment, custom_op_conversion_func=custom_op_conversion_func)

# Append a TICK, unless it was already handled by an internal REPEAT block.
if length_before == len(self.out) or not isinstance(
self.out[-1], stim.CircuitRepeatBlock
):
self.out.append_operation("TICK", [])

def process_moments(self, moments: Iterable[cirq.Moment]):
def process_moments(self, moments: Iterable[cirq.Moment], custom_op_conversion_func: Callable | None):
for moment in moments:
self.process_moment(moment)
self.process_moment(moment, custom_op_conversion_func=custom_op_conversion_func)
40 changes: 27 additions & 13 deletions glue/cirq/stimcirq/_cirq_to_stim_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,22 @@ def test_random_gate_channel():
)


def test_stimcirq_tags():
def test_stimcirq_custom_conversion():

_tag_lookup = {"H": cirq.H, "X": cirq.X, "Y": cirq.Y, "Z": cirq.Z}

def _op_conversion(op: cirq.Operation) -> cirq.Operation:
"""" For converting particular tagged cirq.Operator's to a value described by the tag content.
emma-louise-rosenfeld marked this conversation as resolved.
Show resolved Hide resolved
Useful when treating non-Clifford gates in cirq and converting to STIM.
"""
if isinstance(op, cirq.TaggedOperation):
tag_checks = [tag for tag in op.tags if tag in list(_tag_lookup.keys())]
if len(tag_checks) == 1:
gate = _tag_lookup[tag_checks[0]]
op = gate.on(*op.qubits)
elif len(tag_checks) > 1:
raise ValueError(f"found multiple {op.tags=} matching a conversion")
return op.untagged

a, b = cirq.LineQubit.range(2)
c = cirq.FrozenCircuit(
Expand All @@ -420,7 +435,7 @@ def test_stimcirq_tags():
cirq.measure(b, key="b"),
)

stim_circuit = stimcirq.cirq_circuit_to_stim_circuit(c)
stim_circuit = stimcirq.cirq_circuit_to_stim_circuit(c, custom_op_conversion_func=_op_conversion)
assert stim_circuit == stim.Circuit(
"""
H 0
Expand All @@ -439,7 +454,7 @@ def test_stimcirq_tags():
cirq.measure(b, key="b"),
)

stim_circuit = stimcirq.cirq_circuit_to_stim_circuit(c)
stim_circuit = stimcirq.cirq_circuit_to_stim_circuit(c, custom_op_conversion_func=_op_conversion)
assert stim_circuit == stim.Circuit(
"""
H 0
Expand All @@ -463,16 +478,15 @@ def test_stimcirq_tags():
).with_tags("my_tag")
)

stim_circuit = stimcirq.cirq_circuit_to_stim_circuit(c)
stim_circuit = stimcirq.cirq_circuit_to_stim_circuit(c, custom_op_conversion_func=_op_conversion)
assert stim_circuit == stim.Circuit(
"""

REPEAT 3 {
H 0
X 1
TICK
M 0 1
TICK
}
"""
REPEAT 3 {
H 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do a test where the stim circuit ends up containing a tag.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Strilanc this required some extra functionality (persisting cirq tags through to stim tags) beyond our use case here but I gave a stab at it. See my most recent commit for an example of what I'm thinking. I haven't run/written tests on it yet.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Strilanc , done, check my latest commit. Note that tags on cirq.CircuitOperation's aren't persisted through - only tags associated with cirq.TaggedOperations are maintained

X 1
TICK
M 0 1
TICK
}
"""
)