Skip to content

Commit

Permalink
Stop modding of non-periodic angles (#1254)
Browse files Browse the repository at this point in the history
Resolves #1180 and resolves #1173 .
  • Loading branch information
unalmis authored Sep 19, 2024
2 parents 54becdb + 97d1747 commit 75eafcc
Show file tree
Hide file tree
Showing 7 changed files with 87 additions and 89 deletions.
6 changes: 3 additions & 3 deletions desc/compute/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1494,7 +1494,7 @@ def _Z_zzz(params, transforms, profiles, data, **kwargs):
label="\\alpha",
units="~",
units_long="None",
description="Field line label, defined on [0, 2pi)",
description="Field line label",
dim=1,
params=[],
transforms={},
Expand All @@ -1503,7 +1503,7 @@ def _Z_zzz(params, transforms, profiles, data, **kwargs):
data=["theta_PEST", "phi", "iota"],
)
def _alpha(params, transforms, profiles, data, **kwargs):
data["alpha"] = (data["theta_PEST"] - data["iota"] * data["phi"]) % (2 * jnp.pi)
data["alpha"] = data["theta_PEST"] - data["iota"] * data["phi"]
return data


Expand Down Expand Up @@ -3077,7 +3077,7 @@ def _theta(params, transforms, profiles, data, **kwargs):
data=["theta", "lambda"],
)
def _theta_PEST(params, transforms, profiles, data, **kwargs):
data["theta_PEST"] = (data["theta"] + data["lambda"]) % (2 * jnp.pi)
data["theta_PEST"] = data["theta"] + data["lambda"]
return data


Expand Down
131 changes: 71 additions & 60 deletions desc/equilibrium/coords.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@ def _periodic(x, period):
return jnp.where(jnp.isfinite(period), x % period, x)


def _fixup_residual(r, period):
r = _periodic(r, period)
# r should be between -period and period
return jnp.where((r > period / 2) & jnp.isfinite(period), -period + r, r)


