Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add num_error visualisation on demand #118

Merged
merged 1 commit into from
Oct 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 32 additions & 1 deletion pyscf_ipu/nanoDFT/nanoDFT.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,27 @@ def nanoDFT_iteration(i, vals, opts, mol):
log["matrices"] = jax.lax.dynamic_update_slice(log["matrices"], H. reshape(1, 1, N, N), (i, 2, 0, 0))
log["energy"] = log["energy"].at[i].set(energy(density_matrix, H_core, diff_JK, E_xc, E_nuc)) # (iterations, 6)

if opts.vis_num_error is True:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Neat. I debated with myself whether it'd be nice to have it in separate function. I prefer current form: I like that current form allow quick change of callback which is neat.

import os
dir_label = opts.molecule_name
num_error_dir = f'num_error/{dir_label}/'
os.makedirs(num_error_dir , exist_ok=True)

def host_callback(data, i):
# labels are adjusted to the `data` that will be passed to the callback - keep that in mind when passing different list of tensors
labels = ["density_matrix", "V_xc", "diff_JK", "O", "H_core", "L_inv", "E_nuc", "occupancy", "ERI", "grid_weights", "grid_AO", "diis_history", "E_xc", "eigvects", "H"]
for l, d in zip(labels, data):
if l == "diis_history" or l == "ERI":
for idx, arr in enumerate(d):
np.savez(f'{num_error_dir}{i}_{l}{idx}.npz', v = np.array(arr))
else:
np.savez(f'{num_error_dir}{i}_{l}.npz', v = d)

jax.debug.callback(host_callback, vals[:-1] + [E_xc, eigvects, H], i)

return [density_matrix, V_xc, diff_JK, O, H_core, L_inv, E_nuc, occupancy, ERI, grid_weights, grid_AO, diis_history, log]


def exchange_correlation(density_matrix, grid_AO, grid_weights):
"""Compute exchange correlation integral using atomic orbitals (AO) evalauted on a grid. """
# Perfectly SIMD parallelizable over grid_size axis.
Expand Down Expand Up @@ -538,7 +557,9 @@ def nanoDFT_options(
ndevices: int = 1,
dense_ERI: bool = False,
v: bool = False, # verbose
profile: bool = False # if we only want profile exit after IPU finishes.
profile: bool = False, # if we only want profile exit after IPU finishes.
vis_num_error: bool = False,
molecule_name: str = None
):
"""
nanoDFT
Expand All @@ -562,6 +583,10 @@ def nanoDFT_options(
ao_threshold (float): Zero out grid_AO that are below the threshold in absolute value.
dense_ERI (bool): Whether to use dense ERI (s1) or sparse symmtric ERI.
"""
if molecule_name is None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok

# use mol_str as a molecule name (in case it has not been provided)
# before mol_str CLI arg is preprocessed and overwritten
molecule_name = mol_str

# From a compound name or CID, get a list of its atoms and their coordinates
mol_str = utils.process_mol_str(mol_str)
Expand Down Expand Up @@ -603,6 +628,12 @@ def main():
nanoDFT_forces = grad(mol, grid_coords, grid_weights, mo_coeff, mo_energy)
pyscf_E, pyscf_hlgap, pyscf_forces = pyscf_reference(mol_str, opts)
print_difference(nanoDFT_E, nanoDFT_forces, nanoDFT_logged_E, nanoDFT_hlgap, pyscf_E, pyscf_forces, pyscf_hlgap)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok

if opts.vis_num_error is True:
from utils import save_plot
import sys
_plot_title = f"Created with: python {' '.join(sys.argv)}"
save_plot("num_error/", opts.molecule_name, opts.its, _plot_title)
else:
# pip install mogli imageio[ffmpeg] matplotlib
import mogli
Expand Down
78 changes: 77 additions & 1 deletion pyscf_ipu/nanoDFT/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,4 +142,80 @@ def min_interatomic_distance(mol_str):
"""This computes the minimum distance between atoms."""
coords = map(itemgetter(1), mol_str)
distances = map(lambda x: np.linalg.norm(np.array(x[0]) - np.array(x[1])), combinations(coords, 2))
return min(distances)
return min(distances)


def save_plot(base_data_dir: str, molecule_name: str, iterations: int, _plot_title: str = "Default Title"):
import matplotlib.pyplot as plt
import matplotlib
import os
matplotlib.rc('font', **{'family': 'serif', 'serif': ['Computer Modern']})

import seaborn as sns
sns.set_theme()
sns.set_style("white")

data_dir = base_data_dir + molecule_name + '/'

def prepare(val):
val = np.abs(val[val == val])
val[np.logical_and(val<1e-15, val!=0)] = 2e-15 # show the ones that go out of plot
val[val==0] = 1e-17 # remove zeros.
return val

xticks = []
xticklabels = []

fig, ax = plt.subplots(1, 1, figsize=(14,8))
images_subdir = f'{base_data_dir}/tmp_images/num_error/'
os.makedirs(images_subdir, exist_ok=True)

for outer_num, i in enumerate(range(iterations)):
skip = 0
print(f'figure [{i+1} / {iterations}]\r', end="")
plt.cla()
plt.title("[Iterations %i] \n"%(i+1) + _plot_title)
files = sorted([a for a in os.listdir(data_dir) if "[" not in a and int(a.split("_")[0]) == i and ".jpg" not in a and ".gif" not in a])

for num, file in enumerate(files):
val= np.load(data_dir+file, allow_pickle=True)["v"]
shape = val.shape
if np.prod(shape) <= 1:
skip += 1
continue

val = prepare(val)
val = np.sort(val)
num_max_dots = 500

if val.size > num_max_dots: val= val[::int(val.size)//num_max_dots]

ys = -np.ones(val.shape[0])*(num - skip)
ax.plot([1e-15, 1e18], [ys[0], ys[1]], 'C%i-'%(num%10), lw=10, alpha=0.2)
ax.plot(val, ys, 'C%io'%(num%10), ms=6, alpha=0.2)

if i == 0:
xticks.append(ys[0])
xticklabels.append(file.replace(".npz", "").replace("%i_"%i, ""))

plt.plot( [10**(-10), 10**(-10)], [0, xticks[-1]], 'C7--', alpha=0.6)
plt.plot( [10**(10), 10**10], [0, xticks[-1]], 'C7--', alpha=0.6)
plt.plot( [10**(0), 10**0], [0, xticks[-1]], 'C7-', alpha=1)

for x, label in zip(xticks, xticklabels):
ax.text(1e10, x+0.25, label, horizontalalignment='left', size='small', color='black', weight='normal')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think if we remove the 0.25 here the text would go on top of the line which may look at little nicer.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's true but in case the tensor values go beyond 10e10 the text would cover the dots. We might anticipate that it will never happen but as it is now it is more generic being ready for much higher values.


plt.yticks([], [])
plt.xscale("log")
plt.xlim([10**(-15), 10**18])
if i == 0: plt.tight_layout()

plt.savefig(f'{images_subdir}num_error{outer_num}.jpg')

import imageio
gif_path = f'{base_data_dir}visualize_DFT_num_error_{molecule_name}.gif'
writer = imageio.get_writer(gif_path, loop=0, duration=7)
for i in range(iterations):
writer.append_data(imageio.v2.imread(f'{images_subdir}num_error{i}.jpg'))
writer.close()
print("Numerical error visualisation saved in", gif_path)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice addition!