From dbf4e92f29be19f65a7e1990a39f72b57c3b12ce Mon Sep 17 00:00:00 2001 From: Adam Krzywaniak Date: Fri, 6 Oct 2023 09:14:10 +0000 Subject: [PATCH] add num_error visualisation on demand --- pyscf_ipu/nanoDFT/nanoDFT.py | 33 ++++++++++++++- pyscf_ipu/nanoDFT/utils.py | 78 +++++++++++++++++++++++++++++++++++- 2 files changed, 109 insertions(+), 2 deletions(-) diff --git a/pyscf_ipu/nanoDFT/nanoDFT.py b/pyscf_ipu/nanoDFT/nanoDFT.py index 5dfe9db..99bd7e3 100644 --- a/pyscf_ipu/nanoDFT/nanoDFT.py +++ b/pyscf_ipu/nanoDFT/nanoDFT.py @@ -81,8 +81,27 @@ def nanoDFT_iteration(i, vals, opts, mol): log["matrices"] = jax.lax.dynamic_update_slice(log["matrices"], H. reshape(1, 1, N, N), (i, 2, 0, 0)) log["energy"] = log["energy"].at[i].set(energy(density_matrix, H_core, diff_JK, E_xc, E_nuc)) # (iterations, 6) + if opts.vis_num_error is True: + import os + dir_label = opts.molecule_name + num_error_dir = f'num_error/{dir_label}/' + os.makedirs(num_error_dir , exist_ok=True) + + def host_callback(data, i): + # labels are adjusted to the `data` that will be passed to the callback - keep that in mind when passing different list of tensors + labels = ["density_matrix", "V_xc", "diff_JK", "O", "H_core", "L_inv", "E_nuc", "occupancy", "ERI", "grid_weights", "grid_AO", "diis_history", "E_xc", "eigvects", "H"] + for l, d in zip(labels, data): + if l == "diis_history" or l == "ERI": + for idx, arr in enumerate(d): + np.savez(f'{num_error_dir}{i}_{l}{idx}.npz', v = np.array(arr)) + else: + np.savez(f'{num_error_dir}{i}_{l}.npz', v = d) + + jax.debug.callback(host_callback, vals[:-1] + [E_xc, eigvects, H], i) + return [density_matrix, V_xc, diff_JK, O, H_core, L_inv, E_nuc, occupancy, ERI, grid_weights, grid_AO, diis_history, log] + def exchange_correlation(density_matrix, grid_AO, grid_weights): """Compute exchange correlation integral using atomic orbitals (AO) evalauted on a grid. """ # Perfectly SIMD parallelizable over grid_size axis. @@ -538,7 +557,9 @@ def nanoDFT_options( ndevices: int = 1, dense_ERI: bool = False, v: bool = False, # verbose - profile: bool = False # if we only want profile exit after IPU finishes. + profile: bool = False, # if we only want profile exit after IPU finishes. + vis_num_error: bool = False, + molecule_name: str = None ): """ nanoDFT @@ -562,6 +583,10 @@ def nanoDFT_options( ao_threshold (float): Zero out grid_AO that are below the threshold in absolute value. dense_ERI (bool): Whether to use dense ERI (s1) or sparse symmtric ERI. """ + if molecule_name is None: + # use mol_str as a molecule name (in case it has not been provided) + # before mol_str CLI arg is preprocessed and overwritten + molecule_name = mol_str # From a compound name or CID, get a list of its atoms and their coordinates mol_str = utils.process_mol_str(mol_str) @@ -603,6 +628,12 @@ def main(): nanoDFT_forces = grad(mol, grid_coords, grid_weights, mo_coeff, mo_energy) pyscf_E, pyscf_hlgap, pyscf_forces = pyscf_reference(mol_str, opts) print_difference(nanoDFT_E, nanoDFT_forces, nanoDFT_logged_E, nanoDFT_hlgap, pyscf_E, pyscf_forces, pyscf_hlgap) + + if opts.vis_num_error is True: + from utils import save_plot + import sys + _plot_title = f"Created with: python {' '.join(sys.argv)}" + save_plot("num_error/", opts.molecule_name, opts.its, _plot_title) else: # pip install mogli imageio[ffmpeg] matplotlib import mogli diff --git a/pyscf_ipu/nanoDFT/utils.py b/pyscf_ipu/nanoDFT/utils.py index dce50d2..90154ad 100644 --- a/pyscf_ipu/nanoDFT/utils.py +++ b/pyscf_ipu/nanoDFT/utils.py @@ -142,4 +142,80 @@ def min_interatomic_distance(mol_str): """This computes the minimum distance between atoms.""" coords = map(itemgetter(1), mol_str) distances = map(lambda x: np.linalg.norm(np.array(x[0]) - np.array(x[1])), combinations(coords, 2)) - return min(distances) \ No newline at end of file + return min(distances) + + +def save_plot(base_data_dir: str, molecule_name: str, iterations: int, _plot_title: str = "Default Title"): + import matplotlib.pyplot as plt + import matplotlib + import os + matplotlib.rc('font', **{'family': 'serif', 'serif': ['Computer Modern']}) + + import seaborn as sns + sns.set_theme() + sns.set_style("white") + + data_dir = base_data_dir + molecule_name + '/' + + def prepare(val): + val = np.abs(val[val == val]) + val[np.logical_and(val<1e-15, val!=0)] = 2e-15 # show the ones that go out of plot + val[val==0] = 1e-17 # remove zeros. + return val + + xticks = [] + xticklabels = [] + + fig, ax = plt.subplots(1, 1, figsize=(14,8)) + images_subdir = f'{base_data_dir}/tmp_images/num_error/' + os.makedirs(images_subdir, exist_ok=True) + + for outer_num, i in enumerate(range(iterations)): + skip = 0 + print(f'figure [{i+1} / {iterations}]\r', end="") + plt.cla() + plt.title("[Iterations %i] \n"%(i+1) + _plot_title) + files = sorted([a for a in os.listdir(data_dir) if "[" not in a and int(a.split("_")[0]) == i and ".jpg" not in a and ".gif" not in a]) + + for num, file in enumerate(files): + val= np.load(data_dir+file, allow_pickle=True)["v"] + shape = val.shape + if np.prod(shape) <= 1: + skip += 1 + continue + + val = prepare(val) + val = np.sort(val) + num_max_dots = 500 + + if val.size > num_max_dots: val= val[::int(val.size)//num_max_dots] + + ys = -np.ones(val.shape[0])*(num - skip) + ax.plot([1e-15, 1e18], [ys[0], ys[1]], 'C%i-'%(num%10), lw=10, alpha=0.2) + ax.plot(val, ys, 'C%io'%(num%10), ms=6, alpha=0.2) + + if i == 0: + xticks.append(ys[0]) + xticklabels.append(file.replace(".npz", "").replace("%i_"%i, "")) + + plt.plot( [10**(-10), 10**(-10)], [0, xticks[-1]], 'C7--', alpha=0.6) + plt.plot( [10**(10), 10**10], [0, xticks[-1]], 'C7--', alpha=0.6) + plt.plot( [10**(0), 10**0], [0, xticks[-1]], 'C7-', alpha=1) + + for x, label in zip(xticks, xticklabels): + ax.text(1e10, x+0.25, label, horizontalalignment='left', size='small', color='black', weight='normal') + + plt.yticks([], []) + plt.xscale("log") + plt.xlim([10**(-15), 10**18]) + if i == 0: plt.tight_layout() + + plt.savefig(f'{images_subdir}num_error{outer_num}.jpg') + + import imageio + gif_path = f'{base_data_dir}visualize_DFT_num_error_{molecule_name}.gif' + writer = imageio.get_writer(gif_path, loop=0, duration=7) + for i in range(iterations): + writer.append_data(imageio.v2.imread(f'{images_subdir}num_error{i}.jpg')) + writer.close() + print("Numerical error visualisation saved in", gif_path)