Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add jac_chunk_size keyword argument to ObjectiveFunction to reduce memory usage of forward mode Jacobian calculation #1052

Merged
merged 74 commits into from
Sep 26, 2024

Conversation

dpanici
Copy link
Collaborator

@dpanici dpanici commented Jun 13, 2024

  • changes most jnp.vectorize calls to instead use batched_vectorize which performs the function vectorization in smaller chunks, which reduces the memory cost of the calculation, at the expense of taking longer the smaller the chunk size is.
  • Add jac_chunk_size to ObjectiveFunction and _Objective to control the above chunk size for the fwd mode Jacobian calculation
    • if None, the chunk size is equal to dim_x, so no chunking is done
    • if an int, this is the chunk size to be used.
    • if "auto" for the ObjectiveFunction, will use a heuristic for the maximum jac_chunk_size needed to fit the jacobian calculation on the available device memory, according to the formula: max_jac_chunk_size = (desc_config.get("avail_mem") / estimated_memory_usage - 0.22) / 0.85 * self.dim_x
  • the ObjectiveFunction jac_chunk_size is used if deriv_mode="batched", and the _Objective jac_chunk_size will be used if deriv_mode="blocked"

This works well, this is LMN18 equilibrium solve with 1.5 oversampled grid and maxiter=10 memory trace vs time on GPU, where we get 4x memory decrease with negligible runtime increase:

image

Also, I can do up to an LMN=20 eq ForceBalance objective with the default double grid oversampling, and with the "auto" chunk sizing, the jacobian compiles and computes without going OOM on an 80gb GPU (on master this would go OOM).

TODO

  • re-implement without relying on netket
  • change chunk_size to a better default value (something like 100 would be fine, maybe can dynamically choose based off of size of dim_x)
  • Add chunk_size argument to every Objective class
    • I am choosing right now to not to add it as an arg to the LinearObjective classes, though technically you could
  • Add "chunked" as a deriv_mode to Derivative (or, just as an argument to Derivative to be used when "batched" is used) - > I don't remember what this was exactly, I think we can keep just for Objectives
  • change chunk_size to jacobian_chunk_size for Objective kwarg
  • use in constraint wrappers

TODO Later

  • add to singular integral calculation as well

Resolves #826

Copy link

codecov bot commented Jun 13, 2024

Codecov Report

Attention: Patch coverage is 89.44099% with 17 lines in your changes missing coverage. Please review.

Project coverage is 92.19%. Comparing base (15d95f1) to head (3e99510).
Report is 5 commits behind head on master.

Files with missing lines Patch % Lines
desc/batching.py 85.83% 17 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #1052      +/-   ##
==========================================
- Coverage   95.37%   92.19%   -3.19%     
==========================================
  Files          96       96              
  Lines       23560    23560              
==========================================
- Hits        22471    21721     -750     
- Misses       1089     1839     +750     
Files with missing lines Coverage Δ
desc/continuation.py 93.26% <100.00%> (ø)
desc/derivatives.py 92.85% <100.00%> (-0.43%) ⬇️
desc/objectives/_bootstrap.py 97.14% <ø> (ø)
desc/objectives/_coils.py 99.17% <ø> (ø)
desc/objectives/_equilibrium.py 94.53% <ø> (-0.43%) ⬇️
desc/objectives/_free_boundary.py 82.83% <ø> (-14.20%) ⬇️
desc/objectives/_generic.py 67.76% <ø> (-29.76%) ⬇️
desc/objectives/_geometry.py 96.93% <ø> (ø)
desc/objectives/_omnigenity.py 96.30% <ø> (ø)
desc/objectives/_power_balance.py 87.50% <ø> (-2.09%) ⬇️
... and 7 more

... and 22 files with indirect coverage changes

Copy link
Contributor

github-actions bot commented Jun 13, 2024

