Skip to content

Commit

Permalink
Increase coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
unalmis committed Aug 30, 2024
1 parent 446c0b7 commit 239e441
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 44 deletions.
4 changes: 2 additions & 2 deletions desc/equilibrium/coords.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,12 +685,12 @@ def get_rtz_grid(
rvp : rho, theta_PEST, phi
rtz : rho, theta, zeta
period : tuple of float
Assumed periodicity for each quantity in inbasis.
Assumed periodicity for functions of the given coordinates.
Use ``np.inf`` to denote no periodicity.
jitable : bool, optional
If false the returned grid has additional attributes.
Required to be false to retain nodes at magnetic axis.
kwargs : dict
kwargs
Additional parameters to supply to the coordinate mapping function.
See ``desc.equilibrium.coords.map_coordinates``.
Expand Down
7 changes: 6 additions & 1 deletion desc/integrals/bounce_integral.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def check_points(self, z1, z2, pitch_inv, plot=True, **kwargs):
specified by ``pitch_inv[...,α,ρ]`` where in the latter the labels
are interpreted as the indices that correspond to that field line.
plot : bool
Whether to plot stuff.
Whether to plot the field lines and bounce points of the given pitch angles.
kwargs
Keyword arguments into ``desc/integrals/bounce_utils.py::plot_ppoly``.
Expand Down Expand Up @@ -285,6 +285,7 @@ def integrate(
method="cubic",
batch=True,
check=False,
plot=False,
):
"""Bounce integrate ∫ f(ℓ) dℓ.
Expand Down Expand Up @@ -337,6 +338,9 @@ def integrate(
Whether to perform computation in a batched manner. Default is true.
check : bool
Flag for debugging. Must be false for JAX transformations.
plot : bool
Whether to plot the quantities in the integrand interpolated to the
quadrature points of each integral. Ignored if ``check`` is false.
Returns
-------
Expand All @@ -361,6 +365,7 @@ def integrate(
method=method,
batch=batch,
check=check,
plot=plot,
)
if weight is not None:
result *= interp_to_argmin(
Expand Down
30 changes: 12 additions & 18 deletions desc/integrals/bounce_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@


def get_pitch_inv(min_B, max_B, num, relative_shift=1e-6):
"""Return 1/λ values uniformly spaced between ``min_B`` and ``max_B``.
"""Return 1/λ values for quadrature between ``min_B`` and ``max_B``.
Parameters
----------
Expand Down Expand Up @@ -262,6 +262,7 @@ def _check_bounce_points(z1, z2, pitch_inv, knots, B, plot=True, **kwargs):
z1=_z1,
z2=_z2,
k=pitch_inv[idx],
title=kwargs.pop("title") + f", (p,m,l)={idx}",
**kwargs,
)

Expand Down Expand Up @@ -350,7 +351,8 @@ def bounce_quadrature(
Flag for debugging. Must be false for JAX transformations.
Ignored if ``batch`` is false.
plot : bool
Whether to plot stuff if ``check`` is true. Default is false.
Whether to plot the quantities in the integrand interpolated to the
quadrature points of each integral. Ignored if ``check`` is false.
Returns
-------
Expand Down Expand Up @@ -418,8 +420,8 @@ def _interpolate_and_integrate(
data,
knots,
method,
check=False,
plot=False,
check,
plot,
):
"""Interpolate given functions to points ``Q`` and perform quadrature.
Expand Down Expand Up @@ -526,28 +528,18 @@ def _check_interp(shape, Q, f, b_sup_z, B, result, plot):


def _plot_check_interp(Q, V, name=""):
"""Plot V[λ, α, ρ, (ζ₁, ζ₂)](Q).
These are pretty, but likely only useful for developers
doing debugging, so we don't include an option to plot these
in the public API of Bounce1D.
"""
"""Plot V[λ, α, ρ, (ζ₁, ζ₂)](Q)."""
for idx in np.ndindex(Q.shape[:3]):
marked = jnp.nonzero(jnp.any(Q[idx] != 0.0, axis=-1))[0]
if marked.size == 0:
continue

Check warning on line 535 in desc/integrals/bounce_utils.py

View check run for this annotation

Codecov / codecov/patch

desc/integrals/bounce_utils.py#L535

Added line #L535 was not covered by tests
fig, ax = plt.subplots()
ax.set_xlabel(r"$\zeta$")
ax.set_ylabel(name)
ax.set_title(f"Interpolation of {name} to quadrature points. Index {idx}.")
ax.set_title(f"Interpolation of {name} to quadrature points, (p,m,l)={idx}")
for i in marked:
ax.plot(Q[(*idx, i)], V[(*idx, i)], marker="o")
fig.text(
0.01,
0.01,
f"Each color specifies {name} interpolated to the quadrature "
"points of a particular integral.",
)
fig.text(0.01, 0.01, "Each color specifies a particular integral.")
plt.tight_layout()
plt.show()

Expand Down Expand Up @@ -765,7 +757,7 @@ def plot_ppoly(
start=None,
stop=None,
include_knots=False,
knot_transparency=0.1,
knot_transparency=0.2,
include_legend=True,
):
"""Plot the piecewise polynomial ``ppoly``.
Expand Down Expand Up @@ -805,6 +797,8 @@ def plot_ppoly(
Whether to plot vertical lines at the knots.
knot_transparency : float
Transparency of knot lines.
include_legend : bool
Whether to include the legend in the plot. Default is true.
Returns
-------
Expand Down
8 changes: 2 additions & 6 deletions desc/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,21 +739,17 @@ def flatten_matrix(y):
# https://github.com/numpy/numpy/issues/25805
def atleast_nd(ndmin, ary):
"""Adds dimensions to front if necessary."""
if ndmin == 1:
return jnp.atleast_1d(ary)
if ndmin == 2:
return jnp.atleast_2d(ary)
return jnp.array(ary, ndmin=ndmin) if jnp.ndim(ary) < ndmin else ary


def atleast_3d_mid(ary):
"""Like np.atleast3d but if adds dim at axis 1 for 2d arrays."""
"""Like np.atleast_3d but if adds dim at axis 1 for 2d arrays."""
ary = jnp.atleast_2d(ary)
return ary[:, jnp.newaxis] if ary.ndim == 2 else ary


def atleast_2d_end(ary):
"""Like np.atleast2d but if adds dim at axis 1 for 1d arrays."""
"""Like np.atleast_2d but if adds dim at axis 1 for 1d arrays."""
ary = jnp.atleast_1d(ary)
return ary[:, jnp.newaxis] if ary.ndim == 1 else ary

Expand Down
50 changes: 33 additions & 17 deletions tests/test_integrals.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,11 @@
get_pitch_inv,
interp_to_argmin,
interp_to_argmin_hard,
plot_ppoly,
)
from desc.integrals.quad_utils import (
automorphism_sin,
bijection_from_disc,
get_quadrature,
grad_automorphism_sin,
grad_bijection_from_disc,
leggauss_lob,
Expand Down Expand Up @@ -738,7 +738,9 @@ def test_z1_first(self):
B = CubicHermiteSpline(knots, np.cos(knots), -np.sin(knots))
pitch_inv = 0.5
intersect = B.solve(pitch_inv, extrapolate=False)
z1, z2 = bounce_points(pitch_inv, knots, B.c.T, B.derivative().c.T, check=True)
z1, z2 = bounce_points(
pitch_inv, knots, B.c.T, B.derivative().c.T, check=True, include_knots=True
)
z1, z2 = TestBounce1DPoints.filter(z1, z2)
assert z1.size and z2.size
np.testing.assert_allclose(z1, intersect[0::2])
Expand All @@ -753,7 +755,9 @@ def test_z2_first(self):
B = CubicHermiteSpline(k, np.cos(k), -np.sin(k))
pitch_inv = 0.5
intersect = B.solve(pitch_inv, extrapolate=False)
z1, z2 = bounce_points(pitch_inv, k, B.c.T, B.derivative().c.T, check=True)
z1, z2 = bounce_points(
pitch_inv, k, B.c.T, B.derivative().c.T, check=True, include_knots=True
)
z1, z2 = TestBounce1DPoints.filter(z1, z2)
assert z1.size and z2.size
np.testing.assert_allclose(z1, intersect[1:-1:2])
Expand All @@ -772,7 +776,9 @@ def test_z1_before_extrema(self):
)
dB_dz = B.derivative()
pitch_inv = B(dB_dz.roots(extrapolate=False))[3] - 1e-13
z1, z2 = bounce_points(pitch_inv, k, B.c.T, dB_dz.c.T, check=True)
z1, z2 = bounce_points(
pitch_inv, k, B.c.T, dB_dz.c.T, check=True, include_knots=True
)
z1, z2 = TestBounce1DPoints.filter(z1, z2)
assert z1.size and z2.size
intersect = B.solve(pitch_inv, extrapolate=False)
Expand All @@ -797,7 +803,9 @@ def test_z2_before_extrema(self):
)
dB_dz = B.derivative()
pitch_inv = B(dB_dz.roots(extrapolate=False))[2]
z1, z2 = bounce_points(pitch_inv, k, B.c.T, dB_dz.c.T, check=True)
z1, z2 = bounce_points(
pitch_inv, k, B.c.T, dB_dz.c.T, check=True, include_knots=True
)
z1, z2 = TestBounce1DPoints.filter(z1, z2)
assert z1.size and z2.size
intersect = B.solve(pitch_inv, extrapolate=False)
Expand All @@ -819,9 +827,14 @@ def test_extrema_first_and_before_z1(self):
dB_dz = B.derivative()
pitch_inv = B(dB_dz.roots(extrapolate=False))[2] + 1e-13
z1, z2 = bounce_points(
pitch_inv, k[2:], B.c[:, 2:].T, dB_dz.c[:, 2:].T, check=True, plot=False
pitch_inv,
k[2:],
B.c[:, 2:].T,
dB_dz.c[:, 2:].T,
check=True,
start=k[2],
include_knots=True,
)
plot_ppoly(B, z1=z1, z2=z2, k=pitch_inv, start=k[2])
z1, z2 = TestBounce1DPoints.filter(z1, z2)
assert z1.size and z2.size
intersect = B.solve(pitch_inv, extrapolate=False)
Expand All @@ -844,7 +857,9 @@ def test_extrema_first_and_before_z2(self):
)
dB_dz = B.derivative()
pitch_inv = B(dB_dz.roots(extrapolate=False))[1] - 1e-13
z1, z2 = bounce_points(pitch_inv, k, B.c.T, dB_dz.c.T, check=True)
z1, z2 = bounce_points(
pitch_inv, k, B.c.T, dB_dz.c.T, check=True, include_knots=True
)
z1, z2 = TestBounce1DPoints.filter(z1, z2)
assert z1.size and z2.size
# Our routine correctly detects intersection, while scipy, jnp.root fails.
Expand Down Expand Up @@ -937,7 +952,7 @@ def test_bounce_quadrature(self, is_strong, quad, automorphism):
check=True,
**kwargs,
)
result = bounce.integrate(pitch_inv, integrand, check=True)
result = bounce.integrate(pitch_inv, integrand, check=True, plot=True)
assert np.count_nonzero(result) == 1
np.testing.assert_allclose(result.sum(), truth, rtol=1e-4)

Expand All @@ -950,14 +965,10 @@ def _adaptive_elliptic(integrand, k):

@staticmethod
def _fixed_elliptic(integrand, k, deg):
# Can use this test to benchmark quadrature performance.
# Just
k = np.atleast_1d(k)
a = np.zeros_like(k)
b = 2 * np.arcsin(k)
x, w = leggauss(deg)
w = w * grad_automorphism_sin(x)
x = automorphism_sin(x)
x, w = get_quadrature(leggauss(deg), (automorphism_sin, grad_automorphism_sin))
Z = bijection_from_disc(x, a[..., np.newaxis], b[..., np.newaxis])
k = k[..., np.newaxis]
quad = np.dot(integrand(Z, k), w) * grad_bijection_from_disc(a, b)
Expand Down Expand Up @@ -1118,7 +1129,10 @@ def test_bounce1d_checks(self):
nodes = grid.source_grid.meshgrid_reshape(grid.source_grid.nodes[:, :2], "arz")
print("(α, ρ):", nodes[m, l, 0])

# 7. Plotting
# 7. Optionally check for correctness of bounce points
bounce.check_points(*bounce.points(pitch_inv), pitch_inv, plot=False)

# 8. Plotting
fig, ax = bounce.plot(m, l, pitch_inv[..., l], include_legend=False, show=False)
return fig

Expand Down Expand Up @@ -1343,6 +1357,8 @@ def test_binormal_drift_bounce1d(self):
Lref=data["a"],
check=True,
)
bounce.check_points(*bounce.points(pitch_inv), pitch_inv, plot=False)

f = Bounce1D.reshape_data(grid.source_grid, cvdrift, gbdrift)
drift_numerical_num = bounce.integrate(
pitch_inv=pitch_inv,
Expand Down Expand Up @@ -1389,8 +1405,8 @@ def _test_bounce_autodiff(bounce, integrand, **kwargs):
If the AD tool works properly, then these operations should be assigned
zero gradients while the gradients wrt parameters of our physics computations
accumulate correctly. Less mature AD tools may have subtle bugs that cause
the gradients to not accumulate correctly. (There's more than a few
GitHub issues that JAX has fixed related to this in the past!)
the gradients to not accumulate correctly. (There's a few
GitHub issues that JAX has fixed related to this in the past.)
This test first confirms the gradients computed by reverse mode AD matches
the analytic approximation of the true gradient. Then we confirm that the
Expand Down

0 comments on commit 239e441

Please sign in to comment.