We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
map_coordinates
jax==0.4.34
Calling, for example plot_comparison in our basic equilibriim notebook yields a JAX error on 0.4.34
related to _jacobi_jvp which seems like we did something slightly hacky to avoid escaped trace values, maybe not JAX does not like that anymore
_jacobi_jvp
{ "name": "TypeError", "message": "Called multiply with a float0 array. float0s do not support any operations by design because they are not compatible with non-trivial vector spaces. No implicit dtype conversion is done. You can use np.zeros_like(arr, dtype=np.float) to cast a float0 array to a regular zeros array. If you didn't expect to get a float0 you might have accidentally taken a gradient with respect to an integer argument.", "stack": "--------------------------------------------------------------------------- JaxStackTraceBeforeTransformation Traceback (most recent call last) File <frozen runpy>:198, in _run_module_as_main() File <frozen runpy>:88, in _run_code() File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/site-packages/ipykernel_launcher.py:18 16 from ipykernel import kernelapp as app ---> 18 app.launch_new_instance() File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/site-packages/traitlets/config/application.py:1075, in launch_instance() 1074 app.initialize(argv) -> 1075 app.start() File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/site-packages/ipykernel/kernelapp.py:739, in start() 738 try: --> 739 self.io_loop.start() 740 except KeyboardInterrupt: File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/site-packages/tornado/platform/asyncio.py:205, in start() 204 def start(self) -> None: --> 205 self.asyncio_loop.run_forever() File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/asyncio/base_events.py:618, in run_forever() 617 while True: --> 618 self._run_once() 619 if self._stopping: File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/asyncio/base_events.py:1951, in _run_once() 1950 else: -> 1951 handle._run() 1952 handle = None File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/asyncio/events.py:84, in _run() 83 try: ---> 84 self._context.run(self._callback, *self._args) 85 except (SystemExit, KeyboardInterrupt): File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/site-packages/ipykernel/kernelbase.py:545, in dispatch_queue() 544 try: --> 545 await self.process_one() 546 except Exception: File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/site-packages/ipykernel/kernelbase.py:534, in process_one() 533 return --> 534 await dispatch(*args) File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/site-packages/ipykernel/kernelbase.py:437, in dispatch_shell() 436 if inspect.isawaitable(result): --> 437 await result 438 except Exception: File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/site-packages/ipykernel/ipkernel.py:362, in execute_request() 361 self._associate_new_top_level_threads_with(parent_header) --> 362 await super().execute_request(stream, ident, parent) File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/site-packages/ipykernel/kernelbase.py:778, in execute_request() 777 if inspect.isawaitable(reply_content): --> 778 reply_content = await reply_content 780 # Flush output before sending the reply. File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/site-packages/ipykernel/ipkernel.py:449, in do_execute() 448 if accepts_params[\"cell_id\"]: --> 449 res = shell.run_cell( 450 code, 451 store_history=store_history, 452 silent=silent, 453 cell_id=cell_id, 454 ) 455 else: File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/site-packages/ipykernel/zmqshell.py:549, in run_cell() 548 self._last_traceback = None --> 549 return super().run_cell(*args, **kwargs) File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/site-packages/IPython/core/interactiveshell.py:3075, in run_cell() 3074 try: -> 3075 result = self._run_cell( 3076 raw_cell, store_history, silent, shell_futures, cell_id 3077 ) 3078 finally: File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/site-packages/IPython/core/interactiveshell.py:3130, in _run_cell() 3129 try: -> 3130 result = runner(coro) 3131 except BaseException as e: File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/site-packages/IPython/core/async_helpers.py:128, in _pseudo_sync_runner() 127 try: --> 128 coro.send(None) 129 except StopIteration as exc: File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/site-packages/IPython/core/interactiveshell.py:3334, in run_cell_async() 3331 interactivity = \"none\" if silent else self.ast_node_interactivity -> 3334 has_raised = await self.run_ast_nodes(code_ast.body, cell_name, 3335 interactivity=interactivity, compiler=compiler, result=result) 3337 self.last_execution_succeeded = not has_raised File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/site-packages/IPython/core/interactiveshell.py:3517, in run_ast_nodes() 3516 asy = compare(code) -> 3517 if await self.run_code(code, result, async_=asy): 3518 return True File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/site-packages/IPython/core/interactiveshell.py:3577, in run_code() 3576 else: -> 3577 exec(code_obj, self.user_global_ns, self.user_ns) 3578 finally: 3579 # Reset our crash handler in place Cell In[6], line 3 2 print(\"Number of equilibria in the EquilibriaFamily:\", len(eq_fam)) ----> 3 fig, ax = plot_comparison( 4 eqs=[eq_fam[1], eq_fam[3], eq_fam[-1]], 5 labels=[ 6 \"Axisymmetric w/o pressure\", 7 \"Axisymmetric w/ pressure\", 8 \"Nonaxisymmetric w/ pressure\", 9 ], 10 ) File ~/Research/DESC/desc/plotting.py:2388, in plot_comparison() 2380 for i, eq in enumerate(eqs): 2381 fig, ax, _plot_data = plot_surfaces( 2382 eq, 2383 rho, 2384 theta, 2385 phi, 2386 ax, 2387 theta_color=color[i % len(color)], -> 2388 theta_ls=ls[i % len(ls)], 2389 theta_lw=lw[i % len(lw)], 2390 rho_color=color[i % len(color)], 2391 rho_ls=ls[i % len(ls)], 2392 rho_lw=lw[i % len(lw)], 2393 lcfs_color=color[i % len(color)], 2394 lcfs_ls=ls[i % len(ls)], 2395 lcfs_lw=lw[i % len(lw)], 2396 axis_color=color[i % len(color)], 2397 axis_alpha=0, 2398 axis_marker=\"o\", 2399 axis_size=0, 2400 label=labels[i % len(labels)], 2401 title_fontsize=title_fontsize, 2402 xlabel_fontsize=xlabel_fontsize, 2403 ylabel_fontsize=ylabel_fontsize, 2404 return_data=True, 2405 ) 2406 for key in _plot_data.keys(): File ~/Research/DESC/desc/plotting.py:1685, in plot_surfaces() 1676 tnr, tnt, tnz = t_grid.num_rho, t_grid.num_theta, t_grid.num_zeta 1677 v_grid = Grid( 1678 map_coordinates( 1679 eq, 1680 t_grid.nodes, 1681 [\"rho\", \"theta_PEST\", \"phi\"], 1682 [\"rho\", \"theta\", \"zeta\"], 1683 period=(np.inf, 2 * np.pi, 2 * np.pi), 1684 guess=t_grid.nodes, -> 1685 ), 1686 sort=False, 1687 ) 1688 rows = np.floor(np.sqrt(nphi)).astype(int) File ~/Research/DESC/desc/equilibrium/coords.py:218, in map_coordinates() 215 # See description here 216 # https://github.com/PlasmaControl/DESC/pull/504#discussion_r1194172532 217 # except we make sure properly handle periodic coordinates. --> 218 yk, (res, niter) = vecroot(yk, coords) 220 out = compute(yk, outbasis) File ~/Research/DESC/desc/equilibrium/coords.py:203, in map_coordinates.<locals>.<lambda>() 199 yk = fixup(yk) 201 vecroot = jit( 202 vmap( --> 203 lambda x0, *p: root( 204 residual, 205 x0, 206 jac=jac, 207 args=p, 208 fixup=fixup, 209 tol=tol, 210 maxiter=maxiter, 211 **kwargs, 212 ) 213 ) 214 ) 215 # See description here 216 # https://github.com/PlasmaControl/DESC/pull/504#discussion_r1194172532 217 # except we make sure properly handle periodic coordinates. File ~/Research/DESC/desc/backend.py:398, in root() 396 return _lstsq(A, jnp.atleast_1d(y)) --> 398 x, (res, niter) = jax.lax.custom_root( 399 res, x0, solve, tangent_solve, has_aux=True 400 ) 401 return x, (safenorm(res), niter) File ~/Research/DESC/desc/backend.py:344, in root.<locals>.<lambda>() 342 jac2 = lambda x: jnp.atleast_2d(jac(x, *args)) --> 344 res = lambda x: jnp.atleast_1d(fun(x, *args)).flatten() 346 # want to use least squares for rank-defficient systems, but 347 # jnp.linalg.lstsq doesn't have JVP defined and is slower than needed 348 # so we use the normal equations with regularized cholesky File ~/Research/DESC/desc/equilibrium/coords.py:174, in residual() 172 @jit 173 def residual(y, coords): --> 174 xk = compute(y, inbasis) 175 return _fixup_residual(xk - coords, period) File ~/Research/DESC/desc/equilibrium/coords.py:167, in compute() 166 data[\"iota_rr\"] = profiles[\"iota\"].compute(grid, dr=2, params=params[\"i_l\"]) --> 167 transforms = get_transforms(basis, eq, grid, jitable=True) 168 data = compute_fun(eq, basis, params, transforms, profiles, data) File ~/Research/DESC/desc/backend.py:112, in wrapper() 111 with jax.default_device(jax.devices(\"cpu\")[0]): --> 112 return func(*args, **kwargs) File ~/Research/DESC/desc/compute/utils.py:631, in get_transforms() 630 if hasattr(t, \"build\"): --> 631 t.build() 633 return transforms File ~/Research/DESC/desc/transform.py:385, in build() 384 for d in self.derivatives: --> 385 self.matrices[\"direct1\"][d[0]][d[1]][d[2]] = self.basis.evaluate( 386 self.grid.nodes, d, unique=False 387 ) 389 if self.method in [\"fft\", \"direct2\"]: File ~/Research/DESC/desc/basis.py:1134, in evaluate() 1132 n = n[nidx] -> 1134 radial = zernike_radial(r[:, np.newaxis], lm[:, 0], lm[:, 1], dr=derivatives[0]) 1135 poloidal = fourier(t[:, np.newaxis], m, dt=derivatives[1]) File ~/Research/DESC/desc/basis.py:1533, in zernike_radial() 1532 if dr == 0: -> 1533 out = r**m * _jacobi(n, alpha, beta, jacobi_arg, 0) 1534 elif dr == 1: JaxStackTraceBeforeTransformation: TypeError: Called multiply with a float0 array. float0s do not support any operations by design because they are not compatible with non-trivial vector spaces. No implicit dtype conversion is done. You can use np.zeros_like(arr, dtype=np.float) to cast a float0 array to a regular zeros array. If you didn't expect to get a float0 you might have accidentally taken a gradient with respect to an integer argument. The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception. -------------------- The above exception was the direct cause of the following exception: TypeError Traceback (most recent call last) Cell In[6], line 3 1 eq_fam = desc.io.load(\"input.HELIOTRON_output.h5\") 2 print(\"Number of equilibria in the EquilibriaFamily:\", len(eq_fam)) ----> 3 fig, ax = plot_comparison( 4 eqs=[eq_fam[1], eq_fam[3], eq_fam[-1]], 5 labels=[ 6 \"Axisymmetric w/o pressure\", 7 \"Axisymmetric w/ pressure\", 8 \"Nonaxisymmetric w/ pressure\", 9 ], 10 ) File ~/Research/DESC/desc/plotting.py:2388, in plot_comparison(eqs, rho, theta, phi, ax, cmap, color, lw, ls, labels, return_data, **kwargs) 2386 plot_data[string] = [] 2387 for i, eq in enumerate(eqs): -> 2388 fig, ax, _plot_data = plot_surfaces( 2389 eq, 2390 rho, 2391 theta, 2392 phi, 2393 ax, 2394 theta_color=color[i % len(color)], 2395 theta_ls=ls[i % len(ls)], 2396 theta_lw=lw[i % len(lw)], 2397 rho_color=color[i % len(color)], 2398 rho_ls=ls[i % len(ls)], 2399 rho_lw=lw[i % len(lw)], 2400 lcfs_color=color[i % len(color)], 2401 lcfs_ls=ls[i % len(ls)], 2402 lcfs_lw=lw[i % len(lw)], 2403 axis_color=color[i % len(color)], 2404 axis_alpha=0, 2405 axis_marker=\"o\", 2406 axis_size=0, 2407 label=labels[i % len(labels)], 2408 title_fontsize=title_fontsize, 2409 xlabel_fontsize=xlabel_fontsize, 2410 ylabel_fontsize=ylabel_fontsize, 2411 return_data=True, 2412 ) 2413 for key in _plot_data.keys(): 2414 plot_data[key].append(_plot_data[key]) File ~/Research/DESC/desc/plotting.py:1685, in plot_surfaces(eq, rho, theta, phi, ax, return_data, **kwargs) 1682 t_grid = _get_grid(**grid_kwargs) 1683 tnr, tnt, tnz = t_grid.num_rho, t_grid.num_theta, t_grid.num_zeta 1684 v_grid = Grid( -> 1685 map_coordinates( 1686 eq, 1687 t_grid.nodes, 1688 [\"rho\", \"theta_PEST\", \"phi\"], 1689 [\"rho\", \"theta\", \"zeta\"], 1690 period=(np.inf, 2 * np.pi, 2 * np.pi), 1691 guess=t_grid.nodes, 1692 ), 1693 sort=False, 1694 ) 1695 rows = np.floor(np.sqrt(nphi)).astype(int) 1696 cols = np.ceil(nphi / rows).astype(int) File ~/Research/DESC/desc/equilibrium/coords.py:218, in map_coordinates(eq, coords, inbasis, outbasis, guess, params, period, tol, maxiter, full_output, **kwargs) 201 vecroot = jit( 202 vmap( 203 lambda x0, *p: root( (...) 213 ) 214 ) 215 # See description here 216 # https://github.com/PlasmaControl/DESC/pull/504#discussion_r1194172532 217 # except we make sure properly handle periodic coordinates. --> 218 yk, (res, niter) = vecroot(yk, coords) 220 out = compute(yk, outbasis) 221 if full_output: [... skipping hidden 14 frame] File ~/Research/DESC/desc/equilibrium/coords.py:203, in map_coordinates.<locals>.<lambda>(x0, *p) 197 yk = _initial_guess_nn_search(coords, inbasis, eq, period, compute) 199 yk = fixup(yk) 201 vecroot = jit( 202 vmap( --> 203 lambda x0, *p: root( 204 residual, 205 x0, 206 jac=jac, 207 args=p, 208 fixup=fixup, 209 tol=tol, 210 maxiter=maxiter, 211 **kwargs, 212 ) 213 ) 214 ) 215 # See description here 216 # https://github.com/PlasmaControl/DESC/pull/504#discussion_r1194172532 217 # except we make sure properly handle periodic coordinates. 218 yk, (res, niter) = vecroot(yk, coords) File ~/Research/DESC/desc/backend.py:398, in root(fun, x0, jac, args, tol, maxiter, maxiter_ls, alpha, fixup) 395 A = jnp.atleast_2d(jax.jacfwd(g)(y)) 396 return _lstsq(A, jnp.atleast_1d(y)) --> 398 x, (res, niter) = jax.lax.custom_root( 399 res, x0, solve, tangent_solve, has_aux=True 400 ) 401 return x, (safenorm(res), niter) [... skipping hidden 14 frame] File ~/Research/DESC/desc/backend.py:344, in root.<locals>.<lambda>(x) 341 else: 342 jac2 = lambda x: jnp.atleast_2d(jac(x, *args)) --> 344 res = lambda x: jnp.atleast_1d(fun(x, *args)).flatten() 346 # want to use least squares for rank-defficient systems, but 347 # jnp.linalg.lstsq doesn't have JVP defined and is slower than needed 348 # so we use the normal equations with regularized cholesky 349 def _lstsq(a, b): [... skipping hidden 47 frame] File ~/Research/DESC/desc/basis.py:1816, in _jacobi_jvp(x, xdot) 1811 df = _jacobi(n, alpha, beta, x, dx + 1) 1812 # in theory n, alpha, beta, dx aren't differentiable (they're integers) 1813 # but marking them as non-diff argnums seems to cause escaped tracer values. 1814 # probably a more elegant fix, but just setting those derivatives to zero seems 1815 # to work fine. -> 1816 return f, df * xdot + 0 * ndot + 0 * alphadot + 0 * betadot + 0 * dxdot File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/site-packages/jax/_src/numpy/array_methods.py:1036, in _forward_operator_to_aval.<locals>.op(self, *args) 1035 def op(self, *args): -> 1036 return getattr(self.aval, f\"_{name}\")(self, *args) File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/site-packages/jax/_src/numpy/array_methods.py:573, in _defer_to_unrecognized_arg.<locals>.deferring_binary_op(self, other) 571 args = (other, self) if swap else (self, other) 572 if isinstance(other, _accepted_binop_types): --> 573 return binary_op(*args) 574 # Note: don't use isinstance here, because we don't want to raise for 575 # subclasses, e.g. NamedTuple objects that may override operators. 576 if type(other) in _rejected_binop_types: File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/site-packages/jax/_src/numpy/ufunc_api.py:177, in ufunc.__call__(self, out, where, *args) 175 raise NotImplementedError(f\"where argument of {self}\") 176 call = self.__static_props['call'] or self._call_vectorized --> 177 return call(*args) [... skipping hidden 11 frame] File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/site-packages/jax/_src/numpy/ufuncs.py:1142, in _multiply(x, y) 1115 @partial(jit, inline=True) 1116 def _multiply(x: ArrayLike, y: ArrayLike, /) -> Array: 1117 \"\"\"Multiply two arrays element-wise. 1118 1119 JAX implementation of :obj:`numpy.multiply`. This is a universal function, (...) 1140 Array([ 0, 10, 20, 30], dtype=int32) 1141 \"\"\" -> 1142 x, y = promote_args(\"multiply\", x, y) 1143 return lax.mul(x, y) if x.dtype != bool else lax.bitwise_and(x, y) File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/site-packages/jax/_src/numpy/util.py:355, in promote_args(fun_name, *args) 353 \"\"\"Convenience function to apply Numpy argument shape and dtype promotion.\"\"\" 354 check_arraylike(fun_name, *args) --> 355 _check_no_float0s(fun_name, *args) 356 check_for_prngkeys(fun_name, *args) 357 return promote_shapes(fun_name, *promote_dtypes(*args)) File ~/miniconda3/envs/desc-env-10-24/lib/python3.12/site-packages/jax/_src/numpy/util.py:326, in check_no_float0s(fun_name, *args) 324 \"\"\"Check if none of the args have dtype float0.\"\"\" 325 if any(dtypes.dtype(arg) == dtypes.float0 for arg in args): --> 326 raise TypeError( 327 f\"Called {fun_name} with a float0 array. \" 328 \"float0s do not support any operations by design because they \" 329 \"are not compatible with non-trivial vector spaces. No implicit dtype \" 330 \"conversion is done. You can use np.zeros_like(arr, dtype=np.float) \" 331 \"to cast a float0 array to a regular zeros array. \ \" 332 \"If you didn't expect to get a float0 you might have accidentally \" 333 \"taken a gradient with respect to an integer argument.\") TypeError: Called multiply with a float0 array. float0s do not support any operations by design because they are not compatible with non-trivial vector spaces. No implicit dtype conversion is done. You can use np.zeros_like(arr, dtype=np.float) to cast a float0 array to a regular zeros array. If you didn't expect to get a float0 you might have accidentally taken a gradient with respect to an integer argument." }
The text was updated successfully, but these errors were encountered:
Pin jax version temporarily (#1292)
6c599b4
Pin jax to `<0.4.34` until we fix bug listed in #1291
There are some other bugs related to JAX 0.4.34, for example equinox stuff, AttributeError: jax.core.pp_eqn_rules was removed in JAX v0.4.34.
0.4.34
equinox
AttributeError: jax.core.pp_eqn_rules was removed in JAX v0.4.34.
Sorry, something went wrong.
Ok this was due to some old version of equinox==0.11.3
equinox==0.11.3
We think the issue is caused by returning the metadata from the root solve, which can include non-differentiable data types
Successfully merging a pull request may close this issue.
Calling, for example plot_comparison in our basic equilibriim notebook yields a JAX error on 0.4.34
related to
_jacobi_jvp
which seems like we did something slightly hacky to avoid escaped trace values, maybe not JAX does not like that anymoreThe text was updated successfully, but these errors were encountered: