Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error in map_coordinates with jax==0.4.34 #1291

Open
dpanici opened this issue Oct 4, 2024 · 3 comments · May be fixed by #1293
Open

Error in map_coordinates with jax==0.4.34 #1291

dpanici opened this issue Oct 4, 2024 · 3 comments · May be fixed by #1293
Labels
bug Something isn't working

Comments

@dpanici
Copy link
Collaborator

dpanici commented Oct 4, 2024

Calling, for example plot_comparison in our basic equilibriim notebook yields a JAX error on 0.4.34

related to _jacobi_jvp which seems like we did something slightly hacky to avoid escaped trace values, maybe not JAX does not like that anymore

{
	"name": "TypeError",
	"message": "Called multiply with a float0 array. float0s do not support any operations by design because they are not compatible with non-trivial vector spaces. No implicit dtype conversion is done. You can use np.zeros_like(arr, dtype=np.float) to cast a float0 array to a regular zeros array. 
If you didn't expect to get a float0 you might have accidentally taken a gradient with respect to an integer argument.",
	"stack": "---------------------------------------------------------------------------
JaxStackTraceBeforeTransformation         Traceback (most recent call last)
File <frozen runpy>:198, in _run_module_as_main()

File <frozen runpy>:88, in _run_code()

File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/site-packages/ipykernel_launcher.py:18
     16 from ipykernel import kernelapp as app
---> 18 app.launch_new_instance()

File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/site-packages/traitlets/config/application.py:1075, in launch_instance()
   1074 app.initialize(argv)
-> 1075 app.start()

File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/site-packages/ipykernel/kernelapp.py:739, in start()
    738 try:
--> 739     self.io_loop.start()
    740 except KeyboardInterrupt:

File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/site-packages/tornado/platform/asyncio.py:205, in start()
    204 def start(self) -> None:
--> 205     self.asyncio_loop.run_forever()

File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/asyncio/base_events.py:618, in run_forever()
    617 while True:
--> 618     self._run_once()
    619     if self._stopping:

File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/asyncio/base_events.py:1951, in _run_once()
   1950     else:
-> 1951         handle._run()
   1952 handle = None

File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/asyncio/events.py:84, in _run()
     83 try:
---> 84     self._context.run(self._callback, *self._args)
     85 except (SystemExit, KeyboardInterrupt):

File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/site-packages/ipykernel/kernelbase.py:545, in dispatch_queue()
    544 try:
--> 545     await self.process_one()
    546 except Exception:

File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/site-packages/ipykernel/kernelbase.py:534, in process_one()
    533         return
--> 534 await dispatch(*args)

File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/site-packages/ipykernel/kernelbase.py:437, in dispatch_shell()
    436     if inspect.isawaitable(result):
--> 437         await result
    438 except Exception:

File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/site-packages/ipykernel/ipkernel.py:362, in execute_request()
    361 self._associate_new_top_level_threads_with(parent_header)
--> 362 await super().execute_request(stream, ident, parent)

File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/site-packages/ipykernel/kernelbase.py:778, in execute_request()
    777 if inspect.isawaitable(reply_content):
--> 778     reply_content = await reply_content
    780 # Flush output before sending the reply.

File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/site-packages/ipykernel/ipkernel.py:449, in do_execute()
    448 if accepts_params[\"cell_id\"]:
--> 449     res = shell.run_cell(
    450         code,
    451         store_history=store_history,
    452         silent=silent,
    453         cell_id=cell_id,
    454     )
    455 else:

File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/site-packages/ipykernel/zmqshell.py:549, in run_cell()
    548 self._last_traceback = None
--> 549 return super().run_cell(*args, **kwargs)

File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/site-packages/IPython/core/interactiveshell.py:3075, in run_cell()
   3074 try:
-> 3075     result = self._run_cell(
   3076         raw_cell, store_history, silent, shell_futures, cell_id
   3077     )
   3078 finally:

File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/site-packages/IPython/core/interactiveshell.py:3130, in _run_cell()
   3129 try:
-> 3130     result = runner(coro)
   3131 except BaseException as e:

File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/site-packages/IPython/core/async_helpers.py:128, in _pseudo_sync_runner()
    127 try:
--> 128     coro.send(None)
    129 except StopIteration as exc:

File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/site-packages/IPython/core/interactiveshell.py:3334, in run_cell_async()
   3331 interactivity = \"none\" if silent else self.ast_node_interactivity
-> 3334 has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
   3335        interactivity=interactivity, compiler=compiler, result=result)
   3337 self.last_execution_succeeded = not has_raised

File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/site-packages/IPython/core/interactiveshell.py:3517, in run_ast_nodes()
   3516     asy = compare(code)
-> 3517 if await self.run_code(code, result, async_=asy):
   3518     return True

File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/site-packages/IPython/core/interactiveshell.py:3577, in run_code()
   3576     else:
-> 3577         exec(code_obj, self.user_global_ns, self.user_ns)
   3578 finally:
   3579     # Reset our crash handler in place

Cell In[6], line 3
      2 print(\"Number of equilibria in the EquilibriaFamily:\", len(eq_fam))
----> 3 fig, ax = plot_comparison(
      4     eqs=[eq_fam[1], eq_fam[3], eq_fam[-1]],
      5     labels=[
      6         \"Axisymmetric w/o pressure\",
      7         \"Axisymmetric w/ pressure\",
      8         \"Nonaxisymmetric w/ pressure\",
      9     ],
     10 )

File ~/Research/DESC/desc/plotting.py:2388, in plot_comparison()
   2380 for i, eq in enumerate(eqs):
   2381     fig, ax, _plot_data = plot_surfaces(
   2382         eq,
   2383         rho,
   2384         theta,
   2385         phi,
   2386         ax,
   2387         theta_color=color[i % len(color)],
-> 2388         theta_ls=ls[i % len(ls)],
   2389         theta_lw=lw[i % len(lw)],
   2390         rho_color=color[i % len(color)],
   2391         rho_ls=ls[i % len(ls)],
   2392         rho_lw=lw[i % len(lw)],
   2393         lcfs_color=color[i % len(color)],
   2394         lcfs_ls=ls[i % len(ls)],
   2395         lcfs_lw=lw[i % len(lw)],
   2396         axis_color=color[i % len(color)],
   2397         axis_alpha=0,
   2398         axis_marker=\"o\",
   2399         axis_size=0,
   2400         label=labels[i % len(labels)],
   2401         title_fontsize=title_fontsize,
   2402         xlabel_fontsize=xlabel_fontsize,
   2403         ylabel_fontsize=ylabel_fontsize,
   2404         return_data=True,
   2405     )
   2406     for key in _plot_data.keys():

File ~/Research/DESC/desc/plotting.py:1685, in plot_surfaces()
   1676     tnr, tnt, tnz = t_grid.num_rho, t_grid.num_theta, t_grid.num_zeta
   1677     v_grid = Grid(
   1678         map_coordinates(
   1679             eq,
   1680             t_grid.nodes,
   1681             [\"rho\", \"theta_PEST\", \"phi\"],
   1682             [\"rho\", \"theta\", \"zeta\"],
   1683             period=(np.inf, 2 * np.pi, 2 * np.pi),
   1684             guess=t_grid.nodes,
-> 1685         ),
   1686         sort=False,
   1687     )
   1688 rows = np.floor(np.sqrt(nphi)).astype(int)

File ~/Research/DESC/desc/equilibrium/coords.py:218, in map_coordinates()
    215 # See description here
    216 # https://github.com/PlasmaControl/DESC/pull/504#discussion_r1194172532
    217 # except we make sure properly handle periodic coordinates.
--> 218 yk, (res, niter) = vecroot(yk, coords)
    220 out = compute(yk, outbasis)

File ~/Research/DESC/desc/equilibrium/coords.py:203, in map_coordinates.<locals>.<lambda>()
    199 yk = fixup(yk)
    201 vecroot = jit(
    202     vmap(
--> 203         lambda x0, *p: root(
    204             residual,
    205             x0,
    206             jac=jac,
    207             args=p,
    208             fixup=fixup,
    209             tol=tol,
    210             maxiter=maxiter,
    211             **kwargs,
    212         )
    213     )
    214 )
    215 # See description here
    216 # https://github.com/PlasmaControl/DESC/pull/504#discussion_r1194172532
    217 # except we make sure properly handle periodic coordinates.

File ~/Research/DESC/desc/backend.py:398, in root()
    396     return _lstsq(A, jnp.atleast_1d(y))
--> 398 x, (res, niter) = jax.lax.custom_root(
    399     res, x0, solve, tangent_solve, has_aux=True
    400 )
    401 return x, (safenorm(res), niter)

File ~/Research/DESC/desc/backend.py:344, in root.<locals>.<lambda>()
    342     jac2 = lambda x: jnp.atleast_2d(jac(x, *args))
--> 344 res = lambda x: jnp.atleast_1d(fun(x, *args)).flatten()
    346 # want to use least squares for rank-defficient systems, but
    347 # jnp.linalg.lstsq doesn't have JVP defined and is slower than needed
    348 # so we use the normal equations with regularized cholesky

File ~/Research/DESC/desc/equilibrium/coords.py:174, in residual()
    172 @jit
    173 def residual(y, coords):
--> 174     xk = compute(y, inbasis)
    175     return _fixup_residual(xk - coords, period)

File ~/Research/DESC/desc/equilibrium/coords.py:167, in compute()
    166     data[\"iota_rr\"] = profiles[\"iota\"].compute(grid, dr=2, params=params[\"i_l\"])
--> 167 transforms = get_transforms(basis, eq, grid, jitable=True)
    168 data = compute_fun(eq, basis, params, transforms, profiles, data)

File ~/Research/DESC/desc/backend.py:112, in wrapper()
    111 with jax.default_device(jax.devices(\"cpu\")[0]):
--> 112     return func(*args, **kwargs)

File ~/Research/DESC/desc/compute/utils.py:631, in get_transforms()
    630     if hasattr(t, \"build\"):
--> 631         t.build()
    633 return transforms

File ~/Research/DESC/desc/transform.py:385, in build()
    384     for d in self.derivatives:
--> 385         self.matrices[\"direct1\"][d[0]][d[1]][d[2]] = self.basis.evaluate(
    386             self.grid.nodes, d, unique=False
    387         )
    389 if self.method in [\"fft\", \"direct2\"]:

File ~/Research/DESC/desc/basis.py:1134, in evaluate()
   1132     n = n[nidx]
-> 1134 radial = zernike_radial(r[:, np.newaxis], lm[:, 0], lm[:, 1], dr=derivatives[0])
   1135 poloidal = fourier(t[:, np.newaxis], m, dt=derivatives[1])

File ~/Research/DESC/desc/basis.py:1533, in zernike_radial()
   1532 if dr == 0:
-> 1533     out = r**m * _jacobi(n, alpha, beta, jacobi_arg, 0)
   1534 elif dr == 1:

JaxStackTraceBeforeTransformation: TypeError: Called multiply with a float0 array. float0s do not support any operations by design because they are not compatible with non-trivial vector spaces. No implicit dtype conversion is done. You can use np.zeros_like(arr, dtype=np.float) to cast a float0 array to a regular zeros array. 
If you didn't expect to get a float0 you might have accidentally taken a gradient with respect to an integer argument.

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

TypeError                                 Traceback (most recent call last)
Cell In[6], line 3
      1 eq_fam = desc.io.load(\"input.HELIOTRON_output.h5\")
      2 print(\"Number of equilibria in the EquilibriaFamily:\", len(eq_fam))
----> 3 fig, ax = plot_comparison(
      4     eqs=[eq_fam[1], eq_fam[3], eq_fam[-1]],
      5     labels=[
      6         \"Axisymmetric w/o pressure\",
      7         \"Axisymmetric w/ pressure\",
      8         \"Nonaxisymmetric w/ pressure\",
      9     ],
     10 )

File ~/Research/DESC/desc/plotting.py:2388, in plot_comparison(eqs, rho, theta, phi, ax, cmap, color, lw, ls, labels, return_data, **kwargs)
   2386     plot_data[string] = []
   2387 for i, eq in enumerate(eqs):
-> 2388     fig, ax, _plot_data = plot_surfaces(
   2389         eq,
   2390         rho,
   2391         theta,
   2392         phi,
   2393         ax,
   2394         theta_color=color[i % len(color)],
   2395         theta_ls=ls[i % len(ls)],
   2396         theta_lw=lw[i % len(lw)],
   2397         rho_color=color[i % len(color)],
   2398         rho_ls=ls[i % len(ls)],
   2399         rho_lw=lw[i % len(lw)],
   2400         lcfs_color=color[i % len(color)],
   2401         lcfs_ls=ls[i % len(ls)],
   2402         lcfs_lw=lw[i % len(lw)],
   2403         axis_color=color[i % len(color)],
   2404         axis_alpha=0,
   2405         axis_marker=\"o\",
   2406         axis_size=0,
   2407         label=labels[i % len(labels)],
   2408         title_fontsize=title_fontsize,
   2409         xlabel_fontsize=xlabel_fontsize,
   2410         ylabel_fontsize=ylabel_fontsize,
   2411         return_data=True,
   2412     )
   2413     for key in _plot_data.keys():
   2414         plot_data[key].append(_plot_data[key])

File ~/Research/DESC/desc/plotting.py:1685, in plot_surfaces(eq, rho, theta, phi, ax, return_data, **kwargs)
   1682     t_grid = _get_grid(**grid_kwargs)
   1683     tnr, tnt, tnz = t_grid.num_rho, t_grid.num_theta, t_grid.num_zeta
   1684     v_grid = Grid(
-> 1685         map_coordinates(
   1686             eq,
   1687             t_grid.nodes,
   1688             [\"rho\", \"theta_PEST\", \"phi\"],
   1689             [\"rho\", \"theta\", \"zeta\"],
   1690             period=(np.inf, 2 * np.pi, 2 * np.pi),
   1691             guess=t_grid.nodes,
   1692         ),
   1693         sort=False,
   1694     )
   1695 rows = np.floor(np.sqrt(nphi)).astype(int)
   1696 cols = np.ceil(nphi / rows).astype(int)

File ~/Research/DESC/desc/equilibrium/coords.py:218, in map_coordinates(eq, coords, inbasis, outbasis, guess, params, period, tol, maxiter, full_output, **kwargs)
    201 vecroot = jit(
    202     vmap(
    203         lambda x0, *p: root(
   (...)
    213     )
    214 )
    215 # See description here
    216 # https://github.com/PlasmaControl/DESC/pull/504#discussion_r1194172532
    217 # except we make sure properly handle periodic coordinates.
--> 218 yk, (res, niter) = vecroot(yk, coords)
    220 out = compute(yk, outbasis)
    221 if full_output:

    [... skipping hidden 14 frame]

File ~/Research/DESC/desc/equilibrium/coords.py:203, in map_coordinates.<locals>.<lambda>(x0, *p)
    197     yk = _initial_guess_nn_search(coords, inbasis, eq, period, compute)
    199 yk = fixup(yk)
    201 vecroot = jit(
    202     vmap(
--> 203         lambda x0, *p: root(
    204             residual,
    205             x0,
    206             jac=jac,
    207             args=p,
    208             fixup=fixup,
    209             tol=tol,
    210             maxiter=maxiter,
    211             **kwargs,
    212         )
    213     )
    214 )
    215 # See description here
    216 # https://github.com/PlasmaControl/DESC/pull/504#discussion_r1194172532
    217 # except we make sure properly handle periodic coordinates.
    218 yk, (res, niter) = vecroot(yk, coords)

File ~/Research/DESC/desc/backend.py:398, in root(fun, x0, jac, args, tol, maxiter, maxiter_ls, alpha, fixup)
    395     A = jnp.atleast_2d(jax.jacfwd(g)(y))
    396     return _lstsq(A, jnp.atleast_1d(y))
--> 398 x, (res, niter) = jax.lax.custom_root(
    399     res, x0, solve, tangent_solve, has_aux=True
    400 )
    401 return x, (safenorm(res), niter)

    [... skipping hidden 14 frame]

File ~/Research/DESC/desc/backend.py:344, in root.<locals>.<lambda>(x)
    341 else:
    342     jac2 = lambda x: jnp.atleast_2d(jac(x, *args))
--> 344 res = lambda x: jnp.atleast_1d(fun(x, *args)).flatten()
    346 # want to use least squares for rank-defficient systems, but
    347 # jnp.linalg.lstsq doesn't have JVP defined and is slower than needed
    348 # so we use the normal equations with regularized cholesky
    349 def _lstsq(a, b):

    [... skipping hidden 47 frame]

File ~/Research/DESC/desc/basis.py:1816, in _jacobi_jvp(x, xdot)
   1811 df = _jacobi(n, alpha, beta, x, dx + 1)
   1812 # in theory n, alpha, beta, dx aren't differentiable (they're integers)
   1813 # but marking them as non-diff argnums seems to cause escaped tracer values.
   1814 # probably a more elegant fix, but just setting those derivatives to zero seems
   1815 # to work fine.
-> 1816 return f, df * xdot + 0 * ndot + 0 * alphadot + 0 * betadot + 0 * dxdot

File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/site-packages/jax/_src/numpy/array_methods.py:1036, in _forward_operator_to_aval.<locals>.op(self, *args)
   1035 def op(self, *args):
-> 1036   return getattr(self.aval, f\"_{name}\")(self, *args)

File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/site-packages/jax/_src/numpy/array_methods.py:573, in _defer_to_unrecognized_arg.<locals>.deferring_binary_op(self, other)
    571 args = (other, self) if swap else (self, other)
    572 if isinstance(other, _accepted_binop_types):
--> 573   return binary_op(*args)
    574 # Note: don't use isinstance here, because we don't want to raise for
    575 # subclasses, e.g. NamedTuple objects that may override operators.
    576 if type(other) in _rejected_binop_types:

File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/site-packages/jax/_src/numpy/ufunc_api.py:177, in ufunc.__call__(self, out, where, *args)
    175   raise NotImplementedError(f\"where argument of {self}\")
    176 call = self.__static_props['call'] or self._call_vectorized
--> 177 return call(*args)

    [... skipping hidden 11 frame]

File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/site-packages/jax/_src/numpy/ufuncs.py:1142, in _multiply(x, y)
   1115 @partial(jit, inline=True)
   1116 def _multiply(x: ArrayLike, y: ArrayLike, /) -> Array:
   1117   \"\"\"Multiply two arrays element-wise.
   1118 
   1119   JAX implementation of :obj:`numpy.multiply`. This is a universal function,
   (...)
   1140     Array([ 0, 10, 20, 30], dtype=int32)
   1141   \"\"\"
-> 1142   x, y = promote_args(\"multiply\", x, y)
   1143   return lax.mul(x, y) if x.dtype != bool else lax.bitwise_and(x, y)

File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/site-packages/jax/_src/numpy/util.py:355, in promote_args(fun_name, *args)
    353 \"\"\"Convenience function to apply Numpy argument shape and dtype promotion.\"\"\"
    354 check_arraylike(fun_name, *args)
--> 355 _check_no_float0s(fun_name, *args)
    356 check_for_prngkeys(fun_name, *args)
    357 return promote_shapes(fun_name, *promote_dtypes(*args))

File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/site-packages/jax/_src/numpy/util.py:326, in check_no_float0s(fun_name, *args)
    324 \"\"\"Check if none of the args have dtype float0.\"\"\"
    325 if any(dtypes.dtype(arg) == dtypes.float0 for arg in args):
--> 326   raise TypeError(
    327       f\"Called {fun_name} with a float0 array. \"
    328       \"float0s do not support any operations by design because they \"
    329       \"are not compatible with non-trivial vector spaces. No implicit dtype \"
    330       \"conversion is done. You can use np.zeros_like(arr, dtype=np.float) \"
    331       \"to cast a float0 array to a regular zeros array. \
\"
    332       \"If you didn't expect to get a float0 you might have accidentally \"
    333       \"taken a gradient with respect to an integer argument.\")

TypeError: Called multiply with a float0 array. float0s do not support any operations by design because they are not compatible with non-trivial vector spaces. No implicit dtype conversion is done. You can use np.zeros_like(arr, dtype=np.float) to cast a float0 array to a regular zeros array. 
If you didn't expect to get a float0 you might have accidentally taken a gradient with respect to an integer argument."
}
@dpanici dpanici added the bug Something isn't working label Oct 4, 2024
dpanici added a commit that referenced this issue Oct 4, 2024
Pin jax to `<0.4.34` until we fix bug listed in #1291
@YigitElma
Copy link
Collaborator

There are some other bugs related to JAX 0.4.34, for example equinox stuff, AttributeError: jax.core.pp_eqn_rules was removed in JAX v0.4.34.

@YigitElma YigitElma linked a pull request Oct 4, 2024 that will close this issue
3 tasks
@YigitElma
Copy link
Collaborator

There are some other bugs related to JAX 0.4.34, for example equinox stuff, AttributeError: jax.core.pp_eqn_rules was removed in JAX v0.4.34.

Ok this was due to some old version of equinox==0.11.3

@ddudt
Copy link
Collaborator

ddudt commented Oct 14, 2024

We think the issue is caused by returning the metadata from the root solve, which can include non-differentiable data types

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants