Skip to content

Commit

Permalink
Add jac_chunk_size keyword argument to ObjectiveFunction to reduc…
Browse files Browse the repository at this point in the history
…e memory usage of forward mode Jacobian calculation (#1052)

- changes most `jnp.vectorize` calls to instead use `batched_vectorize`
which performs the function vectorization in smaller chunks, which
reduces the memory cost of the calculation, at the expense of taking
longer the smaller the chunk size is.
- Add `jac_chunk_size` to `ObjectiveFunction` and `_Objective` to
control the above chunk size for the `fwd` mode Jacobian calculation
- if `None`, the chunk size is equal to `dim_x`, so no chunking is done
  - if an `int`, this is the chunk size to be used.
- if `"auto"` for the `ObjectiveFunction`, will use a heuristic for the
maximum `jac_chunk_size` needed to fit the jacobian calculation on the
available device memory, according to the formula: `max_jac_chunk_size =
(desc_config.get("avail_mem") / estimated_memory_usage - 0.22) / 0.85 *
self.dim_x`
- the `ObjectiveFunction` `jac_chunk_size` is used if
`deriv_mode="batched"`, and the `_Objective` `jac_chunk_size` will be
used if `deriv_mode="blocked"`


This works well, this is LMN18 equilibrium solve with 1.5 oversampled
grid and `maxiter=10` memory trace vs time on GPU, where we get 4x
memory decrease with negligible runtime increase:

<img width="501" alt="image"
src="https://github.com/PlasmaControl/DESC/assets/37969854/0ed2eba2-a887-4e51-b748-29ffa599d67c">

Also, I can do up to an `LMN=20` eq `ForceBalance` objective with the
default double grid oversampling, and with the `"auto"` chunk sizing,
the jacobian compiles and computes without going OOM on an 80gb GPU (on
master this would go OOM).

TODO
- [x] re-implement without relying on `netket`
- [x] change chunk_size to a better default value (something like 100
would be fine, maybe can dynamically choose based off of size of
`dim_x`)
- [x] Add `chunk_size` argument to every Objective class
- [x] I am choosing right now to not to add it as an arg to the
`LinearObjective` classes, though technically you could
- [x] Add `"chunked"` as a deriv_mode to `Derivative` (or, just as an
argument to `Derivative` to be used when `"batched"` is used) - > I
don't remember what this was exactly, I think we can keep just for
Objectives
- [x] change `chunk_size` to `jacobian_chunk_size` for Objective kwarg
- [x] use in constraint wrappers

TODO Later
- [ ] add to singular integral calculation as well


Resolves #826
  • Loading branch information
dpanici authored Sep 26, 2024
2 parents 8419b4d + 3e99510 commit 17c2b15
Show file tree
Hide file tree
Showing 25 changed files with 1,159 additions and 99 deletions.
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,20 @@ New Features
- Changes ``ToroidalFlux`` objective to default using a 1D loop integral of the vector potential
to compute the toroidal flux when possible, as opposed to a 2D surface integral of the magnetic field dotted with ``n_zeta``.
- Allow specification of Nyquist spectrum maximum modenumbers when using ``VMECIO.save`` to save a DESC .h5 file as a VMEC-format wout file
- Add ``jac_chunk_size`` to ``ObjectiveFunction`` and ``_Objective`` to control the above chunk size for the ``fwd`` mode Jacobian calculation
- if ``None``, the chunk size is equal to ``dim_x``, so no chunking is done
- if an ``int``, this is the chunk size to be used.
- if ``"auto"`` for the ``ObjectiveFunction``, will use a heuristic for the maximum ``jac_chunk_size`` needed to fit the jacobian calculation on the available device memory, according to the formula: ``max_jac_chunk_size = (desc_config.get("avail_mem") / estimated_memory_usage - 0.22) / 0.85 * self.dim_x`` with ``estimated_memory_usage = 2.4e-7 * self.dim_f * self.dim_x + 1``
- the ``ObjectiveFunction`` ``jac_chunk_size`` is used if ``deriv_mode="batched"``, and the ``_Objective`` ``jac_chunk_size`` will be used if ``deriv_mode="blocked"``

Bug Fixes

- Fixes bugs that occur when saving asymmetric equilibria as wout files
- Fixes bug that occurs when using ``VMECIO.plot_vmec_comparison`` to compare to an asymmetric wout file

Deprecations

- ``deriv_mode="looped"`` in ``ObjectiveFunction`` is deprecated and will be removed in a future version in favored of ``deriv_mode="batched"`` with ``jac_chunk_size=1``,


v0.12.1
Expand Down
322 changes: 322 additions & 0 deletions desc/batching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,322 @@
"""Utility functions for the ``batched_vectorize`` function."""

import functools
from typing import Callable, Optional

from desc.backend import jax, jnp

if jax.__version_info__ >= (0, 4, 16):
from jax.extend import linear_util as lu
else:
from jax import linear_util as lu

from jax._src.numpy.vectorize import (
_apply_excluded,
_check_output_dims,
_parse_gufunc_signature,
_parse_input_dimensions,
)

# The following section of this code is derived from the NetKet project
# https://github.com/netket/netket/blob/9881c9fb217a2ac4dc9274a054bf6e6a2993c519/
# netket/jax/_chunk_utils.py
#
# The original copyright notice is as follows
# Copyright 2021 The NetKet Authors - All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");


def _treeify(f):
def _f(x, *args, **kwargs):
return jax.tree_util.tree_map(lambda y: f(y, *args, **kwargs), x)

return _f


@_treeify
def _unchunk(x):
return x.reshape((-1,) + x.shape[2:])


@_treeify
def _chunk(x, chunk_size=None):
# chunk_size=None -> add just a dummy chunk dimension,
# same as np.expand_dims(x, 0)
if x.ndim == 0:
raise ValueError("x cannot be chunked as it has 0 dimensions.")
n = x.shape[0]
if chunk_size is None:
chunk_size = n

n_chunks, residual = divmod(n, chunk_size)
if residual != 0:
raise ValueError(
"The first dimension of x must be divisible by chunk_size."
+ f"\n Got x.shape={x.shape} but chunk_size={chunk_size}."
)
return x.reshape((n_chunks, chunk_size) + x.shape[1:])


####

# The following section of this code is derived from the NetKet project
# https://github.com/netket/netket/blob/9881c9fb217a2ac4dc9274a054bf6e6a2993c519/
# netket/jax/_scanmap.py


def scan_append(f, x):
"""Evaluate f element by element in x while appending the results.
Parameters
----------
f: a function that takes elements of the leading dimension of x
x: a pytree where each leaf array has the same leading dimension
Returns
-------
a (pytree of) array(s) with leading dimension same as x,
containing the evaluation of f at each element in x
"""
carry_init = True

def f_(carry, x):
return False, f(x)

_, res_append = jax.lax.scan(f_, carry_init, x, unroll=1)
return res_append


# TODO in_axes a la vmap?
def _scanmap(fun, scan_fun, argnums=0):
"""A helper function to wrap f with a scan_fun."""

def f_(*args, **kwargs):
f = lu.wrap_init(fun, kwargs)
f_partial, dyn_args = jax.api_util.argnums_partial(
f, argnums, args, require_static_args_hashable=False
)
return scan_fun(lambda x: f_partial.call_wrapped(*x), dyn_args)

return f_


# The following section of this code is derived from the NetKet project
# https://github.com/netket/netket/blob/9881c9fb217a2ac4dc9274a054bf6e6a2993c519/
# netket/jax/_vmap_chunked.py


def _eval_fun_in_chunks(vmapped_fun, chunk_size, argnums, *args, **kwargs):
n_elements = jax.tree_util.tree_leaves(args[argnums[0]])[0].shape[0]
n_chunks, n_rest = divmod(n_elements, chunk_size)

if n_chunks == 0 or chunk_size >= n_elements:
y = vmapped_fun(*args, **kwargs)
else:
# split inputs
def _get_chunks(x):
x_chunks = jax.tree_util.tree_map(
lambda x_: x_[: n_elements - n_rest, ...], x
)
x_chunks = _chunk(x_chunks, chunk_size)
return x_chunks

def _get_rest(x):
x_rest = jax.tree_util.tree_map(
lambda x_: x_[n_elements - n_rest :, ...], x
)
return x_rest

args_chunks = [
_get_chunks(a) if i in argnums else a for i, a in enumerate(args)
]
args_rest = [_get_rest(a) if i in argnums else a for i, a in enumerate(args)]

y_chunks = _unchunk(
_scanmap(vmapped_fun, scan_append, argnums)(*args_chunks, **kwargs)
)

if n_rest == 0:
y = y_chunks
else:
y_rest = vmapped_fun(*args_rest, **kwargs)
y = jax.tree_util.tree_map(
lambda y1, y2: jnp.concatenate((y1, y2)), y_chunks, y_rest
)
return y


def _chunk_vmapped_function(
vmapped_fun: Callable,
chunk_size: Optional[int],
argnums=0,
) -> Callable:
"""Takes a vmapped function and computes it in chunks."""
if chunk_size is None:
return vmapped_fun

if isinstance(argnums, int):
argnums = (argnums,)
return functools.partial(_eval_fun_in_chunks, vmapped_fun, chunk_size, argnums)


def _parse_in_axes(in_axes):
if isinstance(in_axes, int):
in_axes = (in_axes,)

if not set(in_axes).issubset((0, None)):
raise NotImplementedError("Only in_axes 0/None are currently supported")

argnums = tuple(
map(lambda ix: ix[0], filter(lambda ix: ix[1] is not None, enumerate(in_axes)))
)
return in_axes, argnums


def vmap_chunked(
f: Callable,
in_axes=0,
*,
chunk_size: Optional[int],
) -> Callable:
"""Behaves like jax.vmap but uses scan to chunk the computations in smaller chunks.
Parameters
----------
f: The function to be vectorised.
in_axes: The axes that should be scanned along. Only supports `0` or `None`
chunk_size: The maximum size of the chunks to be used. If it is `None`,
chunking is disabled
Returns
-------
f: A vectorised and chunked function
"""
in_axes, argnums = _parse_in_axes(in_axes)
vmapped_fun = jax.vmap(f, in_axes=in_axes)
return _chunk_vmapped_function(vmapped_fun, chunk_size, argnums)


def batched_vectorize(pyfunc, *, excluded=frozenset(), signature=None, chunk_size=None):
"""Define a vectorized function with broadcasting and batching.
:func:`vectorize` is a convenience wrapper for defining vectorized
functions with broadcasting, in the style of NumPy's
`generalized universal functions
<https://numpy.org/doc/stable/reference/c-api/generalized-ufuncs.html>`_.
It allows for defining functions that are automatically repeated across
any leading dimensions, without the implementation of the function needing to
be concerned about how to handle higher dimensional inputs.
:func:`jax.numpy.vectorize` has the same interface as
:class:`numpy.vectorize`, but it is syntactic sugar for an auto-batching
transformation (:func:`vmap`) rather than a Python loop. This should be
considerably more efficient, but the implementation must be written in terms
of functions that act on JAX arrays.
Parameters
----------
pyfunc: callable,function to vectorize.
excluded: optional set of integers representing positional arguments for
which the function will not be vectorized. These will be passed directly
to ``pyfunc`` unmodified.
signature: optional generalized universal function signature, e.g.,
``(m,n),(n)->(m)`` for vectorized matrix-vector multiplication. If
provided, ``pyfunc`` will be called with (and expected to return) arrays
with shapes given by the size of corresponding core dimensions. By
default, pyfunc is assumed to take scalars arrays as input and output.
chunk_size: the size of the batches to pass to vmap. If None, defaults to
the largest possible chunk_size (like the default behavior of ``vectorize11)
Returns
-------
Batch-vectorized version of the given function.
"""
if any(not isinstance(exclude, (str, int)) for exclude in excluded):
raise TypeError(
"jax.numpy.vectorize can only exclude integer or string arguments, "
"but excluded={!r}".format(excluded)
)
if any(isinstance(e, int) and e < 0 for e in excluded):
raise ValueError(f"excluded={excluded!r} contains negative numbers")

@functools.wraps(pyfunc)
def wrapped(*args, **kwargs):
error_context = (
"on vectorized function with excluded={!r} and "
"signature={!r}".format(excluded, signature)
)
excluded_func, args, kwargs = _apply_excluded(pyfunc, excluded, args, kwargs)

if signature is not None:
input_core_dims, output_core_dims = _parse_gufunc_signature(signature)
else:
input_core_dims = [()] * len(args)
output_core_dims = None

none_args = {i for i, arg in enumerate(args) if arg is None}
if any(none_args):
if any(input_core_dims[i] != () for i in none_args):
raise ValueError(
f"Cannot pass None at locations {none_args} with {signature=}"
)
excluded_func, args, _ = _apply_excluded(excluded_func, none_args, args, {})
input_core_dims = [
dim for i, dim in enumerate(input_core_dims) if i not in none_args
]

args = tuple(map(jnp.asarray, args))

broadcast_shape, dim_sizes = _parse_input_dimensions(
args, input_core_dims, error_context
)

checked_func = _check_output_dims(
excluded_func, dim_sizes, output_core_dims, error_context
)

# Rather than broadcasting all arguments to full broadcast shapes, prefer
# expanding dimensions using vmap. By pushing broadcasting
# into vmap, we can make use of more efficient batching rules for
# primitives where only some arguments are batched (e.g., for
# lax_linalg.triangular_solve), and avoid instantiating large broadcasted
# arrays.

squeezed_args = []
rev_filled_shapes = []

for arg, core_dims in zip(args, input_core_dims):
noncore_shape = arg.shape[: arg.ndim - len(core_dims)]

pad_ndim = len(broadcast_shape) - len(noncore_shape)
filled_shape = pad_ndim * (1,) + noncore_shape
rev_filled_shapes.append(filled_shape[::-1])

squeeze_indices = tuple(
i for i, size in enumerate(noncore_shape) if size == 1
)
squeezed_arg = jnp.squeeze(arg, axis=squeeze_indices)
squeezed_args.append(squeezed_arg)

vectorized_func = checked_func
dims_to_expand = []
for negdim, axis_sizes in enumerate(zip(*rev_filled_shapes)):
in_axes = tuple(None if size == 1 else 0 for size in axis_sizes)
if all(axis is None for axis in in_axes):
dims_to_expand.append(len(broadcast_shape) - 1 - negdim)
else:
# change the vmap here to chunked_vmap
vectorized_func = vmap_chunked(
vectorized_func, in_axes, chunk_size=chunk_size
)
result = vectorized_func(*squeezed_args)

if not dims_to_expand:
return result
elif isinstance(result, tuple):
return tuple(jnp.expand_dims(r, axis=dims_to_expand) for r in result)
else:
return jnp.expand_dims(result, axis=dims_to_expand)

return wrapped
Loading

0 comments on commit 17c2b15

Please sign in to comment.