Skip to content

Commit

Permalink
Bounce averaging (#854)
Browse files Browse the repository at this point in the history
This PR adds functionality to compute bounce averages in DESC

- [x] Differentiable algorithm to compute bounce points and integrals.
- [x] Works with any numerical quadrature
- [x] Fixed bugs with numpy compatibility.

Related
- #1003 
- #1042 
- #1229 
- #1196
  • Loading branch information
unalmis authored Sep 3, 2024
2 parents 60de6fc + 917ad1c commit d2e9a2c
Show file tree
Hide file tree
Showing 31 changed files with 3,095 additions and 81 deletions.
119 changes: 77 additions & 42 deletions desc/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,19 +66,23 @@
)

if use_jax: # noqa: C901 - FIXME: simplify this, define globally and then assign?
jit = jax.jit
fori_loop = jax.lax.fori_loop
cond = jax.lax.cond
switch = jax.lax.switch
while_loop = jax.lax.while_loop
vmap = jax.vmap
bincount = jnp.bincount
repeat = jnp.repeat
take = jnp.take
scan = jax.lax.scan
from jax import custom_jvp
from jax import custom_jvp, jit, vmap

imap = jax.lax.map
from jax.experimental.ode import odeint
from jax.scipy.linalg import block_diag, cho_factor, cho_solve, qr, solve_triangular
from jax.lax import cond, fori_loop, scan, switch, while_loop
from jax.nn import softmax as softargmax
from jax.numpy import bincount, flatnonzero, repeat, take
from jax.numpy.fft import irfft, rfft, rfft2
from jax.scipy.fft import dct, idct
from jax.scipy.linalg import (
block_diag,
cho_factor,
cho_solve,
eigh_tridiagonal,
qr,
solve_triangular,
)
from jax.scipy.special import gammaln, logsumexp
from jax.tree_util import (
register_pytree_node,
Expand All @@ -90,6 +94,10 @@
treedef_is_leaf,
)

trapezoid = (
jnp.trapezoid if hasattr(jnp, "trapezoid") else jax.scipy.integrate.trapezoid
)

def put(arr, inds, vals):
"""Functional interface for array "fancy indexing".
Expand Down Expand Up @@ -328,6 +336,8 @@ def root(
This routine may be used on over or under-determined systems, in which case it
will solve it in a least squares / least norm sense.
"""
from desc.compute.utils import safenorm

if fixup is None:
fixup = lambda x, *args: x
if jac is None:
Expand Down Expand Up @@ -392,7 +402,7 @@ def tangent_solve(g, y):
x, (res, niter) = jax.lax.custom_root(
res, x0, solve, tangent_solve, has_aux=True
)
return x, (jnp.linalg.norm(res), niter)
return x, (safenorm(res), niter)


# we can't really test the numpy backend stuff in automated testing, so we ignore it
Expand All @@ -401,15 +411,54 @@ def tangent_solve(g, y):
jit = lambda func, *args, **kwargs: func
execute_on_cpu = lambda func: func
import scipy.optimize
from numpy.fft import irfft, rfft, rfft2 # noqa: F401
from scipy.fft import dct, idct # noqa: F401
from scipy.integrate import odeint # noqa: F401
from scipy.linalg import ( # noqa: F401
block_diag,
cho_factor,
cho_solve,
eigh_tridiagonal,
qr,
solve_triangular,
)
from scipy.special import gammaln, logsumexp # noqa: F401
from scipy.special import softmax as softargmax # noqa: F401

trapezoid = np.trapezoid if hasattr(np, "trapezoid") else np.trapz

def imap(f, xs, batch_size=None, in_axes=0, out_axes=0):
"""Generalizes jax.lax.map; uses numpy."""
if not isinstance(xs, np.ndarray):
raise NotImplementedError(
"Require numpy array input, or install jax to support pytrees."
)
xs = np.moveaxis(xs, source=in_axes, destination=0)
return np.stack([f(x) for x in xs], axis=out_axes)

def vmap(fun, in_axes=0, out_axes=0):
"""A numpy implementation of jax.lax.map whose API is a subset of jax.vmap.
Like Python's builtin map,
except inputs and outputs are in the form of stacked arrays,
and the returned object is a vectorized version of the input function.
Parameters
----------
fun: callable
Function (A -> B)
in_axes: int
Axis to map over.
out_axes: int
An integer indicating where the mapped axis should appear in the output.
Returns
-------
fun_vmap: callable
Vectorized version of fun.
"""
return lambda xs: imap(fun, xs, in_axes=in_axes, out_axes=out_axes)

def tree_stack(*args, **kwargs):
"""Stack pytree for numpy backend."""
Expand Down Expand Up @@ -592,32 +641,6 @@ def while_loop(cond_fun, body_fun, init_val):
val = body_fun(val)
return val

def vmap(fun, out_axes=0):
"""A numpy implementation of jax.lax.map whose API is a subset of jax.vmap.
Like Python's builtin map,
except inputs and outputs are in the form of stacked arrays,
and the returned object is a vectorized version of the input function.
Parameters
----------
fun: callable
Function (A -> B)
out_axes: int
An integer indicating where the mapped axis should appear in the output.
Returns
-------
fun_vmap: callable
Vectorized version of fun.
"""

def fun_vmap(fun_inputs):
return np.stack([fun(fun_input) for fun_input in fun_inputs], axis=out_axes)

return fun_vmap

