Skip to content

Commit

Permalink
change field line integration to use diffrax (#610)
Browse files Browse the repository at this point in the history
Updated comment by @f0uriest 

- `field_line_integrate` now uses `diffrax` instead of
`jax.experimental.ode.odeint` which has been soft deprecated.
- Allows better control of bounding box (hard stop of field lines rather
than ad-hoc expoential decay)
- In practice seems to be a bit faster. `test_plot_poincare` which calls
`field_line_integrate` under the hood went from ~65s to ~55s with these
changes.


Resolves #609
  • Loading branch information
ddudt authored Sep 18, 2024
2 parents a4220a4 + acf1c98 commit 54becdb
Show file tree
Hide file tree
Showing 5 changed files with 118 additions and 77 deletions.
136 changes: 73 additions & 63 deletions desc/magnetic_fields/_core.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,23 @@
"""Classes for magnetic fields."""

import warnings
from abc import ABC, abstractmethod
from collections.abc import MutableSequence

import numpy as np
import scipy.linalg
from diffrax import (
DiscreteTerminatingEvent,
ODETerm,
PIDController,
SaveAt,
Tsit5,
diffeqsolve,
)
from interpax import approx_df, interp1d, interp2d, interp3d
from netCDF4 import Dataset, chartostring, stringtochar

from desc.backend import fori_loop, jit, jnp, odeint, sign
from desc.backend import fori_loop, jit, jnp, sign
from desc.basis import (
ChebyshevDoubleFourierBasis,
ChebyshevPolynomial,
Expand Down Expand Up @@ -2175,11 +2184,13 @@ def field_line_integrate(
rtol=1e-8,
atol=1e-8,
maxstep=1000,
min_step_size=1e-8,
solver=Tsit5(),
bounds_R=(0, np.inf),
bounds_Z=(-np.inf, np.inf),
decay_accel=1e6,
**kwargs,
):
"""Trace field lines by integration.
"""Trace field lines by integration, using diffrax package.
Parameters
----------
Expand All @@ -2191,34 +2202,29 @@ def field_line_integrate(
and the negative toroidal angle for negative Bphi
field : MagneticField
source of magnetic field to integrate
params: dict
params: dict, optional
parameters passed to field
source_grid : Grid, optional
Collocation points used to discretize source field.
rtol, atol : float
relative and absolute tolerances for ode integration
maxstep : int
maximum number of steps between different phis
min_step_size: float
minimum step size (in phi) that the integration can take. default is 1e-8
solver: diffrax.Solver
diffrax Solver object to use in integration,
defaults to Tsit5(), a RK45 explicit solver
bounds_R : tuple of (float,float), optional
R bounds for field line integration bounding box.
If supplied, the RHS of the field line equations will be
multiplied by exp(-r) where r is the distance to the bounding box,
this is meant to prevent the field lines which escape to infinity from
slowing the integration down by being traced to infinity.
defaults to (0,np.inf)
R bounds for field line integration bounding box. Trajectories that leave this
box will be stopped, and NaN returned for points outside the box.
Defaults to (0,np.inf)
bounds_Z : tuple of (float,float), optional
Z bounds for field line integration bounding box.
If supplied, the RHS of the field line equations will be
multiplied by exp(-r) where r is the distance to the bounding box,
this is meant to prevent the field lines which escape to infinity from
slowing the integration down by being traced to infinity.
Z bounds for field line integration bounding box. Trajectories that leave this
box will be stopped, and NaN returned for points outside the box.
Defaults to (-np.inf,np.inf)
decay_accel : float, optional
An extra factor to the exponential that decays the RHS, i.e.
the RHS is multiplied by exp(-r * decay_accel), this is to
accelerate the decay of the RHS and stop the integration sooner
after exiting the bounds. Defaults to 1e6
kwargs: dict
keyword arguments to be passed into the ``diffrax.diffeqsolve``
Returns
-------
Expand All @@ -2228,60 +2234,64 @@ def field_line_integrate(
"""
r0, z0, phis = map(jnp.asarray, (r0, z0, phis))
assert r0.shape == z0.shape, "r0 and z0 must have the same shape"
assert decay_accel > 0, "decay_accel must be positive"
rshape = r0.shape
r0 = r0.flatten()
z0 = z0.flatten()
x0 = jnp.array([r0, phis[0] * jnp.ones_like(r0), z0]).T

@jit
def odefun(rpz, s):
def odefun(s, rpz, args):
rpz = rpz.reshape((3, -1)).T
r = rpz[:, 0]
z = rpz[:, 2]
# if bounds are given, will decay the magnetic field line eqn
# RHS if the trajectory is outside of bounds to avoid
# integrating the field line to infinity, which is costly
# and not useful in most cases
decay_factor = jnp.where(
jnp.array(
[
jnp.less(r, bounds_R[0]),
jnp.greater(r, bounds_R[1]),
jnp.less(z, bounds_Z[0]),
jnp.greater(z, bounds_Z[1]),
]
),
jnp.array(
[
# we multiply by decay_accel to accelerate the decay so that the
# integration is stopped soon after the bounds are exited.
jnp.exp(-(decay_accel * (r - bounds_R[0]) ** 2)),
jnp.exp(-(decay_accel * (r - bounds_R[1]) ** 2)),
jnp.exp(-(decay_accel * (z - bounds_Z[0]) ** 2)),
jnp.exp(-(decay_accel * (z - bounds_Z[1]) ** 2)),
]
),
1.0,
)
# multiply all together, the conditions that are not violated
# are just one while the violated ones are continuous decaying exponentials
decay_factor = jnp.prod(decay_factor, axis=0)

br, bp, bz = field.compute_magnetic_field(
rpz, params, basis="rpz", source_grid=source_grid
).T
return (
decay_factor
* jnp.array(
[r * br / bp * jnp.sign(bp), jnp.sign(bp), r * bz / bp * jnp.sign(bp)]
).squeeze()
)
return jnp.array(
[r * br / bp * jnp.sign(bp), jnp.sign(bp), r * bz / bp * jnp.sign(bp)]
).squeeze()

# diffrax parameters

def default_terminating_event_fxn(state, **kwargs):
R_out = jnp.any(jnp.array([state.y[0] < bounds_R[0], state.y[0] > bounds_R[1]]))
Z_out = jnp.any(jnp.array([state.y[2] < bounds_Z[0], state.y[2] > bounds_Z[1]]))
return jnp.any(jnp.array([R_out, Z_out]))

kwargs.setdefault(
"stepsize_controller", PIDController(rtol=rtol, atol=atol, dtmin=min_step_size)
)
kwargs.setdefault(
"discrete_terminating_event",
DiscreteTerminatingEvent(default_terminating_event_fxn),
)

term = ODETerm(odefun)
saveat = SaveAt(ts=phis)

intfun = lambda x: diffeqsolve(
term,
solver,
y0=x,
t0=phis[0],
t1=phis[-1],
saveat=saveat,
max_steps=maxstep * len(phis),
dt0=min_step_size,
**kwargs,
).ys

# suppress warnings till its fixed upstream:
# https://github.com/patrick-kidger/diffrax/issues/445
# also ignore deprecation warning for now until we actually need to deal with it
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="unhashable type")
warnings.filterwarnings("ignore", message="`diffrax.*discrete_terminating")
x = jnp.vectorize(intfun, signature="(k)->(n,k)")(x0)

x = jnp.where(jnp.isinf(x), jnp.nan, x)
r = x[:, :, 0].squeeze().T.reshape((len(phis), *rshape))
z = x[:, :, 2].squeeze().T.reshape((len(phis), *rshape))

intfun = lambda x: odeint(odefun, x, phis, rtol=rtol, atol=atol, mxstep=maxstep)
x = jnp.vectorize(intfun, signature="(k)->(n,k)")(x0)
r = x[:, :, 0].T.reshape((len(phis), *rshape))
z = x[:, :, 2].T.reshape((len(phis), *rshape))
return r, z


Expand Down
1 change: 1 addition & 0 deletions devtools/dev-requirements_conda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ dependencies:
- pip:
# Conda only parses a single list of pip requirements.
# If two pip lists are given, all but the last list is skipped.
- diffrax >= 0.4.1
- interpax >= 0.3.3
- jax[cpu] >= 0.3.2, < 0.5.0
- nvgpu
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
colorama
diffrax >= 0.4.1
h5py >= 3.0.0, < 4.0
interpax >= 0.3.3
jax[cpu] >= 0.3.2, < 0.5.0
Expand Down
1 change: 1 addition & 0 deletions requirements_conda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ name: desc-env
dependencies:
# standard install
- colorama
- diffrax >= 0.4.1
- h5py >= 3.0.0, < 4.0
- matplotlib >= 3.5.0, < 4.0.0
- mpmath >= 1.0.0, < 2.0
Expand Down
56 changes: 42 additions & 14 deletions tests/test_magnetic_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import numpy as np
import pytest
from diffrax import Dopri5
from scipy.constants import mu_0

from desc.backend import jit, jnp
Expand Down Expand Up @@ -1038,24 +1039,51 @@ def test_field_line_integrate(self):
np.testing.assert_allclose(z[-1], 0.001, rtol=1e-6, atol=1e-6)

@pytest.mark.unit
def test_field_line_integrate_bounds(self):
"""Test field line integration with bounding box."""
def test_field_line_integrate_long(self):
"""Test field line integration for long distance along line."""
# q=4, field line should rotate 1/4 turn after 1 toroidal transit
# from outboard midplane to top center
field = ToroidalMagneticField(2, 10) + PoloidalMagneticField(2, 10, 0.25)
# test that bounds work correctly, and stop integration when trajectory
# hits the bounds
r0 = [10.1]
r0 = [10.001]
z0 = [0.0]
phis = [0, 2 * np.pi]
# this will hit the R bound
# (there is no Z bound, and R would go to 10.0 if not bounded)
r, z = field_line_integrate(r0, z0, phis, field, bounds_R=(10.05, np.inf))
np.testing.assert_allclose(r[-1], 10.05, rtol=3e-4)
# this will hit the Z bound
# (there is no R bound, and Z would go to 0.1 if not bounded)
r, z = field_line_integrate(r0, z0, phis, field, bounds_Z=(-np.inf, 0.05))
np.testing.assert_allclose(z[-1], 0.05, atol=3e-3)
phis = [0, 2 * np.pi * 25]
r, z = field_line_integrate(r0, z0, phis, field, solver=Dopri5())
np.testing.assert_allclose(r[-1], 10, rtol=1e-6, atol=1e-6)
np.testing.assert_allclose(z[-1], 0.001, rtol=1e-6, atol=1e-6)

@pytest.mark.unit
def test_field_line_integrate_early_terminate_default(self):
"""Test field line integration with default early termination criterion."""
# q=4, field line should rotate 1/4 turn after 1 toroidal transit
# from outboard midplane to top center
# early terminate when it crosses towards the inboard side (R=10),
field1 = ToroidalMagneticField(2, 10) + PoloidalMagneticField(2, 10, 0.25)
# make a SplineMagneticField only defined in a tiny region around initial point
field = SplineMagneticField.from_field(
field=field1,
R=np.linspace(10.0, 10.005, 40),
phi=np.linspace(0, 2 * np.pi, 40),
Z=np.linspace(-5e-3, 5e-3, 40),
extrap=True,
)
r0 = [10.001]
z0 = [0.0]
phis = [0, 2 * np.pi, 2 * np.pi * 2]

r, z = field_line_integrate(
r0,
z0,
phis,
field,
bounds_R=(np.min(field._R), np.max(field._R)),
bounds_Z=(np.min(field._Z), np.max(field._Z)),
min_step_size=1e-2,
)
np.testing.assert_allclose(r[1], 10, rtol=1e-6, atol=1e-6)
np.testing.assert_allclose(z[1], 0.001, rtol=1e-6, atol=1e-6)
# if early terinated, the values at the un-integrated phi points are inf
assert np.isnan(r[-1])
assert np.isnan(z[-1])

@pytest.mark.unit
def test_Bnormal_calculation(self):
Expand Down

0 comments on commit 54becdb

Please sign in to comment.