From a15ad3e464b49dc31b2aae56af9e702b01cbaaf3 Mon Sep 17 00:00:00 2001 From: Dario Panici Date: Thu, 3 Aug 2023 17:28:53 -0400 Subject: [PATCH 01/15] change field line integration to use diffrax --- desc/magnetic_fields.py | 72 +++++++++++++++++++++++++++++++---- requirements.txt | 1 + tests/test_magnetic_fields.py | 35 +++++++++++++++++ 3 files changed, 101 insertions(+), 7 deletions(-) diff --git a/desc/magnetic_fields.py b/desc/magnetic_fields.py index 442c10e278..28f860c1cf 100644 --- a/desc/magnetic_fields.py +++ b/desc/magnetic_fields.py @@ -3,9 +3,18 @@ from abc import ABC, abstractmethod import numpy as np +from diffrax import ( + AbstractSolver, + DiscreteTerminatingEvent, + ODETerm, + PIDController, + SaveAt, + Tsit5, + diffeqsolve, +) from netCDF4 import Dataset -from desc.backend import jit, jnp, odeint +from desc.backend import jit, jnp from desc.derivatives import Derivative from desc.geometry.utils import rpz2xyz_vec, xyz2rpz from desc.grid import Grid @@ -747,9 +756,18 @@ def compute_magnetic_field(self, coords, params=None, basis="rpz"): def field_line_integrate( - r0, z0, phis, field, params={}, rtol=1e-8, atol=1e-8, maxstep=1000 + r0, + z0, + phis, + field, + params={}, + rtol=1e-8, + atol=1e-8, + maxstep=1000, + solver=Tsit5(), + terminating_event=None, ): - """Trace field lines by integration. + """Trace field lines by integration, using diffrax package. Parameters ---------- @@ -767,6 +785,16 @@ def field_line_integrate( relative and absolute tolerances for ode integration maxstep : int maximum number of steps between different phis + solver: diffrax.Solver + diffrax Solver object to use in integration, + defaults to Tsit5(), a RK45 explicit solver + terminating_event_fxn: fxn + Function which takes as input the state of the ODE solution at each timestep and + outputs a bool which, if True, terminates the solve at that timestep. + NOTE: If the solve is terminated early, the output returned is still + length(phis), however all values from the step point the fxn evaluated + to True and on will be inf + Returns ------- @@ -774,6 +802,16 @@ def field_line_integrate( arrays of r, z coordinates at specified phi angles """ + if not isinstance(solver, AbstractSolver): + try: + # maybe they passed in the correct object but did not + # instantiate it + solver = solver() + except TypeError: + raise TypeError( + "Expected a diffrax Solver object as solver," + + f"instead got type {type(solver)}!" + ) r0, z0, phis = map(jnp.asarray, (r0, z0, phis)) assert r0.shape == z0.shape, "r0 and z0 must have the same shape" rshape = r0.shape @@ -782,7 +820,7 @@ def field_line_integrate( x0 = jnp.array([r0, phis[0] * jnp.ones_like(r0), z0]).T @jit - def odefun(rpz, s): + def odefun(s, rpz, args): rpz = rpz.reshape((3, -1)).T r = rpz[:, 0] br, bp, bz = field.compute_magnetic_field(rpz, params, basis="rpz").T @@ -790,8 +828,28 @@ def odefun(rpz, s): [r * br / bp * jnp.sign(bp), jnp.sign(bp), r * bz / bp * jnp.sign(bp)] ).squeeze() - intfun = lambda x: odeint(odefun, x, phis, rtol=rtol, atol=atol, mxstep=maxstep) + # diffrax parameters + stepsize_controller = PIDController(rtol=rtol, atol=atol) + terminating_event = ( + DiscreteTerminatingEvent(terminating_event) if terminating_event else None + ) + term = ODETerm(odefun) + saveat = SaveAt(ts=phis) + + intfun = lambda x: diffeqsolve( + term, + solver, + y0=x, + t0=phis[0], + t1=phis[-1], + saveat=saveat, + max_steps=maxstep, + dt0=None, # have diffrax automatically choose it + stepsize_controller=stepsize_controller, + discrete_terminating_event=terminating_event, + ).ys + x = jnp.vectorize(intfun, signature="(k)->(n,k)")(x0) - r = x[:, :, 0].T.reshape((len(phis), *rshape)) - z = x[:, :, 2].T.reshape((len(phis), *rshape)) + r = x[:, :, 0].squeeze().T.reshape((len(phis), *rshape)) + z = x[:, :, 2].squeeze().T.reshape((len(phis), *rshape)) return r, z diff --git a/requirements.txt b/requirements.txt index b455bcabbe..da16efa623 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,3 +9,4 @@ nvgpu psutil scipy >= 1.5.0, < 1.11.0 termcolor +diffrax diff --git a/tests/test_magnetic_fields.py b/tests/test_magnetic_fields.py index 23e98b97b9..7aafd2b354 100644 --- a/tests/test_magnetic_fields.py +++ b/tests/test_magnetic_fields.py @@ -2,6 +2,7 @@ import numpy as np import pytest +from diffrax import Dopri5 from desc.backend import jnp from desc.magnetic_fields import ( @@ -130,3 +131,37 @@ def test_field_line_integrate(self): r, z = field_line_integrate(r0, z0, phis, field) np.testing.assert_allclose(r[-1], 10, rtol=1e-6, atol=1e-6) np.testing.assert_allclose(z[-1], 0.001, rtol=1e-6, atol=1e-6) + + @pytest.mark.unit + def test_field_line_integrate_long(self): + """Test field line integration for long distance along line.""" + # q=4, field line should rotate 1/4 turn after 1 toroidal transit + # from outboard midplane to top center + field = ToroidalMagneticField(2, 10) + PoloidalMagneticField(2, 10, 0.25) + r0 = [10.001] + z0 = [0.0] + phis = [0, 2 * np.pi * 25] + r, z = field_line_integrate(r0, z0, phis, field, solver=Dopri5) + np.testing.assert_allclose(r[-1], 10, rtol=1e-6, atol=1e-6) + np.testing.assert_allclose(z[-1], 0.001, rtol=1e-6, atol=1e-6) + + def test_field_line_integrate_early_terminate(self): + """Test field line integration with early termination criterion.""" + # q=4, field line should rotate 1/4 turn after 1 toroidal transit + # from outboard midplane to top center + # early terminate at 2pi, if fails to terminate correctly + # then the assert statements would not hold + field = ToroidalMagneticField(2, 10) + PoloidalMagneticField(2, 10, 0.25) + r0 = [10.001] + z0 = [0.0] + phis = [0, 2 * np.pi, 2 * np.pi * 2] + + def cond_fxn(state, **kwargs): + return jnp.any(state.y[1] > 8) + + r, z = field_line_integrate(r0, z0, phis, field, terminating_event=cond_fxn) + np.testing.assert_allclose(r[1], 10, rtol=1e-6, atol=1e-6) + np.testing.assert_allclose(z[1], 0.001, rtol=1e-6, atol=1e-6) + # if early terinated, the values at the un-integrated phi points are inf + assert np.isinf(r[-1]) + assert np.isinf(z[-1]) From 68961714e46ad3723983c3f7346e9111dd3f2c25 Mon Sep 17 00:00:00 2001 From: Dario Panici Date: Thu, 3 Aug 2023 17:33:21 -0400 Subject: [PATCH 02/15] limit diffrax version as 0.4.1 requires jax > 0.4.13 --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index da16efa623..496967ca9a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,4 +9,4 @@ nvgpu psutil scipy >= 1.5.0, < 1.11.0 termcolor -diffrax +diffrax<=0.4.0 From c3fde2c57f1d6a0437bb09c40c7fe833b4b788dd Mon Sep 17 00:00:00 2001 From: Dario Panici Date: Thu, 3 Aug 2023 17:56:33 -0400 Subject: [PATCH 03/15] add to docstring --- desc/magnetic_fields.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/desc/magnetic_fields.py b/desc/magnetic_fields.py index 28f860c1cf..0a90b814f2 100644 --- a/desc/magnetic_fields.py +++ b/desc/magnetic_fields.py @@ -794,6 +794,9 @@ def field_line_integrate( NOTE: If the solve is terminated early, the output returned is still length(phis), however all values from the step point the fxn evaluated to True and on will be inf + state has attributes such a state.y (array of length 3 with current (R,phi,Z)) + see diffrax documentation for more in-depth information + Returns From 6ca2471fe172459d64037b4287bebb2fd4c7ce98 Mon Sep 17 00:00:00 2001 From: Dario Panici Date: Thu, 10 Aug 2023 15:22:13 -0400 Subject: [PATCH 04/15] increase JAX version, restrict ml_dtypes to avoid DeprecationWarning hopefully --- requirements.txt | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index 496967ca9a..f30fc94cae 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,9 @@ colorama +diffrax h5py >= 3.0.0 -jax[cpu] >= 0.3.2, <= 0.4.11 +jax[cpu] >= 0.4.1, <= 0.4.11 matplotlib >= 3.3.0, <= 3.6.0, != 3.4.3 +ml_dtypes<0.2.0 mpmath >= 1.0.0 netcdf4 >= 1.5.4 numpy >= 1.20.0, < 1.25.0 @@ -9,4 +11,3 @@ nvgpu psutil scipy >= 1.5.0, < 1.11.0 termcolor -diffrax<=0.4.0 From 0bc53406b00dfa34147f92a9399e7825afb2f92c Mon Sep 17 00:00:00 2001 From: Dario Panici Date: Thu, 10 Aug 2023 17:54:47 -0400 Subject: [PATCH 05/15] add NaN check (not working currently) and ignore RuntimeWarnings (from netCDF4 but cannot only target that module) --- desc/magnetic_fields.py | 17 +++++++++++++---- requirements.txt | 2 +- setup.cfg | 3 +++ tests/test_magnetic_fields.py | 27 +++++++++++++++++++++++++++ 4 files changed, 44 insertions(+), 5 deletions(-) diff --git a/desc/magnetic_fields.py b/desc/magnetic_fields.py index 0a90b814f2..c0f2178f10 100644 --- a/desc/magnetic_fields.py +++ b/desc/magnetic_fields.py @@ -789,12 +789,14 @@ def field_line_integrate( diffrax Solver object to use in integration, defaults to Tsit5(), a RK45 explicit solver terminating_event_fxn: fxn - Function which takes as input the state of the ODE solution at each timestep and - outputs a bool which, if True, terminates the solve at that timestep. + Function which takes as input the state of the ODE solution at each timestep + and **kwargs, and outputs a bool which, if True, terminates the solve at that + timestep. + fxn must have signature of (state, **kwargs) -> Bool NOTE: If the solve is terminated early, the output returned is still length(phis), however all values from the step point the fxn evaluated to True and on will be inf - state has attributes such a state.y (array of length 3 with current (R,phi,Z)) + state has attributes such as state.y (array of length 3 with current (R,phi,Z)) see diffrax documentation for more in-depth information @@ -833,8 +835,15 @@ def odefun(s, rpz, args): # diffrax parameters stepsize_controller = PIDController(rtol=rtol, atol=atol) + + def default_terminating_event_fxn(state, **kwargs): + terms = kwargs.get("terms", lambda a, x, b: x) + return jnp.any(jnp.isnan(terms.vf(0, state.y, 0))) + terminating_event = ( - DiscreteTerminatingEvent(terminating_event) if terminating_event else None + DiscreteTerminatingEvent(terminating_event) + if terminating_event + else DiscreteTerminatingEvent(default_terminating_event_fxn) ) term = ODETerm(odefun) saveat = SaveAt(ts=phis) diff --git a/requirements.txt b/requirements.txt index f30fc94cae..e0360b5504 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ colorama diffrax h5py >= 3.0.0 -jax[cpu] >= 0.4.1, <= 0.4.11 +jax[cpu] >= 0.4.1, <= 0.4.13 matplotlib >= 3.3.0, <= 3.6.0, != 3.4.3 ml_dtypes<0.2.0 mpmath >= 1.0.0 diff --git a/setup.cfg b/setup.cfg index 0870f8b3b8..7cae1db428 100644 --- a/setup.cfg +++ b/setup.cfg @@ -49,6 +49,9 @@ filterwarnings= ignore::pytest.PytestUnraisableExceptionWarning ignore::RuntimeWarning:desc.compute # Ignore division by zero warnings. + ignore::RuntimeWarning + + [flake8] # Primarily ignoring whitespace, indentation, and commenting etiquette that black does not catch diff --git a/tests/test_magnetic_fields.py b/tests/test_magnetic_fields.py index 7aafd2b354..aaaf2438f1 100644 --- a/tests/test_magnetic_fields.py +++ b/tests/test_magnetic_fields.py @@ -165,3 +165,30 @@ def cond_fxn(state, **kwargs): # if early terinated, the values at the un-integrated phi points are inf assert np.isinf(r[-1]) assert np.isinf(z[-1]) + + def test_field_line_integrate_early_terminate_NaN(self): + """Test field line integration with default early termination criterion.""" + # q=4, field line should rotate 1/4 turn after 1 toroidal transit + # from outboard midplane to top center + # early terminate at 2pi, if fails to terminate correctly + # then the assert statements would not hold + field1 = ToroidalMagneticField(2, 10) + PoloidalMagneticField(2, 10, 0.25) + # make a SplineMagneticField only defined in a tiny region around initial point + field = SplineMagneticField.from_field( + field=field1, + R=np.linspace(10.0, 10.002, 10), + phi=np.linspace(0, 2 * np.pi, 10), + Z=np.linspace(0, 1.5e-3, 10), + extrap=False, + period=2 * np.pi, + ) + r0 = [10.001] + z0 = [0.0] + phis = [0, 2 * np.pi, 2 * np.pi * 2] + + r, z = field_line_integrate(r0, z0, phis, field) + np.testing.assert_allclose(r[1], 10, rtol=1e-6, atol=1e-6) + np.testing.assert_allclose(z[1], 0.001, rtol=1e-6, atol=1e-6) + # if early terinated, the values at the un-integrated phi points are inf + assert np.isinf(r[-1]) + assert np.isinf(z[-1]) From 2900dfac1cdcf6f91faaede093c72932683c1231 Mon Sep 17 00:00:00 2001 From: Dario Panici Date: Fri, 11 Aug 2023 08:40:44 -0400 Subject: [PATCH 06/15] change ignore to only catch the bening ndarray size change warning (if it is not benign, Cython would raise a ValueError, not a RuntimeWarning), and add kwrgs to integrate function --- desc/magnetic_fields.py | 8 ++++++-- setup.cfg | 2 +- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/desc/magnetic_fields.py b/desc/magnetic_fields.py index f32f3136e7..f1e48ed7aa 100644 --- a/desc/magnetic_fields.py +++ b/desc/magnetic_fields.py @@ -766,6 +766,7 @@ def field_line_integrate( maxstep=1000, solver=Tsit5(), terminating_event=None, + kwargs={}, ): """Trace field lines by integration, using diffrax package. @@ -798,6 +799,8 @@ def field_line_integrate( to True and on will be inf state has attributes such as state.y (array of length 3 with current (R,phi,Z)) see diffrax documentation for more in-depth information + kwargs: dict + keyword arguments to be passed into the diffrax diffeqsolve function call @@ -838,7 +841,7 @@ def odefun(s, rpz, args): def default_terminating_event_fxn(state, **kwargs): terms = kwargs.get("terms", lambda a, x, b: x) - return jnp.any(jnp.isnan(terms.vf(0, state.y, 0))) + return jnp.any(jnp.isnan(terms.vf(state.tnext, state.y, 0))) terminating_event = ( DiscreteTerminatingEvent(terminating_event) @@ -846,7 +849,7 @@ def default_terminating_event_fxn(state, **kwargs): else DiscreteTerminatingEvent(default_terminating_event_fxn) ) term = ODETerm(odefun) - saveat = SaveAt(ts=phis) + saveat = kwargs.get("saveat", SaveAt(ts=phis)) intfun = lambda x: diffeqsolve( term, @@ -859,6 +862,7 @@ def default_terminating_event_fxn(state, **kwargs): dt0=None, # have diffrax automatically choose it stepsize_controller=stepsize_controller, discrete_terminating_event=terminating_event, + **kwargs, ).ys x = jnp.vectorize(intfun, signature="(k)->(n,k)")(x0) diff --git a/setup.cfg b/setup.cfg index 7cae1db428..956509efe4 100644 --- a/setup.cfg +++ b/setup.cfg @@ -49,7 +49,7 @@ filterwarnings= ignore::pytest.PytestUnraisableExceptionWarning ignore::RuntimeWarning:desc.compute # Ignore division by zero warnings. - ignore::RuntimeWarning + ignore:numpy.ndarray size changed:RuntimeWarning From ab4b1d403cb0bed3101317b296c1f3b78cdbb885 Mon Sep 17 00:00:00 2001 From: Dario Panici Date: Fri, 11 Aug 2023 14:51:30 -0400 Subject: [PATCH 07/15] add missing pytest markers --- tests/test_magnetic_fields.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_magnetic_fields.py b/tests/test_magnetic_fields.py index aaaf2438f1..ffe9660347 100644 --- a/tests/test_magnetic_fields.py +++ b/tests/test_magnetic_fields.py @@ -145,6 +145,7 @@ def test_field_line_integrate_long(self): np.testing.assert_allclose(r[-1], 10, rtol=1e-6, atol=1e-6) np.testing.assert_allclose(z[-1], 0.001, rtol=1e-6, atol=1e-6) + @pytest.mark.unit def test_field_line_integrate_early_terminate(self): """Test field line integration with early termination criterion.""" # q=4, field line should rotate 1/4 turn after 1 toroidal transit @@ -166,6 +167,7 @@ def cond_fxn(state, **kwargs): assert np.isinf(r[-1]) assert np.isinf(z[-1]) + @pytest.mark.unit def test_field_line_integrate_early_terminate_NaN(self): """Test field line integration with default early termination criterion.""" # q=4, field line should rotate 1/4 turn after 1 toroidal transit From 140633c0d348716591e854055a3d6f6a2099f900 Mon Sep 17 00:00:00 2001 From: Dario Panici Date: Mon, 14 Aug 2023 14:29:06 -0400 Subject: [PATCH 08/15] add min stepsize (NaN event works now but compile time takes ages on diffrax's part), undo accidental ml_dtypes change, add ignore for benign equinox warning --- desc/magnetic_fields.py | 8 ++++++-- requirements.txt | 1 - setup.cfg | 1 + 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/desc/magnetic_fields.py b/desc/magnetic_fields.py index f1e48ed7aa..dec30ed0ae 100644 --- a/desc/magnetic_fields.py +++ b/desc/magnetic_fields.py @@ -764,6 +764,7 @@ def field_line_integrate( rtol=1e-8, atol=1e-8, maxstep=1000, + min_step_size=1e-8, solver=Tsit5(), terminating_event=None, kwargs={}, @@ -786,6 +787,8 @@ def field_line_integrate( relative and absolute tolerances for ode integration maxstep : int maximum number of steps between different phis + min_step_size: float + minimum step size (in phi) that the integration can take. default is 1e-8 solver: diffrax.Solver diffrax Solver object to use in integration, defaults to Tsit5(), a RK45 explicit solver @@ -837,7 +840,7 @@ def odefun(s, rpz, args): ).squeeze() # diffrax parameters - stepsize_controller = PIDController(rtol=rtol, atol=atol) + stepsize_controller = PIDController(rtol=rtol, atol=atol, dtmin=min_step_size) def default_terminating_event_fxn(state, **kwargs): terms = kwargs.get("terms", lambda a, x, b: x) @@ -859,7 +862,7 @@ def default_terminating_event_fxn(state, **kwargs): t1=phis[-1], saveat=saveat, max_steps=maxstep, - dt0=None, # have diffrax automatically choose it + dt0=min_step_size, stepsize_controller=stepsize_controller, discrete_terminating_event=terminating_event, **kwargs, @@ -868,4 +871,5 @@ def default_terminating_event_fxn(state, **kwargs): x = jnp.vectorize(intfun, signature="(k)->(n,k)")(x0) r = x[:, :, 0].squeeze().T.reshape((len(phis), *rshape)) z = x[:, :, 2].squeeze().T.reshape((len(phis), *rshape)) + return r, z diff --git a/requirements.txt b/requirements.txt index 160a7f65e8..0e7d8e82d1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,7 +3,6 @@ diffrax h5py >= 3.0.0 jax[cpu] >= 0.3.2, <= 0.4.14 matplotlib >= 3.3.0, <= 3.6.0, != 3.4.3 -ml_dtypes<0.2.0 mpmath >= 1.0.0 netcdf4 >= 1.5.4 numpy >= 1.20.0, < 1.25.0 diff --git a/setup.cfg b/setup.cfg index 0cecf568dc..d34b271ca3 100644 --- a/setup.cfg +++ b/setup.cfg @@ -53,6 +53,7 @@ filterwarnings= # ignore benign Cython warnings on ndarray size ignore::DeprecationWarning:ml_dtypes.* # ignore benign ml_dtypes DeprecationWarning + ignore:As of Equinox 0.10.7, `equinox.filter_custom_vjp.defvjp`: UserWarning [flake8] # Primarily ignoring whitespace, indentation, and commenting etiquette that black does not catch From 8107318be1b312fafb56b3229a7a3920fca8c46b Mon Sep 17 00:00:00 2001 From: Dario Panici Date: Tue, 15 Aug 2023 19:46:44 -0400 Subject: [PATCH 09/15] fix test by updating diffrax version --- requirements.txt | 2 +- tests/test_magnetic_fields.py | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/requirements.txt b/requirements.txt index 0e7d8e82d1..7dbece083d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ colorama -diffrax +diffrax>=0.4.1 h5py >= 3.0.0 jax[cpu] >= 0.3.2, <= 0.4.14 matplotlib >= 3.3.0, <= 3.6.0, != 3.4.3 diff --git a/tests/test_magnetic_fields.py b/tests/test_magnetic_fields.py index ffe9660347..aa7a290cce 100644 --- a/tests/test_magnetic_fields.py +++ b/tests/test_magnetic_fields.py @@ -168,7 +168,7 @@ def cond_fxn(state, **kwargs): assert np.isinf(z[-1]) @pytest.mark.unit - def test_field_line_integrate_early_terminate_NaN(self): + def test_field_line_integrate_early_terminate_default(self): """Test field line integration with default early termination criterion.""" # q=4, field line should rotate 1/4 turn after 1 toroidal transit # from outboard midplane to top center @@ -178,11 +178,11 @@ def test_field_line_integrate_early_terminate_NaN(self): # make a SplineMagneticField only defined in a tiny region around initial point field = SplineMagneticField.from_field( field=field1, - R=np.linspace(10.0, 10.002, 10), - phi=np.linspace(0, 2 * np.pi, 10), - Z=np.linspace(0, 1.5e-3, 10), + R=np.linspace(-9.995, 10.005, 40), + phi=np.linspace(0, 3 * np.pi, 60), + Z=np.linspace(-1.5e-3, 1.5e-3, 40), extrap=False, - period=2 * np.pi, + period=np.inf, ) r0 = [10.001] z0 = [0.0] From ee285b441a130855e76c7a0f89c765c1580403cc Mon Sep 17 00:00:00 2001 From: Dario Panici Date: Tue, 15 Aug 2023 21:54:03 -0400 Subject: [PATCH 10/15] change default terminating event to end integration if exit domain of interest defined by passed-in bounds --- desc/magnetic_fields.py | 82 +++++++++++++++++++++++++++++++---- tests/test_magnetic_fields.py | 15 +++++-- 2 files changed, 85 insertions(+), 12 deletions(-) diff --git a/desc/magnetic_fields.py b/desc/magnetic_fields.py index dec30ed0ae..e7dc85547b 100644 --- a/desc/magnetic_fields.py +++ b/desc/magnetic_fields.py @@ -767,6 +767,9 @@ def field_line_integrate( min_step_size=1e-8, solver=Tsit5(), terminating_event=None, + bounds_R=None, + bounds_Z=None, + bounds_phi=None, kwargs={}, ): """Trace field lines by integration, using diffrax package. @@ -792,16 +795,35 @@ def field_line_integrate( solver: diffrax.Solver diffrax Solver object to use in integration, defaults to Tsit5(), a RK45 explicit solver - terminating_event_fxn: fxn + terminating_event: fxn Function which takes as input the state of the ODE solution at each timestep and **kwargs, and outputs a bool which, if True, terminates the solve at that timestep. + fxn must have signature of (state, **kwargs) -> Bool + + If not given and one of bounds_R,Z,phi are given, will + default to a terminating event which ends integration + once R,Z,or phi exit the domain defined by the given bounds + NOTE: If the solve is terminated early, the output returned is still length(phis), however all values from the step point the fxn evaluated to True and on will be inf state has attributes such as state.y (array of length 3 with current (R,phi,Z)) see diffrax documentation for more in-depth information + bounds_R: tuple + tuple of (R_min,R_max) of the R bounds for the domain of interest. + Integration will terminate when the field line exits this domain + (when using the default terminating event) + bounds_Z: tuple + tuple of (Z_min,Z_max) of the Z bounds for the domain of interest. + Integration will terminate when the field line exits this domain + (when using the default terminating event) + bounds_phi: tuple + tuple of (phi_min,phi_max) of the phi bounds for the domain of interest. + Integration will terminate when the field line exits this domain + (when using the default terminating event) + kwargs: dict keyword arguments to be passed into the diffrax diffeqsolve function call @@ -842,15 +864,57 @@ def odefun(s, rpz, args): # diffrax parameters stepsize_controller = PIDController(rtol=rtol, atol=atol, dtmin=min_step_size) - def default_terminating_event_fxn(state, **kwargs): - terms = kwargs.get("terms", lambda a, x, b: x) - return jnp.any(jnp.isnan(terms.vf(state.tnext, state.y, 0))) + if np.all( + [ + bounds_R is None, + bounds_Z is None, + bounds_phi is None, + terminating_event is None, + ] + ): + # no bounds or terminating event given, so dont use a terminating event + terminating_event = None + else: + bounds_R = (-np.inf, np.inf) if bounds_R is None else bounds_R + bounds_Z = (-np.inf, np.inf) if bounds_Z is None else bounds_Z + bounds_phi = (-np.inf, np.inf) if bounds_phi is None else bounds_phi + + def default_terminating_event_fxn(state, **kwargs): + R_out_of_bounds = jnp.any( + jnp.array( + [ + jnp.less(state.y[0], bounds_R[0]), + jnp.greater(state.y[0], bounds_R[1]), + ] + ) + ) + Z_out_of_bounds = jnp.any( + jnp.array( + [ + jnp.less(state.y[2], bounds_Z[0]), + jnp.greater(state.y[2], bounds_Z[1]), + ] + ) + ) + phi_out_of_bounds = jnp.any( + jnp.array( + [ + jnp.less(state.y[1], bounds_phi[0]), + jnp.greater(state.y[1], bounds_phi[1]), + ] + ) + ) + + return jnp.any( + jnp.array([R_out_of_bounds, Z_out_of_bounds, phi_out_of_bounds]) + ) + + terminating_event = ( + DiscreteTerminatingEvent(terminating_event) + if terminating_event + else DiscreteTerminatingEvent(default_terminating_event_fxn) + ) - terminating_event = ( - DiscreteTerminatingEvent(terminating_event) - if terminating_event - else DiscreteTerminatingEvent(default_terminating_event_fxn) - ) term = ODETerm(odefun) saveat = kwargs.get("saveat", SaveAt(ts=phis)) diff --git a/tests/test_magnetic_fields.py b/tests/test_magnetic_fields.py index aa7a290cce..69d4ab9683 100644 --- a/tests/test_magnetic_fields.py +++ b/tests/test_magnetic_fields.py @@ -179,16 +179,25 @@ def test_field_line_integrate_early_terminate_default(self): field = SplineMagneticField.from_field( field=field1, R=np.linspace(-9.995, 10.005, 40), - phi=np.linspace(0, 3 * np.pi, 60), + phi=np.linspace(0, 2 * np.pi, 40), Z=np.linspace(-1.5e-3, 1.5e-3, 40), extrap=False, - period=np.inf, + period=2 * np.pi, ) r0 = [10.001] z0 = [0.0] phis = [0, 2 * np.pi, 2 * np.pi * 2] - r, z = field_line_integrate(r0, z0, phis, field) + r, z = field_line_integrate( + r0, + z0, + phis, + field, + bounds_R=(np.min(field._R), np.max(field._R)), + bounds_Z=(np.min(field._Z), np.max(field._Z)), + bounds_phi=(np.min(field._phi), 3 * np.pi), + min_step_size=1e-2, + ) np.testing.assert_allclose(r[1], 10, rtol=1e-6, atol=1e-6) np.testing.assert_allclose(z[1], 0.001, rtol=1e-6, atol=1e-6) # if early terinated, the values at the un-integrated phi points are inf From 1705508587f47f9c8be1ad268606944af0622cf9 Mon Sep 17 00:00:00 2001 From: Dario Panici Date: Tue, 15 Aug 2023 22:53:49 -0400 Subject: [PATCH 11/15] remove 3.8 testing as is not compatible with diffrax --- .github/workflows/unittest.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/unittest.yml b/.github/workflows/unittest.yml index e8aae7d1d9..a5278c2754 100644 --- a/.github/workflows/unittest.yml +++ b/.github/workflows/unittest.yml @@ -23,7 +23,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ['3.8', '3.10'] + python-version: ['3.9', '3.10'] group: [1, 2, 3, 4, 5] steps: From 9709281a3970b312fd8b69e18e91ebb87ef3b3ce Mon Sep 17 00:00:00 2001 From: Dario Panici Date: Sat, 20 Jan 2024 21:46:19 -0500 Subject: [PATCH 12/15] fix params arg --- desc/magnetic_fields.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/desc/magnetic_fields.py b/desc/magnetic_fields.py index 95b01cea5b..45f5e970b2 100644 --- a/desc/magnetic_fields.py +++ b/desc/magnetic_fields.py @@ -1212,7 +1212,7 @@ def field_line_integrate( z0, phis, field, - params={}, + params=None, grid=None, rtol=1e-8, atol=1e-8, @@ -1237,7 +1237,7 @@ def field_line_integrate( and the negative toroidal angle for negative Bphi field : MagneticField source of magnetic field to integrate - params: dict + params: dict, optional parameters passed to field grid : Grid, optional Collocation points used to discretize source field. From ee6fa850c55a10c5799704460f2597949166c743 Mon Sep 17 00:00:00 2001 From: Rory Conlin Date: Tue, 27 Aug 2024 01:25:52 -0400 Subject: [PATCH 13/15] Simplify API a bit, get tests working --- desc/magnetic_fields/_core.py | 56 +++++++++++------------------------ tests/test_magnetic_fields.py | 54 ++++----------------------------- 2 files changed, 23 insertions(+), 87 deletions(-) diff --git a/desc/magnetic_fields/_core.py b/desc/magnetic_fields/_core.py index 4b1f6b7ad2..12341298f1 100644 --- a/desc/magnetic_fields/_core.py +++ b/desc/magnetic_fields/_core.py @@ -1,5 +1,6 @@ """Classes for magnetic fields.""" +import warnings from abc import ABC, abstractmethod from collections.abc import MutableSequence @@ -1566,43 +1567,18 @@ def odefun(s, rpz, args): ).squeeze() # diffrax parameters - stepsize_controller = PIDController(rtol=rtol, atol=atol, dtmin=min_step_size) - - if "discrete_terminating_event" not in kwargs: - bounds_phi = (-np.inf, np.inf) - - def default_terminating_event_fxn(state, **kwargs): - R_out_of_bounds = jnp.any( - jnp.array( - [ - jnp.less(state.y[0], bounds_R[0]), - jnp.greater(state.y[0], bounds_R[1]), - ] - ) - ) - Z_out_of_bounds = jnp.any( - jnp.array( - [ - jnp.less(state.y[2], bounds_Z[0]), - jnp.greater(state.y[2], bounds_Z[1]), - ] - ) - ) - phi_out_of_bounds = jnp.any( - jnp.array( - [ - jnp.less(state.y[1], bounds_phi[0]), - jnp.greater(state.y[1], bounds_phi[1]), - ] - ) - ) - return jnp.any( - jnp.array([R_out_of_bounds, Z_out_of_bounds, phi_out_of_bounds]) - ) + def default_terminating_event_fxn(state, **kwargs): + R_out = jnp.any(jnp.array([state.y[0] < bounds_R[0], state.y[0] > bounds_R[1]])) + Z_out = jnp.any(jnp.array([state.y[2] < bounds_Z[0], state.y[2] > bounds_Z[1]])) + return jnp.any(jnp.array([R_out, Z_out])) - kwargs["discrete_terminating_event"] = DiscreteTerminatingEvent( - default_terminating_event_fxn + kwargs.setdefault( + "stepsize_controller", PIDController(rtol=rtol, atol=atol, dtmin=min_step_size) + ) + kwargs.setdefault( + "discrete_terminating_event", + DiscreteTerminatingEvent(default_terminating_event_fxn), ) term = ODETerm(odefun) @@ -1615,13 +1591,17 @@ def default_terminating_event_fxn(state, **kwargs): t0=phis[0], t1=phis[-1], saveat=saveat, - max_steps=maxstep, + max_steps=maxstep * len(phis), dt0=min_step_size, - stepsize_controller=stepsize_controller, **kwargs, ).ys - x = jnp.vectorize(intfun, signature="(k)->(n,k)")(x0) + # suppress warnings till its fixed upstream: + # https://github.com/patrick-kidger/diffrax/issues/445 + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", message="unhashable type") + x = jnp.vectorize(intfun, signature="(k)->(n,k)")(x0) + r = x[:, :, 0].squeeze().T.reshape((len(phis), *rshape)) z = x[:, :, 2].squeeze().T.reshape((len(phis), *rshape)) diff --git a/tests/test_magnetic_fields.py b/tests/test_magnetic_fields.py index 320651df93..5d09e2c4f4 100644 --- a/tests/test_magnetic_fields.py +++ b/tests/test_magnetic_fields.py @@ -728,48 +728,24 @@ def test_field_line_integrate_long(self): r0 = [10.001] z0 = [0.0] phis = [0, 2 * np.pi * 25] - r, z = field_line_integrate(r0, z0, phis, field, solver=Dopri5) + r, z = field_line_integrate(r0, z0, phis, field, solver=Dopri5()) np.testing.assert_allclose(r[-1], 10, rtol=1e-6, atol=1e-6) np.testing.assert_allclose(z[-1], 0.001, rtol=1e-6, atol=1e-6) - @pytest.mark.unit - def test_field_line_integrate_early_terminate(self): - """Test field line integration with early termination criterion.""" - # q=4, field line should rotate 1/4 turn after 1 toroidal transit - # from outboard midplane to top center - # early terminate at 2pi, if fails to terminate correctly - # then the assert statements would not hold - field = ToroidalMagneticField(2, 10) + PoloidalMagneticField(2, 10, 0.25) - r0 = [10.001] - z0 = [0.0] - phis = [0, 2 * np.pi, 2 * np.pi * 2] - - def cond_fxn(state, **kwargs): - return jnp.any(state.y[1] > 8) - - r, z = field_line_integrate(r0, z0, phis, field, terminating_event=cond_fxn) - np.testing.assert_allclose(r[1], 10, rtol=1e-6, atol=1e-6) - np.testing.assert_allclose(z[1], 0.001, rtol=1e-6, atol=1e-6) - # if early terinated, the values at the un-integrated phi points are inf - assert np.isinf(r[-1]) - assert np.isinf(z[-1]) - @pytest.mark.unit def test_field_line_integrate_early_terminate_default(self): """Test field line integration with default early termination criterion.""" # q=4, field line should rotate 1/4 turn after 1 toroidal transit # from outboard midplane to top center - # early terminate at 2pi, if fails to terminate correctly - # then the assert statements would not hold + # early terminate when it crosses towards the inboard side (R=10), field1 = ToroidalMagneticField(2, 10) + PoloidalMagneticField(2, 10, 0.25) # make a SplineMagneticField only defined in a tiny region around initial point field = SplineMagneticField.from_field( field=field1, - R=np.linspace(-9.995, 10.005, 40), + R=np.linspace(10.0, 10.005, 40), phi=np.linspace(0, 2 * np.pi, 40), - Z=np.linspace(-1.5e-3, 1.5e-3, 40), - extrap=False, - period=2 * np.pi, + Z=np.linspace(-5e-3, 5e-3, 40), + extrap=True, ) r0 = [10.001] z0 = [0.0] @@ -782,7 +758,6 @@ def test_field_line_integrate_early_terminate_default(self): field, bounds_R=(np.min(field._R), np.max(field._R)), bounds_Z=(np.min(field._Z), np.max(field._Z)), - bounds_phi=(np.min(field._phi), 3 * np.pi), min_step_size=1e-2, ) np.testing.assert_allclose(r[1], 10, rtol=1e-6, atol=1e-6) @@ -791,25 +766,6 @@ def test_field_line_integrate_early_terminate_default(self): assert np.isinf(r[-1]) assert np.isinf(z[-1]) - def test_field_line_integrate_bounds(self): - """Test field line integration with bounding box.""" - # q=4, field line should rotate 1/4 turn after 1 toroidal transit - # from outboard midplane to top center - field = ToroidalMagneticField(2, 10) + PoloidalMagneticField(2, 10, 0.25) - # test that bounds work correctly, and stop integration when trajectory - # hits the bounds - r0 = [10.1] - z0 = [0.0] - phis = [0, 2 * np.pi] - # this will hit the R bound - # (there is no Z bound, and R would go to 10.0 if not bounded) - r, z = field_line_integrate(r0, z0, phis, field, bounds_R=(10.05, np.inf)) - np.testing.assert_allclose(r[-1], 10.05, rtol=3e-4) - # this will hit the Z bound - # (there is no R bound, and Z would go to 0.1 if not bounded) - r, z = field_line_integrate(r0, z0, phis, field, bounds_Z=(-np.inf, 0.05)) - np.testing.assert_allclose(z[-1], 0.05, atol=3e-3) - @pytest.mark.unit def test_Bnormal_calculation(self): """Tests Bnormal calculation for simple toroidal field.""" From 59f9175a6de4770018810f2eaedfd06b8f9bb292 Mon Sep 17 00:00:00 2001 From: Rory Conlin Date: Tue, 27 Aug 2024 01:30:34 -0400 Subject: [PATCH 14/15] Return NaN for points that leave box --- desc/magnetic_fields/_core.py | 17 ++++++----------- tests/test_magnetic_fields.py | 4 ++-- 2 files changed, 8 insertions(+), 13 deletions(-) diff --git a/desc/magnetic_fields/_core.py b/desc/magnetic_fields/_core.py index 12341298f1..7cc1841d80 100644 --- a/desc/magnetic_fields/_core.py +++ b/desc/magnetic_fields/_core.py @@ -1526,18 +1526,12 @@ def field_line_integrate( diffrax Solver object to use in integration, defaults to Tsit5(), a RK45 explicit solver bounds_R : tuple of (float,float), optional - R bounds for field line integration bounding box. - If supplied, the RHS of the field line equations will be - multiplied by exp(-r) where r is the distance to the bounding box, - this is meant to prevent the field lines which escape to infinity from - slowing the integration down by being traced to infinity. - defaults to (0,np.inf) + R bounds for field line integration bounding box. Trajectories that leave this + box will be stopped, and NaN returned for points outside the box. + Defaults to (0,np.inf) bounds_Z : tuple of (float,float), optional - Z bounds for field line integration bounding box. - If supplied, the RHS of the field line equations will be - multiplied by exp(-r) where r is the distance to the bounding box, - this is meant to prevent the field lines which escape to infinity from - slowing the integration down by being traced to infinity. + Z bounds for field line integration bounding box. Trajectories that leave this + box will be stopped, and NaN returned for points outside the box. Defaults to (-np.inf,np.inf) kwargs: dict keyword arguments to be passed into the ``diffrax.diffeqsolve`` @@ -1602,6 +1596,7 @@ def default_terminating_event_fxn(state, **kwargs): warnings.filterwarnings("ignore", message="unhashable type") x = jnp.vectorize(intfun, signature="(k)->(n,k)")(x0) + x = jnp.where(jnp.isinf(x), jnp.nan, x) r = x[:, :, 0].squeeze().T.reshape((len(phis), *rshape)) z = x[:, :, 2].squeeze().T.reshape((len(phis), *rshape)) diff --git a/tests/test_magnetic_fields.py b/tests/test_magnetic_fields.py index 5d09e2c4f4..262df0fafe 100644 --- a/tests/test_magnetic_fields.py +++ b/tests/test_magnetic_fields.py @@ -763,8 +763,8 @@ def test_field_line_integrate_early_terminate_default(self): np.testing.assert_allclose(r[1], 10, rtol=1e-6, atol=1e-6) np.testing.assert_allclose(z[1], 0.001, rtol=1e-6, atol=1e-6) # if early terinated, the values at the un-integrated phi points are inf - assert np.isinf(r[-1]) - assert np.isinf(z[-1]) + assert np.isnan(r[-1]) + assert np.isnan(z[-1]) @pytest.mark.unit def test_Bnormal_calculation(self): From 2dab0a36d3e434e054ccf5cc6b364c9d1a7a0ff3 Mon Sep 17 00:00:00 2001 From: Rory Conlin Date: Tue, 27 Aug 2024 02:12:04 -0400 Subject: [PATCH 15/15] Ignore diffrax deprecation warning for now --- desc/magnetic_fields/_core.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/desc/magnetic_fields/_core.py b/desc/magnetic_fields/_core.py index 7cc1841d80..a598a15659 100644 --- a/desc/magnetic_fields/_core.py +++ b/desc/magnetic_fields/_core.py @@ -1592,8 +1592,10 @@ def default_terminating_event_fxn(state, **kwargs): # suppress warnings till its fixed upstream: # https://github.com/patrick-kidger/diffrax/issues/445 + # also ignore deprecation warning for now until we actually need to deal with it with warnings.catch_warnings(): warnings.filterwarnings("ignore", message="unhashable type") + warnings.filterwarnings("ignore", message="`diffrax.*discrete_terminating") x = jnp.vectorize(intfun, signature="(k)->(n,k)")(x0) x = jnp.where(jnp.isinf(x), jnp.nan, x)