Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add num_error visualisation on demand #118

Merged
merged 1 commit into from
Oct 10, 2023
Merged

add num_error visualisation on demand #118

merged 1 commit into from
Oct 10, 2023

Conversation

akrzgc
Copy link
Contributor

@akrzgc akrzgc commented Oct 6, 2023

The change introduces functionality of numerical error visualisation for selected tensors (all possible at the moment).
In order to visualise numerical error for a given nanoDFT.py run one should use new parameter --vis_num_error true.
The script uses host callback to be able to save the tensors from the device to a number of files so that option shall not be used when profiling performance.
The script assumes that the parameter provided as a --mol_str from the command line will be treated as molecule name and based on that the name of required helper directories and files will be created.
However, one may also use a newly introduced parameter --molecule_name which would overwrite the --mol_str provided by the CLI.

The script executed with --vis_num_error true shall provide similar result to below ones:
visualize_DFT_num_error_benzene
visualize_DFT_num_error_c20

The change is addressing the first part of: #43

@AlexanderMath
Copy link
Contributor

Looks good. Only suggestion: I think if we remove the +15 here the label would be put on top of the line.

Copy link
Contributor

@AlexanderMath AlexanderMath left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! Only comment: The +0.25 in label utils.py line 206 may be removed/adjusted to text appear on top of the lines.

@@ -81,8 +81,27 @@ def nanoDFT_iteration(i, vals, opts, mol):
log["matrices"] = jax.lax.dynamic_update_slice(log["matrices"], H. reshape(1, 1, N, N), (i, 2, 0, 0))
log["energy"] = log["energy"].at[i].set(energy(density_matrix, H_core, diff_JK, E_xc, E_nuc)) # (iterations, 6)

if opts.vis_num_error is True:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Neat. I debated with myself whether it'd be nice to have it in separate function. I prefer current form: I like that current form allow quick change of callback which is neat.

@@ -562,6 +583,10 @@ def nanoDFT_options(
ao_threshold (float): Zero out grid_AO that are below the threshold in absolute value.
dense_ERI (bool): Whether to use dense ERI (s1) or sparse symmtric ERI.
"""
if molecule_name is None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok

@@ -603,6 +628,12 @@ def main():
nanoDFT_forces = grad(mol, grid_coords, grid_weights, mo_coeff, mo_energy)
pyscf_E, pyscf_hlgap, pyscf_forces = pyscf_reference(mol_str, opts)
print_difference(nanoDFT_E, nanoDFT_forces, nanoDFT_logged_E, nanoDFT_hlgap, pyscf_E, pyscf_forces, pyscf_hlgap)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok


for x, label in zip(xticks, xticklabels):
ax.text(1e10, x+0.25, label, horizontalalignment='left', size='small', color='black', weight='normal')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think if we remove the 0.25 here the text would go on top of the line which may look at little nicer.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's true but in case the tensor values go beyond 10e10 the text would cover the dots. We might anticipate that it will never happen but as it is now it is more generic being ready for much higher values.

writer.append_data(imageio.v2.imread(f'{images_subdir}num_error{i}.jpg'))
writer.close()
print("Numerical error visualisation saved in", gif_path)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice addition!

@AlexanderMath AlexanderMath removed the request for review from awf October 9, 2023 15:34
Copy link
Contributor

@AlexanderMath AlexanderMath left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry clicked "comment" instead of "approve". Feel free to decide whether text should be on top of lines by modifying the +0.25.

@akrzgc akrzgc merged commit bd5c7ff into main Oct 10, 2023
4 checks passed
@akrzgc akrzgc deleted the visualise_num_error branch October 10, 2023 13:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants