Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JIT Error encountered when optimizing GammaC #1288

Open
dpanici opened this issue Oct 2, 2024 · 7 comments
Open

JIT Error encountered when optimizing GammaC #1288

dpanici opened this issue Oct 2, 2024 · 7 comments
Labels
bug Something isn't working

Comments

@dpanici
Copy link
Collaborator

dpanici commented Oct 2, 2024

Error seems to occur when optimizing GammaC objective on gh/Gamma_c branch, happens on the second optimization step and seems related to the JIT cache? The error also only occurs if attempting an optimization at a resolution that you have previously optimized at, changing the eq resolution between steps seems to avoid this issue, so I assume it is related to the caching

MWE:

from desc import set_device

set_device("gpu")

import jax
import numpy as np

import desc.examples
from desc.continuation import solve_continuation_automatic
from desc.equilibrium import EquilibriaFamily, Equilibrium
from desc.geometry import FourierRZToroidalSurface
from desc.grid import ConcentricGrid, LinearGrid
from desc.io import load
from desc.objectives import (  # FixIota,
    AspectRatio,
    Elongation,
    FixBoundaryR,
    FixBoundaryZ,
    FixCurrent,
    FixPressure,
    FixPsi,
    ForceBalance,
    GammaC,
    GenericObjective,
    ObjectiveFunction,
    QuasisymmetryTwoTerm,
)
from desc.optimize import Optimizer
from desc.plotting import plot_boozer_surface
import pdb
from desc.backend import jnp
from desc.examples import get
def run_opt_step(k, eq):
    """Run a step of the optimization example."""
    # this step will only optimize boundary modes with |m|,|n| <= k
    # we create an ObjectiveFunction, in this case made up of multiple objectives
    # which will be combined in a least squares sense

    shape_grid = LinearGrid(
        M=int(eq.M), N=int(eq.N), rho=np.array([1.0]), NFP=eq.NFP, sym=True, axis=False
    )

    ntransits = 8

    zeta_field_line = np.linspace(0, 2 * np.pi * ntransits, 64 * ntransits)
    alpha = jnp.array([0.0])
    rho = jnp.linspace(0.85, 1.0, 2)
    # rho = np.linspace(0.85, 1.0, 2)
    flux_surface_grid = LinearGrid(
        rho=rho, M=eq.M_grid, N=eq.N_grid, sym=eq.sym, NFP=eq.NFP
    )

    objective = ObjectiveFunction(
        (
            GammaC(
                eq=eq,
                rho=rho,
                alpha=alpha,
                deriv_mode="fwd",
                batch=False,
                weight=1e3,
                Nemov = False,
            ),
            Elongation(eq=eq, grid=shape_grid,target=1),#0 bounds=(0.5, 2.0), weight=1e3),
            GenericObjective(
                f="curvature_k2_rho",
                thing=eq,
                grid=shape_grid,
                bounds=(-75, 15),
                weight=2e3,
            ),
        ),
    )
    R_modes = np.vstack(
        (
            [0, 0, 0],
            eq.surface.R_basis.modes[
                np.max(np.abs(eq.surface.R_basis.modes), 1) > k, :
            ],
        )
    )
    Z_modes = eq.surface.Z_basis.modes[
        np.max(np.abs(eq.surface.Z_basis.modes), 1) > k, :
    ]
    constraints = (
        ForceBalance(
            eq,
            grid=ConcentricGrid(
                L=round(2 * eq.L),
                M=round(1.5 * eq.M),
                N=round(1.5 * eq.N),
                NFP=eq.NFP,
                sym=eq.sym,
            ),
        ),
        FixBoundaryR(eq=eq, modes=R_modes),
        FixBoundaryZ(eq=eq, modes=Z_modes),
        FixPressure(eq=eq),
        FixCurrent(eq=eq),
        FixPsi(eq=eq),
    )
    # this is the default optimizer, which re-solves the equilibrium at each step
    optimizer = Optimizer("proximal-lsq-exact")          
    eq_new, result = optimizer.optimize(
        things = eq,
        objective=objective,
        constraints=constraints,
        maxiter=3,  # we don't need to solve to optimality at each multigrid step
        verbose=3,
        copy=True,  # don't modify original, return a new optimized copy
        options={
            # Sometimes the default initial trust radius is too big, allowing the
            # optimizer to take too large a step in a bad direction. If this happens,
            # we can manually specify a smaller starting radius. Each optimizer has a
            # number of different options that can be used to tune the performance.
            # See the documentation for more info.
            "initial_trust_ratio": 1e-2,
            "maxiter": 125,
            "ftol": 1e-3,
            "xtol": 1e-8,
        },
    )
    eq_new = eq_new[0]
   
    return eq_new 

