Skip to content

Commit

Permalink
address Rory's comments
Browse files Browse the repository at this point in the history
  • Loading branch information
YigitElma committed Dec 21, 2024
1 parent 601da8a commit a4ab21e
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 81 deletions.
2 changes: 1 addition & 1 deletion docs/performance_tips.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ Although the compiled code is fast, it still takes time to compile. If you are r
This will use a directory called ``jax-caches`` in the parent directory of the script to store the compiled code. The ``jax_persistent_cache_min_entry_size_bytes`` and ``jax_persistent_cache_min_compile_time_secs`` parameters are set to -1 and 0, respectively, to ensure that all compiled code is cached. For more details on caching, refer to official JAX documentation [`here <https://jax.readthedocs.io/en/latest/persistent_compilation_cache.html#persistent-compilation-cache>`__].

Note: Updating JAX version might re-compile some previously cached code, and thi might increase the cache size. Every once in a while, you might need to clear your cache directory.
Note: Updating JAX version might re-compile some previously cached code, and this might increase the cache size. Every once in a while, you might need to clear your cache directory.


Reducing Memory Size of Objective Jacobian Calculation
Expand Down
15 changes: 7 additions & 8 deletions tests/benchmarks/benchmark_cpu_small.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
"""Benchmarks for timing comparison on cpu (that are small enough to run on CI)."""

import jax
import numpy as np
import pytest

import desc

desc.set_device("cpu")
import desc.examples
from desc.backend import jax
from desc.basis import FourierZernikeBasis
from desc.equilibrium import Equilibrium
from desc.grid import ConcentricGrid, LinearGrid
Expand Down Expand Up @@ -99,7 +99,7 @@ def build():
N = 5
_ = Equilibrium(L=L, M=M, N=N)

benchmark.pedantic(build, setup=setup, iterations=1, rounds=15)
benchmark.pedantic(build, setup=setup, iterations=1, rounds=20)


@pytest.mark.benchmark()
Expand All @@ -115,7 +115,7 @@ def build():
N = 15
_ = Equilibrium(L=L, M=M, N=N)

benchmark.pedantic(build, setup=setup, iterations=1, rounds=15)
benchmark.pedantic(build, setup=setup, iterations=1, rounds=20)


@pytest.mark.benchmark()
Expand All @@ -131,7 +131,7 @@ def build():
N = 25
_ = Equilibrium(L=L, M=M, N=N)

benchmark.pedantic(build, setup=setup, iterations=1, rounds=10)
benchmark.pedantic(build, setup=setup, iterations=1, rounds=20)


@pytest.mark.slow
Expand Down Expand Up @@ -234,7 +234,7 @@ def test_objective_jac_dshape_current(benchmark):
def run(x, objective):
objective.jac_scaled_error(x, objective.constants).block_until_ready()

benchmark.pedantic(run, args=(x, objective), rounds=50, iterations=1)
benchmark.pedantic(run, args=(x, objective), rounds=80, iterations=1)


@pytest.mark.slow
Expand Down Expand Up @@ -288,7 +288,7 @@ def setup():
}
return args, kwargs

benchmark.pedantic(perturb, setup=setup, rounds=8, iterations=1)
benchmark.pedantic(perturb, setup=setup, rounds=10, iterations=1)


@pytest.mark.slow
Expand Down Expand Up @@ -321,14 +321,13 @@ def setup():
}
return args, kwargs

benchmark.pedantic(perturb, setup=setup, rounds=8, iterations=1)
benchmark.pedantic(perturb, setup=setup, rounds=10, iterations=1)


@pytest.mark.slow
@pytest.mark.benchmark
def test_proximal_jac_atf(benchmark):
"""Benchmark computing jacobian of constrained proximal projection."""
jax.clear_caches()
eq = desc.examples.get("ATF")
grid = LinearGrid(M=eq.M_grid, N=eq.N_grid, NFP=eq.NFP, rho=np.linspace(0.1, 1, 10))
objective = ObjectiveFunction(QuasisymmetryTwoTerm(eq, grid=grid))
Expand Down
117 changes: 45 additions & 72 deletions tests/benchmarks/benchmark_gpu_small.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
"""Benchmarks for timing comparison on gpu (that are small enough to run on CI)."""

import jax
import numpy as np
import pytest

import desc

desc.set_device("gpu")
import desc.examples
from desc.backend import jax
from desc.basis import FourierZernikeBasis
from desc.equilibrium import Equilibrium
from desc.grid import ConcentricGrid, LinearGrid
Expand Down Expand Up @@ -45,7 +45,7 @@ def build():
transf = Transform(grid, basis, method="fft", build=False)
transf.build()

benchmark.pedantic(build, setup=setup, iterations=1, rounds=50)
benchmark.pedantic(build, setup=setup, iterations=1, rounds=20)


@pytest.mark.benchmark()
Expand All @@ -64,7 +64,7 @@ def build():
transf = Transform(grid, basis, method="fft", build=False)
transf.build()

