diff --git a/doc/code/lab_dev.rst b/doc/code/lab_dev.rst index 80112c235..0d5554d98 100644 --- a/doc/code/lab_dev.rst +++ b/doc/code/lab_dev.rst @@ -9,6 +9,5 @@ mrmustard.lab_dev lab_dev/states lab_dev/transformations lab_dev/circuits - lab_dev/simulator .. currentmodule:: mrmustard.lab_dev diff --git a/doc/code/lab_dev/simulator.rst b/doc/code/lab_dev/simulator.rst deleted file mode 100644 index 16ae4b041..000000000 --- a/doc/code/lab_dev/simulator.rst +++ /dev/null @@ -1,8 +0,0 @@ -mrmustard.lab_dev.simulator -=========================== - -.. currentmodule:: mrmustard.lab_dev.simulator - -.. automodapi:: mrmustard.lab_dev.simulator - :no-heading: - :include-all-objects: diff --git a/mrmustard/lab_dev/__init__.py b/mrmustard/lab_dev/__init__.py index adc9ebd8e..4ef839b6b 100644 --- a/mrmustard/lab_dev/__init__.py +++ b/mrmustard/lab_dev/__init__.py @@ -20,6 +20,5 @@ from .circuit_components_utils import * from .circuits import * from .states import * -from .simulator import * from .transformations import * from .wires import Wires diff --git a/mrmustard/lab_dev/circuit_components_utils/branch_and_bound.py b/mrmustard/lab_dev/circuit_components_utils/branch_and_bound.py new file mode 100644 index 000000000..ca44ff142 --- /dev/null +++ b/mrmustard/lab_dev/circuit_components_utils/branch_and_bound.py @@ -0,0 +1,518 @@ +# Copyright 2024 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Branch and bound algorithm for optimal contraction of a tensor network. +""" + +from __future__ import annotations +import random +from copy import deepcopy +from queue import PriorityQueue +from math import factorial +import numpy as np +from typing import Generator +import networkx as nx +from mrmustard.lab_dev.wires import Wires +from mrmustard.lab_dev.circuit_components import CircuitComponent + +Edge = tuple[int, int] + + +# ===================== +# ====== Classes ====== +# ===================== + + +class GraphComponent: + r""" + A lightweight "CircuitComponent" without the actual representation. + Basically a wrapper around Wires, so that it can emulate components in + a circuit. It exposes the repr, wires, shape, name and cost of obtaining + the component from previous contractions. + + Args: + repr: The name of the representation of the component. + wires: The wires of the component. + shape: The fock shape of the component. + name: The name of the component. + cost: The cost of obtaining this component. + """ + + def __init__(self, repr: str, wires: Wires, shape: list[int], name: str = "", cost: int = 0): + if None in shape: + raise ValueError("Detected `None`s in shape. Please provide a full shape.") + self.repr = repr + self.wires = wires + self.shape = list(shape) + self.name = name + self.cost = cost + + @classmethod + def from_circuitcomponent(cls, c: CircuitComponent): + r""" + Creates a GraphComponent from a CircuitComponent. + + Args: + c: A CircuitComponent. + """ + return GraphComponent( + repr=str(c.representation.__class__.__name__), + wires=Wires(*c.wires.args), + shape=c.auto_shape(), + name=c.__class__.__name__, + ) + + def contraction_cost(self, other: GraphComponent) -> int: + r""" + Returns the computational cost in approx FLOPS for contracting this component with another + one. Three cases are possible: + + 1. If both components are in Fock representation the cost is the product of the values along + both shapes, counting only once the shape of the contracted indices. E.g. a tensor with shape + (20,30,40) contracts its 1,2 indices with the 0,1 indices of a tensor with shape (30,40,50,60). + The cost is 20 x 30 x 40 x 50 x 60 = 72_000_000 (note 30,40 were counted only once). + + 2. If the representations are a mix of Bargmann and Fock, we add the cost of converting the + Bargmann to Fock, which is the product of the shape of the Bargmann component. + + 3. If both are in Bargmann representation, the contraction can be a simple a Gaussian integral + which scales like the cube of the number of contracted indices, i.e. ~ just 8 in the example above. + + Arguments: + other: GraphComponent + + Returns: + int: contraction cost in approx FLOPS + + """ + idxA, idxB = self.wires.contracted_indices(other.wires) + m = len(idxA) # same as len(idxB) + nA, nB = len(self.shape) - m, len(other.shape) - m + + if self.repr == "Bargmann" and other.repr == "Bargmann": + cost = ( # +1s to include vector part) + m * m * m # M inverse + + (m + 1) * m * nA # left matmul + + (m + 1) * m * nB # right matmul + + (m + 1) * m # addition + + m * m * m # determinant of M + ) + else: # otherwise need to use fock cost + prod_A = np.prod([s for i, s in enumerate(self.shape) if i not in idxA]) + prod_B = np.prod([s for i, s in enumerate(other.shape) if i not in idxB]) + prod_contracted = np.prod( + [min(self.shape[i], other.shape[j]) for i, j in zip(idxA, idxB)] + ) + cost = ( + prod_A * prod_B * prod_contracted # matmul + + np.prod(self.shape) * (self.repr == "Bargmann") # conversion + + np.prod(other.shape) * (other.repr == "Bargmann") # conversion + ) + return int(cost) + + def __matmul__(self, other) -> GraphComponent: + r""" + Returns the contracted GraphComponent. + + Args: + other: Another GraphComponent + """ + new_wires, perm = self.wires @ other.wires + idxA, idxB = self.wires.contracted_indices(other.wires) + shape_A = [n for i, n in enumerate(self.shape) if i not in idxA] + shape_B = [n for i, n in enumerate(other.shape) if i not in idxB] + shape = shape_A + shape_B + new_shape = [shape[p] for p in perm] + new_component = GraphComponent( + "Bargmann" if self.repr == other.repr == "Bargmann" else "Fock", + new_wires, + new_shape, + f"({self.name}@{other.name})", + self.contraction_cost(other) + self.cost + other.cost, + ) + return new_component + + def __repr__(self): + return f"{self.name}({self.shape}, {self.wires})" + + +class Graph(nx.DiGraph): + r""" + Power pack for nx.DiGraph with additional attributes and methods. + + Args: + solution: The sequence of edges contracted to obtain this graph. + costs: The costs of contracting each edge in the current solution. + """ + + def __init__(self, solution: tuple[Edge, ...] = (), costs: tuple[int, ...] = ()): + super().__init__() + self.solution = solution + self.costs = costs + + @property + def cost(self) -> int: + r""" + Returns the total cost of the graph. + """ + return sum(self.costs) + + def component(self, n) -> GraphComponent: + r""" + Returns the ``GraphComponent`` associated with a node. + + Args: + n: The node index. + """ + return self.nodes[n]["component"] + + def components(self) -> Generator[GraphComponent, None, None]: + r""" + Yields the ``GraphComponents`` associated with the nodes. + """ + for n in self.nodes: + yield self.component(n) + + def __lt__(self, other: Graph) -> bool: + r""" + Compares two graphs by their cost. Used for sorting in the priority queue. + + Args: + other: Another graph. + """ + return self.cost < other.cost + + def __hash__(self) -> int: + r""" + Returns a hash of the graph. + """ + return hash( + tuple(self.nodes) + + tuple(self.edges) + + tuple(self.solution) + + tuple(sum((c.shape for c in self.components()), start=[])) + ) + + +# ======================= +# ====== Functions ====== +# ======================= + + +def optimize_fock_shapes(graph: Graph, iteration: int, verbose: bool) -> Graph: + r""" + Iteratively optimizes the Fock shapes of the components in the graph. + + Args: + graph: The graph to optimize. + iteration: The iteration number. + verbose: Whether to print the progress. + """ + h = hash(graph) + for A, B in graph.edges: + wires_A = graph.nodes[A]["component"].wires + wires_B = graph.nodes[B]["component"].wires + idx_A, idx_B = wires_A.contracted_indices(wires_B) + # ensure at idx_i and idx_j the shapes are the minimum + for i, j in zip(idx_A, idx_B): + value = min( + graph.nodes[A]["component"].shape[i], + graph.nodes[B]["component"].shape[j], + ) + graph.nodes[A]["component"].shape[i] = value + graph.nodes[B]["component"].shape[j] = value + + for component in graph.components(): + if component.name == "BSgate": + a, b, c, d = component.shape + if c and d: + if not a or a > c + d: + a = c + d + if not b or b > c + d: + b = c + d + if a and b: + if not c or c > a + b: + c = a + b + if not d or d > a + b: + d = a + b + component.shape = [a, b, c, d] + + if h != hash(graph) and verbose: + print(f"Iteration {iteration}: graph updated") + graph = optimize_fock_shapes(graph, iteration + 1, verbose) + return graph + + +def parse_components(components: list[CircuitComponent]) -> Graph: + r""" + Parses a list of CircuitComponents into a Graph. + + Each node in the graph corresponds to a GraphComponent and an edge between two nodes indicates that + the GraphComponents are connected in the circuit. Whether they are connected by one wire + or by many, in the graph they will have a single edge between them. + + Args: + components: A list of CircuitComponents. + """ + validate_components(components) + graph = Graph() + for i, A in enumerate(components): + comp = GraphComponent.from_circuitcomponent(A) + wires = Wires(*A.wires.args) + comp.wires = wires + for j, B in enumerate(components[i + 1 :]): + ovlp_bra, ovlp_ket = wires.overlap(B.wires) + if ovlp_ket or ovlp_bra: + graph.add_edge(i, i + j + 1) + wires = Wires( + wires.args[0] - ovlp_bra, + wires.args[1], + wires.args[2] - ovlp_ket, + wires.args[3], + ) + if not wires.output: + break + graph.add_node(i, component=comp) + return graph + + +def validate_components(components: list[CircuitComponent]) -> None: + r""" + Raises an error if the components will not contract correctly. + + Args: + components: A list of CircuitComponents. + """ + if len(components) == 0: + return + w = components[0].wires + for comp in components[1:]: + w = (w @ comp.wires)[0] + + +def contract(graph: Graph, edge: Edge, debug: int = 0) -> Graph: + r""" + Contracts an edge in a graph and returns the contracted graph. + Makes a copy of the original graph. + + Args: + graph (Graph): A graph. + edge (tuple[int, int]): An edge to contract. + debug (int): Whether to print debug information. + + Returns: + Graph: A new graph with the contracted edge. + """ + new_graph = nx.contracted_edge(graph, edge, self_loops=False, copy=True) + A = graph.nodes[edge[0]]["component"] + B = graph.nodes[edge[1]]["component"] + if debug > 0: + print(f"A wires: {A.wires}, B wires: {B.wires}") + new_graph.nodes[edge[0]]["component"] = A @ B + new_graph.costs = graph.costs + (graph.edges[edge]["cost"],) + new_graph.solution = graph.solution + (edge,) + assign_costs(new_graph) + return new_graph + + +def children(graph: Graph, cost_bound: int) -> set[Graph]: + r""" + Returns a set of graphs obtained by contracting each edge. + Only graphs with a cost below ``cost_bound`` are returned. + Two nodes are contracted by removing the edge between them and merging + the two nodes into a single node. The shape of the new node + is the union of the shapes of the two nodes without the wires that were + contracted (this is all handled by the wires). + + Args: + graph (Graph): A graph. + cost_bound (int): The maximum cost of the children. + + Returns: + set[Graph]: The set of graphs obtained by contracting each edge. + """ + children_set = set() + for edge in sorted(graph.out_edges, key=lambda e: graph.out_edges[e]["cost"]): + if graph.cost + graph.edges[edge]["cost"] < cost_bound: + children_set.add(contract(graph, edge)) + return children_set + + +def grandchildren(graph: Graph, cost_bound: int) -> set[Graph]: + r""" + A set of grandchildren constructed from each child's children. + Only grandchildren with a cost below ``cost_bound`` are returned. + Note that children without further children are included, so with + a single call to this function we get all the descendants up to + the grandchildren level including leaf nodes whether they are + children or grandchildren. + + Args: + graph (Graph): A graph. + cost_bound (int): The maximum cost of the grandchildren + + Returns: + set[Graph]: The set of grandchildren below the cost bound. + """ + grandchildren_set = set() + for child in children(graph, cost_bound): + if child.number_of_edges() == 0: + grandchildren_set.add(child) + continue + for grandchild in children(child, cost_bound): + grandchildren_set.add(grandchild) + return grandchildren_set + + +def assign_costs(graph: Graph, debug: int = 0) -> None: + r""" + Assigns to each edge in the graph the cost of contracting the two nodes it connects. + + Args: + graph (Graph): A graph. + debug (int): Whether to print debug information + """ + for edge in graph.edges: + A = graph.nodes[edge[0]]["component"] + B = graph.nodes[edge[1]]["component"] + graph.edges[edge]["cost"] = A.contraction_cost(B) + if debug > 0: + print( + f"cost of edge {edge}: {A.repr}|{A.shape} x {B.repr}|{B.shape} = {graph.edges[edge]['cost']}" + ) + + +def random_solution(graph: Graph) -> Graph: + r""" + Returns a random solution to contract a given graph. + + Args: + graph (Graph): The initial graph. + + Returns: + Graph: The contracted graph + """ + while graph.number_of_edges() > 0: + edge = random.choice(list(graph.edges)) + graph = contract(graph, edge) + return graph + + +def reduce_first(graph: Graph, code: str) -> tuple[Graph, Edge | bool]: + r""" + Reduces the first pair of nodes that match the pattern in the code. + The first number and letter describe a node with that number of + edges and that repr (B for Bargmann, F for Fock), and the last letter + describes the repr of the second node. + For example 1BB means we will contract the first occurrence of a node + that has one edge (a leaf) connected to a node of repr B with an arbitrary + number of edges. + We typically use codes like 1BB, 2BB, 1FF, 2FF by default because they are + safe, and codes like 1BF, 1FB optionally as they are not always the best choice. + + Args: + graph: A graph. + code: A pattern indicating the type of nodes to contract. + """ + n, tA, tB = code + for node in graph.nodes: + if int(n) == graph.degree(node): + for edge in list(graph.out_edges(node)) + list(graph.in_edges(node)): + A = graph.nodes[edge[0]]["component"] + B = graph.nodes[edge[1]]["component"] + if A.repr[0] == tA and B.repr[0] == tB: + graph = contract(graph, edge) + return graph, edge + return graph, False + + +def heuristic(graph: Graph, code: str, verbose: bool) -> Graph: + r""" + Simplifies the graph by contracting all pairs of nodes that match the given pattern. + + Args: + graph: A graph. + code: A pattern indicating the type of nodes to contract. + verbose: Whether to print the progress. + """ + edge = True + while edge: + graph, edge = reduce_first(graph, code) + if edge and verbose: + print(f"{code} edge: {edge} | tot cost: {graph.cost}") + return graph + + +def optimal_contraction( + graph: Graph, + n_init: int, + heuristics: tuple[str, ...], + verbose: bool, +) -> Graph: + r""" + Finds the optimal path to contract a graph. + + Args: + graph: The graph to contract. + n_init: The number of random contractions to find an initial cost upper bound. + heuristics: A sequence of patterns to reduce in order. + verbose: Whether to print the progress. + + Returns: + The optimally contracted graph with associated cost and solution + """ + assign_costs(graph) + if verbose: + print("\n===== Simplify graph via heuristics =====") + for code in heuristics: + graph = heuristic(graph, code, verbose) + if graph.number_of_edges() == 0: + return graph + + if verbose: + print(f"\n===== Branch and bound ({factorial(len(graph.nodes)):_d} paths) =====") + best = Graph(costs=(np.inf,)) # will be replaced by first random contraction + for _ in range(n_init): + rand = random_solution(deepcopy(graph)) + best = rand if rand.cost < best.cost else best + if verbose: + print( + f"Best cost from {n_init} random contractions: {best.cost}. Solution: {best.solution}\n" + ) + + queue = PriorityQueue() + queue.put(graph) + while not queue.empty(): + candidate = queue.get() + if verbose: + print( + f"Queue: {queue.qsize()}/{queue.unfinished_tasks} | cost: {candidate.cost} | solution: {candidate.solution}", + end="\x1b[1K\r", + ) + + if candidate.cost >= best.cost: + if verbose: + print("warning: early stop") + return candidate # early stopping because first in queue is already worse + elif candidate.number_of_edges() == 0: # better solution! πŸ₯³ + best = candidate + queue.queue = [g for g in queue.queue if g.cost < best.cost] # prune + else: + for g in grandchildren(candidate, cost_bound=best.cost): + if g not in queue.queue: + queue.put(g) + if verbose: + print(f"\n\nFinal path: best cost = {best.cost}. Solution is {best.solution}") + return best diff --git a/mrmustard/lab_dev/circuits.py b/mrmustard/lab_dev/circuits.py index 4b530ef36..711ea3b73 100644 --- a/mrmustard/lab_dev/circuits.py +++ b/mrmustard/lab_dev/circuits.py @@ -12,21 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. -# pylint: disable=too-many-branches """ -A class to quantum circuits. +A class to simulate quantum circuits. """ from __future__ import annotations from collections import defaultdict -from pydoc import locate +from copy import deepcopy from typing import Sequence +from pydoc import locate from mrmustard import math, settings from mrmustard.utils.serialize import save from mrmustard.lab_dev.circuit_components import CircuitComponent -from mrmustard.lab_dev.transformations import BSgate +import mrmustard.lab_dev.circuit_components_utils.branch_and_bound as bb __all__ = ["Circuit"] @@ -38,8 +38,9 @@ class Circuit: are contracted is specified by the ``path`` attribute. Different orders of contraction lead to the same result, but the cost of the contraction - can vary significantly. The ``path`` attribute is used to specify the order in which the - components are contracted. + can vary significantly. The ``optimize`` method optimizes the Fock shapes and the contraction + path of the circuit, while the ``contract`` method contracts the components in the order + specified by the ``path`` attribute. .. code-block:: @@ -73,162 +74,91 @@ class Circuit: components: A list of circuit components. """ - def __init__(self, components: Sequence[CircuitComponent] | None = None) -> None: - self._components = [c._light_copy() for c in components] if components else [] - self._path = [] - - # a dictionary to keep track of the underlying graph, mapping the ``ids`` of output wires - # to the ``ids`` of the input wires that they are being contracted with. It is initialized - # automatically (once and for all) when a path is validated for the first time. - self._graph: dict[int, int] = {} - - @property - def indices_dict(self) -> dict[int, dict[int, dict[int, int]]]: + def __init__( + self, + components: Sequence[CircuitComponent] | None = None, + ) -> None: + self.components = [c._light_copy() for c in components] if components else [] + self._graph = bb.parse_components(self.components) + self.path: list[tuple[int, int]] = [ + (0, i) for i in range(1, len(self.components)) + ] # default path (likely not optimal) + + def optimize( + self, n_init: int = 100, with_BF_heuristic: bool = True, verbose: bool = True + ) -> None: r""" - A dictionary that maps the index of each component to all the components it is connected to. - For each connected component, the value is a dictionary with a key-value pair for each component - connected to the first one, where the key is the index of this component and the value is - a dictionary with all the wire index pairs that are being contracted between the two components. - - For example, if components[i] is connected to components[j] and they are contracting two wires - at index pairs (a, b) and (c, d), then indices_dict[i][j] = {a: b, c:d}. + Optimizes the Fock shapes and the contraction path of this circuit. + It allows one to exclude the 1BF and 1FB heuristic in case contracting 1-wire Fock/Bagmann + components with multimode Bargmann/Fock components leads to a higher total cost. - This dictionary is used to propagate the shapes of the components in the circuit. + Args: + n_init: The number of random contractions to find an initial cost upper bound. + with_BF_heuristic: If True (default), the 1BF/1FB heuristics are included in the optimization process. + verbose: If True (default), the progress of the optimization is shown. """ - if not hasattr(self, "_idxdict"): - self._idxdict = self._indices_dict() - return self._idxdict - - def _indices_dict(self): - ret = {} - for i, opA in enumerate(self.components): - out_idx = set(opA.wires.output.indices) - indices: dict[int, dict[int, int]] = {} - for j, opB in enumerate(self.components[i + 1 :]): - ovlp_bra = opA.wires.output.bra.modes & opB.wires.input.bra.modes - ovlp_ket = opA.wires.output.ket.modes & opB.wires.input.ket.modes - if not (ovlp_bra or ovlp_ket): - continue - iA = opA.wires.output.bra[ovlp_bra].indices + opA.wires.output.ket[ovlp_ket].indices - iB = opB.wires.input.bra[ovlp_bra].indices + opB.wires.input.ket[ovlp_ket].indices - if not out_idx.intersection(iA): - continue - indices[i + j + 1] = dict(zip(iA, iB)) - out_idx -= set(iA) - if not out_idx: - break - ret[i] = indices - return ret - - def propagate_shapes(self): - r"""Propagates the shape information so that the shapes of the components are better - than those provided by the auto_shape attribute. - - .. code-block:: - - >>> from mrmustard.lab_dev import BSgate, Dgate, Coherent, Circuit, SqueezedVacuum - >>> from mrmustard import settings - >>> settings.AUTOSHAPE_PROBABILITY = 0.999 + graph = deepcopy(self._graph) + bb.assign_costs(graph) + G = bb.random_solution(graph) + if len(G.nodes) > 1: + raise ValueError("Circuit has disconnected components.") + + i = list(G.nodes)[0] + if len(G.nodes[i]["component"].wires) > 0: + raise NotImplementedError("Cannot optimize a circuit with dangling wires yet.") + + self.optimize_fock_shapes(verbose) + heuristics = ( + ("1BB", "2BB", "1BF", "1FB", "1FF", "2FF") + if with_BF_heuristic + else ("1BB", "2BB", "1FF", "2FF") + ) + optimized_graph = bb.optimal_contraction(self._graph, n_init, heuristics, verbose) + self.path = list(optimized_graph.solution) + + def contract(self) -> CircuitComponent: + r""" + Contracts the components in this circuit in the order specified by the ``path`` attribute. - >>> circ = Circuit([Coherent([0], x=1.0), Dgate([0], 0.1)]) - >>> assert [op.auto_shape() for op in circ] == [(5,), (50,50)] - >>> circ.propagate_shapes() - >>> assert [op.auto_shape() for op in circ] == [(5,), (50, 5)] + Returns: + The result of contracting the circuit. - >>> circ = Circuit([SqueezedVacuum([0,1], r=[0.5,-0.5]), BSgate([0,1], 0.9)]) - >>> assert [op.auto_shape() for op in circ] == [(6, 6), (50, 50, 50, 50)] - >>> circ.propagate_shapes() - >>> assert [op.auto_shape() for op in circ] == [(6, 6), (12, 12, 6, 6)] + Raises: + ValueError: If ``circuit`` has an incomplete path. """ + if len(self.path) != len(self) - 1: + msg = f"``circuit.path`` needs to specify {len(self) - 1} contractions, found " + msg += ( + f"{len(self.path)}. Please run the ``.optimize()`` method or set the path manually." + ) + raise ValueError(msg) - for component in self: - component.manual_shape = list(component.auto_shape()) - - # update the manual_shapes until convergence - changes = self._update_shapes() - while changes: - changes = self._update_shapes() - - def _update_shapes(self) -> bool: - r"""Updates the shapes of the components in the circuit graph by propagating the known shapes. - If two wires are connected and one of them has shape n and the other None, the shape of the - wire with None is updated to n. If both wires have a shape, the minimum is taken. - - For a BSgate, we apply the rule that the sum of the shapes of the inputs must be equal to the sum of - the shapes of the outputs. + ret = dict(enumerate(self.components)) + for idx0, idx1 in self.path: + ret[idx0] = ret[idx0] >> ret.pop(idx1) - It returns True if any shape was updated, False otherwise. - """ - changes = False - # get shapes from neighbors if needed - for i, component in enumerate(self.components): - for j, indices in self.indices_dict[i].items(): - for a, b in indices.items(): - s_ia = self.components[i].manual_shape[a] - s_jb = self.components[j].manual_shape[b] - s = min(s_ia or 1e42, s_jb or 1e42) if (s_ia or s_jb) else None - if self.components[j].manual_shape[b] != s: - self.components[j].manual_shape[b] = s - changes = True - if self.components[i].manual_shape[a] != s: - self.components[i].manual_shape[a] = s - changes = True - - # propagate through BSgates - for i, component in enumerate(self.components): - if isinstance(component, BSgate): - a, b, c, d = component.manual_shape - if c and d: - if not a or a > c + d: - a = c + d - changes = True - if not b or b > c + d: - b = c + d - changes = True - if a and b: - if not c or c > a + b: - c = a + b - changes = True - if not d or d > a + b: - d = a + b - changes = True - - self.components[i].manual_shape = [a, b, c, d] - - return changes - - @property - def components(self) -> Sequence[CircuitComponent]: - r""" - The components in this circuit. - """ - return self._components + return list(ret.values())[0] - @property - def path(self) -> list[tuple[int, int]]: + def optimize_fock_shapes(self, verbose: bool) -> None: r""" - A list describing the desired contraction path followed by the ``Simulator``. + Optimizes the Fock shapes of the components in this circuit. + It starts by matching the existing connected wires and keeps the smaller shape, + then it enforces the BSgate symmetry (conservation of photon number) to further + reduce the shapes across the circuit. + This operation acts in place. """ - return self._path + if verbose: + print("===== Optimizing Fock shapes =====") + self._graph = bb.optimize_fock_shapes(self._graph, 0, verbose) + for i, c in enumerate(self.components): + c.manual_shape = self._graph.component(i).shape - @path.setter - def path(self, value: list[tuple[int, int]]) -> None: + def check_contraction(self, n: int) -> None: r""" - A function to set the path. + An auxiliary function that helps visualize the contraction path of the circuit. - In addition to setting the path, it validates it using the ``validate_path`` method. - - Args: - path: The path. - """ - self.validate_path(value) - self._path = value - - def lookup_path(self) -> None: - r""" - An auxiliary function that helps building the contraction path for this circuit. - - Shows the remaining components and the corresponding contraction indices. + Shows the remaining components and the corresponding contraction indices after n + of the contractions in ``self.path``. .. code-block:: @@ -243,7 +173,7 @@ def lookup_path(self) -> None: >>> # ``circ`` has no path: all the components are available, and indexed >>> # as they appear in the list of components - >>> circ.lookup_path() + >>> circ.check_contraction(0) # no contractions β†’ index: 0 mode 0: β—–Vacβ—— @@ -268,9 +198,10 @@ def lookup_path(self) -> None: - >>> # start building the path - >>> circ.path = [(0, 1)] - >>> circ.lookup_path() + >>> # start building the path manually + >>> circ.path = ((0, 1), (2, 3), (0, 2)) + + >>> circ.check_contraction(1) # after 1 contraction β†’ index: 0 mode 0: β—–Vac◗──S(0.1,0.0) @@ -290,8 +221,7 @@ def lookup_path(self) -> None: - >>> circ.path = [(0, 1), (2, 3)] - >>> circ.lookup_path() + >>> circ.check_contraction(2) # after 2 contractions β†’ index: 0 mode 0: β—–Vac◗──S(0.1,0.0) @@ -307,8 +237,7 @@ def lookup_path(self) -> None: - >>> circ.path = [(0, 1), (2, 3), (0, 2)] - >>> circ.lookup_path() + >>> circ.check_contraction(3) # after 3 contractions β†’ index: 0 mode 0: β—–Vac◗──S(0.1,0.0)──╭‒──────────────────────── @@ -323,7 +252,7 @@ def lookup_path(self) -> None: ValueError: If ``circuit.path`` contains invalid contractions. """ remaining = {i: Circuit([c]) for i, c in enumerate(self.components)} - for idx0, idx1 in self.path: + for idx0, idx1 in self.path[:n]: try: left = remaining[idx0].components right = remaining.pop(idx1).components @@ -340,97 +269,6 @@ def lookup_path(self) -> None: print(msg) - def make_path(self, strategy: str = "l2r") -> None: - r""" - Automatically generates a path for this circuit. - - The available strategies are: - * ``l2r``: The two left-most components are contracted together, then the - resulting component is contracted with the third one from the left, et cetera. - * ``r2l``: The two right-most components are contracted together, then the - resulting component is contracted with the third one from the right, et cetera. - - Args: - strategy: The strategy used to generate the path. - """ - if strategy == "l2r": - self.path = [(0, i) for i in range(1, len(self))] - elif strategy == "r2l": - self.path = [(i, i + 1) for i in range(len(self) - 2, -1, -1)] - else: - msg = f"Strategy ``{strategy}`` is not available." - raise ValueError(msg) - - def validate_path(self, path) -> None: - r""" - A convenience function to check whether a given contraction path is valid for this circuit. - - Uses the wires' ``ids`` to understand what pairs of wires would be contracted, if the - simulation was carried from left to right. Next, it checks whether ``path`` is an equivalent - contraction path, meaning that it instructs to contract the same wires as a ``l2r`` path. - - Args: - path: A candidate contraction path. - - Raises: - ValueError: If the given path is not equivalent to a left-to-right path. - """ - wires = [c.wires for c in self.components] - - # if at least one of the ``Wires`` has wires on the bra side, add the adjoint - # to all the other ``Wires`` - add_adjoints = len(set(bool(w.bra) for w in wires)) != 1 - if add_adjoints: - wires = [(w @ w.adjoint)[0] if bool(w.bra) is False else w for w in wires] - - # if the circuit has no graph, compute it - if not self._graph: - # a dictionary to store the ``ids`` of the dangling wires - ids_dangling_wires = {m: {"ket": None, "bra": None} for w in wires for m in w.modes} - - # populate the graph - for w in wires: - # if there is a dangling wire, add a contraction - for m in w.input.ket.modes: # ket side - if ids_dangling_wires[m]["ket"]: - self._graph[ids_dangling_wires[m]["ket"]] = w.input.ket[m].ids[0] - ids_dangling_wires[m]["ket"] = None - for m in w.input.bra.modes: # bra side - if ids_dangling_wires[m]["bra"]: - self._graph[ids_dangling_wires[m]["bra"]] = w.input.bra[m].ids[0] - ids_dangling_wires[m]["bra"] = None - - # update the dangling wires - for m in w.output.ket.modes: # ket side - if w.output.ket[m].ids: - if ids_dangling_wires[m]["ket"]: - raise ValueError("Dangling wires cannot be overwritten.") - ids_dangling_wires[m]["ket"] = w.output.ket[m].ids[0] - for m in w.output.bra.modes: # bra side - if w.output.bra[m].ids: - if ids_dangling_wires[m]["bra"]: - raise ValueError("Dangling wires cannot be overwritten.") - ids_dangling_wires[m]["bra"] = w.output.bra[m].ids[0] - - # use ``self._graph`` to validate the path - remaining = dict(enumerate(wires)) - for i1, i2 in path: - overlap_ket = remaining[i1].output.ket.modes & remaining[i2].input.ket.modes - for m in overlap_ket: - key = remaining[i1].output.ket[m].ids[0] - val = remaining[i2].input.ket[m].ids[0] - if self._graph[key] != val: - raise ValueError(f"The contraction ``{(i1, i2)}`` is invalid.") - - overlap_bra = remaining[i1].output.bra.modes & remaining[i2].input.bra.modes - for m in overlap_bra: - key = remaining[i1].output.bra[m].ids[0] - val = remaining[i2].input.bra[m].ids[0] - if self._graph[key] != val: - raise ValueError(f"The contraction ``{i1, i2}`` is invalid.") - - remaining[i1] = (remaining[i1] @ remaining.pop(i2))[0] - def serialize(self, filestem: str = None): r""" Serialize a Circuit. @@ -457,18 +295,19 @@ def deserialize(cls, data: dict) -> Circuit: classes: list[CircuitComponent] = [locate(c.pop("class")) for c in comps] circ = cls([c._deserialize(comp_data) for c, comp_data in zip(classes, comps)]) - if path: # re-evaluates the hidden `_graph` property - circ.path = [tuple(p) for p in path] + circ.path = [tuple(p) for p in path] return circ def __eq__(self, other: Circuit) -> bool: + if not isinstance(other, Circuit): + return false return self.components == other.components def __getitem__(self, idx: int) -> CircuitComponent: r""" The component in position ``idx`` of this circuit's components. """ - return self._components[idx] + return self.components[idx] def __len__(self): r""" diff --git a/mrmustard/lab_dev/simulator.py b/mrmustard/lab_dev/simulator.py deleted file mode 100644 index aa1b59887..000000000 --- a/mrmustard/lab_dev/simulator.py +++ /dev/null @@ -1,115 +0,0 @@ -# Copyright 2023 Xanadu Quantum Technologies Inc. - -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Simulators for quantum circuits. -""" - -from __future__ import annotations - -from .circuit_components import CircuitComponent -from .circuits import Circuit - -__all__ = ["Simulator"] - - -class Simulator: - r""" - A simulator for quantum circuits. - - Circuits can be simulated by using the ``run`` method of ``Simulator``: - - .. code-block:: - - >>> from mrmustard.lab_dev import * - >>> import numpy as np - - >>> # initialize a circuit - >>> state = Number(modes=[0, 1], n=[2, 0], cutoffs=2) - >>> gate = BSgate([0, 1], theta=np.pi/4) - >>> proj1 = Number(modes=[1], n=[0]).dual - >>> circuit = Circuit([state, gate, proj1]) - - >>> # run the simulation - >>> result = Simulator().run(circuit) - - >>> # the simulator returns a component that can be potentially be plugged - >>> # into another circuit - >>> assert isinstance(result, CircuitComponent) - - The simulation is carried out by contracting the components of the given circuit in pairs, - until only one component is left and returned. In the examples above, the contractions happen - in a "left-to-right" fashion, meaning that the left-most component in the circuit (``state``) - is contracted with the one in its right (``gate``), and finally the resulting component is - contracted with the projector. This provides a simple and convenient way to run simulations, - but for large circuits, different contraction paths may be more efficient. - - The ``path`` attribute of ``Circuit``\s allows customising the contraction order and potentially - speeding up the simulation. When a ``path`` of the type ``[(i, j), (l, m), ...]`` is given, the - simulator creates a dictionary of the type ``{0: c0, ..., N: cN}``, where ``[c0, .., cN]`` - is the ``circuit.component`` list. Then: - - * The two components ``ci`` and ``cj`` in positions ``i`` and ``j`` are contracted. ``ci`` is - replaced by the resulting component ``cj >> cj``, while ``cj`` is popped. - * The two components ``cl`` and ``cm`` in positions ``l`` and ``m`` are contracted. ``cl`` is - replaced by the resulting component ``cl >> cm``, while ``cm`` is popped. - * Et cetera. - - Below is an example where a circuit is simulated in a "right-to-left" fashion: - - .. code-block:: - - >>> from mrmustard.lab_dev import * - >>> import numpy as np - - >>> state = Number(modes=[0, 1], n=[2, 0], cutoffs=2) - >>> gate = BSgate([0, 1], theta=np.pi/4) - >>> proj01 = Number(modes=[0, 1], n=[2, 0]).dual - - >>> # initialize the circuit and specify a custom path - >>> circuit = Circuit([state, gate, proj01]) - >>> circuit.path = [(1, 2), (0, 1)] - - >>> result = Simulator().run(circuit) - - The setter for ``path`` also validates the path using the ``validate_path`` function of - ``Circuit``. - """ - - def run(self, circuit: Circuit) -> CircuitComponent: - r""" - Runs the simulations of the given circuit. - - Arguments: - circuit: The circuit to simulate. - - Returns: - A circuit component representing the entire circuit. - - Raises: - ValueError: If ``circuit`` has an incomplete path. - """ - if not circuit.path: - circuit.make_path() - - if len(circuit.path) != len(circuit) - 1: - msg = f"``circuit.path`` needs to specify {len(circuit) - 1} contractions, " - msg += f"found {len(circuit.path)}." - raise ValueError(msg) - - ret = dict(enumerate(circuit.components)) - for idx0, idx1 in circuit.path: - ret[idx0] = ret[idx0] >> ret.pop(idx1) - - return list(ret.values())[0] diff --git a/mrmustard/lab_dev/states/two_mode_squeezed_vacuum.py b/mrmustard/lab_dev/states/two_mode_squeezed_vacuum.py index e04f74686..7cb535ee0 100644 --- a/mrmustard/lab_dev/states/two_mode_squeezed_vacuum.py +++ b/mrmustard/lab_dev/states/two_mode_squeezed_vacuum.py @@ -18,6 +18,8 @@ from __future__ import annotations +from typing import Sequence + from mrmustard.physics.representations import Bargmann from mrmustard.physics import triples from .base import Ket @@ -52,7 +54,7 @@ class TwoModeSqueezedVacuum(Ket): def __init__( self, - modes: tuple[int, int], + modes: Sequence[int], r: float = 0.0, phi: float = 0.0, r_trainable: bool = False, diff --git a/mrmustard/lab_dev/transformations/bsgate.py b/mrmustard/lab_dev/transformations/bsgate.py index d6fb8ea95..71df7cd21 100644 --- a/mrmustard/lab_dev/transformations/bsgate.py +++ b/mrmustard/lab_dev/transformations/bsgate.py @@ -18,6 +18,8 @@ from __future__ import annotations +from typing import Sequence + from .base import Unitary from ...physics.representations import Bargmann from ...physics import triples @@ -88,7 +90,7 @@ class BSgate(Unitary): def __init__( self, - modes: tuple[int, int], + modes: Sequence[int], theta: float = 0.0, phi: float = 0.0, theta_trainable: bool = False, diff --git a/mrmustard/lab_dev/transformations/s2gate.py b/mrmustard/lab_dev/transformations/s2gate.py index 1d1f8f756..32d30a37f 100644 --- a/mrmustard/lab_dev/transformations/s2gate.py +++ b/mrmustard/lab_dev/transformations/s2gate.py @@ -18,6 +18,8 @@ from __future__ import annotations +from typing import Sequence + from .base import Unitary from ...physics.representations import Bargmann from ...physics import triples @@ -71,7 +73,7 @@ class S2gate(Unitary): def __init__( self, - modes: tuple[int, int], + modes: Sequence[int], r: float = 0.0, phi: float = 0.0, r_trainable: bool = False, diff --git a/mrmustard/lab_dev/wires.py b/mrmustard/lab_dev/wires.py index e1bc43a10..7f896df1d 100644 --- a/mrmustard/lab_dev/wires.py +++ b/mrmustard/lab_dev/wires.py @@ -64,17 +64,17 @@ class Wires: β•šβ•β•β•β•β•β•β•β•β•β• β””β”€β”€β”€β”€β”€β”€β”€β”˜ - A ket representing the state of mode ``1`` has only output wires: + A ket representing the state of mode ``1`` has only output wires: ╔═════════╗ 1 β”Œβ”€β”€β”€β”€β”€β”€β”€β” β•‘ Ket ║─────▢│Ket outβ”‚ β•šβ•β•β•β•β•β•β•β•β•β• β””β”€β”€β”€β”€β”€β”€β”€β”˜ - A measurement acting on mode ``0`` has input wires on the ket side and classical output wires: + A measurement acting on mode ``0`` has input wires on the ket side and classical output wires: - β”Œβ”€β”€β”€β”€β”€β”€β” 0 ╔═════════════╗ 0 β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” - β”‚Ket in│─────▢║ Measurement ║─────▢│Classical outβ”‚ - β””β”€β”€β”€β”€β”€β”€β”˜ β•šβ•β•β•β•β•β•β•β•β•β•β•β•β•β• β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ + β”Œβ”€β”€β”€β”€β”€β”€β” 0 ╔═════════════╗ 0 β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” + β”‚Ket in│─────▢║ Measurement ║─────▢│Classical outβ”‚ + β””β”€β”€β”€β”€β”€β”€β”˜ β•šβ•β•β•β•β•β•β•β•β•β•β•β•β•β• β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ The ``Wires`` class can then be used to create subsets of wires: @@ -193,7 +193,12 @@ def adjoint(self) -> Wires: New ``Wires`` object obtained by swapping ket and bra wires. """ return Wires( - self.args[2], self.args[3], self.args[0], self.args[1], self.args[4], self.args[5] + self.args[2], + self.args[3], + self.args[0], + self.args[1], + self.args[4], + self.args[5], ) @cached_property @@ -234,7 +239,12 @@ def dual(self) -> Wires: New ``Wires`` object obtained by swapping input and output wires. """ return Wires( - self.args[1], self.args[0], self.args[3], self.args[2], self.args[5], self.args[4] + self.args[1], + self.args[0], + self.args[3], + self.args[2], + self.args[5], + self.args[4], ) @cached_property @@ -363,6 +373,29 @@ def sorted_args(self) -> tuple[list[int], ...]: """ return tuple(sorted(s) for s in self.args) + def contracted_indices(self, other: Wires): + r""" + Returns the indices being contracted between self and other when calling matmul. + + Args: + other: another Wires object + """ + ovlp_bra, ovlp_ket = self.overlap(other) + idxA = self.output.bra[ovlp_bra].indices + self.output.ket[ovlp_ket].indices + idxB = other.input.bra[ovlp_bra].indices + other.input.ket[ovlp_ket].indices + return idxA, idxB + + def overlap(self, other: Wires) -> tuple[set[int], set[int]]: + r""" + Returns the modes that overlap between the two ``Wires`` objects. + + Args: + other: Another ``Wires`` object. + """ + ovlp_ket = self.output.ket.modes & other.input.ket.modes + ovlp_bra = self.output.bra.modes & other.input.bra.modes + return ovlp_bra, ovlp_ket + def __add__(self, other: Wires) -> Wires: r""" New ``Wires`` object that combines the wires of ``self`` and those of ``other``. diff --git a/tests/test_lab_dev/test_circuits.py b/tests/test_lab_dev/test_circuits.py index 5a1c9e38b..550ce6a57 100644 --- a/tests/test_lab_dev/test_circuits.py +++ b/tests/test_lab_dev/test_circuits.py @@ -29,6 +29,7 @@ ) from mrmustard import settings from mrmustard.utils.serialize import load +import mrmustard.lab_dev.circuit_components_utils.branch_and_bound as bb class TestCircuit: @@ -44,52 +45,34 @@ def test_init(self): circ1 = Circuit([vac, s01, bs01, bs12]) assert circ1.components == [vac, s01, bs01, bs12] - assert circ1.path == [] + assert circ1.path == [(0, 1), (0, 2), (0, 3)] circ2 = Circuit() >> vac >> s01 >> bs01 >> bs12 assert circ2.components == [vac, s01, bs01, bs12] - assert circ2.path == [] + assert circ2.path == [(0, 1), (0, 2), (0, 3)] def test_propagate_shapes(self): MAX = settings.AUTOSHAPE_MAX settings.AUTOSHAPE_PROBABILITY = 0.999 circ = Circuit([Coherent([0], x=1.0), Dgate([0], 0.1)]) assert [op.auto_shape() for op in circ] == [(5,), (MAX, MAX)] - circ.propagate_shapes() + circ.optimize_fock_shapes(verbose=False) assert [op.auto_shape() for op in circ] == [(5,), (MAX, 5)] - circ = Circuit([SqueezedVacuum([0, 1], r=[0.5, -0.5]), BSgate([0, 1], 0.9)]) + circ = Circuit([SqueezedVacuum([0, 1], r=[0.5, -0.5]), BSgate((0, 1), 0.9)]) assert [op.auto_shape() for op in circ] == [(6, 6), (MAX, MAX, MAX, MAX)] - circ.propagate_shapes() + circ.optimize_fock_shapes(verbose=True) assert [op.auto_shape() for op in circ] == [(6, 6), (12, 12, 6, 6)] - settings.AUTOSHAPE_PROBABILITY = 0.99999 - def test_make_path(self): - vac = Vacuum([0, 1, 2]) - s01 = Sgate([0, 1]) - bs01 = BSgate([0, 1]) - bs12 = BSgate([1, 2]) - - circ = Circuit([vac, s01, bs01, bs12]) - - circ.make_path("l2r") - assert circ.path == [(0, 1), (0, 2), (0, 3)] - - circ.make_path("r2l") - assert circ.path == [(2, 3), (1, 2), (0, 1)] - - with pytest.raises(ValueError): - circ.make_path("my_strategy") - def test_lookup_path(self, capfd): vac = Vacuum([0, 1, 2]) s01 = Sgate([0, 1]) - bs01 = BSgate([0, 1]) - bs12 = BSgate([1, 2]) + bs01 = BSgate((0, 1)) + bs12 = BSgate((1, 2)) circ = Circuit([vac, s01, bs01, bs12]) - circ.lookup_path() + circ.check_contraction(0) out1, _ = capfd.readouterr() exp1 = "\n" exp1 += "β†’ index: 0\n" @@ -108,7 +91,7 @@ def test_lookup_path(self, capfd): assert out1 == exp1 circ.path += [(0, 1)] - circ.lookup_path() + circ.check_contraction(1) out2, _ = capfd.readouterr() exp2 = "\n" exp2 += "β†’ index: 0\n" @@ -139,50 +122,12 @@ def test_path(self, path): def test_path_errors(self): vac12 = Vacuum([1, 2]) - d1 = Dgate([1], x=0.1, y=0.1) - d12 = Dgate([1, 2], x=0.1, y=[0.1, 0.2]) - circuit1 = Circuit([vac12, vac12]) - with pytest.raises(ValueError, match="Dangling wires cannot be overwritten."): - circuit1.path += [(0, 1)] + with pytest.raises(ValueError, match="overlap"): + Circuit([vac12, vac12]) - circuit2 = Circuit([vac12.adjoint, vac12.adjoint]) - with pytest.raises(ValueError, match="Dangling wires cannot be overwritten."): - circuit2.path += [(0, 1)] - - circuit3 = Circuit([vac12, d1, d12]) - with pytest.raises(ValueError, match="is invalid."): - circuit3.path += [(0, 2)] - - circuit4 = Circuit([vac12.adjoint, d1.adjoint, d12.adjoint]) - with pytest.raises(ValueError, match="is invalid."): - circuit4.path += [(0, 2)] - - def test_graph(self): - vac12 = Vacuum([1, 2]) - d1 = Dgate([1], x=0.1, y=0.1) - d2 = Dgate([2], x=0.1, y=0.2) - d12 = Dgate([1, 2], x=0.1, y=[0.1, 0.2]) - a1 = Attenuator([1], transmissivity=0.8) - n12 = Number([1, 2], n=1).dual - - circuit = Circuit([vac12, d1, d2, d12, a1, n12]) - assert not circuit._graph # pylint: disable=protected-access - - circuit.make_path("l2r") - graph1 = circuit._graph # pylint: disable=protected-access - - c0 = circuit.components[0] - c1 = circuit.components[1] - assert graph1[c0.wires.output.ket[1].ids[0]] == c1.wires.input.ket[1].ids[0] - - circuit.make_path("r2l") - graph2 = circuit._graph # pylint: disable=protected-access - assert graph1 == graph2 - - circuit.path = [(0, 1), (2, 3)] - graph3 = circuit._graph # pylint: disable=protected-access - assert graph1 == graph3 + with pytest.raises(ValueError, match="overlap"): + Circuit([vac12.adjoint, vac12.adjoint]) def test_eq(self): vac = Vacuum([0, 1, 2]) @@ -290,6 +235,33 @@ def test_repr_issue_334(self): r1 += "\n\n" assert repr(circ1) == r1 + def test_optimize_path(self): + "tests the optimize method" + # contracting the last two first is better + circ = Circuit([Number([0], n=15), Sgate([0], r=1.0), Coherent([0], x=1.0).dual]) + circ.optimize(with_BF_heuristic=True) # with default heuristics + assert circ.path == [(1, 2), (0, 1)] + + circ = Circuit([Number([0], n=15), Sgate([0], r=1.0), Coherent([0], x=1.0).dual]) + circ.optimize(with_BF_heuristic=False) # without the BF heuristic + assert circ.path == [(1, 2), (0, 1)] + + circ = Circuit([Number([0], n=15), Sgate([0], r=1.0), Coherent([0], x=1.0).dual]) + circ.optimize(n_init=1, verbose=False) + assert circ.path == [(1, 2), (0, 1)] + + def test_wrong_path(self): + "tests an exception is raised if contract is called with a wrond path" + circ = Circuit([Number([0], n=15), Sgate([0], r=1.0), Dgate([0], x=1.0)]) + circ.path = [(0, 3)] + with pytest.raises(ValueError): + circ.contract() + + def test_contract(self): + "tests the contract method" + circ = Circuit([Number([0], n=15), Sgate([0], r=1.0), Dgate([0], x=1.0)]) + assert circ.contract() == Number([0], n=15) >> Sgate([0], r=1.0) >> Dgate([0], x=1.0) + def test_serialize_makes_zip(self, tmpdir): """Test that serialize makes a JSON and a zip.""" settings.CACHE_DIR = tmpdir @@ -311,15 +283,33 @@ def test_serialize_custom_name(self, tmpdir): def test_path_is_loaded(self, tmpdir): """Test that circuit paths are saved if already evaluated.""" settings.CACHE_DIR = tmpdir - vac = Vacuum([0, 1, 2]) - s01 = Sgate([0, 1]) + vac = Vacuum([0]) + S0 = Sgate([0]) + s0 = SqueezedVacuum([1]) bs01 = BSgate([0, 1]) - bs12 = BSgate([1, 2]) - - circ = Circuit([vac, s01, bs01, bs12]) - assert not load(circ.serialize())._path - - circ.make_path() - assert circ._path - - assert load(circ.serialize())._path == circ.path + c0 = Coherent([0]).dual + c1 = Coherent([1]).dual + + circ = Circuit([vac, S0, s0, bs01, c0, c1]) + base_path = circ.path + assert load(circ.serialize()).path == base_path + + circ.optimize() + opt_path = circ.path + assert opt_path != base_path + assert load(circ.serialize()).path == opt_path + + def test_graph_children_and_grandchildren(self): + """tests that the children function returns the correct graphs""" + + circ = Circuit([Number([0], n=15), Sgate([0], r=1.0), Dgate([0], x=1.0)]) + bb.assign_costs(circ._graph) + children_set = bb.children(circ._graph, int(1e20)) + for child in children_set: + assert isinstance(child, bb.Graph) + assert len(child.nodes) == 2 + + grandchildren_set = bb.grandchildren(circ._graph, int(1e20)) + for grandchild in grandchildren_set: + assert isinstance(grandchild, bb.Graph) + assert len(grandchild.nodes) == 1 diff --git a/tests/test_lab_dev/test_simulator.py b/tests/test_lab_dev/test_simulator.py deleted file mode 100644 index 624eff309..000000000 --- a/tests/test_lab_dev/test_simulator.py +++ /dev/null @@ -1,63 +0,0 @@ -# Copyright 2023 Xanadu Quantum Technologies Inc. - -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for the ``Simulator`` class.""" - -# pylint: disable=missing-function-docstring, expression-not-assigned - -import numpy as np -import pytest - -from mrmustard import settings -from mrmustard.lab_dev.circuits import Circuit -from mrmustard.lab_dev.simulator import Simulator -from mrmustard.lab_dev.states import Vacuum, Number -from mrmustard.lab_dev.transformations import Dgate, Attenuator - -# original settings -autocutoff_max0 = settings.AUTOSHAPE_MAX - - -class TestSimulator: - r""" - Tests for the ``Circuit`` class. - """ - - @pytest.mark.parametrize( - "path", - [ - [], - [(0, 1), (2, 3), (0, 2), (0, 4), (0, 5)], - [(4, 5), (3, 4), (2, 3), (1, 2), (0, 1)], - ], - ) - def test_run(self, path): - settings.AUTOSHAPE_MAX = 10 - - vac12 = Vacuum([1, 2]) - d1 = Dgate([1], x=0.1, y=0.1) - d2 = Dgate([2], x=0.1, y=0.2) - d12 = Dgate([1, 2], x=0.1, y=[0.1, 0.2]) - a1 = Attenuator([1], transmissivity=0.8) - n12 = Number([1, 2], n=1).dual - - circuit = Circuit([vac12, d1, d2, d12, a1, n12]) - circuit.path = path - - res = Simulator().run(circuit) - exp = vac12 >> d1 >> d2 >> d12 >> a1 >> n12 - - assert np.isclose(res, exp) - - settings.AUTOSHAPE_MAX = autocutoff_max0 diff --git a/tests/test_physics/test_representations.py b/tests/test_physics/test_representations.py index 2513c1675..b0f03f7e1 100644 --- a/tests/test_physics/test_representations.py +++ b/tests/test_physics/test_representations.py @@ -144,7 +144,7 @@ def test_add_error(self): fock = Fock(np.random.random((1, 4, 4, 4)), batched=True) with pytest.raises(ValueError): - bargmann + fock + bargmann + fock # pylint: disable=pointless-statement @pytest.mark.parametrize("n", [1, 2, 3]) def test_sub(self, n): diff --git a/tests/test_utils/test_serialize.py b/tests/test_utils/test_serialize.py index d8398f4fd..21e85cab5 100644 --- a/tests/test_utils/test_serialize.py +++ b/tests/test_utils/test_serialize.py @@ -138,7 +138,8 @@ def test_two_numpy_obj(self): def test_overlap_forbidden(self): """Test that array names must be distinct from non-array names.""" with pytest.raises( - ValueError, match=r"Arrays cannot have the same name as generic data: {'val'}" + ValueError, + match=r"Arrays cannot have the same name as generic data: {'val'}", ): save(Dummy, arrays={"val": [1]}, val=2) @@ -189,14 +190,16 @@ def test_all_components_serializable(self): Attenuator([1, 2], transmissivity=0.1), BtoPS([0], s=1), TraceOut([0, 1]), - Thermal([1, 2], nbar=3), - Coherent([0, 1], x=[0.3, 0.4], y=0.2, y_trainable=True, y_bounds=(-0.5, 0.5)), + Thermal([0, 1], nbar=3), + Coherent([0, 1], x=[0.3, 0.4], y=0.2, y_trainable=True, y_bounds=(-0.5, 0.5)).dual, DisplacedSqueezed([0], 1, 2, 3, 4, x_bounds=(-1.5, 1.5), x_trainable=True), - Number([0, 1], n=[10, 20]), - QuadratureEigenstate([1, 2], x=1, phi=0, phi_trainable=True, phi_bounds=[-1, 1]), + Number([1], n=[20]), + QuadratureEigenstate( + [0, 1, 2], x=1, phi=0, phi_trainable=True, phi_bounds=(-1, 1) + ).dual, SqueezedVacuum([0, 1, 2], r=[0.3, 0.4, 0.5], phi=0.2), - TwoModeSqueezedVacuum([0, 1], r=0.3, phi=0.2), - Vacuum([1, 2]), + TwoModeSqueezedVacuum([0, 1], r=0.3, phi=0.2).dual, + Vacuum([2]).dual, ], ) path = circ.serialize()