diff --git a/desc/compute/_basis_vectors.py b/desc/compute/_basis_vectors.py index 72803a1d7..4787eb860 100644 --- a/desc/compute/_basis_vectors.py +++ b/desc/compute/_basis_vectors.py @@ -3223,17 +3223,61 @@ def _e_sub_zeta_zz(params, transforms, profiles, data, **kwargs): transforms={}, profiles=[], coordinates="rtz", - data=["e^rho", "e^theta", "e^zeta", "alpha_r", "alpha_t", "alpha_z"], + data=["periodic(grad(alpha))", "secular(grad(alpha))"], ) def _grad_alpha(params, transforms, profiles, data, **kwargs): - data["grad(alpha)"] = ( - data["alpha_r"] * data["e^rho"].T + data["grad(alpha)"] = data["periodic(grad(alpha))"] + data["secular(grad(alpha))"] + return data + + +@register_compute_fun( + name="periodic(grad(alpha))", + label="\\mathrm{periodic}(\\nabla \\alpha)", + units="m^{-1}", + units_long="Inverse meters", + description=( + "Gradient of field line label, which is perpendicular to the field line, " + "periodic component" + ), + dim=3, + params=[], + transforms={}, + profiles=[], + coordinates="rtz", + data=["e^rho", "e^theta", "e^zeta", "periodic(alpha_r)", "alpha_t", "alpha_z"], +) +def _periodic_grad_alpha(params, transforms, profiles, data, **kwargs): + data["periodic(grad(alpha))"] = ( + data["periodic(alpha_r)"] * data["e^rho"].T + data["alpha_t"] * data["e^theta"].T + data["alpha_z"] * data["e^zeta"].T ).T return data +@register_compute_fun( + name="secular(grad(alpha))", + label="\\mathrm{secular}(\\nabla \\alpha)", + units="m^{-1}", + units_long="Inverse meters", + description=( + "Gradient of field line label, which is perpendicular to the field line, " + "periodic component" + ), + dim=3, + params=[], + transforms={}, + profiles=[], + coordinates="rtz", + data=["e^rho", "secular(alpha_r)"], +) +def _secular_grad_alpha(params, transforms, profiles, data, **kwargs): + data["secular(grad(alpha))"] = ( + data["secular(alpha_r)"][:, jnp.newaxis] * data["e^rho"] + ) + return data + + @register_compute_fun( name="grad(psi)", label="\\nabla\\psi", diff --git a/desc/compute/_core.py b/desc/compute/_core.py index 784e296f7..ef0466f8b 100644 --- a/desc/compute/_core.py +++ b/desc/compute/_core.py @@ -1518,14 +1518,48 @@ def _alpha(params, transforms, profiles, data, **kwargs): transforms={}, profiles=[], coordinates="rtz", - data=["theta_PEST_r", "phi", "phi_r", "iota", "iota_r"], + data=["periodic(alpha_r)", "secular(alpha_r)"], ) def _alpha_r(params, transforms, profiles, data, **kwargs): - data["alpha_r"] = ( - data["theta_PEST_r"] - - data["iota_r"] * data["phi"] - - data["iota"] * data["phi_r"] - ) + data["alpha_r"] = data["periodic(alpha_r)"] + data["secular(alpha_r)"] + return data + + +@register_compute_fun( + name="periodic(alpha_r)", + label="\\mathrm{periodic}(\\partial_\\rho \\alpha)", + units="~", + units_long="None", + description="Field line label, derivative wrt radial coordinate, " + "periodic component", + dim=1, + params=[], + transforms={}, + profiles=[], + coordinates="rtz", + data=["theta_PEST_r", "iota", "phi_r"], +) +def _periodic_alpha_r(params, transforms, profiles, data, **kwargs): + data["periodic(alpha_r)"] = data["theta_PEST_r"] - data["iota"] * data["phi_r"] + return data + + +@register_compute_fun( + name="secular(alpha_r)", + label="\\mathrm{secular}(\\partial_\\rho \\alpha)", + units="~", + units_long="None", + description="Field line label, derivative wrt radial coordinate, " + "secular component", + dim=1, + params=[], + transforms={}, + profiles=[], + coordinates="rtz", + data=["iota_r", "phi"], +) +def _secular_alpha_r(params, transforms, profiles, data, **kwargs): + data["secular(alpha_r)"] = -data["iota_r"] * data["phi"] return data diff --git a/desc/compute/_metric.py b/desc/compute/_metric.py index da4353ad0..0d66b8305 100644 --- a/desc/compute/_metric.py +++ b/desc/compute/_metric.py @@ -1941,8 +1941,8 @@ def _g_sup_ra(params, transforms, profiles, data, **kwargs): # Exact definition of the magnetic drifts taken from # eqn. 48 of Introduction to Quasisymmetry by Landreman # https://tinyurl.com/54udvaa4 - label="\\mathrm{gbdrift} = 1/B^{2} (\\mathbf{b}\\times\\nabla B) \\cdot" - + "\\nabla \\alpha", + label="(\\nabla \\vert B \\vert)_{\\mathrm{drift}} = " + "(\\mathbf{b} \\times \\nabla B) \\cdot \\nabla \\alpha / \\vert B \\vert^{2}", units="1/(T-m^{2})", units_long="inverse Tesla meters^2", description="Binormal component of the geometric part of the gradB drift" @@ -1952,13 +1952,61 @@ def _g_sup_ra(params, transforms, profiles, data, **kwargs): transforms={}, profiles=[], coordinates="rtz", - data=["|B|^2", "b", "grad(alpha)", "grad(|B|)"], + data=["periodic(gbdrift)", "secular(gbdrift)"], ) def _gbdrift(params, transforms, profiles, data, **kwargs): - data["gbdrift"] = ( - 1 + data["gbdrift"] = data["periodic(gbdrift)"] + data["secular(gbdrift)"] + return data + + +@register_compute_fun( + name="periodic(gbdrift)", + # Exact definition of the magnetic drifts taken from + # eqn. 48 of Introduction to Quasisymmetry by Landreman + # https://tinyurl.com/54udvaa4 + label="\\mathrm{periodic}(\\nabla \\vert B \\vert)_{\\mathrm{drift}}", + units="1/(T-m^{2})", + units_long="inverse Tesla meters^2", + description="Binormal component of the geometric part of the gradB drift" + + " used for local stability analyses, Gamma_c, epsilon_eff etc." + " Periodic component.", + dim=1, + params=[], + transforms={}, + profiles=[], + coordinates="rtz", + data=["|B|^2", "b", "periodic(grad(alpha))", "grad(|B|)"], +) +def _periodic_gbdrift(params, transforms, profiles, data, **kwargs): + data["periodic(gbdrift)"] = ( + dot(data["b"], cross(data["grad(|B|)"], data["periodic(grad(alpha))"])) + / data["|B|^2"] + ) + return data + + +@register_compute_fun( + name="secular(gbdrift)", + # Exact definition of the magnetic drifts taken from + # eqn. 48 of Introduction to Quasisymmetry by Landreman + # https://tinyurl.com/54udvaa4 + label="\\mathrm{secular}(\\nabla \\vert B \\vert)_{\\mathrm{drift}}", + units="1/(T-m^{2})", + units_long="inverse Tesla meters^2", + description="Binormal component of the geometric part of the gradB drift" + + " used for local stability analyses, Gamma_c, epsilon_eff etc. " + "Secular component.", + dim=1, + params=[], + transforms={}, + profiles=[], + coordinates="rtz", + data=["|B|^2", "b", "secular(grad(alpha))", "grad(|B|)"], +) +def _secular_gbdrift(params, transforms, profiles, data, **kwargs): + data["secular(gbdrift)"] = ( + dot(data["b"], cross(data["grad(|B|)"], data["secular(grad(alpha))"])) / data["|B|^2"] - * dot(data["b"], cross(data["grad(|B|)"], data["grad(alpha)"])) ) return data @@ -1979,11 +2027,35 @@ def _gbdrift(params, transforms, profiles, data, **kwargs): transforms={}, profiles=[], coordinates="rtz", - data=["p_r", "psi_r", "|B|^2", "gbdrift"], + data=["periodic(cvdrift)", "secular(gbdrift)"], ) def _cvdrift(params, transforms, profiles, data, **kwargs): - dp_dpsi = mu_0 * data["p_r"] / data["psi_r"] - data["cvdrift"] = 1 / data["|B|^2"] * dp_dpsi + data["gbdrift"] + data["cvdrift"] = data["periodic(cvdrift)"] + data["secular(gbdrift)"] + return data + + +@register_compute_fun( + name="periodic(cvdrift)", + # Exact definition of the magnetic drifts taken from + # eqn. 48 of Introduction to Quasisymmetry by Landreman + # https://tinyurl.com/54udvaa4 + label="\\mathrm{periodic(cvdrift)}", + units="1/(T-m^{2})", + units_long="inverse Tesla meters^2", + description="Binormal component of the geometric part of the curvature drift" + + " used for local stability analyses, Gamma_c, epsilon_eff etc. " + "Periodic component.", + dim=1, + params=[], + transforms={}, + profiles=[], + coordinates="rtz", + data=["p_r", "psi_r", "|B|^2", "periodic(gbdrift)"], +) +def _periodic_cvdrift(params, transforms, profiles, data, **kwargs): + data["periodic(cvdrift)"] = ( + mu_0 * data["p_r"] / data["psi_r"] / data["|B|^2"] + data["periodic(gbdrift)"] + ) return data diff --git a/desc/grid.py b/desc/grid.py index c4b8a46e7..4b7bb281e 100644 --- a/desc/grid.py +++ b/desc/grid.py @@ -671,6 +671,7 @@ class Grid(_Grid): Use np.inf to denote no periodicity. NFP : int Number of field periods (Default = 1). + Change this only if your nodes are placed within one field period. source_grid : Grid Grid from which coordinates were mapped from. sort : bool @@ -794,7 +795,8 @@ def create_meshgrid( NFP : int Number of field periods (Default = 1). Only makes sense to change from 1 if last coordinate is periodic - with some constant divided by ``NFP``. + with some constant divided by ``NFP`` and the nodes are placed + within one field period. Returns ------- @@ -916,6 +918,8 @@ class LinearGrid(_Grid): Toroidal grid resolution. NFP : int Number of field periods (Default = 1). + Change this only if your nodes are placed within one field period + or should be interpreted as spanning one field period. sym : bool True for stellarator symmetry, False otherwise (Default = False). axis : bool @@ -1011,6 +1015,8 @@ def _create_nodes( # noqa: C901 Toroidal grid resolution. NFP : int Number of field periods (Default = 1). + Only change this if your nodes are placed within one field period + or should be interpreted as spanning one field period. axis : bool True to include a point at rho=0 (default), False for rho[0] = rho[1]/2. endpoint : bool @@ -1037,8 +1043,10 @@ def _create_nodes( # noqa: C901 """ self._NFP = check_posint(NFP, "NFP", False) self._period = (np.inf, 2 * np.pi, 2 * np.pi / self._NFP) - # TODO: + # FIXME: # https://github.com/PlasmaControl/DESC/pull/1204#pullrequestreview-2246771337 + # Quantities like alpha, grad(alpha), etc. are computed incorrectly at + # phi > 2pi / NFP and theta > 2pi / NFP. axis = bool(axis) endpoint = bool(endpoint) theta_period = self.period[1] diff --git a/desc/integrals/__init__.py b/desc/integrals/__init__.py index 88cc3001c..e3d59d02e 100644 --- a/desc/integrals/__init__.py +++ b/desc/integrals/__init__.py @@ -1,6 +1,6 @@ """Classes for function integration.""" -from .bounce_integral import Bounce1D +from .bounce_integral import Bounce1D, Bounce2D from .singularities import ( DFTInterpolator, FFTInterpolator, diff --git a/desc/integrals/basis.py b/desc/integrals/basis.py index 91a31edf6..18379fc34 100644 --- a/desc/integrals/basis.py +++ b/desc/integrals/basis.py @@ -1,9 +1,36 @@ -"""Fast transformable basis.""" +"""Fast transformable series.""" from functools import partial -from desc.backend import flatnonzero, jnp, put -from desc.utils import setdefault +import numpy as np +from matplotlib import pyplot as plt + +from desc.backend import dct, flatnonzero, idct, irfft, jnp, put, rfft +from desc.integrals.interp_utils import ( + _eps, + _filter_distinct, + _subtract_first, + cheb_from_dct, + cheb_pts, + chebroots_vec, + dct_from_cheb, + fourier_pts, + harmonic, + idct_non_uniform, + irfft_non_uniform, +) +from desc.integrals.quad_utils import bijection_from_disc, bijection_to_disc +from desc.io import IOAble +from desc.utils import ( + atleast_2d_end, + atleast_3d_mid, + atleast_nd, + errorif, + flatten_matrix, + isposint, + setdefault, + take_mask, +) @partial(jnp.vectorize, signature="(m),(m)->(m)") @@ -58,6 +85,615 @@ def _in_epigraph_and(is_intersect, df_dy_sign, /): return put(is_intersect, idx[0], edge_case) +def _chebcast(cheb, arr): + """Add leftmost axis to ``cheb`` depending on ``arr.ndim``. + + Input ``arr`` should not have rightmost dimension of cheb that iterates + coefficients, but may have additional leftmost dimension for batch operation. + """ + errorif( + jnp.ndim(arr) > cheb.ndim, + NotImplementedError, + msg=f"Only one additional axis for batch dimension is allowed. " + f"Got {jnp.ndim(arr) - cheb.ndim + 1} additional axes.", + ) + return cheb if jnp.ndim(arr) < cheb.ndim else cheb[jnp.newaxis] + + +class FourierChebyshevSeries(IOAble): + """Fourier-Chebyshev series. + + f(x, y) = ∑ₘₙ aₘₙ ψₘ(x) Tₙ(y) + where ψₘ are trigonometric polynomials on [0, 2π] + and Tₙ are Chebyshev polynomials on [−yₘᵢₙ, yₘₐₓ]. + + Notes + ----- + Performance may improve significantly + if the spectral resolutions ``M`` and ``N`` are powers of two. + + + Parameters + ---------- + f : jnp.ndarray + Shape (..., M, N). + Samples of real function on the ``FourierChebyshevSeries.nodes`` grid. + domain : tuple[float] + Domain for y coordinates. Default is [-1, 1]. + lobatto : bool + Whether ``f`` was sampled on the Gauss-Lobatto (extrema-plus-endpoint) + instead of the interior roots grid for Chebyshev points. + + Attributes + ---------- + M : int + Fourier spectral resolution. + N : int + Chebyshev spectral resolution. + + """ + + def __init__(self, f, domain=(-1, 1), lobatto=False): + """Interpolate Fourier-Chebyshev series to ``f``.""" + self.M = f.shape[-2] + self.N = f.shape[-1] + errorif(domain[0] > domain[-1], msg="Got inverted domain.") + self.domain = tuple(domain) + errorif(lobatto, NotImplementedError, "JAX hasn't implemented type 1 DCT.") + self.lobatto = bool(lobatto) + self._c = FourierChebyshevSeries._transform(f, self.lobatto) + + @staticmethod + def _transform(f, lobatto): + N = f.shape[-1] + return rfft( + dct(f, type=2 - lobatto, axis=-1) / (N - lobatto), + axis=-2, + norm="forward", + ) + + @staticmethod + def nodes(M, N, L=None, domain=(-1, 1), lobatto=False): + """Tensor product grid of optimal collocation nodes for this basis. + + Parameters + ---------- + M : int + Grid resolution in x direction. Preferably power of 2. + N : int + Grid resolution in y direction. Preferably power of 2. + L : int or jnp.ndarray + Optional, resolution in radial direction of domain [0, 1]. + May also be an array of coordinates values. If given, then the + returned ``coords`` is a 3D tensor-product with shape (L * M * N, 3). + domain : tuple[float] + Domain for y coordinates. Default is [-1, 1]. + lobatto : bool + Whether to use the Gauss-Lobatto (Extrema-plus-Endpoint) + instead of the interior roots grid for Chebyshev points. + + Returns + ------- + coords : jnp.ndarray + Shape (M * N, 2). + Grid of (x, y) points for optimal interpolation. + + """ + x = fourier_pts(M) + y = cheb_pts(N, domain, lobatto) + if L is None: + coords = (x, y) + else: + if isposint(L): + L = jnp.flipud(jnp.linspace(1, 0, L, endpoint=False)) + coords = (jnp.atleast_1d(L), x, y) + coords = tuple(map(jnp.ravel, jnp.meshgrid(*coords, indexing="ij"))) + return jnp.column_stack(coords) + + def evaluate(self, M, N): + """Evaluate Fourier-Chebyshev series. + + Parameters + ---------- + M : int + Grid resolution in x direction. Preferably power of 2. + N : int + Grid resolution in y direction. Preferably power of 2. + + Returns + ------- + fq : jnp.ndarray + Shape (..., M, N) + Fourier-Chebyshev series evaluated at + ``FourierChebyshevSeries.nodes(M,N,L,self.domain,self.lobatto)``. + + """ + return idct( + irfft(self._c, n=M, axis=-2, norm="forward"), + type=2 - self.lobatto, + n=N, + axis=-1, + ) * (N - self.lobatto) + + def harmonics(self): + """Spectral coefficients aₘₙ of the interpolating trigonometric polynomial. + + Transform Fourier interpolant harmonics to Nyquist trigonometric + interpolant harmonics so that the coefficients are all real. + + Returns + ------- + a_mn : jnp.ndarray + Shape (..., M, N). + Real valued spectral coefficients for Fourier-Chebyshev series. + + """ + a_mn = harmonic(cheb_from_dct(self._c), self.M, axis=-2) + assert a_mn.shape[-2:] == (self.M, self.N) + return a_mn + + def compute_cheb(self, x): + """Evaluate Fourier series at ``x`` to obtain set of 1D Chebyshev coefficients. + + Parameters + ---------- + x : jnp.ndarray + Points to evaluate Fourier series. + + Returns + ------- + cheb : PiecewiseChebyshevSeries + Chebyshev coefficients αₙ(x=``x``) for f(x, y) = ∑ₙ₌₀ᴺ⁻¹ αₙ(x) Tₙ(y). + + """ + # Add axis to broadcast against Chebyshev coefficients. + x = jnp.atleast_1d(x)[..., jnp.newaxis] + # Add axis to broadcast against multiple x values. + cheb = cheb_from_dct( + irfft_non_uniform(x, self._c[..., jnp.newaxis, :, :], self.M, axis=-2) + ) + assert cheb.shape[-2:] == (x.shape[-2], self.N) + return PiecewiseChebyshevSeries(cheb, self.domain) + + +class PiecewiseChebyshevSeries(IOAble): + """Chebyshev series. + + { fₓ | fₓ : y ↦ ∑ₙ₌₀ᴺ⁻¹ aₙ(x) Tₙ(y) } + and Tₙ are Chebyshev polynomials on [−yₘᵢₙ, yₘₐₓ] + + Parameters + ---------- + cheb : jnp.ndarray + Shape (..., M, N). + Chebyshev coefficients αₙ(x) for f(x, y) = ∑ₙ₌₀ᴺ⁻¹ αₙ(x) Tₙ(y). + domain : tuple[float] + Domain for y coordinates. Default is [-1, 1]. + + """ + + def __init__(self, cheb, domain=(-1, 1)): + """Make piecewise series from given Chebyshev coefficients.""" + self.cheb = jnp.atleast_2d(cheb) + errorif(domain[0] > domain[-1], msg="Got inverted domain.") + self.domain = tuple(domain) + + @property + def M(self): + """Number of cuts.""" + return self.cheb.shape[-2] + + @property + def N(self): + """Chebyshev spectral resolution.""" + return self.cheb.shape[-1] + + def stitch(self): + """Enforce the piecewise series is continuous.""" + # evaluate at left boundary + f_0 = self.cheb[..., ::2].sum(axis=-1) - self.cheb[..., 1::2].sum(axis=-1) + # evaluate at right boundary + f_1 = self.cheb.sum(axis=-1) + dfx = f_1[..., :-1] - f_0[..., 1:] # Δf = f(xᵢ, y₁) - f(xᵢ₊₁, y₀) + self.cheb = self.cheb.at[..., 1:, 0].add(dfx.cumsum(axis=-1)) + + def evaluate(self, N): + """Evaluate Chebyshev series at N Chebyshev points. + + Evaluate each function in this set + { fₓ | fₓ : y ↦ ∑ₙ₌₀ᴺ⁻¹ aₙ(x) Tₙ(y) } + at y points given by the N Chebyshev points. + + Parameters + ---------- + N : int + Grid resolution in y direction. Preferably power of 2. + + Returns + ------- + fq : jnp.ndarray + Shape (..., M, N) + Chebyshev series evaluated at N Chebyshev points. + + """ + return idct(dct_from_cheb(self.cheb), type=2, n=N, axis=-1) * N + + def isomorphism_to_C1(self, y): + """Return coordinates z ∈ ℂ isomorphic to (x, y) ∈ ℂ². + + Maps row x of y to z = y + f(x) where f(x) = x * |domain|. + + Parameters + ---------- + y : jnp.ndarray + Shape (..., y.shape[-2], y.shape[-1]). + Second to last axis iterates the rows. + + Returns + ------- + z : jnp.ndarray + Shape y.shape. + Isomorphic coordinates. + + """ + assert y.ndim >= 2 + z_shift = jnp.arange(y.shape[-2]) * (self.domain[-1] - self.domain[0]) + z = y + z_shift[:, jnp.newaxis] + return z + + def isomorphism_to_C2(self, z): + """Return coordinates (x, y) ∈ ℂ² isomorphic to z ∈ ℂ. + + Returns index x and minimum value y such that + z = f(x) + y where f(x) = x * |domain|. + + Parameters + ---------- + z : jnp.ndarray + Shape z.shape. + + Returns + ------- + x_idx, y_val : (jnp.ndarray, jnp.ndarray) + Shape z.shape. + Isomorphic coordinates. + + """ + x_idx, y_val = jnp.divmod(z - self.domain[0], self.domain[-1] - self.domain[0]) + x_idx = x_idx.astype(int) + y_val += self.domain[0] + return x_idx, y_val + + def eval1d(self, z, cheb=None): + """Evaluate piecewise Chebyshev series at coordinates z. + + Parameters + ---------- + z : jnp.ndarray + Shape (..., *cheb.shape[:-2], z.shape[-1]). + Coordinates in [self.domain[0], ∞). + The coordinates z ∈ ℝ are assumed isomorphic to (x, y) ∈ ℝ² where + ``z // domain`` yields the index into the proper Chebyshev series + along the second to last axis of ``cheb`` and ``z % domain`` is + the coordinate value on the domain of that Chebyshev series. + cheb : jnp.ndarray + Shape (..., M, N). + Chebyshev coefficients to use. If not given, uses ``self.cheb``. + + Returns + ------- + f : jnp.ndarray + Shape z.shape. + Chebyshev series evaluated at z. + + """ + cheb = _chebcast(setdefault(cheb, self.cheb), z) + N = cheb.shape[-1] + x_idx, y = self.isomorphism_to_C2(z) + y = bijection_to_disc(y, self.domain[0], self.domain[1]) + # Chebyshev coefficients αₙ for f(z) = ∑ₙ₌₀ᴺ⁻¹ αₙ(x[z]) Tₙ(y[z]) + # are held in cheb with shape (..., num cheb series, N). + cheb = jnp.take_along_axis(cheb, x_idx[..., jnp.newaxis], axis=-2) + f = idct_non_uniform(y, cheb, N) + assert f.shape == z.shape + return f + + def intersect2d(self, k=0.0, *, eps=_eps): + """Coordinates yᵢ such that f(x, yᵢ) = k(x). + + Parameters + ---------- + k : jnp.ndarray + Shape must broadcast with (..., *cheb.shape[:-1]). + Specify to find solutions yᵢ to f(x, yᵢ) = k(x). Default 0. + eps : float + Absolute tolerance with which to consider value as zero. + + Returns + ------- + y : jnp.ndarray + Shape (..., *cheb.shape[:-1], N - 1). + Solutions yᵢ of f(x, yᵢ) = k(x), in ascending order. + is_intersect : jnp.ndarray + Shape y.shape. + Boolean array into ``y`` indicating whether element is an intersect. + df_dy_sign : jnp.ndarray + Shape y.shape. + Sign of ∂f/∂y (x, yᵢ). + + """ + c = _subtract_first(_chebcast(self.cheb, k), k) + # roots yᵢ of f(x, y) = ∑ₙ₌₀ᴺ⁻¹ αₙ(x) Tₙ(y) - k(x) + y = chebroots_vec(c) + assert y.shape == (*c.shape[:-1], self.N - 1) + + # Intersects must satisfy y ∈ [-1, 1]. + # Pick sentinel such that only distinct roots are considered intersects. + y = _filter_distinct(y, sentinel=-2.0, eps=eps) + is_intersect = (jnp.abs(y.imag) <= eps) & (jnp.abs(y.real) < 1.0) + # Ensure y ∈ (-1, 1), i.e. where arccos is differentiable. + y = jnp.where(is_intersect, y.real, 0.0) + + # TODO: Multipoint evaluation with FFT. + # Chapter 10, https://doi.org/10.1017/CBO9781139856065. + n = jnp.arange(self.N) + # ∂f/∂y = ∑ₙ₌₀ᴺ⁻¹ aₙ(x) n Uₙ₋₁(y) + # sign ∂f/∂y = sign ∑ₙ₌₀ᴺ⁻¹ aₙ(x) n sin(n arcos y) + df_dy_sign = jnp.sign( + jnp.linalg.vecdot( + n * jnp.sin(n * jnp.arccos(y)[..., jnp.newaxis]), + self.cheb[..., jnp.newaxis, :], + ) + ) + y = bijection_from_disc(y, self.domain[0], self.domain[-1]) + return y, is_intersect, df_dy_sign + + def intersect1d(self, k=0.0, *, num_intersect=None, pad_value=0.0): + """Coordinates z(x, yᵢ) such that fₓ(yᵢ) = k for every x. + + Parameters + ---------- + k : jnp.ndarray + Shape must broadcast with (..., *cheb.shape[:-2]). + Specify to find solutions yᵢ to fₓ(yᵢ) = k. Default 0. + num_intersect : int or None + Specify to return the first ``num_intersect`` intersects. + This is useful if ``num_intersect`` tightly bounds the actual number. + + If not specified, then all intersects are returned. If there were fewer + intersects detected than the size of the last axis of the returned arrays, + then that axis is padded with ``pad_value``. + pad_value : float + Value with which to pad array. Default 0. + + Returns + ------- + z1, z2 : tuple[jnp.ndarray] + Shape broadcasts with (..., *self.cheb.shape[:-2], num_intersect). + Tuple of length two (z1, z2) of coordinates of intersects. + The points are ordered and grouped such that the straight line path + between ``z1`` and ``z2`` resides in the epigraph of f. + + """ + errorif( + self.N < 2, + NotImplementedError, + "This method requires a Chebyshev spectral resolution of N > 1, " + f"but got N = {self.N}.", + ) + + # Add axis to use same k over all Chebyshev series of the piecewise spline. + y, is_intersect, df_dy_sign = self.intersect2d( + jnp.atleast_1d(k)[..., jnp.newaxis] + ) + # Flatten so that last axis enumerates intersects along the piecewise spline. + y, is_intersect, df_dy_sign = map( + flatten_matrix, (self.isomorphism_to_C1(y), is_intersect, df_dy_sign) + ) + + # Note for bounce point applications: + # We ignore the degenerate edge case where the boundary shared by adjacent + # polynomials is a left intersection i.e. ``is_z1`` because the subset of + # pitch values that generate this edge case has zero measure. By ignoring + # this, for those subset of pitch values the integrations will be done in + # the hypograph of |B|, which will yield zero. If in far future decide to + # not ignore this, note the solution is to disqualify intersects within + # ``_eps`` from ``domain[-1]``. Edit: For differentiability, we cannot + # consider intersects at boundary of Chebyshev polynomial. Again, cases + # where this would be incorrect have measure zero. + is_z1 = (df_dy_sign <= 0) & is_intersect + is_z2 = (df_dy_sign >= 0) & _in_epigraph_and(is_intersect, df_dy_sign) + + sentinel = self.domain[0] - 1.0 + z1 = take_mask(y, is_z1, size=num_intersect, fill_value=sentinel) + z2 = take_mask(y, is_z2, size=num_intersect, fill_value=sentinel) + + mask = (z1 > sentinel) & (z2 > sentinel) + # Set outside mask to same value so integration is over set of measure zero. + z1 = jnp.where(mask, z1, pad_value) + z2 = jnp.where(mask, z2, pad_value) + return z1, z2 + + def _check_shape(self, z1, z2, k): + """Return shapes that broadcast with (k.shape[0], *self.cheb.shape[:-2], W).""" + assert z1.shape == z2.shape + # Ensure pitch batch dim exists and add back dim to broadcast with wells. + k = atleast_nd(self.cheb.ndim - 1, k)[..., jnp.newaxis] + # Same but back dim already exists. + z1 = atleast_nd(self.cheb.ndim, z1) + z2 = atleast_nd(self.cheb.ndim, z2) + # Cheb has shape (..., M, N) and others + # have shape (K, ..., W) + errorif(not (z1.ndim == z2.ndim == k.ndim == self.cheb.ndim)) + return z1, z2, k + + def check_intersect1d(self, z1, z2, k, plot=True, **kwargs): + """Check that intersects are computed correctly. + + Parameters + ---------- + z1, z2 : jnp.ndarray + Shape must broadcast with (*self.cheb.shape[:-2], W). + Tuple of length two (z1, z2) of coordinates of intersects. + The points are ordered and grouped such that the straight line path + between ``z1`` and ``z2`` resides in the epigraph of f. + k : jnp.ndarray + Shape must broadcast with self.cheb.shape[:-2]. + k such that fₓ(yᵢ) = k. + plot : bool + Whether to plot the piecewise spline and intersects for the given ``k``. + kwargs : dict + Keyword arguments into ``self.plot``. + + Returns + ------- + plots : list + Matplotlib (fig, ax) tuples for the 1D plot of each field line. + + """ + plots = [] + z1, z2, k = self._check_shape(z1, z2, k) + mask = (z1 - z2) != 0.0 + z1 = jnp.where(mask, z1, jnp.nan) + z2 = jnp.where(mask, z2, jnp.nan) + + err_1 = jnp.any(z1 > z2, axis=-1) + err_2 = jnp.any(z1[..., 1:] < z2[..., :-1], axis=-1) + f_midpoint = self.eval1d((z1 + z2) / 2) + eps = kwargs.pop("eps", jnp.finfo(jnp.array(1.0).dtype).eps * 10) + err_3 = jnp.any(f_midpoint > k + eps, axis=-1) + if not (plot or jnp.any(err_1 | err_2 | err_3)): + return plots + + # Ensure l axis exists for iteration in below loop. + cheb = atleast_nd(3, self.cheb) + mask, z1, z2, f_midpoint = map(atleast_3d_mid, (mask, z1, z2, f_midpoint)) + err_1, err_2, err_3 = map(atleast_2d_end, (err_1, err_2, err_3)) + + for l in np.ndindex(cheb.shape[:-2]): + for p in range(k.shape[0]): + idx = (p, *l) + if not (err_1[idx] or err_2[idx] or err_3[idx]): + continue + _z1 = z1[idx][mask[idx]] + _z2 = z2[idx][mask[idx]] + if plot: + self.plot1d( + cheb=cheb[l], + z1=_z1, + z2=_z2, + k=k[idx], + title=kwargs.pop( + "title", r"Intersects $z$ in epigraph($f$) s.t. $f(z) = k$" + ) + + f", (p,l)={idx}", + **kwargs, + ) + print(" z1 | z2") + print(jnp.column_stack([_z1, _z2])) + assert not err_1[idx], "Intersects have an inversion.\n" + assert not err_2[idx], "Detected discontinuity.\n" + assert not err_3[idx], ( + f"Detected f = {f_midpoint[idx][mask[idx]]} > {k[idx] + _eps} = k" + "in well, implying the straight line path between z1 and z2 is in" + "hypograph(f). Increase spectral resolution.\n" + ) + idx = (slice(None), *l) + if plot: + plots.append( + self.plot1d( + cheb=cheb[l], + z1=z1[idx], + z2=z2[idx], + k=k[idx], + **kwargs, + ) + ) + return plots + + def plot1d( + self, + cheb, + num=5000, + z1=None, + z2=None, + k=None, + k_transparency=0.5, + klabel=r"$k$", + title=r"Intersects $z$ in epigraph($f$) s.t. $f(z) = k$", + hlabel=r"$z$", + vlabel=r"$f$", + show=True, + include_legend=True, + ): + """Plot the piecewise Chebyshev series ``cheb``. + + Parameters + ---------- + cheb : jnp.ndarray + Shape (M, N). + Piecewise Chebyshev series f. + num : int + Number of points to evaluate ``cheb`` for plot. + z1 : jnp.ndarray + Shape (k.shape[0], W). + Optional, intersects with ∂f/∂y <= 0. + z2 : jnp.ndarray + Shape (k.shape[0], W). + Optional, intersects with ∂f/∂y >= 0. + k : jnp.ndarray + Shape (k.shape[0], ). + Optional, k such that fₓ(yᵢ) = k. + k_transparency : float + Transparency of pitch lines. + klabel : float + Label of intersect lines. + title : str + Plot title. + hlabel : str + Horizontal axis label. + vlabel : str + Vertical axis label. + show : bool + Whether to show the plot. Default is true. + include_legend : bool + Whether to include the legend in the plot. Default is true. + + Returns + ------- + fig, ax + Matplotlib (fig, ax) tuple. + + """ + fig, ax = plt.subplots() + legend = {} + z = jnp.linspace( + start=self.domain[0], + stop=self.domain[0] + (self.domain[1] - self.domain[0]) * self.M, + num=num, + ) + _add2legend(legend, ax.plot(z, self.eval1d(z, cheb), label=vlabel)) + _plot_intersect( + ax=ax, + legend=legend, + z1=z1, + z2=z2, + k=k, + k_transparency=k_transparency, + klabel=klabel, + ) + ax.set_xlabel(hlabel) + ax.set_ylabel(vlabel) + if include_legend: + ax.legend(legend.values(), legend.keys(), loc="lower right") + ax.set_title(title) + plt.tight_layout() + if show: + plt.show() + plt.close() + return fig, ax + + def _add2legend(legend, lines): """Add lines to legend if it's not already in it.""" for line in setdefault(lines, [lines], hasattr(lines, "__iter__")): diff --git a/desc/integrals/bounce_integral.py b/desc/integrals/bounce_integral.py index 577c600e6..aea90fe00 100644 --- a/desc/integrals/bounce_integral.py +++ b/desc/integrals/bounce_integral.py @@ -3,24 +3,931 @@ from interpax import CubicHermiteSpline, PPoly from orthax.legendre import leggauss -from desc.backend import jnp +from desc.backend import dct, jnp, rfft2 +from desc.integrals.basis import FourierChebyshevSeries, PiecewiseChebyshevSeries from desc.integrals.bounce_utils import ( _bounce_quadrature, _check_bounce_points, + _check_interp, _set_default_plot_kwargs, bounce_points, + get_alpha, get_pitch_inv_quad, interp_to_argmin, plot_ppoly, ) -from desc.integrals.interp_utils import polyder_vec +from desc.integrals.interp_utils import ( + cheb_from_dct, + cheb_pts, + idct_non_uniform, + interp_rfft2, + irfft2_non_uniform, + polyder_vec, +) from desc.integrals.quad_utils import ( automorphism_sin, + bijection_from_disc, get_quadrature, grad_automorphism_sin, + grad_bijection_from_disc, ) from desc.io import IOAble -from desc.utils import errorif, setdefault, warnif +from desc.utils import ( + atleast_nd, + check_posint, + errorif, + flatten_matrix, + setdefault, + warnif, +) + + +def _transform_to_desc(grid, f, is_reshaped=False): + """Transform to DESC spectral domain. + + Parameters + ---------- + grid : Grid + Tensor-product grid in (θ, ζ) with uniformly spaced nodes [0, 2π) × [0, 2π/NFP). + Preferably power of 2 for ``grid.num_theta`` and ``grid.num_zeta``. + f : jnp.ndarray + Function evaluated on ``grid``. + + Returns + ------- + a : jnp.ndarray + Shape (..., grid.num_theta // 2 + 1, grid.num_zeta) + Complex coefficients of 2D real FFT of ``f``. + + """ + if not is_reshaped: + f = grid.meshgrid_reshape(f, "rtz") + # real fft over poloidal since usually m > n + return rfft2(f, axes=(-1, -2), norm="forward") + + +def _transform_to_clebsch(grid, nodes, f, is_reshaped=False): + """Transform to Clebsch spectral domain. + + Parameters + ---------- + grid : Grid + Tensor-product grid in (θ, ζ) with uniformly spaced nodes [0, 2π) × [0, 2π/NFP). + Preferably power of 2 for ``grid.num_theta`` and ``grid.num_zeta``. + nodes : jnp.ndarray + Shape (L, M, N, 2) or (M, N, 2). + DESC coordinates (θ, ζ) sourced from the Clebsch coordinates + ``FourierChebyshevSeries.nodes(M,N,domain=(0,2*jnp.pi))``. + f : jnp.ndarray + Function evaluated on ``grid``. + + Returns + ------- + a : FourierChebyshevSeries + Spectral coefficients of f(α, ζ). + + """ + assert nodes.shape[-1] == 2 + if not is_reshaped: + f = grid.meshgrid_reshape(f, "rtz") + + M, N = nodes.shape[-3], nodes.shape[-2] + nodes = nodes.reshape(*nodes.shape[:-3], M * N, 2) + return FourierChebyshevSeries( + f=interp_rfft2( + # Interpolate to nodes in Clebsch space, + # which is not a tensor product node set in DESC space. + xq0=nodes[..., 0], + xq1=nodes[..., 1], + f=f[..., jnp.newaxis, :, :], + domain1=(0, 2 * jnp.pi / grid.NFP), + axes=(-1, -2), + ).reshape(*nodes.shape[:-2], M, N), + domain=(0, 2 * jnp.pi), + ) + + +def _transform_to_clebsch_1d(grid, alpha, theta, B, N_B, is_reshaped=False): + """Transform to single variable Clebsch spectral domain. + + Notes + ----- + The field line label α changes discontinuously, so the approximation + g defined with basis function in (α, ζ) coordinates to some continuous + function f does not guarantee continuity between cuts of the field line + until full convergence of g to f. + + Note if g were defined with basis functions in straight field line + coordinates, then continuity between cuts of the field line, as + determined by the straight field line coordinates (ϑ, ζ), is + guaranteed even with incomplete convergence (because the + parameters (ϑ, ζ) change continuously along the field line). + + Do not interpret this as superior function approximation. + Indeed, if g is defined with basis functions in (α, ζ) coordinates, then + g(α=α₀, ζ) will sample the approximation to f(α=α₀, ζ) for the full domain in ζ. + This holds even with incomplete convergence of g to f. + However, if g is defined with basis functions in (ϑ, ζ) coordinates, then + g(ϑ(α=α₀,ζ), ζ) will sample the approximation to f(α=α₀ ± ε, ζ) with ε → 0 as + g converges to f. + + (Visually, the small discontinuity apparent in g(α, ζ) at cuts of the field + line will not be visible in g(ϑ, ζ) because when moving along the field line + with g(ϑ, ζ) one is continuously flowing away from the starting field line, + (whereas g(α, ζ) has to "decide" at the cut what the next field line is). + (If full convergence is difficult to achieve, then in the context of surface + averaging bounce integrals, function approximation in (α, ζ) coordinates + might be preferable because most of the bounce integrals do not stretch + across toroidal transits).) + + Now, it appears the Fourier transform of θ may have small oscillatory bumps + outside reasonable bandwidths. This impedes full convergence of any + approximation, and in particular the poloidal Fourier series for, θ(α, ζ=ζ₀). + Maybe this is because the Chebyshev interpolation is detecting root-finding + errors where the nodes are more densely clustered. (Note the Fourier series + converges fast for |B|, even in non-omnigenous configurations where + (∂|B|/∂α)|ρ,ζ is not small, so this is indeed some feature with θ). + + Therefore, we explicitly enforce continuity of our approximation of θ between + cuts to short-circuit the convergence. This works to remove the small + discontinuity between cuts of the field line because the first cut is on α=0, + which is a knot of the Fourier series, and the Chebyshev points include a knot + near endpoints, so θ at the next cut of the field line is known with precision. + + Parameters + ---------- + grid : Grid + Tensor-product grid in (θ, ζ) with uniformly spaced nodes [0, 2π) × [0, 2π/NFP). + Preferably power of 2 for ``grid.num_theta`` and ``grid.num_zeta``. + alpha : jnp.ndarray + Shape (L, num_transit) or (num_transit, ). + Sequence of poloidal coordinates A = (α₀, α₁, …, αₘ₋₁) that specify field line. + theta : jnp.ndarray + Shape (L, M, N) or (M, N). + DESC coordinates θ sourced from the Clebsch coordinates + ``FourierChebyshevSeries.nodes(M,N,domain=(0,2*jnp.pi))``. + B : jnp.ndarray + |B| evaluated on ``grid``. + N_B : int + Desired Chebyshev spectral resolution for |B|. Preferably power of 2. + + Returns + ------- + T, B : tuple[PiecewiseChebyshevSeries] + Set of 1D Chebyshev spectral coefficients of θ along field line. + {θ_α : ζ ↦ θ(α, ζ) | α ∈ A} where A = (α₀, α₁, …, αₘ₋₁) is ``alpha``. + Likewise with |B|. + + """ + if not is_reshaped: + B = grid.meshgrid_reshape(B, "rtz") + + # Evaluating set of single variable maps is more efficient than evaluating + # multivariable map, so we project θ to a set of Chebyshev series. + T = FourierChebyshevSeries(f=theta, domain=(0, 2 * jnp.pi)).compute_cheb(alpha) + T.stitch() + theta = T.evaluate(N_B) + zeta = jnp.broadcast_to(cheb_pts(N_B, domain=T.domain), theta.shape) + + shape = (*alpha.shape[:-1], alpha.shape[-1] * N_B) + B = interp_rfft2( + theta.reshape(shape), + zeta.reshape(shape), + f=B[..., jnp.newaxis, :, :], + domain1=(0, 2 * jnp.pi / grid.NFP), + axes=(-1, -2), + ).reshape(*alpha.shape, N_B) + # Parameterize |B| by single variable to compute roots. + B = PiecewiseChebyshevSeries(cheb_from_dct(dct(B, type=2, axis=-1)) / N_B, T.domain) + # |B| guaranteed to be continuous because it was interpolated from B(θ(α, ζ),ζ). + return T, B + + +def _swap_pl(f): + # Given shape (L, num_pitch, -1) or (num_pitch, L, -1) or (num_pitch, -1) + # swap L and num_pitch axes. + assert f.ndim <= 3 + return jnp.swapaxes(f, 0, -2) + + +# TODO: After GitHub issue #1034 is resolved, we should pass in the previous +# θ(α, ζ) coordinates as an initial guess for the next coordinate mapping. +# Perhaps tell the optimizer to perturb the coefficients of the +# θ(α, ζ) directly? think this is equivalent to perturbing lambda. + + +class Bounce2D(IOAble): + """Computes bounce integrals using two-dimensional pseudo-spectral methods. + + The bounce integral is defined as ∫ f(λ, ℓ) dℓ, where + dℓ parameterizes the distance along the field line in meters, + f(λ, ℓ) is the quantity to integrate along the field line, + and the boundaries of the integral are bounce points ℓ₁, ℓ₂ s.t. λ|B|(ℓᵢ) = 1, + where λ is a constant defining the integral proportional to the magnetic moment + over energy and |B| is the norm of the magnetic field. + + For a particle with fixed λ, bounce points are defined to be the location on the + field line such that the particle's velocity parallel to the magnetic field is zero. + The bounce integral is defined up to a sign. We choose the sign that corresponds to + the particle's guiding center trajectory traveling in the direction of increasing + field-line-following coordinate ζ. + + Notes + ----- + Brief description of algorithm. + + Magnetic field line with label α, defined by B = ∇ρ × ∇α, is determined from + α : ρ, θ, ζ ↦ θ + λ(ρ,θ,ζ) − ι(ρ) [ζ + ω(ρ,θ,ζ)] + Interpolate Fourier-Chebyshev series to DESC poloidal coordinate. + θ : α, ζ ↦ tₘₙ exp(jmα) Tₙ(ζ) + Compute |B| along field lines. + |B| : α, ζ ↦ bₙ(θ(α, ζ)) Tₙ(ζ) + Compute bounce points. + r(ζₖ) = |B|(ζₖ) − 1/λ = 0 + Interpolate smooth components of integrand with FFTs. + G : α, ζ ↦ gₘₙ exp(j [m θ(α,ζ) + n ζ] ) + Perform Gaussian quadrature after removing singularities. + Fᵢ : λ, ζ₁, ζ₂ ↦ ∫ᵢ f(λ, ζ, {Gⱼ}) dζ + + If the map G is multivalued at a physical location, then it is still + permissible if separable into a single valued and multivalued parts. + In that case, supply the single valued parts, which will be interpolated + with FFTs, and use the provided coordinates θ,ζ ∈ ℝ to compose G. + + Longer description for developers. + + For applications which reduce to computing a nonlinear function of distance + along field lines between bounce points, it is required to identify these + points with field-line-following coordinates. (In the special case of a linear + function summing integrals between bounce points over a flux surface, arbitrary + coordinate systems may be used as that task reduces to a surface integral, + which is invariant to the order of summation). + + The DESC coordinate system is related to field-line-following coordinate + systems by a relation whose solution is best found with Newton iteration + since this solution is unique. Newton iteration is not a globally + convergent algorithm to find the real roots of r : ζ ↦ |B|(ζ) − 1/λ where + ζ is a field-line-following coordinate. For this, function approximation + of |B| is necessary. + + Therefore, to compute bounce points {(ζ₁, ζ₂)}, we approximate |B| by a + series expansion of basis functions parameterized by a single variable ζ, + restricting the class of basis functions to low order (e.g. N = 2ᵏ where + k is small) algebraic or trigonometric polynomial with integer frequencies. + These are the two classes useful for function approximation and for which + there exists globally convergent root-finding algorithms. We require low + order because the computation expenses grow with the number of potential + roots, and the theorem of algebra states that number is N (2N) for algebraic + (trigonometric) polynomials of degree N. + + The frequency transform of a map under the chosen basis must be concentrated + at low frequencies for the series to converge fast. For periodic + (non-periodic) maps, the best basis is a Fourier (Chebyshev) series. Both + converge exponentially, but the larger region of convergence in the complex + plane of Fourier series make it preferable in practice to choose coordinate + systems such that the function to approximate is periodic. The Chebyshev + polynomials are preferred to other orthogonal polynomial series since + fast discrete polynomial transforms (DPT) are implemented via fast transform + to Chebyshev then DCT. Although nothing prohibits a direct DPT, we want to + rely on existing libraries. Therefore, a Fourier-Chebyshev series is chosen + to interpolate θ(α,ζ), and a piecewise Chebyshev series interpolates |B|(ζ). + + Computing accurate series expansions in (α, ζ) coordinates demands + particular interpolation points in that coordinate system. Newton iteration + is used to compute θ at these points. Note that interpolation is necessary + because there is no transformation that converts series coefficients in + periodic coordinates, e.g. (ϑ, ϕ), to a low order polynomial basis in + non-periodic coordinates. For example, one can obtain series coefficients in + (α, ϕ) coordinates from those in (ϑ, ϕ) as follows + g : ϑ, ϕ ↦ ∑ₘₙ aₘₙ exp(j [mϑ + nϕ]) + + g : α, ϕ ↦ ∑ₘₙ aₘₙ exp(j [mα + (m ι + n)ϕ]) + However, the basis for the latter are trigonometric functions with + irrational frequencies, courtesy of the irrational rotational transform. + Globally convergent root-finding schemes for that basis (at fixed α) are + not known. The denominator of a close rational could be absorbed into the + coordinate ϕ, but this balloons the frequency, and hence the degree of the + series. + + Recall that periodicity enables faster convergence, motivating the desire + to instead interpolate |B|(ϑ, ϕ) with a double Fourier series and applying + bisection methods to find roots with mesh size inversely + proportional to the max frequency along the field line: M ι + N. ``Bounce2D`` + does not use that approach as that root-finding scheme is inferior. + The reason θ is not interpolated with a double Fourier series θ(ϑ, ζ) is + because quadrature points along |B|(α=α₀, ζ) can be identified by a single + variable; evaluating the multivariable map θ(ϑ(α, ζ), ζ) is expensive + compared to evaluating the single variable map θ(α=α₀, ζ). Also, the advantage + of DESC coordinates is that they use the spectrally condensed variable + ζ* = NFP ζ. This cannot be done in any other coordinate system, regardless of + whether it is periodic or not, so (ϑ, ϕ) coordinates are no better than (α, ζ) + coordinates in this aspect. (Another option is to use a filtered Fourier + series, doi.org/10.1016/j.aml.2006.10.001). + + After computing the bounce points, the supplied quadrature is performed. + By default, this is a Gauss quadrature after removing the singularity. + Fast fourier transforms interpolate smooth functions in the integrand to the + quadrature nodes. Quadrature is chosen over Runge-Kutta methods of the form + ∂Fᵢ/∂ζ = f(λ,ζ,{Gⱼ}) subject to Fᵢ(ζ₁) = 0 + because a fourth order Runge-Kutta method is equivalent to a quadrature + with Simpson's rule. Our quadratures resolve these integrals more + efficiently, and the fixed nature of quadrature performs better on GPUs. + + Fast transforms are used where possible. Fast multipoint methods are not + implemented. For non-uniform interpolation, MMTs are used. It will be + worthwhile to use the inverse non-uniform fast transforms. + + Additional notes on multivalued coordinates. + The definition of α in B = ∇ρ × ∇α on an irrational magnetic surface + implies the angle θ(α, ζ) is multivalued at a physical location. + In particular, following an irrational field, the single-valued θ grows + to ∞ (always non-monotonically) as ζ → ∞. Therefore, it is impossible to + approximate this map using single-valued basis functions defined on a + bounded subset of ℝ² (recall continuous functions on compact sets attain + their maximum). + + Still, it suffices to interpolate θ over one branch cut. + DESC chooses the branch cut defined by (α, ζ) ∈ [0, 2π]² and we must + maintain that convention with our basis functions. On such a branch + cut, the bound θ ∈ [0, 4π] holds. + + Likewise, α is multivalued. As the field line is followed, the label + jumps to α ∉ [0, 2π] after completing some toroidal transit. Therefore, + the map θ(α, ζ) must be periodic in α with period 2π. At every point + ζₚ ∈ [2π k, 2π ℓ] where k, ℓ ∈ ℤ where the field line completes a + poloidal transit there is guaranteed to exist a discrete jump + discontinuity in θ at ζ = 2π ℓ(p), starting the toroidal transit. + Recall a jump discontinuity appears as an infinitely sharp cut; + nearby the cut, the function must be blind to the cut. + + To recover the single-valued θ(α, ζ) from the function approximation + over one branch cut, at every ζ = 2π ℓ we can add either 0 or 2π or + 4π to the next cut of θ. + + See Also + -------- + Bounce1D + Uses one-dimensional local spline methods for the same task. + + Below are some advantages of ``Bounce2D`` over ``Bounce1D``. + The coordinates on which the root-finding must be done to map from DESC + to Clebsch coords is fixed to ``L*M*N``, independent of the number of + toroidal transits; generating the same data with DESC for input to + ``Bounce1D`` requires ``L*M*N*num_transit``. Furthermore, pseudo-spectral + interpolation of smooth functions such as |B| on each flux surface is more + efficient. This reduces the number of bounce integrals to be done by an + order of magnitude. Also, we have noticed C1 cubic spline interpolation + is inefficient to reconstruct smooth local maxima of |B|, which might be + important for (strongly) singular bounce integrals whose estimation + depends on ∂|B|/∂ζ there. + + Attributes + ---------- + required_names : list + Names in ``data_index`` required to compute bounce integrals. + + """ + + required_names = ["B^zeta", "|B|", "iota"] + get_pitch_inv_quad = staticmethod(get_pitch_inv_quad) + + def __init__( + self, + grid, + data, + iota, + theta, + N_B=None, + num_transit=16, + # TODO: Allow multiple starting labels for near-rational surfaces. + # think can just concatenate along second to last axis of cheb. + # Do this in different PR. + alpha=0.0, + quad=leggauss(32), + automorphism=(automorphism_sin, grad_automorphism_sin), + *, + Bref=1.0, + Lref=1.0, + is_reshaped=False, + check=False, + **kwargs, + ): + """Returns an object to compute bounce integrals. + + Notes + ----- + Performance may improve significantly if the spectral + resolutions ``m``, ``n``, ``M``, ``N``, and ``N_B`` are powers of two. + + Parameters + ---------- + grid : Grid + Tensor-product grid in (ρ, θ, ζ) with uniformly spaced nodes + [0, 2π) × [0, 2π/NFP). Note that below shape notation defines + L = ``grid.num_rho``, m = ``grid.num_theta``, and n = ``grid.num_zeta``. + data : dict[str, jnp.ndarray] + Data evaluated on ``grid``. + Must include names in ``Bounce2D.required_names``. + iota : jnp.ndarray + Shape (L, ). + Rotational transform. + theta : jnp.ndarray + Shape (L, M, N). + DESC coordinates θ sourced from the Clebsch coordinates + ``FourierChebyshevSeries.nodes(M,N,L,domain=(0,2*jnp.pi))``. + N_B : int + Desired Chebyshev spectral resolution for |B|. + Default is to double the resolution of ``theta``. + alpha : float + Starting field line poloidal label. + num_transit : int + Number of toroidal transits to follow field line. + quad : tuple[jnp.ndarray] + Quadrature points xₖ and weights wₖ for the approximate evaluation of an + integral ∫₋₁¹ g(x) dx = ∑ₖ wₖ g(xₖ). Default is 32 points. + For weak singular integrals, use ``chebgauss2`` from + ``desc.integrals.quad_utils``. + For strong singular integrals, use ``leggauss``. + automorphism : tuple[Callable] or None + The first callable should be an automorphism of the real interval [-1, 1]. + The second callable should be the derivative of the first. This map defines + a change of variable for the bounce integral. The choice made for the + automorphism will affect the performance of the quadrature method. + For weak singular integrals, use ``None``. + For strong singular integrals, use + ``(automorphism_sin,grad_automorphism_sin)`` from + ``desc.integrals.quad_utils``. + Bref : float + Optional. Reference magnetic field strength for normalization. + Lref : float + Optional. Reference length scale for normalization. + is_reshaped : bool + Whether the arrays in ``data`` are already reshaped to the expected form of + shape (..., m, n) or (L, m, n). This option can be used to iteratively + compute bounce integrals one flux surface at a time, reducing memory usage + To do so, set to true and provide only those axes of the reshaped data. + Default is false. + check : bool + Flag for debugging. Must be false for JAX transformations. + + """ + errorif(grid.sym, NotImplementedError, msg="Need grid that works with FFTs.") + # Strictly increasing zeta knots enforces dζ > 0. + # To retain dℓ = (|B|/B^ζ) dζ > 0 after fixing dζ > 0, we require + # B^ζ = B⋅∇ζ > 0. This is equivalent to changing the sign of ∇ζ or [∂ℓ/∂ζ]|ρ,a. + # Recall dζ = ∇ζ⋅dR, implying 1 = ∇ζ⋅(e_ζ|ρ,a). Hence, a sign change in ∇ζ + # requires the same sign change in e_ζ|ρ,a to retain the metric identity. + warnif( + check and kwargs.pop("warn", True) and jnp.any(data["B^zeta"] <= 0), + msg="(∂ℓ/∂ζ)|ρ,a > 0 is required. Enforcing positive B^ζ.", + ) + N_B = setdefault(N_B, theta.shape[-1] * 2) + self._alpha = alpha + self._m = grid.num_theta + self._n = grid.num_zeta + self._NFP = grid.NFP + self._x, self._w = get_quadrature(quad, automorphism) + + # peel off field lines + alpha = get_alpha(alpha, iota, num_transit, 2 * jnp.pi) + # Compute spectral coefficients. + self._T, self._B = _transform_to_clebsch_1d( + grid, alpha, theta, data["|B|"] / Bref, N_B, is_reshaped + ) + self._B_sup_z = _transform_to_desc( + grid, jnp.abs(data["B^zeta"]) * Lref / Bref, is_reshaped + ) + assert self._T.M == self._B.M == num_transit + assert self._T.N == theta.shape[-1] + assert self._B.N == N_B + + @staticmethod + def compute_theta(eq, M=16, N=32, rho=1.0, clebsch=None, **kwargs): + """Return DESC coordinates θ of (α,ζ) Fourier Chebyshev basis nodes. + + Parameters + ---------- + eq : Equilibrium + Equilibrium to use defining the coordinate mapping. + M : int + Grid resolution in poloidal direction for Clebsch coordinate grid. + Preferably power of 2. + N : int + Grid resolution in toroidal direction for Clebsch coordinate grid. + Preferably power of 2. + rho : float or jnp.ndarray + Flux surfaces labels in [0, 1] on which to compute. + clebsch : jnp.ndarray + Optional, Clebsch coordinate tensor-product grid (ρ, α, ζ). + ``FourierChebyshevSeries.nodes(M,N,L,domain=(0,2*jnp.pi))``. + If given, ``rho`` is ignored. + kwargs + Additional parameters to supply to the coordinate mapping function. + See ``desc.equilibrium.Equilibrium.map_coordinates``. + + Returns + ------- + theta : jnp.ndarray + Shape (L, M, N). + DESC coordinates θ sourced from the Clebsch coordinates + ``FourierChebyshevSeries.nodes(M,N,L,domain=(0,2*jnp.pi))``. + + """ + if clebsch is None: + clebsch = FourierChebyshevSeries.nodes( + check_posint(M), + check_posint(N), + rho, + domain=(0, 2 * jnp.pi), + ) + return eq.map_coordinates( + coords=clebsch, + inbasis=("rho", "alpha", "zeta"), + period=(jnp.inf, jnp.inf, jnp.inf), + **kwargs, + ).reshape(-1, M, N, 3)[..., 1] + + @staticmethod + def reshape_data(grid, *arys): + """Reshape ``data`` arrays for acceptable input to ``integrate``. + + Parameters + ---------- + grid : Grid + Tensor-product grid in (ρ, θ, ζ). + arys : jnp.ndarray + Data evaluated on grid. + + Returns + ------- + f : jnp.ndarray + Shape (L, M, N). + Reshaped data which may be given to ``integrate``. + + """ + f = [grid.meshgrid_reshape(d, "rtz") for d in arys] + return f if len(f) > 1 else f[0] + + def points(self, pitch_inv, *, num_well=None): + """Compute bounce points. + + Parameters + ---------- + pitch_inv : jnp.ndarray + Shape (L, num_pitch). + 1/λ values to compute the bounce integrals. 1/λ(ρ) is specified by + ``pitch_inv[ρ]`` where in the latter the labels are interpreted + as the indices that correspond to that field line. + num_well : int or None + Specify to return the first ``num_well`` pairs of bounce points for each + pitch along each field line. This is useful if ``num_well`` tightly + bounds the actual number. As a reference, there are typically 20 wells + per toroidal transit for a given pitch. You can check this by plotting + the field lines with the ``check_points`` method. + + If not specified, then all bounce points are returned. If there were fewer + wells detected along a field line than the size of the last axis of the + returned arrays, then that axis is padded with zero. + + Returns + ------- + z1, z2 : tuple[jnp.ndarray] + Shape (L, num_pitch, num_well). + Tuple of length two (z1, z2) that stores ζ coordinates of bounce points. + The points are ordered and grouped such that the straight line path + between ``z1`` and ``z2`` resides in the epigraph of |B|. + + If there were less than ``num_well`` wells detected along a field line, + then the last axis, which enumerates bounce points for a particular field + line and pitch, is padded with zero. + + """ + pitch_inv = atleast_nd(self._B.cheb.ndim - 1, pitch_inv).T + # Expects pitch_inv shape (num_pitch, L) if B.cheb.shape[0] is L. + z1, z2 = map(_swap_pl, self._B.intersect1d(pitch_inv, num_intersect=num_well)) + return z1, z2 + + def check_points(self, points, pitch_inv, *, plot=True, **kwargs): + """Check that bounce points are computed correctly. + + Parameters + ---------- + points : tuple[jnp.ndarray] + Shape (L, num_pitch, num_well). + Output of method ``self.points``. + Tuple of length two (z1, z2) that stores ζ coordinates of bounce points. + The points are ordered and grouped such that the straight line path + between ``z1`` and ``z2`` resides in the epigraph of |B|. + pitch_inv : jnp.ndarray + Shape (L, num_pitch). + 1/λ values to compute the bounce integrals. 1/λ(ρ) is specified by + ``pitch_inv[ρ]`` where in the latter the labels are interpreted + as the indices that correspond to that field line. + plot : bool + Whether to plot the field lines and bounce points of the given pitch angles. + kwargs : dict + Keyword arguments into + ``desc/integrals/basis.py::PiecewiseChebyshevSeries.plot1d``. + + Returns + ------- + plots : list + Matplotlib (fig, ax) tuples for the 1D plot of each field line. + + """ + kwargs.setdefault("hlabel", r"$\alpha = $" + str(self._alpha) + r", $\zeta$") + return self._B.check_intersect1d( + z1=_swap_pl(points[0]), + z2=_swap_pl(points[1]), + k=atleast_nd(self._B.cheb.ndim - 1, pitch_inv).T, + plot=plot, + **_set_default_plot_kwargs(kwargs), + ) + + def integrate( + self, + integrand, + pitch_inv, + f=None, + f_vec=None, + weight=None, + points=None, + *, + check=False, + plot=False, + ): + """Bounce integrate ∫ f(λ, ℓ) dℓ. + + Computes the bounce integral ∫ f(λ, ℓ) dℓ for every field line and pitch. + + Parameters + ---------- + integrand : callable + The composition operator on the set of functions in ``f`` that maps the + functions in ``f`` to the integrand f(λ, ℓ) in ∫ f(λ, ℓ) dℓ. It should + accept the arrays in ``f`` as arguments as well as the additional keyword + arguments: ``B``, ``pitch``, and ``zeta``. A quadrature will be performed + to approximate the bounce integral of + ``integrand(*f,B=B,pitch=pitch,zeta=zeta)``. + pitch_inv : jnp.ndarray + Shape (L, num_pitch). + 1/λ values to compute the bounce integrals. 1/λ(ρ) is specified by + ``pitch_inv[ρ]`` where in the latter the labels are interpreted + as the indices that correspond to that field line. + f : list[jnp.ndarray] or jnp.ndarray + Shape (L, m, n). + Real scalar-valued (2π × 2π/NFP) periodic in (θ, ζ) functions evaluated + on the ``grid`` supplied to construct this object. These functions + should be arguments to the callable ``integrand``. Use the method + ``Bounce2D.reshape_data`` to reshape the data into the expected shape. + f_vec : list[jnp.ndarray] or jnp.ndarray + Shape (L, m, n, 3). + Real vector-valued (2π × 2π/NFP) periodic in (θ, ζ) functions evaluated + on the ``grid`` supplied to construct this object. These functions + should be arguments to the callable ``integrand``. Use the method + ``Bounce2D.reshape_data`` to reshape the data into the expected shape. + weight : jnp.ndarray + Shape (L, m, n). + If supplied, the bounce integral labeled by well j is weighted such that + the returned value is w(j) ∫ f(λ, ℓ) dℓ, where w(j) is ``weight`` + interpolated to the deepest point in that magnetic well. Use the method + ``Bounce2D.reshape_data`` to reshape the data into the expected shape. + points : tuple[jnp.ndarray] + Shape (L, num_pitch, num_well). + Optional, output of method ``self.points``. + Tuple of length two (z1, z2) that stores ζ coordinates of bounce points. + The points are ordered and grouped such that the straight line path + between ``z1`` and ``z2`` resides in the epigraph of |B|. + check : bool + Flag for debugging. Must be false for JAX transformations. + plot : bool + Whether to plot the quantities in the integrand interpolated to the + quadrature points of each integral. Ignored if ``check`` is false. + + Returns + ------- + result : jnp.ndarray + Shape (M, L, num_pitch, num_well). + Last axis enumerates the bounce integrals for a given field line, + flux surface, and pitch value. + + """ + errorif(weight is not None, NotImplementedError, msg="See Bounce1D") + f = setdefault(f, []) + f_vec = setdefault(f_vec, []) + if not isinstance(f, (list, tuple)): + f = [f] + if not isinstance(f_vec, (list, tuple)): + f_vec = [f_vec] + + points = map(_swap_pl, points) + pitch_inv = atleast_nd(self._B.cheb.ndim - 1, pitch_inv).T + result = self._integrate(integrand, points, pitch_inv, f, f_vec, check, plot) + return result + + def _integrate(self, integrand, points, pitch_inv, f, f_vec, check, plot): + """Bounce integrate ∫ f(λ, ℓ) dℓ. + + Parameters + ---------- + points : jnp.ndarray + Shape (num_pitch, num_well) or (num_pitch, L, num_well). + pitch_inv : jnp.ndarray + Shape (num_pitch, ) or (num_pitch, L). + f : list[jnp.ndarray] + Shape (m, n) or (L, m, n). + f_vec : list[jnp.ndarray] + Shape (m, n, 3) or (L, m, n, 3). + + """ + z1, z2 = points + shape = [*z1.shape, self._x.size] + + # These are the ζ coordinates of the quadrature points. + # Shape is (num_pitch, L, number of points to interpolate onto). + zeta = flatten_matrix( + bijection_from_disc(self._x, z1[..., jnp.newaxis], z2[..., jnp.newaxis]) + ) + # Note self._T expects shape (num_pitch, L) if T.cheb.shape[0] is L. + # These are the θ coordinates of the quadrature points. + theta = self._T.eval1d(zeta) + + B_sup_z = irfft2_non_uniform( + theta, + zeta, + a=self._B_sup_z[..., jnp.newaxis, :, :], + M=self._n, + N=self._m, + domain1=(0, 2 * jnp.pi / self._NFP), + axes=(-1, -2), + ) + B = self._B.eval1d(zeta) + f = [ + interp_rfft2( + theta, + zeta, + f_i[..., jnp.newaxis, :, :], + domain1=(0, 2 * jnp.pi / self._NFP), + axes=(-1, -2), + ) + for f_i in f + ] + f_vec = [ + interp_rfft2( + theta[..., jnp.newaxis], + zeta[..., jnp.newaxis], + f_i[..., jnp.newaxis, :, :, :], + domain1=(0, 2 * jnp.pi / self._NFP), + axes=(-2, -3), + ) + for f_i in f_vec + ] + result = _swap_pl( + ( + integrand( + *f, + *f_vec, + B=B, + pitch=1 / pitch_inv[..., jnp.newaxis], + zeta=zeta, + ) + * B + / B_sup_z + ) + .reshape(shape) + .dot(self._w) + * grad_bijection_from_disc(z1, z2) + ) + + if check: + shape[-3], shape[0] = shape[0], shape[-3] + _check_interp( + # num_alpha is 1, num_rho, num_pitch, num_well, num_quad + (1, *shape), + *map(_swap_pl, (zeta, B_sup_z, B)), + result, + list(map(_swap_pl, f)), + plot, + ) + return result + + def compute_length(self, quad=None): + """Compute the proper length of the field line ∫ dℓ / |B|. + + Parameters + ---------- + quad : tuple[jnp.ndarray] + Quadrature points xₖ and weights wₖ for the approximate evaluation + of the integral ∫₋₁¹ f(x) dx ≈ ∑ₖ wₖ f(xₖ). + Should not use more points than half Chebyshev resolution of |B|. + + Returns + ------- + length : jnp.ndarray + Shape (L, ). + + """ + # Integrating an analytic map, so a fixed high order quadrature is ideal. + # The integration domain is not periodic, so the best candidates to choose + # are Gauss-Legendre quadrature and Clenshaw-Curtis. GL is more efficient + # than CC for analytic maps by a factor of 2. Advantage of CC is that θ at + # the quadrature points can be computed with fast cosine transform. However, + # one still needs to perform a non-uniform inverse fourier transform of B^ζ + # at those points, which will be the dominating expense. The spectral width + # of θ along field lines is also narrower than B^ζ, especially at high NFP. + x, w = leggauss(self._B.N // 2) if quad is None else quad + + # TODO: Use fast Chebyshev to Legendre inverse transform. + # When converted to a Legendre series, θ at the quadrature points can + # be computed with a fast transform (without needing to compute leggauss nodes). + # This is likely preferable to even a true non-uniform transform. + theta = idct_non_uniform(x, self._T.cheb[..., jnp.newaxis, :], self._T.N) + zeta = jnp.broadcast_to(bijection_from_disc(x, 0, 2 * jnp.pi), theta.shape) + + shape = (-1, self._T.M * w.size) # (num_rho, num transit * w.size) + B_sup_z = irfft2_non_uniform( + theta.reshape(shape), + zeta.reshape(shape), + a=self._B_sup_z[..., jnp.newaxis, :, :], + M=self._n, + N=self._m, + domain1=(0, 2 * jnp.pi / self._NFP), + axes=(-1, -2), + ).reshape(*self._T.cheb.shape[:-2], self._T.M, w.size) + + # Gradient of change of variable bijection from [−1, 1] → [0, 2π] is π. + return (1 / B_sup_z).dot(w).sum(axis=-1) * jnp.pi + + def plot(self, l, pitch_inv=None, **kwargs): + """Plot the field line and bounce points of the given pitch angles. + + Parameters + ---------- + l : int + Index into the nodes of the grid supplied to make this object. + ``rho=grid.compress(grid.nodes[:,0])[l]``. + pitch_inv : jnp.ndarray + Shape (num_pitch, ). + Optional, 1/λ values whose corresponding bounce points on the field line + specified by Clebsch coordinate ρ(l) will be plotted. + kwargs + Keyword arguments into + ``desc/integrals/basis.py::PiecewiseChebyshevSeries.plot1d``. + + Returns + ------- + fig, ax + Matplotlib (fig, ax) tuple. + + """ + B = self._B + if B.cheb.ndim > 2: + B = PiecewiseChebyshevSeries(B.cheb[l], B.domain) + if pitch_inv is not None: + errorif( + pitch_inv.ndim > 1, + msg=f"Got pitch_inv.ndim={pitch_inv.ndim}, but expected 1.", + ) + z1, z2 = B.intersect1d(pitch_inv) + kwargs["z1"] = z1 + kwargs["z2"] = z2 + kwargs["k"] = pitch_inv + kwargs.setdefault("hlabel", r"$\alpha = $" + str(self._alpha) + r", $\zeta$") + fig, ax = B.plot1d(B.cheb, **_set_default_plot_kwargs(kwargs)) + return fig, ax + + def plot_theta(self, l, **kwargs): + """Plot θ(α, ζ) on the specified flux surface. + + Parameters + ---------- + l : int + Index into the nodes of the grid supplied to make this object. + ``rho=grid.compress(grid.nodes[:,0])[l]``. + kwargs + Keyword arguments into + ``desc/integrals/basis.py::PiecewiseChebyshevSeries.plot1d``. + + Returns + ------- + fig, ax + Matplotlib (fig, ax) tuple. + + """ + T = self._T + if T.cheb.ndim > 2: + T = PiecewiseChebyshevSeries(T.cheb[l], T.domain) + kwargs.setdefault( + "title", + r"DESC poloidal angle $\theta($" + + r"$\alpha=$" + + str(self._alpha) + + r"$, \zeta)$", + ) + kwargs.setdefault("hlabel", r"$\alpha = $" + str(self._alpha) + r", $\zeta$") + kwargs.setdefault("vlabel", r"$\theta$") + fig, ax = T.plot1d(T.cheb, **_set_default_plot_kwargs(kwargs)) + return fig, ax class Bounce1D(IOAble): @@ -51,12 +958,11 @@ class Bounce1D(IOAble): which is invariant to the order of summation). The DESC coordinate system is related to field-line-following coordinate - systems by a relation whose solution is best found with Newton iteration. - There is a unique real solution to that relation, so Newton iteration is a - globally convergent root-finding algorithm here. For the task of finding - bounce points, Newton iteration is not a globally convergent algorithm to - find the real roots of r : ζ ↦ |B|(ζ) − 1/λ where ζ is a field-line-following - coordinate. For this, function approximation of |B| is necessary. + systems by a relation whose solution is best found with Newton iteration + since this solution is unique. Newton iteration is not a globally + convergent algorithm to find the real roots of r : ζ ↦ |B|(ζ) − 1/λ where + ζ is a field-line-following coordinate. For this, function approximation + of |B| is necessary. The function approximation in ``Bounce1D`` is ignorant that the objects to approximate are defined on a bounded subset of ℝ². Instead, the domain is @@ -69,7 +975,12 @@ class Bounce1D(IOAble): After computing the bounce points, the supplied quadrature is performed. By default, this is a Gauss quadrature after removing the singularity. - Local splines interpolate functions in the integrand to the quadrature nodes. + Local splines interpolate smooth functions in the integrand to the quadrature + nodes. Quadrature is chosen over Runge-Kutta methods of the form + ∂Fᵢ/∂ζ = f(λ,ζ,{Gⱼ}) subject to Fᵢ(ζ₁) = 0 + because a fourth order Runge-Kutta method is equivalent to a quadrature + with Simpson's rule. Our quadratures resolve these integrals more + efficiently, and the fixed nature of quadrature performs better on GPUs. See Also -------- @@ -139,6 +1050,7 @@ def __init__( For weak singular integrals, use ``None``. For strong singular integrals, use ``(automorphism_sin,grad_automorphism_sin)`` from + ``desc.integrals.quad_utils``. Bref : float Optional. Reference magnetic field strength for normalization. Lref : float @@ -179,10 +1091,10 @@ def __init__( self._x, self._w = get_quadrature(quad, automorphism) # Compute local splines. - self._zeta = grid.compress(grid.nodes[:, 2], surface_label="zeta") + self.zeta = grid.compress(grid.nodes[:, 2], surface_label="zeta") self.B = jnp.moveaxis( CubicHermiteSpline( - x=self._zeta, + x=self.zeta, y=self._data["|B|"], dydx=self._data["|B|_z|r,a"], axis=-1, @@ -191,7 +1103,7 @@ def __init__( source=(0, 1), destination=(-1, -2), ) - self._dB_dz = polyder_vec(self.B) + self.dB_dz = polyder_vec(self.B) # Add axis here instead of in ``_bounce_quadrature``. for name in self._data: @@ -252,7 +1164,7 @@ def points(self, pitch_inv, *, num_well=None): line and pitch, is padded with zero. """ - return bounce_points(pitch_inv, self._zeta, self.B, self._dB_dz, num_well) + return bounce_points(pitch_inv, self.zeta, self.B, self.dB_dz, num_well) def check_points(self, points, pitch_inv, *, plot=True, **kwargs): """Check that bounce points are computed correctly. @@ -285,7 +1197,7 @@ def check_points(self, points, pitch_inv, *, plot=True, **kwargs): z1=points[0], z2=points[1], pitch_inv=pitch_inv, - knots=self._zeta, + knots=self.zeta, B=self.B, plot=plot, **kwargs, @@ -303,6 +1215,7 @@ def integrate( batch=True, check=False, plot=False, + quad=None, ): """Bounce integrate ∫ f(λ, ℓ) dℓ. @@ -350,11 +1263,14 @@ def integrate( plot : bool Whether to plot the quantities in the integrand interpolated to the quadrature points of each integral. Ignored if ``check`` is false. + quad : tuple[jnp.ndarray] + Optional quadrature points and weights. If given this overrides + the quadrature chosen when this object was made. Returns ------- result : jnp.ndarray - Shape is same as input points. + Shape (M, L, num_pitch, num_well). Last axis enumerates the bounce integrals for a given field line, flux surface, and pitch value. @@ -362,14 +1278,14 @@ def integrate( if points is None: points = self.points(pitch_inv) result = _bounce_quadrature( - x=self._x, - w=self._w, + x=self._x if quad is None else quad[0], + w=self._w if quad is None else quad[1], integrand=integrand, points=points, pitch_inv=pitch_inv, f=setdefault(f, []), data=self._data, - knots=self._zeta, + knots=self.zeta, method=method, batch=batch, check=check, @@ -379,9 +1295,9 @@ def integrate( result *= interp_to_argmin( weight, points, - self._zeta, + self.zeta, self.B, - self._dB_dz, + self.dB_dz, method, ) return result @@ -407,7 +1323,7 @@ def plot(self, m, l, pitch_inv=None, **kwargs): Matplotlib (fig, ax) tuple. """ - B, dB_dz = self.B, self._dB_dz + B, dB_dz = self.B, self.dB_dz if B.ndim == 4: B = B[m] dB_dz = dB_dz[m] @@ -419,9 +1335,9 @@ def plot(self, m, l, pitch_inv=None, **kwargs): pitch_inv.ndim > 1, msg=f"Got pitch_inv.ndim={pitch_inv.ndim}, but expected 1.", ) - z1, z2 = bounce_points(pitch_inv, self._zeta, B, dB_dz) + z1, z2 = bounce_points(pitch_inv, self.zeta, B, dB_dz) kwargs["z1"] = z1 kwargs["z2"] = z2 kwargs["k"] = pitch_inv - fig, ax = plot_ppoly(PPoly(B.T, self._zeta), **_set_default_plot_kwargs(kwargs)) + fig, ax = plot_ppoly(PPoly(B.T, self.zeta), **_set_default_plot_kwargs(kwargs)) return fig, ax diff --git a/desc/integrals/bounce_utils.py b/desc/integrals/bounce_utils.py index 656c4e588..26e2756c4 100644 --- a/desc/integrals/bounce_utils.py +++ b/desc/integrals/bounce_utils.py @@ -27,6 +27,34 @@ ) +# TODO: Generalize this beyond ζ = ϕ or just map to Clebsch with ϕ. +def get_alpha(alpha_0, iota, num_transit, period): + """Get sequence of poloidal coordinates A = (α₀, α₁, …, αₘ₋₁) of field line. + + Parameters + ---------- + alpha_0 : float + Starting field line poloidal label. + iota : jnp.ndarray + Shape (iota.size, ). + Rotational transform normalized by 2π. + num_transit : float + Number of ``period``s to follow field line. + period : float + Toroidal period after which to update label. + + Returns + ------- + alpha : jnp.ndarray + Shape (iota.size, num_transit). + Sequence of poloidal coordinates A = (α₀, α₁, …, αₘ₋₁) that specify field line. + + """ + # Δϕ (∂α/∂ϕ) = Δϕ ι̅ = Δϕ ι/2π = Δϕ data["iota"] + alpha = alpha_0 + period * jnp.expand_dims(iota, -1) * jnp.arange(num_transit) + return alpha + + def get_pitch_inv_quad(min_B, max_B, num_pitch): """Return 1/λ values and weights for quadrature between ``min_B`` and ``max_B``. @@ -207,14 +235,18 @@ def bounce_points( def _set_default_plot_kwargs(kwargs): + vlabel = r"$\vert B \vert$" kwargs.setdefault( "title", - r"Intersects $\zeta$ in epigraph($\vert B \vert$) s.t. " - r"$\vert B \vert(\zeta) = 1/\lambda$", + r"Intersects $\zeta$ in epigraph(" + + vlabel + + ") s.t. " + + vlabel + + r"$(\zeta) = 1/\lambda$", ) kwargs.setdefault("klabel", r"$1/\lambda$") kwargs.setdefault("hlabel", r"$\zeta$") - kwargs.setdefault("vlabel", r"$\vert B \vert$") + kwargs.setdefault("vlabel", vlabel) return kwargs @@ -726,7 +758,7 @@ def interp_to_argmin_hard(h, points, knots, g, dg_dz, method="cubic"): def plot_ppoly( ppoly, - num=1000, + num=5000, z1=None, z2=None, k=None, diff --git a/desc/integrals/interp_utils.py b/desc/integrals/interp_utils.py index c00506fce..cfff90487 100644 --- a/desc/integrals/interp_utils.py +++ b/desc/integrals/interp_utils.py @@ -9,10 +9,438 @@ from functools import partial +import numpy as np from interpax import interp1d +from orthax.chebyshev import chebroots + +from desc.backend import dct, jnp, rfft, rfft2, take +from desc.integrals.quad_utils import bijection_from_disc +from desc.utils import Index, errorif, safediv + +# TODO: Boyd's method 𝒪(N²) instead of Chebyshev companion matrix 𝒪(N³). +# John P. Boyd, Computing real roots of a polynomial in Chebyshev series +# form through subdivision. https://doi.org/10.1016/j.apnum.2005.09.007. +# Use that once to find extrema of |B| if N_B > 64. Then to find roots +# of bounce points use the closed formula in Boyd's spectral methods +# section 19.6. Can isolate interval to search for root by observing +# whether B - 1/pitch changes sign at extrema. This is significantly +# cheaper and non-iterative, so jax and gpu will like it. +chebroots_vec = jnp.vectorize(chebroots, signature="(m)->(n)") + + +# TODO: Transformation to make nodes more uniform Boyd eq. 16.46 pg. 336. +# More uniformly spaced nodes might speed up convergence. + + +def cheb_pts(N, domain=(-1, 1), lobatto=False): + """Get ``N`` Chebyshev points mapped to given domain. + + Warnings + -------- + This is a common definition of the Chebyshev points (see Boyd, Chebyshev and + Fourier Spectral Methods p. 498). These are the points demanded by discrete + cosine transformations to interpolate Chebyshev series because the cosine + basis for the DCT is defined on [0, π]. They differ in ordering from the + points returned by ``numpy.polynomial.chebyshev.chebpts1`` and + ``numpy.polynomial.chebyshev.chebpts2``. + + Parameters + ---------- + N : int + Number of points. + domain : tuple[float] + Domain for points. + lobatto : bool + Whether to return the Gauss-Lobatto (extrema-plus-endpoint) + instead of the interior roots for Chebyshev points. + + Returns + ------- + pts : jnp.ndarray + Shape (N, ). + Chebyshev points mapped to given domain. + + """ + n = jnp.arange(N) + if lobatto: + y = jnp.cos(jnp.pi * n / (N - 1)) + else: + y = jnp.cos(jnp.pi * (2 * n + 1) / (2 * N)) + return bijection_from_disc(y, domain[0], domain[-1]) + + +def fourier_pts(M): + """Get ``M`` Fourier points in [0, 2π).""" + # [0, 2π] instead of [-π, π] required to match our definition of α. + return 2 * jnp.pi * jnp.arange(M) / M + + +def harmonic(a, M, axis=-1): + """Spectral coefficients of the Nyquist trigonometric interpolant. + + Parameters + ---------- + a : jnp.ndarray + Fourier coefficients ``a=rfft(f,norm="forward",axis=axis)``. + M : int + Spectral resolution of ``a``. + axis : int + Axis along which coefficients are stored. + + Returns + ------- + h : jnp.ndarray + Nyquist trigonometric interpolant coefficients. + Coefficients ordered along ``axis`` of size ``M`` to match ordering of + [1, cos(x), ..., cos(mx), sin(x), sin(2x), ..., sin(mx)] basis. + + """ + is_even = (M % 2) == 0 + # cos(mx) coefficients + an = 2.0 * ( + a.real.at[Index.get(0, axis, a.ndim)] + .divide(2.0) + .at[Index.get(-1, axis, a.ndim)] + .divide(1.0 + is_even) + ) + # sin(mx) coefficients + bn = -2.0 * take( + a.imag, + jnp.arange(1, a.shape[axis] - is_even), + axis, + unique_indices=True, + indices_are_sorted=True, + ) + h = jnp.concatenate([an, bn], axis=axis) + assert h.shape[axis] == M + return h + + +def harmonic_vander(x, M, domain=(0, 2 * np.pi)): + """Nyquist trigonometric interpolant basis evaluated at ``x``. + + Parameters + ---------- + x : jnp.ndarray + Points at which to evaluate pseudo-Vandermonde matrix. + M : int + Spectral resolution. + domain : tuple[float] + Domain over which samples will be taken. + This domain should span an open period of the function to interpolate. + + Returns + ------- + basis : jnp.ndarray + Shape (*x.shape, M). + Pseudo-Vandermonde matrix of degree ``M-1`` and sample points ``x``. + Last axis ordered as [1, cos(x), ..., cos(mx), sin(x), sin(2x), ..., sin(mx)]. + + """ + m = jnp.fft.rfftfreq(M, d=np.diff(domain) / (2 * jnp.pi) / M) + mx = m * (x - domain[0])[..., jnp.newaxis] + basis = jnp.concatenate( + [jnp.cos(mx), jnp.sin(mx[..., 1 : m.size - ((M % 2) == 0)])], axis=-1 + ) + assert basis.shape == (*x.shape, M) + return basis + + +# TODO: For inverse transforms, use non-uniform fast transforms (NFFT). +# https://github.com/flatironinstitute/jax-finufft. +# Let spectral resolution be F, (e.g. F = M N for 2D transform), +# and number of points (non-uniform) to evaluate be Q. A non-uniform +# fast transform cost is 𝒪([F+Q] log[F] log[1/ε]) where ε is the +# interpolation error term (depending on implementation how ε appears +# may change, but it is always logarithmic). Direct evaluation is 𝒪(F Q). +# Note that for the inverse Chebyshev transforms, we can also use fast +# multipoint methods Chapter 10, https://doi.org/10.1017/CBO9781139856065. +# Unlike NFFTs, multipoint methods are exact and reduce to using FFTs. +# The cost is 𝒪([F+Q] łog²[F + Q]). This is a good candidate for evaluating +# |B|, since the integrands are not smooth functions of |B|, which we know +# as a Chebyshev series, and the nodes are packed more tightly near the edges, +# in particular for the strongly singular integrals. + + +def interp_rfft(xq, f, domain=(0, 2 * jnp.pi), axis=-1): + """Interpolate real-valued ``f`` to ``xq`` with FFT. + + Parameters + ---------- + xq : jnp.ndarray + Real query points where interpolation is desired. + Shape of ``xq`` must broadcast with arrays of shape ``np.delete(f.shape,axis)``. + f : jnp.ndarray + Real function values on uniform grid over an open period to interpolate. + domain : tuple[float] + Domain over which samples were taken. + axis : int + Axis along which to transform. + + Returns + ------- + fq : jnp.ndarray + Real function value at query points. + + """ + a = rfft(f, axis=axis, norm="forward") + fq = irfft_non_uniform(xq, a, f.shape[axis], domain, axis) + return fq + + +def irfft_non_uniform(xq, a, n, domain=(0, 2 * jnp.pi), axis=-1): + """Evaluate Fourier coefficients ``a`` at ``xq``. + + Parameters + ---------- + xq : jnp.ndarray + Real query points where interpolation is desired. + Shape of ``xq`` must broadcast with arrays of shape ``np.delete(a.shape,axis)``. + a : jnp.ndarray + Fourier coefficients ``a=rfft(f,axis=axis,norm="forward")``. + n : int + Spectral resolution of ``a``. + domain : tuple[float] + Domain over which samples were taken. + axis : int + Axis along which to transform. + + Returns + ------- + fq : jnp.ndarray + Real function value at query points. + + """ + # |a| << |basis|, so move a instead of basis + a = ( + jnp.moveaxis(a, axis, -1) + .at[..., 0] + .divide(2.0) + .at[..., -1] + .divide(1.0 + ((n % 2) == 0)) + ) + m = jnp.fft.rfftfreq(n, d=np.diff(domain) / (2 * jnp.pi) / n) + xq = xq - domain[0] + basis = jnp.exp(-1j * m * xq[..., jnp.newaxis]) + fq = 2.0 * jnp.linalg.vecdot(basis, a).real + # ℜ〈 basis, a 〉= cos(m xq)⋅ℜ(a) − sin(m xq)⋅ℑ(a) + return fq + + +def interp_rfft2( + xq0, xq1, f, domain0=(0, 2 * jnp.pi), domain1=(0, 2 * jnp.pi), axes=(-2, -1) +): + """Interpolate real-valued ``f`` to coordinates ``(xq0,xq1)`` with FFT. + + Parameters + ---------- + xq0 : jnp.ndarray + Real query points of coordinate in ``domain0`` where interpolation is desired. + Shape must broadcast with shape ``np.delete(a.shape,axes)``. + The coordinates stored here must be the same coordinate enumerated + across axis ``min(axes)`` of the function values ``f``. + xq1 : jnp.ndarray + Real query points of coordinate in ``domain1`` where interpolation is desired. + Shape must broadcast with shape ``np.delete(a.shape,axes)``. + The coordinates stored here must be the same coordinate enumerated + across axis ``max(axes)`` of the function values ``f``. + f : jnp.ndarray + Shape (..., f.shape[-2], f.shape[-1]). + Real function values on uniform tensor-product grid over an open period. + domain0 : tuple[float] + Domain of coordinate specified by ``xq0`` over which samples were taken. + domain1 : tuple[float] + Domain of coordinate specified by ``xq1`` over which samples were taken. + axes : tuple[int] + Axes along which to transform. + The real transform is done along ``axes[1]``, so it will be more + efficient for that to denote the larger size axis in ``axes``. + + Returns + ------- + fq : jnp.ndarray + Real function value at query points. + + """ + a = rfft2(f, axes=axes, norm="forward") + fq = irfft2_non_uniform( + xq0, xq1, a, f.shape[axes[0]], f.shape[axes[1]], domain0, domain1, axes + ) + return fq + + +def irfft2_non_uniform( + xq0, xq1, a, M, N, domain0=(0, 2 * jnp.pi), domain1=(0, 2 * jnp.pi), axes=(-2, -1) +): + """Evaluate Fourier coefficients ``a`` at coordinates ``(xq0,xq1)``. + + Parameters + ---------- + xq0 : jnp.ndarray + Real query points of coordinate in ``domain0`` where interpolation is desired. + Shape must broadcast with shape ``np.delete(a.shape,axes)``. + The coordinates stored here must be the same coordinate enumerated + across axis ``min(axes)`` of the Fourier coefficients ``a``. + xq1 : jnp.ndarray + Real query points of coordinate in ``domain1`` where interpolation is desired. + Shape must broadcast with shape ``np.delete(a.shape,axes)``. + The coordinates stored here must be the same coordinate enumerated + across axis ``max(axes)`` of the Fourier coefficients ``a``. + a : jnp.ndarray + Shape (..., a.shape[-2], a.shape[-1]). + Fourier coefficients ``a=rfft2(f,axes=axes,norm="forward")``. + M : int + Spectral resolution of ``a`` along ``axes[0]``. + N : int + Spectral resolution of ``a`` along ``axes[1]``. + domain0 : tuple[float] + Domain of coordinate specified by ``xq0`` over which samples were taken. + domain1 : tuple[float] + Domain of coordinate specified by ``xq1`` over which samples were taken. + axes : tuple[int] + Axes along which to transform. + + Returns + ------- + fq : jnp.ndarray + Real function value at query points. + + """ + errorif(len(axes) != 2, msg="This is a 2D transform.") + errorif(a.ndim < 2, msg=f"Dimension mismatch, a.shape: {a.shape}.") + + # |a| << |basis|, so move a instead of basis + a = ( + jnp.moveaxis(a, source=axes, destination=(-2, -1)) + .at[..., 0] + .divide(2.0) + .at[..., -1] + .divide(1.0 + ((N % 2) == 0)) + ) + + idx = np.argsort(axes) + domain = (domain0, domain1) + m = jnp.fft.fftfreq(M, d=np.diff(domain[idx[0]]) / (2 * jnp.pi) / M) + n = jnp.fft.rfftfreq(N, d=np.diff(domain[idx[1]]) / (2 * jnp.pi) / N) + xq0 = xq0 - domain0[0] + xq1 = xq1 - domain1[0] + xq = (xq0, xq1) + + basis = jnp.exp( + 1j + * ( + (m * xq[idx[0]][..., jnp.newaxis])[..., jnp.newaxis] + + (n * xq[idx[1]][..., jnp.newaxis])[..., jnp.newaxis, :] + ) + ) + fq = 2.0 * (basis * a).real.sum(axis=(-2, -1)) + return fq + + +def cheb_from_dct(a, axis=-1): + """Get discrete Chebyshev transform from discrete cosine transform. + + Parameters + ---------- + a : jnp.ndarray + Discrete cosine transform coefficients, e.g. + ``a=dct(f,type=2,axis=axis,norm="forward")``. + The discrete cosine transformation used by scipy is defined here: + https://docs.scipy.org/doc/scipy/reference/generated/scipy.fft.dct.html. + axis : int + Axis along which to transform. + + Returns + ------- + cheb : jnp.ndarray + Chebyshev coefficients along ``axis``. + + """ + cheb = a.copy().at[Index.get(0, axis, a.ndim)].divide(2.0) + return cheb + + +def dct_from_cheb(cheb, axis=-1): + """Get discrete cosine transform from discrete Chebyshev transform. + + Parameters + ---------- + cheb : jnp.ndarray + Discrete Chebyshev transform coefficients, e.g.``cheb_from_dct(a)``. + axis : int + Axis along which to transform. + + Returns + ------- + a : jnp.ndarray + Chebyshev coefficients along ``axis``. + + """ + a = cheb.copy().at[Index.get(0, axis, cheb.ndim)].multiply(2.0) + return a + + +def interp_dct(xq, f, lobatto=False, axis=-1): + """Interpolate ``f`` to ``xq`` with discrete Chebyshev transform. + + Parameters + ---------- + xq : jnp.ndarray + Real query points where interpolation is desired. + Shape of ``xq`` must broadcast with shape ``np.delete(f.shape,axis)``. + f : jnp.ndarray + Real function values on Chebyshev points to interpolate. + lobatto : bool + Whether ``f`` was sampled on the Gauss-Lobatto (extrema-plus-endpoint) + or interior roots grid for Chebyshev points. + axis : int + Axis along which to transform. + + Returns + ------- + fq : jnp.ndarray + Real function value at query points. + + """ + lobatto = bool(lobatto) + errorif(lobatto, NotImplementedError, "JAX hasn't implemented type 1 DCT.") + a = cheb_from_dct(dct(f, type=2 - lobatto, axis=axis), axis) / ( + f.shape[axis] - lobatto + ) + fq = idct_non_uniform(xq, a, f.shape[axis], axis) + return fq + + +def idct_non_uniform(xq, a, n, axis=-1): + """Evaluate discrete Chebyshev transform coefficients ``a`` at ``xq`` ∈ [-1, 1]. + + Parameters + ---------- + xq : jnp.ndarray + Real query points where interpolation is desired. + Shape of ``xq`` must broadcast with shape ``np.delete(a.shape,axis)``. + a : jnp.ndarray + Discrete Chebyshev transform coefficients. + n : int + Spectral resolution of ``a``. + axis : int + Axis along which to transform. + + Returns + ------- + fq : jnp.ndarray + Real function value at query points. + + """ + a = jnp.moveaxis(a, axis, -1) + # Equivalent to + # Clenshaw recursion: chebval(xq, a, tensor=False), + # Vandermode product: jnp.linalg.vecdot(chebvander(xq, n - 1), a) + # but performs better on GPU. + n = jnp.arange(n) + fq = jnp.linalg.vecdot(jnp.cos(n * jnp.arccos(xq)[..., jnp.newaxis]), a) + return fq -from desc.backend import jnp -from desc.utils import safediv # Warning: method must be specified as keyword argument. interp1d_vec = jnp.vectorize( @@ -85,6 +513,23 @@ def polyval_vec(*, x, c): # TODO: Eventually do a PR to move this stuff into interpax. +def _subtract_first(c, k): + """Subtract ``k`` from first index of last axis of ``c``. + + Semantically same as ``return c.copy().at[...,0].add(-k)``, + but allows dimension to increase. + """ + c_0 = c[..., 0] - k + c = jnp.concatenate( + [ + c_0[..., jnp.newaxis], + jnp.broadcast_to(c[..., 1:], (*c_0.shape, c.shape[-1] - 1)), + ], + axis=-1, + ) + return c + + def _subtract_last(c, k): """Subtract ``k`` from last index of last axis of ``c``. @@ -111,7 +556,10 @@ def _filter_distinct(r, sentinel, eps): return r -_roots = jnp.vectorize(partial(jnp.roots, strip_zeros=False), signature="(m)->(n)") +_polyroots_vec = jnp.vectorize( + partial(jnp.roots, strip_zeros=False), signature="(m)->(n)" +) +_eps = max(jnp.finfo(jnp.array(1.0).dtype).eps, 2.5e-12) def polyroot_vec( @@ -121,7 +569,7 @@ def polyroot_vec( a_max=None, sort=False, sentinel=jnp.nan, - eps=max(jnp.finfo(jnp.array(1.0).dtype).eps, 2.5e-12), + eps=_eps, distinct=False, ): """Roots of polynomial with given coefficients. @@ -179,7 +627,7 @@ def polyroot_vec( distinct = distinct and num_coef > 3 else: # Compute from eigenvalues of polynomial companion matrix. - r = _roots(c) + r = _polyroots_vec(c) if get_only_real_roots: a_min = -jnp.inf if a_min is None else a_min[..., jnp.newaxis] diff --git a/desc/integrals/quad_utils.py b/desc/integrals/quad_utils.py index 7d33b67fd..4f70dec61 100644 --- a/desc/integrals/quad_utils.py +++ b/desc/integrals/quad_utils.py @@ -58,7 +58,7 @@ def automorphism_sin(x, s=0, m=10): """[-1, 1] ∋ x ↦ y ∈ [−1, 1]. This map increases node density near the boundary by the asymptotic factor - 1/√(1−x²) and adds a √(1−x²) factor to the integrand. + 1/√(1−x²) and adds a cosine factor to the integrand. Parameters ---------- @@ -242,7 +242,7 @@ def get_quadrature(quad, automorphism): Parameters ---------- - quad : (jnp.ndarray, jnp.ndarray) + quad : tuple[jnp.ndarray] Quadrature points xₖ and weights wₖ for the approximate evaluation of the integral ∫₋₁¹ g(x) dx = ∑ₖ wₖ g(xₖ). automorphism : (Callable, Callable) or None diff --git a/desc/transform.py b/desc/transform.py index a5c1798ad..2ab54ccc9 100644 --- a/desc/transform.py +++ b/desc/transform.py @@ -8,7 +8,13 @@ from desc.backend import jnp, put from desc.io import IOAble -from desc.utils import combination_permutation, isalmostequal, islinspaced, issorted +from desc.utils import ( + combination_permutation, + isalmostequal, + islinspaced, + issorted, + warnif, +) class Transform(IOAble): @@ -59,21 +65,22 @@ def __init__( self._basis = basis self._rcond = rcond if rcond is not None else "auto" - if ( + warnif( + self.grid.coordinates != "rtz", + msg=f"Expected coordinates rtz got {self.grid.coordinates}.", + ) + # DESC truncates the computational domain to ζ ∈ [0, 2π/grid.NFP] + # and changes variables to the spectrally condensed ζ* = basis.NFP ζ, + # so basis.NFP must equal grid.NFP. + warnif( method != "jitable" - and grid.node_pattern != "custom" - and self.basis.N != 0 and self.grid.NFP != self.basis.NFP - and np.any(self.grid.nodes[:, 2] != 0) - ): - warnings.warn( - colored( - "Unequal number of field periods for grid {} and basis {}.".format( - self.grid.NFP, self.basis.NFP - ), - "yellow", - ) - ) + and self.basis.N != 0 + and grid.node_pattern != "custom" + and np.any(self.grid.nodes[:, 2] != 0), + msg=f"Unequal number of field periods for grid {self.grid.NFP} and " + f"basis {self.basis.NFP}.", + ) self._built = False self._built_pinv = False diff --git a/desc/utils.py b/desc/utils.py index 50c1db1b5..fe68c94ff 100644 --- a/desc/utils.py +++ b/desc/utils.py @@ -744,6 +744,18 @@ def atleast_nd(ndmin, ary): return jnp.array(ary, ndmin=ndmin) if jnp.ndim(ary) < ndmin else ary +def atleast_3d_mid(ary): + """Like np.atleast_3d but if adds dim at axis 1 for 2d arrays.""" + ary = jnp.atleast_2d(ary) + return ary[:, jnp.newaxis] if ary.ndim == 2 else ary + + +def atleast_2d_end(ary): + """Like np.atleast_2d but if adds dim at axis 1 for 1d arrays.""" + ary = jnp.atleast_1d(ary) + return ary[:, jnp.newaxis] if ary.ndim == 1 else ary + + PRINT_WIDTH = 60 # current longest name is BootstrapRedlConsistency with pre-text diff --git a/docs/notebooks/tutorials/ideal_ballooning_stability.ipynb b/docs/notebooks/tutorials/ideal_ballooning_stability.ipynb index 1ccb88234..a2ffbf30b 100644 --- a/docs/notebooks/tutorials/ideal_ballooning_stability.ipynb +++ b/docs/notebooks/tutorials/ideal_ballooning_stability.ipynb @@ -783,6 +783,7 @@ } ], "source": [ + "ball_data0 = eq1.compute(\"gbdrift\", grid=grid, data=ball_data0)\n", "fig, ax = plt.subplots(7, sharex=True, figsize=(8, 16))\n", "\n", "ax[0].plot(zeta, (ball_data0[\"B^zeta\"] / ball_data0[\"|B|\"])[::nalpha], \"-or\", ms=1.5)\n", @@ -827,7 +828,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.9" + "version": "3.12.6" } }, "nbformat": 4, diff --git a/tests/baseline/test_binormal_drift_bounce2d.png b/tests/baseline/test_binormal_drift_bounce2d.png new file mode 100644 index 000000000..c78d2471c Binary files /dev/null and b/tests/baseline/test_binormal_drift_bounce2d.png differ diff --git a/tests/baseline/test_bounce1d_checks.png b/tests/baseline/test_bounce1d_checks.png index f3927bec6..71b757f20 100644 Binary files a/tests/baseline/test_bounce1d_checks.png and b/tests/baseline/test_bounce1d_checks.png differ diff --git a/tests/baseline/test_bounce2d_checks.png b/tests/baseline/test_bounce2d_checks.png new file mode 100644 index 000000000..472f3aeef Binary files /dev/null and b/tests/baseline/test_bounce2d_checks.png differ diff --git a/tests/inputs/master_compute_data_rpz.pkl b/tests/inputs/master_compute_data_rpz.pkl index 7a887be3d..ad8205555 100644 Binary files a/tests/inputs/master_compute_data_rpz.pkl and b/tests/inputs/master_compute_data_rpz.pkl differ diff --git a/tests/test_axis_limits.py b/tests/test_axis_limits.py index b8d173f7e..a2f556264 100644 --- a/tests/test_axis_limits.py +++ b/tests/test_axis_limits.py @@ -61,12 +61,12 @@ "g^tz_r", "g^tz_t", "g^tz_z", - "grad(alpha)", "g^aa", "g^ra", + "grad(alpha)", + "periodic(grad(alpha))", "gbdrift", "cvdrift", - "grad(alpha)", "|e^helical|", "|grad(theta)|", " Redl", # may not exist for all configurations diff --git a/tests/test_integrals.py b/tests/test_integrals.py index e25426900..90cc80485 100644 --- a/tests/test_integrals.py +++ b/tests/test_integrals.py @@ -6,14 +6,14 @@ import pytest from jax import grad from matplotlib import pyplot as plt -from numpy.polynomial.chebyshev import chebgauss, chebweight +from numpy.polynomial.chebyshev import chebgauss, chebinterpolate, chebroots, chebweight from numpy.polynomial.legendre import leggauss from scipy import integrate from scipy.interpolate import CubicHermiteSpline from scipy.special import ellipe, ellipkm1 from tests.test_plotting import tol_1d -from desc.backend import jnp +from desc.backend import jit, jnp from desc.basis import FourierZernikeBasis from desc.equilibrium import Equilibrium from desc.equilibrium.coords import get_rtz_grid @@ -21,6 +21,7 @@ from desc.grid import ConcentricGrid, Grid, LinearGrid, QuadratureGrid from desc.integrals import ( Bounce1D, + Bounce2D, DFTInterpolator, FFTInterpolator, line_integrals, @@ -33,13 +34,16 @@ surface_variance, virtual_casing_biot_savart, ) +from desc.integrals.basis import FourierChebyshevSeries from desc.integrals.bounce_utils import ( _get_extrema, bounce_points, + get_alpha, get_pitch_inv_quad, interp_to_argmin, interp_to_argmin_hard, ) +from desc.integrals.interp_utils import fourier_pts from desc.integrals.quad_utils import ( automorphism_sin, bijection_from_disc, @@ -53,7 +57,7 @@ from desc.integrals.singularities import _get_quadrature_nodes from desc.integrals.surface_integral import _get_grid_surface from desc.transform import Transform -from desc.utils import dot, safediv +from desc.utils import dot, errorif, safediv class TestSurfaceIntegral: @@ -724,6 +728,14 @@ def test_biest_interpolators(self): class TestBouncePoints: """Test that bounce points are computed correctly.""" + @staticmethod + def _cheb_intersect(cheb, k): + cheb = cheb.copy() + cheb[0] = cheb[0] - k + roots = chebroots(cheb) + intersect = roots[np.logical_and(np.isreal(roots), np.abs(roots.real) < 1)].real + return intersect + @staticmethod def filter(z1, z2): """Remove bounce points whose integrals have zero measure.""" @@ -892,6 +904,27 @@ def test_get_extrema(self): np.testing.assert_allclose(ext[idx], ext_scipy) np.testing.assert_allclose(B_ext[idx], B_ext_scipy) + @pytest.mark.unit + def test_z1_first_chebyshev(self): + """Test that bounce points are computed correctly.""" + + def f(z): + return -2 * np.cos(1 / (0.1 + z**2)) + 2 + + M, N = 1, 10 + alpha, zeta = FourierChebyshevSeries.nodes(M, N).T + cheb = FourierChebyshevSeries(f(zeta).reshape(M, N)).compute_cheb( + fourier_pts(M) + ) + pitch_inv = 3 + z1, z2 = cheb.intersect1d(pitch_inv) + cheb.check_intersect1d(z1, z2, pitch_inv) + z1, z2 = TestBouncePoints.filter(z1, z2) + + r = self._cheb_intersect(chebinterpolate(f, N - 1), pitch_inv) + np.testing.assert_allclose(z1, r[np.isclose(r, -0.24, atol=1e-1)]) + np.testing.assert_allclose(z2, r[np.isclose(r, 0.24, atol=1e-1)]) + def _chebgauss1(deg): x, w = chebgauss(deg) @@ -926,14 +959,14 @@ def test_bounce_quadrature(self, is_strong, quad, automorphism): Notes ----- Empirical testing shows asymptotic density of nodes needs to be at least - 1/√(1−x²) and quadrature needs √(1−x²) factor in Jacobian for accurate + 1/√(1−x²) and quadrature needs a cosine factor in Jacobian for accurate bounce integrals. This is satisfied by ``chebgauss2`` and ``leggauss`` with the sin automorphism. The former has less clustering near boundary by a factor of 1/√(1−x²), so we choose it for weakly singular bounce integrals. This will capture more features in the integral, especially the W shaped wells. Less clustering will also make non-uniform FFTs more accurate. - For the strongly singular bounce integrals, another √(1−x²) factor is preferred + For the strongly singular bounce integrals, another cosine factor is preferred to supress the derivative (as expected from chain rule), so we need to use the sin automorphism. We choose to apply that map to ``leggauss`` instead of ``_chebgauss1`` because the extra cosine term in ``_chebgauss1`` increases the @@ -964,10 +997,10 @@ def test_bounce_quadrature(self, is_strong, quad, automorphism): integrand = lambda B, pitch: jnp.sqrt(1 - m * pitch * B) truth = v * 2 * ellipe(m) bounce = Bounce1D( - Grid.create_meshgrid([1, 0, knots], coordinates="raz"), - data, - quad, - automorphism, + grid=Grid.create_meshgrid([1, 0, knots], coordinates="raz"), + data=data, + quad=quad, + automorphism=automorphism, check=True, ) points = bounce.points(pitch_inv, num_well=1) @@ -1106,6 +1139,7 @@ def test_bounce1d_checks(self): zeta = np.linspace(-2 * np.pi, 2 * np.pi, 200) eq = get("HELIOTRON") + # 3. Convert above coordinates to DESC computational coordinates. grid = get_rtz_grid(eq, rho, alpha, zeta, coordinates="raz") # 4. Compute input data. @@ -1139,13 +1173,17 @@ def test_bounce1d_checks(self): batch=False, ) avg = safediv(num, den) - assert np.isfinite(avg).all() and np.count_nonzero(avg) + errorif(not np.isfinite(avg).all()) + errorif( + np.count_nonzero(avg) == 0, + msg="Detected 0 wells on this cut of the field line. Make sure enough " + "toroidal transits were followed for this test, or plot the field line " + "to see if this is expected.", + ) # 9. Example manipulation of the output # Sum all bounce averages across a particular field line, for every field line. result = avg.sum(axis=-1) - # Group the result by pitch and flux surface. - result = result.reshape(alpha.size, rho.size, pitch_inv.shape[-1]) # The result stored at m, l, p = 0, 1, 3 print("Result(α, ρ, λ):", result[m, l, p]) @@ -1153,7 +1191,7 @@ def test_bounce1d_checks(self): print("1/λ(α, ρ):", pitch_inv[l, p]) # for the Clebsch-type field line coordinates nodes = grid.source_grid.meshgrid_reshape(grid.source_grid.nodes[:, :2], "arz") - print("(α, ρ):", nodes[m, l, 0]) + print("(ρ, α):", nodes[m, l, 0]) # 10. Plotting fig, ax = bounce.plot(m, l, pitch_inv[l], include_legend=False, show=False) @@ -1192,7 +1230,7 @@ def dg_dz(z): points = (np.array(0, ndmin=4), np.array(2 * np.pi, ndmin=4)) argmin = 5.61719 h_min = h(argmin) - result = func(h(zeta), points, zeta, bounce.B, bounce._dB_dz) + result = func(h(zeta), points, zeta, bounce.B, bounce.dB_dz) assert result.shape == points[0].shape np.testing.assert_allclose(h_min, result, rtol=1e-3) @@ -1223,6 +1261,7 @@ def get_drift_analytic_data(): "iota", "psi", "a", + "theta_PEST", ], grid=grid, ) @@ -1265,15 +1304,14 @@ def drift_analytic(data): # is independent of normalization length scales, like "effective r/R0". epsilon = data["a"] * data["rho"] # Aspect ratio of the flux surface. np.testing.assert_allclose(epsilon, 0.05) - theta_PEST = data["alpha"] + data["iota"] * data["zeta"] # same as 1 / (1 + epsilon cos(theta)) assuming epsilon << 1 - B_analytic = B0 * (1 - epsilon * np.cos(theta_PEST)) + B_analytic = B0 * (1 - epsilon * np.cos(data["theta_PEST"])) np.testing.assert_allclose(B, B_analytic, atol=3e-3) gradpar = data["a"] * data["B^zeta"] / data["|B|"] # This method of computing G0 suggests a fixed point iteration. G0 = data["a"] - gradpar_analytic = G0 * (1 - epsilon * np.cos(theta_PEST)) + gradpar_analytic = G0 * (1 - epsilon * np.cos(data["theta_PEST"])) gradpar_theta_analytic = data["iota"] * gradpar_analytic G0 = np.mean(gradpar_theta_analytic) np.testing.assert_allclose(gradpar, gradpar_analytic, atol=5e-3) @@ -1291,10 +1329,12 @@ def drift_analytic(data): / data["Bref"] ) gds21_analytic = -data["shear"] * ( - data["shear"] * theta_PEST - alpha_MHD / B**4 * np.sin(theta_PEST) + data["shear"] * data["theta_PEST"] + - alpha_MHD / B**4 * np.sin(data["theta_PEST"]) ) gds21_analytic_low_order = -data["shear"] * ( - data["shear"] * theta_PEST - alpha_MHD / B0**4 * np.sin(theta_PEST) + data["shear"] * data["theta_PEST"] + - alpha_MHD / B0**4 * np.sin(data["theta_PEST"]) ) np.testing.assert_allclose(gds21, gds21_analytic, atol=2e-2) np.testing.assert_allclose(gds21, gds21_analytic_low_order, atol=2.7e-2) @@ -1302,13 +1342,13 @@ def drift_analytic(data): fudge_1 = 0.19 gbdrift_analytic = fudge_1 * ( -data["shear"] - + np.cos(theta_PEST) - - gds21_analytic / data["shear"] * np.sin(theta_PEST) + + np.cos(data["theta_PEST"]) + - gds21_analytic / data["shear"] * np.sin(data["theta_PEST"]) ) gbdrift_analytic_low_order = fudge_1 * ( -data["shear"] - + np.cos(theta_PEST) - - gds21_analytic_low_order / data["shear"] * np.sin(theta_PEST) + + np.cos(data["theta_PEST"]) + - gds21_analytic_low_order / data["shear"] * np.sin(data["theta_PEST"]) ) fudge_2 = 0.07 cvdrift_analytic = gbdrift_analytic + fudge_2 * alpha_MHD / B**2 @@ -1382,6 +1422,7 @@ def test_binormal_drift_bounce1d(self): f=f, points=points, check=True, + plot=False, ) drift_numerical_den = bounce.integrate( integrand=TestBounce1D.drift_den_integrand, @@ -1389,6 +1430,7 @@ def test_binormal_drift_bounce1d(self): weight=np.ones(data["zeta"].size), points=points, check=True, + plot=False, ) drift_numerical = np.squeeze(drift_numerical_num / drift_numerical_den) msg = "There should be one bounce integral per pitch in this example." @@ -1477,3 +1519,224 @@ def fun2(pitch): # wrt λ but the boundary derivative: f(λ,ζ₂) (∂ζ₂/∂λ)(λ) - f(λ,ζ₁) (∂ζ₁/∂λ)(λ). # smooths out because the bounce points ζ₁ and ζ₂ are smooth functions of λ. np.testing.assert_allclose(fun2(pitch), -171500, rtol=1e-1) + + +class TestBounce2D: + """Test bounce integration that uses 2D pseudo-spectral methods.""" + + @pytest.mark.unit + @pytest.mark.parametrize( + "alpha_0, iota, num_period, period", + [ + (0, np.sqrt(2), 1, 2 * np.pi), + (0, np.arange(1, 3) * np.sqrt(2), 5, 2 * np.pi), + ], + ) + def test_alpha_sequence(self, alpha_0, iota, num_period, period): + """Test field line label updating works with jit.""" + alphas = jit(get_alpha, static_argnums=2)(alpha_0, iota, num_period, period) + if np.ndim(iota): + assert alphas.shape == (iota.size, num_period) + else: + assert alphas.shape == (num_period,) + print(alphas) + + @staticmethod + def _example_numerator(g_zz, B, pitch, zeta): + f = (1 - 0.5 * pitch * B) * g_zz + return safediv(f, jnp.sqrt(jnp.abs(1 - pitch * B))) + + @staticmethod + def _example_denominator(B, pitch, zeta): + return safediv(1, jnp.sqrt(jnp.abs(1 - pitch * B))) + + # TODO: Could test integration against bounce1d for this stellarator + @pytest.mark.unit + @pytest.mark.mpl_image_compare(remove_text=True, tolerance=tol_1d * 4) + def test_bounce2d_checks(self): + """Test that all the internal correctness checks pass for real example.""" + # noqa: D202 + # Suppose we want to compute a bounce average of the function + # f(ℓ) = (1 − λ|B|/2) * g_zz, where g_zz is the squared norm of the + # toroidal basis vector on some set of field lines specified by (ρ, α) + # coordinates. This is defined as + # [∫ f(ℓ) / √(1 − λ|B|) dℓ] / [∫ 1 / √(1 − λ|B|) dℓ] + + # 1. Define python functions for the integrands. We do that above. + # 2. Pick flux surfaces and grid resolution. + rho = np.linspace(0.1, 1, 6) + eq = get("HELIOTRON") + grid = Grid.create_meshgrid( + [rho, fourier_pts(eq.M_grid), fourier_pts(eq.N_grid) / eq.NFP], + period=(np.inf, 2 * np.pi, 2 * np.pi / eq.NFP), + NFP=eq.NFP, + ) + # 3. Compute input data. + data = eq.compute( + Bounce2D.required_names + ["min_tz |B|", "max_tz |B|", "g_zz"], grid=grid + ) + # 4. Compute DESC coordinates of optimal interpolation nodes. + theta = Bounce2D.compute_theta(eq, M=8, N=64, rho=rho) + # 5. Make the bounce integration operator. + bounce = Bounce2D( + grid, + data, + iota=grid.compress(data["iota"]), + theta=theta, + num_transit=2, + quad=leggauss(3), + check=True, + ) + pitch_inv, _ = bounce.get_pitch_inv_quad( + min_B=grid.compress(data["min_tz |B|"]), + max_B=grid.compress(data["max_tz |B|"]), + num_pitch=10, + ) + # 6. Compute bounce points. + points = bounce.points(pitch_inv) + # 7. Optionally check for correctness of bounce points. + bounce.check_points(points, pitch_inv, plot=False) + # 8. Integrate. + num = bounce.integrate( + integrand=TestBounce2D._example_numerator, + pitch_inv=pitch_inv, + f=Bounce2D.reshape_data(grid, data["g_zz"]), + points=points, + check=True, + ) + den = bounce.integrate( + integrand=TestBounce2D._example_denominator, + pitch_inv=pitch_inv, + points=points, + check=True, + ) + avg = safediv(num, den) + errorif(not np.isfinite(avg).all()) + errorif( + np.count_nonzero(avg) == 0, + msg="Detected 0 wells on this cut of the field line. Make sure enough " + "toroidal transits were followed for this test, or plot the field line " + "to see if this is expected.", + ) + + # 9. Example manipulation of the output + # Sum all bounce averages across a particular field line, for every field line. + result = avg.sum(axis=-1) + # The result stored at + l, p = 1, 3 + print("Result(ρ, λ):", result[l, p]) + # corresponds to the 1/λ value + print("1/λ(ρ):", pitch_inv[l, p]) + # for the flux surface + print("ρ:", rho[l]) + + np.testing.assert_allclose( + bounce.compute_length(), + # Computed data below through with Simpson's rule at 800 nodes. + # The difference is likely due to interpolation and floating point error. + # (On the version of JAX on which rtol was set, there is a bug with DCT + # and FFT that limit the accuracy to something comparable to 32 bit). + [ + 384.77892007, + 361.60220181, + 345.33817065, + 333.00781712, + 352.16277188, + 440.09424799, + ], + rtol=3e-3, + ) + + # 10. Plotting + fig, ax = bounce.plot(l, pitch_inv[l], include_legend=False, show=False) + return fig + + @staticmethod + def drift_num_integrand(cvdrift, gbdrift, B, pitch, zeta): + """Integrand of numerator of bounce averaged binormal drift.""" + g = jnp.sqrt(1 - pitch * B) + return (cvdrift * g) - (0.5 * g * gbdrift) + (0.5 * gbdrift / g) + + @staticmethod + def drift_den_integrand(B, pitch, zeta): + """Integrand of denominator of bounce averaged binormal drift.""" + return 1 / jnp.sqrt(1 - pitch * B) + + @pytest.mark.unit + @pytest.mark.mpl_image_compare(remove_text=True, tolerance=tol_1d) + def test_binormal_drift_bounce2d(self): + """Test bounce-averaged drift with analytical expressions.""" + data, things = TestBounce1D.get_drift_analytic_data() + # Compute analytic approximation. + drift_analytic, _, _, pitch_inv = TestBounce1D.drift_analytic(data) + + # Recompute on non-symmetric, fft compatible grid. + eq = things["eq"] + grid = Grid.create_meshgrid( + [ + data["rho"], + fourier_pts(eq._M_grid), + fourier_pts(max(1, eq.N_grid)) / eq.NFP, + ], + NFP=eq.NFP, + ) + # TODO: request periodic terms and construct secular terms in the + # integrand functions + grid_data = eq.compute( + names=Bounce2D.required_names + ["cvdrift", "gbdrift"], grid=grid + ) + grid_data["cvdrift"] = grid_data["cvdrift"] * data["normalization"] + grid_data["gbdrift"] = grid_data["gbdrift"] * data["normalization"] + + # Compute numerical result. + M, N = 32, 32 # todo: lower these to minimum once we get match + bounce = Bounce2D( + grid=grid, + data=grid_data, + iota=data["iota"], + theta=Bounce2D.compute_theta( + eq, + M, + N, + data["rho"], + iota=jnp.broadcast_to(data["iota"], shape=(M * N)), + ), + num_transit=2, + alpha=data["alpha"] - 2 * np.pi * data["iota"], + Bref=data["Bref"], + Lref=data["a"], + check=True, + ) + points = bounce.points(pitch_inv, num_well=1) + bounce.check_points(points, pitch_inv, plot=False) + + f = Bounce2D.reshape_data(grid, grid_data["cvdrift"], grid_data["gbdrift"]) + drift_numerical_num = bounce.integrate( + integrand=TestBounce2D.drift_num_integrand, + pitch_inv=pitch_inv, + f=f, + points=points, + check=True, + plot=False, + ) + drift_numerical_den = bounce.integrate( + integrand=TestBounce2D.drift_den_integrand, + pitch_inv=pitch_inv, + points=points, + check=True, + plot=False, + ) + drift_numerical = np.squeeze(drift_numerical_num / drift_numerical_den) + msg = "There should be one bounce integral per pitch in this example." + assert drift_numerical.size == drift_analytic.size, msg + + fig, ax = plt.subplots() + ax.plot(pitch_inv, drift_analytic) + ax.plot(pitch_inv, drift_numerical) + + # TODO: need to make integrands multivalued + # np.testing.assert_allclose( # noqa: E800 + # drift_numerical, drift_analytic, atol=5e-3, rtol=5e-2 # noqa: E800 + # ) # noqa: E800 + + return fig diff --git a/tests/test_interp_utils.py b/tests/test_interp_utils.py index 606b0fe09..e0843acd7 100644 --- a/tests/test_interp_utils.py +++ b/tests/test_interp_utils.py @@ -2,9 +2,32 @@ import numpy as np import pytest +from matplotlib import pyplot as plt +from numpy.polynomial.chebyshev import ( + cheb2poly, + chebinterpolate, + chebpts1, + chebpts2, + chebval, +) from numpy.polynomial.polynomial import polyvander +from scipy.fft import dct as sdct +from scipy.fft import idct as sidct -from desc.integrals.interp_utils import polyder_vec, polyroot_vec, polyval_vec +from desc.backend import dct, idct, rfft +from desc.integrals.interp_utils import ( + cheb_from_dct, + cheb_pts, + harmonic, + harmonic_vander, + interp_dct, + interp_rfft, + interp_rfft2, + polyder_vec, + polyroot_vec, + polyval_vec, +) +from desc.integrals.quad_utils import bijection_to_disc class TestPolyUtils: @@ -101,3 +124,228 @@ def test(x, c): assert c.shape[:-1] == x.shape[x.ndim - (c.ndim - 1) :] assert np.unique((c.shape[-1],) + x.shape[c.ndim - 1 :]).size == x.ndim - 1 test(x, c) + + +def _f_1d(x): + """Test function for 1D FFT.""" + return np.cos(7 * x) + np.sin(x) - 33.2 + + +def _f_1d_nyquist_freq(): + return 7 + + +def _f_2d(x, y): + """Test function for 2D FFT.""" + x_freq, y_freq = 3, 5 + return ( + # something that's not separable + np.cos(x_freq * x) * np.sin(2 * x + y) + + np.sin(y_freq * y) * np.cos(x + 3 * y) + # DC terms + - 33.2 + + np.cos(x) + + np.cos(y) + ) + + +def _f_2d_nyquist_freq(): + # can just sum frequencies multiplied above thanks to fourier + x_freq, y_freq = 3, 5 + x_freq_nyquist = x_freq + 2 + y_freq_nyquist = y_freq + 3 + return x_freq_nyquist, y_freq_nyquist + + +def _identity(x): + return x + + +def _f_non_periodic(z): + return np.sin(np.sqrt(2) * z) * np.cos(1 / (2 + z)) * np.cos(z**2) * z + + +def _f_algebraic(z): + return z**3 - 10 * z**6 - z - np.e + z**4 + + +class TestFastInterp: + """Test fast interpolation.""" + + @pytest.mark.unit + @pytest.mark.parametrize("N", [2, 6, 7]) + def test_cheb_pts(self, N): + """Test we use Chebyshev points compatible with standard definition of DCT.""" + np.testing.assert_allclose(cheb_pts(N), chebpts1(N)[::-1], atol=1e-15) + np.testing.assert_allclose( + cheb_pts(N, domain=(-np.pi, np.pi), lobatto=True), + np.pi * chebpts2(N)[::-1], + atol=1e-15, + ) + + @pytest.mark.unit + @pytest.mark.parametrize("M", [1, 8, 9]) + def test_rfftfreq(self, M): + """Make sure numpy uses Nyquist interpolant frequencies.""" + np.testing.assert_allclose(np.fft.rfftfreq(M, d=1 / M), np.arange(M // 2 + 1)) + + @pytest.mark.unit + @pytest.mark.parametrize( + "func, n, domain", + [ + # Test cases chosen with purpose, don't remove any. + (_f_1d, 2 * _f_1d_nyquist_freq() + 1, (0, 2 * np.pi)), + (_f_1d, 2 * _f_1d_nyquist_freq(), (0, 2 * np.pi)), + (_f_1d, 2 * _f_1d_nyquist_freq() + 1, (-np.pi, np.pi)), + (_f_1d, 2 * _f_1d_nyquist_freq(), (-np.pi, np.pi)), + (lambda x: np.cos(7 * x), 2, (-np.pi / 7, np.pi / 7)), + (lambda x: np.sin(7 * x), 3, (-np.pi / 7, np.pi / 7)), + ], + ) + def test_interp_rfft(self, func, n, domain): + """Test non-uniform FFT interpolation.""" + x = np.linspace(domain[0], domain[1], n, endpoint=False) + f = func(x) + xq = np.array([7.34, 1.10134, 2.28]) + fq = func(xq) + np.testing.assert_allclose(interp_rfft(xq, f, domain), fq) + M = f.shape[-1] + basis = harmonic_vander(xq, M, domain) + coef = harmonic(rfft(f, norm="forward"), M) + np.testing.assert_allclose((basis * coef).sum(axis=-1), fq) + + @pytest.mark.unit + @pytest.mark.parametrize( + "func, m, n, domain0, domain1", + [ + # Test cases chosen with purpose, don't remove any. + ( + _f_2d, + 2 * _f_2d_nyquist_freq()[0] + 1, + 2 * _f_2d_nyquist_freq()[1] + 1, + (0, 2 * np.pi), + (0, 2 * np.pi), + ), + ( + _f_2d, + 2 * _f_2d_nyquist_freq()[0] + 1, + 2 * _f_2d_nyquist_freq()[1] + 1, + (-np.pi / 3, 5 * np.pi / 3), + (np.pi, 3 * np.pi), + ), + ( + lambda x, y: np.cos(30 * x) + np.sin(y) ** 2 + 1, + 2 * 30 // 30 + 1, + 2 * 2 + 1, + (0, 2 * np.pi / 30), + (np.pi, 3 * np.pi), + ), + ], + ) + def test_interp_rfft2(self, func, m, n, domain0, domain1): + """Test non-uniform FFT interpolation.""" + theta = np.array([7.34, 1.10134, 2.28, 1e3 * np.e]) + zeta = np.array([1.1, 3.78432, 8.542, 0]) + x = np.linspace(domain0[0], domain0[1], m, endpoint=False) + y = np.linspace(domain1[0], domain1[1], n, endpoint=False) + x, y = map(np.ravel, list(np.meshgrid(x, y, indexing="ij"))) + truth = func(theta, zeta) + f = func(x, y).reshape(m, n) + np.testing.assert_allclose( + interp_rfft2(theta, zeta, f, domain0, domain1, axes=(-2, -1)), + truth, + ) + np.testing.assert_allclose( + interp_rfft2(theta, zeta, f, domain0, domain1, axes=(-1, -2)), + truth, + ) + + @pytest.mark.unit + @pytest.mark.parametrize( + "f, M, lobatto", + [ + # Identity map known for bad Gibbs; if discrete Chebyshev transform + # implemented correctly then won't see Gibbs. + (_identity, 2, False), + (_identity, 3, False), + (_identity, 3, True), + (_identity, 4, True), + ], + ) + def test_dct(self, f, M, lobatto): + """Test discrete cosine transform interpolation. + + Parameters + ---------- + f : callable + Function to test. + M : int + Fourier spectral resolution. + lobatto : bool + Whether ``f`` should be sampled on the Gauss-Lobatto (extrema-plus-endpoint) + or interior roots grid for Chebyshev points. + + """ + # Need to test fft used in Fourier Chebyshev interpolation due to issues like + # https://github.com/scipy/scipy/issues/15033 + # https://github.com/scipy/scipy/issues/21198 + # https://github.com/google/jax/issues/22466. + domain = (0, 2 * np.pi) + m = cheb_pts(M, domain, lobatto) + n = cheb_pts(m.size * 10, domain, lobatto) + norm = (n.size - lobatto) / (m.size - lobatto) + + dct_type = 2 - lobatto + fq_1 = np.sqrt(norm) * sidct( + sdct(f(m), type=dct_type, norm="ortho", orthogonalize=False), + type=dct_type, + n=n.size, + norm="ortho", + orthogonalize=False, + ) + if lobatto: + # JAX has yet to implement type 1 DCT. + fq_2 = norm * sidct(sdct(f(m), type=dct_type), n=n.size, type=dct_type) + else: + fq_2 = norm * idct(dct(f(m), type=dct_type), n=n.size, type=dct_type) + np.testing.assert_allclose(fq_1, f(n), atol=1e-14) + # JAX is less accurate than scipy. + np.testing.assert_allclose(fq_2, f(n), atol=1e-6) + + fig, ax = plt.subplots() + ax.scatter(m, f(m)) + ax.plot(n, fq_1) + ax.plot(n, fq_2) + return fig + + @pytest.mark.unit + @pytest.mark.parametrize( + "f, M", + [(_f_non_periodic, 5), (_f_non_periodic, 6), (_f_algebraic, 7)], + ) + def test_interp_dct(self, f, M): + """Test non-uniform DCT interpolation.""" + c0 = chebinterpolate(f, M - 1) + assert not np.allclose( + c0, + cheb_from_dct(dct(f(chebpts1(M)), 2)) / M, + ), ( + "Interpolation should fail because cosine basis is in wrong domain, " + "yet the supplied test function was interpolated fine using this wrong " + "domain. Pick a better test function." + ) + # test interpolation + z = cheb_pts(M) + fz = f(z) + np.testing.assert_allclose(c0, cheb_from_dct(dct(fz, 2) / M), atol=1e-13) + if np.allclose(_f_algebraic(z), fz): # Should reconstruct exactly. + np.testing.assert_allclose( + cheb2poly(c0), + np.array([-np.e, -1, 0, 1, 1, 0, -10]), + atol=1e-13, + ) + # test evaluation + xq = np.arange(10 * 3 * 2).reshape(10, 3, 2) + xq = bijection_to_disc(xq, 0, xq.size) + fq = chebval(xq, c0, tensor=False) + np.testing.assert_allclose(fq, interp_dct(xq, fz), atol=1e-13) diff --git a/tests/test_stability_funs.py b/tests/test_stability_funs.py index 72b4819b3..7233b5ce4 100644 --- a/tests/test_stability_funs.py +++ b/tests/test_stability_funs.py @@ -416,6 +416,7 @@ def test_ballooning_geometry(tmpdir_factory): "g^ra", "g^rr", "cvdrift", + "gbdrift", "cvdrift0", "|B|", "B^zeta", @@ -448,11 +449,9 @@ def test_ballooning_geometry(tmpdir_factory): sign_psi = psi_s / np.abs(psi_s) sign_iota = iotas / np.abs(iotas) - modB = data["|B|"] x = Lref * np.sqrt(psi) shat = -x / iotas * shears / Lref - psi_r = data["psi_r"] grad_alpha = data["grad(alpha)"] g_sup_rr = data["g^rr"]