def scan(f, init, xs, length=None, reverse=False, unroll=1):
"""Scan a function over leading array axes while carrying along state.
Expand Down Expand Up @@ -657,9 +680,14 @@ def scan(f, init, xs, length=None, reverse=False, unroll=1):
ys.append(y)
return carry, np.stack(ys)

def bincount(x, weights=None, minlength=None, length=None):
"""Same as np.bincount but with a dummy parameter to match jnp.bincount API."""
return np.bincount(x, weights, minlength)
def bincount(x, weights=None, minlength=0, length=None):
"""A numpy implementation of jnp.bincount."""
x = np.clip(x, 0, None)
if length is None:
length = max(minlength, x.max() + 1)
else:
minlength = max(minlength, length)
return np.bincount(x, weights, minlength)[:length]

def repeat(a, repeats, axis=None, total_repeat_length=None):
"""A numpy implementation of jnp.repeat."""
Expand Down Expand Up @@ -778,6 +806,13 @@ def root(
out = scipy.optimize.root(fun, x0, args, jac=jac, tol=tol)
return out.x, out

def flatnonzero(a, size=None, fill_value=0):
"""A numpy implementation of jnp.flatnonzero."""
nz = np.flatnonzero(a)
if size is not None:
nz = np.pad(nz, (0, max(size - nz.size, 0)), constant_values=fill_value)
return nz

def take(
a,
indices,
Expand Down
2 changes: 1 addition & 1 deletion desc/compute/_bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from scipy.special import roots_legendre

from ..backend import fori_loop, jnp
from ..integrals import surface_averages_map
from ..integrals.surface_integral import surface_averages_map
from .data_index import register_compute_fun


Expand Down
2 changes: 1 addition & 1 deletion desc/compute/_equil.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from desc.backend import jnp

from ..integrals import surface_averages
from ..integrals.surface_integral import surface_averages
from .data_index import register_compute_fun
from .utils import cross, dot, safediv, safenorm

Expand Down
2 changes: 1 addition & 1 deletion desc/compute/_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from desc.backend import jnp

from ..integrals import (
from ..integrals.surface_integral import (
surface_averages,
surface_integrals_map,
surface_max,
Expand Down
2 changes: 1 addition & 1 deletion desc/compute/_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from desc.backend import jnp

from ..integrals import surface_averages
from ..integrals.surface_integral import surface_averages
from .data_index import register_compute_fun
from .utils import cross, dot, safediv, safenorm

Expand Down
2 changes: 1 addition & 1 deletion desc/compute/_profiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from desc.backend import cond, jnp

from ..integrals import surface_averages, surface_integrals
from ..integrals.surface_integral import surface_averages, surface_integrals
from .data_index import register_compute_fun
from .utils import cumtrapz, dot, safediv

Expand Down
2 changes: 1 addition & 1 deletion desc/compute/_stability.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from desc.backend import jnp

from ..integrals import surface_integrals_map
from ..integrals.surface_integral import surface_integrals_map
from .data_index import register_compute_fun
from .utils import dot

Expand Down
7 changes: 5 additions & 2 deletions desc/equilibrium/coords.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,11 +685,14 @@ def get_rtz_grid(
rvp : rho, theta_PEST, phi
rtz : rho, theta, zeta
period : tuple of float
Assumed periodicity for each quantity in inbasis.
Assumed periodicity for functions of the given coordinates.
Use ``np.inf`` to denote no periodicity.
jitable : bool, optional
If false the returned grid has additional attributes.
Required to be false to retain nodes at magnetic axis.
kwargs
Additional parameters to supply to the coordinate mapping function.
See ``desc.equilibrium.coords.map_coordinates``.
Returns
-------
Expand All @@ -701,7 +704,7 @@ def get_rtz_grid(
[radial, poloidal, toroidal], coordinates=coordinates, period=period
)
if "iota" in kwargs:
kwargs["iota"] = grid.expand(kwargs["iota"])
kwargs["iota"] = grid.expand(jnp.atleast_1d(kwargs["iota"]))
inbasis = {
"r": "rho",
"t": "theta",
Expand Down
6 changes: 5 additions & 1 deletion desc/equilibrium/equilibrium.py
Original file line number Diff line number Diff line change
Expand Up @@ -1255,7 +1255,11 @@ def compute_theta_coords(
point. Only returned if ``full_output`` is True.
"""
warnif(True, DeprecationWarning, msg="Use map_coordinates instead.")
warnif(
True,
DeprecationWarning,
"Use map_coordinates instead of compute_theta_coords.",
)
return map_coordinates(
self,
flux_coords,
Expand Down
18 changes: 13 additions & 5 deletions desc/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,7 @@ def meshgrid_reshape(self, x, order):
-------
x : ndarray
Data reshaped to align with grid nodes.
"""
errorif(
not self.is_meshgrid,
Expand All @@ -637,7 +638,8 @@ def meshgrid_reshape(self, x, order):
vec = True
shape += (-1,)
x = x.reshape(shape, order="F")
x = jnp.moveaxis(x, 1, 0) # now shape rtz/raz etc
# swap to change shape from trz/arz to rtz/raz etc.
x = jnp.swapaxes(x, 1, 0)
newax = tuple(self.coordinates.index(c) for c in order)
if vec:
newax += (3,)
Expand Down Expand Up @@ -788,10 +790,11 @@ def create_meshgrid(
rtz : rho, theta, zeta
period : tuple of float
Assumed periodicity for each coordinate.
Use np.inf to denote no periodicity.
Use ``np.inf`` to denote no periodicity.
NFP : int
Number of field periods (Default = 1).
Only makes sense to change from 1 if ``period[2]==2π``.
Only makes sense to change from 1 if last coordinate is periodic
with some constant divided by ``NFP``.
Returns
-------
Expand Down Expand Up @@ -1885,8 +1888,13 @@ def _periodic_spacing(x, period=2 * jnp.pi, sort=False, jnp=jnp):
x = jnp.sort(x, axis=0)
# choose dx to be half the distance between its neighbors
if x.size > 1:
dx_0 = x[1] + (period - x[-1]) % period
dx_1 = x[0] + (period - x[-2]) % period
if np.isfinite(period):
dx_0 = x[1] + (period - x[-1]) % period
dx_1 = x[0] + (period - x[-2]) % period
else:
# just set to 0 to stop nan gradient, even though above gives expected value
dx_0 = 0
dx_1 = 0
if x.size == 2:
# then dx[0] == period and dx[-1] == 0, so fix this
dx_1 = dx_0
Expand Down
1 change: 1 addition & 0 deletions desc/integrals/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Classes for function integration."""

from .bounce_integral import Bounce1D
from .singularities import (
DFTInterpolator,
FFTInterpolator,
Expand Down
Loading

0 comments on commit d2e9a2c

Please sign in to comment.