Skip to content

Commit

Permalink
print sparsity llevels before deleting elements
Browse files Browse the repository at this point in the history
  • Loading branch information
Blazej Banaszewski committed Oct 10, 2023
1 parent 5d1e67e commit cd2b2d5
Showing 1 changed file with 17 additions and 18 deletions.
35 changes: 17 additions & 18 deletions pyscf_ipu/nanoDFT/nanoDFT.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,21 +87,21 @@ def exchange_correlation(density_matrix, grid_AO, grid_weights):
"""Compute exchange correlation integral using atomic orbitals (AO) evalauted on a grid. """
# Perfectly SIMD parallelizable over grid_size axis.
# Only need one reduce_sum in the end.
grid_AO_dm = grid_AO[0] @ density_matrix # (gsize, N) @ (N, N) -> (gsize, N)
grid_AO_dm = jnp.expand_dims(grid_AO_dm, axis=0) # (1, gsize, N)
grid_AO_dm = grid_AO[0] @ density_matrix # (gsize, N) @ (N, N) -> (gsize, N)
grid_AO_dm = jnp.expand_dims(grid_AO_dm, axis=0) # (1, gsize, N)
mult = grid_AO_dm * grid_AO
rho = jnp.sum(mult, axis=2) # (4, grid_size)=(4, 45624) for C6H6.
E_xc, vrho, vgamma = b3lyp(rho, EPSILON_B3LYP) # (gridsize,) (gridsize,) (gridsize,)
E_xc = jax.lax.psum(jnp.sum(rho[0] * grid_weights * E_xc), axis_name="p") # float=-27.968[Ha] for C6H6 at convergence.
rho = jnp.concatenate([vrho.reshape(1, -1)/2, 4*vgamma*rho[1:4]], axis=0) * grid_weights # (4, grid_size)=(4, 45624)
grid_AO_T = grid_AO[0].T # (N, gsize)
rho = jnp.expand_dims(rho, axis=2) # (4, gsize, 1)
grid_AO_rho = grid_AO * rho # (4, gsize, N)
sum_grid_AO_rho = jnp.sum(grid_AO_rho, axis=0) # (gsize, N)
V_xc = grid_AO_T @ sum_grid_AO_rho # (N, N)
V_xc = jax.lax.psum(V_xc, axis_name="p") #(N, N)
V_xc = V_xc + V_xc.T # (N, N)
return E_xc, V_xc # (float) (N, N)
rho = jnp.sum(mult, axis=2) # (4, grid_size)=(4, 45624) for C6H6.
E_xc, vrho, vgamma = b3lyp(rho, EPSILON_B3LYP) # (gridsize,) (gridsize,) (gridsize,)
E_xc = jax.lax.psum(jnp.sum(rho[0] * grid_weights * E_xc), axis_name="p") # float=-27.968[Ha] for C6H6 at convergence.
rho = jnp.concatenate([vrho.reshape(1, -1)/2, 4*vgamma*rho[1:4]], axis=0) * grid_weights # (4, grid_size)=(4, 45624)
grid_AO_T = grid_AO[0].T # (N, gsize)
rho = jnp.expand_dims(rho, axis=2) # (4, gsize, 1)
grid_AO_rho = grid_AO * rho # (4, gsize, N)
sum_grid_AO_rho = jnp.sum(grid_AO_rho, axis=0) # (gsize, N)
V_xc = grid_AO_T @ sum_grid_AO_rho # (N, N)
V_xc = jax.lax.psum(V_xc, axis_name="p") # (N, N)
V_xc = V_xc + V_xc.T # (N, N)
return E_xc, V_xc # (float) (N, N)

def get_JK(density_matrix, ERI, dense_ERI, backend):
"""Computes the (N, N) matrices J and K. Density matrix is (N, N) and ERI is (N, N, N, N). """
Expand Down Expand Up @@ -201,15 +201,14 @@ def init_dft_tensors_cpu(mol, opts, DIIS_iters=9):
grid_AO = mol.eval_gto(coord_str, grids.coords, 4) # (4, grid_size, N) = (4, 45624, 9) for C6H6.
if opts.ao_threshold > 0.0:
grid_AO[np.abs(grid_AO)<opts.ao_threshold] = 0
# trim grid_size (4, grid_size, N) by removing slices with all zeros along (0, 2) axes
sparsity_mask = np.where(np.all(grid_AO == 0, axis=0), 0, 1)
sparse_rows = np.where(np.all(sparsity_mask == 0, axis=1), 0, 1).reshape(-1, 1)
grid_AO = jnp.delete(grid_AO, jnp.where(sparse_rows == 0)[0], axis=1)
grid_weights = jnp.delete(grid_weights, jnp.where(sparse_rows == 0)[0], axis=0)
grid_coords = jnp.delete(grids.coords, jnp.where(sparse_rows == 0)[0], axis=0)
print(f"axis=( , ) sparsity in grid_AO: {np.sum(grid_AO==0) / grid_AO.size:.4f}")
print(f"axis=(0, ) sparsity in grid_AO: {np.sum(sparsity_mask==0) / sparsity_mask.size:.4f}")
print(f"axis=(0, 2) sparsity in grid_AO: {np.sum(sparse_rows==0) / sparse_rows.size:.4f}")
grid_AO = jnp.delete(grid_AO, jnp.where(sparse_rows == 0)[0], axis=1)
grid_weights = jnp.delete(grid_weights, jnp.where(sparse_rows == 0)[0], axis=0)
grid_coords = jnp.delete(grids.coords, jnp.where(sparse_rows == 0)[0], axis=0)
else:
grid_coords = grids.coords
density_matrix = pyscf.scf.hf.init_guess_by_minao(mol) # (N,N)=(66,66) for C6H6.
Expand Down

0 comments on commit cd2b2d5

Please sign in to comment.