Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sparse grid ao along (0,2) axes #114

Merged
merged 2 commits into from
Oct 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 1 addition & 56 deletions notebooks/sparse_grid_ao.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@
" \n",
" for line in output.split(\"\\n\"):\n",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did we check if the unit [eV] was correct (or if it should be [meV])? I think it'd be neat to have a cut-off around 42meV (chemical accuracy). If the units are [eV] then we're probably sparsifying a bit too much (150eV error is >1000x worse error than ML models).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the unit was correct. The high error we saw was for 32 carbon atoms placed in a line - the error for placing atoms in a line is high regardless of the grid sparsification (c_32 on the graph).

image

" if \"Number of electrons\" in line: num_electrons = line.split(\" \")[-1].replace(' ', '')\n",
" elif \"sparsity_grid_AO\" in line: sparsity_level = line.split(\"=\")[-1]\n",
" elif \"axis=( , ) sparsity in grid_AO:\" in line: sparsity_level = line.split(\"=\")[-1]\n",
" elif \"Error Energy\" in line: error_energy = dict(zip(iterations, line.split()[-7:])) \n",
" elif \"Error HLGAP\" in line: error_hlgap = dict(zip(iterations, line.split()[-7:]))\n",
" \n",
Expand All @@ -182,47 +182,6 @@
"compute_results(results)"
]
},
{
"cell_type": "code",
"execution_count": 96,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"60\n",
"54\n"
]
}
],
"source": [
"to_del = []\n",
"for pair, value in results.items():\n",
" (mol, thresh) = pair\n",
" if float(thresh) == 0.1:\n",
" to_del.append(pair)\n",
"to_del\n",
"\n",
"print(len(results))\n",
"for delli in to_del:\n",
" results.pop(delli)\n",
"print(len(results))"
]
},
{
"cell_type": "code",
"execution_count": 97,
"metadata": {},
"outputs": [],
"source": [
"# Convert tuple keys to string before saving\n",
"str_results = {f\"{key[0]}:{key[1]}\": value for key, value in results.items()}\n",
"# Cache the results after each computation\n",
"with open(results_file, 'w') as f:\n",
" json.dump(str_results, f, indent=4)"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -371,20 +330,6 @@
"# Call the function to plot\n",
"plot_sparsity_3d_surface(results)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down
46 changes: 27 additions & 19 deletions pyscf_ipu/nanoDFT/nanoDFT.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,21 +87,21 @@ def exchange_correlation(density_matrix, grid_AO, grid_weights):
"""Compute exchange correlation integral using atomic orbitals (AO) evalauted on a grid. """
# Perfectly SIMD parallelizable over grid_size axis.
# Only need one reduce_sum in the end.
grid_AO_dm = grid_AO[0] @ density_matrix # (gsize, N) @ (N, N) -> (gsize, N)
grid_AO_dm = jnp.expand_dims(grid_AO_dm, axis=0) # (1, gsize, N)
grid_AO_dm = grid_AO[0] @ density_matrix # (gsize, N) @ (N, N) -> (gsize, N)
grid_AO_dm = jnp.expand_dims(grid_AO_dm, axis=0) # (1, gsize, N)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Amazing we can use the same code for this!

mult = grid_AO_dm * grid_AO
rho = jnp.sum(mult, axis=2) # (4, grid_size)=(4, 45624) for C6H6.
E_xc, vrho, vgamma = b3lyp(rho, EPSILON_B3LYP) # (gridsize,) (gridsize,) (gridsize,)
E_xc = jax.lax.psum(jnp.sum(rho[0] * grid_weights * E_xc), axis_name="p") # float=-27.968[Ha] for C6H6 at convergence.
rho = jnp.concatenate([vrho.reshape(1, -1)/2, 4*vgamma*rho[1:4]], axis=0) * grid_weights # (4, grid_size)=(4, 45624)
grid_AO_T = grid_AO[0].T # (N, gsize)
rho = jnp.expand_dims(rho, axis=2) # (4, gsize, 1)
grid_AO_rho = grid_AO * rho # (4, gsize, N)
sum_grid_AO_rho = jnp.sum(grid_AO_rho, axis=0) # (gsize, N)
V_xc = grid_AO_T @ sum_grid_AO_rho # (N, N)
V_xc = jax.lax.psum(V_xc, axis_name="p") #(N, N)
V_xc = V_xc + V_xc.T # (N, N)
return E_xc, V_xc # (float) (N, N)
rho = jnp.sum(mult, axis=2) # (4, grid_size)=(4, 45624) for C6H6.
E_xc, vrho, vgamma = b3lyp(rho, EPSILON_B3LYP) # (gridsize,) (gridsize,) (gridsize,)
E_xc = jax.lax.psum(jnp.sum(rho[0] * grid_weights * E_xc), axis_name="p") # float=-27.968[Ha] for C6H6 at convergence.
rho = jnp.concatenate([vrho.reshape(1, -1)/2, 4*vgamma*rho[1:4]], axis=0) * grid_weights # (4, grid_size)=(4, 45624)
grid_AO_T = grid_AO[0].T # (N, gsize)
rho = jnp.expand_dims(rho, axis=2) # (4, gsize, 1)
grid_AO_rho = grid_AO * rho # (4, gsize, N)
sum_grid_AO_rho = jnp.sum(grid_AO_rho, axis=0) # (gsize, N)
V_xc = grid_AO_T @ sum_grid_AO_rho # (N, N)
V_xc = jax.lax.psum(V_xc, axis_name="p") # (N, N)
V_xc = V_xc + V_xc.T # (N, N)
return E_xc, V_xc # (float) (N, N)

def get_JK(density_matrix, ERI, dense_ERI, backend):
"""Computes the (N, N) matrices J and K. Density matrix is (N, N) and ERI is (N, N, N, N). """
Expand Down Expand Up @@ -199,10 +199,18 @@ def init_dft_tensors_cpu(mol, opts, DIIS_iters=9):
grid_weights = grids.weights # (grid_size,) = (45624,) for C6H6
coord_str = 'GTOval_cart_deriv1' if mol.cart else 'GTOval_sph_deriv1'
grid_AO = mol.eval_gto(coord_str, grids.coords, 4) # (4, grid_size, N) = (4, 45624, 9) for C6H6.
if opts.ao_threshold is not None:
if opts.ao_threshold > 0.0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really like the simplicity: we just pre-process remove and only change grid_size so exchange_correlation code remains the same!

grid_AO[np.abs(grid_AO)<opts.ao_threshold] = 0
sparsity_grid_AO = np.sum(grid_AO==0) / grid_AO.size
print(f"{sparsity_grid_AO=:.4f}")
sparsity_mask = np.where(np.all(grid_AO == 0, axis=0), 0, 1)
sparse_rows = np.where(np.all(sparsity_mask == 0, axis=1), 0, 1).reshape(-1, 1)
print(f"axis=( , ) sparsity in grid_AO: {np.sum(grid_AO==0) / grid_AO.size:.4f}")
print(f"axis=(0, ) sparsity in grid_AO: {np.sum(sparsity_mask==0) / sparsity_mask.size:.4f}")
print(f"axis=(0, 2) sparsity in grid_AO: {np.sum(sparse_rows==0) / sparse_rows.size:.4f}")
grid_AO = jnp.delete(grid_AO, jnp.where(sparse_rows == 0)[0], axis=1)
grid_weights = jnp.delete(grid_weights, jnp.where(sparse_rows == 0)[0], axis=0)
grid_coords = jnp.delete(grids.coords, jnp.where(sparse_rows == 0)[0], axis=0)
else:
grid_coords = grids.coords
density_matrix = pyscf.scf.hf.init_guess_by_minao(mol) # (N,N)=(66,66) for C6H6.

# TODO(): Add integral math formulas for kinetic/nuclear/O/ERI.
Expand All @@ -227,7 +235,7 @@ def init_dft_tensors_cpu(mol, opts, DIIS_iters=9):
L_inv=L_inv, diis_history=diis_history)


return state, n_electrons_half, E_nuc, N, L_inv, grid_weights, grids.coords, grid_AO
return state, n_electrons_half, E_nuc, N, L_inv, grid_weights, grid_coords, grid_AO

def nanoDFT(mol, opts):
# Init DFT tensors on CPU using PySCF.
Expand Down Expand Up @@ -533,7 +541,7 @@ def nanoDFT_options(
diis: bool = True,
structure_optimization: bool = False, # AKA gradient descent on energy wrt nuclei
eri_threshold : float = 0.0,
ao_threshold: float = None,
ao_threshold: float = 0.0,
batches: int = 32,
ndevices: int = 1,
dense_ERI: bool = False,
Expand Down