diff --git a/desc/magnetic_fields/_core.py b/desc/magnetic_fields/_core.py index 760f4e372..bc1d7bad8 100644 --- a/desc/magnetic_fields/_core.py +++ b/desc/magnetic_fields/_core.py @@ -1,14 +1,23 @@ """Classes for magnetic fields.""" +import warnings from abc import ABC, abstractmethod from collections.abc import MutableSequence import numpy as np import scipy.linalg +from diffrax import ( + DiscreteTerminatingEvent, + ODETerm, + PIDController, + SaveAt, + Tsit5, + diffeqsolve, +) from interpax import approx_df, interp1d, interp2d, interp3d from netCDF4 import Dataset, chartostring, stringtochar -from desc.backend import fori_loop, jit, jnp, odeint, sign +from desc.backend import fori_loop, jit, jnp, sign from desc.basis import ( ChebyshevDoubleFourierBasis, ChebyshevPolynomial, @@ -2175,11 +2184,13 @@ def field_line_integrate( rtol=1e-8, atol=1e-8, maxstep=1000, + min_step_size=1e-8, + solver=Tsit5(), bounds_R=(0, np.inf), bounds_Z=(-np.inf, np.inf), - decay_accel=1e6, + **kwargs, ): - """Trace field lines by integration. + """Trace field lines by integration, using diffrax package. Parameters ---------- @@ -2191,7 +2202,7 @@ def field_line_integrate( and the negative toroidal angle for negative Bphi field : MagneticField source of magnetic field to integrate - params: dict + params: dict, optional parameters passed to field source_grid : Grid, optional Collocation points used to discretize source field. @@ -2199,26 +2210,21 @@ def field_line_integrate( relative and absolute tolerances for ode integration maxstep : int maximum number of steps between different phis + min_step_size: float + minimum step size (in phi) that the integration can take. default is 1e-8 + solver: diffrax.Solver + diffrax Solver object to use in integration, + defaults to Tsit5(), a RK45 explicit solver bounds_R : tuple of (float,float), optional - R bounds for field line integration bounding box. - If supplied, the RHS of the field line equations will be - multiplied by exp(-r) where r is the distance to the bounding box, - this is meant to prevent the field lines which escape to infinity from - slowing the integration down by being traced to infinity. - defaults to (0,np.inf) + R bounds for field line integration bounding box. Trajectories that leave this + box will be stopped, and NaN returned for points outside the box. + Defaults to (0,np.inf) bounds_Z : tuple of (float,float), optional - Z bounds for field line integration bounding box. - If supplied, the RHS of the field line equations will be - multiplied by exp(-r) where r is the distance to the bounding box, - this is meant to prevent the field lines which escape to infinity from - slowing the integration down by being traced to infinity. + Z bounds for field line integration bounding box. Trajectories that leave this + box will be stopped, and NaN returned for points outside the box. Defaults to (-np.inf,np.inf) - decay_accel : float, optional - An extra factor to the exponential that decays the RHS, i.e. - the RHS is multiplied by exp(-r * decay_accel), this is to - accelerate the decay of the RHS and stop the integration sooner - after exiting the bounds. Defaults to 1e6 - + kwargs: dict + keyword arguments to be passed into the ``diffrax.diffeqsolve`` Returns ------- @@ -2228,60 +2234,64 @@ def field_line_integrate( """ r0, z0, phis = map(jnp.asarray, (r0, z0, phis)) assert r0.shape == z0.shape, "r0 and z0 must have the same shape" - assert decay_accel > 0, "decay_accel must be positive" rshape = r0.shape r0 = r0.flatten() z0 = z0.flatten() x0 = jnp.array([r0, phis[0] * jnp.ones_like(r0), z0]).T @jit - def odefun(rpz, s): + def odefun(s, rpz, args): rpz = rpz.reshape((3, -1)).T r = rpz[:, 0] - z = rpz[:, 2] - # if bounds are given, will decay the magnetic field line eqn - # RHS if the trajectory is outside of bounds to avoid - # integrating the field line to infinity, which is costly - # and not useful in most cases - decay_factor = jnp.where( - jnp.array( - [ - jnp.less(r, bounds_R[0]), - jnp.greater(r, bounds_R[1]), - jnp.less(z, bounds_Z[0]), - jnp.greater(z, bounds_Z[1]), - ] - ), - jnp.array( - [ - # we multiply by decay_accel to accelerate the decay so that the - # integration is stopped soon after the bounds are exited. - jnp.exp(-(decay_accel * (r - bounds_R[0]) ** 2)), - jnp.exp(-(decay_accel * (r - bounds_R[1]) ** 2)), - jnp.exp(-(decay_accel * (z - bounds_Z[0]) ** 2)), - jnp.exp(-(decay_accel * (z - bounds_Z[1]) ** 2)), - ] - ), - 1.0, - ) - # multiply all together, the conditions that are not violated - # are just one while the violated ones are continuous decaying exponentials - decay_factor = jnp.prod(decay_factor, axis=0) - br, bp, bz = field.compute_magnetic_field( rpz, params, basis="rpz", source_grid=source_grid ).T - return ( - decay_factor - * jnp.array( - [r * br / bp * jnp.sign(bp), jnp.sign(bp), r * bz / bp * jnp.sign(bp)] - ).squeeze() - ) + return jnp.array( + [r * br / bp * jnp.sign(bp), jnp.sign(bp), r * bz / bp * jnp.sign(bp)] + ).squeeze() + + # diffrax parameters + + def default_terminating_event_fxn(state, **kwargs): + R_out = jnp.any(jnp.array([state.y[0] < bounds_R[0], state.y[0] > bounds_R[1]])) + Z_out = jnp.any(jnp.array([state.y[2] < bounds_Z[0], state.y[2] > bounds_Z[1]])) + return jnp.any(jnp.array([R_out, Z_out])) + + kwargs.setdefault( + "stepsize_controller", PIDController(rtol=rtol, atol=atol, dtmin=min_step_size) + ) + kwargs.setdefault( + "discrete_terminating_event", + DiscreteTerminatingEvent(default_terminating_event_fxn), + ) + + term = ODETerm(odefun) + saveat = SaveAt(ts=phis) + + intfun = lambda x: diffeqsolve( + term, + solver, + y0=x, + t0=phis[0], + t1=phis[-1], + saveat=saveat, + max_steps=maxstep * len(phis), + dt0=min_step_size, + **kwargs, + ).ys + + # suppress warnings till its fixed upstream: + # https://github.com/patrick-kidger/diffrax/issues/445 + # also ignore deprecation warning for now until we actually need to deal with it + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", message="unhashable type") + warnings.filterwarnings("ignore", message="`diffrax.*discrete_terminating") + x = jnp.vectorize(intfun, signature="(k)->(n,k)")(x0) + + x = jnp.where(jnp.isinf(x), jnp.nan, x) + r = x[:, :, 0].squeeze().T.reshape((len(phis), *rshape)) + z = x[:, :, 2].squeeze().T.reshape((len(phis), *rshape)) - intfun = lambda x: odeint(odefun, x, phis, rtol=rtol, atol=atol, mxstep=maxstep) - x = jnp.vectorize(intfun, signature="(k)->(n,k)")(x0) - r = x[:, :, 0].T.reshape((len(phis), *rshape)) - z = x[:, :, 2].T.reshape((len(phis), *rshape)) return r, z diff --git a/devtools/dev-requirements_conda.yml b/devtools/dev-requirements_conda.yml index 5aa77689d..107e3857e 100644 --- a/devtools/dev-requirements_conda.yml +++ b/devtools/dev-requirements_conda.yml @@ -15,6 +15,7 @@ dependencies: - pip: # Conda only parses a single list of pip requirements. # If two pip lists are given, all but the last list is skipped. + - diffrax >= 0.4.1 - interpax >= 0.3.3 - jax[cpu] >= 0.3.2, < 0.5.0 - nvgpu diff --git a/requirements.txt b/requirements.txt index fa5b86bba..227f95eac 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ colorama +diffrax >= 0.4.1 h5py >= 3.0.0, < 4.0 interpax >= 0.3.3 jax[cpu] >= 0.3.2, < 0.5.0 diff --git a/requirements_conda.yml b/requirements_conda.yml index da2996429..93d8e2720 100644 --- a/requirements_conda.yml +++ b/requirements_conda.yml @@ -2,6 +2,7 @@ name: desc-env dependencies: # standard install - colorama + - diffrax >= 0.4.1 - h5py >= 3.0.0, < 4.0 - matplotlib >= 3.5.0, < 4.0.0 - mpmath >= 1.0.0, < 2.0 diff --git a/tests/test_magnetic_fields.py b/tests/test_magnetic_fields.py index 86f417454..b5f36dc8f 100644 --- a/tests/test_magnetic_fields.py +++ b/tests/test_magnetic_fields.py @@ -2,6 +2,7 @@ import numpy as np import pytest +from diffrax import Dopri5 from scipy.constants import mu_0 from desc.backend import jit, jnp @@ -1038,24 +1039,51 @@ def test_field_line_integrate(self): np.testing.assert_allclose(z[-1], 0.001, rtol=1e-6, atol=1e-6) @pytest.mark.unit - def test_field_line_integrate_bounds(self): - """Test field line integration with bounding box.""" + def test_field_line_integrate_long(self): + """Test field line integration for long distance along line.""" # q=4, field line should rotate 1/4 turn after 1 toroidal transit # from outboard midplane to top center field = ToroidalMagneticField(2, 10) + PoloidalMagneticField(2, 10, 0.25) - # test that bounds work correctly, and stop integration when trajectory - # hits the bounds - r0 = [10.1] + r0 = [10.001] z0 = [0.0] - phis = [0, 2 * np.pi] - # this will hit the R bound - # (there is no Z bound, and R would go to 10.0 if not bounded) - r, z = field_line_integrate(r0, z0, phis, field, bounds_R=(10.05, np.inf)) - np.testing.assert_allclose(r[-1], 10.05, rtol=3e-4) - # this will hit the Z bound - # (there is no R bound, and Z would go to 0.1 if not bounded) - r, z = field_line_integrate(r0, z0, phis, field, bounds_Z=(-np.inf, 0.05)) - np.testing.assert_allclose(z[-1], 0.05, atol=3e-3) + phis = [0, 2 * np.pi * 25] + r, z = field_line_integrate(r0, z0, phis, field, solver=Dopri5()) + np.testing.assert_allclose(r[-1], 10, rtol=1e-6, atol=1e-6) + np.testing.assert_allclose(z[-1], 0.001, rtol=1e-6, atol=1e-6) + + @pytest.mark.unit + def test_field_line_integrate_early_terminate_default(self): + """Test field line integration with default early termination criterion.""" + # q=4, field line should rotate 1/4 turn after 1 toroidal transit + # from outboard midplane to top center + # early terminate when it crosses towards the inboard side (R=10), + field1 = ToroidalMagneticField(2, 10) + PoloidalMagneticField(2, 10, 0.25) + # make a SplineMagneticField only defined in a tiny region around initial point + field = SplineMagneticField.from_field( + field=field1, + R=np.linspace(10.0, 10.005, 40), + phi=np.linspace(0, 2 * np.pi, 40), + Z=np.linspace(-5e-3, 5e-3, 40), + extrap=True, + ) + r0 = [10.001] + z0 = [0.0] + phis = [0, 2 * np.pi, 2 * np.pi * 2] + + r, z = field_line_integrate( + r0, + z0, + phis, + field, + bounds_R=(np.min(field._R), np.max(field._R)), + bounds_Z=(np.min(field._Z), np.max(field._Z)), + min_step_size=1e-2, + ) + np.testing.assert_allclose(r[1], 10, rtol=1e-6, atol=1e-6) + np.testing.assert_allclose(z[1], 0.001, rtol=1e-6, atol=1e-6) + # if early terinated, the values at the un-integrated phi points are inf + assert np.isnan(r[-1]) + assert np.isnan(z[-1]) @pytest.mark.unit def test_Bnormal_calculation(self):