Skip to content

Commit

Permalink
Merge pull request #118 from graphcore-research/visualise_num_error
Browse files Browse the repository at this point in the history
add num_error visualisation on demand
  • Loading branch information
akrzgc authored Oct 10, 2023
2 parents 9a1e9b3 + dbf4e92 commit bd5c7ff
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 2 deletions.
33 changes: 32 additions & 1 deletion pyscf_ipu/nanoDFT/nanoDFT.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,27 @@ def nanoDFT_iteration(i, vals, opts, mol):
log["matrices"] = jax.lax.dynamic_update_slice(log["matrices"], H. reshape(1, 1, N, N), (i, 2, 0, 0))
log["energy"] = log["energy"].at[i].set(energy(density_matrix, H_core, diff_JK, E_xc, E_nuc)) # (iterations, 6)

if opts.vis_num_error is True:
import os
dir_label = opts.molecule_name
num_error_dir = f'num_error/{dir_label}/'
os.makedirs(num_error_dir , exist_ok=True)

def host_callback(data, i):
# labels are adjusted to the `data` that will be passed to the callback - keep that in mind when passing different list of tensors
labels = ["density_matrix", "V_xc", "diff_JK", "O", "H_core", "L_inv", "E_nuc", "occupancy", "ERI", "grid_weights", "grid_AO", "diis_history", "E_xc", "eigvects", "H"]
for l, d in zip(labels, data):
if l == "diis_history" or l == "ERI":
for idx, arr in enumerate(d):
np.savez(f'{num_error_dir}{i}_{l}{idx}.npz', v = np.array(arr))
else:
np.savez(f'{num_error_dir}{i}_{l}.npz', v = d)

jax.debug.callback(host_callback, vals[:-1] + [E_xc, eigvects, H], i)

return [density_matrix, V_xc, diff_JK, O, H_core, L_inv, E_nuc, occupancy, ERI, grid_weights, grid_AO, diis_history, log]


def exchange_correlation(density_matrix, grid_AO, grid_weights):
"""Compute exchange correlation integral using atomic orbitals (AO) evalauted on a grid. """
# Perfectly SIMD parallelizable over grid_size axis.
Expand Down Expand Up @@ -546,7 +565,9 @@ def nanoDFT_options(
ndevices: int = 1,
dense_ERI: bool = False,
v: bool = False, # verbose
profile: bool = False # if we only want profile exit after IPU finishes.
profile: bool = False, # if we only want profile exit after IPU finishes.
vis_num_error: bool = False,
molecule_name: str = None
):
"""
nanoDFT
Expand All @@ -570,6 +591,10 @@ def nanoDFT_options(
ao_threshold (float): Zero out grid_AO that are below the threshold in absolute value.
dense_ERI (bool): Whether to use dense ERI (s1) or sparse symmtric ERI.
"""
if molecule_name is None:
# use mol_str as a molecule name (in case it has not been provided)
# before mol_str CLI arg is preprocessed and overwritten
molecule_name = mol_str

# From a compound name or CID, get a list of its atoms and their coordinates
mol_str = utils.process_mol_str(mol_str)
Expand Down Expand Up @@ -611,6 +636,12 @@ def main():
nanoDFT_forces = grad(mol, grid_coords, grid_weights, mo_coeff, mo_energy)
pyscf_E, pyscf_hlgap, pyscf_forces = pyscf_reference(mol_str, opts)
print_difference(nanoDFT_E, nanoDFT_forces, nanoDFT_logged_E, nanoDFT_hlgap, pyscf_E, pyscf_forces, pyscf_hlgap)

if opts.vis_num_error is True:
from utils import save_plot
import sys
_plot_title = f"Created with: python {' '.join(sys.argv)}"
save_plot("num_error/", opts.molecule_name, opts.its, _plot_title)
else:
# pip install mogli imageio[ffmpeg] matplotlib
import mogli
Expand Down
78 changes: 77 additions & 1 deletion pyscf_ipu/nanoDFT/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,4 +147,80 @@ def min_interatomic_distance(mol_str):
"""This computes the minimum distance between atoms."""
coords = map(itemgetter(1), mol_str)
distances = map(lambda x: np.linalg.norm(np.array(x[0]) - np.array(x[1])), combinations(coords, 2))
return min(distances)
return min(distances)


def save_plot(base_data_dir: str, molecule_name: str, iterations: int, _plot_title: str = "Default Title"):
import matplotlib.pyplot as plt
import matplotlib
import os
matplotlib.rc('font', **{'family': 'serif', 'serif': ['Computer Modern']})

import seaborn as sns
sns.set_theme()
sns.set_style("white")

data_dir = base_data_dir + molecule_name + '/'

def prepare(val):
val = np.abs(val[val == val])
val[np.logical_and(val<1e-15, val!=0)] = 2e-15 # show the ones that go out of plot
val[val==0] = 1e-17 # remove zeros.
return val

xticks = []
xticklabels = []

fig, ax = plt.subplots(1, 1, figsize=(14,8))
images_subdir = f'{base_data_dir}/tmp_images/num_error/'
os.makedirs(images_subdir, exist_ok=True)

for outer_num, i in enumerate(range(iterations)):
skip = 0
print(f'figure [{i+1} / {iterations}]\r', end="")
plt.cla()
plt.title("[Iterations %i] \n"%(i+1) + _plot_title)
files = sorted([a for a in os.listdir(data_dir) if "[" not in a and int(a.split("_")[0]) == i and ".jpg" not in a and ".gif" not in a])

for num, file in enumerate(files):
val= np.load(data_dir+file, allow_pickle=True)["v"]
shape = val.shape
if np.prod(shape) <= 1:
skip += 1
continue

val = prepare(val)
val = np.sort(val)
num_max_dots = 500

if val.size > num_max_dots: val= val[::int(val.size)//num_max_dots]

ys = -np.ones(val.shape[0])*(num - skip)
ax.plot([1e-15, 1e18], [ys[0], ys[1]], 'C%i-'%(num%10), lw=10, alpha=0.2)
ax.plot(val, ys, 'C%io'%(num%10), ms=6, alpha=0.2)

if i == 0:
xticks.append(ys[0])
xticklabels.append(file.replace(".npz", "").replace("%i_"%i, ""))

plt.plot( [10**(-10), 10**(-10)], [0, xticks[-1]], 'C7--', alpha=0.6)
plt.plot( [10**(10), 10**10], [0, xticks[-1]], 'C7--', alpha=0.6)
plt.plot( [10**(0), 10**0], [0, xticks[-1]], 'C7-', alpha=1)

for x, label in zip(xticks, xticklabels):
ax.text(1e10, x+0.25, label, horizontalalignment='left', size='small', color='black', weight='normal')

plt.yticks([], [])
plt.xscale("log")
plt.xlim([10**(-15), 10**18])
if i == 0: plt.tight_layout()

plt.savefig(f'{images_subdir}num_error{outer_num}.jpg')

import imageio
gif_path = f'{base_data_dir}visualize_DFT_num_error_{molecule_name}.gif'
writer = imageio.get_writer(gif_path, loop=0, duration=7)
for i in range(iterations):
writer.append_data(imageio.v2.imread(f'{images_subdir}num_error{i}.jpg'))
writer.close()
print("Numerical error visualisation saved in", gif_path)

0 comments on commit bd5c7ff

Please sign in to comment.