Skip to content

Commit

Permalink
Bugfix in Cirq Interop: Attempt 2 (#1100)
Browse files Browse the repository at this point in the history
* Greedy topological sort of the binst graph to minimize qubit allocations / deallocations

* Docstring for _PrioritizedItem

* Fix failing CI

* Undo thc.ipynb change

* Another attempt at cirq interop bugfix

* Fix formatting

* Fix as_cirq_op for SGate(is_adjoint=True)
  • Loading branch information
tanujkhattar authored Jul 8, 2024
1 parent 19928a3 commit 3315602
Show file tree
Hide file tree
Showing 16 changed files with 98 additions and 28 deletions.
8 changes: 0 additions & 8 deletions qualtran/_infra/adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

import cirq
from attrs import frozen
from numpy.typing import NDArray

from .composite_bloq import _binst_to_cxns, _cxns_to_soq_dict, _map_soqs, _reg_to_soq, BloqBuilder
from .gate_with_registers import GateWithRegisters
Expand Down Expand Up @@ -142,13 +141,6 @@ def decompose_bloq(self) -> 'CompositeBloq':
"""The decomposition is the adjoint of `subbloq`'s decomposition."""
return self.subbloq.decompose_bloq().adjoint()

def decompose_from_registers(
self, *, context: cirq.DecompositionContext, **quregs: NDArray[cirq.Qid] # type: ignore[type-var]
) -> cirq.OP_TREE:
if isinstance(self.subbloq, GateWithRegisters):
return cirq.inverse(self.subbloq.decompose_from_registers(context=context, **quregs))
return super().decompose_from_registers(context=context, **quregs)

def _circuit_diagram_info_(
self, args: 'cirq.CircuitDiagramInfoArgs'
) -> cirq.CircuitDiagramInfo:
Expand Down
6 changes: 3 additions & 3 deletions qualtran/_infra/gate_with_registers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,11 +141,11 @@ def test_gate_with_registers_decompose_from_context_auto_generated():
cirq.testing.assert_has_diagram(
circuit,
"""
l: ───BloqWithDecompose───X───────free───
l: ───BloqWithDecompose───X───
r: ───r───────────────────alloc───Z──────
r: ───r───────────────────Z───
t: ───t───────────────────Y──────────────
t: ───t───────────────────Y───
""",
)

Expand Down
3 changes: 3 additions & 0 deletions qualtran/bloqs/arithmetic/comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,9 @@ def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
def _has_unitary_(self):
return True

def adjoint(self) -> 'Bloq':
return self


@bloq_example
def _leq_symb() -> LessThanEqual:
Expand Down
3 changes: 2 additions & 1 deletion qualtran/bloqs/basic_gates/s_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ def as_cirq_op(
import cirq

(q,) = q
return cirq.S(q), {'q': np.array([q])}
p = -1 if self.is_adjoint else 1
return cirq.S(q) ** p, {'q': np.array([q])}

def pretty_name(self) -> str:
maybe_dag = '†' if self.is_adjoint else ''
Expand Down
3 changes: 2 additions & 1 deletion qualtran/bloqs/basic_gates/s_gate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,10 @@ def test_to_cirq():
bb = BloqBuilder()
q = bb.add(PlusState())
q = bb.add(SGate(), q=q)
q = bb.add(SGate().adjoint(), q=q)
cbloq = bb.finalize(q=q)
circuit = cbloq.to_cirq_circuit()
cirq.testing.assert_has_diagram(circuit, "_c(0): ───H───S───")
cirq.testing.assert_has_diagram(circuit, "_c(0): ───H───S───S^-1───")


def test_tensors():
Expand Down
15 changes: 14 additions & 1 deletion qualtran/bloqs/bookkeeping/allocate.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import cached_property
from typing import Dict, List, Tuple, TYPE_CHECKING
from typing import Dict, List, Tuple, TYPE_CHECKING, Union

import numpy as np
import sympy
from attrs import frozen

Expand All @@ -34,8 +35,11 @@
from qualtran.drawing import directional_text_box, Text, WireSymbol

if TYPE_CHECKING:
import cirq
import quimb.tensor as qtn

from qualtran.cirq_interop import CirqQuregT


@frozen
class Allocate(_BookkeepingBloq):
Expand Down Expand Up @@ -83,6 +87,15 @@ def wire_symbol(self, reg: Register, idx: Tuple[int, ...] = tuple()) -> 'WireSym
assert reg.name == 'reg'
return directional_text_box('alloc', Side.RIGHT)

def as_cirq_op(
self, qubit_manager: 'cirq.QubitManager'
) -> Tuple[Union['cirq.Operation', None], Dict[str, 'CirqQuregT']]:
shape = (*self.signature[0].shape, self.signature[0].bitsize)
return (
None,
{'reg': np.array(qubit_manager.qalloc(self.signature.n_qubits())).reshape(shape)},
)


@bloq_example
def _alloc() -> Allocate:
Expand Down
10 changes: 9 additions & 1 deletion qualtran/bloqs/bookkeeping/free.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from functools import cached_property
from typing import Dict, List, Tuple, TYPE_CHECKING
from typing import Dict, List, Tuple, TYPE_CHECKING, Union

import sympy
from attrs import frozen
Expand All @@ -35,8 +35,10 @@
from qualtran.drawing import directional_text_box, Text, WireSymbol

if TYPE_CHECKING:
import cirq
import quimb.tensor as qtn

from qualtran.cirq_interop import CirqQuregT
from qualtran.simulation.classical_sim import ClassicalValT


Expand Down Expand Up @@ -92,6 +94,12 @@ def wire_symbol(self, reg: Register, idx: Tuple[int, ...] = tuple()) -> 'WireSym
assert reg.name == 'reg'
return directional_text_box('free', Side.LEFT)

def as_cirq_op(
self, qubit_manager: 'cirq.QubitManager', reg: 'CirqQuregT'
) -> Tuple[Union['cirq.Operation', None], Dict[str, 'CirqQuregT']]:
qubit_manager.qfree(reg.flatten().tolist())
return (None, {})


@bloq_example
def _free() -> Free:
Expand Down
6 changes: 3 additions & 3 deletions qualtran/bloqs/data_loading/select_swap_qrom_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,9 @@ def test_select_swap_qrom(data, block_size):
cirq.decompose_once(qrom.on_registers(**qubit_regs), context=context)
)

dirty_target_ancilla = [
q for q in qrom_circuit.all_qubits() if isinstance(q, cirq.ops.BorrowableQubit)
]
dirty_target_ancilla = sorted(
qrom_circuit.all_qubits() - set(q for qs in qubit_regs.values() for q in qs.flatten())
)

circuit = cirq.Circuit(
# Prepare dirty ancillas in an arbitrary state.
Expand Down
4 changes: 2 additions & 2 deletions qualtran/bloqs/reflections/reflection_using_prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,12 @@
from qualtran.bloqs.basic_gates.global_phase import GlobalPhase
from qualtran.bloqs.basic_gates.rotation import ZPowGate
from qualtran.bloqs.basic_gates.x_basis import XGate
from qualtran.bloqs.block_encoding.lcu_select_and_prepare import PrepareOracle
from qualtran.bloqs.mcmt.multi_control_multi_target_pauli import MultiControlPauli
from qualtran.resource_counting.generalizers import ignore_split_join
from qualtran.symbolics.types import SymbolicInt

if TYPE_CHECKING:
from qualtran.bloqs.block_encoding.lcu_select_and_prepare import PrepareOracle
from qualtran.resource_counting import BloqCountT, SympySymbolAllocator


Expand Down Expand Up @@ -77,7 +77,7 @@ class ReflectionUsingPrepare(SpecializedSingleQubitControlledGate):
Babbush et. al. (2018). Figure 1.
"""

prepare_gate: PrepareOracle
prepare_gate: 'PrepareOracle'
control_val: Optional[int] = None
global_phase: complex = 1
eps: float = 1e-11
Expand Down
12 changes: 10 additions & 2 deletions qualtran/bloqs/reflections/reflection_using_prepare_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import pytest
from numpy.typing import NDArray

from qualtran import Bloq
from qualtran import Adjoint, Bloq
from qualtran._infra.gate_with_registers import get_named_qubits
from qualtran.bloqs.arithmetic import LessThanConstant, LessThanEqual
from qualtran.bloqs.basic_gates import ZPowGate
Expand All @@ -30,6 +30,7 @@
from qualtran.bloqs.reflections.prepare_identity import PrepareIdentity
from qualtran.bloqs.reflections.reflection_using_prepare import ReflectionUsingPrepare
from qualtran.bloqs.state_preparation import StatePreparationAliasSampling
from qualtran.cirq_interop import BloqAsCirqGate
from qualtran.cirq_interop.testing import GateHelper
from qualtran.resource_counting.generalizers import (
ignore_alloc_free,
Expand Down Expand Up @@ -58,6 +59,13 @@ def keep(op: cirq.Operation):
ret = op in gateset_to_keep
if op.gate is not None and isinstance(op.gate, cirq.ops.raw_types._InverseCompositeGate):
ret |= op.gate._original in gateset_to_keep
if op.gate is not None and isinstance(op.gate, Adjoint):
subgate = (
op.gate.subbloq
if isinstance(op.gate.subbloq, cirq.Gate)
else BloqAsCirqGate(op.gate.subbloq)
)
ret |= subgate in gateset_to_keep
return ret


Expand All @@ -73,7 +81,7 @@ def construct_gate_helper_and_qubit_order(gate, decompose_once: bool = False):
)
ordered_input = list(itertools.chain(*g.quregs.values()))
qubit_order = cirq.QubitOrder.explicit(ordered_input, fallback=cirq.QubitOrder.DEFAULT)
assert len(circuit.all_qubits()) < 30
assert len(circuit.all_qubits()) < 24
return g, qubit_order, circuit


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,16 @@ def test_prepare_uniform_superposition_consistent_protocols():
PrepareUniformSuperposition(5, cvs=()),
PrepareUniformSuperposition(5, cvs=[]),
)


def test_prepare_uniform_superposition_adjoint():
n = 3
target = cirq.NamedQubit.range((n - 1).bit_length(), prefix='target')
control = [cirq.NamedQubit('control')]
op = PrepareUniformSuperposition(n, cvs=(0,)).on_registers(ctrl=control, target=target)
gqm = cirq.GreedyQubitManager(prefix="_ancilla", maximize_reuse=True)
context = cirq.DecompositionContext(gqm)
circuit = cirq.Circuit(op, cirq.decompose(cirq.inverse(op), context=context))
identity = cirq.Circuit(cirq.identity_each(*circuit.all_qubits())).final_state_vector()
result = cirq.Simulator(dtype=np.complex128).simulate(circuit)
np.testing.assert_allclose(result.final_state_vector, identity, atol=1e-8)
14 changes: 13 additions & 1 deletion qualtran/bloqs/swap_network/cswap_approx_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import random
from typing import Dict, Tuple, Union

import cirq
import numpy as np
import pytest
import sympy

Expand All @@ -38,6 +39,17 @@ def test_cswap_approx_decomp():
assert_valid_bloq_decomposition(csa)


def test_cswap_approx_decomposition():
csa = CSwapApprox(4)
circuit = (
csa.as_composite_bloq().to_cirq_circuit()
+ csa.adjoint().as_composite_bloq().to_cirq_circuit()
)
initial_state = cirq.testing.random_superposition(2**9, random_state=1234)
result = cirq.Simulator(dtype=np.complex128).simulate(circuit, initial_state=initial_state)
np.testing.assert_allclose(result.final_state_vector, initial_state)


@pytest.mark.parametrize('n', [5, 32])
def test_approx_cswap_t_count(n):
cswap = CSwapApprox(bitsize=n)
Expand Down
2 changes: 1 addition & 1 deletion qualtran/bloqs/swap_network/swap_with_zero.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def build_composite_bloq(

def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
num_swaps = prod(x for x in self.n_target_registers) - 1
return {(CSwapApprox(self.target_bitsize), num_swaps)}
return {(self.cswap_n, num_swaps)}

def _circuit_diagram_info_(self, args) -> cirq.CircuitDiagramInfo:
from qualtran.cirq_interop._bloq_to_cirq import _wire_symbol_to_cirq_diagram_info
Expand Down
1 change: 0 additions & 1 deletion qualtran/cirq_interop/_bloq_to_cirq.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,6 @@ def _bloq_to_cirq_op(
del qvar_to_qreg[soq]

op, out_quregs = bloq.as_cirq_op(qubit_manager=qubit_manager, **in_quregs)

# 2. Update the mappings based on output soquets and `out_quregs`.
for cxn in succ_cxns:
soq = cxn.left
Expand Down
2 changes: 1 addition & 1 deletion qualtran/cirq_interop/_bloq_to_cirq_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def test_bloq_as_cirq_gate_left_register():
bb.free(q)
cbloq = bb.finalize()
circuit = cbloq.to_cirq_circuit()
cirq.testing.assert_has_diagram(circuit, """_c(0): ───alloc───X───free───""")
cirq.testing.assert_has_diagram(circuit, """_c(0): ───X───""")


def test_bloq_as_cirq_gate_for_mod_exp():
Expand Down
24 changes: 22 additions & 2 deletions qualtran/cirq_interop/_interop_qubit_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,30 @@ def __init__(self, qm: Optional[cirq.QubitManager] = None):
self._managed_qubits: Set[cirq.Qid] = set()

def qalloc(self, n: int, dim: int = 2) -> List['cirq.Qid']:
return self._qm.qalloc(n, dim)
ret: List['cirq.Qid'] = []
qubits_to_free: List['cirq.Qid'] = []
while len(ret) < n:
new_alloc = self._qm.qalloc(n - len(ret), dim)
for q in new_alloc:
if q in self._managed_qubits:
qubits_to_free.append(q)
else:
ret.append(q)
self._qm.qfree(qubits_to_free)
return ret

def qborrow(self, n: int, dim: int = 2) -> List['cirq.Qid']:
return self._qm.qborrow(n, dim)
ret: List['cirq.Qid'] = []
qubits_to_free: List['cirq.Qid'] = []
while len(ret) < n:
new_alloc = self._qm.qborrow(n - len(ret), dim)
for q in new_alloc:
if q in self._managed_qubits:
qubits_to_free.append(q)
else:
ret.append(q)
self._qm.qfree(qubits_to_free)
return ret

def manage_qubits(self, qubits: Iterable[cirq.Qid]):
self._managed_qubits |= set(qubits)
Expand Down

0 comments on commit 3315602

Please sign in to comment.