Skip to content

Commit

Permalink
Extend DecomposeClassicalExp to handle ClExprOp (#1678)
Browse files Browse the repository at this point in the history
  • Loading branch information
cqc-alec authored Nov 15, 2024
1 parent 5055846 commit c91e562
Show file tree
Hide file tree
Showing 6 changed files with 191 additions and 17 deletions.
6 changes: 3 additions & 3 deletions pytket/binders/passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ static PassPtr gen_default_aas_routing_pass(

const PassPtr &DecomposeClassicalExp() {
// a special box decomposer for Circuits containing
// ClassicalExpBox<py::object>
// ClassicalExpBox<py::object> and ClExprOp
static const PassPtr pp([]() {
Transform t = Transform([](Circuit &circ) {
py::module decomposer =
Expand Down Expand Up @@ -483,8 +483,8 @@ PYBIND11_MODULE(passes, m) {
py::arg("excluded_opgroups") = std::unordered_set<std::string>());
m.def(
"DecomposeClassicalExp", &DecomposeClassicalExp,
"Replaces each :py:class:`ClassicalExpBox` by a sequence of "
"classical gates.");
"Replaces each :py:class:`ClassicalExpBox` and `ClExprOp` by a sequence "
"of classical gates.");
m.def(
"DecomposeMultiQubitsCX", &DecomposeMultiQubitsCX,
"Converts all multi-qubit gates into CX and single-qubit gates.");
Expand Down
2 changes: 2 additions & 0 deletions pytket/docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ Features:
and `flatten_registers`
* Implement `dagger()` and `transpose()` for `CustomGate`.
* Use `ClExprOp` by default when converting from QASM.
* Extend `DecomposeClassicalExp` to handle `ClExprOp` as well as
`ClassicalExpBox`.

Deprecations:

Expand Down
2 changes: 1 addition & 1 deletion pytket/pytket/_tket/passes.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ def DecomposeBoxes(excluded_types: set[pytket._tket.circuit.OpType] = set(), exc
"""
def DecomposeClassicalExp() -> BasePass:
"""
Replaces each :py:class:`ClassicalExpBox` by a sequence of classical gates.
Replaces each :py:class:`ClassicalExpBox` and `ClExprOp` by a sequence of classical gates.
"""
def DecomposeMultiQubitsCX() -> BasePass:
"""
Expand Down
168 changes: 164 additions & 4 deletions pytket/pytket/circuit/decompose_classical.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,28 @@
import copy
from collections.abc import Callable
from heapq import heappop, heappush
from typing import Generic, TypeVar

from pytket._tket.circuit import Circuit, ClassicalExpBox, Conditional, OpType
from typing import Any, Generic, TypeVar

from pytket._tket.circuit import (
Circuit,
ClassicalExpBox,
ClBitVar,
ClExpr,
ClExprOp,
ClOp,
ClRegVar,
Conditional,
OpType,
WiredClExpr,
)
from pytket._tket.unit_id import (
_TEMP_BIT_NAME,
_TEMP_BIT_REG_BASE,
_TEMP_REG_SIZE,
Bit,
BitRegister,
)
from pytket.circuit.clexpr import check_register_alignments, has_reg_output
from pytket.circuit.logic_exp import (
BitLogicExp,
BitWiseOp,
Expand Down Expand Up @@ -242,8 +254,131 @@ def recursive_walk(
return recursive_walk


class ClExprDecomposer:
def __init__(
self,
circ: Circuit,
bit_posn: dict[int, int],
reg_posn: dict[int, list[int]],
args: list[Bit],
bit_heap: BitHeap,
reg_heap: RegHeap,
kwargs: dict[str, Any],
):
self.circ: Circuit = circ
self.bit_posn: dict[int, int] = bit_posn
self.reg_posn: dict[int, list[int]] = reg_posn
self.args: list[Bit] = args
self.bit_heap: BitHeap = bit_heap
self.reg_heap: RegHeap = reg_heap
self.kwargs: dict[str, Any] = kwargs
# Construct maps from int (i.e. ClBitVar) to Bit, and from int (i.e. ClRegVar)
# to BitRegister:
self.bit_vars = {i: args[p] for i, p in bit_posn.items()}
self.reg_vars = {
i: BitRegister(args[p[0]].reg_name, len(p)) for i, p in reg_posn.items()
}

def add_var(self, var: Variable) -> None:
"""Add a Bit or BitRegister to the circuit if not already present."""
if isinstance(var, Bit):
self.circ.add_bit(var, reject_dups=False)
else:
assert isinstance(var, BitRegister)
for bit in var.to_list():
self.circ.add_bit(bit, reject_dups=False)

def set_bits(self, var: Variable, val: int) -> None:
"""Set the value of a Bit or BitRegister."""
assert val >= 0
if isinstance(var, Bit):
assert val >> 1 == 0
self.circ.add_c_setbits([bool(val)], [var], **self.kwargs)
else:
assert isinstance(var, BitRegister)
assert val >> var.size == 0
self.circ.add_c_setreg(val, var, **self.kwargs)

def decompose_expr(self, expr: ClExpr, out_var: Variable | None) -> Variable:
"""Add the decomposed expression to the circuit and return the Bit or
BitRegister that contains the result.
:param expr: the expression to decompose
:param out_var: where to put the output (if None, create a new scratch location)
"""
op: ClOp = expr.op
heap: VarHeap = self.reg_heap if has_reg_output(op) else self.bit_heap

# Eliminate (recursively) subsidiary expressions from the arguments, and convert
# all terms to Bit or BitRegister:
terms: list[Variable] = []
for arg in expr.args:
if isinstance(arg, int):
# Assign to a fresh variable
fresh_var = heap.fresh_var()
self.add_var(fresh_var)
self.set_bits(fresh_var, arg)
terms.append(fresh_var)
elif isinstance(arg, ClBitVar):
terms.append(self.bit_vars[arg.index])
elif isinstance(arg, ClRegVar):
terms.append(self.reg_vars[arg.index])
else:
assert isinstance(arg, ClExpr)
terms.append(self.decompose_expr(arg, None))

# Enable reuse of temporary terms:
for term in terms:
if heap.is_heap_var(term):
heap.push(term)

if out_var is None:
out_var = heap.fresh_var()
self.add_var(out_var)
match op:
case ClOp.BitAnd:
self.circ.add_c_and(*terms, out_var, **self.kwargs) # type: ignore
case ClOp.BitNot:
self.circ.add_c_not(*terms, out_var, **self.kwargs) # type: ignore
case ClOp.BitOne:
assert isinstance(out_var, Bit)
self.circ.add_c_setbits([True], [out_var], **self.kwargs)
case ClOp.BitOr:
self.circ.add_c_or(*terms, out_var, **self.kwargs) # type: ignore
case ClOp.BitXor:
self.circ.add_c_xor(*terms, out_var, **self.kwargs) # type: ignore
case ClOp.BitZero:
assert isinstance(out_var, Bit)
self.circ.add_c_setbits([False], [out_var], **self.kwargs)
case ClOp.RegAnd:
self.circ.add_c_and_to_registers(*terms, out_var, **self.kwargs) # type: ignore
case ClOp.RegNot:
self.circ.add_c_not_to_registers(*terms, out_var, **self.kwargs) # type: ignore
case ClOp.RegOne:
assert isinstance(out_var, BitRegister)
self.circ.add_c_setbits(
[True] * out_var.size, out_var.to_list(), **self.kwargs
)
case ClOp.RegOr:
self.circ.add_c_or_to_registers(*terms, out_var, **self.kwargs) # type: ignore
case ClOp.RegXor:
self.circ.add_c_xor_to_registers(*terms, out_var, **self.kwargs) # type: ignore
case ClOp.RegZero:
assert isinstance(out_var, BitRegister)
self.circ.add_c_setbits(
[False] * out_var.size, out_var.to_list(), **self.kwargs
)
case _:
raise DecomposeClassicalError(
f"{op} cannot be decomposed to TKET primitives."
)
return out_var


def _decompose_expressions(circ: Circuit) -> tuple[Circuit, bool]:
"""Rewrite a circuit command-wise, decomposing ClassicalExpBox."""
"""Rewrite a circuit command-wise, decomposing ClassicalExpBox and ClExprOp."""
if not check_register_alignments(circ):
raise DecomposeClassicalError("Circuit contains non-register-aligned ClExprOp.")
bit_heap = BitHeap()
reg_heap = RegHeap()
# add already used heap variables to heaps
Expand Down Expand Up @@ -343,6 +478,31 @@ def _decompose_expressions(circ: Circuit) -> tuple[Circuit, bool]:
replace_targets[out_reg] = comp_reg
modified = True
continue

elif optype == OpType.ClExpr:
assert isinstance(op, ClExprOp)
wexpr: WiredClExpr = op.expr
expr: ClExpr = wexpr.expr
bit_posn = wexpr.bit_posn
reg_posn = wexpr.reg_posn
output_posn = wexpr.output_posn
assert len(output_posn) > 0
output0 = args[output_posn[0]]
assert isinstance(output0, Bit)
out_var: Variable = (
BitRegister(output0.reg_name, len(output_posn))
if has_reg_output(expr.op)
else output0
)
decomposer = ClExprDecomposer(
newcirc, bit_posn, reg_posn, args, bit_heap, reg_heap, kwargs # type: ignore
)
comp_var = decomposer.decompose_expr(expr, out_var)
if comp_var != out_var:
replace_targets[out_var] = comp_var
modified = True
continue

if optype == OpType.Barrier:
# add_gate doesn't work for metaops
newcirc.add_barrier(args)
Expand Down
14 changes: 10 additions & 4 deletions pytket/tests/qasm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
reg_lt,
reg_neq,
)
from pytket.circuit.decompose_classical import DecomposeClassicalError
from pytket.circuit.logic_exp import BitWiseOp, create_bit_logic_exp
from pytket.passes import DecomposeBoxes, DecomposeClassicalExp
from pytket.qasm import (
Expand Down Expand Up @@ -464,14 +465,18 @@ def test_extended_qasm() -> None:

assert circuit_to_qasm_str(c2, "hqslib1")

assert not DecomposeClassicalExp().apply(c)
with pytest.raises(DecomposeClassicalError):
DecomposeClassicalExp().apply(c)


def test_decomposable_extended() -> None:
@pytest.mark.parametrize("use_clexpr", [True, False])
def test_decomposable_extended(use_clexpr: bool) -> None:
fname = str(curr_file_path / "qasm_test_files/test18.qasm")
out_fname = str(curr_file_path / "qasm_test_files/test18_output.qasm")

c = circuit_from_qasm_wasm(fname, "testfile.wasm", maxwidth=64, use_clexpr=True)
c = circuit_from_qasm_wasm(
fname, "testfile.wasm", maxwidth=64, use_clexpr=use_clexpr
)
DecomposeClassicalExp().apply(c)

out_qasm = circuit_to_qasm_str(c, "hqslib1", maxwidth=64)
Expand Down Expand Up @@ -1233,7 +1238,8 @@ def test_multibitop() -> None:
test_hqs_conditional_params()
test_barrier()
test_barrier_2()
test_decomposable_extended()
test_decomposable_extended(True)
test_decomposable_extended(False)
test_alternate_encoding()
test_header_stops_gate_definition()
test_tk2_definition()
Expand Down
16 changes: 11 additions & 5 deletions pytket/tests/qasm_test_files/test18_output.qasm
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,22 @@ creg a[2];
creg b[3];
creg c[4];
creg d[1];
creg tk_SCRATCH_BIT[7];
creg tk_SCRATCH_BITREG_0[64];
c = 2;
tk_SCRATCH_BITREG_0[0] = b[0] & a[0];
tk_SCRATCH_BITREG_0[1] = b[1] & a[1];
c[0] = a[0];
c[1] = a[1];
if(b!=2) c[1] = ((b[1] & a[1]) | a[0]);
c = ((b & a) | d);
d[0] = (a[0] ^ 1);
a = CCE(a, b);
if(b!=2) tk_SCRATCH_BIT[6] = b[1] & a[1];
c[0] = tk_SCRATCH_BITREG_0[0] | d[0];
if(b!=2) c[1] = tk_SCRATCH_BIT[6] | a[0];
tk_SCRATCH_BIT[6] = 1;
d[0] = a[0] ^ tk_SCRATCH_BIT[6];
if(c>=2) h q[0];
CCE(c);
a = CCE(a, b);
if(c<=2) h q[0];
CCE(c);
if(c<=1) h q[0];
if(c>=3) h q[0];
if(c!=2) h q[0];
Expand Down

0 comments on commit c91e562

Please sign in to comment.