Skip to content

Commit

Permalink
Merge pull request #24 from qBraid/rohan
Browse files Browse the repository at this point in the history
Rohan
  • Loading branch information
rjain37 authored Dec 28, 2023
2 parents a8dfc84 + 36610fa commit e2dc0fe
Show file tree
Hide file tree
Showing 4 changed files with 133 additions and 2 deletions.
44 changes: 44 additions & 0 deletions qbraid_qir/cirq/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,34 @@
"""
from typing import Optional

import numpy as np

import cirq
import qbraid.programs.cirq
from pyqir import Context, Module, qir_module

from qbraid_qir.cirq.elements import CirqModule, generate_module_id
from qbraid_qir.cirq.visitor import BasicQisVisitor
from qbraid_qir.exceptions import QirConversionError
from qbraid_qir.cirq.opsets import CIRQ_GATES, get_callable_from_pyqir_name


def _preprocess_circuit(circuit: cirq.Circuit) -> cirq.Circuit:
"""
Preprocesses a Cirq circuit to ensure that it is compatible with the QIR conversion.
Args:
circuit (cirq.Circuit): The Cirq circuit to preprocess.
Returns:
cirq.Circuit: The preprocessed Cirq circuit.
"""
# circuit = cirq.contrib.qasm_import.circuit_from_qasm(circuit.to_qasm()) # decompose?
qprogram = qbraid.programs.cirq.CirqCircuit(circuit)
qprogram._convert_to_line_qubits()
cirq_circuit = qprogram.program
return cirq_circuit


def _preprocess_circuit(circuit: cirq.Circuit) -> cirq.Circuit:
Expand Down Expand Up @@ -62,10 +83,33 @@ def cirq_to_qir(circuit: cirq.Circuit, name: Optional[str] = None, **kwargs) ->
if name is None:
name = generate_module_id(circuit)

# create a variable for circuit.unitary that we will use to create assertions later
input_unitary = circuit.unitary()

circuit = _preprocess_circuit(circuit)

# according to the gateset in CIRQ_GATES, perform gate decomposition;
for moment in circuit:
for op in moment:
if str(op.gate) in CIRQ_GATES:
# i don't know what to do here
callable = get_callable_from_pyqir_name(op)
else:
raise QirConversionError(f"Unsupported gate {str(op.gate)} in circuit.")


# ensure that input/output circuit.unitary() are equivalent.
output_unitary = circuit.unitary()
if not np.allclose(input_unitary, output_unitary):
raise QirConversionError("Cirq circuit unitary changed during conversion.")


llvm_module = qir_module(Context(), name)
module = CirqModule.from_circuit(circuit, llvm_module)

visitor = BasicQisVisitor(**kwargs)
module.accept(visitor)

err = llvm_module.verify()
if err is not None:
raise QirConversionError(err)
Expand Down
43 changes: 42 additions & 1 deletion qbraid_qir/cirq/opsets.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,46 @@
Module defining supported Cirq operations/gates.
"""
import cirq
import pyqir._native

CIRQ_GATES = {}
# "barrier",
# "mz",
# "reset",

ZPOWER_DICT = {
1: pyqir._native.z,
0.5: pyqir._native.s,
0.25: pyqir._native.t,
-0.5: pyqir._native.s_adj,
-0.25: pyqir._native.t_adj,
}
CIRQ_GATES = {
'TOFFOLI': pyqir._native.ccx,
'CCX': pyqir._native.ccx,
'CCNOT': pyqir._native.ccx,
'CNOT': pyqir._native.cx,
'CZ': pyqir._native.cz,
'H': pyqir._native.h,
'SWAP': pyqir._native.swap,
'X': pyqir._native.x,
'Y': pyqir._native.y,
# 'Z': pyqir._native.z
}

def get_callable_from_pyqir_name(op: cirq.Operation):
"""Get callable from pyqir name."""
if isinstance(op, cirq.ops.ZPowGate):
return ZPOWER_DICT[op.gate.exponent]
return CIRQ_GATES[str(op.gate)]

