Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Neo wrapper #1472

Draft
wants to merge 29 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
d81a0dd
working on gx wrapper.
Aug 9, 2022
3664f28
finished implementing wrapper, about to test and fail.
Aug 9, 2022
ae9da04
continuing implementation of gx wrapper. works individually, still do…
Aug 10, 2022
09c35ea
fixed gx wrapper, working now, but very slow. requires many gx calls.
Aug 12, 2022
f3c7c61
still testing gx wrapper. working, not sure how effectively.
Aug 13, 2022
dd4d4d7
small updates.
Aug 13, 2022
427f459
fixed gxwrapper to do all the finite differences for each arg at once…
Aug 14, 2022
0e07705
did more testing, seems to be working pretty robustly if I set target…
Aug 15, 2022
348d29d
implemented projections before calling jacobian in wrapper. not reall…
Aug 18, 2022
8190011
Merge branch 'master' into neo_wrapper
ddudt Sep 7, 2022
b562d8a
delete unnecessary GX files
ddudt Sep 7, 2022
359b977
basic structure for NEOWrapper
ddudt Sep 8, 2022
a47facf
add msissing quantities to VMECIO.save()
ddudt Sep 12, 2022
8b7f6ac
bug fixes
ddudt Sep 12, 2022
6348c20
more bug fixes
ddudt Sep 12, 2022
892b4ec
merge constrain-toroidal-current-density
ddudt Sep 13, 2022
16b00a9
Merge branch 'master' into neo_wrapper
ddudt Oct 3, 2022
9fcda60
add NEOWrapper to objectives init, ns=8
ddudt Oct 3, 2022
af4e4f5
fixing things
ddudt Oct 3, 2022
c6d8149
suppress debugging print statements
ddudt Oct 3, 2022
bef260d
merge with master
ddudt Oct 5, 2022
ed175ac
add print_value_fmt to NEOWrapper
ddudt Oct 5, 2022
f40add8
wrapped objective jac speedup
ddudt Oct 5, 2022
0a12b0a
remove obsolete use_jit call
ddudt Oct 5, 2022
cd59b7c
add linear_objective reference
ddudt Oct 5, 2022
de0ea6f
remove extra_args kwarg
ddudt Oct 5, 2022
ffcc12c
fix use_jit issues with wrapped objective
ddudt Oct 5, 2022
c3c4396
remove debugging print statement
ddudt Oct 5, 2022
72fe3ec
revert wrapped objective jac speedup
ddudt Oct 6, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 49 additions & 18 deletions desc/derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,13 +339,13 @@ def f(x):
tempargs = args[0 : self._argnum] + (x,) + args[self._argnum + 1 :]
return self._fun(*tempargs)

x = np.atleast_1d(args[self._argnum])
x = jnp.atleast_1d(args[self._argnum])
n = len(x)
fx = f(x)
h = np.maximum(1.0, np.abs(x)) * self.rel_step
ee = np.diag(h)
h = jnp.maximum(1.0, np.abs(x)) * self.rel_step
ee = jnp.diag(h)
dtype = fx.dtype
hess = np.outer(h, h)
hess = jnp.outer(h, h)

for i in range(n):
eei = ee[i, :]
Expand Down Expand Up @@ -383,13 +383,13 @@ def f(x):
tempargs = args[0 : self._argnum] + (x,) + args[self._argnum + 1 :]
return self._fun(*tempargs)

x0 = np.atleast_1d(args[self._argnum])
x0 = jnp.atleast_1d(args[self._argnum])
f0 = f(x0)
m = f0.size
n = x0.size
J = np.zeros((m, n))
h = np.maximum(1.0, np.abs(x0)) * self.rel_step
h_vecs = np.diag(np.atleast_1d(h))
J = jnp.zeros((m, n))
h = jnp.maximum(1.0, np.abs(x0)) * self.rel_step
h_vecs = jnp.diag(np.atleast_1d(h))
for i in range(n):
x1 = x0 - h_vecs[i]
x2 = x0 + h_vecs[i]
Expand All @@ -400,11 +400,46 @@ def f(x):
dfdx = df / dx
J = put(J.T, i, dfdx.flatten()).T
if m == 1:
J = np.ravel(J)
J = jnp.ravel(J)
return J

@classmethod
def compute_jvp(cls, fun, argnum, v, *args, **kwargs):
rel_step = kwargs.get("rel_step", 1e-3)
h = rel_step
if jnp.isscalar(argnum):
nargs = 1
argnum = (argnum,)
else:
nargs = len(
argnum,
)
v = (v,) if not isinstance(v, tuple) else v
v = v[:-1] + (jnp.array(v[-1]),)
args = args[:-1] + (jnp.array(args[-1]),)
varr, vtreedef, vidx = cls._tree2arr(v)
xarr, xtreedef, xidx = cls._tree2arr(args)

def fwrap(x):
x = cls._arr2tree(x, xtreedef, xidx)
return fun(*x)

fx = fwrap(xarr)
return (fwrap(xarr + h * varr) - fx) / h

@classmethod
def _tree2arr(cls, tree):
leaves, treedef = jax.tree_util.tree_flatten(tree)
idx = np.cumsum([foo.size for foo in leaves])
return jnp.concatenate(leaves), treedef, idx[:-1]

@classmethod
def _arr2tree(cls, arr, treedef, idx):
leaves = jnp.split(arr, idx)
return jax.tree_util.tree_unflatten(treedef, leaves)

@classmethod
def compute_jvp_old(cls, fun, argnum, v, *args, **kwargs):
"""Compute df/dx*v.

Parameters
Expand All @@ -426,20 +461,19 @@ def compute_jvp(cls, fun, argnum, v, *args, **kwargs):
"""
rel_step = kwargs.get("rel_step", 1e-3)

if np.isscalar(argnum):
if jnp.isscalar(argnum):
nargs = 1
argnum = (argnum,)
else:
nargs = len(argnum)
v = (v,) if not isinstance(v, tuple) else v

f = np.array(
f = jnp.array(
[
cls._compute_jvp_1arg(fun, argnum[i], v[i], *args, rel_step=rel_step)
for i in range(nargs)
]
)
return np.sum(f, axis=0)
return jnp.sum(f, axis=0)

@classmethod
def compute_jvp2(cls, fun, argnum1, argnum2, v1, v2, *args):
Expand Down Expand Up @@ -538,11 +572,8 @@ def _compute_jvp(self, v, *args):
def _compute_jvp_1arg(cls, fun, argnum, v, *args, **kwargs):
"""Compute a jvp wrt a single argument."""
rel_step = kwargs.get("rel_step", 1e-3)
normv = np.linalg.norm(v)
if normv != 0:
vh = v / normv
else:
vh = v
normv = jnp.linalg.norm(v)
vh = jnp.where(normv != 0, v / normv, v)
x = args[argnum]

def f(x):
Expand Down
Loading
Loading