From 8197f71ce1c3495ed26597b7ce86966e077bf073 Mon Sep 17 00:00:00 2001 From: unalmis Date: Tue, 20 Aug 2024 12:19:01 -0400 Subject: [PATCH] Force push with lease to avoid diverging branch with remote due to commit 0a5216c --- desc/compute/utils.py | 43 +- desc/grid.py | 2 +- desc/integrals/__init__.py | 3 + desc/{compute => integrals}/_interp_utils.py | 8 +- desc/{compute => integrals}/_quad_utils.py | 0 .../{compute => integrals}/bounce_integral.py | 122 ++-- .../fourier_bounce_integral.py | 526 ++++++++++-------- desc/utils.py | 73 ++- tests/test_bounce_integral.py | 52 +- tests/test_fourier_bounce.py | 114 +++- tests/test_interp_utils.py | 16 +- 11 files changed, 550 insertions(+), 409 deletions(-) create mode 100644 desc/integrals/__init__.py rename desc/{compute => integrals}/_interp_utils.py (98%) rename desc/{compute => integrals}/_quad_utils.py (100%) rename desc/{compute => integrals}/bounce_integral.py (92%) rename desc/{compute => integrals}/fourier_bounce_integral.py (75%) diff --git a/desc/compute/utils.py b/desc/compute/utils.py index 7d7a2562dd..92c41a000f 100644 --- a/desc/compute/utils.py +++ b/desc/compute/utils.py @@ -2,14 +2,13 @@ import copy import inspect -from functools import partial import numpy as np -from desc.backend import cond, execute_on_cpu, flatnonzero, fori_loop, jnp, put, take +from desc.backend import cond, execute_on_cpu, fori_loop, jnp, put from desc.grid import ConcentricGrid, Grid, LinearGrid -from ..utils import errorif, setdefault, warnif +from ..utils import errorif, warnif from .data_index import allowed_kwargs, data_index # map from profile name to equilibrium parameter name @@ -1580,41 +1579,3 @@ def body(i, mins): # The above implementation was benchmarked to be more efficient than # alternatives without explicit loops in GitHub pull request #501. return grid.expand(mins, surface_label) - - -@partial(jnp.vectorize, signature="(m),(m)->(n)", excluded={"size", "fill_value"}) -def take_mask(a, mask, size=None, fill_value=None): - """JIT compilable method to return ``a[mask][:size]`` padded by ``fill_value``. - - Parameters - ---------- - a : jnp.ndarray - The source array. - mask : jnp.ndarray - Boolean mask to index into ``a``. Should have same shape as ``a``. - size : int - Elements of ``a`` at the first size True indices of ``mask`` will be returned. - If there are fewer elements than size indicates, the returned array will be - padded with ``fill_value``. The size default is ``mask.size``. - fill_value : Any - When there are fewer than the indicated number of elements, the remaining - elements will be filled with ``fill_value``. Defaults to NaN for inexact types, - the largest negative value for signed types, the largest positive value for - unsigned types, and True for booleans. - - Returns - ------- - result : jnp.ndarray - Shape (size, ). - - """ - assert a.shape == mask.shape - idx = flatnonzero(mask, size=setdefault(size, mask.size), fill_value=mask.size) - return take( - a, - idx, - mode="fill", - fill_value=fill_value, - unique_indices=True, - indices_are_sorted=True, - ) diff --git a/desc/grid.py b/desc/grid.py index 359917c10b..b5afa3ab16 100644 --- a/desc/grid.py +++ b/desc/grid.py @@ -742,7 +742,7 @@ def create_meshgrid( rtz : rho, theta, zeta period : tuple of float Assumed periodicity for each coordinate. - Use np.inf to denote no periodicity. + Use ``np.inf`` to denote no periodicity. NFP : int Number of field periods (Default = 1). Only makes sense to change from 1 if last coordinate is periodic diff --git a/desc/integrals/__init__.py b/desc/integrals/__init__.py new file mode 100644 index 0000000000..419801a33e --- /dev/null +++ b/desc/integrals/__init__.py @@ -0,0 +1,3 @@ +"""Classes for integration.""" + +from .fourier_bounce_integral import FourierChebyshevBasis, PiecewiseChebyshevBasis diff --git a/desc/compute/_interp_utils.py b/desc/integrals/_interp_utils.py similarity index 98% rename from desc/compute/_interp_utils.py rename to desc/integrals/_interp_utils.py index 8284c1a02d..ea022891c1 100644 --- a/desc/compute/_interp_utils.py +++ b/desc/integrals/_interp_utils.py @@ -6,8 +6,8 @@ from orthax.polynomial import polyvander from desc.backend import dct, jnp, rfft, rfft2, take -from desc.compute._quad_utils import bijection_from_disc from desc.compute.utils import safediv +from desc.integrals._quad_utils import bijection_from_disc from desc.utils import Index, errorif @@ -314,8 +314,8 @@ def interp_dct(xq, f, lobatto=False, axis=-1): lobatto = bool(lobatto) errorif(lobatto, NotImplementedError) assert f.ndim >= 1 - a = cheb_from_dct( - dct(f, type=2 - lobatto, axis=axis) / (f.shape[axis] - lobatto), axis + a = cheb_from_dct(dct(f, type=2 - lobatto, axis=axis), axis) / ( + f.shape[axis] - lobatto ) fq = idct_non_uniform(xq, a, f.shape[axis], axis) return fq @@ -345,7 +345,7 @@ def idct_non_uniform(xq, a, n, axis=-1): assert a.ndim >= 1 a = jnp.moveaxis(a, axis, -1) basis = chebvander(xq, n - 1) - # Could instead use Clenshaw recursion with ``fq=chebval(xq,a,tensor=False)``. + # Could use Clenshaw recursion with fq = chebval(xq, a, tensor=False). fq = jnp.linalg.vecdot(basis, a) return fq diff --git a/desc/compute/_quad_utils.py b/desc/integrals/_quad_utils.py similarity index 100% rename from desc/compute/_quad_utils.py rename to desc/integrals/_quad_utils.py diff --git a/desc/compute/bounce_integral.py b/desc/integrals/bounce_integral.py similarity index 92% rename from desc/compute/bounce_integral.py rename to desc/integrals/bounce_integral.py index bff1b9cdf6..2b61dfece5 100644 --- a/desc/compute/bounce_integral.py +++ b/desc/integrals/bounce_integral.py @@ -2,45 +2,34 @@ from functools import partial -import numpy as np from interpax import CubicHermiteSpline, PPoly, interp1d from jax.nn import softmax from matplotlib import pyplot as plt -from numpy.polynomial.legendre import leggauss +from orthax.legendre import leggauss +from tests.test_interp_utils import filter_not_nan from desc.backend import flatnonzero, imap, jnp, put -from desc.compute._interp_utils import poly_root, polyder_vec, polyval_vec -from desc.compute._quad_utils import ( +from desc.integrals._interp_utils import poly_root, polyder_vec, polyval_vec +from desc.integrals._quad_utils import ( automorphism_sin, bijection_from_disc, grad_automorphism_sin, grad_bijection_from_disc, ) -from desc.compute.utils import take_mask -from desc.utils import errorif, setdefault, warnif +from desc.utils import errorif, setdefault, take_mask, warnif -# use for debugging and testing -def _filter_not_nan(a, check=False): - """Filter out nan from ``a`` while asserting nan is padded at right.""" - is_nan = np.isnan(a) - if check: - assert np.array_equal(is_nan, np.sort(is_nan, axis=-1)) - return a[~is_nan] - - -# use for debugging and testing -def _filter_nonzero_measure(bp1, bp2): +def filter_bounce_points(bp1, bp2): """Return only bounce points such that |bp2 - bp1| > 0.""" - mask = (bp2 - bp1) != 0 + mask = (bp2 - bp1) != 0.0 return bp1[mask], bp2[mask] def plot_field_line( B, pitch=None, - bp1=np.array([]), - bp2=np.array([]), + bp1=jnp.array([]), + bp2=jnp.array([]), start=None, stop=None, num=1000, @@ -57,11 +46,11 @@ def plot_field_line( ---------- B : PPoly Spline of |B| over given field line. - pitch : np.ndarray + pitch : jnp.ndarray λ value. - bp1 : np.ndarray + bp1 : jnp.ndarray Bounce points with (∂|B|/∂ζ)|ρ,α <= 0. - bp2 : np.ndarray + bp2 : jnp.ndarray Bounce points with (∂|B|/∂ζ)|ρ,α >= 0. start : float Minimum ζ on plot. @@ -90,9 +79,7 @@ def plot_field_line( legend = {} def add(lines): - if not hasattr(lines, "__iter__"): - lines = [lines] - for line in lines: + for line in setdefault(lines, [lines], hasattr(lines, "__iter__")): label = line.get_label() if label not in legend: legend[label] = line @@ -101,7 +88,7 @@ def add(lines): if include_knots: for knot in B.x: add(ax.axvline(x=knot, color="tab:blue", alpha=alpha_knot, label="knot")) - z = np.linspace( + z = jnp.linspace( start=setdefault(start, B.x[0]), stop=setdefault(stop, B.x[-1]), num=num, @@ -109,24 +96,24 @@ def add(lines): add(ax.plot(z, B(z), label=r"$\vert B \vert (\zeta)$")) if pitch is not None: - b = 1 / np.atleast_1d(pitch) + b = 1 / jnp.atleast_1d(pitch) for val in b: add( ax.axhline( val, color="tab:purple", alpha=alpha_pitch, label=r"$1 / \lambda$" ) ) - bp1, bp2 = np.atleast_2d(bp1, bp2) + bp1, bp2 = jnp.atleast_2d(bp1, bp2) for i in range(bp1.shape[0]): if bp1.shape == bp2.shape: - bp1_i, bp2_i = _filter_nonzero_measure(bp1[i], bp2[i]) + bp1_i, bp2_i = filter_bounce_points(bp1[i], bp2[i]) else: bp1_i, bp2_i = bp1[i], bp2[i] - bp1_i, bp2_i = map(_filter_not_nan, (bp1_i, bp2_i)) + bp1_i, bp2_i = bp1_i[~jnp.isnan(bp1_i)], bp2_i[~jnp.isnan(bp2_i)] add( ax.scatter( bp1_i, - np.full_like(bp1_i, b[i]), + jnp.full_like(bp1_i, b[i]), marker="v", color="tab:red", label="bp1", @@ -135,7 +122,7 @@ def add(lines): add( ax.scatter( bp2_i, - np.full_like(bp2_i, b[i]), + jnp.full_like(bp2_i, b[i]), marker="^", color="tab:green", label="bp2", @@ -155,44 +142,55 @@ def add(lines): return fig, ax -def _check_bounce_points(bp1, bp2, sentinel, pitch, knots, B_c, plot, **kwargs): +def _check_bounce_points(bp1, bp2, pitch, knots, B_c, plot, **kwargs): """Check that bounce points are computed correctly.""" - bp1 = jnp.where(bp1 > sentinel, bp1, jnp.nan) - bp2 = jnp.where(bp2 > sentinel, bp2, jnp.nan) + assert bp1.shape == bp2.shape + mask = (bp1 - bp2) == 0 + bp1 = jnp.where(mask, jnp.nan, bp1) + bp2 = jnp.where(mask, jnp.nan, bp2) eps = jnp.finfo(jnp.array(1.0).dtype).eps * 10 - P, S = bp1.shape[:-1] - msg_1 = "Bounce points have an inversion." + msg_1 = "Bounce points have an inversion.\n" err_1 = jnp.any(bp1 > bp2, axis=-1) - msg_2 = "Discontinuity detected." + msg_2 = "Discontinuity detected.\n" err_2 = jnp.any(bp1[..., 1:] < bp2[..., :-1], axis=-1) + P, S, _ = bp1.shape for s in range(S): B = PPoly(B_c[:, s], knots) for p in range(P): - B_mid = B((bp1[p, s] + bp2[p, s]) / 2) - err_3 = jnp.any(B_mid > 1 / pitch[p, s] + eps) + B_m_ps = B((bp1[p, s] + bp2[p, s]) / 2) + err_3 = jnp.any(B_m_ps > 1 / pitch[p, s] + eps) if err_1[p, s] or err_2[p, s] or err_3: - bp1_p = _filter_not_nan(bp1[p, s], check=True) - bp2_p = _filter_not_nan(bp2[p, s], check=True) - B_mid = _filter_not_nan(B_mid, check=True) + bp1_ps, bp2_ps, B_m_ps = map( + filter_not_nan, (bp1[p, s], bp2[p, s], B_m_ps) + ) if plot: plot_field_line( - B, pitch[p, s], bp1_p, bp2_p, title_id=f"{p},{s}", **kwargs + B, + pitch[p, s], + bp1_ps, + bp2_ps, + title_id=f"{p},{s}", + **kwargs, ) - print("bp1:", bp1_p) - print("bp2:", bp2_p) + print("bp1:", bp1_ps) + print("bp2:", bp2_ps) assert not err_1[p, s], msg_1 assert not err_2[p, s], msg_2 msg_3 = ( - f"Detected B midpoint = {B_mid}>{1 / pitch[p, s] + eps} = 1/pitch. " - "You need to use more knots or, if that is infeasible, switch to a " - "monotonic spline method.\n" + f"Detected |B| = {B_m_ps} > {1 / pitch[p, s] + eps} = 1/λ in well. " + "Use more knots or switch to a monotonic spline method.\n" ) assert not err_3, msg_3 if plot: plot_field_line( - B, pitch[:, s], bp1[:, s], bp2[:, s], title_id=str(s), **kwargs + B, + pitch[:, s], + bp1[:, s], + bp2[:, s], + title_id=str(s), + **kwargs, ) @@ -334,7 +332,7 @@ def bounce_points( a_min=jnp.array([0.0]), a_max=jnp.diff(knots), sort=True, - sentinel=-1, + sentinel=-1.0, distinct=True, ) assert intersect.shape == (P, S, N, degree) @@ -356,13 +354,14 @@ def bounce_points( bp1 = take_mask(intersect, is_bp1, size=num_well, fill_value=sentinel) bp2 = take_mask(intersect, is_bp2, size=num_well, fill_value=sentinel) - if check: - _check_bounce_points(bp1, bp2, sentinel, pitch, knots, B_c, plot, **kwargs) - mask = (bp1 > sentinel) & (bp2 > sentinel) # Set outside mask to same value so integration is over set of measure zero. - bp1 = jnp.where(mask, bp1, 0) - bp2 = jnp.where(mask, bp2, 0) + bp1 = jnp.where(mask, bp1, 0.0) + bp2 = jnp.where(mask, bp2, 0.0) + + if check: + _check_bounce_points(bp1, bp2, pitch, knots, B_c, plot, **kwargs) + return bp1, bp2 @@ -626,12 +625,7 @@ def _bounce_quadrature( Parameters ---------- - bp1 : jnp.ndarray - Shape (P, S, num_well). - The field line-following ζ coordinates of bounce points for a given pitch along - a field line. The pairs ``bp1[i,j,k]`` and ``bp2[i,j,k]`` form left and right - integration boundaries, respectively, for the bounce integrals. - bp2 : jnp.ndarray + bp1, bp2 : jnp.ndarray Shape (P, S, num_well). The field line-following ζ coordinates of bounce points for a given pitch along a field line. The pairs ``bp1[i,j,k]`` and ``bp2[i,j,k]`` form left and right @@ -876,7 +870,7 @@ def bounce_integral( if automorphism is not None: auto, grad_auto = automorphism w = w * grad_auto(x) - # Recall affine_bijection(auto(x), ζ_b₁, ζ_b₂) = ζ. + # Recall bijection_from_disc(auto(x), ζ_b₁, ζ_b₂) = ζ. x = auto(x) def bounce_integrate( diff --git a/desc/compute/fourier_bounce_integral.py b/desc/integrals/fourier_bounce_integral.py similarity index 75% rename from desc/compute/fourier_bounce_integral.py rename to desc/integrals/fourier_bounce_integral.py index 0459996601..cf03b7596b 100644 --- a/desc/compute/fourier_bounce_integral.py +++ b/desc/integrals/fourier_bounce_integral.py @@ -6,7 +6,7 @@ from orthax.legendre import leggauss from desc.backend import dct, idct, irfft, jnp, rfft, rfft2 -from desc.compute._interp_utils import ( +from desc.integrals._interp_utils import ( _filter_distinct, cheb_from_dct, cheb_pts, @@ -17,17 +17,23 @@ irfft2_non_uniform, irfft_non_uniform, ) -from desc.compute._quad_utils import ( +from desc.integrals._quad_utils import ( automorphism_sin, bijection_from_disc, bijection_to_disc, grad_automorphism_sin, ) -from desc.compute.bounce_integral import _filter_nonzero_measure, _fix_inversion -from desc.compute.utils import take_mask -from desc.utils import errorif, warnif +from desc.integrals.bounce_integral import _fix_inversion, filter_bounce_points +from desc.utils import ( + atleast_2d_end, + atleast_3d_mid, + atleast_nd, + errorif, + setdefault, + take_mask, + warnif, +) -# TODO: There are better techniques to find eigenvalues of Chebyshev colleague matrix. _chebroots_vec = jnp.vectorize(chebroots, signature="(m)->(n)") @@ -36,7 +42,7 @@ def _flatten_matrix(y): return y.reshape(*y.shape[:-2], -1) -def alpha_sequence(alpha_0, iota, num_period, period=2 * jnp.pi): +def alpha_sequence(alpha_0, iota, num_transit, period=2 * jnp.pi): """Get sequence of poloidal coordinates A = (α₀, α₁, …, αₘ₋₁) of field line. Parameters @@ -46,36 +52,23 @@ def alpha_sequence(alpha_0, iota, num_period, period=2 * jnp.pi): iota : jnp.ndarray Shape (iota.size, ). Rotational transform normalized by 2π. - num_period : float - Number of periods to follow field line. + num_transit : float + Number of ``period``s to follow field line. period : float Toroidal period after which to update label. Returns ------- alphas : jnp.ndarray - Shape (iota.size, num_period). + Shape (iota.size, num_transit). Sequence of poloidal coordinates A = (α₀, α₁, …, αₘ₋₁) that specify field line. """ # Δϕ (∂α/∂ϕ) = Δϕ ι̅ = Δϕ ι/2π = Δϕ data["iota"] - alphas = alpha_0 + period * iota[:, jnp.newaxis] * jnp.arange(num_period) + alphas = alpha_0 + period * iota[:, jnp.newaxis] * jnp.arange(num_transit) return alphas -def _subtract(c, k): - # subtract k from last axis of c, obeying numpy broadcasting - c_0 = c[..., 0] - k - c = jnp.concatenate( - [ - jnp.broadcast_to(c[..., 1:], (*c_0.shape, c.shape[-1] - 1)), - c_0[..., jnp.newaxis], - ], - axis=-1, - ) - return c - - class FourierChebyshevBasis: """Fourier-Chebyshev series. @@ -113,19 +106,14 @@ def __init__(self, f, lobatto=False, domain=(0, 2 * jnp.pi)): Domain for y coordinates. Default is [0, 2π]. """ - errorif(domain[0] > domain[-1], msg="Got inverted y coordinate domain.") + lobatto = bool(lobatto) errorif(lobatto, NotImplementedError, "JAX has not implemented type 1 DCT.") + self.lobatto = lobatto + errorif(domain[0] > domain[-1], msg="Got inverted domain.") + self.domain = domain self.M = f.shape[-2] self.N = f.shape[-1] - self.lobatto = bool(lobatto) - self.domain = domain - self._c = ( - rfft( - dct(f, type=2 - self.lobatto, axis=-1) / (self.N - self.lobatto), - axis=-2, - ) - / self.M - ) + self._c = self._fast_transform(f, lobatto) @staticmethod def nodes(M, N, lobatto=False, domain=(0, 2 * jnp.pi), **kwargs): @@ -145,19 +133,23 @@ def nodes(M, N, lobatto=False, domain=(0, 2 * jnp.pi), **kwargs): Returns ------- - coords : jnp.ndarray + coord : jnp.ndarray Shape (M * N, 2). Grid of (x, y) points for optimal interpolation. """ x = fourier_pts(M) y = cheb_pts(N, lobatto, domain) - coords = ( - [jnp.atleast_1d(kwargs.pop("rho")), x, y] if "rho" in kwargs else [x, y] - ) - coords = list(map(jnp.ravel, jnp.meshgrid(*coords, indexing="ij"))) - coords = jnp.column_stack(coords) - return coords + coord = [jnp.atleast_1d(kwargs.pop("rho")), x, y] if "rho" in kwargs else [x, y] + coord = list(map(jnp.ravel, jnp.meshgrid(*coord, indexing="ij"))) + coord = jnp.column_stack(coord) + return coord + + @staticmethod + def _fast_transform(f, lobatto): + M = f.shape[-2] + N = f.shape[-1] + return rfft(dct(f, type=2 - lobatto, axis=-1), axis=-2) / (M * (N - lobatto)) def evaluate(self, M, N): """Evaluate Fourier-Chebyshev series. @@ -176,12 +168,9 @@ def evaluate(self, M, N): Fourier-Chebyshev series evaluated at ``FourierChebyshevBasis.nodes(M, N)``. """ - fq = idct( - irfft(self._c, n=M, axis=-2) * M, - type=2 - self.lobatto, - n=N, - axis=-1, - ) * (N - self.lobatto) + fq = idct(irfft(self._c, n=M, axis=-2), type=2 - self.lobatto, n=N, axis=-1) * ( + M * (N - self.lobatto) + ) return fq def harmonics(self): @@ -213,7 +202,7 @@ def compute_cheb(self, x): Returns ------- - cheb : _PiecewiseChebyshevBasis + cheb : PiecewiseChebyshevBasis Chebyshev coefficients αₙ(x=``x``) for f(x, y) = ∑ₙ₌₀ᴺ⁻¹ αₙ(x) Tₙ(y). """ @@ -221,10 +210,23 @@ def compute_cheb(self, x): x = jnp.atleast_1d(x)[..., jnp.newaxis] cheb = cheb_from_dct(irfft_non_uniform(x, self._c, self.M, axis=-2), axis=-1) assert cheb.shape[-2:] == (x.shape[-2], self.N) - return _PiecewiseChebyshevBasis(cheb, self.domain) + return PiecewiseChebyshevBasis(cheb, self.domain) -class _PiecewiseChebyshevBasis: +def _subtract(c, k): + # subtract k from last axis of c, obeying numpy broadcasting + c_0 = c[..., 0] - k + c = jnp.concatenate( + [ + c_0[..., jnp.newaxis], + jnp.broadcast_to(c[..., 1:], (*c_0.shape, c.shape[-1] - 1)), + ], + axis=-1, + ) + return c + + +class PiecewiseChebyshevBasis: """Chebyshev series. { fₓ | fₓ : y ↦ ∑ₙ₌₀ᴺ⁻¹ aₙ(x) Tₙ(y) } @@ -233,8 +235,10 @@ class _PiecewiseChebyshevBasis: Attributes ---------- cheb : jnp.ndarray - Shape (..., N). + Shape (..., M, N). Chebyshev coefficients αₙ(x) for fₓ(y) = ∑ₙ₌₀ᴺ⁻¹ αₙ(x) Tₙ(y). + M : int + Number of function in this basis set. N : int Chebyshev spectral resolution. domain : (float, float) @@ -250,25 +254,38 @@ def __init__(self, cheb, domain): Parameters ---------- cheb : jnp.ndarray - Shape (..., N). + Shape (..., M, N). Chebyshev coefficients αₙ(x=``x``) for f(x, y) = ∑ₙ₌₀ᴺ⁻¹ αₙ(x) Tₙ(y). """ - self.cheb = cheb - self.N = cheb.shape[-1] + errorif(domain[0] > domain[-1], msg="Got inverted domain.") self.domain = domain + self.cheb = jnp.atleast_2d(cheb) + + @property + def M(self): + """Number of function in this basis set.""" + return self.cheb.shape[-2] - def _chebcast(self, arr): + @property + def N(self): + """Chebyshev spectral resolution.""" + return self.cheb.shape[-1] + + @staticmethod + def _chebcast(cheb, arr): # Input should not have rightmost dimension of cheb that iterates coefficients, - # but may have additional leftmost dimensions for batch operations. + # but may have additional leftmost dimension for batch operation. errorif( - arr.ndim > self.cheb.ndim, + arr.ndim > cheb.ndim, NotImplementedError, - msg=f"Got ndim {arr.ndim} > cheb.ndim {self.cheb.ndim}.", + msg=f"Only one additional axis for batch dimension is allowed. " + f"Got {arr.ndim - cheb.ndim + 1} additional axes.", ) - return self.cheb if arr.ndim < self.cheb.ndim else self.cheb[jnp.newaxis] + # Don't add additional axis unless necessary to appease JIT compilation. + return cheb if arr.ndim < cheb.ndim else cheb[jnp.newaxis] - def intersect(self, k=0, eps=_eps): + def intersect(self, k, eps=_eps): """Coordinates yᵢ such that f(x, yᵢ) = k(x). Parameters @@ -295,15 +312,17 @@ def intersect(self, k=0, eps=_eps): Boolean array into ``y`` indicating whether element is an intersect. """ - c = _subtract(self._chebcast(k), k) + k = jnp.atleast_1d(k) + c = _subtract(self._chebcast(self.cheb, k), k) # roots yᵢ of f(x, y) = ∑ₙ₌₀ᴺ⁻¹ αₙ(x) Tₙ(y) - k(x) y = _chebroots_vec(c) assert y.shape == (*c.shape[:-1], self.N - 1) - y = _filter_distinct(y, sentinel=-2, eps=eps) - # Pick sentinel above such that only distinct roots are considered intersects. - is_intersect = (jnp.abs(y.imag) <= eps) & (jnp.abs(y.real) <= 1) - y = jnp.where(is_intersect, y.real, 0) # ensure y is in domain of arcos + # Intersects must satisfy y ∈ [-1, 1]. + # Pick sentinel such that only distinct roots are considered intersects. + y = _filter_distinct(y, sentinel=-2.0, eps=eps) + is_intersect = (jnp.abs(y.imag) <= eps) & (jnp.abs(y.real) <= 1.0) + y = jnp.where(is_intersect, y.real, 1.0) # ensure y is in domain of arcos # TODO: Multipoint evaluation with FFT. # Chapter 10, https://doi.org/10.1017/CBO9781139856065. @@ -317,7 +336,7 @@ def intersect(self, k=0, eps=_eps): is_decreasing = s <= 0 is_increasing = s >= 0 - y = bijection_from_disc(y, self.domain[0], self.domain[-1]) + y = bijection_from_disc(y, *self.domain) return y, is_decreasing, is_increasing, is_intersect def bounce_points( @@ -357,13 +376,15 @@ def bounce_points( ------- bp1, bp2 : (jnp.ndarray, jnp.ndarray) Shape (*y.shape[:-2], num_well). - The field line-following coordinates of bounce points for a given pitch - along a field line. The pairs ``bp1`` and ``bp2`` form left and right - integration boundaries, respectively, for the bounce integrals. + The field line-following coordinates of bounce points. + The pairs ``bp1`` and ``bp2`` form left and right integration boundaries, + respectively, for the bounce integrals. """ + errorif(self.N < 2, NotImplementedError, f"Got self.N = {self.N} < 2.") + # Flatten so that last axis enumerates intersects of a pitch along a field line. - y = _flatten_matrix(self._isomorphism_1d(y)) + y = _flatten_matrix(self._isomorphism_to_C1(y)) is_decreasing = _flatten_matrix(is_decreasing) is_increasing = _flatten_matrix(is_increasing) is_intersect = _flatten_matrix(is_intersect) @@ -375,44 +396,202 @@ def bounce_points( is_bp1 = is_decreasing & is_intersect is_bp2 = is_increasing & _fix_inversion(is_intersect, is_increasing) - sentinel = self.domain[0] - 1 + sentinel = self.domain[0] - 1.0 bp1 = take_mask(y, is_bp1, size=num_well, fill_value=sentinel) bp2 = take_mask(y, is_bp2, size=num_well, fill_value=sentinel) mask = (bp1 > sentinel) & (bp2 > sentinel) # Set outside mask to same value so integration is over set of measure zero. - bp1 = jnp.where(mask, bp1, 0) - bp2 = jnp.where(mask, bp2, 0) + bp1 = jnp.where(mask, bp1, 0.0) + bp2 = jnp.where(mask, bp2, 0.0) return bp1, bp2 + def eval1d(self, z, cheb=None): + """Evaluate piecewise Chebyshev spline at coordinates z. + + The coordinates z ∈ ℝ are assumed isomorphic to (x, y) ∈ ℝ² + where z integer division domain yields index into the proper + Chebyshev series of the spline and z mod domain is the coordinate + value along the domain of that Chebyshev series. + + Parameters + ---------- + z : jnp.ndarray + Shape (..., *cheb.shape[:-2], z.shape[-1]). + Isomorphic coordinates along field line [0, ∞). + cheb : jnp.ndarray + Shape (..., M, N). + Chebyshev coefficients to use. If not given, uses ``self.cheb``. + + Returns + ------- + f : jnp.ndarray + Shape z.shape. + Chebyshev basis evaluated at z. + + """ + cheb = self._chebcast(setdefault(cheb, self.cheb), z) + N = cheb.shape[-1] + x_idx, y = self._isomorphism_to_C2(z) + y = bijection_to_disc(y, self.domain[0], self.domain[1]) + # Chebyshev coefficients αₙ for f(z) = ∑ₙ₌₀ᴺ⁻¹ αₙ(x[z]) Tₙ(y[z]) + # are held in cheb with shape (..., num cheb series, N). + cheb = jnp.take_along_axis(cheb, x_idx[..., jnp.newaxis], axis=-2) + f = idct_non_uniform(y, cheb, N) + assert f.shape == z.shape + return f + + def _isomorphism_to_C1(self, y): + """Return coordinates z ∈ ℂ isomorphic to (x, y) ∈ ℂ². + + Maps row x of y to z = y + f(x) where f(x) = x * |domain|. + + Parameters + ---------- + y : jnp.ndarray + Shape (..., y.shape[-2], y.shape[-1]). + Second to last axis iterates the rows. + + Returns + ------- + z : jnp.ndarray + Shape y.shape. + Isomorphic coordinates. + + """ + assert y.ndim >= 2 + z_shift = jnp.arange(y.shape[-2]) * (self.domain[-1] - self.domain[0]) + return y + z_shift[:, jnp.newaxis] + + def _isomorphism_to_C2(self, z): + """Return coordinates (x, y) ∈ ℂ² isomorphic to z ∈ ℂ. + + Returns index x and value y such that z = f(x) + y where f(x) = x * |domain|. + + Parameters + ---------- + z : jnp.ndarray + Shape z.shape. + + Returns + ------- + x_idx, y_val : (jnp.ndarray, jnp.ndarray) + Shape z.shape. + Isomorphic coordinates. + + """ + x_idx, y_val = jnp.divmod(z - self.domain[0], self.domain[-1] - self.domain[0]) + return x_idx.astype(int), y_val + self.domain[0] + + def _check_shape(self, bp1, bp2, pitch): + """Return shapes that broadcast with (P, *self.cheb.shape[:-2], W).""" + # Ensure pitch batch dim exists and add back dim to broadcast with wells. + pitch = atleast_nd(self.cheb.ndim - 1, pitch)[..., jnp.newaxis] + # Same but back dim already exists. + bp1, bp2 = atleast_nd(self.cheb.ndim, bp1, bp2) + # Cheb has shape (..., M, N) and others + # have shape (P, ..., W) + errorif(not (bp1.ndim == bp2.ndim == pitch.ndim == self.cheb.ndim)) + return bp1, bp2, pitch + + def check_bounce_points(self, bp1, bp2, pitch, plot=True, **kwargs): + """Check that bounce points are computed correctly. + + Parameters + ---------- + bp1, bp2 : jnp.ndarray + Shape must broadcast with (P, *self.cheb.shape[:-2], W). + The field line-following coordinates of bounce points. + The pairs ``bp1`` and ``bp2`` form left and right integration boundaries, + respectively, for the bounce integrals. + pitch : jnp.ndarray + Shape must broadcast with (P, *self.cheb.shape[:-2]). + λ values to evaluate the bounce integral. + plot : bool + Whether to plot stuff. Default is true. + kwargs : dict + Keyword arguments into ``plot_field_line``. + + """ + assert bp1.shape == bp2.shape + mask = (bp1 - bp2) != 0.0 + bp1 = jnp.where(mask, bp1, jnp.nan) + bp2 = jnp.where(mask, bp2, jnp.nan) + bp1, bp2, pitch = self._check_shape(bp1, bp2, pitch) + + err_1 = jnp.any(bp1 > bp2, axis=-1) + err_2 = jnp.any(bp1[..., 1:] < bp2[..., :-1], axis=-1) + B_m = self.eval1d((bp1 + bp2) / 2) + assert B_m.shape == bp1.shape + err_3 = jnp.any(B_m > 1 / pitch + self._eps, axis=-1) + if not (plot or jnp.any(err_1 | err_2 | err_3)): + return + + # Ensure l axis exists for iteration in below loop. + cheb = atleast_nd(3, self.cheb) + mask, bp1, bp2, B_m = atleast_3d_mid(mask, bp1, bp2, B_m) + err_1, err_2, err_3 = atleast_2d_end(err_1, err_2, err_3) + + print(np.sum(mask)) + + for l in np.ndindex(cheb.shape[:-2]): + for p in range(pitch.shape[0]): + if not (err_1[p, l] or err_2[p, l] or err_3[p, l]): + continue + _bp1 = bp1[p, l][mask[p, l]] + _bp2 = bp2[p, l][mask[p, l]] + if plot: + self.plot_field_line( + cheb[l], + pitch=pitch[p, l], + bp1=_bp1, + bp2=_bp2, + title_id=f"{p},{l}", + **kwargs, + ) + print(" bp1 | bp2") + print(jnp.column_stack([_bp1, _bp2])) + assert not err_1[p, l], "Bounce points have an inversion.\n" + assert not err_2[p, l], "Detected discontinuity.\n" + assert not err_3[p, l], ( + "Detected |B| > 1/λ in well. Increase Chebyshev resolution.\n" + f"{B_m[p, l][mask[p, l]]} > {1 / pitch[p, l] + self._eps}" + ) + if plot: + self.plot_field_line( + cheb[l], + pitch=pitch[:, l], + bp1=bp1[:, l], + bp2=bp2[:, l], + title_id=str(l), + **kwargs, + ) + def plot_field_line( self, - start, - stop, + cheb, + bp1=jnp.array([[]]), + bp2=jnp.array([[]]), + pitch=jnp.array([]), num=1000, - bp1=np.array([]), - bp2=np.array([]), - pitch=np.array([]), title=r"Computed bounce points for $\vert B \vert$ and pitch $\lambda$", title_id=None, - transparency_pitch=0.3, + transparency_pitch=0.5, show=True, ): """Plot the field line given spline of |B|. Parameters ---------- - start : float - Minimum ζ on plot. - stop : float - Maximum ζ on plot. + cheb : jnp.ndarray + Piecewise Chebyshev coefficients of |B| along the field line. num : int Number of ζ points to plot. Pick a big number. - bp1 : np.ndarray + bp1 : jnp.ndarray Bounce points with (∂|B|/∂ζ)|ρ,α <= 0. - bp2 : np.ndarray + bp2 : jnp.ndarray Bounce points with (∂|B|/∂ζ)|ρ,α >= 0. - pitch : np.ndarray + pitch : jnp.ndarray λ value. title : str Plot title. @@ -428,23 +607,24 @@ def plot_field_line( fig, ax : matplotlib figure and axes. """ - errorif(start is None or stop is None) legend = {} def add(lines): - if not hasattr(lines, "__iter__"): - lines = [lines] - for line in lines: + for line in setdefault(lines, [lines], hasattr(lines, "__iter__")): label = line.get_label() if label not in legend: legend[label] = line fig, ax = plt.subplots() - z = np.linspace(start=start, stop=stop, num=num) - add(ax.plot(z, self.eval1d(z), label=r"$\vert B \vert (\zeta)$")) + z = jnp.linspace( + start=self.domain[0], + stop=self.domain[0] + (self.domain[1] - self.domain[0]) * self.M, + num=num, + ) + add(ax.plot(z, self.eval1d(z, cheb), label=r"$\vert B \vert (\zeta)$")) if pitch is not None: - b = 1 / np.atleast_1d(pitch) + b = 1 / jnp.atleast_1d(pitch) for val in b: add( ax.axhline( @@ -454,13 +634,16 @@ def add(lines): label=r"$1 / \lambda$", ) ) - bp1, bp2 = np.atleast_2d(bp1, bp2) + bp1, bp2 = jnp.atleast_2d(bp1, bp2) for i in range(bp1.shape[0]): - bp1_i, bp2_i = _filter_nonzero_measure(bp1[i], bp2[i]) + if bp1.shape == bp2.shape: + _bp1, _bp2 = filter_bounce_points(bp1[i], bp2[i]) + else: + _bp1, _bp2 = bp1[i], bp2[i] add( ax.scatter( - bp1_i, - np.full_like(bp1_i, b[i]), + _bp1, + jnp.full_like(_bp1, b[i]), marker="v", color="tab:red", label="bp1", @@ -468,8 +651,8 @@ def add(lines): ) add( ax.scatter( - bp2_i, - np.full_like(bp2_i, b[i]), + _bp2, + jnp.full_like(_bp2, b[i]), marker="^", color="tab:green", label="bp2", @@ -480,7 +663,7 @@ def add(lines): ax.set_ylabel(r"$\vert B \vert \sim 1 / \lambda$") ax.legend(legend.values(), legend.keys(), loc="lower right") if title_id is not None: - title = f"{title}. id = {title_id}." + title = f"{title}. ID={title_id}." ax.set_title(title) plt.tight_layout() if show: @@ -488,131 +671,6 @@ def add(lines): plt.close() return fig, ax - def check_bounce_points( - self, bp1, bp2, pitch, plot=True, start=None, stop=None, **kwargs - ): - """Check that bounce points are computed correctly.""" - pitch = jnp.atleast_3d(pitch) - errorif(not (pitch.ndim == bp1.ndim == bp2.ndim == 3), NotImplementedError) - errorif(bp1.shape != bp2.shape) - - P, L, num_wells = bp1.shape - msg_1 = "Bounce points have an inversion." - err_1 = jnp.any(bp1 > bp2, axis=-1) - msg_2 = "Discontinuity detected." - err_2 = jnp.any(bp1[..., 1:] < bp2[..., :-1], axis=-1) - - for l in range(L): - for p in range(P): - B_mid = self.eval1d((bp1[p, l] + bp2[p, l]) / 2) - err_3 = jnp.any(B_mid > 1 / pitch[p, l] + self._eps) - if err_1[p, l] or err_2[p, l] or err_3: - bp1_p, bp2_p = _filter_nonzero_measure(bp1[p, l], bp2[p, l]) - B_mid = B_mid[(bp1[p, l] - bp2[p, l]) != 0] - if plot: - self.plot_field_line( - start=start, - stop=stop, - pitch=pitch[p, l], - bp1=bp1_p, - bp2=bp2_p, - title_id=f"{p},{l}", - **kwargs, - ) - print("bp1:", bp1_p) - print("bp2:", bp2_p) - assert not err_1[p, l], msg_1 - assert not err_2[p, l], msg_2 - msg_3 = ( - f"Detected B midpoint = {B_mid}>{1 / pitch[p, l] + self._eps} =" - " 1/pitch. You need to use more knots." - ) - assert not err_3, msg_3 - if plot: - self.plot_field_line( - start=start, - stop=stop, - pitch=pitch[:, l], - bp1=bp1[:, l], - bp2=bp2[:, l], - title_id=str(l), - **kwargs, - ) - - def eval1d(self, z): - """Evaluate piecewise Chebyshev spline at coordinates z. - - The coordinates z ∈ ℝ are assumed isomorphic to (x, y) ∈ ℝ² - where z integer division domain yields index into the proper - Chebyshev series of the spline and z mod domain is the coordinate - value along the domain of that Chebyshev series. - - Parameters - ---------- - z : jnp.ndarray - Shape (..., *cheb.shape[:-2], z.shape[-1]). - Isomorphic coordinates along field line [0, ∞). - - Returns - ------- - f : jnp.ndarray - Shape z.shape. - Chebyshev basis evaluated at z. - - """ - x_idx, y = self._isomorphism_2d(z) - y = bijection_to_disc(y, self.domain[0], self.domain[1]) - # Chebyshev coefficients αₙ for f(z) = ∑ₙ₌₀ᴺ⁻¹ αₙ(x[z]) Tₙ(y[z]) - # are held in self.cheb with shape (..., num cheb series, N). - cheb = jnp.take_along_axis(self._chebcast(z), x_idx[..., jnp.newaxis], axis=-2) - f = idct_non_uniform(y, cheb, self.N) - assert f.shape == z.shape - return f - - def _isomorphism_1d(self, y): - """Return coordinates z ∈ ℂ isomorphic to (x, y) ∈ ℂ². - - Maps row x of y to z = α(x) + y where α(x) = x * |domain|. - - Parameters - ---------- - y : jnp.ndarray - Shape (..., y.shape[-2], y.shape[-1]). - Second to last axis iterates the rows. - - Returns - ------- - z : jnp.ndarray - Shape y.shape. - Isomorphic coordinates. - - """ - assert y.ndim >= 2 - period = self.domain[-1] - self.domain[0] - zeta_shift = period * jnp.arange(y.shape[-2]) - z = zeta_shift[:, jnp.newaxis] + y - return z - - def _isomorphism_2d(self, z): - """Return coordinates (x, y) ∈ ℂ² isomorphic to z ∈ ℂ. - - Returns index x and value y such that z = α(x) + y where α(x) = x * |domain|. - - Parameters - ---------- - z : jnp.ndarray - Shape z.shape. - - Returns - ------- - x_index, y_value : (jnp.ndarray, jnp.ndarray) - Shape z.shape. - Isomorphic coordinates. - - """ - x_index, y_value = jnp.divmod(z, self.domain[-1] - self.domain[0]) - return x_index.astype(int), y_value - def _bounce_quadrature(bp1, bp2, x, w, m, n, integrand, f, b_sup_z, B, T, pitch): """Bounce integrate ∫ f(ℓ) dℓ. @@ -655,12 +713,12 @@ def _bounce_quadrature(bp1, bp2, x, w, m, n, integrand, f, b_sup_z, B, T, pitch) b_sup_z : jnp.ndarray Shape (L, 1, m, n). Set of 2D Fourier spectral coefficients of B^ζ/|B|. - B : _PiecewiseChebyshevBasis + B : PiecewiseChebyshevBasis Set of 1D Chebyshev spectral coefficients of |B| along field line. - {|B|_α : ζ |B|(α, ζ) | α ∈ A } . - T : _PiecewiseChebyshevBasis + {|B|_α : ζ ↦ |B|(α, ζ) | α ∈ A }. + T : PiecewiseChebyshevBasis Set of 1D Chebyshev spectral coefficients of θ along field line. - {θ_α : ζ θ(α, ζ) | α ∈ A }. + {θ_α : ζ ↦ θ(α, ζ) | α ∈ A }. pitch : jnp.ndarray Shape (P, L, 1). λ values to evaluate the bounce integral at each field line. @@ -702,7 +760,7 @@ def _bounce_quadrature(bp1, bp2, x, w, m, n, integrand, f, b_sup_z, B, T, pitch) def required_names(): """Return names in ``data_index`` required to compute bounce integrals.""" - return ["B^zeta", "|B|"] + return ["B^zeta", "|B|", "iota"] # TODO: Assumes zeta = phi (alpha sequence) @@ -784,10 +842,10 @@ def bounce_integral( Poloidal coordinates A = (α₀, α₁, …, αₘ₋₁) that specify field line. B : _PiecewiseChebyshevBasis Set of 1D Chebyshev spectral coefficients of |B| along field line. - {|B|_α : ζ |B|(α, ζ) | α ∈ A } . + {|B|_α : ζ ↦ |B|(α, ζ) | α ∈ A }. T : _PiecewiseChebyshevBasis Set of 1D Chebyshev spectral coefficients of θ along field line. - {θ_α : ζ θ(α, ζ) | α ∈ A }. + {θ_α : ζ ↦ θ(α, ζ) | α ∈ A }. """ # Resolution of periodic DESC coordinate tensor-product grid. @@ -824,7 +882,7 @@ def bounce_integral( if automorphism is not None: auto, grad_auto = automorphism w = w * grad_auto(x) - # Recall affine_bijection(auto(x), ζ_b₁, ζ_b₂) = ζ. + # Recall bijection_from_disc(auto(x), ζ_b₁, ζ_b₂) = ζ. x = auto(x) def bounce_integrate(integrand, f, pitch, weight=None, num_well=None): @@ -885,6 +943,8 @@ def bounce_integrate(integrand, f, pitch, weight=None, num_well=None): or B.cheb.shape[0] == 1 ) bp1, bp2 = B.bounce_points(*B.intersect(1 / pitch), num_well) + if check: + B.check_bounce_points(bp1, bp2, pitch, plot=True) P = pitch.shape[0] num_well = bp1.shape[-1] assert bp1.shape == bp2.shape == (P, L, num_well) diff --git a/desc/utils.py b/desc/utils.py index 1547fc9e34..eb8e459fd7 100644 --- a/desc/utils.py +++ b/desc/utils.py @@ -2,13 +2,14 @@ import operator import warnings +from functools import partial from itertools import combinations_with_replacement, permutations import numpy as np from scipy.special import factorial from termcolor import colored -from desc.backend import fori_loop, jit, jnp +from desc.backend import flatnonzero, fori_loop, jit, jnp, take class Timer: @@ -689,3 +690,73 @@ def broadcast_tree(tree_in, tree_out, dtype=int): # invalid tree structure else: raise ValueError("trees must be nested lists of dicts") + + +@partial(jnp.vectorize, signature="(m),(m)->(n)", excluded={"size", "fill_value"}) +def take_mask(a, mask, size=None, fill_value=None): + """JIT compilable method to return ``a[mask][:size]`` padded by ``fill_value``. + + Parameters + ---------- + a : jnp.ndarray + The source array. + mask : jnp.ndarray + Boolean mask to index into ``a``. Should have same shape as ``a``. + size : int + Elements of ``a`` at the first size True indices of ``mask`` will be returned. + If there are fewer elements than size indicates, the returned array will be + padded with ``fill_value``. The size default is ``mask.size``. + fill_value : Any + When there are fewer than the indicated number of elements, the remaining + elements will be filled with ``fill_value``. Defaults to NaN for inexact types, + the largest negative value for signed types, the largest positive value for + unsigned types, and True for booleans. + + Returns + ------- + result : jnp.ndarray + Shape (size, ). + + """ + assert a.shape == mask.shape + idx = flatnonzero(mask, size=setdefault(size, mask.size), fill_value=mask.size) + return take( + a, + idx, + mode="fill", + fill_value=fill_value, + unique_indices=True, + indices_are_sorted=True, + ) + + +# TODO: Eventually remove and use numpy's stuff. +# https://github.com/numpy/numpy/issues/25805 +def atleast_nd(ndmin, *arys): + """Adds dimensions to front if necessary.""" + if ndmin == 1: + return jnp.atleast_1d(*arys) + if ndmin == 2: + return jnp.atleast_2d(*arys) + tup = tuple(jnp.array(ary, ndmin=ndmin) for ary in arys) + if len(tup) == 1: + tup = tup[0] + return tup + + +def atleast_3d_mid(*arys): + """Like np.atleast3d but if adds dim at axis 1 for 2d arrays.""" + arys = jnp.atleast_2d(*arys) + tup = tuple(ary[:, jnp.newaxis] if ary.ndim == 2 else ary for ary in arys) + if len(tup) == 1: + tup = tup[0] + return tup + + +def atleast_2d_end(*arys): + """Like np.atleast2d but if adds dim at axis 1 for 1d arrays.""" + arys = jnp.atleast_1d(*arys) + tup = tuple(ary[:, jnp.newaxis] if ary.ndim == 1 else ary for ary in arys) + if len(tup) == 1: + tup = tup[0] + return tup diff --git a/tests/test_bounce_integral.py b/tests/test_bounce_integral.py index e6e2719010..a09273657c 100644 --- a/tests/test_bounce_integral.py +++ b/tests/test_bounce_integral.py @@ -15,7 +15,12 @@ from tests.test_plotting import tol_1d from desc.backend import jnp -from desc.compute._quad_utils import ( +from desc.compute.utils import dot +from desc.equilibrium import Equilibrium +from desc.equilibrium.coords import get_rtz_grid +from desc.examples import get +from desc.grid import Grid, LinearGrid +from desc.integrals._quad_utils import ( automorphism_arcsin, automorphism_sin, bijection_from_disc, @@ -26,24 +31,18 @@ leggausslob, tanh_sinh, ) -from desc.compute.bounce_integral import ( +from desc.integrals.bounce_integral import ( _composite_linspace, - _filter_nonzero_measure, - _filter_not_nan, _get_extrema, _interp_to_argmin_B_hard, _interp_to_argmin_B_soft, bounce_integral, bounce_points, + filter_bounce_points, get_pitch, plot_field_line, required_names, ) -from desc.compute.utils import dot -from desc.equilibrium import Equilibrium -from desc.equilibrium.coords import get_rtz_grid -from desc.examples import get -from desc.grid import Grid, LinearGrid from desc.utils import only1 @@ -94,7 +93,8 @@ def test_get_extrema(): ) B_z_ra = B.derivative() extrema, B_extrema = _get_extrema(k, B.c, B_z_ra.c) - extrema, B_extrema = map(_filter_not_nan, (extrema, B_extrema)) + mask = ~np.isnan(extrema) + extrema, B_extrema = extrema[mask], B_extrema[mask] idx = np.argsort(extrema) extrema_scipy = np.sort(B_z_ra.roots(extrapolate=False)) @@ -130,7 +130,7 @@ def test_bp1_first(): pitch = 2.0 intersect = B.solve(1 / pitch, extrapolate=False) bp1, bp2 = bounce_points(pitch, knots, B.c, B.derivative().c, check=True) - bp1, bp2 = _filter_nonzero_measure(bp1, bp2) + bp1, bp2 = filter_bounce_points(bp1, bp2) assert bp1.size and bp2.size np.testing.assert_allclose(bp1, intersect[0::2]) np.testing.assert_allclose(bp2, intersect[1::2]) @@ -146,7 +146,7 @@ def test_bp2_first(): pitch = 2.0 intersect = B.solve(1 / pitch, extrapolate=False) bp1, bp2 = bounce_points(pitch, k, B.c, B.derivative().c, check=True) - bp1, bp2 = _filter_nonzero_measure(bp1, bp2) + bp1, bp2 = filter_bounce_points(bp1, bp2) assert bp1.size and bp2.size np.testing.assert_allclose(bp1, intersect[1:-1:2]) np.testing.assert_allclose(bp2, intersect[0::2][1:]) @@ -164,7 +164,7 @@ def test_bp1_before_extrema(): B_z_ra = B.derivative() pitch = 1 / B(B_z_ra.roots(extrapolate=False))[3] + 1e-13 bp1, bp2 = bounce_points(pitch, k, B.c, B_z_ra.c, check=True) - bp1, bp2 = _filter_nonzero_measure(bp1, bp2) + bp1, bp2 = filter_bounce_points(bp1, bp2) assert bp1.size and bp2.size intersect = B.solve(1 / pitch, extrapolate=False) np.testing.assert_allclose(bp1[1], 1.982767, rtol=1e-6) @@ -188,7 +188,7 @@ def test_bp2_before_extrema(): B_z_ra = B.derivative() pitch = 1 / B(B_z_ra.roots(extrapolate=False))[2] bp1, bp2 = bounce_points(pitch, k, B.c, B_z_ra.c, check=True) - bp1, bp2 = _filter_nonzero_measure(bp1, bp2) + bp1, bp2 = filter_bounce_points(bp1, bp2) assert bp1.size and bp2.size intersect = B.solve(1 / pitch, extrapolate=False) np.testing.assert_allclose(bp1, intersect[[0, -2]]) @@ -198,21 +198,6 @@ def test_bp2_before_extrema(): @pytest.mark.unit def test_extrema_first_and_before_bp1(): """Test that bounce points are computed correctly.""" - # In theory, this test should only pass if distinct=True when computing the - # intersections in bounce points. However, we can get lucky due to floating - # point errors, and it may also pass when distinct=False. - # If a regression fails this test, this note will save many hours of debugging. - # If the filter in place to return only the distinct roots is too coarse, - # in particular atol < 1e-15, then this test will error. In the resulting - # plot that the error will produce the red bounce point on the first hump - # disappears. The true sequence is green, double red, green, red, green. - # The first green was close to the double red and hence the first of the - # double red root pair was erased as it was falsely detected as a duplicate. - # The second of the double red root pair is correctly erased. All that is - # left is the green. Now the bounce_points method assumes the intermediate - # value theorem holds for the continuous spline, so when fed these sequence - # of roots, the correct action is to ignore the first green root since - # otherwise the interior of the bounce points would be hills and not valleys. start = -1.2 * np.pi end = -2 * start k = np.linspace(start, end, 7) @@ -227,7 +212,7 @@ def test_extrema_first_and_before_bp1(): pitch, k[2:], B.c[:, 2:], B_z_ra.c[:, 2:], check=True, plot=False ) plot_field_line(B, pitch, bp1, bp2, start=k[2]) - bp1, bp2 = _filter_nonzero_measure(bp1, bp2) + bp1, bp2 = filter_bounce_points(bp1, bp2) assert bp1.size and bp2.size intersect = B.solve(1 / pitch, extrapolate=False) np.testing.assert_allclose(bp1[0], 0.835319, rtol=1e-6) @@ -250,7 +235,7 @@ def test_extrema_first_and_before_bp2(): B_z_ra = B.derivative() pitch = 1 / B(B_z_ra.roots(extrapolate=False))[1] + 1e-13 bp1, bp2 = bounce_points(pitch, k, B.c, B_z_ra.c, check=True) - bp1, bp2 = _filter_nonzero_measure(bp1, bp2) + bp1, bp2 = filter_bounce_points(bp1, bp2) assert bp1.size and bp2.size # Our routine correctly detects intersection, while scipy, jnp.root fails. intersect = B.solve(1 / pitch, extrapolate=False) @@ -709,10 +694,7 @@ def integrand_den(B, pitch): num_well=1, weight=np.ones(zeta.size), ) - - drift_numerical_num = np.squeeze(drift_numerical_num) - drift_numerical_den = np.squeeze(drift_numerical_den) - drift_numerical = drift_numerical_num / drift_numerical_den + drift_numerical = np.squeeze(drift_numerical_num / drift_numerical_den) msg = "There should be one bounce integral per pitch in this example." assert drift_numerical.size == drift_analytic.size, msg np.testing.assert_allclose(drift_numerical, drift_analytic, atol=5e-3, rtol=5e-2) diff --git a/tests/test_fourier_bounce.py b/tests/test_fourier_bounce.py index 8718695766..e6b44aa4ac 100644 --- a/tests/test_fourier_bounce.py +++ b/tests/test_fourier_bounce.py @@ -3,22 +3,24 @@ import numpy as np import pytest from matplotlib import pyplot as plt +from numpy.polynomial.chebyshev import chebinterpolate, chebroots from numpy.polynomial.legendre import leggauss from tests.test_bounce_integral import _drift_analytic from tests.test_plotting import tol_1d from desc.backend import jnp -from desc.compute.bounce_integral import get_pitch -from desc.compute.fourier_bounce_integral import ( +from desc.equilibrium import Equilibrium +from desc.equilibrium.coords import get_rtz_grid, map_coordinates +from desc.examples import get +from desc.grid import Grid, LinearGrid +from desc.integrals._interp_utils import fourier_pts +from desc.integrals.bounce_integral import filter_bounce_points, get_pitch +from desc.integrals.fourier_bounce_integral import ( FourierChebyshevBasis, alpha_sequence, bounce_integral, required_names, ) -from desc.equilibrium import Equilibrium -from desc.equilibrium.coords import get_rtz_grid, map_coordinates -from desc.examples import get -from desc.grid import LinearGrid @pytest.mark.unit @@ -27,15 +29,56 @@ [(0, np.sqrt(2), 1, 2 * np.pi), (0, np.arange(1, 3) * np.sqrt(2), 5, 2 * np.pi)], ) def test_alpha_sequence(alpha_0, iota, num_period, period): - """Test field line poloidal label tracking utility.""" + """Test field line poloidal label tracking.""" iota = np.atleast_1d(iota) alphas = alpha_sequence(alpha_0, iota, num_period, period) assert alphas.shape == (iota.size, num_period) for i in range(iota.size): - assert np.unique(alphas[i]).size == num_period, "Is iota irrational?" + assert np.unique(alphas[i]).size == num_period, f"{iota} is irrational" print(alphas) +class TestBouncePoints: + """Test that bounce points are computed correctly.""" + + @staticmethod + def _cheb_intersect(cheb, k): + cheb = cheb.copy() + cheb[0] = cheb[0] - k + roots = chebroots(cheb) + intersect = roots[ + np.logical_and(np.isreal(roots), np.abs(roots.real) <= 1) + ].real + return intersect + + @staticmethod + def _periodic_fun(nodes, M, N): + alpha, zeta = nodes.T + f = -2 * np.cos(1 / (0.1 + zeta**2)) + 2 + return f.reshape(M, N) + + @pytest.mark.unit + def test_bp1_first(self): + """Test that bounce points are computed correctly.""" + pitch = 1 / np.linspace(1, 4, 20).reshape(20, 1) + M, N = 1, 10 + domain = (-1, 1) + nodes = FourierChebyshevBasis.nodes(M, N, domain=domain) + f = self._periodic_fun(nodes, M, N) + fcb = FourierChebyshevBasis(f, domain=domain) + pcb = fcb.compute_cheb(fourier_pts(M)) + bp1, bp2 = pcb.bounce_points(*pcb.intersect(1 / pitch)) + pcb.check_bounce_points(bp1, bp2, pitch.ravel()) + bp1, bp2 = filter_bounce_points(bp1, bp2) + + def f(z): + return -2 * np.cos(1 / (0.1 + z**2)) + 2 + + r = self._cheb_intersect(chebinterpolate(f, N), 1 / pitch) + np.testing.assert_allclose(bp1, r[::2], rtol=1e-3) + np.testing.assert_allclose(bp2, r[1::2], rtol=1e-3) + + @pytest.mark.unit def test_fourier_chebyshev(rho=1, M=8, N=32, f=lambda B, pitch: B * pitch): """Test bounce points...""" @@ -71,12 +114,18 @@ def test_drift(): np.testing.assert_allclose(rho, 0.5) # Make a set of nodes along a single fieldline. - grid_fsa = LinearGrid(rho=rho, M=eq.M_grid, N=eq.N_grid, sym=eq.sym, NFP=eq.NFP) - data = eq.compute(["iota"], grid=grid_fsa) - iota = grid_fsa.compress(data["iota"]).item() + grid_rtz = Grid.create_meshgrid( + [ + rho, + np.linspace(0, 2 * np.pi, eq.M_grid), + np.linspace(0, 2 * np.pi, eq.N_grid + 1), + ], + ) + data = eq.compute(["iota"], grid=grid_rtz) + iota = grid_rtz.compress(data["iota"]).item() alpha = 0 zeta = np.linspace(-np.pi / iota, np.pi / iota, (2 * eq.M_grid) * 4 + 1) - grid = get_rtz_grid( + grid_raz = get_rtz_grid( eq, rho, alpha, @@ -97,7 +146,7 @@ def test_drift(): "psi", "a", ], - grid=grid, + grid=grid_raz, ) np.testing.assert_allclose(data["psi"], psi) np.testing.assert_allclose(data["iota"], iota) @@ -107,20 +156,38 @@ def test_drift(): data["rho"] = rho data["alpha"] = alpha data["zeta"] = zeta - data["psi"] = grid.compress(data["psi"]) - data["iota"] = grid.compress(data["iota"]) - data["shear"] = grid.compress(data["shear"]) - + data["psi"] = grid_raz.compress(data["psi"]) + data["iota"] = grid_raz.compress(data["iota"]) + data["shear"] = grid_raz.compress(data["shear"]) # Compute analytic approximation. drift_analytic, cvdrift, gbdrift, pitch = _drift_analytic(data) + # Compute numerical result. + M, N = eq.M_grid, 100 + clebsch = FourierChebyshevBasis.nodes(M=eq.M_grid, N=N, rho=rho) + data_2 = eq.compute(names=required_names() + ["cvdrift", "gbdrift"], grid=grid_rtz) + normalization = -np.sign(data["psi"]) * data["B ref"] * data["a"] ** 2 + cvdrift = data_2["cvdrift"] * normalization + gbdrift = data_2["gbdrift"] * normalization bounce_integrate, _ = bounce_integral( - data, - knots=zeta, - B_ref=B_ref, + grid_rtz, + data_2, + M, + N, + desc_from_clebsch=map_coordinates( + eq, + clebsch, + inbasis=("rho", "alpha", "zeta"), + period=(np.inf, 2 * np.pi, np.inf), + iota=np.broadcast_to(data["iota"], (M * N)), + ), + alpha_0=data["alpha"], + num_transit=5, + B_ref=data["B ref"], L_ref=data["a"], quad=leggauss(28), # converges to absolute and relative tolerance of 1e-7 check=True, + plot=True, ) def integrand_num(cvdrift, gbdrift, B, pitch): @@ -141,12 +208,8 @@ def integrand_den(B, pitch): f=[], pitch=pitch[:, np.newaxis], num_well=1, - weight=np.ones(zeta.size), ) - - drift_numerical_num = np.squeeze(drift_numerical_num) - drift_numerical_den = np.squeeze(drift_numerical_den) - drift_numerical = drift_numerical_num / drift_numerical_den + drift_numerical = np.squeeze(drift_numerical_num / drift_numerical_den) msg = "There should be one bounce integral per pitch in this example." assert drift_numerical.size == drift_analytic.size, msg np.testing.assert_allclose(drift_numerical, drift_analytic, atol=5e-3, rtol=5e-2) @@ -154,4 +217,5 @@ def integrand_den(B, pitch): fig, ax = plt.subplots() ax.plot(1 / pitch, drift_analytic) ax.plot(1 / pitch, drift_numerical) + plt.show() return fig diff --git a/tests/test_interp_utils.py b/tests/test_interp_utils.py index 1f47e74418..9cfd2239eb 100644 --- a/tests/test_interp_utils.py +++ b/tests/test_interp_utils.py @@ -14,7 +14,7 @@ from scipy.fft import idct as sidct from desc.backend import dct, idct, jnp, rfft -from desc.compute._interp_utils import ( +from desc.integrals._interp_utils import ( cheb_from_dct, cheb_pts, harmonic, @@ -26,8 +26,14 @@ polyder_vec, polyval_vec, ) -from desc.compute._quad_utils import bijection_to_disc -from desc.compute.bounce_integral import _filter_not_nan +from desc.integrals._quad_utils import bijection_to_disc + + +def filter_not_nan(a): + """Filter out nan from ``a`` while asserting nan is padded at right.""" + is_nan = jnp.isnan(a) + assert jnp.array_equal(is_nan, jnp.sort(is_nan, axis=-1)) + return a[~is_nan] @pytest.mark.unit @@ -64,7 +70,7 @@ def test_poly_root(): root = poly_root(c.T, sort=True, distinct=True) for j in range(c.shape[0]): unique_roots = np.unique(np.roots(c[j])) - root_filter = _filter_not_nan(root[j], check=True) + root_filter = filter_not_nan(root[j]) assert root_filter.size == unique_roots.size, j np.testing.assert_allclose( actual=root_filter, @@ -72,7 +78,7 @@ def test_poly_root(): err_msg=str(j), ) c = np.array([0, 1, -1, -8, 12]) - root = _filter_not_nan(poly_root(c, sort=True, distinct=True), check=True) + root = filter_not_nan(poly_root(c, sort=True, distinct=True)) unique_root = np.unique(np.roots(c)) assert root.size == unique_root.size np.testing.assert_allclose(root, unique_root)