benchmark.pedantic(build, setup=setup, iterations=1, rounds=50)
benchmark.pedantic(build, setup=setup, iterations=1, rounds=20)


@pytest.mark.benchmark()
Expand All @@ -83,7 +83,7 @@ def build():
transf = Transform(grid, basis, method="fft", build=False)
transf.build()

benchmark.pedantic(build, setup=setup, iterations=1, rounds=50)
benchmark.pedantic(build, setup=setup, iterations=1, rounds=20)


@pytest.mark.benchmark()
Expand All @@ -99,7 +99,7 @@ def build():
N = 5
_ = Equilibrium(L=L, M=M, N=N)

benchmark.pedantic(build, setup=setup, iterations=1, rounds=50)
benchmark.pedantic(build, setup=setup, iterations=1, rounds=20)


@pytest.mark.benchmark()
Expand All @@ -115,7 +115,7 @@ def build():
N = 15
_ = Equilibrium(L=L, M=M, N=N)

benchmark.pedantic(build, setup=setup, iterations=1, rounds=50)
benchmark.pedantic(build, setup=setup, iterations=1, rounds=20)


@pytest.mark.benchmark()
Expand All @@ -131,67 +131,53 @@ def build():
N = 25
_ = Equilibrium(L=L, M=M, N=N)

benchmark.pedantic(build, setup=setup, iterations=1, rounds=50)
benchmark.pedantic(build, setup=setup, iterations=1, rounds=20)


@pytest.mark.slow
@pytest.mark.benchmark
def test_objective_compile_dshape_current(benchmark):
"""Benchmark compiling objective."""
eq = desc.examples.get("DSHAPE_current")
objective = LinearConstraintProjection(
get_equilibrium_objective(eq),
ObjectiveFunction(
maybe_add_self_consistency(eq, get_fixed_boundary_constraints(eq)),
),
)
objective.build(eq)

def setup():
def run(objective):
jax.clear_caches()
eq = desc.examples.get("DSHAPE_current")
objective = LinearConstraintProjection(
get_equilibrium_objective(eq),
ObjectiveFunction(
maybe_add_self_consistency(eq, get_fixed_boundary_constraints(eq)),
),
)
objective.build(eq)
args = (
objective,
eq,
)
kwargs = {}
return args, kwargs

def run(objective, eq):
objective.compile()

benchmark.pedantic(run, setup=setup, rounds=10, iterations=1)
benchmark.pedantic(run, args=(objective,), rounds=10, iterations=1)


@pytest.mark.slow
@pytest.mark.benchmark
def test_objective_compile_atf(benchmark):
"""Benchmark compiling objective."""
eq = desc.examples.get("ATF")
objective = LinearConstraintProjection(
get_equilibrium_objective(eq),
ObjectiveFunction(
maybe_add_self_consistency(eq, get_fixed_boundary_constraints(eq)),
),
)
objective.build(eq)

def setup():
def run(objective):
jax.clear_caches()
eq = desc.examples.get("ATF")
objective = LinearConstraintProjection(
get_equilibrium_objective(eq),
ObjectiveFunction(
maybe_add_self_consistency(eq, get_fixed_boundary_constraints(eq)),
),
)
objective.build(eq)
args = (objective, eq)
kwargs = {}
return args, kwargs

def run(objective, eq):
objective.compile()

benchmark.pedantic(run, setup=setup, rounds=10, iterations=1)
benchmark.pedantic(run, args=(objective,), rounds=10, iterations=1)


