-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add num_error visualisation on demand #118
Conversation
Looks good. Only suggestion: I think if we remove the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good! Only comment: The +0.25
in label utils.py line 206 may be removed/adjusted to text appear on top of the lines.
@@ -81,8 +81,27 @@ def nanoDFT_iteration(i, vals, opts, mol): | |||
log["matrices"] = jax.lax.dynamic_update_slice(log["matrices"], H. reshape(1, 1, N, N), (i, 2, 0, 0)) | |||
log["energy"] = log["energy"].at[i].set(energy(density_matrix, H_core, diff_JK, E_xc, E_nuc)) # (iterations, 6) | |||
|
|||
if opts.vis_num_error is True: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Neat. I debated with myself whether it'd be nice to have it in separate function. I prefer current form: I like that current form allow quick change of callback which is neat.
@@ -562,6 +583,10 @@ def nanoDFT_options( | |||
ao_threshold (float): Zero out grid_AO that are below the threshold in absolute value. | |||
dense_ERI (bool): Whether to use dense ERI (s1) or sparse symmtric ERI. | |||
""" | |||
if molecule_name is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok
@@ -603,6 +628,12 @@ def main(): | |||
nanoDFT_forces = grad(mol, grid_coords, grid_weights, mo_coeff, mo_energy) | |||
pyscf_E, pyscf_hlgap, pyscf_forces = pyscf_reference(mol_str, opts) | |||
print_difference(nanoDFT_E, nanoDFT_forces, nanoDFT_logged_E, nanoDFT_hlgap, pyscf_E, pyscf_forces, pyscf_hlgap) | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok
|
||
for x, label in zip(xticks, xticklabels): | ||
ax.text(1e10, x+0.25, label, horizontalalignment='left', size='small', color='black', weight='normal') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think if we remove the 0.25 here the text would go on top of the line which may look at little nicer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's true but in case the tensor values go beyond 10e10 the text would cover the dots. We might anticipate that it will never happen but as it is now it is more generic being ready for much higher values.
writer.append_data(imageio.v2.imread(f'{images_subdir}num_error{i}.jpg')) | ||
writer.close() | ||
print("Numerical error visualisation saved in", gif_path) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice addition!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry clicked "comment" instead of "approve". Feel free to decide whether text should be on top of lines by modifying the +0.25
.
The change introduces functionality of numerical error visualisation for selected tensors (all possible at the moment).
In order to visualise numerical error for a given
nanoDFT.py
run one should use new parameter--vis_num_error true
.The script uses host callback to be able to save the tensors from the device to a number of files so that option shall not be used when profiling performance.
The script assumes that the parameter provided as a
--mol_str
from the command line will be treated as molecule name and based on that the name of required helper directories and files will be created.However, one may also use a newly introduced parameter
--molecule_name
which would overwrite the--mol_str
provided by the CLI.The script executed with
--vis_num_error true
shall provide similar result to below ones:The change is addressing the first part of: #43