diff --git a/pyscf_ipu/nanoDFT/nanoDFT.py b/pyscf_ipu/nanoDFT/nanoDFT.py index 01cc10e..6660efe 100644 --- a/pyscf_ipu/nanoDFT/nanoDFT.py +++ b/pyscf_ipu/nanoDFT/nanoDFT.py @@ -328,7 +328,7 @@ def nanoDFT(mol, opts): print(grid_AO.shape, grid_weights.shape) vals = jitted_nanoDFT(state, ERI, grid_AO, grid_weights) logged_matrices, H_core, logged_energies, cycles = [np.asarray(a[0]).astype(np.float64) for a in vals] # Ensure CPU - print("Cycle Count: ", cycles/10**6, "[M]") + if opts.backend == "ipu": print("Cycle Count: ", cycles/10**6, "[M]") # It's cheap to compute energy/hlgap on CPU in float64 from the logged values/matrices. logged_E_xc = logged_energies[:, 3].copy()