Skip to content

Commit

Permalink
Simplify logic by using properties of input DAG.
Browse files Browse the repository at this point in the history
  • Loading branch information
dlyongemallo committed Jul 22, 2024
1 parent 395a271 commit 45811d8
Showing 1 changed file with 11 additions and 21 deletions.
32 changes: 11 additions & 21 deletions zxpass/zxpass.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,12 @@

"""A transpiler pass for Qiskit which uses ZX-Calculus for circuit optimization, implemented using PyZX."""

from collections import OrderedDict
from typing import Dict, List, Tuple, Callable, Optional, Type, Union
import numpy as np

from qiskit.transpiler.basepasses import TransformationPass
from qiskit.dagcircuit import DAGCircuit, DAGOpNode
from qiskit.circuit import Qubit, Clbit, Instruction
from qiskit.circuit import Qubit, Instruction

from qiskit.circuit.library import XGate, YGate, ZGate, HGate, SGate, TGate, SXGate
from qiskit.circuit.library import SdgGate, TdgGate, SXdgGate
Expand Down Expand Up @@ -101,12 +100,6 @@ class ZXPass(TransformationPass):

def __init__(self, optimize: Optional[Callable[[zx.Circuit], zx.Circuit]] = None):
super().__init__()

self.cregs: OrderedDict[str, Clbit] = OrderedDict()
self.clbits: List[Clbit] = []
self.qregs: OrderedDict[str, Qubit] = OrderedDict()
self.qubits: List[Qubit] = []
self.qubit_to_index: Dict[Qubit, int] = {}
self.optimize: Callable[[zx.Circuit], zx.Circuit] = optimize or _optimize

def _dag_to_circuits_and_nodes(self, dag: DAGCircuit) -> List[Union[zx.Circuit, DAGOpNode]]:
Expand All @@ -118,11 +111,7 @@ def _dag_to_circuits_and_nodes(self, dag: DAGCircuit) -> List[Union[zx.Circuit,
"""

circuits_and_nodes: List[Union[zx.Circuit, DAGOpNode]] = []
self.cregs = dag.cregs
self.clbits = dag.clbits
self.qregs = dag.qregs
self.qubits = dag.qubits
self.qubit_to_index = {qubit: index for index, qubit in enumerate(dag.qubits)}
qubit_to_index = {qubit: index for index, qubit in enumerate(dag.qubits)}

current_circuit: Optional[zx.Circuit] = None
for node in dag.topological_op_nodes():
Expand All @@ -149,7 +138,7 @@ def _dag_to_circuits_and_nodes(self, dag: DAGCircuit) -> List[Union[zx.Circuit,
kwargs = {'adjoint': adjoint[0]} if adjoint else {}
if current_circuit is None:
current_circuit = zx.Circuit(len(dag.qubits))
current_circuit.add_gate(gate_type(*[self.qubit_to_index[qarg] for qarg in node.qargs], # type: ignore
current_circuit.add_gate(gate_type(*[qubit_to_index[qarg] for qarg in node.qargs], # type: ignore
*[param / np.pi for param in node.op.params], # type: ignore
**kwargs)) # type: ignore

Expand All @@ -159,18 +148,19 @@ def _dag_to_circuits_and_nodes(self, dag: DAGCircuit) -> List[Union[zx.Circuit,

return circuits_and_nodes

def _recover_dag(self, circuits_and_nodes: List[Union[zx.Circuit, DAGOpNode]]) -> DAGCircuit:
def _recover_dag(self, circuits_and_nodes: List[Union[zx.Circuit, DAGOpNode]], original_dag: DAGCircuit) -> DAGCircuit:
"""Recover a DAG from a list of a pyzx Circuits and DAGOpNodes.
:param circuits_and_nodes: The list of (optimized) PyZX Circuits and DAGOpNodes from which to recover the DAG.
:param original_dag: The original input DAG to ZXPass.
:return: An optimized version of the original input DAG to ZXPass.
"""

dag = DAGCircuit()
dag.cregs = self.cregs
dag.add_clbits(self.clbits)
dag.qregs = self.qregs
dag.add_qubits(self.qubits)
dag.cregs = original_dag.cregs
dag.add_clbits(original_dag.clbits)
dag.qregs = original_dag.qregs
dag.add_qubits(original_dag.qubits)
for circuit_or_node in circuits_and_nodes:
if isinstance(circuit_or_node, DAGOpNode):
dag.apply_operation_back(circuit_or_node.op, circuit_or_node.qargs, circuit_or_node.cargs)
Expand All @@ -183,7 +173,7 @@ def _recover_dag(self, circuits_and_nodes: List[Union[zx.Circuit, DAGOpNode]]) -
qargs: List[Qubit] = []
for attr in ['ctrl1', 'ctrl2', 'control', 'target']:
if hasattr(gate, attr):
qargs.append(self.qubits[getattr(gate, attr)])
qargs.append(original_dag.qubits[getattr(gate, attr)])
params: List[float] = []
if hasattr(gate, 'phase'):
params = [float(gate.phase) * np.pi]
Expand All @@ -208,7 +198,7 @@ def run(self, dag: DAGCircuit) -> DAGCircuit:
circuits_and_nodes = [self.optimize(circuit) if isinstance(circuit, zx.Circuit) else circuit
for circuit in circuits_and_nodes]

return self._recover_dag(circuits_and_nodes)
return self._recover_dag(circuits_and_nodes, dag)

def name(self) -> str:
return "ZXPass"

0 comments on commit 45811d8

Please sign in to comment.