eq = get("ESTELL")
for k in np.arange(1, eq.M + 1, 1):
    if not eq.is_nested():
        print("NOT NESTED")
        assert eq.is_nested()
        break
    jax.clear_caches()
    eq = run_opt_step(k, eq)

Error:

ValueError                                Traceback (most recent call last)
Cell In[1], line 137
    135     break
    136 jax.clear_caches()
--> 137 eq = run_opt_step(k, eq)

Cell In[1], line 107, in run_opt_step(k, eq)
    103 optimizer = Optimizer("proximal-lsq-exact")
    105 print("spot 1:", type(eq))
--> 107 eq_new, result = optimizer.optimize(
    108     things = eq,
    109     objective=objective,
    110     constraints=constraints,
    111     maxiter=3,  # we don't need to solve to optimality at each multigrid step
    112     verbose=3,
    113     copy=True,  # don't modify original, return a new optimized copy
    114     options={
    115         # Sometimes the default initial trust radius is too big, allowing the
    116         # optimizer to take too large a step in a bad direction. If this happens,
    117         # we can manually specify a smaller starting radius. Each optimizer has a
    118         # number of different options that can be used to tune the performance.
    119         # See the documentation for more info.
    120         "initial_trust_ratio": 1e-2,
    121         "maxiter": 125,
    122         "ftol": 1e-3,
    123         "xtol": 1e-8,
    124     },
    125 )
    126 eq_new = eq_new[0]
    128 return eq_new

File ~/DESC/desc/optimize/optimizer.py:311, in Optimizer.optimize(self, things, objective, constraints, ftol, xtol, gtol, ctol, x_scale, verbose, maxiter, options, copy)
    307     print("Using method: " + str(self.method))
    309 timer.start("Solution time")
--> 311 result = optimizers[method]["fun"](
    312     objective,
    313     nonlinear_constraint,
    314     x0,
    315     method,
    316     x_scale,
    317     verbose,
    318     stoptol,
    319     options,
    320 )
    322 if isinstance(objective, LinearConstraintProjection):
    323     # remove wrapper to get at underlying objective
    324     result["allx"] = [objective.recover(x) for x in result["allx"]]

File ~/DESC/desc/optimize/_desc_wrappers.py:270, in _optimize_desc_least_squares(objective, constraint, x0, method, x_scale, verbose, stoptol, options)
    267     options.setdefault("initial_trust_ratio", 0.1)
    268 options["max_nfev"] = stoptol["max_nfev"]
--> 270 result = lsqtr(
    271     objective.compute_scaled_error,
    272     x0=x0,
    273     jac=objective.jac_scaled_error,
    274     args=(objective.constants,),
    275     x_scale=x_scale,
    276     ftol=stoptol["ftol"],
    277     xtol=stoptol["xtol"],
    278     gtol=stoptol["gtol"],
    279     maxiter=stoptol["maxiter"],
    280     verbose=verbose,
    281     callback=None,
    282     options=options,
    283 )
    284 return result

File ~/DESC/desc/optimize/least_squares.py:176, in lsqtr(fun, x0, jac, bounds, args, x_scale, ftol, xtol, gtol, verbose, maxiter, callback, options)
    173 assert in_bounds(x, lb, ub), "x0 is infeasible"
    174 x = make_strictly_feasible(x, lb, ub)
