Skip to content

Commit

Permalink
Merge branch 'master' into dp/pyREGCOIL
Browse files Browse the repository at this point in the history
  • Loading branch information
dpanici committed Oct 15, 2024
2 parents 2ff2065 + e5d31c2 commit 80f6f33
Show file tree
Hide file tree
Showing 7 changed files with 313 additions and 67 deletions.
145 changes: 114 additions & 31 deletions desc/coils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Classes for magnetic field coils."""

import functools
import numbers
from abc import ABC
from collections.abc import MutableSequence
Expand Down Expand Up @@ -28,7 +29,7 @@
from desc.grid import LinearGrid
from desc.magnetic_fields import _MagneticField
from desc.optimizable import Optimizable, OptimizableCollection, optimizable_parameter
from desc.utils import equals, errorif, flatten_list, safenorm, warnif
from desc.utils import cross, dot, equals, errorif, flatten_list, safenorm, warnif


@jit
Expand Down Expand Up @@ -245,7 +246,7 @@ def num_coils(self):
"""int: Number of coils."""
return 1

def _compute_position(self, params=None, grid=None, **kwargs):
def _compute_position(self, params=None, grid=None, dx1=False, **kwargs):
"""Compute coil positions accounting for stellarator symmetry.
Parameters
Expand All @@ -255,18 +256,31 @@ def _compute_position(self, params=None, grid=None, **kwargs):
grid : Grid or int, optional
Grid of coordinates to evaluate at. Defaults to a Linear grid.
If an integer, uses that many equally spaced points.
dx1 : bool
If True, also return dx/ds for the curve.
Returns
-------
x : ndarray, shape(len(self),source_grid.num_nodes,3)
Coil positions, in [R,phi,Z] or [X,Y,Z] coordinates.
x_s : ndarray, shape(len(self),source_grid.num_nodes,3)
Coil position derivatives, in [R,phi,Z] or [X,Y,Z] coordinates.
Only returned if dx1=True.
"""
x = self.compute("x", grid=grid, params=params, **kwargs)["x"]
x = jnp.transpose(jnp.atleast_3d(x), [2, 0, 1]) # shape=(1,num_nodes,3)
basis = kwargs.pop("basis", "xyz")
kwargs.setdefault("basis", "xyz")
keys = ["x", "x_s"] if dx1 else ["x"]
data = self.compute(keys, grid=grid, params=params, **kwargs)
x = jnp.transpose(jnp.atleast_3d(data["x"]), [2, 0, 1]) # shape=(1,num_nodes,3)
if dx1:
x_s = jnp.transpose(
jnp.atleast_3d(data["x_s"]), [2, 0, 1]
) # shape=(1,num_nodes,3)
basis = kwargs.get("basis", "xyz")
if basis.lower() == "rpz":
x = x.at[:, :, 1].set(jnp.mod(x[:, :, 1], 2 * jnp.pi))
if dx1:
return x, x_s
return x

def _compute_A_or_B(
Expand Down Expand Up @@ -1359,7 +1373,7 @@ def flip(self, *args, **kwargs):
"""Flip the coils across a plane."""
[coil.flip(*args, **kwargs) for coil in self.coils]

def _compute_position(self, params=None, grid=None, **kwargs):
def _compute_position(self, params=None, grid=None, dx1=False, **kwargs):
"""Compute coil positions accounting for stellarator symmetry.
Parameters
Expand All @@ -1369,25 +1383,35 @@ def _compute_position(self, params=None, grid=None, **kwargs):
grid : Grid or int, optional
Grid of coordinates to evaluate at. Defaults to a Linear grid.
If an integer, uses that many equally spaced points.
dx1 : bool
If True, also return dx/ds for each curve.
Returns
-------
x : ndarray, shape(len(self),source_grid.num_nodes,3)
Coil positions, in [R,phi,Z] or [X,Y,Z] coordinates.
x_s : ndarray, shape(len(self),source_grid.num_nodes,3)
Coil position derivatives, in [R,phi,Z] or [X,Y,Z] coordinates.
Only returned if dx1=True.
"""
basis = kwargs.pop("basis", "xyz")
keys = ["x", "x_s"] if dx1 else ["x"]
if params is None:
params = [get_params("x", coil, basis=basis) for coil in self]
data = self.compute("x", grid=grid, params=params, basis=basis, **kwargs)
params = [get_params(keys, coil, basis=basis) for coil in self]
data = self.compute(keys, grid=grid, params=params, basis=basis, **kwargs)
data = tree_leaves(data, is_leaf=lambda x: isinstance(x, dict))
x = jnp.dstack([d["x"].T for d in data]).T # shape=(ncoils,num_nodes,3)

if dx1:
x_s = jnp.dstack([d["x_s"].T for d in data]).T # shape=(ncoils,num_nodes,3)
# stellarator symmetry is easiest in [X,Y,Z] coordinates
if basis.lower() == "rpz":
xyz = rpz2xyz(x)
else:
xyz = x
xyz = rpz2xyz(x) if basis.lower() == "rpz" else x
if dx1:
xyz_s = (
rpz2xyz_vec(x_s, xyz[:, :, 0], xyz[:, :, 1])
if basis.lower() == "rpz"
else x_s
)

# if stellarator symmetric, add reflected coils from the other half field period
if self.sym:
Expand All @@ -1396,27 +1420,64 @@ def _compute_position(self, params=None, grid=None, **kwargs):
)
xyz_sym = xyz @ reflection_matrix(normal).T @ reflection_matrix([0, 0, 1]).T
xyz = jnp.vstack((xyz, jnp.flipud(xyz_sym)))
if dx1:
xyz_s_sym = (
xyz_s @ reflection_matrix(normal).T @ reflection_matrix([0, 0, 1]).T
)
xyz_s = jnp.vstack((xyz_s, jnp.flipud(xyz_s_sym)))

# field period rotation is easiest in [R,phi,Z] coordinates
rpz = xyz2rpz(xyz)
if dx1:
rpz_s = xyz2rpz_vec(xyz_s, xyz[:, :, 0], xyz[:, :, 1])

# if field period symmetry, add rotated coils from other field periods
if self.NFP > 1:
rpz0 = rpz
for k in range(1, self.NFP):
rpz = jnp.vstack(
(rpz, rpz0 + jnp.array([0, 2 * jnp.pi * k / self.NFP, 0]))
)
rpz0 = rpz
for k in range(1, self.NFP):
rpz = jnp.vstack((rpz, rpz0 + jnp.array([0, 2 * jnp.pi * k / self.NFP, 0])))
if dx1:
rpz_s = jnp.tile(rpz_s, (self.NFP, 1, 1))

# ensure phi in [0, 2pi)
rpz = rpz.at[:, :, 1].set(jnp.mod(rpz[:, :, 1], 2 * jnp.pi))

if basis.lower() == "xyz":
x = rpz2xyz(rpz)
else:
x = rpz
x = rpz2xyz(rpz) if basis.lower() == "xyz" else rpz
if dx1:
x_s = (
rpz2xyz_vec(rpz_s, phi=rpz[:, :, 1])
if basis.lower() == "xyz"
else rpz_s
)
return x, x_s
return x

def _compute_linking_number(self, params=None, grid=None):
"""Calculate linking numbers for coils in the coilset.
Parameters
----------
params : dict or array-like of dict, optional
Parameters to pass to coils, either the same for all coils or one for each.
grid : Grid or int, optional
Grid of coordinates to evaluate at. Defaults to a Linear grid.
If an integer, uses that many equally spaced points.
Returns
-------
link : ndarray, shape(num_coils, num_coils)
Linking number of each coil with each other coil. link=0 means they are not
linked, +/- 1 means the coils link each other in one direction or another.
"""
if grid is None:
grid = LinearGrid(N=50)
dx = grid.spacing[:, 2]
x, x_s = self._compute_position(params, grid, dx1=True, basis="xyz")
link = _linking_number(
x[:, None], x[None, :], x_s[:, None], x_s[None, :], dx, dx
)
return link / (4 * jnp.pi)

def _compute_A_or_B(
self,
coords,
Expand Down Expand Up @@ -2307,7 +2368,7 @@ def compute(
)
]

def _compute_position(self, params=None, grid=None, **kwargs):
def _compute_position(self, params=None, grid=None, dx1=False, **kwargs):
"""Compute coil positions accounting for stellarator symmetry.
Parameters
Expand All @@ -2318,11 +2379,16 @@ def _compute_position(self, params=None, grid=None, **kwargs):
Grid of coordinates to evaluate at. Defaults to a Linear grid.
If an integer, uses that many equally spaced points.
If array-like, should be 1 value per coil.
dx1 : bool
If True, also return dx/ds for each curve.
Returns
-------
x : ndarray, shape(len(self),source_grid.num_nodes,3)
Coil positions, in [R,phi,Z] or [X,Y,Z] coordinates.
x_s : ndarray, shape(len(self),source_grid.num_nodes,3)
Coil position derivatives, in [R,phi,Z] or [X,Y,Z] coordinates.
Only returned if dx1=True.
"""
errorif(
Expand All @@ -2331,15 +2397,17 @@ def _compute_position(self, params=None, grid=None, **kwargs):
"grid must be supplied to MixedCoilSet._compute_position, since the "
+ "default grid for each coil could have a different number of nodes.",
)
kwargs.setdefault("basis", "xyz")
params = self._make_arraylike(params)
grid = self._make_arraylike(grid)
x = jnp.vstack(
[
coil._compute_position(par, grd, **kwargs)
for coil, par, grd in zip(self.coils, params, grid)
]
)
return x
out = []
for coil, par, grd in zip(self.coils, params, grid):
out.append(coil._compute_position(par, grd, dx1, **kwargs))
if dx1:
x = jnp.vstack([foo[0] for foo in out])
x_s = jnp.vstack([foo[1] for foo in out])
return x, x_s
return jnp.vstack(out)

def _compute_A_or_B(
self,
Expand Down Expand Up @@ -2795,3 +2863,18 @@ def flatten_coils(coilset):
if ignore_groups:
cset = cls(*flatten_coils(cset), check_intersection=check_intersection)
return cset


@functools.partial(jnp.vectorize, signature="(m,3),(n,3),(m,3),(n,3),(m),(n)->()")
def _linking_number(x1, x2, x1_s, x2_s, dx1, dx2):
"""Linking number between curves x1 and x2 with tangents x1_s, x2_s."""
x1_s *= dx1[:, None]
x2_s *= dx2[:, None]
dx = x1[:, None, :] - x2[None, :, :] # shape(m,n,3)
dx_norm = safenorm(dx, axis=-1) # shape(m,n)
den = dx_norm**3
dr1xdr2 = cross(x1_s[:, None, :], x2_s[None, :, :], axis=-1) # shape(m,n,3)
num = dot(dx, dr1xdr2, axis=-1) # shape(m,n)
small = dx_norm < jnp.finfo(x1.dtype).eps
ratio = jnp.where(small, 0.0, num / jnp.where(small, 1.0, den))
return ratio.sum()
1 change: 1 addition & 0 deletions desc/objectives/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
CoilCurrentLength,
CoilCurvature,
CoilLength,
CoilSetLinkingNumber,
CoilSetMinDistance,
CoilTorsion,
PlasmaCoilSetMinDistance,
Expand Down
Loading

0 comments on commit 80f6f33

Please sign in to comment.