Skip to content

Commit

Permalink
examples: update benchmarking script
Browse files Browse the repository at this point in the history
  • Loading branch information
BradyPlanden committed Sep 2, 2024
1 parent d3a18d5 commit eb08c00
Showing 1 changed file with 10 additions and 55 deletions.
65 changes: 10 additions & 55 deletions examples/scripts/jaxified-idaklu-benchmarks.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,14 @@
import time

import jax
import jax.numpy as jnp
import numpy as np
import pybamm

import pybop

n = 30 # Number of solves
output_vars = [
"Voltage [V]",
"Current [A]",
"Time [s]",
]

n = 1 # Number of solves
solvers = [
pybamm.CasadiSolver(mode="fast with events", atol=1e-6, rtol=1e-6),
pybamm.IDAKLUSolver(atol=1e-6, rtol=1e-6, output_variables=output_vars),
pybamm.IDAKLUSolver(atol=1e-6, rtol=1e-6),
]

# Parameter set and model definition
Expand Down Expand Up @@ -96,56 +88,19 @@ def inputs():
out = cost(inputs(), calculate_grad=True)
print(f"({solver.name}) Time PyBOP Cost w/grad: {time.time() - start_time:.4f}")

# Jaxified benchmarks
ida_solver = pybamm.IDAKLUSolver(atol=1e-6, rtol=1e-6, output_variables=output_vars)
solver = ida_solver.jaxify(model=model.built_model, t_eval=t_eval)
solver_no_grad = ida_solver.jaxify(
model=model.built_model, t_eval=t_eval, calculate_sensitivities=False
)
f = solver.get_jaxpr()
k = solver_no_grad.get_jaxpr()

start_time = time.time()
for _i in range(n):
kout = k(t_eval, inputs())
print(f"Time jax expression w/o grad: {time.time() - start_time:.4f}")

start_time = time.time()
for _i in range(n):
fout = f(t_eval, inputs())
print(f"Time jax expression w/ grad: {time.time() - start_time:.4f}")

start_time = time.time()
for _i in range(n):
y = solver_no_grad.get_var("Voltage [V]")(t_eval, inputs())
print(f"Time jax solver_no_grad.get_var: {time.time() - start_time:.4f}")

start_time = time.time()
for _i in range(n):
y = solver.get_var("Voltage [V]")(t_eval, inputs())
print(f"Time jax solver.get_var: {time.time() - start_time:.4f}")


# Sum-of-squared errors
def sse_no_grad(t_eval, inputs):
y = solver_no_grad.get_var("Voltage [V]")(t_eval, inputs)
r = jnp.asarray([y - problem.target[signal] for signal in problem.signal])
return jnp.sum(jnp.sum(r**2, axis=0), axis=0)


# Sum-of-squared errors
def sse_grad(t_eval, inputs):
y = solver.get_var("Voltage [V]")(t_eval, inputs)
r = jnp.asarray([y - problem.target[signal] for signal in problem.signal])
return jnp.sum(jnp.sum(r**2, axis=0), axis=0)

# Recreate for Jax IDAKLU solver
ida_solver = pybamm.IDAKLUSolver(atol=1e-6, rtol=1e-6)
model = pybop.lithium_ion.DFN(parameter_set=parameter_set, solver=ida_solver, jax=True)
problem = pybop.FittingProblem(model, parameters, dataset)
cost = pybop.JaxSumSquaredError(problem)

# Jaxified benchmarks
start_time = time.time()
for _i in range(n):
out = sse_no_grad(t_eval, inputs())
out = cost(inputs(), calculate_grad=False)
print(f"Time Jax SumSquaredError w/o grad: {time.time() - start_time:.4f}")

start_time = time.time()
for _i in range(n):
out = jax.value_and_grad(sse_grad, argnums=1)(t_eval, inputs())
out = cost(inputs(), calculate_grad=True)
print(f"Time Jax SumSquaredError w/ grad: {time.time() - start_time:.4f}")

0 comments on commit eb08c00

Please sign in to comment.