def map_coordinates( # noqa: C901
eq,
coords,
Expand Down Expand Up @@ -87,9 +93,9 @@ def map_coordinates( # noqa: C901
ValueError,
f"tol must be a positive float, got {tol}",
)
params = setdefault(params, eq.params_dict)
inbasis = tuple(inbasis)
outbasis = tuple(outbasis)
params = setdefault(params, eq.params_dict)

basis_derivs = tuple(f"{X}_{d}" for X in inbasis for d in ("r", "t", "z"))
for key in basis_derivs:
Expand All @@ -111,25 +117,27 @@ def map_coordinates( # noqa: C901
profiles["iota"] = eq.get_profile(["iota", "iota_r"], params=params)
iota = profiles["iota"].compute(Grid(coords, sort=False, jitable=True))
return _map_clebsch_coordinates(
coords,
iota,
params["L_lmn"],
eq.L_basis,
guess[:, 1] if guess is not None else None,
tol,
maxiter,
full_output,
coords=coords,
iota=iota,
L_lmn=params["L_lmn"],
L_basis=eq.L_basis,
guess=guess[:, 1] if guess is not None else None,
period=period[1] if period is not None else np.inf,
tol=tol,
maxiter=maxiter,
full_output=full_output,
**kwargs,
)
if inbasis == ("rho", "theta_PEST", "zeta"):
return _map_PEST_coordinates(
coords,
params["L_lmn"],
eq.L_basis,
guess[:, 1] if guess is not None else None,
tol,
maxiter,
full_output,
coords=coords,
L_lmn=params["L_lmn"],
L_basis=eq.L_basis,
guess=guess[:, 1] if guess is not None else None,
period=period[1] if period is not None else np.inf,
tol=tol,
maxiter=maxiter,
full_output=full_output,
**kwargs,
)

Expand All @@ -139,7 +147,6 @@ def map_coordinates( # noqa: C901
params["i_l"] = profiles["iota"].params

rhomin = kwargs.pop("rhomin", tol / 10)
warnif(period is None, msg="Assuming no periodicity.")
period = np.asarray(setdefault(period, (np.inf, np.inf, np.inf)))
coords = _periodic(coords, period)

Expand All @@ -165,8 +172,7 @@ def compute(y, basis):
@jit
def residual(y, coords):
xk = compute(y, inbasis)
r = _periodic(xk, period) - _periodic(coords, period)
return jnp.where((r > period / 2) & jnp.isfinite(period), -period + r, r)
return _fixup_residual(xk - coords, period)

@jit
def jac(y, coords):
Expand Down Expand Up @@ -212,7 +218,6 @@ def fixup(y, *args):
yk, (res, niter) = vecroot(yk, coords)

out = compute(yk, outbasis)

if full_output:
return out, (res, niter)
return out
Expand Down Expand Up @@ -253,7 +258,7 @@ def _initial_guess_heuristic(yk, coords, inbasis, eq, profiles):
zero = jnp.zeros_like(rho)
grid = Grid(nodes=jnp.column_stack([rho, zero, zero]), sort=False, jitable=True)
iota = profiles["iota"].compute(grid)
theta = (alpha + iota * zeta) % (2 * jnp.pi)
theta = alpha + iota * zeta

yk = jnp.column_stack([rho, theta, zeta])
return yk
Expand Down Expand Up @@ -284,6 +289,7 @@ def _map_PEST_coordinates(
L_lmn,
L_basis,
guess,
period=np.inf,
tol=1e-6,
maxiter=30,
full_output=False,
Expand All @@ -304,6 +310,9 @@ def _map_PEST_coordinates(
guess : jnp.ndarray
Shape (k, ).
Optional initial guess for the computational coordinates.
period : float
Assumed periodicity for ϑ.
Use ``np.inf`` to denote no periodicity.
tol : float
Stopping tolerance.
maxiter : int
Expand All @@ -325,36 +334,25 @@ def _map_PEST_coordinates(
Only returned if ``full_output`` is True.
"""
rho, theta_PEST, zeta = coords.T
theta_PEST = theta_PEST % (2 * np.pi)
# Assume λ=0 for initial guess.
guess = setdefault(guess, theta_PEST)
# noqa: D202

# Root finding for θₖ such that r(θₖ) = ϑₖ(ρ, θₖ, ζ) − ϑ = 0.
def rootfun(theta_DESC, theta_PEST, rho, zeta):
nodes = jnp.array(
[rho.squeeze(), theta_DESC.squeeze(), zeta.squeeze()], ndmin=2
)
def rootfun(theta, theta_PEST, rho, zeta):
nodes = jnp.array([rho.squeeze(), theta.squeeze(), zeta.squeeze()], ndmin=2)
A = L_basis.evaluate(nodes)
lmbda = A @ L_lmn
theta_PEST_k = (theta_DESC + lmbda) % (2 * np.pi)
r = theta_PEST_k - theta_PEST
# r should be between -pi and pi
r = jnp.where(r > np.pi, r - 2 * np.pi, r)
r = jnp.where(r < -np.pi, r + 2 * np.pi, r)
return r.squeeze()

def jacfun(theta_DESC, theta_PEST, rho, zeta):
# Valid everywhere except θ such that θ+λ = k 2π where k ∈ ℤ.
nodes = jnp.array(
[rho.squeeze(), theta_DESC.squeeze(), zeta.squeeze()], ndmin=2
)
theta_PEST_k = theta + lmbda
return _fixup_residual(theta_PEST_k - theta_PEST, period).squeeze()

def jacfun(theta, theta_PEST, rho, zeta):
# Valid everywhere except θ such that θ+λ = k period where k ∈ ℤ.
nodes = jnp.array([rho.squeeze(), theta.squeeze(), zeta.squeeze()], ndmin=2)
A1 = L_basis.evaluate(nodes, (0, 1, 0))
lmbda_t = jnp.dot(A1, L_lmn)
return 1 + lmbda_t.squeeze()

def fixup(x, *args):
return x % (2 * np.pi)
return _periodic(x, period)

vecroot = jit(
vmap(
Expand All @@ -370,10 +368,15 @@ def fixup(x, *args):
)
)
)
theta_DESC, (res, niter) = vecroot(guess, theta_PEST, rho, zeta)

out = jnp.column_stack([rho, jnp.atleast_1d(theta_DESC.squeeze()), zeta])

rho, theta_PEST, zeta = coords.T
theta, (res, niter) = vecroot(
# Assume λ=0 for default initial guess.
setdefault(guess, theta_PEST),
theta_PEST,
rho,
zeta,
)
out = jnp.column_stack([rho, jnp.atleast_1d(theta.squeeze()), zeta])
if full_output:
return out, (res, niter)
return out
Expand All @@ -386,6 +389,7 @@ def _map_clebsch_coordinates(
L_lmn,
L_basis,
guess=None,
period=np.inf,
tol=1e-6,
maxiter=30,
full_output=False,
Expand All @@ -409,6 +413,9 @@ def _map_clebsch_coordinates(
guess : jnp.ndarray
Shape (k, ).
Optional initial guess for the computational coordinates.
period : float
Assumed periodicity for α.
Use ``np.inf`` to denote no periodicity.
tol : float
Stopping tolerance.
maxiter : int
Expand All @@ -430,32 +437,25 @@ def _map_clebsch_coordinates(
Only returned if ``full_output`` is True.
"""
rho, alpha, zeta = coords.T
if guess is None:
# Assume λ=0 for initial guess.
guess = (alpha + iota * zeta) % (2 * np.pi)
# noqa: D202

# Root finding for θₖ such that r(θₖ) = αₖ(ρ, θₖ, ζ) − α = 0.
def rootfun(theta, alpha, rho, zeta, iota):
nodes = jnp.array([rho.squeeze(), theta.squeeze(), zeta.squeeze()], ndmin=2)
A = L_basis.evaluate(nodes)
lmbda = A @ L_lmn
alpha_k = theta + lmbda - iota * zeta
r = (alpha_k - alpha) % (2 * np.pi)
# r should be between -pi and pi
r = jnp.where(r > np.pi, r - 2 * np.pi, r)
r = jnp.where(r < -np.pi, r + 2 * np.pi, r)
return r.squeeze()
return _fixup_residual(alpha_k - alpha, period).squeeze()

def jacfun(theta, alpha, rho, zeta, iota):
# Valid everywhere except θ such that θ+λ = k where k ∈ ℤ.
# Valid everywhere except θ such that θ+λ = k period where k ∈ ℤ.
nodes = jnp.array([rho.squeeze(), theta.squeeze(), zeta.squeeze()], ndmin=2)
A1 = L_basis.evaluate(nodes, (0, 1, 0))
lmbda_t = jnp.dot(A1, L_lmn)
return 1 + lmbda_t.squeeze()

def fixup(x, *args):
return x % (2 * np.pi)
return _periodic(x, period)

vecroot = jit(
vmap(
Expand All @@ -471,9 +471,13 @@ def fixup(x, *args):
)
)
)
rho, alpha, zeta = coords.T
if guess is None:
# Assume λ=0 for default initial guess.
guess = alpha + iota * zeta
theta, (res, niter) = vecroot(guess, alpha, rho, zeta, iota)
out = jnp.column_stack([rho, jnp.atleast_1d(theta.squeeze()), zeta])

out = jnp.column_stack([rho, jnp.atleast_1d(theta.squeeze()), zeta])
if full_output:
return out, (res, niter)
return out
Expand Down Expand Up @@ -662,7 +666,14 @@ def to_sfl(


def get_rtz_grid(
eq, radial, poloidal, toroidal, coordinates, period, jitable=True, **kwargs
eq,
radial,
poloidal,
toroidal,
coordinates,
period=(np.inf, np.inf, np.inf),
jitable=True,
**kwargs,
):
"""Return DESC grid in rtz (rho, theta, zeta) coordinates from given coordinates.
Expand All @@ -685,7 +696,7 @@ def get_rtz_grid(
rvp : rho, theta_PEST, phi
rtz : rho, theta, zeta
period : tuple of float
Assumed periodicity for functions of the given coordinates.
Assumed periodicity of the given coordinates.
Use ``np.inf`` to denote no periodicity.
jitable : bool, optional
If false the returned grid has additional attributes.
Expand Down
2 changes: 1 addition & 1 deletion desc/equilibrium/equilibrium.py
Original file line number Diff line number Diff line change
Expand Up @@ -1262,7 +1262,7 @@ def compute_theta_coords(
)
return map_coordinates(
self,
flux_coords,
coords=flux_coords,
inbasis=("rho", "theta_PEST", "zeta"),
outbasis=("rho", "theta", "zeta"),
params=self.params_dict if L_lmn is None else {"L_lmn": L_lmn},
Expand Down
2 changes: 2 additions & 0 deletions desc/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -1037,6 +1037,8 @@ def _create_nodes( # noqa: C901
"""
self._NFP = check_posint(NFP, "NFP", False)
self._period = (np.inf, 2 * np.pi, 2 * np.pi / self._NFP)
# TODO:
# https://github.com/PlasmaControl/DESC/pull/1204#pullrequestreview-2246771337
axis = bool(axis)
endpoint = bool(endpoint)
theta_period = self.period[1]
Expand Down
15 changes: 7 additions & 8 deletions desc/vmec.py
Original file line number Diff line number Diff line change
Expand Up @@ -1664,7 +1664,7 @@ def vmec_interpolate(Cmn, Smn, xm, xn, theta, phi, s=None, si=None, sym=True):
def compute_theta_coords(
cls, lmns, xm, xn, s, theta_star, zeta, si=None, lmnc=None
):
"""Find theta such that theta + lambda(theta) == theta_star.
"""Find θ (theta_DESC) for given PEST straight field line ϑ (theta_star).
Parameters
----------
Expand Down Expand Up @@ -1693,6 +1693,7 @@ def compute_theta_coords(
theta such that theta + lambda(theta) == theta_star
"""
theta_PEST = theta_star
if si is None:
si = np.linspace(0, 1, lmns.shape[0])
si[1:] = si[0:-1] + 0.5 / (lmns.shape[0] - 1)
Expand All @@ -1702,9 +1703,7 @@ def compute_theta_coords(
else:
lmbda_mnc = interpolate.CubicSpline(si, lmnc)

# Note: theta* (also known as vartheta) is the poloidal straight field line
# angle in PEST-like flux coordinates

# Root finding for θₖ such that r(θₖ) = ϑₖ(ρ, θₖ, ζ) − ϑ = 0.
def root_fun(theta):
lmbda = np.sum(
lmbda_mns(s)
Expand All @@ -1721,12 +1720,12 @@ def root_fun(theta):
),
axis=-1,
)
theta_star_k = theta + lmbda # theta* = theta + lambda
err = theta_star - theta_star_k # FIXME: mod by 2pi
return err
theta_PEST_k = theta + lmbda
r = theta_PEST_k - theta_PEST
return -r # the negative sign is necessary

out = optimize.root(
root_fun, x0=theta_star, method="diagbroyden", options={"ftol": 1e-6}
root_fun, x0=theta_PEST, method="diagbroyden", options={"ftol": 1e-6}
)
return out.x

Expand Down
6 changes: 1 addition & 5 deletions tests/test_equilibrium.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,7 @@ def test_map_coordinates():
eq.change_resolution(3, 3, 3, 6, 6, 6)
n = 100
coords = np.array([np.ones(n), np.zeros(n), np.linspace(0, 10 * np.pi, n)]).T
out = eq.map_coordinates(
coords,
inbasis=["rho", "alpha", "zeta"],
period=(np.inf, 2 * np.pi, np.inf),
)
out = eq.map_coordinates(coords, inbasis=["rho", "alpha", "zeta"])
assert np.isfinite(out).all()

eq = get("DSHAPE")
Expand Down
Loading

0 comments on commit 75eafcc

Please sign in to comment.