Skip to content

Commit

Permalink
TensorProductBasis: drop orth_weight argument
Browse files Browse the repository at this point in the history
  • Loading branch information
inducer committed Nov 14, 2023
1 parent 65ba365 commit 7b8174c
Showing 1 changed file with 8 additions and 28 deletions.
36 changes: 8 additions & 28 deletions modepy/modes.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,6 @@ def _cse(expr, prefix):
else:
return expr

return expr


def _where(op_a, comp, op_b, then, else_):
from pymbolic.primitives import Comparison, Expression, If
Expand Down Expand Up @@ -940,6 +938,9 @@ def orthonormality_weight(self) -> float:
"""
:raises: :exc:`BasisNotOrthonormal` if the basis does not have
a weight, i.e. it is not orthogonal.
For now, only scalar orthonormality weights are supported.
In the future, this may become a symbolic expression involving
symbols ``"rstuvw"``.
"""

@property
Expand Down Expand Up @@ -1115,28 +1116,23 @@ class TensorProductBasis(Basis):

def __init__(self,
bases: Sequence[Basis],
orth_weight: Optional[float],
dims_per_basis: Optional[Tuple[int, ...]] = None) -> None:
"""
:param bases: a sequence of 1D bases used to construct the tensor
product approximation basis.
:param orth_weight: if *bases* forms an orthogonal basis, this should
be the normalizing weight. If *None*, then the basis is assumed to
not be orthogonal (this is not checked).
"""

if dims_per_basis is None:
dims_per_basis = (1,) * len(bases)

self._bases = tuple(bases)
self._orth_weight = orth_weight
self._dims_per_basis = tuple(dims_per_basis)

def orthonormality_weight(self):
if self._orth_weight is None:
raise BasisNotOrthonormal
else:
return self._orth_weight
def orthonormality_weight(self) -> float:
orth_weight: float = 1.0
for b in self._bases:
orth_weight *= b.orthonormality_weight()
return orth_weight

@property
def bases(self) -> Sequence[Basis]:
Expand Down Expand Up @@ -1216,19 +1212,6 @@ def gradients(self):
for mid in self._mode_index_tuples)


def _get_orth_weight(bases: Sequence[Basis]) -> Optional[float]:
orth_weight: Optional[float] = 1.0
for b in bases:
try:
assert orth_weight is not None
orth_weight *= b.orthonormality_weight()
except BasisNotOrthonormal:
orth_weight = None
break

return orth_weight


@orthonormal_basis_for_space.register(TensorProductSpace)
def _orthonormal_basis_for_tp(
space: TensorProductSpace,
Expand All @@ -1244,7 +1227,6 @@ def _orthonormal_basis_for_tp(

return TensorProductBasis(
bases,
orth_weight=_get_orth_weight(bases),
dims_per_basis=tuple([b.spatial_dim for b in space.bases]))


Expand All @@ -1259,7 +1241,6 @@ def _basis_for_tp(space: TensorProductSpace, shape: TensorProductShape):
bases = [basis_for_space(b, s) for b, s in zip(space.bases, shape.bases)]
return TensorProductBasis(
bases,
orth_weight=_get_orth_weight(bases),
dims_per_basis=tuple([b.spatial_dim for b in space.bases]))


Expand All @@ -1274,7 +1255,6 @@ def _monomial_basis_for_tp(space: TensorProductSpace, shape: TensorProductShape)

return TensorProductBasis(
bases,
orth_weight=None,
dims_per_basis=tuple([b.spatial_dim for b in space.bases]))

# }}}
Expand Down

0 comments on commit 7b8174c

Please sign in to comment.