|             benchmark_name             |         dt(%)          |         dt(s)          |        t_new(s)        |        t_old(s)        | 
| -------------------------------------- | ---------------------- | ---------------------- | ---------------------- | ---------------------- |
 test_build_transform_fft_lowres         |     +3.16 +/- 3.96     | +1.69e-02 +/- 2.12e-02 |  5.53e-01 +/- 2.0e-02  |  5.36e-01 +/- 6.8e-03  |
 test_equilibrium_init_medres            |     -0.44 +/- 5.40     | -1.94e-02 +/- 2.36e-01 |  4.34e+00 +/- 1.0e-01  |  4.36e+00 +/- 2.1e-01  |
 test_equilibrium_init_highres           |     -0.74 +/- 2.41     | -4.25e-02 +/- 1.39e-01 |  5.73e+00 +/- 1.2e-01  |  5.77e+00 +/- 6.4e-02  |
 test_objective_compile_dshape_current   |     -1.42 +/- 1.53     | -5.72e-02 +/- 6.15e-02 |  3.97e+00 +/- 5.0e-02  |  4.03e+00 +/- 3.6e-02  |
 test_objective_compute_dshape_current   |     -1.97 +/- 3.73     | -7.30e-05 +/- 1.39e-04 |  3.64e-03 +/- 4.4e-05  |  3.71e-03 +/- 1.3e-04  |
 test_objective_jac_dshape_current       |     -0.67 +/- 4.78     | -2.76e-04 +/- 1.96e-03 |  4.08e-02 +/- 1.4e-03  |  4.11e-02 +/- 1.3e-03  |
 test_perturb_2                          |     +0.42 +/- 3.47     | +7.51e-02 +/- 6.14e-01 |  1.78e+01 +/- 5.1e-01  |  1.77e+01 +/- 3.5e-01  |
 test_proximal_freeb_jac                 |     -0.28 +/- 1.56     | -2.12e-02 +/- 1.17e-01 |  7.51e+00 +/- 7.8e-02  |  7.53e+00 +/- 8.8e-02  |
 test_solve_fixed_iter                   |     +0.40 +/- 57.53    | +2.00e-02 +/- 2.88e+00 |  5.03e+00 +/- 2.0e+00  |  5.01e+00 +/- 2.1e+00  |
 test_build_transform_fft_midres         |     -0.76 +/- 5.52     | -4.77e-03 +/- 3.46e-02 |  6.23e-01 +/- 1.1e-02  |  6.28e-01 +/- 3.3e-02  |
 test_build_transform_fft_highres        |     -0.34 +/- 3.28     | -3.46e-03 +/- 3.36e-02 |  1.02e+00 +/- 9.3e-03  |  1.02e+00 +/- 3.2e-02  |
 test_equilibrium_init_lowres            |     +1.50 +/- 3.83     | +5.83e-02 +/- 1.49e-01 |  3.95e+00 +/- 1.5e-01  |  3.89e+00 +/- 3.4e-02  |
 test_objective_compile_atf              |     -0.08 +/- 4.11     | -6.25e-03 +/- 3.25e-01 |  7.90e+00 +/- 2.4e-01  |  7.91e+00 +/- 2.2e-01  |
 test_objective_compute_atf              |     +2.00 +/- 2.81     | +2.10e-04 +/- 2.97e-04 |  1.07e-02 +/- 2.5e-04  |  1.05e-02 +/- 1.5e-04  |
 test_objective_jac_atf                  |     +1.18 +/- 2.10     | +2.33e-02 +/- 4.16e-02 |  2.00e+00 +/- 3.0e-02  |  1.98e+00 +/- 2.8e-02  |
 test_perturb_1                          |     +7.72 +/- 3.83     | +9.70e-01 +/- 4.81e-01 |  1.35e+01 +/- 4.2e-01  |  1.26e+01 +/- 2.3e-01  |
 test_proximal_jac_atf                   |     +1.08 +/- 0.76     | +8.87e-02 +/- 6.27e-02 |  8.29e+00 +/- 4.7e-02  |  8.20e+00 +/- 4.1e-02  |
 test_proximal_freeb_compute             |     +2.84 +/- 1.08     | +5.27e-03 +/- 2.00e-03 |  1.91e-01 +/- 1.8e-03  |  1.86e-01 +/- 9.6e-04  |

