diff --git a/docs/reference.md b/docs/reference.md index 95f02aa1..1ba59cf2 100644 --- a/docs/reference.md +++ b/docs/reference.md @@ -14,9 +14,9 @@ ::: opensquirrel.replacer -::: opensquirrel.squirrel_ast +::: opensquirrel.squirrel_ir -::: opensquirrel.squirrel_ast_creator +::: opensquirrel.squirrel_ir_creator ::: opensquirrel.squirrel_error_handler diff --git a/opensquirrel/__init__.py b/opensquirrel/__init__.py index e75d3ac3..d9fbe6bb 100644 --- a/opensquirrel/__init__.py +++ b/opensquirrel/__init__.py @@ -1,3 +1,3 @@ from opensquirrel.circuit import Circuit from opensquirrel.circuit_builder import CircuitBuilder -from opensquirrel.default_gates import DefaultGates +from opensquirrel.default_gates import default_gate_aliases diff --git a/opensquirrel/circuit.py b/opensquirrel/circuit.py index ad7f00ee..adafeb74 100644 --- a/opensquirrel/circuit.py +++ b/opensquirrel/circuit.py @@ -1,27 +1,21 @@ -import antlr4 +from typing import Callable + import numpy as np -from opensquirrel.default_gates import DefaultGates # For the doctest. -from opensquirrel.mckay_decomposer import McKayDecomposer -from opensquirrel.parsing.antlr.generated import CQasm3Lexer, CQasm3Parser -from opensquirrel.parsing.antlr.squirrel_ast_creator import SquirrelASTCreator -from opensquirrel.replacer import Replacer -from opensquirrel.squirrel_ast import SquirrelAST -from opensquirrel.squirrel_error_handler import SquirrelErrorHandler -from opensquirrel.test_interpreter import TestInterpreter -from opensquirrel.type_checker import TypeChecker -from opensquirrel.writer import Writer +import opensquirrel.mckay_decomposer as mckay_decomposer +import opensquirrel.replacer as replacer +import opensquirrel.test_interpreter as test_interpreter +import opensquirrel.writer as writer +from opensquirrel.default_gates import default_gate_aliases, default_gate_set # For the doctest. +from opensquirrel.parsing.antlr.squirrel_ir_from_string import squirrel_ir_from_string +from opensquirrel.squirrel_ir import SquirrelIR, Gate class Circuit: """The Circuit class is the only interface to access OpenSquirrel's features. - A Circuit object is constructed from a cQasm3 string, representing a quantum circuit, and a Python dictionary - containing the prototypes and semantic of the allowed quantum gates. - A default set of gates is exposed as `opensquirrel.default_gates` but it can be replaced and extended. - Examples: - >>> c = Circuit.from_string(DefaultGates, "version 3.0; qubit[3] q; h q[0]") + >>> c = Circuit.from_string("version 3.0; qubit[3] q; h q[0]") >>> c version 3.0 @@ -39,19 +33,15 @@ class Circuit: rz q[0], 1.5707963 x90 q[0] - - Args: - squirrelAST: OpenSquirrel internal AST. """ - def __init__(self, squirrelAST: SquirrelAST): - """Create a circuit object from a SquirrelAST object.""" + def __init__(self, squirrel_ir: SquirrelIR): + """Create a circuit object from a SquirrelIR object.""" - self.gates = squirrelAST.gates - self.squirrel_ast = squirrelAST + self.squirrel_ir = squirrel_ir @classmethod - def from_string(cls, gates: dict, cqasm3_string: str): + def from_string(cls, cqasm3_string: str, gate_set: [Callable[..., Gate]] = default_gate_set, gate_aliases: dict[str, Callable[..., Gate]] = default_gate_aliases): """Create a circuit object from a cQasm3 string. All the gates in the circuit need to be defined in the `gates` argument. @@ -61,31 +51,15 @@ def from_string(cls, gates: dict, cqasm3_string: str): * for example of `gates` dictionary, please look at TestGates.py """ - input_stream = antlr4.InputStream(cqasm3_string) - - lexer = CQasm3Lexer.CQasm3Lexer(input_stream) - - stream = antlr4.CommonTokenStream(lexer) - - parser = CQasm3Parser.CQasm3Parser(stream) - - parser.removeErrorListeners() - parser.addErrorListener(SquirrelErrorHandler()) - - tree = parser.prog() - - typeChecker = TypeChecker(gates) - typeChecker.visit(tree) # FIXME: return error instead of throwing? - - squirrelASTCreator = SquirrelASTCreator(gates) - - return Circuit(squirrelASTCreator.visit(tree)) + return Circuit(squirrel_ir_from_string(cqasm3_string, gate_set = gate_set, gate_aliases = gate_aliases)) - def getNumberOfQubits(self) -> int: - return self.squirrel_ast.nQubits + @property + def number_of_qubits(self) -> int: + return self.squirrel_ir.number_of_qubits - def getQubitRegisterName(self) -> str: - return self.squirrel_ast.qubitRegisterName + @property + def qubit_register_name(self) -> str: + return self.squirrel_ir.qubit_register_name def decompose_mckay(self): """Perform gate fusion on all one-qubit gates and decompose them in the McKay style. @@ -98,19 +72,16 @@ def decompose_mckay(self): for the input and output circuit - those outputs should be equivalent modulo global phase. """ - mcKayDecomposer = McKayDecomposer(self.gates) - self.squirrel_ast = mcKayDecomposer.process(self.squirrel_ast) + self.squirrel_ir = mckay_decomposer.decompose_mckay(self.squirrel_ir) # FIXME: inplace - def replace(self, gateName: str, f): + def replace(self, gate_name: str, f): """Manually replace occurrences of a given gate with a list of gates. * this can be called decomposition - but it's the least fancy version of it * function parameter gives the decomposition based on parameters of original gate """ - assert gateName in self.gates, f"Cannot replace unknown gate `{gateName}`" - replacer = Replacer(self.gates) # FIXME: only one instance of this is needed. - self.squirrel_ast = replacer.process(self.squirrel_ast, gateName, f) + replacer.replace(self.squirrel_ir, gate_name, f) def test_get_circuit_matrix(self) -> np.ndarray: """Get the (large) unitary matrix corresponding to the circuit. @@ -120,8 +91,7 @@ def test_get_circuit_matrix(self) -> np.ndarray: * result is stored as a numpy array of complex numbers """ - interpreter = TestInterpreter(self.gates) - return interpreter.process(self.squirrel_ast) + return test_interpreter.get_circuit_matrix(self.squirrel_ir) def __repr__(self) -> str: """Write the circuit to a cQasm3 string. @@ -129,5 +99,4 @@ def __repr__(self) -> str: * comments are removed """ - writer = Writer(self.gates) - return writer.process(self.squirrel_ast) + return writer.squirrel_ir_to_string(self.squirrel_ir) diff --git a/opensquirrel/circuit_builder.py b/opensquirrel/circuit_builder.py index 88f82ec6..d86604dc 100644 --- a/opensquirrel/circuit_builder.py +++ b/opensquirrel/circuit_builder.py @@ -1,32 +1,53 @@ +import inspect +from typing import Callable + from opensquirrel.circuit import Circuit -from opensquirrel.default_gates import DefaultGates -from opensquirrel.squirrel_ast import SquirrelAST +from opensquirrel.default_gates import default_gate_set, default_gate_aliases +from opensquirrel.gate_library import GateLibrary +from opensquirrel.squirrel_ir import Comment, Qubit, SquirrelIR, Gate -class CircuitBuilder: - """A class using the builder pattern to make construction of circuits easy. +class CircuitBuilder(GateLibrary): + """ + A class using the builder pattern to make construction of circuits easy from Python. Adds corresponding gate when a method is called. Checks gates are known and called with the right arguments. Mainly here to allow for Qiskit-style circuit construction: - >>> myCircuit = CircuitBuilder(DefaultGates, 3).h(0).cnot(0, 1).cnot(0, 2).to_circuit() - + >>> CircuitBuilder(number_of_qubits=3).h(Qubit(0)).cnot(Qubit(0), Qubit(1)).cnot(Qubit(0), Qubit(2)).to_circuit() + version 3.0 + + qubit[3] q + + h q[0] + cnot q[0], q[1] + cnot q[0], q[2] + """ - __default_qubit_register_name = "q" + _default_qubit_register_name = "q" - def __init__(self, gates: dict, numberOfQubits: int): - self.squirrelAST = SquirrelAST(gates, numberOfQubits, self.__default_qubit_register_name) + def __init__(self, number_of_qubits: int, gate_set: [Callable[..., Gate]] = default_gate_set, gate_aliases: dict[str, Callable[..., Gate]] = default_gate_aliases): + GateLibrary.__init__(self, gate_set, gate_aliases) + self.squirrel_ir = SquirrelIR(number_of_qubits, self._default_qubit_register_name) def __getattr__(self, attr): - def addComment(commentString: str): - self.squirrelAST.add_comment(commentString) + def add_comment(comment_string: str): + self.squirrel_ir.add_comment(Comment(comment_string)) return self - def addThisGate(*args): - self.squirrelAST.add_gate(attr, *args) + def add_this_gate(*args): + generator_f = GateLibrary.get_gate_f(self, attr) + + for i, par in enumerate(inspect.signature(generator_f).parameters.values()): + if not isinstance(args[i], par.annotation): + raise TypeError( + f"Wrong argument type for gate `{attr}`, got {type(args[i])} but expected {par.annotation}" + ) + + self.squirrel_ir.add_gate(generator_f(*args)) return self - return addComment if attr == "comment" else addThisGate + return add_comment if attr == "comment" else add_this_gate def to_circuit(self) -> Circuit: - return Circuit(self.squirrelAST) + return Circuit(self.squirrel_ir) diff --git a/opensquirrel/common.py b/opensquirrel/common.py index 8da47982..d4450fed 100644 --- a/opensquirrel/common.py +++ b/opensquirrel/common.py @@ -1,6 +1,5 @@ import cmath import math -from enum import Enum from typing import Tuple import numpy as np @@ -8,25 +7,13 @@ ATOL = 0.0000001 -class ExprType(Enum): - QUBITREFS = 1 - FLOAT = 2 - INT = 3 - - -class ArgType(Enum): - QUBIT = 0 - FLOAT = 1 - INT = 2 - - -def exprTypeToArgType(t: ExprType) -> ArgType: - if t == ExprType.QUBITREFS: - return ArgType.QUBIT - if t == ExprType.FLOAT: - return ArgType.FLOAT - if t == ExprType.INT: - return ArgType.INT +def normalize_angle(x: float) -> float: + t = x - 2 * math.pi * (x // (2 * math.pi) + 1) + if t < -math.pi + ATOL: + t += 2 * math.pi + elif t > math.pi: + t -= 2 * math.pi + return t X = np.array([[0, 1], [1, 0]]) @@ -34,7 +21,7 @@ def exprTypeToArgType(t: ExprType) -> ArgType: Z = np.array([[1, 0], [0, -1]]) -def Can1(axis: Tuple[float, float, float], angle: float, phase: float = 0) -> np.ndarray: +def can1(axis: Tuple[float, float, float], angle: float, phase: float = 0) -> np.ndarray: nx, ny, nz = axis norm = math.sqrt(nx**2 + ny**2 + nz**2) assert norm > 0.00000001 diff --git a/opensquirrel/default_gates.py b/opensquirrel/default_gates.py index 615a7eea..1a7464aa 100644 --- a/opensquirrel/default_gates.py +++ b/opensquirrel/default_gates.py @@ -1,104 +1,75 @@ -import cmath import math -import numpy as np - -from opensquirrel.common import ArgType -from opensquirrel.gates import MultiQubitMatrixSemantic, SingleQubitAxisAngleSemantic - -DefaultGates = { - "h": { - "signature": (ArgType.QUBIT,), - "semantic": SingleQubitAxisAngleSemantic(axis=(1, 0, 1), angle=math.pi, phase=math.pi / 2), - }, - "H": "h", # This is how you define an alias. - "hadamard": "h", - "x": { - "signature": (ArgType.QUBIT,), - "semantic": SingleQubitAxisAngleSemantic(axis=(1, 0, 0), angle=math.pi, phase=math.pi / 2), - }, - "X": "x", - "y": { - "signature": (ArgType.QUBIT,), - "semantic": SingleQubitAxisAngleSemantic(axis=(0, 1, 0), angle=math.pi, phase=math.pi / 2), - }, - "Y": "y", - "z": { - "signature": (ArgType.QUBIT,), - "semantic": SingleQubitAxisAngleSemantic(axis=(0, 0, 1), angle=math.pi, phase=math.pi / 2), - }, - "Z": "z", - "rx": { - "signature": (ArgType.QUBIT, ArgType.FLOAT), - "semantic": lambda theta: SingleQubitAxisAngleSemantic(axis=(1, 0, 0), angle=theta, phase=0), - }, - "RX": "rx", - "x90": { - "signature": (ArgType.QUBIT,), - "semantic": SingleQubitAxisAngleSemantic(axis=(1, 0, 0), angle=math.pi / 2, phase=0), - }, - "X90": "x90", - "mx90": { - "signature": (ArgType.QUBIT,), - "semantic": SingleQubitAxisAngleSemantic(axis=(1, 0, 0), angle=-math.pi / 2, phase=0), - }, - "MX90": "mx90", - "ry": { - "signature": (ArgType.QUBIT, ArgType.FLOAT), - "semantic": lambda theta: SingleQubitAxisAngleSemantic(axis=(0, 1, 0), angle=theta, phase=0), - }, - "RY": "ry", - "y90": { - "signature": (ArgType.QUBIT,), - "semantic": SingleQubitAxisAngleSemantic(axis=(0, 1, 0), angle=math.pi / 2, phase=0), - }, - "Y90": "y90", - "rz": { - "signature": (ArgType.QUBIT, ArgType.FLOAT), - "semantic": lambda theta: SingleQubitAxisAngleSemantic(axis=(0, 0, 1), angle=theta, phase=0), - }, - "RZ": "rz", - "cnot": { - "signature": (ArgType.QUBIT, ArgType.QUBIT), - "semantic": MultiQubitMatrixSemantic( - np.array( - [ - [1, 0, 0, 0], - [0, 1, 0, 0], - [0, 0, 0, 1], - [0, 0, 1, 0], - ] - ) - ), - }, - "CNOT": "cnot", - "cz": { - "signature": (ArgType.QUBIT, ArgType.QUBIT), - "semantic": MultiQubitMatrixSemantic( - np.array( - [ - [1, 0, 0, 0], - [0, 1, 0, 0], - [0, 0, 1, 0], - [0, 0, 0, -1], - ] - ) - ), - }, - "CZ": "cz", - "cr": { - "signature": (ArgType.QUBIT, ArgType.QUBIT, ArgType.FLOAT), - "semantic": lambda theta: MultiQubitMatrixSemantic( - np.array( - [ - [1, 0, 0, 0], - [0, 1, 0, 0], - [0, 0, 1, 0], - [0, 0, 0, cmath.rect(1, theta)], - ] - ) - ), - }, - "CR": "cr", - # Rest is TODO -} +from opensquirrel.squirrel_ir import * + + +@named_gate +def h(q: Qubit) -> Gate: + return BlochSphereRotation(qubit=q, axis=(1, 0, 1), angle=math.pi, phase=math.pi / 2) + + +@named_gate +def x(q: Qubit) -> Gate: + return BlochSphereRotation(qubit=q, axis=(1, 0, 0), angle=math.pi, phase=math.pi / 2) + + +@named_gate +def x90(q: Qubit) -> Gate: + return BlochSphereRotation(qubit=q, axis=(1, 0, 0), angle=math.pi / 2, phase=0) + + +@named_gate +def y(q: Qubit) -> Gate: + return BlochSphereRotation(qubit=q, axis=(0, 1, 0), angle=math.pi, phase=math.pi / 2) + + +@named_gate +def y90(q: Qubit) -> Gate: + return BlochSphereRotation(qubit=q, axis=(0, 1, 0), angle=math.pi / 2, phase=0) + + +@named_gate +def z(q: Qubit) -> Gate: + return BlochSphereRotation(qubit=q, axis=(0, 0, 1), angle=math.pi, phase=math.pi / 2) + + +@named_gate +def z90(q: Qubit) -> Gate: + return BlochSphereRotation(qubit=q, axis=(0, 0, 1), angle=math.pi / 2, phase=0) + + +@named_gate +def rx(q: Qubit, theta: Float) -> Gate: + return BlochSphereRotation(qubit=q, axis=(1, 0, 0), angle=theta.value, phase=0) + + +@named_gate +def ry(q: Qubit, theta: Float) -> Gate: + return BlochSphereRotation(qubit=q, axis=(0, 1, 0), angle=theta.value, phase=0) + + +@named_gate +def rz(q: Qubit, theta: Float) -> Gate: + return BlochSphereRotation(qubit=q, axis=(0, 0, 1), angle=theta.value, phase=0) + + +@named_gate +def cnot(control: Qubit, target: Qubit) -> Gate: + return ControlledGate(control, x(target)) + + +@named_gate +def cz(control: Qubit, target: Qubit) -> Gate: + return ControlledGate(control, z(target)) + + +@named_gate +def cr(control: Qubit, target: Qubit, theta: Float) -> Gate: + return ControlledGate( + control, BlochSphereRotation(qubit=target, axis=(0, 0, 1), angle=theta.value, phase=theta.value / 2) + ) + + +default_gate_set = [h, x, x90, y, y90, z, z90, cz, cr, cnot, rx, ry, rz, x] +default_gate_aliases = {f.__name__: f for f in default_gate_set} +default_gate_aliases.update({"X": x, "RX": rx, "RY": ry, "RZ": rz, "Hadamard": h, "H": h}) diff --git a/opensquirrel/gate_library.py b/opensquirrel/gate_library.py new file mode 100644 index 00000000..65270258 --- /dev/null +++ b/opensquirrel/gate_library.py @@ -0,0 +1,13 @@ +class GateLibrary: + def __init__(self, gate_set, gate_aliases): + self.gate_set = gate_set + self.gate_aliases = gate_aliases + + def get_gate_f(self, gate_name: str): + try: + generator_f = next(f for f in self.gate_set if f.__name__ == gate_name) + except StopIteration: + if gate_name not in self.gate_aliases: + raise Exception(f"Unknown gate `{gate_name}`") + generator_f = self.gate_aliases[gate_name] + return generator_f diff --git a/opensquirrel/gates.py b/opensquirrel/gates.py deleted file mode 100644 index fb3e3f19..00000000 --- a/opensquirrel/gates.py +++ /dev/null @@ -1,71 +0,0 @@ -from typing import Tuple - -import numpy as np - -from opensquirrel.common import ArgType - - -class Semantic: - pass - - -class SingleQubitAxisAngleSemantic(Semantic): - def __init__(self, axis: Tuple[float, float, float], angle: float, phase: float): - self.axis = self._normalize(np.asarray(axis, dtype=np.float64)) - self.angle = angle - self.phase = phase - - def _normalize(self, axis): - norm = np.linalg.norm(axis) - axis /= norm - return axis - - -class MultiQubitMatrixSemantic(Semantic): - def __init__(self, matrix: np.ndarray): - self.matrix = matrix - - -class ControlledSemantic(MultiQubitMatrixSemantic): - def __init__(self, numberOfControlQubits: int, matrix: np.ndarray): - pass # TODO - - -def queryEntry(gatesDict: dict, gateName: str): - if gateName not in gatesDict: - raise Exception(f"Unknown gate or alias of gate: `{gateName}`") - - entry = gatesDict[gateName] - - if isinstance(entry, str): - return queryEntry(gatesDict, entry) - - return entry - - -def querySemantic(gatesDict: dict, gateName: str, *gateArgs): - signature = querySignature(gatesDict, gateName) - assert len(gateArgs) == sum(1 for t in signature if t != ArgType.QUBIT) - - entry = queryEntry(gatesDict, gateName) - - assert "semantic" in entry, f"Gate semantic not defined for gate: `{gateName}`" - - semantic = entry["semantic"] - - if isinstance(semantic, Semantic): - assert len(gateArgs) == 0, f"Gate `{gateName}` accepts no argument" - - return semantic - - return semantic(*gateArgs) # TODO: nice error when args don't match? But should be already checked by typer - - -def querySignature(gatesDict: dict, gateName: str): - entry = queryEntry(gatesDict, gateName) - - assert "signature" in entry, f"Gate signature not defined for gate: `{gateName}`" - - signature = entry["signature"] - - return signature diff --git a/opensquirrel/mckay_decomposer.py b/opensquirrel/mckay_decomposer.py index e2fedc72..718869f6 100644 --- a/opensquirrel/mckay_decomposer.py +++ b/opensquirrel/mckay_decomposer.py @@ -3,37 +3,26 @@ import numpy as np -from opensquirrel.common import ATOL, ArgType -from opensquirrel.gates import SingleQubitAxisAngleSemantic, queryEntry, querySemantic, querySignature -from opensquirrel.squirrel_ast import SquirrelAST +from opensquirrel.common import ATOL, normalize_angle +from opensquirrel.default_gates import rz, x90 +from opensquirrel.squirrel_ir import BlochSphereRotation, Float, Gate, Int, Qubit, SquirrelIR, Statement -def normalizeAngle(x: float) -> float: - t = x - 2 * pi * (x // (2 * pi) + 1) - if t < -pi + ATOL: - t += 2 * pi - elif t > pi: - t -= 2 * pi - return t +class _McKayDecomposerImpl: + def __init__(self, number_of_qubits: int, qubit_register_name: str): + self.output = SquirrelIR(number_of_qubits, qubit_register_name) + self.accumulated_1q_gates = {} - -class McKayDecomposer: - def __init__(self, gates): - self.gates = gates - - queryEntry(self.gates, "rz") # FIXME: improve. Pass those gates as parameters to the constructor. - queryEntry(self.gates, "x90") - - def _decomposeAndAdd(self, qubit, angle: float, axis: Tuple[float, float, float]): + def decompose_and_add(self, qubit: Qubit, angle: float, axis: Tuple[float, float, float]): if abs(angle) < ATOL: return # McKay decomposition - zaMod = sqrt(cos(angle / 2) ** 2 + (axis[2] * sin(angle / 2)) ** 2) - zbMod = abs(sin(angle / 2)) * sqrt(axis[0] ** 2 + axis[1] ** 2) + za_mod = sqrt(cos(angle / 2) ** 2 + (axis[2] * sin(angle / 2)) ** 2) + zb_mod = abs(sin(angle / 2)) * sqrt(axis[0] ** 2 + axis[1] ** 2) - theta = pi - 2 * atan2(zbMod, zaMod) + theta = pi - 2 * atan2(zb_mod, za_mod) alpha = atan2(-sin(angle / 2) * axis[2], cos(angle / 2)) beta = atan2(-sin(angle / 2) * axis[0], -sin(angle / 2) * axis[1]) @@ -41,97 +30,90 @@ def _decomposeAndAdd(self, qubit, angle: float, axis: Tuple[float, float, float] lam = beta - alpha phi = -beta - alpha - pi - lam = normalizeAngle(lam) - phi = normalizeAngle(phi) - theta = normalizeAngle(theta) + lam = normalize_angle(lam) + phi = normalize_angle(phi) + theta = normalize_angle(theta) if abs(lam) > ATOL: - self.output.add_gate("rz", qubit, lam) + self.output.add_gate(rz(qubit, Float(lam))) - self.output.add_gate("x90", qubit) + self.output.add_gate(x90(qubit)) if abs(theta) > ATOL: - self.output.add_gate("rz", qubit, theta) + self.output.add_gate(rz(qubit, Float(theta))) - self.output.add_gate("x90", qubit) + self.output.add_gate(x90(qubit)) if abs(phi) > ATOL: - self.output.add_gate("rz", qubit, phi) + self.output.add_gate(rz(qubit, Float(phi))) - def _flush(self, q): - if q not in self.oneQubitGates: + def flush(self, q): + if q not in self.accumulated_1q_gates: return - p = self.oneQubitGates.pop(q) - self._decomposeAndAdd(q, p["angle"], p["axis"]) + p = self.accumulated_1q_gates.pop(q) + self.decompose_and_add(q, p["angle"], p["axis"]) - def _flush_all(self): - while len(self.oneQubitGates) > 0: - self._flush(next(iter(self.oneQubitGates))) + def flush_all(self): + while len(self.accumulated_1q_gates) > 0: + self.flush(next(iter(self.accumulated_1q_gates))) - def _acc(self, qubit, semantic: SingleQubitAxisAngleSemantic): - axis, angle, phase = semantic.axis, semantic.angle, semantic.phase + def accumulate(self, qubit, bloch_sphere_rotation: BlochSphereRotation): + axis, angle, phase = bloch_sphere_rotation.axis, bloch_sphere_rotation.angle, bloch_sphere_rotation.phase - if qubit not in self.oneQubitGates: - self.oneQubitGates[qubit] = {"angle": angle, "axis": axis, "phase": phase} + if qubit not in self.accumulated_1q_gates: + self.accumulated_1q_gates[qubit] = {"angle": angle, "axis": axis, "phase": phase} return - existing = self.oneQubitGates[qubit] - combinedPhase = phase + existing["phase"] + existing = self.accumulated_1q_gates[qubit] + combined_phase = phase + existing["phase"] a = angle l = axis b = existing["angle"] m = existing["axis"] - combinedAngle = 2 * acos(cos(a / 2) * cos(b / 2) - sin(a / 2) * sin(b / 2) * np.dot(l, m)) + combined_angle = 2 * acos(cos(a / 2) * cos(b / 2) - sin(a / 2) * sin(b / 2) * np.dot(l, m)) - if abs(sin(combinedAngle / 2)) < ATOL: - self.oneQubitGates.pop(qubit) + if abs(sin(combined_angle / 2)) < ATOL: + self.accumulated_1q_gates.pop(qubit) return - combinedAxis = ( + combined_axis = ( 1 - / sin(combinedAngle / 2) + / sin(combined_angle / 2) * (sin(a / 2) * cos(b / 2) * l + cos(a / 2) * sin(b / 2) * m + sin(a / 2) * sin(b / 2) * np.cross(l, m)) ) - self.oneQubitGates[qubit] = {"angle": combinedAngle, "axis": combinedAxis, "phase": combinedPhase} + self.accumulated_1q_gates[qubit] = {"angle": combined_angle, "axis": combined_axis, "phase": combined_phase} - def process(self, squirrelAST): - # FIXME: duplicate gates in ast and self?? - self.output = SquirrelAST(self.gates, squirrelAST.nQubits, squirrelAST.qubitRegisterName) - self.oneQubitGates = {} + def process_gate(self, gate: Gate): + qubit_arguments = [arg for arg in gate.arguments if isinstance(arg, Qubit)] - for operation in squirrelAST.operations: - if isinstance(operation, str): - continue - - self._processSingleOperation(operation) + if len(qubit_arguments) >= 2: + [self.flush(q) for q in qubit_arguments] + self.output.add_gate(gate) + return - self._flush_all() + if len(qubit_arguments) == 0: + assert False, "Unsupported" + return - return self.output # FIXME: instead of returning a new AST, modify existing one + assert isinstance( + gate, BlochSphereRotation + ), f"Not supported for single qubit gate `{gate.name}`: {type(gate)}" - def _processSingleOperation(self, operation): - gateName, gateArgs = operation + self.accumulate(qubit_arguments[0], gate) - signature = querySignature(self.gates, gateName) - qubitArguments = [arg for i, arg in enumerate(gateArgs) if signature[i] == ArgType.QUBIT] - nonQubitArguments = [arg for i, arg in enumerate(gateArgs) if signature[i] != ArgType.QUBIT] - if len(qubitArguments) >= 2: - [self._flush(q) for q in qubitArguments] - self.output.add_gate(gateName, *gateArgs) - return +def decompose_mckay(squirrel_ir): + impl = _McKayDecomposerImpl(squirrel_ir.number_of_qubits, squirrel_ir.qubit_register_name) - if len(qubitArguments) == 0: - assert False, "Unsupported" - return + for statement in squirrel_ir.statements: + if not isinstance(statement, Gate): + continue - semantic = querySemantic(self.gates, gateName, *nonQubitArguments) + impl.process_gate(statement) - assert isinstance( - semantic, SingleQubitAxisAngleSemantic - ), f"Not supported for single qubit gate `{gateName}`: {type(semantic)}" + impl.flush_all() - self._acc(qubitArguments[0], semantic) + return impl.output # FIXME: instead of returning a new IR, modify existing one diff --git a/opensquirrel/parsing/antlr/qubit_range_checker.py b/opensquirrel/parsing/antlr/qubit_range_checker.py new file mode 100644 index 00000000..6b78174d --- /dev/null +++ b/opensquirrel/parsing/antlr/qubit_range_checker.py @@ -0,0 +1,66 @@ +from opensquirrel.parsing.antlr.generated import CQasm3Visitor + + +class QubitRangeChecker(CQasm3Visitor.CQasm3Visitor): + """ + This class checks that all qubit indices make sense in an ANTLR parse tree. + It is an instance of the ANTLR abstract syntax tree visitor class. + Therefore, method names are fixed and based on rule names in the Grammar .g4 file. + """ + + def __init__(self): + self.number_of_qubits = 0 + + def visitProg(self, ctx): + self.visit(ctx.qubitRegisterDeclaration()) + for gate in ctx.gateApplication(): + self.visit(gate) + + def visitQubitRegisterDeclaration(self, ctx): + self.number_of_qubits = int(str(ctx.INT())) + + def visitGateApplication(self, ctx): + visited_args = (self.visit(arg) for arg in ctx.expr()) + qubit_argument_sizes = [qubit_range_size for qubit_range_size in visited_args if qubit_range_size is not None] + + if len(qubit_argument_sizes) > 0 and not all(s == qubit_argument_sizes[0] for s in qubit_argument_sizes): + raise Exception("Invalid gate call with qubit arguments of different sizes") + + def visitQubit(self, ctx): + qubit_index = int(str(ctx.INT())) + if qubit_index >= self.number_of_qubits: + raise Exception(f"Qubit index {qubit_index} out of range") + + return 1 + + def visitQubits(self, ctx): + qubit_indices = list(map(int, map(str, ctx.INT()))) + for qubit_index in qubit_indices: + if qubit_index >= self.number_of_qubits: + raise Exception(f"Qubit index {qubit_index} out of range") + + return len(qubit_indices) + + def visitQubitRange(self, ctx): + first_qubit_index = int(str(ctx.INT(0))) + last_qubit_index = int(str(ctx.INT(1))) + + if first_qubit_index > last_qubit_index: + raise Exception(f"Qubit index range {first_qubit_index}:{last_qubit_index} malformed") + + if max(first_qubit_index, last_qubit_index) >= self.number_of_qubits: + raise Exception(f"Qubit index range {first_qubit_index}:{last_qubit_index} out of range") + + return last_qubit_index - first_qubit_index + 1 + + def visitIntLiteral(self, ctx): + return None + + def visitNegatedIntLiteral(self, ctx): + return None + + def visitFloatLiteral(self, ctx): + return None + + def visitNegatedFloatLiteral(self, ctx): + return None diff --git a/opensquirrel/parsing/antlr/squirrel_ast_creator.py b/opensquirrel/parsing/antlr/squirrel_ast_creator.py deleted file mode 100644 index 0e05df61..00000000 --- a/opensquirrel/parsing/antlr/squirrel_ast_creator.py +++ /dev/null @@ -1,71 +0,0 @@ -from opensquirrel.common import ArgType -from opensquirrel.gates import querySignature -from opensquirrel.parsing.antlr.generated import CQasm3Visitor -from opensquirrel.squirrel_ast import SquirrelAST - - -class SquirrelASTCreator(CQasm3Visitor.CQasm3Visitor): - """ - This class creates a SquirrelAST object from an ANTLR parse tree. - It is an instance of the ANTLR grammar visitor. - Therefore, method names are fixed and based on rule names in the Grammar .g4 file. - """ - - def __init__(self, gates): - self.gates = gates - self.squirrel_ast = None - - def visitProg(self, ctx): - qubit_register_name, number_of_qubits = self.visit(ctx.qubitRegisterDeclaration()) # Use? - - self.squirrel_ast = SquirrelAST(self.gates, number_of_qubits, qubit_register_name) - - for gate_application in ctx.gateApplication(): - self.visit(gate_application) - - return self.squirrel_ast - - def visitGateApplication(self, ctx): - gate_name = str(ctx.ID()) - - signature = querySignature(self.gates, gate_name) - - number_of_operands = next( - len(self.visit(ctx.expr(i))) for i in range(len(signature)) if signature[i] == ArgType.QUBIT - ) - - expanded_args = [ - self.visit(ctx.expr(i)) - if signature[i] == ArgType.QUBIT - else [self.visit(ctx.expr(i)) for _ in range(number_of_operands)] - for i in range(len(signature)) - ] - - for individual_args in zip(*expanded_args): - self.squirrel_ast.add_gate(gate_name, *individual_args) - - def visitQubitRegisterDeclaration(self, ctx): - return str(ctx.ID()), int(str(ctx.INT())) - - def visitQubit(self, ctx): - return [int(str(ctx.INT()))] - - def visitQubits(self, ctx): - return list(map(int, map(str, ctx.INT()))) - - def visitQubitRange(self, ctx): - qubit1 = int(str(ctx.INT(0))) - qubit2 = int(str(ctx.INT(1))) - return list(range(qubit1, qubit2 + 1)) - - def visitFloatLiteral(self, ctx): - return float(str(ctx.FLOAT())) - - def visitNegatedFloatLiteral(self, ctx): - return -float(str(ctx.FLOAT())) - - def visitIntLiteral(self, ctx): - return int(str(ctx.INT())) - - def visitNegatedIntLiteral(self, ctx): - return -int(str(ctx.INT())) diff --git a/opensquirrel/squirrel_error_handler.py b/opensquirrel/parsing/antlr/squirrel_error_handler.py similarity index 80% rename from opensquirrel/squirrel_error_handler.py rename to opensquirrel/parsing/antlr/squirrel_error_handler.py index 02fc77c6..2efac352 100644 --- a/opensquirrel/squirrel_error_handler.py +++ b/opensquirrel/parsing/antlr/squirrel_error_handler.py @@ -6,7 +6,7 @@ class SquirrelParseException(Exception): class SquirrelErrorHandler(Antlr4ErrorListener): - def syntaxError(self, recognizer, offendingSymbol, line, column, msg, e): + def syntaxError(self, recognizer, offending_symbol, line, column, msg, e): stack = recognizer.getRuleInvocationStack() stack.reverse() raise SquirrelParseException(f"Parsing error at {line}:{column}: {msg}") diff --git a/opensquirrel/parsing/antlr/squirrel_ir_creator.py b/opensquirrel/parsing/antlr/squirrel_ir_creator.py new file mode 100644 index 00000000..e3217158 --- /dev/null +++ b/opensquirrel/parsing/antlr/squirrel_ir_creator.py @@ -0,0 +1,73 @@ +import inspect + +from opensquirrel.default_gates import default_gate_set, default_gate_aliases +from opensquirrel.gate_library import GateLibrary +from opensquirrel.parsing.antlr.generated import CQasm3Visitor +from opensquirrel.squirrel_ir import Float, Int, Qubit, SquirrelIR + + +class SquirrelIRCreator(GateLibrary, CQasm3Visitor.CQasm3Visitor): + """ + This class creates a SquirrelIR object from an ANTLR parse tree. + It is an instance of the ANTLR abstract syntax tree visitor class. + Therefore, method names are fixed and based on rule names in the Grammar .g4 file. + """ + + def __init__(self, gate_set = default_gate_set, gate_aliases = default_gate_aliases): + GateLibrary.__init__(self, gate_set, gate_aliases) + self.squirrel_ir = None + + def visitProg(self, ctx): + number_of_qubits, qubit_register_name = self.visit(ctx.qubitRegisterDeclaration()) + + self.squirrel_ir = SquirrelIR(number_of_qubits, qubit_register_name) + + for gate_application in ctx.gateApplication(): + self.visit(gate_application) + + return self.squirrel_ir + + def visitGateApplication(self, ctx): + gate_name = str(ctx.ID()) + + generator_f = GateLibrary.get_gate_f(self, gate_name) + parameters = inspect.signature(generator_f).parameters + + number_of_operands = next( + len(self.visit(ctx.expr(i))) for i, par in enumerate(parameters.values()) if par.annotation == Qubit + ) + + # The below is for handling e.g. `cr q[1:3], q[5:7], 1.23` + expanded_args = [ + self.visit(ctx.expr(i)) if par.annotation == Qubit else [self.visit(ctx.expr(i))] * number_of_operands + for i, par in enumerate(parameters.values()) + ] + + for individual_args in zip(*expanded_args): + self.squirrel_ir.add_gate(generator_f(*individual_args)) + + def visitQubitRegisterDeclaration(self, ctx): + return int(str(ctx.INT())), str(ctx.ID()) + + def visitQubit(self, ctx): + return [Qubit(int(str(ctx.INT())))] + + def visitQubits(self, ctx): + return list(map(Qubit, map(int, map(str, ctx.INT())))) + + def visitQubitRange(self, ctx): + first_qubit_index = int(str(ctx.INT(0))) + last_qubit_index = int(str(ctx.INT(1))) + return list(map(Qubit, range(first_qubit_index, last_qubit_index + 1))) + + def visitFloatLiteral(self, ctx): + return Float(float(str(ctx.FLOAT()))) + + def visitNegatedFloatLiteral(self, ctx): + return Float(-float(str(ctx.FLOAT()))) + + def visitIntLiteral(self, ctx): + return Int(int(str(ctx.INT()))) + + def visitNegatedIntLiteral(self, ctx): + return Int(-int(str(ctx.INT()))) diff --git a/opensquirrel/parsing/antlr/squirrel_ir_from_string.py b/opensquirrel/parsing/antlr/squirrel_ir_from_string.py new file mode 100644 index 00000000..bf5d98ff --- /dev/null +++ b/opensquirrel/parsing/antlr/squirrel_ir_from_string.py @@ -0,0 +1,44 @@ +import antlr4 + +from opensquirrel.parsing.antlr.generated import CQasm3Lexer, CQasm3Parser +from opensquirrel.parsing.antlr.qubit_range_checker import QubitRangeChecker +from opensquirrel.parsing.antlr.squirrel_error_handler import SquirrelErrorHandler +from opensquirrel.parsing.antlr.squirrel_ir_creator import SquirrelIRCreator +from opensquirrel.parsing.antlr.type_checker import TypeChecker + + +def antlr_tree_from_string(s: str): + input_stream = antlr4.InputStream(s) + + lexer = CQasm3Lexer.CQasm3Lexer(input_stream) + + stream = antlr4.CommonTokenStream(lexer) + + parser = CQasm3Parser.CQasm3Parser(stream) + + parser.removeErrorListeners() + parser.addErrorListener(SquirrelErrorHandler()) + + return parser.prog() + + +def type_check_antlr_tree(tree, gate_set, gate_aliases): + type_checker = TypeChecker(gate_set, gate_aliases) + type_checker.visit(tree) # FIXME: return error instead of throwing? + + +def check_qubit_ranges_of_antlr_tree(tree): + qubit_range_checker = QubitRangeChecker() + qubit_range_checker.visit(tree) # FIXME: return error instead of throwing? + + +def squirrel_ir_from_string(s: str, gate_set, gate_aliases): + tree = antlr_tree_from_string(s) + + type_check_antlr_tree(tree, gate_set=gate_set, gate_aliases=gate_aliases) + + check_qubit_ranges_of_antlr_tree(tree) + + squirrel_ir_creator = SquirrelIRCreator(gate_set=gate_set, gate_aliases=gate_aliases) + + return squirrel_ir_creator.visit(tree) diff --git a/opensquirrel/parsing/antlr/type_checker.py b/opensquirrel/parsing/antlr/type_checker.py new file mode 100644 index 00000000..5474d76a --- /dev/null +++ b/opensquirrel/parsing/antlr/type_checker.py @@ -0,0 +1,75 @@ +import inspect + +from opensquirrel.default_gates import default_gate_set, default_gate_aliases +from opensquirrel.gate_library import GateLibrary +from opensquirrel.parsing.antlr.generated import CQasm3Visitor +from opensquirrel.squirrel_ir import Float, Int, Qubit + + +class TypeChecker(GateLibrary, CQasm3Visitor.CQasm3Visitor): + """ + This class checks that all gate parameter types make sense in an ANTLR parse tree. + It is an instance of the ANTLR abstract syntax tree visitor class. + Therefore, method names are fixed and based on rule names in the Grammar .g4 file. + """ + + def __init__(self, gate_set = default_gate_set, gate_aliases = default_gate_aliases): + GateLibrary.__init__(self, gate_set, gate_aliases) + self.register_name = None + + def visitProg(self, ctx): + self.visit(ctx.qubitRegisterDeclaration()) + for gate in ctx.gateApplication(): + self.visit(gate) + + def visitQubitRegisterDeclaration(self, ctx): + self.register_name = str(ctx.ID()) + + def visitGateApplication(self, ctx): + # Check that the types of the operands match the gate generator function. + gate_name = str(ctx.ID()) + generator_f = GateLibrary.get_gate_f(self, gate_name) + + parameters = inspect.signature(generator_f).parameters + + if len(ctx.expr()) > len(parameters): + raise Exception(f"Gate `{gate_name}` takes {len(parameters)} arguments, but {len(ctx.expr())} were given!") + + for i, param in enumerate(parameters.values()): + actual_type = self.visit(ctx.expr(i)) + expected_type = param.annotation + if actual_type != expected_type: + raise Exception( + f"Argument #{i} passed to gate `{gate_name}` is of type" + f" {actual_type} but should be {expected_type}" + ) + + def visitQubit(self, ctx): + if str(ctx.ID()) != self.register_name: + raise Exception(f"Qubit register {str(ctx.ID())} not declared") + + return Qubit + + def visitQubits(self, ctx): + if str(ctx.ID()) != self.register_name: + raise Exception(f"Qubit register {str(ctx.ID())} not declared") + + return Qubit + + def visitQubitRange(self, ctx): + if str(ctx.ID()) != self.register_name: + raise Exception(f"Qubit register {str(ctx.ID())} not declared") + + return Qubit + + def visitIntLiteral(self, ctx): + return Int + + def visitNegatedIntLiteral(self, ctx): + return Int + + def visitFloatLiteral(self, ctx): + return Float + + def visitNegatedFloatLiteral(self, ctx): + return Float diff --git a/opensquirrel/replacer.py b/opensquirrel/replacer.py index 36136a25..1d9bb085 100644 --- a/opensquirrel/replacer.py +++ b/opensquirrel/replacer.py @@ -1,52 +1,31 @@ -from opensquirrel.common import ArgType -from opensquirrel.gates import querySignature -from opensquirrel.squirrel_ast import SquirrelAST +from typing import List +from opensquirrel.squirrel_ir import Comment, Gate, SquirrelIR -class Replacer: - def __init__(self, gates): - self.gates = gates - def process(self, squirrelAST: SquirrelAST, replacedGateName: str, f): - result = SquirrelAST(self.gates, squirrelAST.nQubits, squirrelAST.qubitRegisterName) +def replace(squirrel_ir: SquirrelIR, gate_name_to_replace: str, f): + statement_index = 0 + while statement_index < len(squirrel_ir.statements): + statement = squirrel_ir.statements[statement_index] - signature = querySignature(self.gates, replacedGateName) + if isinstance(statement, Comment): + statement_index += 1 + continue - for operation in squirrelAST.operations: - if isinstance(operation, str): - continue + if not isinstance(statement, Gate): + raise Exception("Unsupported") - otherGateName, otherArgs = operation + if statement.name != gate_name_to_replace: + statement_index += 1 + continue - if otherGateName != replacedGateName: - result.add_gate(otherGateName, *otherArgs) - continue + # FIXME: handle case where if f is not a function but directly a list. - # FIXME: handle case where if f is not a function but directly a list. + replacement: List[Gate] = f(*statement.arguments) + squirrel_ir.statements[statement_index : statement_index + 1] = replacement + statement_index += len(replacement) - assert len(otherArgs) == len(signature) - originalQubits = set(otherArgs[i] for i in range(len(otherArgs)) if signature[i] == ArgType.QUBIT) + # TODO: Here, check that the semantic of the replacement is the same! + # For this, need to update the simulation capabilities. - replacement = f(*otherArgs) - - # TODO: Here, check that the semantic of the replacement is the same! - # For this, need to update the simulation capabilities. - - # TODO: Do we allow skipping the replacement, based on arguments? - - assert isinstance(replacement, list), "Substitution needs to be a list" - - for replacementGate in replacement: - replacementGateName, replacementGateArgs = replacementGate - - replacementGateSignature = querySignature(self.gates, replacementGateName) - assert len(replacementGateArgs) == len(replacementGateSignature) - assert all( - replacementGateArgs[i] in originalQubits - for i in range(len(replacementGateArgs)) - if replacementGateSignature[i] == ArgType.QUBIT - ), (f"Substitution for gate `{replacedGateName}` " f"must use the input qubits {originalQubits} only") - - result.add_gate(replacementGateName, *replacementGateArgs) - - return result + # TODO: Do we allow skipping the replacement, based on arguments? diff --git a/opensquirrel/squirrel_ast.py b/opensquirrel/squirrel_ast.py deleted file mode 100644 index 67fc81d9..00000000 --- a/opensquirrel/squirrel_ast.py +++ /dev/null @@ -1,54 +0,0 @@ -from typing import Any - -from opensquirrel.gates import querySignature - - -class SquirrelAST: - # This is just a list of gates (for now?) - def __init__(self, gates, nQubits: int, qubitRegisterName: str): - self.gates = gates - self.nQubits: int = nQubits - self.operations: list[Any] = [] - self.qubitRegisterName: str = qubitRegisterName - - def add_gate(self, gateName: str, *interpretedArgs): - signature = querySignature(self.gates, gateName) - assert len(signature) == len(interpretedArgs), f"Wrong number of arguments for gate `{gateName}`" - - # FIXME: Also check int vs float - - self.operations.append((gateName, interpretedArgs)) - - def add_comment(self, commentString: str): - assert "*/" not in commentString, "Comment contains illegal characters" - self.operations.append(commentString) - - def __eq__(self, other): - if self.gates != other.gates: - return False - - if self.nQubits != other.nQubits: - return False - - if self.qubitRegisterName != other.qubitRegisterName: - return False - - if len(self.operations) != len(other.operations): - return False - - for i in range(len(self.operations)): - leftName, leftArgs = self.operations[i] - rightName, rightArgs = other.operations[i] - - if leftName != rightName: - return False - - if len(leftArgs) != len(rightArgs) or any(leftArgs[i] != rightArgs[i] for i in range(len(leftArgs))): - # if len(leftArgs) != len(rightArgs) or any(abs(leftArgs[i] - rightArgs[i]) > - # ATOL for i in range(len(leftArgs))): - return False - - return True - - def __repr__(self): - return f"""AST ({self.nQubits} qubits, register {self.qubitRegisterName}): {self.operations}""" diff --git a/opensquirrel/squirrel_ir.py b/opensquirrel/squirrel_ir.py new file mode 100644 index 00000000..77804409 --- /dev/null +++ b/opensquirrel/squirrel_ir.py @@ -0,0 +1,225 @@ +import inspect +from abc import ABC +from dataclasses import dataclass +from functools import wraps +from typing import Callable, List, Tuple, Union, Optional + +import numpy as np + +from opensquirrel.common import ATOL, normalize_angle + + +class SquirrelIRVisitor(ABC): + def visit_comment(self, comment: "Comment"): + pass + + def visit_int(self, i: "Int"): + pass + + def visit_float(self, f: "Float"): + pass + + def visit_qubit(self, qubit: "Qubit"): + pass + + def visit_gate(self, gate: "Gate"): + pass + + def visit_bloch_sphere_rotation(self, bloch_sphere_rotation: "BlochSphereRotation"): + pass + + def visit_matrix_gate(self, matrix_gate: "MatrixGate"): + pass + + def visit_controlled_gate(self, controlled_gate: "ControlledGate"): + pass + + +class IRNode(ABC): + def accept(self, visitor: SquirrelIRVisitor): + pass + + +class Expression(IRNode, ABC): + pass + + +class NonQubitExpression(Expression, ABC): + pass + + +@dataclass +class Float(NonQubitExpression): + value: float + + def accept(self, visitor: SquirrelIRVisitor): + return visitor.visit_float(self) + + +@dataclass +class Int(NonQubitExpression): + value: int + + def accept(self, visitor: SquirrelIRVisitor): + return visitor.visit_int(self) + + +@dataclass +class Qubit(Expression): + index: int + + def __hash__(self): + return hash(self.index) + + def accept(self, visitor: SquirrelIRVisitor): + return visitor.visit_qubit(self) + + +class Statement(IRNode, ABC): + pass + + +class Gate(Statement, ABC): + arguments: Optional[Tuple[Expression, ...]] = None + + @property + def name(self) -> Optional[str]: + return self.generator.__name__ if self.generator else None + + def is_anonymous(self) -> bool: + return self.arguments is None + + +class BlochSphereRotation(Gate): + generator: Optional[Callable[..., "BlochSphereRotation"]] = None + + def __init__(self, qubit: Qubit, axis: Tuple[float, float, float], angle: float, phase: float = 0): + self.qubit: Qubit = qubit + self.axis = self._normalize_axis(np.array(axis).astype(np.float64)) + self.angle = normalize_angle(angle) + self.phase = normalize_angle(phase) + + @staticmethod + def _normalize_axis(axis): + norm = np.linalg.norm(axis) + axis /= norm + return axis + + def __repr__(self): + return f"BlochSphereRotation(q[{self.qubit.index}], axis={self.axis}, angle={self.angle}, phase={self.phase})" + + def __eq__(self, other): + if self.qubit != other.qubit: + return False + + if abs(self.phase - other.phase) > ATOL: + return False + + if np.allclose(self.axis, other.axis): + return abs(self.angle - other.angle) < ATOL + elif np.allclose(self.axis, -other.axis): + return abs(self.angle + other.angle) < ATOL + return False + + def accept(self, visitor: SquirrelIRVisitor): + visitor.visit_gate(self) + return visitor.visit_bloch_sphere_rotation(self) + + +class MatrixGate(Gate): + generator: Optional[Callable[..., "MatrixGate"]] = None + + def __init__(self, matrix: np.ndarray, operands: List[Qubit]): + assert len(operands) >= 2, "For 1q gates, please use BlochSphereRotation" + assert matrix.shape == (1 << len(operands), 1 << len(operands)) + + self.matrix = matrix + self.operands = operands + + def __eq__(self, other): + # Should we allow a global phase difference here? + return np.allclose(self.matrix, other.matrix) + + def __repr__(self): + return f"MatrixGate(qubits={list(map(lambda q: q.index, self.operands))}, matrix={self.matrix})" + + def accept(self, visitor: SquirrelIRVisitor): + visitor.visit_gate(self) + return visitor.visit_matrix_gate(self) + + +class ControlledGate(Gate): + generator: Optional[Callable[..., "ControlledGate"]] = None + + def __init__(self, control: Qubit, target_gate: Gate): + self.control = control + self.target_gate = target_gate + + def __eq__(self, other): + if self.control != other.control: + return False + + return self.target_gate == other.target_gate + + def __repr__(self): + return f"ControlledGate(control_qubit={self.control.index}, {self.target_gate})" + + def accept(self, visitor: SquirrelIRVisitor): + visitor.visit_gate(self) + return visitor.visit_controlled_gate(self) + + +def named_gate(gate_generator: Callable[..., Gate]) -> Callable[..., Gate]: + @wraps(gate_generator) + def wrapper(*args): + for i, par in enumerate(inspect.signature(gate_generator).parameters.values()): + if not issubclass(par.annotation, Expression): + raise TypeError("Gate argument types must be expressions") + + result = gate_generator(*args) + result.generator = gate_generator + result.arguments = args + return result + + return wrapper + + +@dataclass +class Comment(Statement): + str: str + + def __post_init__(self): + assert "*/" not in self.str, "Comment contains illegal characters" + + def accept(self, visitor: SquirrelIRVisitor): + return visitor.visit_comment(self) + + +class SquirrelIR: + # This is just a list of gates (for now?) + def __init__(self, number_of_qubits: int, qubit_register_name: str): + self.number_of_qubits: int = number_of_qubits + self.statements: List[Statement] = [] + self.qubit_register_name: str = qubit_register_name + + def add_gate(self, g: Gate): + self.statements.append(g) + + def add_comment(self, comment: Comment): + self.statements.append(comment) + + def __eq__(self, other): + if self.number_of_qubits != other.number_of_qubits: + return False + + if self.qubit_register_name != other.qubit_register_name: + return False + + return self.statements == other.statements + + def __repr__(self): + return f"""IR ({self.number_of_qubits} qubits, register {self.qubit_register_name}): {self.statements}""" + + def accept(self, visitor: SquirrelIRVisitor): + for statement in self.statements: + statement.accept(visitor) diff --git a/opensquirrel/test_interpreter.py b/opensquirrel/test_interpreter.py index 9607b33f..f48f14e5 100644 --- a/opensquirrel/test_interpreter.py +++ b/opensquirrel/test_interpreter.py @@ -1,30 +1,25 @@ import numpy as np -from opensquirrel.common import ArgType -from opensquirrel.gates import querySemantic, querySignature -from opensquirrel.utils.matrix_expander import get_expanded_matrix +from opensquirrel.squirrel_ir import Comment, Gate, SquirrelIR, SquirrelIRVisitor +from opensquirrel.utils.matrix_expander import get_matrix -class TestInterpreter: - def __init__(self, gates): - self.gates = gates +class _TestInterpreterImpl(SquirrelIRVisitor): + def __init__(self, number_of_qubits): + self.number_of_qubits = number_of_qubits + self.matrix = np.eye(1 << self.number_of_qubits, dtype=np.complex128) - def process(self, squirrelAST): - totalUnitary = np.eye(1 << squirrelAST.nQubits, dtype=np.complex128) + def visit_gate(self, gate: Gate): + big_matrix = get_matrix(gate, number_of_qubits=self.number_of_qubits) + self.matrix = big_matrix @ self.matrix - for operation in squirrelAST.operations: - if isinstance(operation, str): - continue + def visit_comment(self, comment: Comment): + pass - gateName, gateArgs = operation - signature = querySignature(self.gates, gateName) - assert len(gateArgs) == len(signature) - qubitOperands = [gateArgs[i] for i in range(len(gateArgs)) if signature[i] == ArgType.QUBIT] - semantic = querySemantic( - self.gates, gateName, *[gateArgs[i] for i in range(len(gateArgs)) if signature[i] != ArgType.QUBIT] - ) - bigMatrix = get_expanded_matrix(semantic, qubitOperands, number_of_qubits=squirrelAST.nQubits) - totalUnitary = bigMatrix @ totalUnitary +def get_circuit_matrix(squirrel_ir: SquirrelIR): + impl = _TestInterpreterImpl(squirrel_ir.number_of_qubits) - return totalUnitary + squirrel_ir.accept(impl) + + return impl.matrix diff --git a/opensquirrel/type_checker.py b/opensquirrel/type_checker.py deleted file mode 100644 index 306d0c65..00000000 --- a/opensquirrel/type_checker.py +++ /dev/null @@ -1,95 +0,0 @@ -from opensquirrel.common import ExprType, exprTypeToArgType -from opensquirrel.gates import querySignature -from opensquirrel.parsing.antlr.generated import CQasm3Visitor - - -class TypeChecker(CQasm3Visitor.CQasm3Visitor): - def __init__(self, gates): - self.gates = gates - self.nQubits = 0 - - def visitProg(self, ctx): - self.visit(ctx.qubitRegisterDeclaration()) - for gate in ctx.gateApplication(): - self.visit(gate) - - def visitQubitRegisterDeclaration(self, ctx): - self.nQubits = int(str(ctx.INT())) - self.registerName = str(ctx.ID()) - - def visitGateApplication(self, ctx): - # Check that the type of operands match the gate declaration - gateName = str(ctx.ID()) - if gateName not in self.gates: - raise Exception(f"Unknown gate `{gateName}`") - - expectedSignature = querySignature(self.gates, gateName) - - if len(ctx.expr()) != len(expectedSignature): - raise Exception( - f"Gate `{gateName}` takes {len(expectedSignature)} arguments," f" but {len(ctx.expr())} were given!" - ) - - i = 0 - qubitArrays = None - for arg in ctx.expr(): - argumentData = self.visit(arg) - argumentType = argumentData[0] - if argumentType == ExprType.QUBITREFS: - if isinstance(qubitArrays, int) and qubitArrays != argumentData[1]: - raise Exception("Invalid gate call with qubit arguments of different sizes") - - qubitArrays = argumentData[1] - - if expectedSignature[i] != exprTypeToArgType(argumentType): - raise Exception( - f"Argument #{i} passed to gate `{gateName}` is of type" - f" {exprTypeToArgType(argumentType)} but should be {expectedSignature[i]}" - ) - i += 1 - - def visitQubit(self, ctx): - if str(ctx.ID()) != self.registerName: - raise Exception(f"Qubit register {str(ctx.ID())} not declared") - - qubitIndex = int(str(ctx.INT())) - if qubitIndex >= self.nQubits: - raise Exception(f"Qubit index {qubitIndex} out of range") - - return (ExprType.QUBITREFS, 1) - - def visitQubits(self, ctx): - if str(ctx.ID()) != self.registerName: - raise Exception(f"Qubit register {str(ctx.ID())} not declared") - - qubitIndices = list(map(int, map(str, ctx.INT()))) - if any(i >= self.nQubits for i in qubitIndices): - raise Exception(f"Qubit index {next(i for i in qubitIndices if i >= self.nQubits)} out of range") - - return ExprType.QUBITREFS, len(qubitIndices) - - def visitQubitRange(self, ctx): - if str(ctx.ID()) != self.registerName: - raise Exception(f"Qubit register {str(ctx.ID())} not declared") - - qubitIndex1 = int(str(ctx.INT(0))) - qubitIndex2 = int(str(ctx.INT(1))) - if max(qubitIndex1, qubitIndex2) >= self.nQubits: - raise Exception(f"Qubit indices {qubitIndex1}:{qubitIndex2} out of range") - - if qubitIndex1 > qubitIndex2: - raise Exception(f"Qubit indices {qubitIndex1}:{qubitIndex2} malformed") - - return (ExprType.QUBITREFS, qubitIndex2 - qubitIndex1 + 1) - - def visitIntLiteral(self, ctx): - return (ExprType.INT,) - - def visitNegatedIntLiteral(self, ctx): - return (ExprType.INT,) - - def visitFloatLiteral(self, ctx): - return (ExprType.FLOAT,) - - def visitNegatedFloatLiteral(self, ctx): - return (ExprType.FLOAT,) diff --git a/opensquirrel/utils/matrix_expander.py b/opensquirrel/utils/matrix_expander.py index 52534c7a..a3660f9b 100644 --- a/opensquirrel/utils/matrix_expander.py +++ b/opensquirrel/utils/matrix_expander.py @@ -3,11 +3,12 @@ import numpy as np -from opensquirrel.common import Can1 -from opensquirrel.gates import MultiQubitMatrixSemantic, Semantic, SingleQubitAxisAngleSemantic +from opensquirrel.common import can1 +from opensquirrel.squirrel_ir import BlochSphereRotation, ControlledGate, MatrixGate, Qubit, Gate, \ + SquirrelIRVisitor -def get_reduced_ket(ket: int, qubits: List[int]) -> int: +def get_reduced_ket(ket: int, qubits: List[Qubit]) -> int: """ Given a quantum ket represented by its corresponding base-10 integer, this computes the reduced ket where only the given qubits appear, in order. @@ -22,27 +23,27 @@ def get_reduced_ket(ket: int, qubits: List[int]) -> int: The non-negative integer corresponding to the reduced ket. Examples: - >>> get_reduced_ket(1, [0]) # 0b01 + >>> get_reduced_ket(1, [Qubit(0)]) # 0b01 1 - >>> get_reduced_ket(1111, [2]) # 0b01 + >>> get_reduced_ket(1111, [Qubit(2)]) # 0b01 1 - >>> get_reduced_ket(1111, [5]) # 0b0 + >>> get_reduced_ket(1111, [Qubit(5)]) # 0b0 0 - >>> get_reduced_ket(1111, [2, 5]) # 0b01 + >>> get_reduced_ket(1111, [Qubit(2), Qubit(5)]) # 0b01 1 - >>> get_reduced_ket(101, [1, 0]) # 0b10 + >>> get_reduced_ket(101, [Qubit(1), Qubit(0)]) # 0b10 2 - >>> get_reduced_ket(101, [0, 1]) # 0b01 + >>> get_reduced_ket(101, [Qubit(0), Qubit(1)]) # 0b01 1 """ reduced_ket = 0 for i, qubit in enumerate(qubits): - reduced_ket |= ((ket & (1 << qubit)) >> qubit) << i + reduced_ket |= ((ket & (1 << qubit.index)) >> qubit.index) << i return reduced_ket -def expand_ket(base_ket: int, reduced_ket: int, qubits: List[int]) -> int: +def expand_ket(base_ket: int, reduced_ket: int, qubits: List[Qubit]) -> int: """ Given a base quantum ket on n qubits and a reduced ket on a subset of those qubits, this computes the expanded ket where the reduction qubits and the other qubits are set based on the reduced ket and the base ket, respectively. @@ -59,34 +60,88 @@ def expand_ket(base_ket: int, reduced_ket: int, qubits: List[int]) -> int: The non-negative integer corresponding to the expanded ket. Examples: - >>> expand_ket(0b00000, 0b0, [5]) # 0b000000 + >>> expand_ket(0b00000, 0b0, [Qubit(5)]) # 0b000000 0 - >>> expand_ket(0b00000, 0b1, [5]) # 0b100000 + >>> expand_ket(0b00000, 0b1, [Qubit(5)]) # 0b100000 32 - >>> expand_ket(0b00111, 0b0, [5]) # 0b000111 + >>> expand_ket(0b00111, 0b0, [Qubit(5)]) # 0b000111 7 - >>> expand_ket(0b00111, 0b1, [5]) # 0b100111 + >>> expand_ket(0b00111, 0b1, [Qubit(5)]) # 0b100111 39 - >>> expand_ket(0b0000, 0b000, [1, 2, 3]) # 0b0000 + >>> expand_ket(0b0000, 0b000, [Qubit(1), Qubit(2), Qubit(3)]) # 0b0000 0 - >>> expand_ket(0b0000, 0b001, [1, 2, 3]) # 0b0010 + >>> expand_ket(0b0000, 0b001, [Qubit(1), Qubit(2), Qubit(3)]) # 0b0010 2 - >>> expand_ket(0b0000, 0b011, [1, 2, 3]) # 0b0110 + >>> expand_ket(0b0000, 0b011, [Qubit(1), Qubit(2), Qubit(3)]) # 0b0110 6 - >>> expand_ket(0b0000, 0b101, [1, 2, 3]) # 0b1010 + >>> expand_ket(0b0000, 0b101, [Qubit(1), Qubit(2), Qubit(3)]) # 0b1010 10 - >>> expand_ket(0b0001, 0b101, [1, 2, 3]) # 0b1011 + >>> expand_ket(0b0001, 0b101, [Qubit(1), Qubit(2), Qubit(3)]) # 0b1011 11 """ expanded_ket = base_ket for i, qubit in enumerate(qubits): - expanded_ket &= ~(1 << qubit) # Erase bit. - expanded_ket |= ((reduced_ket & (1 << i)) >> i) << qubit # Set bit to value from reduced_ket. + expanded_ket &= ~(1 << qubit.index) # Erase bit. + expanded_ket |= ((reduced_ket & (1 << i)) >> i) << qubit.index # Set bit to value from reduced_ket. return expanded_ket -def get_expanded_matrix(semantic: Semantic, qubit_operands: List[int], number_of_qubits: int) -> np.ndarray: +class MatrixExpander(SquirrelIRVisitor): + def __init__(self, number_of_qubits: int): + self.number_of_qubits = number_of_qubits + + def visit_bloch_sphere_rotation(self, rot): + assert rot.qubit.index < self.number_of_qubits + + result = np.kron( + np.kron(np.eye(1 << (self.number_of_qubits - rot.qubit.index - 1)), can1(rot.axis, rot.angle, rot.phase)), + np.eye(1 << rot.qubit.index), + ) + assert result.shape == (1 << self.number_of_qubits, 1 << self.number_of_qubits) + return result + + def visit_controlled_gate(self, gate): + assert gate.control.index < self.number_of_qubits + + expanded_matrix = gate.target_gate.accept(self) + for col_index, col in enumerate(expanded_matrix.T): + if col_index & (1 << gate.control.index) == 0: + col[:] = 0 + col[col_index] = 1 + return expanded_matrix + + def visit_matrix_gate(self, gate): + # The convention is to write gate matrices with operands reversed. + # For instance, the first operand of CNOT is the control qubit, and this is written as + # 1, 0, 0, 0 + # 0, 1, 0, 0 + # 0, 0, 0, 1 + # 0, 0, 1, 0 + # which corresponds to control being q[1] and target being q[0], + # since qubit #i corresponds to the i-th LEAST significant bit. + qubit_operands = list(reversed(gate.operands)) + + assert all(q.index < self.number_of_qubits for q in qubit_operands) + + m = gate.matrix + + assert m.shape == (1 << len(qubit_operands), 1 << len(qubit_operands)) + + expanded_matrix = np.zeros((1 << self.number_of_qubits, 1 << self.number_of_qubits), dtype=m.dtype) + + for expanded_matrix_column in range(expanded_matrix.shape[1]): + small_matrix_col = get_reduced_ket(expanded_matrix_column, qubit_operands) + + for small_matrix_row, value in enumerate(m[:, small_matrix_col]): + expanded_matrix_row = expand_ket(expanded_matrix_column, small_matrix_row, qubit_operands) + expanded_matrix[expanded_matrix_row][expanded_matrix_column] = value + + assert expanded_matrix.shape == (1 << self.number_of_qubits, 1 << self.number_of_qubits) + return expanded_matrix + + +def get_matrix(gate: Gate, number_of_qubits: int) -> np.ndarray: """ Compute the unitary matrix corresponding to the gate applied to those qubit operands, taken among any number of qubits. This can be used for, e.g., @@ -95,20 +150,19 @@ def get_expanded_matrix(semantic: Semantic, qubit_operands: List[int], number_of - simulating a circuit (simulation in this way is inefficient for large numbers of qubits). Args: - semantic: The semantic of the gate. - qubit_operands: The qubit indices on which the gate operates. + gate: The gate, including the qubits on which it is operated on. number_of_qubits: The total number of qubits. Examples: - >>> X = SingleQubitAxisAngleSemantic((1, 0, 0), math.pi, math.pi / 2) - >>> get_expanded_matrix(X, [1], 2).astype(int) # X q[1] + >>> X = lambda q: BlochSphereRotation(qubit=q, axis=(1, 0, 0), angle=math.pi, phase=math.pi / 2) + >>> get_matrix(X(Qubit(1)), 2).astype(int) # X q[1] array([[0, 0, 1, 0], [0, 0, 0, 1], [1, 0, 0, 0], [0, 1, 0, 0]]) - >>> CNOT = MultiQubitMatrixSemantic(np.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 0, 1], [0, 0, 1, 0]])) - >>> get_expanded_matrix(CNOT, [0, 2], 3) # CNOT q[0], q[2] + >>> CNOT02 = ControlledGate(Qubit(0), X(Qubit(2))) + >>> get_matrix(CNOT02, 3).astype(int) # CNOT q[0], q[2] array([[1, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 1, 0, 0], [0, 0, 1, 0, 0, 0, 0, 0], @@ -117,7 +171,7 @@ def get_expanded_matrix(semantic: Semantic, qubit_operands: List[int], number_of [0, 1, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 1, 0], [0, 0, 0, 1, 0, 0, 0, 0]]) - >>> get_expanded_matrix(CNOT, [1, 2], 3) # CNOT q[1], q[2] + >>> get_matrix(ControlledGate(Qubit(1), X(Qubit(2))), 3).astype(int) # CNOT q[1], q[2] array([[1, 0, 0, 0, 0, 0, 0, 0], [0, 1, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 1, 0], @@ -127,43 +181,7 @@ def get_expanded_matrix(semantic: Semantic, qubit_operands: List[int], number_of [0, 0, 1, 0, 0, 0, 0, 0], [0, 0, 0, 1, 0, 0, 0, 0]]) """ - if isinstance(semantic, SingleQubitAxisAngleSemantic): - assert len(qubit_operands) == 1 - - which_qubit = qubit_operands[0] - - axis, angle, phase = semantic.axis, semantic.angle, semantic.phase - result = np.kron( - np.kron(np.eye(1 << (number_of_qubits - which_qubit - 1)), Can1(axis, angle, phase)), - np.eye(1 << which_qubit), - ) - assert result.shape == (1 << number_of_qubits, 1 << number_of_qubits) - return result - - assert isinstance(semantic, MultiQubitMatrixSemantic) - - # The convention is to write gate matrices with operands reversed. - # For instance, the first operand of CNOT is the control qubit, and this is written as - # 1, 0, 0, 0 - # 0, 1, 0, 0 - # 0, 0, 0, 1 - # 0, 0, 1, 0 - # which corresponds to control being q[1] and target being q[0], - # since qubit #i corresponds to the i-th LEAST significant bit. - qubit_operands.reverse() - - m = semantic.matrix - - assert m.shape == (1 << len(qubit_operands), 1 << len(qubit_operands)) - - expanded_matrix = np.zeros((1 << number_of_qubits, 1 << number_of_qubits), dtype=m.dtype) - - for expanded_matrix_column in range(expanded_matrix.shape[1]): - small_matrix_col = get_reduced_ket(expanded_matrix_column, qubit_operands) - for small_matrix_row, value in enumerate(m[:, small_matrix_col]): - expanded_matrix_row = expand_ket(expanded_matrix_column, small_matrix_row, qubit_operands) - expanded_matrix[expanded_matrix_row][expanded_matrix_column] = value + expander = MatrixExpander(number_of_qubits) + return gate.accept(expander) - assert expanded_matrix.shape == (1 << number_of_qubits, 1 << number_of_qubits) - return expanded_matrix diff --git a/opensquirrel/writer.py b/opensquirrel/writer.py index 60d6ed79..50a42f38 100644 --- a/opensquirrel/writer.py +++ b/opensquirrel/writer.py @@ -1,41 +1,37 @@ -from opensquirrel.common import ArgType -from opensquirrel.gates import querySignature +from opensquirrel.squirrel_ir import Comment, Float, Gate, Int, Qubit, SquirrelIR, SquirrelIRVisitor, Statement -class Writer: - NUMBER_OF_SIGNIFICANT_DIGITS = 8 +class _WriterImpl(SquirrelIRVisitor): + number_of_significant_digits = 8 - def __init__(self, gates): - self.gates = gates + def __init__(self, number_of_qubits, qubit_register_name): + self.qubit_register_name = qubit_register_name + self.output = f"""version 3.0\n\nqubit[{number_of_qubits}] {qubit_register_name}\n\n""" - @classmethod - def _format_arg(cls, squirrelAST, arg, t: ArgType): - if t == ArgType.QUBIT: - return f"{squirrelAST.qubitRegisterName}[{arg}]" - if t == ArgType.INT: - return f"{int(arg)}" - if t == ArgType.FLOAT: - return f"{float(arg):.{Writer.NUMBER_OF_SIGNIFICANT_DIGITS}}" + def visit_qubit(self, qubit: Qubit): + return f"{self.qubit_register_name}[{qubit.index}]" - assert False, "Unknown argument type" + def visit_int(self, _int: Int): + return f"{_int.value}" - def process(self, squirrelAST): - output = "" - output += f"""version 3.0\n\nqubit[{squirrelAST.nQubits}] {squirrelAST.qubitRegisterName}\n\n""" + def visit_float(self, _float: Float): + return f"{_float.value:.{self.number_of_significant_digits}}" - for operation in squirrelAST.operations: - if isinstance(operation, str): - comment = operation - assert "*/" not in comment, "Comment contains illegal characters" + def visit_gate(self, gate: Gate): + if gate.is_anonymous(): + self.output += "\n" + return - output += f"\n/* {comment} */\n\n" - continue + formatted_args = (arg.accept(self) for arg in gate.arguments) + self.output += f"{gate.name} {', '.join(formatted_args)}\n" - gateName, gateArgs = operation - signature = querySignature(self.gates, gateName) + def visit_comment(self, comment: Comment): + self.output += f"\n/* {comment.str} */\n\n" - args = [Writer._format_arg(squirrelAST, arg, t) for arg, t in zip(gateArgs, signature)] - output += f"{gateName} {', '.join(args)}\n" +def squirrel_ir_to_string(squirrel_ir: SquirrelIR): + writer_impl = _WriterImpl(squirrel_ir.number_of_qubits, squirrel_ir.qubit_register_name) - return output + squirrel_ir.accept(writer_impl) + + return writer_impl.output diff --git a/test/test_circuitbuilder.py b/test/test_circuitbuilder.py index 0ca5b2b8..367d08b7 100644 --- a/test/test_circuitbuilder.py +++ b/test/test_circuitbuilder.py @@ -1,44 +1,64 @@ import unittest from opensquirrel.circuit_builder import CircuitBuilder -from opensquirrel.default_gates import DefaultGates +from opensquirrel.default_gates import * +from opensquirrel.squirrel_ir import Qubit class CircuitBuilderTest(unittest.TestCase): def test_simple(self): - builder = CircuitBuilder(DefaultGates, 3) + print(default_gate_aliases) + builder = CircuitBuilder(3) - builder.h(0) - builder.cnot(0, 1) + builder.h(Qubit(0)) + builder.cnot(Qubit(0), Qubit(1)) circuit = builder.to_circuit() - self.assertEqual(circuit.getNumberOfQubits(), 3) - self.assertEqual(circuit.getQubitRegisterName(), "q") - self.assertEqual(len(circuit.squirrel_ast.operations), 2) - self.assertEqual(circuit.squirrel_ast.operations[0], ("h", (0,))) - self.assertEqual(circuit.squirrel_ast.operations[1], ("cnot", (0, 1))) + self.assertEqual(circuit.number_of_qubits, 3) + self.assertEqual(circuit.qubit_register_name, "q") + self.assertEqual( + circuit.squirrel_ir.statements, + [ + h(Qubit(0)), + cnot(Qubit(0), Qubit(1)), + ], + ) def test_chain(self): - builder = CircuitBuilder(DefaultGates, 3) + builder = CircuitBuilder(3) - circuit = builder.h(0).cnot(0, 1).to_circuit() + circuit = builder.h(Qubit(0)).cnot(Qubit(0), Qubit(1)).to_circuit() - self.assertEqual(len(circuit.squirrel_ast.operations), 2) - self.assertEqual(circuit.squirrel_ast.operations[0], ("h", (0,))) - self.assertEqual(circuit.squirrel_ast.operations[1], ("cnot", (0, 1))) + self.assertEqual( + circuit.squirrel_ir.statements, + [ + h(Qubit(0)), + cnot(Qubit(0), Qubit(1)), + ], + ) def test_unknown_gate(self): - builder = CircuitBuilder(DefaultGates, 3) + builder = CircuitBuilder(3) - with self.assertRaisesRegex(Exception, "Unknown gate or alias of gate: `un`"): + with self.assertRaisesRegex(Exception, "Unknown gate `un`"): builder.un(0) def test_wrong_number_of_arguments(self): - builder = CircuitBuilder(DefaultGates, 3) + builder = CircuitBuilder(3) - with self.assertRaisesRegex(AssertionError, "Wrong number of arguments for gate `h`"): - builder.h(0, 1) + with self.assertRaisesRegex(TypeError, r"h\(\) takes 1 positional argument but 2 were given"): + builder.h(Qubit(0), Qubit(1)) + + def test_wrong_argument_type(self): + builder = CircuitBuilder(3) + + with self.assertRaisesRegex( + TypeError, + "Wrong argument type for gate `h`, got but expected " + "", + ): + builder.h(0) if __name__ == "__main__": diff --git a/test/test_decompose_mckay.py b/test/test_decompose_mckay.py index 90417544..de8e2952 100644 --- a/test/test_decompose_mckay.py +++ b/test/test_decompose_mckay.py @@ -2,97 +2,95 @@ import numpy as np +import opensquirrel.mckay_decomposer as mckay_decomposer +import opensquirrel.test_interpreter as test_interpreter from opensquirrel.common import ATOL -from opensquirrel.default_gates import DefaultGates -from opensquirrel.mckay_decomposer import McKayDecomposer -from opensquirrel.squirrel_ast import SquirrelAST -from opensquirrel.test_interpreter import TestInterpreter +from opensquirrel.default_gates import * +from opensquirrel.squirrel_ir import Float, Qubit, SquirrelIR -def areMatricesEqualUpToGlobalPhase(matrixA, matrixB): - firstNonZero = next( +def are_matrices_equal_up_to_global_phase(matrixA, matrixB): + first_non_zero = next( (i, j) for i in range(matrixA.shape[0]) for j in range(matrixA.shape[1]) if abs(matrixA[i, j]) > ATOL ) - if abs(matrixB[firstNonZero]) < ATOL: + if abs(matrixB[first_non_zero]) < ATOL: return False - phaseDifference = matrixA[firstNonZero] / matrixB[firstNonZero] + phase_difference = matrixA[first_non_zero] / matrixB[first_non_zero] - return np.allclose(matrixA, phaseDifference * matrixB) + return np.allclose(matrixA, phase_difference * matrixB) class DecomposeMcKayTests(unittest.TestCase): - def checkMcKayDecomposition(self, squirrelAST, expectedAST=None): + def checkMcKayDecomposition(self, squirrel_ir, expected_ir=None): """ - Check whether the mcKay decomposition transformation applied to the input AST preserves the + Check whether the mcKay decomposition transformation applied to the input IR preserves the circuit matrix up to an irrelevant global phase factor. """ - interpreter = TestInterpreter(squirrelAST.gates) # Store matrix before decompositions. - expectedMatrix = interpreter.process(squirrelAST) + expected_matrix = test_interpreter.get_circuit_matrix(squirrel_ir) - decomposer = McKayDecomposer(squirrelAST.gates) - output = decomposer.process(squirrelAST) + output = mckay_decomposer.decompose_mckay(squirrel_ir) - self.assertEqual(output.nQubits, squirrelAST.nQubits) - self.assertEqual(output.qubitRegisterName, squirrelAST.qubitRegisterName) + self.assertEqual(output.number_of_qubits, squirrel_ir.number_of_qubits) + self.assertEqual(output.qubit_register_name, squirrel_ir.qubit_register_name) - if expectedAST is not None: - self.assertEqual(output, expectedAST) + if expected_ir is not None: + self.assertEqual(output, expected_ir) # Get matrix after decompositions. - actualMatrix = interpreter.process(output) + actual_matrix = test_interpreter.get_circuit_matrix(output) - self.assertTrue(areMatricesEqualUpToGlobalPhase(actualMatrix, expectedMatrix)) + self.assertTrue(are_matrices_equal_up_to_global_phase(actual_matrix, expected_matrix)) def test_one(self): - ast = SquirrelAST(DefaultGates, 2, "squirrel") + ir = SquirrelIR(2, "squirrel") - ast.add_gate("ry", 0, 23847628349.123) - ast.add_gate("rx", 0, 29384672.234) - ast.add_gate("rz", 0, 9877.87634) + ir.add_gate(ry(Qubit(0), Float(23847628349.123))) + ir.add_gate(rx(Qubit(0), Float(29384672.234))) + ir.add_gate(rz(Qubit(0), Float(9877.87634))) - self.checkMcKayDecomposition(ast) + self.checkMcKayDecomposition(ir) def test_two(self): - ast = SquirrelAST(DefaultGates, 2, "squirrel") + ir = SquirrelIR(2, "squirrel") - ast.add_gate("ry", 0, 23847628349.123) - ast.add_gate("cnot", 0, 1) - ast.add_gate("rx", 0, 29384672.234) - ast.add_gate("rz", 0, 9877.87634) - ast.add_gate("cnot", 0, 1) - ast.add_gate("rx", 0, 29384672.234) - ast.add_gate("rz", 0, 9877.87634) + ir.add_gate(ry(Qubit(0), Float(23847628349.123))) + ir.add_gate(cnot(Qubit(0), Qubit(1))) + ir.add_gate(rx(Qubit(0), Float(29384672.234))) + ir.add_gate(rz(Qubit(0), Float(9877.87634))) + ir.add_gate(cnot(Qubit(0), Qubit(1))) + ir.add_gate(rx(Qubit(0), Float(29384672.234))) + ir.add_gate(rz(Qubit(0), Float(9877.87634))) - self.checkMcKayDecomposition(ast) + self.checkMcKayDecomposition(ir) def test_small_random(self): - ast = SquirrelAST(DefaultGates, 4, "q") - - ast.add_gate("H", 2) - ast.add_gate("cr", 2, 3, 2.123) - ast.add_gate("H", 1) - ast.add_gate("H", 0) - ast.add_gate("H", 2) - ast.add_gate("H", 1) - ast.add_gate("H", 0) - ast.add_gate("cr", 2, 3, 2.123) - - expectedAst = SquirrelAST(DefaultGates, 4, "q") - - expectedAst.add_gate("x90", 2) - expectedAst.add_gate("rz", 2, 1.5707963267948966) - expectedAst.add_gate("x90", 2) - expectedAst.add_gate("cr", 2, 3, 2.123) - expectedAst.add_gate("x90", 2) - expectedAst.add_gate("rz", 2, 1.5707963267948966) - expectedAst.add_gate("x90", 2) - expectedAst.add_gate("cr", 2, 3, 2.123) - - self.checkMcKayDecomposition(ast, expectedAst) + ir = SquirrelIR(4, "q") + + ir.add_gate(h(Qubit(2))) + ir.add_gate(cr(Qubit(2), Qubit(3), Float(2.123))) + ir.add_gate(h(Qubit(1))) + ir.add_gate(h(Qubit(0))) + ir.add_gate(h(Qubit(2))) + ir.add_gate(h(Qubit(1))) + ir.add_gate(h(Qubit(0))) + ir.add_gate(cr(Qubit(2), Qubit(3), Float(2.123))) + + expected_ir = SquirrelIR(4, "q") + + expected_ir.add_gate(x90(Qubit(2))) + expected_ir.add_gate(rz(Qubit(2), Float(1.5707963267948966))) + expected_ir.add_gate(x90(Qubit(2))) + expected_ir.add_gate(cr(Qubit(2), Qubit(3), Float(2.123))) + expected_ir.add_gate(x90(Qubit(2))) + expected_ir.add_gate(rz(Qubit(2), Float(1.5707963267948966))) + expected_ir.add_gate(x90(Qubit(2))) + expected_ir.add_gate(cr(Qubit(2), Qubit(3), Float(2.123))) + + self.checkMcKayDecomposition(ir, expected_ir) if __name__ == "__main__": diff --git a/test/test_integration.py b/test/test_integration.py index 2bdc617f..56f96eaa 100644 --- a/test/test_integration.py +++ b/test/test_integration.py @@ -3,13 +3,12 @@ import unittest from opensquirrel.circuit import Circuit -from opensquirrel.default_gates import DefaultGates +from opensquirrel.default_gates import * class IntegrationTest(unittest.TestCase): def test_simple(self): myCircuit = Circuit.from_string( - DefaultGates, """ version 3.0 @@ -20,8 +19,7 @@ def test_simple(self): cnot qreg[0], qreg[1] rx qreg[0], -2.3 ry qreg[1], -3.14 - """, - ) + """) # Decompose CNOT as # @@ -31,7 +29,12 @@ def test_simple(self): # myCircuit.replace( - "cnot", lambda control, target: [("h", (target,)), ("cz", (control, target)), ("h", (target,))] + "cnot", + lambda control, target: [ + h(target), + cz(control, target), + h(target), + ], ) # Do 1q-gate fusion and decompose with McKay decomposition. @@ -73,7 +76,6 @@ def test_simple(self): def test_qi(self): myCircuit = Circuit.from_string( - DefaultGates, """ version 3.0 @@ -101,8 +103,7 @@ def test_qi(self): cr q[2], q[3], 2.123 RY q[1], -1.5707963 - """, - ) + """) myCircuit.decompose_mckay() output = str(myCircuit) diff --git a/test/test_parsing.py b/test/test_parsing.py index 0e33308c..e9577aac 100644 --- a/test/test_parsing.py +++ b/test/test_parsing.py @@ -2,20 +2,19 @@ import antlr4 -from opensquirrel.default_gates import DefaultGates +from opensquirrel.default_gates import * from opensquirrel.parsing.antlr.generated import CQasm3Lexer, CQasm3Parser -from opensquirrel.parsing.antlr.squirrel_ast_creator import SquirrelASTCreator -from opensquirrel.squirrel_error_handler import SquirrelErrorHandler, SquirrelParseException -from opensquirrel.type_checker import TypeChecker +from opensquirrel.parsing.antlr.squirrel_error_handler import SquirrelErrorHandler, SquirrelParseException +from opensquirrel.parsing.antlr.squirrel_ir_creator import SquirrelIRCreator +from opensquirrel.parsing.antlr.type_checker import TypeChecker class ParsingTest(unittest.TestCase): def setUp(self): - self.gates = DefaultGates - self.astCreator = SquirrelASTCreator(DefaultGates) + self.ir_creator = SquirrelIRCreator() - def typeCheck(self, cQasm3String): - input_stream = antlr4.InputStream(cQasm3String) + def type_check(self, cqasm_string): + input_stream = antlr4.InputStream(cqasm_string) lexer = CQasm3Lexer.CQasm3Lexer(input_stream) @@ -28,57 +27,59 @@ def typeCheck(self, cQasm3String): tree = parser.prog() - typeChecker = TypeChecker(self.gates) - typeChecker.visit(tree) + type_checker = TypeChecker() + type_checker.visit(tree) return tree - def getAST(self, cQasm3String): - tree = self.typeCheck(cQasm3String) - return self.astCreator.visit(tree) + def get_ir(self, cqasm_string): + tree = self.type_check(cqasm_string) + return self.ir_creator.visit(tree) def test_empty(self): with self.assertRaisesRegex(SquirrelParseException, "Parsing error at 1:0: mismatched input '' expecting"): - self.typeCheck("") + self.type_check("") def test_illegal(self): with self.assertRaisesRegex( SquirrelParseException, "Parsing error at 1:0: mismatched input 'illegal' expecting" ): - self.typeCheck("illegal") + self.type_check("illegal") def test_no_qubits(self): with self.assertRaisesRegex( SquirrelParseException, "Parsing error at 1:14: mismatched input 'h' expecting 'qubit" ): - self.typeCheck("version 3.0; h q[0]") + self.type_check("version 3.0; h q[0]") def test_wrong_version(self): with self.assertRaisesRegex( SquirrelParseException, "Parsing error at 1:8: mismatched input '3.1' expecting '3.0'" ): - self.typeCheck("version 3.1; qubit[1] q; h q[0]") + self.type_check("version 3.1; qubit[1] q; h q[0]") def test_unknown_gate(self): with self.assertRaisesRegex(Exception, "Unknown gate `unknowngate`"): - self.typeCheck("version 3.0; qubit[1] q; unknowngate q[0]") + self.type_check("version 3.0; qubit[1] q; unknowngate q[0]") def test_wrong_argument_type(self): with self.assertRaisesRegex( - Exception, "Argument #1 passed to gate `rx` is of type ArgType.INT but should be ArgType.FLOAT" + Exception, + "Argument #1 passed to gate `rx` is of type but should be ", ): - self.typeCheck("version 3.0; qubit[1] q; rx q[0], 42") + self.type_check("version 3.0; qubit[1] q; rx q[0], 42") def test_wrong_argument_type_2(self): with self.assertRaisesRegex( - Exception, "Argument #0 passed to gate `rx` is of type ArgType.FLOAT but should be ArgType.QUBIT" + Exception, + "Argument #0 passed to gate `rx` is of type but should be ", ): - self.typeCheck("version 3.0; qubit[1] q; rx 42., q[0]") + self.type_check("version 3.0; qubit[1] q; rx 42., q[0]") - # FIXME: add comments to AST when parsing? + # FIXME: add comments to IR when parsing? def test_simple(self): - ast = self.getAST( + ir = self.get_ir( """ version 3.0 @@ -88,13 +89,13 @@ def test_simple(self): """ ) - self.assertEqual(ast.nQubits, 1) - self.assertEqual(ast.qubitRegisterName, "qu") - self.assertEqual(len(ast.operations), 1) - self.assertEqual(ast.operations[0], ("h", (0,))) + self.assertEqual(ir.number_of_qubits, 1) + self.assertEqual(ir.qubit_register_name, "qu") + self.assertEqual(len(ir.statements), 1) + self.assertEqual(ir.statements[0], h(Qubit(0))) def test_r_xyz(self): - ast = self.getAST( + ir = self.get_ir( """ version 3.0 qubit[2] squirrel @@ -105,15 +106,15 @@ def test_r_xyz(self): """ ) - self.assertEqual(ast.nQubits, 2) - self.assertEqual(ast.qubitRegisterName, "squirrel") - self.assertEqual(len(ast.operations), 3) - self.assertEqual(ast.operations[0], ("h", (0,))) - self.assertEqual(ast.operations[1], ("rx", (1, 1.23))) - self.assertEqual(ast.operations[2], ("ry", (0, -42))) + self.assertEqual(ir.number_of_qubits, 2) + self.assertEqual(ir.qubit_register_name, "squirrel") + self.assertEqual(len(ir.statements), 3) + self.assertEqual(ir.statements[0], h(Qubit(0))) + self.assertEqual(ir.statements[1], rx(Qubit(1), Float(1.23))) + self.assertEqual(ir.statements[2], ry(Qubit(0), Float(-42))) def test_multiple_qubits(self): - ast = self.getAST( + ir = self.get_ir( """ version 3.0 qubit[10] large @@ -123,17 +124,17 @@ def test_multiple_qubits(self): """ ) - self.assertEqual(ast.nQubits, 10) - self.assertEqual(ast.qubitRegisterName, "large") - self.assertEqual(len(ast.operations), 5) - self.assertEqual(ast.operations[0], ("h", (0,))) - self.assertEqual(ast.operations[1], ("h", (3,))) - self.assertEqual(ast.operations[2], ("h", (6,))) - self.assertEqual(ast.operations[3], ("x90", (4,))) - self.assertEqual(ast.operations[4], ("x90", (5,))) + self.assertEqual(ir.number_of_qubits, 10) + self.assertEqual(ir.qubit_register_name, "large") + self.assertEqual(len(ir.statements), 5) + self.assertEqual(ir.statements[0], h(Qubit(0))) + self.assertEqual(ir.statements[1], h(Qubit(3))) + self.assertEqual(ir.statements[2], h(Qubit(6))) + self.assertEqual(ir.statements[3], x90(Qubit(4))) + self.assertEqual(ir.statements[4], x90(Qubit(5))) def test_aliases(self): - ast = self.getAST( + ir = self.get_ir( """ version 3.0 qubit[2] q; @@ -142,7 +143,7 @@ def test_aliases(self): """ ) - self.assertEqual(ast.operations[0], ("H", (1,))) + self.assertEqual(ir.statements[0], h(Qubit(1))) if __name__ == "__main__": diff --git a/test/test_replacer.py b/test/test_replacer.py index 4e3cef4d..d7a18d46 100644 --- a/test/test_replacer.py +++ b/test/test_replacer.py @@ -1,32 +1,31 @@ import unittest -from opensquirrel.default_gates import DefaultGates -from opensquirrel.replacer import Replacer -from opensquirrel.squirrel_ast import SquirrelAST +import opensquirrel.replacer as replacer +from opensquirrel.default_gates import * +from opensquirrel.squirrel_ir import Gate, Qubit, SquirrelIR -def hadamard_decomposition(q): +def hadamard_decomposition(q: Qubit): return [ - ("y90", (q,)), - ("x", (q,)), + y90(q), + x(q), ] class ReplacerTest(unittest.TestCase): def test_replace(self): - squirrelAST = SquirrelAST(DefaultGates, 3, "test") + squirrel_ir = SquirrelIR(3, "test") - replacer = Replacer(DefaultGates) + squirrel_ir.add_gate(h(Qubit(0))) - squirrelAST.add_gate("h", 0) + replacer.replace(squirrel_ir, "h", hadamard_decomposition) - replaced = replacer.process(squirrelAST, "h", hadamard_decomposition) + expected_ir = SquirrelIR(3, "test") - self.assertEqual(replaced.nQubits, 3) - self.assertEqual(replaced.qubitRegisterName, "test") - self.assertEqual(len(replaced.operations), 2) - self.assertEqual(replaced.operations[0], ("y90", (0,))) - self.assertEqual(replaced.operations[1], ("x", (0,))) + expected_ir.add_gate(y90(Qubit(0))) + expected_ir.add_gate(x(Qubit(0))) + + self.assertEqual(expected_ir, squirrel_ir) if __name__ == "__main__": diff --git a/test/test_testinterpreter.py b/test/test_testinterpreter.py index f9292893..de4f607a 100644 --- a/test/test_testinterpreter.py +++ b/test/test_testinterpreter.py @@ -4,20 +4,17 @@ import numpy as np from opensquirrel.circuit import Circuit -from opensquirrel.default_gates import DefaultGates class TestInterpreterTest(unittest.TestCase): def test_hadamard(self): circuit = Circuit.from_string( - DefaultGates, r""" version 3.0 qubit[1] q h q[0] -""", - ) +""") self.assertTrue( np.allclose( circuit.test_get_circuit_matrix(), @@ -33,20 +30,17 @@ def test_hadamard(self): def test_double_hadamard(self): circuit = Circuit.from_string( - DefaultGates, r""" version 3.0 qubit[1] q h q[0] h q[0] -""", - ) +""") self.assertTrue(np.allclose(circuit.test_get_circuit_matrix(), np.eye(2))) def test_triple_hadamard(self): circuit = Circuit.from_string( - DefaultGates, r""" version 3.0 qubit[1] q @@ -54,8 +48,7 @@ def test_triple_hadamard(self): h q[0] h q[0] h q[0] -""", - ) +""") self.assertTrue( np.allclose( circuit.test_get_circuit_matrix(), @@ -71,15 +64,13 @@ def test_triple_hadamard(self): def test_hadamard_x(self): circuit = Circuit.from_string( - DefaultGates, r""" version 3.0 qubit[2] q h q[0] x q[1] -""", - ) +""") self.assertTrue( np.allclose( circuit.test_get_circuit_matrix(), @@ -97,15 +88,13 @@ def test_hadamard_x(self): def test_x_hadamard(self): circuit = Circuit.from_string( - DefaultGates, r""" version 3.0 qubit[2] q h q[1] x q[0] -""", - ) +""") self.assertTrue( np.allclose( circuit.test_get_circuit_matrix(), @@ -123,15 +112,12 @@ def test_x_hadamard(self): def test_cnot(self): circuit = Circuit.from_string( - DefaultGates, r""" version 3.0 qubit[2] q cnot q[1], q[0] -""", - ) - +""") self.assertTrue( np.allclose( circuit.test_get_circuit_matrix(), @@ -148,14 +134,12 @@ def test_cnot(self): def test_cnot_reversed(self): circuit = Circuit.from_string( - DefaultGates, r""" version 3.0 qubit[2] q cnot q[0], q[1] -""", - ) +""") self.assertTrue( np.allclose( @@ -173,15 +157,13 @@ def test_cnot_reversed(self): def test_hadamard_cnot(self): circuit = Circuit.from_string( - DefaultGates, r""" version 3.0 qubit[2] q h q[0] cnot q[0], q[1] -""", - ) +""") self.assertTrue( np.allclose( @@ -200,15 +182,13 @@ def test_hadamard_cnot(self): def test_hadamard_cnot_0_2(self): circuit = Circuit.from_string( - DefaultGates, r""" version 3.0 qubit[3] q h q[0] cnot q[0], q[2] -""", - ) +""") print(circuit.test_get_circuit_matrix()) self.assertTrue( np.allclose( diff --git a/test/test_writer.py b/test/test_writer.py index 501f9da2..d45da49f 100644 --- a/test/test_writer.py +++ b/test/test_writer.py @@ -1,17 +1,15 @@ import unittest -from opensquirrel.default_gates import DefaultGates -from opensquirrel.squirrel_ast import SquirrelAST -from opensquirrel.writer import Writer +import opensquirrel.writer as writer +from opensquirrel.default_gates import * +from opensquirrel.squirrel_ir import Comment, Float, Qubit, SquirrelIR class WriterTest(unittest.TestCase): def test_write(self): - squirrelAST = SquirrelAST(DefaultGates, 3, "myqubitsregister") + squirrel_ir = SquirrelIR(3, "myqubitsregister") - writer = Writer(DefaultGates) - - written = writer.process(squirrelAST) + written = writer.squirrel_ir_to_string(squirrel_ir) self.assertEqual( written, @@ -22,10 +20,10 @@ def test_write(self): """, ) - squirrelAST.add_gate("h", 0) - squirrelAST.add_gate("cr", 0, 1, 1.234) + squirrel_ir.add_gate(h(Qubit(0))) + squirrel_ir.add_gate(cr(Qubit(0), Qubit(1), Float(1.234))) - written = writer.process(squirrelAST) + written = writer.squirrel_ir_to_string(squirrel_ir) self.assertEqual( written, @@ -38,17 +36,33 @@ def test_write(self): """, ) - def test_comment(self): - squirrelAST = SquirrelAST(DefaultGates, 3, "q") + def test_anonymous_gate(self): + squirrel_ir = SquirrelIR(1, "q") + + squirrel_ir.add_gate(cr(Qubit(0), Qubit(1), Float(1.234))) + squirrel_ir.add_gate(BlochSphereRotation(Qubit(0), axis=(1, 1, 1), angle=1.23)) + squirrel_ir.add_gate(cr(Qubit(0), Qubit(1), Float(1.234))) + + self.assertEqual( + writer.squirrel_ir_to_string(squirrel_ir), + """version 3.0 + +qubit[1] q + +cr q[0], q[1], 1.234 + +cr q[0], q[1], 1.234 +""") - writer = Writer(DefaultGates) + def test_comment(self): + squirrel_ir = SquirrelIR(3, "q") - squirrelAST.add_gate("h", 0) - squirrelAST.add_comment("My comment") - squirrelAST.add_gate("cr", 0, 1, 1.234) + squirrel_ir.add_gate(h(Qubit(0))) + squirrel_ir.add_comment(Comment("My comment")) + squirrel_ir.add_gate(cr(Qubit(0), Qubit(1), Float(1.234))) self.assertEqual( - writer.process(squirrelAST), + writer.squirrel_ir_to_string(squirrel_ir), """version 3.0 qubit[3] q @@ -58,25 +72,21 @@ def test_comment(self): /* My comment */ cr q[0], q[1], 1.234 -""", - ) +""") def test_cap_significant_digits(self): - squirrelAST = SquirrelAST(DefaultGates, 3, "q") - - writer = Writer(DefaultGates) + squirrel_ir = SquirrelIR(3, "q") - squirrelAST.add_gate("cr", 0, 1, 1.6546514861321684321654) + squirrel_ir.add_gate(cr(Qubit(0), Qubit(1), Float(1.6546514861321684321654))) self.assertEqual( - writer.process(squirrelAST), + writer.squirrel_ir_to_string(squirrel_ir), """version 3.0 qubit[3] q cr q[0], q[1], 1.6546515 -""", - ) +""") if __name__ == "__main__":