Skip to content

Commit

Permalink
add meas basis change for cirq
Browse files Browse the repository at this point in the history
  • Loading branch information
TheGupta2012 committed Aug 20, 2024
1 parent 47184a0 commit d2d5464
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 3 deletions.
32 changes: 30 additions & 2 deletions qbraid_qir/cirq/opsets.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,17 @@ def i(builder, qubits):
pyqir._native.x(builder, qubits)


def measure_x(builder, qubit, result):
pyqir._native.h(builder, qubit)
pyqir._native.mz(builder, qubit, result)


def measure_y(builder, qubit, result):
pyqir._native.s_adj(builder, qubit)
pyqir._native.h(builder, qubit)
pyqir._native.mz(builder, qubit, result)


PYQIR_OP_MAP = {
# Identity Gate
"I": i,
Expand All @@ -51,7 +62,9 @@ def i(builder, qubits):
# Three-Qubit Gates
"TOFFOLI": pyqir._native.ccx,
# Classical Gates/Operations
"MEASURE": pyqir._native.mz,
"measure_z": pyqir._native.mz,
"measure_x": measure_x,
"measure_y": measure_y,
"reset": pyqir._native.reset,
}

Expand All @@ -76,7 +89,22 @@ def map_cirq_op_to_pyqir_callable(
gate = operation.gate

if isinstance(gate, cirq.ops.MeasurementGate):
op_name = "MEASURE"
op_name = "measure_z"

elif isinstance(gate, cirq.ops.PauliMeasurementGate):
op_name = str(gate.observable())
op_name = op_name.removeprefix("+") # Remove the '+' sign
op_name = op_name.removeprefix("-") # Remove the '-' sign
# TODO: is XYZ same as X tensor Y tensor Z?
# if yes, then we can extend this to multi-qubit measurements

if op_name not in ["X", "Y", "Z"]:
raise CirqConversionError(
f"Multi-qubit gate {op_name} not supported for measurement."
)

op_name = f"measure_{op_name.lower()}"

elif isinstance(gate, (cirq.ops.Rx, cirq.ops.Ry, cirq.ops.Rz)):
op_name = gate.__class__.__name__
elif isinstance(gate, cirq.ops.Pauli):
Expand Down
2 changes: 1 addition & 1 deletion qbraid_qir/cirq/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def _branch(conds, pyqir_func):
else:
pyqir_func, op_str = map_cirq_op_to_pyqir_callable(operation)

if op_str == "MEASURE":
if op_str.startswith("measure"):
handle_measurement(pyqir_func)
elif op_str in ["Rx", "Ry", "Rz"]:
pyqir_func(self._builder, operation.gate._rads, *qubits)
Expand Down
11 changes: 11 additions & 0 deletions tests/cirq_qir/test_cirq_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import numpy as np
import pytest

from qbraid_qir.cirq.exceptions import CirqConversionError
from qbraid_qir.cirq.passes import preprocess_circuit

# pylint: disable=redefined-outer-name
Expand Down Expand Up @@ -56,3 +57,13 @@ def test_empty_circuit_conversion():
circuit = cirq.Circuit()
converted_circuit = preprocess_circuit(circuit)
assert len(converted_circuit.all_qubits()) == 0, "Converted empty circuit should have no qubits"


def test_multi_qubit_measurement_error():
qubits = cirq.LineQubit.range(3)
circuit = cirq.Circuit()
ps = cirq.X(qubits[0]) * cirq.Y(qubits[1]) * cirq.X(qubits[2])
meas_gates = cirq.measure_single_paulistring(ps)
circuit.append(meas_gates)
with pytest.raises(CirqConversionError):
preprocess_circuit(circuit)
36 changes: 36 additions & 0 deletions tests/cirq_qir/test_cirq_to_qir.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
from tests.qir_utils import (
assert_equal_qir,
check_attributes,
check_measure_op,
check_single_qubit_gate_op,
double_op_call_string,
generic_op_call_string,
get_entry_point_body,
Expand Down Expand Up @@ -161,6 +163,40 @@ def test_measurement(circuit_name, request):
assert len(func) == 4


def test_single_pauli_measurement():
# single Pauli gate
qubit = cirq.LineQubit.range(1)
circuit = cirq.Circuit()
ps = cirq.X(qubit[0])
meas_gates = cirq.measure_single_paulistring(ps)
circuit.append(meas_gates)

qir_module = cirq_to_qir(circuit, record_output=False)
generated_qir = str(qir_module).splitlines()

check_attributes(generated_qir, 1, 1)
check_single_qubit_gate_op(generated_qir, 1, [0], "h")
check_measure_op(generated_qir, 1, [0], [0])


def test_pauli_term_measurements():
# single terms
qubits = cirq.LineQubit.range(3)
circuit = cirq.Circuit()
ps = cirq.X(qubits[0]) * cirq.Y(qubits[1]) * cirq.X(qubits[2])
meas_gates = cirq.measure_paulistring_terms(ps)
for gate in meas_gates:
circuit.append(gate)

qir_module = cirq_to_qir(circuit, record_output=False)
generated_qir = str(qir_module).splitlines()

check_attributes(generated_qir, 3, 3)
check_single_qubit_gate_op(generated_qir, 1, [1], "sdg")
check_single_qubit_gate_op(generated_qir, 3, [0, 1, 2], "h")
check_measure_op(generated_qir, 3, [0, 1, 2], [0, 1, 2])


def test_verify_qir_bell_fixture(pyqir_bell):
"""Test that pyqir fixture generates code equal to test_qir_bell.ll file."""
test_name = "test_qir_bell"
Expand Down

0 comments on commit d2d5464

Please sign in to comment.