--> 176 f = fun(x, *args)
    177 nfev += 1
    178 cost = 0.5 * jnp.dot(f, f)

File ~/DESC/desc/optimize/_constraint_wrappers.py:224, in LinearConstraintProjection.compute_scaled_error(self, x_reduced, constants)
    208 """Compute the objective function and apply weighting / bounds.
    209 
    210 Parameters
   (...)
    221 
    222 """
    223 x = self.recover(x_reduced)
--> 224 f = self._objective.compute_scaled_error(x, constants)
    225 return f

File ~/DESC/desc/optimize/_constraint_wrappers.py:843, in ProximalProjection.compute_scaled_error(self, x, constants)
    841 constants = setdefault(constants, self.constants)
    842 xopt, _ = self._update_equilibrium(x, store=False)
--> 843 return self._objective.compute_scaled_error(xopt, constants[0])

    [... skipping hidden 6 frame]

File ~/.conda/envs/desc-env-latest/lib/python3.11/site-packages/jax/_src/pjit.py:1339, in seen_attrs_get(fun, in_type)
   1337 cache = _seen_attrs.setdefault(fun.f, defaultdict(list))
   1338 assert fun.in_type is None or fun.in_type == in_type
-> 1339 return cache[(fun.transforms, fun.params, in_type)]

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
@dpanici dpanici added the bug Something isn't working label Oct 2, 2024
@unalmis
Copy link
Collaborator

unalmis commented Oct 2, 2024

Here are the steps that should be taken to debug

  1. Has this error ever occurred on the branch Gamma_c where this quantity is developed on? JAX version?
  2. If yes to 1, then check which recent changes to master, perhaps the Jacobian changes, have caused this

Unjitting the compute function tends to help for debugging

@f0uriest
Copy link
Member

f0uriest commented Oct 3, 2024

It seems to be unique to the bounce integral objectives, if I comment out GammaC it works fine, and if i change to EffectiveRipple it still happens.

Other things:

  • the bounce integral objectives take a loooong time to compile, even at very low resolution. Like a few minutes, compared to a few seconds for the other objectives
  • when either bounce integral objective is included it seems to make the optimizer stall out (rejects a lot of steps and exits)

These probably aren't related to the error above, but might be another source of concern

@f0uriest
Copy link
Member

f0uriest commented Oct 3, 2024

with use_jit=False (and commenting out the jit in constraint wrappers) I'm unable to reproduce

@unalmis
Copy link
Collaborator

unalmis commented Oct 3, 2024

Ok I ran optimizations leading up to ISHW, so commit 5cd7ebd should not have this issue. State of branch at that commit https://github.com/PlasmaControl/DESC/tree/5cd7ebde563258f754a0401d9da6aa143bc3376f

@unalmis
Copy link
Collaborator

unalmis commented Oct 3, 2024

with use_jit=False (and commenting out the jit in constraint wrappers) I'm unable to reproduce

There is also a jit call wrapping the compute function in _compute. When you could no longer reproduce, was this JIT call still online?

the bounce integral objectives take a loooong time to compile, even at very low resolution. Like a few minutes, compared to a few seconds for the other objectives

Aren't these compiled once? The BallooningStability objective requires less resolution than bounce integrals along a field line, but it still does a coordinate mapping inside the objective and builds transforms on the resulting grid. How does compilation time / optimization stalling compare when "low resolution" is typical resolution for BallooningStability?

the optimizer stall out (rejects a lot of steps and exits)

Can memory usage effect this? Is this forward or reverse mode? I ran forward optimizations before ISHW and did not see the optimizer exit

@dpanici
Copy link
Collaborator Author

dpanici commented Oct 3, 2024

5cd7ebd...Gamma_c
the diff page btwn the commit Kaya mentioned and the current Gamma_c branch

@dpanici
Copy link
Collaborator Author

dpanici commented Oct 3, 2024

I won't have time to debug tonight/tmrw, but will look more this weekend. thanks for starting to look into this so quickly though. on Gamma_c I see the same bug for both GammaC objective and EffectiveRipple

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants