Skip to content

Commit

Permalink
Force push with lease to avoid diverging branch with remote due to co…
Browse files Browse the repository at this point in the history
…mmit 0a5216c
  • Loading branch information
unalmis committed Aug 20, 2024
1 parent ade0a5e commit 8197f71
Show file tree
Hide file tree
Showing 11 changed files with 550 additions and 409 deletions.
43 changes: 2 additions & 41 deletions desc/compute/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,13 @@

import copy
import inspect
from functools import partial

import numpy as np

from desc.backend import cond, execute_on_cpu, flatnonzero, fori_loop, jnp, put, take
from desc.backend import cond, execute_on_cpu, fori_loop, jnp, put
from desc.grid import ConcentricGrid, Grid, LinearGrid

from ..utils import errorif, setdefault, warnif
from ..utils import errorif, warnif
from .data_index import allowed_kwargs, data_index

# map from profile name to equilibrium parameter name
Expand Down Expand Up @@ -1580,41 +1579,3 @@ def body(i, mins):
# The above implementation was benchmarked to be more efficient than
# alternatives without explicit loops in GitHub pull request #501.
return grid.expand(mins, surface_label)


@partial(jnp.vectorize, signature="(m),(m)->(n)", excluded={"size", "fill_value"})
def take_mask(a, mask, size=None, fill_value=None):
"""JIT compilable method to return ``a[mask][:size]`` padded by ``fill_value``.
Parameters
----------
a : jnp.ndarray
The source array.
mask : jnp.ndarray
Boolean mask to index into ``a``. Should have same shape as ``a``.
size : int
Elements of ``a`` at the first size True indices of ``mask`` will be returned.
If there are fewer elements than size indicates, the returned array will be
padded with ``fill_value``. The size default is ``mask.size``.
fill_value : Any
When there are fewer than the indicated number of elements, the remaining
elements will be filled with ``fill_value``. Defaults to NaN for inexact types,
the largest negative value for signed types, the largest positive value for
unsigned types, and True for booleans.
Returns
-------
result : jnp.ndarray
Shape (size, ).
"""
assert a.shape == mask.shape
idx = flatnonzero(mask, size=setdefault(size, mask.size), fill_value=mask.size)
return take(
a,
idx,
mode="fill",
fill_value=fill_value,
unique_indices=True,
indices_are_sorted=True,
)
2 changes: 1 addition & 1 deletion desc/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,7 +742,7 @@ def create_meshgrid(
rtz : rho, theta, zeta
period : tuple of float
Assumed periodicity for each coordinate.
Use np.inf to denote no periodicity.
Use ``np.inf`` to denote no periodicity.
NFP : int
Number of field periods (Default = 1).
Only makes sense to change from 1 if last coordinate is periodic
Expand Down
3 changes: 3 additions & 0 deletions desc/integrals/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
"""Classes for integration."""

from .fourier_bounce_integral import FourierChebyshevBasis, PiecewiseChebyshevBasis
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from orthax.polynomial import polyvander

from desc.backend import dct, jnp, rfft, rfft2, take
from desc.compute._quad_utils import bijection_from_disc
from desc.compute.utils import safediv
from desc.integrals._quad_utils import bijection_from_disc
from desc.utils import Index, errorif


Expand Down Expand Up @@ -314,8 +314,8 @@ def interp_dct(xq, f, lobatto=False, axis=-1):
lobatto = bool(lobatto)
errorif(lobatto, NotImplementedError)
assert f.ndim >= 1
a = cheb_from_dct(
dct(f, type=2 - lobatto, axis=axis) / (f.shape[axis] - lobatto), axis
a = cheb_from_dct(dct(f, type=2 - lobatto, axis=axis), axis) / (
f.shape[axis] - lobatto
)
fq = idct_non_uniform(xq, a, f.shape[axis], axis)
return fq
Expand Down Expand Up @@ -345,7 +345,7 @@ def idct_non_uniform(xq, a, n, axis=-1):
assert a.ndim >= 1
a = jnp.moveaxis(a, axis, -1)
basis = chebvander(xq, n - 1)
# Could instead use Clenshaw recursion with ``fq=chebval(xq,a,tensor=False)``.
# Could use Clenshaw recursion with fq = chebval(xq, a, tensor=False).
fq = jnp.linalg.vecdot(basis, a)
return fq

Expand Down
File renamed without changes.
122 changes: 58 additions & 64 deletions desc/compute/bounce_integral.py → desc/integrals/bounce_integral.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,45 +2,34 @@

from functools import partial

import numpy as np
from interpax import CubicHermiteSpline, PPoly, interp1d
from jax.nn import softmax
from matplotlib import pyplot as plt
from numpy.polynomial.legendre import leggauss
from orthax.legendre import leggauss
from tests.test_interp_utils import filter_not_nan

from desc.backend import flatnonzero, imap, jnp, put
from desc.compute._interp_utils import poly_root, polyder_vec, polyval_vec
from desc.compute._quad_utils import (
from desc.integrals._interp_utils import poly_root, polyder_vec, polyval_vec
from desc.integrals._quad_utils import (
automorphism_sin,
bijection_from_disc,
grad_automorphism_sin,
grad_bijection_from_disc,
)
from desc.compute.utils import take_mask
from desc.utils import errorif, setdefault, warnif
from desc.utils import errorif, setdefault, take_mask, warnif


# use for debugging and testing
def _filter_not_nan(a, check=False):
"""Filter out nan from ``a`` while asserting nan is padded at right."""
is_nan = np.isnan(a)
if check:
assert np.array_equal(is_nan, np.sort(is_nan, axis=-1))
return a[~is_nan]


# use for debugging and testing
def _filter_nonzero_measure(bp1, bp2):
def filter_bounce_points(bp1, bp2):
"""Return only bounce points such that |bp2 - bp1| > 0."""
mask = (bp2 - bp1) != 0
mask = (bp2 - bp1) != 0.0
return bp1[mask], bp2[mask]


def plot_field_line(
B,
pitch=None,
bp1=np.array([]),
bp2=np.array([]),
bp1=jnp.array([]),
bp2=jnp.array([]),
start=None,
stop=None,
num=1000,
Expand All @@ -57,11 +46,11 @@ def plot_field_line(
----------
B : PPoly
Spline of |B| over given field line.
pitch : np.ndarray
pitch : jnp.ndarray
λ value.
bp1 : np.ndarray
bp1 : jnp.ndarray
Bounce points with (∂|B|/∂ζ)|ρ,α <= 0.
bp2 : np.ndarray
bp2 : jnp.ndarray
Bounce points with (∂|B|/∂ζ)|ρ,α >= 0.
start : float
Minimum ζ on plot.
Expand Down Expand Up @@ -90,9 +79,7 @@ def plot_field_line(
legend = {}

def add(lines):
if not hasattr(lines, "__iter__"):
lines = [lines]
for line in lines:
for line in setdefault(lines, [lines], hasattr(lines, "__iter__")):
label = line.get_label()
if label not in legend:
legend[label] = line
Expand All @@ -101,32 +88,32 @@ def add(lines):
if include_knots:
for knot in B.x:
add(ax.axvline(x=knot, color="tab:blue", alpha=alpha_knot, label="knot"))
z = np.linspace(
z = jnp.linspace(
start=setdefault(start, B.x[0]),
stop=setdefault(stop, B.x[-1]),
num=num,
)
add(ax.plot(z, B(z), label=r"$\vert B \vert (\zeta)$"))

if pitch is not None:
b = 1 / np.atleast_1d(pitch)
b = 1 / jnp.atleast_1d(pitch)
for val in b:
add(
ax.axhline(
val, color="tab:purple", alpha=alpha_pitch, label=r"$1 / \lambda$"
)
)
bp1, bp2 = np.atleast_2d(bp1, bp2)
bp1, bp2 = jnp.atleast_2d(bp1, bp2)
for i in range(bp1.shape[0]):
if bp1.shape == bp2.shape:
bp1_i, bp2_i = _filter_nonzero_measure(bp1[i], bp2[i])
bp1_i, bp2_i = filter_bounce_points(bp1[i], bp2[i])
else:
bp1_i, bp2_i = bp1[i], bp2[i]
bp1_i, bp2_i = map(_filter_not_nan, (bp1_i, bp2_i))
bp1_i, bp2_i = bp1_i[~jnp.isnan(bp1_i)], bp2_i[~jnp.isnan(bp2_i)]
add(
ax.scatter(
bp1_i,
np.full_like(bp1_i, b[i]),
jnp.full_like(bp1_i, b[i]),
marker="v",
color="tab:red",
label="bp1",
Expand All @@ -135,7 +122,7 @@ def add(lines):
add(
ax.scatter(
bp2_i,
np.full_like(bp2_i, b[i]),
jnp.full_like(bp2_i, b[i]),
marker="^",
color="tab:green",
label="bp2",
Expand All @@ -155,44 +142,55 @@ def add(lines):
return fig, ax


def _check_bounce_points(bp1, bp2, sentinel, pitch, knots, B_c, plot, **kwargs):
def _check_bounce_points(bp1, bp2, pitch, knots, B_c, plot, **kwargs):
"""Check that bounce points are computed correctly."""
bp1 = jnp.where(bp1 > sentinel, bp1, jnp.nan)
bp2 = jnp.where(bp2 > sentinel, bp2, jnp.nan)
assert bp1.shape == bp2.shape
mask = (bp1 - bp2) == 0
bp1 = jnp.where(mask, jnp.nan, bp1)
bp2 = jnp.where(mask, jnp.nan, bp2)

eps = jnp.finfo(jnp.array(1.0).dtype).eps * 10
P, S = bp1.shape[:-1]
msg_1 = "Bounce points have an inversion."
msg_1 = "Bounce points have an inversion.\n"
err_1 = jnp.any(bp1 > bp2, axis=-1)
msg_2 = "Discontinuity detected."
msg_2 = "Discontinuity detected.\n"
err_2 = jnp.any(bp1[..., 1:] < bp2[..., :-1], axis=-1)

P, S, _ = bp1.shape
for s in range(S):
B = PPoly(B_c[:, s], knots)
for p in range(P):
B_mid = B((bp1[p, s] + bp2[p, s]) / 2)
err_3 = jnp.any(B_mid > 1 / pitch[p, s] + eps)
B_m_ps = B((bp1[p, s] + bp2[p, s]) / 2)
err_3 = jnp.any(B_m_ps > 1 / pitch[p, s] + eps)
if err_1[p, s] or err_2[p, s] or err_3:
bp1_p = _filter_not_nan(bp1[p, s], check=True)
bp2_p = _filter_not_nan(bp2[p, s], check=True)
B_mid = _filter_not_nan(B_mid, check=True)
bp1_ps, bp2_ps, B_m_ps = map(
filter_not_nan, (bp1[p, s], bp2[p, s], B_m_ps)
)
if plot:
plot_field_line(
B, pitch[p, s], bp1_p, bp2_p, title_id=f"{p},{s}", **kwargs
B,
pitch[p, s],
bp1_ps,
bp2_ps,
title_id=f"{p},{s}",
**kwargs,
)
print("bp1:", bp1_p)
print("bp2:", bp2_p)
print("bp1:", bp1_ps)
print("bp2:", bp2_ps)
assert not err_1[p, s], msg_1
assert not err_2[p, s], msg_2
msg_3 = (
f"Detected B midpoint = {B_mid}>{1 / pitch[p, s] + eps} = 1/pitch. "
"You need to use more knots or, if that is infeasible, switch to a "
"monotonic spline method.\n"
f"Detected |B| = {B_m_ps} > {1 / pitch[p, s] + eps} = 1/λ in well. "
"Use more knots or switch to a monotonic spline method.\n"
)
assert not err_3, msg_3
if plot:
plot_field_line(
B, pitch[:, s], bp1[:, s], bp2[:, s], title_id=str(s), **kwargs
B,
pitch[:, s],
bp1[:, s],
bp2[:, s],
title_id=str(s),
**kwargs,
)


Expand Down Expand Up @@ -334,7 +332,7 @@ def bounce_points(
a_min=jnp.array([0.0]),
a_max=jnp.diff(knots),
sort=True,
sentinel=-1,
sentinel=-1.0,
distinct=True,
)
assert intersect.shape == (P, S, N, degree)
Expand All @@ -356,13 +354,14 @@ def bounce_points(
bp1 = take_mask(intersect, is_bp1, size=num_well, fill_value=sentinel)
bp2 = take_mask(intersect, is_bp2, size=num_well, fill_value=sentinel)

if check:
_check_bounce_points(bp1, bp2, sentinel, pitch, knots, B_c, plot, **kwargs)

mask = (bp1 > sentinel) & (bp2 > sentinel)
# Set outside mask to same value so integration is over set of measure zero.
bp1 = jnp.where(mask, bp1, 0)
bp2 = jnp.where(mask, bp2, 0)
bp1 = jnp.where(mask, bp1, 0.0)
bp2 = jnp.where(mask, bp2, 0.0)

if check:
_check_bounce_points(bp1, bp2, pitch, knots, B_c, plot, **kwargs)

return bp1, bp2


Expand Down Expand Up @@ -626,12 +625,7 @@ def _bounce_quadrature(
Parameters
----------
bp1 : jnp.ndarray
Shape (P, S, num_well).
The field line-following ζ coordinates of bounce points for a given pitch along
a field line. The pairs ``bp1[i,j,k]`` and ``bp2[i,j,k]`` form left and right
integration boundaries, respectively, for the bounce integrals.
bp2 : jnp.ndarray
bp1, bp2 : jnp.ndarray
Shape (P, S, num_well).
The field line-following ζ coordinates of bounce points for a given pitch along
a field line. The pairs ``bp1[i,j,k]`` and ``bp2[i,j,k]`` form left and right
Expand Down Expand Up @@ -876,7 +870,7 @@ def bounce_integral(
if automorphism is not None:
auto, grad_auto = automorphism
w = w * grad_auto(x)
# Recall affine_bijection(auto(x), ζ_b₁, ζ_b₂) = ζ.
# Recall bijection_from_disc(auto(x), ζ_b₁, ζ_b₂) = ζ.
x = auto(x)

def bounce_integrate(
Expand Down
Loading

0 comments on commit 8197f71

Please sign in to comment.