diff --git a/examples/scripts/jaxified-idaklu-benchmarks.py b/examples/scripts/jaxified-idaklu-benchmarks.py index 5c9e7214..50eddd84 100644 --- a/examples/scripts/jaxified-idaklu-benchmarks.py +++ b/examples/scripts/jaxified-idaklu-benchmarks.py @@ -1,22 +1,14 @@ import time -import jax -import jax.numpy as jnp import numpy as np import pybamm import pybop -n = 30 # Number of solves -output_vars = [ - "Voltage [V]", - "Current [A]", - "Time [s]", -] - +n = 1 # Number of solves solvers = [ pybamm.CasadiSolver(mode="fast with events", atol=1e-6, rtol=1e-6), - pybamm.IDAKLUSolver(atol=1e-6, rtol=1e-6, output_variables=output_vars), + pybamm.IDAKLUSolver(atol=1e-6, rtol=1e-6), ] # Parameter set and model definition @@ -96,56 +88,19 @@ def inputs(): out = cost(inputs(), calculate_grad=True) print(f"({solver.name}) Time PyBOP Cost w/grad: {time.time() - start_time:.4f}") -# Jaxified benchmarks -ida_solver = pybamm.IDAKLUSolver(atol=1e-6, rtol=1e-6, output_variables=output_vars) -solver = ida_solver.jaxify(model=model.built_model, t_eval=t_eval) -solver_no_grad = ida_solver.jaxify( - model=model.built_model, t_eval=t_eval, calculate_sensitivities=False -) -f = solver.get_jaxpr() -k = solver_no_grad.get_jaxpr() - -start_time = time.time() -for _i in range(n): - kout = k(t_eval, inputs()) -print(f"Time jax expression w/o grad: {time.time() - start_time:.4f}") - -start_time = time.time() -for _i in range(n): - fout = f(t_eval, inputs()) -print(f"Time jax expression w/ grad: {time.time() - start_time:.4f}") - -start_time = time.time() -for _i in range(n): - y = solver_no_grad.get_var("Voltage [V]")(t_eval, inputs()) -print(f"Time jax solver_no_grad.get_var: {time.time() - start_time:.4f}") - -start_time = time.time() -for _i in range(n): - y = solver.get_var("Voltage [V]")(t_eval, inputs()) -print(f"Time jax solver.get_var: {time.time() - start_time:.4f}") - - -# Sum-of-squared errors -def sse_no_grad(t_eval, inputs): - y = solver_no_grad.get_var("Voltage [V]")(t_eval, inputs) - r = jnp.asarray([y - problem.target[signal] for signal in problem.signal]) - return jnp.sum(jnp.sum(r**2, axis=0), axis=0) - - -# Sum-of-squared errors -def sse_grad(t_eval, inputs): - y = solver.get_var("Voltage [V]")(t_eval, inputs) - r = jnp.asarray([y - problem.target[signal] for signal in problem.signal]) - return jnp.sum(jnp.sum(r**2, axis=0), axis=0) - +# Recreate for Jax IDAKLU solver +ida_solver = pybamm.IDAKLUSolver(atol=1e-6, rtol=1e-6) +model = pybop.lithium_ion.DFN(parameter_set=parameter_set, solver=ida_solver, jax=True) +problem = pybop.FittingProblem(model, parameters, dataset) +cost = pybop.JaxSumSquaredError(problem) +# Jaxified benchmarks start_time = time.time() for _i in range(n): - out = sse_no_grad(t_eval, inputs()) + out = cost(inputs(), calculate_grad=False) print(f"Time Jax SumSquaredError w/o grad: {time.time() - start_time:.4f}") start_time = time.time() for _i in range(n): - out = jax.value_and_grad(sse_grad, argnums=1)(t_eval, inputs()) + out = cost(inputs(), calculate_grad=True) print(f"Time Jax SumSquaredError w/ grad: {time.time() - start_time:.4f}")