Skip to content

Commit

Permalink
Merge pull request #114 from graphcore-research/bb/low-hanging-sparse…
Browse files Browse the repository at this point in the history
…-grid-ao

sparse grid ao along (0,2) axes
  • Loading branch information
blazejba authored Oct 10, 2023
2 parents 8b0ee07 + cd2b2d5 commit 39527d5
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 75 deletions.
57 changes: 1 addition & 56 deletions notebooks/sparse_grid_ao.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@
" \n",
" for line in output.split(\"\\n\"):\n",
" if \"Number of electrons\" in line: num_electrons = line.split(\" \")[-1].replace(' ', '')\n",
" elif \"sparsity_grid_AO\" in line: sparsity_level = line.split(\"=\")[-1]\n",
" elif \"axis=( , ) sparsity in grid_AO:\" in line: sparsity_level = line.split(\"=\")[-1]\n",
" elif \"Error Energy\" in line: error_energy = dict(zip(iterations, line.split()[-7:])) \n",
" elif \"Error HLGAP\" in line: error_hlgap = dict(zip(iterations, line.split()[-7:]))\n",
" \n",
Expand All @@ -182,47 +182,6 @@
"compute_results(results)"
]
},
{
"cell_type": "code",
"execution_count": 96,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"60\n",
"54\n"
]
}
],
"source": [
"to_del = []\n",
"for pair, value in results.items():\n",
" (mol, thresh) = pair\n",
" if float(thresh) == 0.1:\n",
" to_del.append(pair)\n",
"to_del\n",
"\n",
"print(len(results))\n",
"for delli in to_del:\n",
" results.pop(delli)\n",
"print(len(results))"
]
},
{
"cell_type": "code",
"execution_count": 97,
"metadata": {},
"outputs": [],
"source": [
"# Convert tuple keys to string before saving\n",
"str_results = {f\"{key[0]}:{key[1]}\": value for key, value in results.items()}\n",
"# Cache the results after each computation\n",
"with open(results_file, 'w') as f:\n",
" json.dump(str_results, f, indent=4)"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -371,20 +330,6 @@
"# Call the function to plot\n",
"plot_sparsity_3d_surface(results)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down
46 changes: 27 additions & 19 deletions pyscf_ipu/nanoDFT/nanoDFT.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,21 +87,21 @@ def exchange_correlation(density_matrix, grid_AO, grid_weights):
"""Compute exchange correlation integral using atomic orbitals (AO) evalauted on a grid. """
# Perfectly SIMD parallelizable over grid_size axis.
# Only need one reduce_sum in the end.
grid_AO_dm = grid_AO[0] @ density_matrix # (gsize, N) @ (N, N) -> (gsize, N)
grid_AO_dm = jnp.expand_dims(grid_AO_dm, axis=0) # (1, gsize, N)
grid_AO_dm = grid_AO[0] @ density_matrix # (gsize, N) @ (N, N) -> (gsize, N)
grid_AO_dm = jnp.expand_dims(grid_AO_dm, axis=0) # (1, gsize, N)
mult = grid_AO_dm * grid_AO
rho = jnp.sum(mult, axis=2) # (4, grid_size)=(4, 45624) for C6H6.
E_xc, vrho, vgamma = b3lyp(rho, EPSILON_B3LYP) # (gridsize,) (gridsize,) (gridsize,)
E_xc = jax.lax.psum(jnp.sum(rho[0] * grid_weights * E_xc), axis_name="p") # float=-27.968[Ha] for C6H6 at convergence.
rho = jnp.concatenate([vrho.reshape(1, -1)/2, 4*vgamma*rho[1:4]], axis=0) * grid_weights # (4, grid_size)=(4, 45624)
grid_AO_T = grid_AO[0].T # (N, gsize)
rho = jnp.expand_dims(rho, axis=2) # (4, gsize, 1)
grid_AO_rho = grid_AO * rho # (4, gsize, N)
sum_grid_AO_rho = jnp.sum(grid_AO_rho, axis=0) # (gsize, N)
V_xc = grid_AO_T @ sum_grid_AO_rho # (N, N)
V_xc = jax.lax.psum(V_xc, axis_name="p") #(N, N)
V_xc = V_xc + V_xc.T # (N, N)
return E_xc, V_xc # (float) (N, N)
rho = jnp.sum(mult, axis=2) # (4, grid_size)=(4, 45624) for C6H6.
E_xc, vrho, vgamma = b3lyp(rho, EPSILON_B3LYP) # (gridsize,) (gridsize,) (gridsize,)
E_xc = jax.lax.psum(jnp.sum(rho[0] * grid_weights * E_xc), axis_name="p") # float=-27.968[Ha] for C6H6 at convergence.
rho = jnp.concatenate([vrho.reshape(1, -1)/2, 4*vgamma*rho[1:4]], axis=0) * grid_weights # (4, grid_size)=(4, 45624)
grid_AO_T = grid_AO[0].T # (N, gsize)
rho = jnp.expand_dims(rho, axis=2) # (4, gsize, 1)
grid_AO_rho = grid_AO * rho # (4, gsize, N)
sum_grid_AO_rho = jnp.sum(grid_AO_rho, axis=0) # (gsize, N)
V_xc = grid_AO_T @ sum_grid_AO_rho # (N, N)
V_xc = jax.lax.psum(V_xc, axis_name="p") # (N, N)
V_xc = V_xc + V_xc.T # (N, N)
return E_xc, V_xc # (float) (N, N)

def get_JK(density_matrix, ERI, dense_ERI, backend):
"""Computes the (N, N) matrices J and K. Density matrix is (N, N) and ERI is (N, N, N, N). """
Expand Down Expand Up @@ -199,10 +199,18 @@ def init_dft_tensors_cpu(mol, opts, DIIS_iters=9):
grid_weights = grids.weights # (grid_size,) = (45624,) for C6H6
coord_str = 'GTOval_cart_deriv1' if mol.cart else 'GTOval_sph_deriv1'
grid_AO = mol.eval_gto(coord_str, grids.coords, 4) # (4, grid_size, N) = (4, 45624, 9) for C6H6.
if opts.ao_threshold is not None:
if opts.ao_threshold > 0.0:
grid_AO[np.abs(grid_AO)<opts.ao_threshold] = 0
sparsity_grid_AO = np.sum(grid_AO==0) / grid_AO.size
print(f"{sparsity_grid_AO=:.4f}")
sparsity_mask = np.where(np.all(grid_AO == 0, axis=0), 0, 1)
sparse_rows = np.where(np.all(sparsity_mask == 0, axis=1), 0, 1).reshape(-1, 1)
print(f"axis=( , ) sparsity in grid_AO: {np.sum(grid_AO==0) / grid_AO.size:.4f}")
print(f"axis=(0, ) sparsity in grid_AO: {np.sum(sparsity_mask==0) / sparsity_mask.size:.4f}")
print(f"axis=(0, 2) sparsity in grid_AO: {np.sum(sparse_rows==0) / sparse_rows.size:.4f}")
grid_AO = jnp.delete(grid_AO, jnp.where(sparse_rows == 0)[0], axis=1)
grid_weights = jnp.delete(grid_weights, jnp.where(sparse_rows == 0)[0], axis=0)
grid_coords = jnp.delete(grids.coords, jnp.where(sparse_rows == 0)[0], axis=0)
else:
grid_coords = grids.coords
density_matrix = pyscf.scf.hf.init_guess_by_minao(mol) # (N,N)=(66,66) for C6H6.

# TODO(): Add integral math formulas for kinetic/nuclear/O/ERI.
Expand All @@ -227,7 +235,7 @@ def init_dft_tensors_cpu(mol, opts, DIIS_iters=9):
L_inv=L_inv, diis_history=diis_history)


return state, n_electrons_half, E_nuc, N, L_inv, grid_weights, grids.coords, grid_AO
return state, n_electrons_half, E_nuc, N, L_inv, grid_weights, grid_coords, grid_AO

def nanoDFT(mol, opts):
# Init DFT tensors on CPU using PySCF.
Expand Down Expand Up @@ -533,7 +541,7 @@ def nanoDFT_options(
diis: bool = True,
structure_optimization: bool = False, # AKA gradient descent on energy wrt nuclei
eri_threshold : float = 0.0,
ao_threshold: float = None,
ao_threshold: float = 0.0,
batches: int = 32,
ndevices: int = 1,
dense_ERI: bool = False,
Expand Down

0 comments on commit 39527d5

Please sign in to comment.