# some testcases for the above function
circuit = cirq.Circuit()
# circuit.append(cirq.ops.Z(cirq.LineQubit(0)))
# circuit.append(cirq.ops.CNOT(cirq.LineQubit(1), cirq.LineQubit(2)))
# circuit.append(cirq.ops.CNOT(cirq.LineQubit(2), cirq.LineQubit(3)))
# circuit.append(cirq.ops.H(cirq.LineQubit(0)))
# circuit.append(cirq.ops.H(cirq.LineQubit(1)))

# for op in circuit.all_operations():
# print(isinstance(op.gate, cirq.ops.ZPowGate))
4 changes: 4 additions & 0 deletions qbraid_qir/cirq/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
entry_point,
)

from qbraid_qir.cirq.opsets import CIRQ_GATES, get_callable_from_pyqir_name
from qbraid_qir.cirq.elements import CirqModule

_log = logging.getLogger(name=__name__)
Expand Down Expand Up @@ -127,6 +128,9 @@ def visit_operation(self, operation: cirq.Operation, qids: list[cirq.Qid]):
results = [pyqir.result(self._module.context, n) for n in qlabels]
# call some function that depends on qubits and results

callable = get_callable_from_pyqir_name(operation)
callable(self._builder, *qubits, *results)

def ir(self) -> str:
return str(self._module)

Expand Down
44 changes: 43 additions & 1 deletion tests/fixtures/basic_gates.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import cirq
import pytest
import numpy as np

# All of the following dictionaries map from the names of methods on Cirq Circuit objects
# to the name of the equivalent pyqir BasicQisBuilder method
Expand Down Expand Up @@ -68,7 +69,6 @@ def test_fixture():

return test_fixture


# Generate simple single-qubit gate fixtures
for gate in _one_qubit_gates:
name = _fixture_name(gate)
Expand Down Expand Up @@ -96,12 +96,22 @@ def _generate_two_qubit_fixture(gate_name: str):
@pytest.fixture()
def test_fixture():
circuit = cirq.Circuit()
q1 = cirq.NamedQubit("q1")
q2 = cirq.NamedQubit("q2")
circuit.append(getattr(cirq, gate_name)(q1, q2))
qs = cirq.LineQubit(2)
circuit.append(getattr(cirq, gate_name)(qs[0], qs[1]))
return _map_gate_name(gate_name), circuit

return test_fixture

# Create a new function to generate a fixture for n-qubit gates
def _generate_n_qubit_fixture(gate_name: str, n: int):
@pytest.fixture()
def test_fixture():
circuit = cirq.Circuit()
qubits = [cirq.NamedQubit(f"q{i}") for i in range(n)]
circuit.append(getattr(cirq, gate_name)(*qubits))

# Generate double-qubit gate fixtures
for gate in _two_qubit_gates.keys():
Expand All @@ -119,6 +129,38 @@ def test_fixture():

return test_fixture

# New function for more complex gate structures:
def _generate_complex_gate_fixture(gate_sequence):
@pytest.fixture()
def test_fixture():
circuit = cirq.Circuit()
qubits = [cirq.NamedQubit(f"q{i}") for i in range(len(gate_sequence))]
for gate, qubit_indices in gate_sequence:
gates_to_apply = [getattr(cirq, gate)(qubits[i]) for i in qubit_indices]
circuit.append(gates_to_apply)
return circuit
return test_fixture

def test_qft():
for n in range(2, 5): # Test for different numbers of qubits
circuit = cirq.Circuit()
qubits = [cirq.NamedQubit(f'q{i}') for i in range(n)]
circuit.append(cirq.qft(*qubits))
# Add assertions or checks here

@pytest.mark.parametrize("angle", np.linspace(0, 2*np.pi, 5))
def test_rx_gate(angle):
qubit = cirq.NamedQubit('q')
circuit = cirq.Circuit(cirq.rx(angle)(qubit))
# Add assertions or checks for the rotation

def test_bell_state():
qubits = [cirq.NamedQubit(f'q{i}') for i in range(2)]
circuit = cirq.Circuit()
circuit.append([cirq.H(qubits[0]), cirq.CNOT(qubits[0], qubits[1])])
# Check if the circuit produces the correct entangled state

single_op_tests = [_fixture_name(s) for s in _one_qubit_gates]

# Generate three-qubit gate fixtures
for gate in _three_qubit_gates.keys():
Expand Down

0 comments on commit e2dc0fe

Please sign in to comment.