@pytest.mark.slow
@pytest.mark.benchmark
def test_objective_compute_dshape_current(benchmark):
"""Benchmark computing objective."""
jax.clear_caches()
eq = desc.examples.get("DSHAPE_current")
objective = LinearConstraintProjection(
get_equilibrium_objective(eq),
Expand All @@ -206,14 +192,13 @@ def test_objective_compute_dshape_current(benchmark):
def run(x, objective):
objective.compute_scaled_error(x, objective.constants).block_until_ready()

benchmark.pedantic(run, args=(x, objective), rounds=50, iterations=1)
benchmark.pedantic(run, args=(x, objective), rounds=100, iterations=1)


@pytest.mark.slow
@pytest.mark.benchmark
def test_objective_compute_atf(benchmark):
"""Benchmark computing objective."""
jax.clear_caches()
eq = desc.examples.get("ATF")
objective = LinearConstraintProjection(
get_equilibrium_objective(eq),
Expand All @@ -228,14 +213,13 @@ def test_objective_compute_atf(benchmark):
def run(x, objective):
objective.compute_scaled_error(x, objective.constants).block_until_ready()

benchmark.pedantic(run, args=(x, objective), rounds=50, iterations=1)
benchmark.pedantic(run, args=(x, objective), rounds=100, iterations=1)


@pytest.mark.slow
@pytest.mark.benchmark
def test_objective_jac_dshape_current(benchmark):
"""Benchmark computing jacobian."""
jax.clear_caches()
eq = desc.examples.get("DSHAPE_current")
objective = LinearConstraintProjection(
get_equilibrium_objective(eq),
Expand All @@ -247,17 +231,16 @@ def test_objective_jac_dshape_current(benchmark):
objective.compile()
x = objective.x(eq)

def run(x):
objective.jac_scaled(x, objective.constants).block_until_ready()
def run(x, objective):
objective.jac_scaled_error(x, objective.constants).block_until_ready()

benchmark.pedantic(run, args=(x,), rounds=15, iterations=1)
benchmark.pedantic(run, args=(x, objective), rounds=80, iterations=1)


@pytest.mark.slow
@pytest.mark.benchmark
def test_objective_jac_atf(benchmark):
"""Benchmark computing jacobian."""
jax.clear_caches()
eq = desc.examples.get("ATF")
objective = LinearConstraintProjection(
get_equilibrium_objective(eq),
Expand All @@ -269,10 +252,10 @@ def test_objective_jac_atf(benchmark):
objective.compile()
x = objective.x(eq)

def run(x):
objective.jac_scaled(x, objective.constants).block_until_ready()
def run(x, objective):
objective.jac_scaled_error(x, objective.constants).block_until_ready()

benchmark.pedantic(run, args=(x,), rounds=15, iterations=1)
benchmark.pedantic(run, args=(x, objective), rounds=20, iterations=1)


@pytest.mark.slow
Expand Down Expand Up @@ -345,7 +328,6 @@ def setup():
@pytest.mark.benchmark
def test_proximal_jac_atf(benchmark):
"""Benchmark computing jacobian of constrained proximal projection."""
jax.clear_caches()
eq = desc.examples.get("ATF")
grid = LinearGrid(M=eq.M_grid, N=eq.N_grid, NFP=eq.NFP, rho=np.linspace(0.1, 1, 10))
objective = ObjectiveFunction(QuasisymmetryTwoTerm(eq, grid=grid))
Expand All @@ -355,17 +337,16 @@ def test_proximal_jac_atf(benchmark):
prox.compile()
x = prox.x(eq)

def run(x):
prox.jac_scaled(x, prox.constants).block_until_ready()
def run(x, prox):
prox.jac_scaled_error(x, prox.constants).block_until_ready()

benchmark.pedantic(run, args=(x,), rounds=15, iterations=1)
benchmark.pedantic(run, args=(x, prox), rounds=10, iterations=1)


@pytest.mark.slow
@pytest.mark.benchmark
def test_proximal_freeb_compute(benchmark):
"""Benchmark computing free boundary objective with proximal constraint."""
jax.clear_caches()
eq = desc.examples.get("ESTELL")
with pytest.warns(UserWarning, match="Reducing radial"):
eq.change_resolution(6, 6, 6, 12, 12, 12)
Expand All @@ -380,17 +361,16 @@ def test_proximal_freeb_compute(benchmark):
obj.compile()
x = obj.x(eq)

def run(x):
def run(x, obj):
obj.compute_scaled_error(x, obj.constants).block_until_ready()

benchmark.pedantic(run, args=(x,), rounds=50, iterations=1)
benchmark.pedantic(run, args=(x, obj), rounds=50, iterations=1)


@pytest.mark.slow
@pytest.mark.benchmark
def test_proximal_freeb_jac(benchmark):
"""Benchmark computing free boundary jacobian with proximal constraint."""
jax.clear_caches()
eq = desc.examples.get("ESTELL")
with pytest.warns(UserWarning, match="Reducing radial"):
eq.change_resolution(6, 6, 6, 12, 12, 12)
Expand All @@ -405,10 +385,10 @@ def test_proximal_freeb_jac(benchmark):
obj.compile()
x = obj.x(eq)

def run(x):
obj.jac_scaled(x, prox.constants).block_until_ready()
def run(x, obj, prox):
obj.jac_scaled_error(x, prox.constants).block_until_ready()

benchmark.pedantic(run, args=(x,), rounds=15, iterations=1)
benchmark.pedantic(run, args=(x, obj, prox), rounds=10, iterations=1)


@pytest.mark.slow
Expand Down Expand Up @@ -455,26 +435,19 @@ def run(eq):
@pytest.mark.benchmark
def test_LinearConstraintProjection_build(benchmark):
"""Benchmark LinearConstraintProjection build."""
eq = desc.examples.get("W7-X")

def setup():
def run():
jax.clear_caches()
eq = desc.examples.get("W7-X")

obj = ObjectiveFunction(ForceBalance(eq))
con = get_fixed_boundary_constraints(eq)
con = maybe_add_self_consistency(eq, con)
con = ObjectiveFunction(con)
obj.build()
con.build()
return (obj, con), {}

def run(obj, con):
lc = LinearConstraintProjection(obj, con)
lc.build()

benchmark.pedantic(
run,
setup=setup,
rounds=10,
iterations=1,
)

0 comments on commit a4ab21e

Please sign in to comment.