Skip to content

Commit

Permalink
Saving changes
Browse files Browse the repository at this point in the history
  • Loading branch information
Nicolas Quesada committed Aug 17, 2023
1 parent 8e43ac1 commit c23cae3
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 26 deletions.
8 changes: 8 additions & 0 deletions thewalrus/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,12 @@
rec_torontonian,
rec_ltorontonian,
)

from ._montrealer import (
mtl,
lmtl,
)

from ._version import __version__


Expand All @@ -152,6 +158,8 @@
"reduction",
"hermite_multidimensional",
"grad_hermite_multidimensional",
"mtl",
"lmtl",
"version",
]

Expand Down
50 changes: 25 additions & 25 deletions thewalrus/_torontonian.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,26 +20,41 @@
from ._hafnian import reduction, find_kept_edges, nb_ix


def tor(A, recursive=True):
"""Returns the Torontonian of a matrix.
def tor_input_checks(A, loops=None):
"""Checks the correcteness of the inputs for the torontonian/montrealer.
Args:
A (array): a square array of even dimensions.
recursive: use the faster recursive implementation.
Returns:
np.float64 or np.complex128: the torontonian of matrix A.
A (array): an NxN array of even dimensions.
loops (array): optinal argument, an N-length vector of even dimensions.
"""
if not isinstance(A, np.ndarray):
raise TypeError("Input matrix must be a NumPy array.")

matshape = A.shape

if matshape[0] != matshape[1]:
raise ValueError("Input matrix must be square.")


if matshape[0] % 2 != 0:
raise ValueError("matrix dimension must be even")


if loops is not None:
if not isinstance(loops, np.ndarray):
raise TypeError("Input matrix must be a NumPy array.")
if matshape[0] != len(loops):
raise ValueError("gamma must be a vector matching the dimension of A")

def tor(A, recursive=True):
"""Returns the Torontonian of a matrix.
Args:
A (array): a square array of even dimensions.
recursive: use the faster recursive implementation.
Returns:
np.float64 or np.complex128: the torontonian of matrix A.
"""
tor_input_checks(A)
return rec_torontonian(A) if recursive else numba_tor(A)


Expand All @@ -54,23 +69,8 @@ def ltor(A, gamma, recursive=True):
Returns:
np.float64 or np.complex128: the loop torontonian of matrix A, vector gamma
"""
tor_input_checks(A, gamma)

if not isinstance(A, np.ndarray):
raise TypeError("Input matrix must be a NumPy array.")

if not isinstance(gamma, np.ndarray):
raise TypeError("Input matrix must be a NumPy array.")

matshape = A.shape

if matshape[0] != matshape[1]:
raise ValueError("Input matrix must be square.")

if matshape[0] != len(gamma):
raise ValueError("gamma must be a vector matching the dimension of A")

if matshape[0] % 2 != 0:
raise ValueError("matrix dimension must be even")

return rec_ltorontonian(A, gamma) if recursive else numba_ltor(A, gamma)

Expand Down
90 changes: 89 additions & 1 deletion thewalrus/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,14 @@
.. autosummary::
hafnian
montrealer
Code details
------------
.. autofunction::
hafnian
montrealer
Auxiliary functions
-------------------
Expand All @@ -49,6 +51,8 @@
partitions
spm
pmp
rspm
rpmp
T
Code details
Expand All @@ -58,7 +62,7 @@

# pylint: disable=too-many-arguments
from collections import OrderedDict
from itertools import tee
from itertools import tee, product, permutations, chain
from types import GeneratorType

MAXSIZE = 1000
Expand Down Expand Up @@ -278,3 +282,87 @@ def hafnian(M, loop=False):
tot_sum = tot_sum + result

return tot_sum


def mapper(x, objects):
"""Helper function to turn a permutation and bistring into an element of rpmp.
Args:
x (tuple): tuple containing a permutation and a bistring.
objects (list): list objects to permute
Returns:
tuple: permuted objects
"""
(perm, bit) = x
m = len(bit)
Blist = [list(range(m)), list(range(m, 2 * m))]
for i, j in enumerate(bit):
if int(j):
(Blist[0][i], Blist[1][i]) = (Blist[1][i], Blist[0][i])
Blist = [Blist[0][i] for i in tuple((0,) + perm)] + [Blist[1][i] for i in tuple((0,) + perm)]
dico_list = {j: i + 1 for i, j in enumerate(Blist)}
new_mapping_list = {
objects[dico_list[i] - 1]: objects[dico_list[j] - 1]
for i, j in zip(list(range(0, m - 1)) + [m], list(range(m + 1, 2 * m)) + [m - 1])
}
return tuple(new_mapping_list.items())

def bitstrings(n):
"""Returns the bistrings from 0 to n/2
Args:
n (int) : Twice the highest bitstring value.
Returns:
(iterator) : An iterable of all bistrings.
"""
for binary in map("".join, product("01", repeat=n - 1)):
yield "0" + binary

def rpmp(s):
"""Generates the restricted set of perfect matchings matching permutations.
Args:
s (tuple): tuple of labels to be used
Returns:
generator: the set of restricted perfect matching permutations of the tuple s
"""
m = len(s) // 2
local_mapper = lambda x: mapper(x, s)
for i in product(permutations(range(1, m)), bitstrings(m)):
yield local_mapper(i)

def splitter(elem):
"""Takes an element from rpmp and returns all the associated elements in rspm
Args:
elem (tuple): tuple representing an element of rpmp
Returns:
(iterator): all the associated elements in rspm
"""
num_elem = len(elem)
net = [elem]
for i in range(num_elem):
left = (elem[j] for j in range(i))
middle = ((elem[i][0],elem[i][0]),(elem[i][1],elem[i][1]))
right = (elem[j] for j in range(i+1,num_elem))
net.append(tuple(left) + tuple(middle) + tuple(right))
for i in net:
yield i



def rspm(s):
"""Generates the restricted set of single-pair matchings.
Args:
s (tuple): tuple of labels to be used
Returns:
generator: the set of restricted perfect matching permutations of the tuple s
"""
gen = rpmp(s)
return chain(*(splitter(i) for i in gen))

0 comments on commit c23cae3

Please sign in to comment.