Skip to content

Commit

Permalink
Resolves #1240 (#1241)
Browse files Browse the repository at this point in the history
  • Loading branch information
dpanici authored Sep 5, 2024
2 parents d2deffa + d217334 commit 6ab8327
Show file tree
Hide file tree
Showing 23 changed files with 207 additions and 209 deletions.
2 changes: 1 addition & 1 deletion desc/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ def root(
This routine may be used on over or under-determined systems, in which case it
will solve it in a least squares / least norm sense.
"""
from desc.compute.utils import safenorm
from desc.utils import safenorm

if fixup is None:
fixup = lambda x, *args: x
Expand Down
3 changes: 1 addition & 2 deletions desc/coils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from desc.compute import get_params, rpz2xyz, rpz2xyz_vec, xyz2rpz, xyz2rpz_vec
from desc.compute.geom_utils import reflection_matrix
from desc.compute.utils import _compute as compute_fun
from desc.compute.utils import safenorm
from desc.geometry import (
FourierPlanarCurve,
FourierRZCurve,
Expand All @@ -29,7 +28,7 @@
from desc.grid import LinearGrid
from desc.magnetic_fields import _MagneticField
from desc.optimizable import Optimizable, OptimizableCollection, optimizable_parameter
from desc.utils import equals, errorif, flatten_list, warnif
from desc.utils import equals, errorif, flatten_list, safenorm, warnif


@jit
Expand Down
2 changes: 1 addition & 1 deletion desc/compute/_basis_vectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@

from desc.backend import jnp

from ..utils import cross, dot, safediv
from .data_index import register_compute_fun
from .utils import cross, dot, safediv


@register_compute_fun(
Expand Down
2 changes: 1 addition & 1 deletion desc/compute/_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

from desc.backend import jnp, sign

from ..utils import cross, dot, safenormalize
from .data_index import register_compute_fun
from .geom_utils import rotation_matrix, rpz2xyz, rpz2xyz_vec, xyz2rpz, xyz2rpz_vec
from .utils import cross, dot, safenormalize


@register_compute_fun(
Expand Down
2 changes: 1 addition & 1 deletion desc/compute/_equil.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
from desc.backend import jnp

from ..integrals.surface_integral import surface_averages
from ..utils import cross, dot, safediv, safenorm
from .data_index import register_compute_fun
from .utils import cross, dot, safediv, safenorm


@register_compute_fun(
Expand Down
2 changes: 1 addition & 1 deletion desc/compute/_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
surface_max,
surface_min,
)
from ..utils import cross, dot, safediv, safenorm
from .data_index import register_compute_fun
from .utils import cross, dot, safediv, safenorm


@register_compute_fun(
Expand Down
2 changes: 1 addition & 1 deletion desc/compute/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
from desc.backend import jnp

from ..integrals.surface_integral import line_integrals, surface_integrals
from ..utils import cross, dot, safenorm
from .data_index import register_compute_fun
from .utils import cross, dot, safenorm


@register_compute_fun(
Expand Down
2 changes: 1 addition & 1 deletion desc/compute/_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
from desc.backend import jnp

from ..integrals.surface_integral import surface_averages
from ..utils import cross, dot, safediv, safenorm
from .data_index import register_compute_fun
from .utils import cross, dot, safediv, safenorm


@register_compute_fun(
Expand Down
2 changes: 1 addition & 1 deletion desc/compute/_omnigenity.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@

from desc.backend import jnp, sign, vmap

from ..utils import cross, dot, safediv
from .data_index import register_compute_fun
from .utils import cross, dot, safediv


@register_compute_fun(
Expand Down
2 changes: 1 addition & 1 deletion desc/compute/_profiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
from desc.backend import cond, jnp

from ..integrals.surface_integral import surface_averages, surface_integrals
from ..utils import cumtrapz, dot, safediv
from .data_index import register_compute_fun
from .utils import cumtrapz, dot, safediv


@register_compute_fun(
Expand Down
2 changes: 1 addition & 1 deletion desc/compute/_stability.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
from desc.backend import jnp

from ..integrals.surface_integral import surface_integrals_map
from ..utils import dot
from .data_index import register_compute_fun
from .utils import dot


@register_compute_fun(
Expand Down
2 changes: 1 addition & 1 deletion desc/compute/geom_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from desc.backend import jnp

from .utils import safenorm, safenormalize
from ..utils import safenorm, safenormalize


def reflection_matrix(normal):
Expand Down
184 changes: 0 additions & 184 deletions desc/compute/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -711,187 +711,3 @@ def _has_transforms(qty, transforms, parameterization):
[d in transforms[key].derivatives.tolist() for d in derivs[key]]
).all()
return all(flags.values())


def dot(a, b, axis=-1):
"""Batched vector dot product.
Parameters
----------
a : array-like
First array of vectors.
b : array-like
Second array of vectors.
axis : int
Axis along which vectors are stored.
Returns
-------
y : array-like
y = sum(a*b, axis=axis)
"""
return jnp.sum(a * b, axis=axis, keepdims=False)


def cross(a, b, axis=-1):
"""Batched vector cross product.
Parameters
----------
a : array-like
First array of vectors.
b : array-like
Second array of vectors.
axis : int
Axis along which vectors are stored.
Returns
-------
y : array-like
y = a x b
"""
return jnp.cross(a, b, axis=axis)


def safenorm(x, ord=None, axis=None, fill=0, threshold=0):
"""Like jnp.linalg.norm, but without nan gradient at x=0.
Parameters
----------
x : ndarray
Vector or array to norm.
ord : {non-zero int, inf, -inf, 'fro', 'nuc'}, optional
Order of norm.
axis : {None, int, 2-tuple of ints}, optional
Axis to take norm along.
fill : float, ndarray, optional
Value to return where x is zero.
threshold : float >= 0
How small is x allowed to be.
"""
is_zero = (jnp.abs(x) <= threshold).all(axis=axis, keepdims=True)
y = jnp.where(is_zero, jnp.ones_like(x), x) # replace x with ones if is_zero
n = jnp.linalg.norm(y, ord=ord, axis=axis)
n = jnp.where(is_zero.squeeze(), fill, n) # replace norm with zero if is_zero
return n


def safenormalize(x, ord=None, axis=None, fill=0, threshold=0):
"""Normalize a vector to unit length, but without nan gradient at x=0.
Parameters
----------
x : ndarray
Vector or array to norm.
ord : {non-zero int, inf, -inf, 'fro', 'nuc'}, optional
Order of norm.
axis : {None, int, 2-tuple of ints}, optional
Axis to take norm along.
fill : float, ndarray, optional
Value to return where x is zero.
threshold : float >= 0
How small is x allowed to be.
"""
is_zero = (jnp.abs(x) <= threshold).all(axis=axis, keepdims=True)
y = jnp.where(is_zero, jnp.ones_like(x), x) # replace x with ones if is_zero
n = safenorm(x, ord, axis, fill, threshold) * jnp.ones_like(x)
# return unit vector with equal components if norm <= threshold
return jnp.where(n <= threshold, jnp.ones_like(y) / jnp.sqrt(y.size), y / n)


def safediv(a, b, fill=0, threshold=0):
"""Divide a/b with guards for division by zero.
Parameters
----------
a, b : ndarray
Numerator and denominator.
fill : float, ndarray, optional
Value to return where b is zero.
threshold : float >= 0
How small is b allowed to be.
"""
mask = jnp.abs(b) <= threshold
num = jnp.where(mask, fill, a)
den = jnp.where(mask, 1, b)
return num / den


def cumtrapz(y, x=None, dx=1.0, axis=-1, initial=None):
"""Cumulatively integrate y(x) using the composite trapezoidal rule.
Taken from SciPy, but changed NumPy references to JAX.NumPy:
https://github.com/scipy/scipy/blob/v1.10.1/scipy/integrate/_quadrature.py
Parameters
----------
y : array_like
Values to integrate.
x : array_like, optional
The coordinate to integrate along. If None (default), use spacing `dx`
between consecutive elements in `y`.
dx : float, optional
Spacing between elements of `y`. Only used if `x` is None.
axis : int, optional
Specifies the axis to cumulate. Default is -1 (last axis).
initial : scalar, optional
If given, insert this value at the beginning of the returned result.
Typically, this value should be 0. Default is None, which means no
value at ``x[0]`` is returned and `res` has one element less than `y`
along the axis of integration.
Returns
-------
res : ndarray
The result of cumulative integration of `y` along `axis`.
If `initial` is None, the shape is such that the axis of integration
has one less value than `y`. If `initial` is given, the shape is equal
to that of `y`.
"""
y = jnp.asarray(y)
if x is None:
d = dx
else:
x = jnp.asarray(x)
if x.ndim == 1:
d = jnp.diff(x)
# reshape to correct shape
shape = [1] * y.ndim
shape[axis] = -1
d = d.reshape(shape)
elif len(x.shape) != len(y.shape):
raise ValueError("If given, shape of x must be 1-D or the " "same as y.")
else:
d = jnp.diff(x, axis=axis)

if d.shape[axis] != y.shape[axis] - 1:
raise ValueError(
"If given, length of x along axis must be the " "same as y."
)

def tupleset(t, i, value):
l = list(t)
l[i] = value
return tuple(l)

nd = len(y.shape)
slice1 = tupleset((slice(None),) * nd, axis, slice(1, None))
slice2 = tupleset((slice(None),) * nd, axis, slice(None, -1))
res = jnp.cumsum(d * (y[slice1] + y[slice2]) / 2.0, axis=axis)

if initial is not None:
if not jnp.isscalar(initial):
raise ValueError("`initial` parameter should be a scalar.")

shape = list(res.shape)
shape[axis] = 1
res = jnp.concatenate(
[jnp.full(shape, initial, dtype=res.dtype), res], axis=axis
)

return res
2 changes: 1 addition & 1 deletion desc/integrals/interp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from interpax import interp1d

from desc.backend import jnp
from desc.compute.utils import safediv
from desc.utils import safediv

# Warning: method must be specified as keyword argument.
interp1d_vec = jnp.vectorize(
Expand Down
3 changes: 1 addition & 2 deletions desc/integrals/singularities.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,9 @@
from desc.backend import fori_loop, jnp, put, vmap
from desc.basis import DoubleFourierSeries
from desc.compute.geom_utils import rpz2xyz, rpz2xyz_vec, xyz2rpz_vec
from desc.compute.utils import safediv, safenorm
from desc.grid import LinearGrid
from desc.io import IOAble
from desc.utils import isalmostequal, islinspaced
from desc.utils import isalmostequal, islinspaced, safediv, safenorm


def _get_quadrature_nodes(q):
Expand Down
3 changes: 1 addition & 2 deletions desc/objectives/_coils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,9 @@
)
from desc.compute import get_profiles, get_transforms, rpz2xyz
from desc.compute.utils import _compute as compute_fun
from desc.compute.utils import safenorm
from desc.grid import LinearGrid, _Grid
from desc.integrals import compute_B_plasma
from desc.utils import Timer, errorif, warnif
from desc.utils import Timer, errorif, safenorm, warnif

from .normalization import compute_scaling_factors
from .objective_funs import _Objective
Expand Down
3 changes: 1 addition & 2 deletions desc/objectives/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@
from desc.backend import jnp, vmap
from desc.compute import get_profiles, get_transforms, rpz2xyz, xyz2rpz
from desc.compute.utils import _compute as compute_fun
from desc.compute.utils import safenorm
from desc.grid import LinearGrid, QuadratureGrid
from desc.utils import Timer, errorif, parse_argname_change, warnif
from desc.utils import Timer, errorif, parse_argname_change, safenorm, warnif

from .normalization import compute_scaling_factors
from .objective_funs import _Objective
Expand Down
Loading

0 comments on commit 6ab8327

Please sign in to comment.