@unalmis
Copy link
Collaborator

unalmis commented Jun 14, 2024

You might already be aware but fyi: jax-ml/jax#19614

@PhilipVinc
Copy link

If you don't care about jax's native multi-GPU sharding support it should be easy to just vendor our implementation.
In that case, you can just vendor our netket/jax/_chunk_utils.py, netket/jax/_scanmap.py and netket/jax/_vmap_chunked.py .

The former 2 files are on purpose standalone. Only _vmap_chunked depends on other things but only if you are using sharding.

Remove all branches hitting of axis_0_is_sharded == True and config.netket_experimental_sharding, which will allow you to remove sharding_decorator, which is a mess only needed to support efficient sharding of jax arrays.

Also replace HashablePartial with functools.partial

@dpanici
Copy link
Collaborator Author

dpanici commented Jul 10, 2024

@dpanici jax batched vmap has been merged to master

@dpanici
Copy link
Collaborator Author

dpanici commented Jul 23, 2024

@dpanici make separate branch with the implementation using JAX's version, and in this PR implement the one based off of netket

@dpanici
Copy link
Collaborator Author

dpanici commented Aug 7, 2024

@kianorr @YigitElma

@dpanici dpanici marked this pull request as ready for review August 22, 2024 20:19
@dpanici dpanici requested review from f0uriest and a team September 24, 2024 16:17
f0uriest
f0uriest previously approved these changes Sep 25, 2024
f0uriest
f0uriest previously approved these changes Sep 25, 2024
YigitElma
YigitElma previously approved these changes Sep 26, 2024
CHANGELOG.md Outdated Show resolved Hide resolved

Parameters
----------
f: a function that takes elements of the leading dimension of x
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not consistent with our docstring format but no problem. Just pointing out.

of functions that act on JAX arrays.

Parameters
----------
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The above docstrings are not too important but this one might be checked more, so maybe make it consistent with out doc format? Again, not too important, we can change it in a later PR.

@@ -474,8 +553,6 @@ def jac_scaled_error(self, x, constants=None):

if self._deriv_mode == "batched":
J = Derivative(self.compute_scaled_error, mode="fwd")(x, constants)
if self._deriv_mode == "looped":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a completely unrelated comment but I think we should explain this jacobian, jvp, etc better in some dev guide. For example, an individual objectives jac_scaled_error doesn't use batched_vmap it is usually jax.jacfwd. On the other hand, when we wrap the objective and constraints, the jacobian is calculated by corresponding jvp_ method, not jac_ method. Even now it confuses me. What is the case we use jac_ method instead of jvp_? Good thing to clarify in long-waiting dev-guide

@dpanici dpanici merged commit 17c2b15 into master Sep 26, 2024
22 of 24 checks passed
@dpanici dpanici deleted the dp/jacobian-batched-vmap branch September 26, 2024 18:56
# into vmap, we can make use of more efficient batching rules for
# primitives where only some arguments are batched (e.g., for
# lax_linalg.triangular_solve), and avoid instantiating large broadcasted
# arrays.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

commenting for reference

YigitElma added a commit that referenced this pull request Oct 8, 2024
Most of the documentation of our objectives has the same parameters that
are inherited from the main `_Objective` class. This PR removes the
repeated docstring from each objective and updates the docstring by
inheriting which reduces the lines of code and also facilitates the
maintenance. For example, when we add `jac_chunk_size` in #1052, we have
to copy-paste the docs to every single objective which is tedious.

Introduces `collect_docs` function that creates docstring for common
parameters and with option to overwrite user can give a custom
definition for a parameter without changing the order of the docs

Resolves #879
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Use batched vmap to reduce memory usage
6 participants