diff --git a/pyscf_ipu/nanoDFT/nanoDFT.py b/pyscf_ipu/nanoDFT/nanoDFT.py index c5570e6..4176afc 100644 --- a/pyscf_ipu/nanoDFT/nanoDFT.py +++ b/pyscf_ipu/nanoDFT/nanoDFT.py @@ -87,21 +87,21 @@ def exchange_correlation(density_matrix, grid_AO, grid_weights): """Compute exchange correlation integral using atomic orbitals (AO) evalauted on a grid. """ # Perfectly SIMD parallelizable over grid_size axis. # Only need one reduce_sum in the end. - grid_AO_dm = grid_AO[0] @ density_matrix # (gsize, N) @ (N, N) -> (gsize, N) - grid_AO_dm = jnp.expand_dims(grid_AO_dm, axis=0) # (1, gsize, N) + grid_AO_dm = grid_AO[0] @ density_matrix # (gsize, N) @ (N, N) -> (gsize, N) + grid_AO_dm = jnp.expand_dims(grid_AO_dm, axis=0) # (1, gsize, N) mult = grid_AO_dm * grid_AO - rho = jnp.sum(mult, axis=2) # (4, grid_size)=(4, 45624) for C6H6. - E_xc, vrho, vgamma = b3lyp(rho, EPSILON_B3LYP) # (gridsize,) (gridsize,) (gridsize,) - E_xc = jax.lax.psum(jnp.sum(rho[0] * grid_weights * E_xc), axis_name="p") # float=-27.968[Ha] for C6H6 at convergence. - rho = jnp.concatenate([vrho.reshape(1, -1)/2, 4*vgamma*rho[1:4]], axis=0) * grid_weights # (4, grid_size)=(4, 45624) - grid_AO_T = grid_AO[0].T # (N, gsize) - rho = jnp.expand_dims(rho, axis=2) # (4, gsize, 1) - grid_AO_rho = grid_AO * rho # (4, gsize, N) - sum_grid_AO_rho = jnp.sum(grid_AO_rho, axis=0) # (gsize, N) - V_xc = grid_AO_T @ sum_grid_AO_rho # (N, N) - V_xc = jax.lax.psum(V_xc, axis_name="p") #(N, N) - V_xc = V_xc + V_xc.T # (N, N) - return E_xc, V_xc # (float) (N, N) + rho = jnp.sum(mult, axis=2) # (4, grid_size)=(4, 45624) for C6H6. + E_xc, vrho, vgamma = b3lyp(rho, EPSILON_B3LYP) # (gridsize,) (gridsize,) (gridsize,) + E_xc = jax.lax.psum(jnp.sum(rho[0] * grid_weights * E_xc), axis_name="p") # float=-27.968[Ha] for C6H6 at convergence. + rho = jnp.concatenate([vrho.reshape(1, -1)/2, 4*vgamma*rho[1:4]], axis=0) * grid_weights # (4, grid_size)=(4, 45624) + grid_AO_T = grid_AO[0].T # (N, gsize) + rho = jnp.expand_dims(rho, axis=2) # (4, gsize, 1) + grid_AO_rho = grid_AO * rho # (4, gsize, N) + sum_grid_AO_rho = jnp.sum(grid_AO_rho, axis=0) # (gsize, N) + V_xc = grid_AO_T @ sum_grid_AO_rho # (N, N) + V_xc = jax.lax.psum(V_xc, axis_name="p") # (N, N) + V_xc = V_xc + V_xc.T # (N, N) + return E_xc, V_xc # (float) (N, N) def get_JK(density_matrix, ERI, dense_ERI, backend): """Computes the (N, N) matrices J and K. Density matrix is (N, N) and ERI is (N, N, N, N). """ @@ -201,15 +201,14 @@ def init_dft_tensors_cpu(mol, opts, DIIS_iters=9): grid_AO = mol.eval_gto(coord_str, grids.coords, 4) # (4, grid_size, N) = (4, 45624, 9) for C6H6. if opts.ao_threshold > 0.0: grid_AO[np.abs(grid_AO)