-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add num_error visualisation on demand #118
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -81,8 +81,27 @@ def nanoDFT_iteration(i, vals, opts, mol): | |
log["matrices"] = jax.lax.dynamic_update_slice(log["matrices"], H. reshape(1, 1, N, N), (i, 2, 0, 0)) | ||
log["energy"] = log["energy"].at[i].set(energy(density_matrix, H_core, diff_JK, E_xc, E_nuc)) # (iterations, 6) | ||
|
||
if opts.vis_num_error is True: | ||
import os | ||
dir_label = opts.molecule_name | ||
num_error_dir = f'num_error/{dir_label}/' | ||
os.makedirs(num_error_dir , exist_ok=True) | ||
|
||
def host_callback(data, i): | ||
# labels are adjusted to the `data` that will be passed to the callback - keep that in mind when passing different list of tensors | ||
labels = ["density_matrix", "V_xc", "diff_JK", "O", "H_core", "L_inv", "E_nuc", "occupancy", "ERI", "grid_weights", "grid_AO", "diis_history", "E_xc", "eigvects", "H"] | ||
for l, d in zip(labels, data): | ||
if l == "diis_history" or l == "ERI": | ||
for idx, arr in enumerate(d): | ||
np.savez(f'{num_error_dir}{i}_{l}{idx}.npz', v = np.array(arr)) | ||
else: | ||
np.savez(f'{num_error_dir}{i}_{l}.npz', v = d) | ||
|
||
jax.debug.callback(host_callback, vals[:-1] + [E_xc, eigvects, H], i) | ||
|
||
return [density_matrix, V_xc, diff_JK, O, H_core, L_inv, E_nuc, occupancy, ERI, grid_weights, grid_AO, diis_history, log] | ||
|
||
|
||
def exchange_correlation(density_matrix, grid_AO, grid_weights): | ||
"""Compute exchange correlation integral using atomic orbitals (AO) evalauted on a grid. """ | ||
# Perfectly SIMD parallelizable over grid_size axis. | ||
|
@@ -538,7 +557,9 @@ def nanoDFT_options( | |
ndevices: int = 1, | ||
dense_ERI: bool = False, | ||
v: bool = False, # verbose | ||
profile: bool = False # if we only want profile exit after IPU finishes. | ||
profile: bool = False, # if we only want profile exit after IPU finishes. | ||
vis_num_error: bool = False, | ||
molecule_name: str = None | ||
): | ||
""" | ||
nanoDFT | ||
|
@@ -562,6 +583,10 @@ def nanoDFT_options( | |
ao_threshold (float): Zero out grid_AO that are below the threshold in absolute value. | ||
dense_ERI (bool): Whether to use dense ERI (s1) or sparse symmtric ERI. | ||
""" | ||
if molecule_name is None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok |
||
# use mol_str as a molecule name (in case it has not been provided) | ||
# before mol_str CLI arg is preprocessed and overwritten | ||
molecule_name = mol_str | ||
|
||
# From a compound name or CID, get a list of its atoms and their coordinates | ||
mol_str = utils.process_mol_str(mol_str) | ||
|
@@ -603,6 +628,12 @@ def main(): | |
nanoDFT_forces = grad(mol, grid_coords, grid_weights, mo_coeff, mo_energy) | ||
pyscf_E, pyscf_hlgap, pyscf_forces = pyscf_reference(mol_str, opts) | ||
print_difference(nanoDFT_E, nanoDFT_forces, nanoDFT_logged_E, nanoDFT_hlgap, pyscf_E, pyscf_forces, pyscf_hlgap) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok |
||
if opts.vis_num_error is True: | ||
from utils import save_plot | ||
import sys | ||
_plot_title = f"Created with: python {' '.join(sys.argv)}" | ||
save_plot("num_error/", opts.molecule_name, opts.its, _plot_title) | ||
else: | ||
# pip install mogli imageio[ffmpeg] matplotlib | ||
import mogli | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -142,4 +142,80 @@ def min_interatomic_distance(mol_str): | |
"""This computes the minimum distance between atoms.""" | ||
coords = map(itemgetter(1), mol_str) | ||
distances = map(lambda x: np.linalg.norm(np.array(x[0]) - np.array(x[1])), combinations(coords, 2)) | ||
return min(distances) | ||
return min(distances) | ||
|
||
|
||
def save_plot(base_data_dir: str, molecule_name: str, iterations: int, _plot_title: str = "Default Title"): | ||
import matplotlib.pyplot as plt | ||
import matplotlib | ||
import os | ||
matplotlib.rc('font', **{'family': 'serif', 'serif': ['Computer Modern']}) | ||
|
||
import seaborn as sns | ||
sns.set_theme() | ||
sns.set_style("white") | ||
|
||
data_dir = base_data_dir + molecule_name + '/' | ||
|
||
def prepare(val): | ||
val = np.abs(val[val == val]) | ||
val[np.logical_and(val<1e-15, val!=0)] = 2e-15 # show the ones that go out of plot | ||
val[val==0] = 1e-17 # remove zeros. | ||
return val | ||
|
||
xticks = [] | ||
xticklabels = [] | ||
|
||
fig, ax = plt.subplots(1, 1, figsize=(14,8)) | ||
images_subdir = f'{base_data_dir}/tmp_images/num_error/' | ||
os.makedirs(images_subdir, exist_ok=True) | ||
|
||
for outer_num, i in enumerate(range(iterations)): | ||
skip = 0 | ||
print(f'figure [{i+1} / {iterations}]\r', end="") | ||
plt.cla() | ||
plt.title("[Iterations %i] \n"%(i+1) + _plot_title) | ||
files = sorted([a for a in os.listdir(data_dir) if "[" not in a and int(a.split("_")[0]) == i and ".jpg" not in a and ".gif" not in a]) | ||
|
||
for num, file in enumerate(files): | ||
val= np.load(data_dir+file, allow_pickle=True)["v"] | ||
shape = val.shape | ||
if np.prod(shape) <= 1: | ||
skip += 1 | ||
continue | ||
|
||
val = prepare(val) | ||
val = np.sort(val) | ||
num_max_dots = 500 | ||
|
||
if val.size > num_max_dots: val= val[::int(val.size)//num_max_dots] | ||
|
||
ys = -np.ones(val.shape[0])*(num - skip) | ||
ax.plot([1e-15, 1e18], [ys[0], ys[1]], 'C%i-'%(num%10), lw=10, alpha=0.2) | ||
ax.plot(val, ys, 'C%io'%(num%10), ms=6, alpha=0.2) | ||
|
||
if i == 0: | ||
xticks.append(ys[0]) | ||
xticklabels.append(file.replace(".npz", "").replace("%i_"%i, "")) | ||
|
||
plt.plot( [10**(-10), 10**(-10)], [0, xticks[-1]], 'C7--', alpha=0.6) | ||
plt.plot( [10**(10), 10**10], [0, xticks[-1]], 'C7--', alpha=0.6) | ||
plt.plot( [10**(0), 10**0], [0, xticks[-1]], 'C7-', alpha=1) | ||
|
||
for x, label in zip(xticks, xticklabels): | ||
ax.text(1e10, x+0.25, label, horizontalalignment='left', size='small', color='black', weight='normal') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think if we remove the 0.25 here the text would go on top of the line which may look at little nicer. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's true but in case the tensor values go beyond 10e10 the text would cover the dots. We might anticipate that it will never happen but as it is now it is more generic being ready for much higher values. |
||
|
||
plt.yticks([], []) | ||
plt.xscale("log") | ||
plt.xlim([10**(-15), 10**18]) | ||
if i == 0: plt.tight_layout() | ||
|
||
plt.savefig(f'{images_subdir}num_error{outer_num}.jpg') | ||
|
||
import imageio | ||
gif_path = f'{base_data_dir}visualize_DFT_num_error_{molecule_name}.gif' | ||
writer = imageio.get_writer(gif_path, loop=0, duration=7) | ||
for i in range(iterations): | ||
writer.append_data(imageio.v2.imread(f'{images_subdir}num_error{i}.jpg')) | ||
writer.close() | ||
print("Numerical error visualisation saved in", gif_path) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice addition! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Neat. I debated with myself whether it'd be nice to have it in separate function. I prefer current form: I like that current form allow quick change of callback which is neat.