Skip to content

Commit

Permalink
define new OuterProductBasis class. Unclear if we'll want to keep this.'
Browse files Browse the repository at this point in the history
  • Loading branch information
rileyjmurray committed Sep 28, 2024
1 parent bb55d13 commit 81ff50c
Showing 1 changed file with 72 additions and 0 deletions.
72 changes: 72 additions & 0 deletions pygsti/baseobjs/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -1762,6 +1762,78 @@ def create_simple_equivalent(self, builtin_basis_name=None):
return BuiltinBasis(builtin_basis_name, self.elsize, sparse=self.sparse)


class OuterProdBasis(TensorProdBasis):

def __init__(self, component_bases, name=None, longname=None, squeeze=True):
"""
Let ⨂ denote an infix operator where (a ⨂ b) := numpy.tensordot(a, b, axes=0).
Suppose component_bases has length k. Under the default setting of squeeze=True,
the elements of this OuterProdBasis are all arrays of the form
v1.squeeze() ⨂ v2.squeeze() ⨂ ... ⨂ vk.squeeze(),
where vi belongs to component_bases[i].elements. The definition is changed in the
natural way if squeeze=False.
"""
TensorProdBasis.__init__(self, component_bases, name, longname)
if self.sparse:
raise NotImplementedError()
cnames = [c.name for c in self.component_bases]
if name is None:
self.name = "⨂".join(cnames)
if longname is None:
self.longname = "Outer-product basis with components " + ", ".join(cnames)
self._squeeze = squeeze
return

@property
def elshape(self):
shape = []
if self._squeeze:
for c in self.component_bases:
shape.extend(d for d in c.elshape if d > 1)
else:
for c in self.component_bases:
shape.extend(c.elshape)
return tuple(shape)

def _lazy_build_elements(self):
compMxs = _np.zeros((self.size,) + self.elshape, 'complex')
comp_els = [c.elements for c in self.component_bases]
for i, factors in enumerate(_itertools.product(*comp_els)):
M = factors[0].squeeze() if self._squeeze else factors[0]
for f in factors[1:]:
if self._squeeze:
f = f.squeeze()
M = _np.tensordot(M, f, axes=0)
compMxs[i] = M
self._elements = compMxs

def _lazy_build_labels(self):
self._labels = []
comp_lbls = [c.labels for c in self.component_bases]
for i, factor_lbls in enumerate(_itertools.product(*comp_lbls)):
self._labels.append('⨂'.join(factor_lbls))

def is_equivalent(self, other, sparseness_must_match=True):
if not sparseness_must_match:
raise NotImplementedError()
otherIsBasis = isinstance(other, OuterProdBasis)
if not otherIsBasis: return False # can't be equal to a non-DirectSumBasis
return all([c1.is_equivalent(c2, True)
for (c1, c2) in zip(self.component_bases, other.component_bases)])

def _copy_with_toggled_sparsity(self):
raise NotImplementedError()

def create_equivalent(self, builtin_basis_name):
raise NotImplementedError()

def create_simple_equivalent(self, builtin_basis_name=None):
raise NotImplementedError()


class EmbeddedBasis(LazyBasis):
"""
A basis that embeds a basis for a smaller state space within a larger state space.
Expand Down

0 comments on commit 81ff50c

Please sign in to comment.