From 7b8174c430ff45b977df030f1c365be0d72085e3 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Tue, 14 Nov 2023 09:57:35 -0600 Subject: [PATCH] TensorProductBasis: drop orth_weight argument --- modepy/modes.py | 36 ++++++++---------------------------- 1 file changed, 8 insertions(+), 28 deletions(-) diff --git a/modepy/modes.py b/modepy/modes.py index 55cd1c4..b755dc8 100644 --- a/modepy/modes.py +++ b/modepy/modes.py @@ -126,8 +126,6 @@ def _cse(expr, prefix): else: return expr - return expr - def _where(op_a, comp, op_b, then, else_): from pymbolic.primitives import Comparison, Expression, If @@ -940,6 +938,9 @@ def orthonormality_weight(self) -> float: """ :raises: :exc:`BasisNotOrthonormal` if the basis does not have a weight, i.e. it is not orthogonal. + For now, only scalar orthonormality weights are supported. + In the future, this may become a symbolic expression involving + symbols ``"rstuvw"``. """ @property @@ -1115,28 +1116,23 @@ class TensorProductBasis(Basis): def __init__(self, bases: Sequence[Basis], - orth_weight: Optional[float], dims_per_basis: Optional[Tuple[int, ...]] = None) -> None: """ :param bases: a sequence of 1D bases used to construct the tensor product approximation basis. - :param orth_weight: if *bases* forms an orthogonal basis, this should - be the normalizing weight. If *None*, then the basis is assumed to - not be orthogonal (this is not checked). """ if dims_per_basis is None: dims_per_basis = (1,) * len(bases) self._bases = tuple(bases) - self._orth_weight = orth_weight self._dims_per_basis = tuple(dims_per_basis) - def orthonormality_weight(self): - if self._orth_weight is None: - raise BasisNotOrthonormal - else: - return self._orth_weight + def orthonormality_weight(self) -> float: + orth_weight: float = 1.0 + for b in self._bases: + orth_weight *= b.orthonormality_weight() + return orth_weight @property def bases(self) -> Sequence[Basis]: @@ -1216,19 +1212,6 @@ def gradients(self): for mid in self._mode_index_tuples) -def _get_orth_weight(bases: Sequence[Basis]) -> Optional[float]: - orth_weight: Optional[float] = 1.0 - for b in bases: - try: - assert orth_weight is not None - orth_weight *= b.orthonormality_weight() - except BasisNotOrthonormal: - orth_weight = None - break - - return orth_weight - - @orthonormal_basis_for_space.register(TensorProductSpace) def _orthonormal_basis_for_tp( space: TensorProductSpace, @@ -1244,7 +1227,6 @@ def _orthonormal_basis_for_tp( return TensorProductBasis( bases, - orth_weight=_get_orth_weight(bases), dims_per_basis=tuple([b.spatial_dim for b in space.bases])) @@ -1259,7 +1241,6 @@ def _basis_for_tp(space: TensorProductSpace, shape: TensorProductShape): bases = [basis_for_space(b, s) for b, s in zip(space.bases, shape.bases)] return TensorProductBasis( bases, - orth_weight=_get_orth_weight(bases), dims_per_basis=tuple([b.spatial_dim for b in space.bases])) @@ -1274,7 +1255,6 @@ def _monomial_basis_for_tp(space: TensorProductSpace, shape: TensorProductShape) return TensorProductBasis( bases, - orth_weight=None, dims_per_basis=tuple([b.spatial_dim for b in space.bases])) # }}}