Skip to content

Commit

Permalink
Adding cycle count and benchmarks .
Browse files Browse the repository at this point in the history
  • Loading branch information
alexandermath committed Oct 28, 2023
1 parent cd2d875 commit 2b1b739
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 2 deletions.
12 changes: 10 additions & 2 deletions pyscf_ipu/nanoDFT/nanoDFT.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,8 @@ def get_JK(density_matrix, ERI, dense_ERI, backend):
return diff_JK

def _nanoDFT(state, ERI, grid_AO, grid_weights, opts, mol):
if opts.backend == "ipu": grid_weights, start = utils.get_ipu_cycles(grid_weights)

# Utilize the IPUs MIMD parallism to compute the electron repulsion integrals (ERIs) in parallel.
#if opts.backend == "ipu": state.ERI = electron_repulsion_integrals(state.input_floats, state.input_ints, mol, opts.threads_int, opts.intv)
#else: pass # Compute on CPU.
Expand All @@ -157,7 +159,12 @@ def _nanoDFT(state, ERI, grid_AO, grid_weights, opts, mol):
log = jax.lax.fori_loop(0, opts.its, partial(nanoDFT_iteration, opts=opts, mol=mol), [state.density_matrix, V_xc, diff_JK, state.O, H_core, state.L_inv, # all (N, N) matrices
state.E_nuc, state.mask, ERI, grid_weights, grid_AO, state.diis_history, log])[-1]

return log["matrices"], H_core, log["energy"]
cycles = -1
if opts.backend == "ipu":
log["energy"], stop = utils.get_ipu_cycles(log["energy"])
cycles = (stop.array-start.array)[0,0]

return log["matrices"], H_core, log["energy"], cycles


FloatN = Float[Array, "N"]
Expand Down Expand Up @@ -320,7 +327,8 @@ def nanoDFT(mol, opts):
axis_name="p")
print(grid_AO.shape, grid_weights.shape)
vals = jitted_nanoDFT(state, ERI, grid_AO, grid_weights)
logged_matrices, H_core, logged_energies = [np.asarray(a[0]).astype(np.float64) for a in vals] # Ensure CPU
logged_matrices, H_core, logged_energies, cycles = [np.asarray(a[0]).astype(np.float64) for a in vals] # Ensure CPU
print("Cycle Count: ", cycles/10**6, "[M]")

# It's cheap to compute energy/hlgap on CPU in float64 from the logged values/matrices.
logged_E_xc = logged_energies[:, 3].copy()
Expand Down
Loading

0 comments on commit 2b1b739

Please sign in to comment.