-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add algebraic trick for ipu_eigh initialisation #104
base: main
Are you sure you want to change the base?
Conversation
@@ -39,8 +39,8 @@ def energy(density_matrix, H_core, diff_JK, E_xc, E_nuc, _np=jax.numpy): | |||
|
|||
def nanoDFT_iteration(i, vals, opts, mol): | |||
"""Each call updates density_matrix attempting to minimize energy(density_matrix, ... ). """ | |||
density_matrix, V_xc, diff_JK, O, H_core, L_inv = vals[:6] # All (N, N) matrices | |||
E_nuc, occupancy, ERI, grid_weights, grid_AO, diis_history, log = vals[6:] # Varying types/shapes. | |||
density_matrix, V_xc, diff_JK, O, H_core, L_inv, prev_eigvects = vals[:7] # All (N, N) matrices |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Assuming initialization prev_eigvects=np.eye(N)
.
@@ -66,10 +66,14 @@ def nanoDFT_iteration(i, vals, opts, mol): | |||
if opts.diis: H, diis_history = DIIS(i, H, density_matrix, O, diis_history, opts) # H_{i+1}=c_1H_i+...+c9H_{i-9}. | |||
|
|||
# Step 2: Solve eigh (L_inv turns generalized eigh into eigh). | |||
eigvects = L_inv.T @ linalg_eigh(L_inv @ H @ L_inv.T, opts)[1] # (N, N) | |||
linalg_input = L_inv @ H @ L_inv.T |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We might consider merging all the eigh tricks together:
L_inv.T @ prev_eigvects @ eigh(prev_eigvects.T @ L_inv @ H @ L_inv.T @ prev_eigvects)
This hints that the L_inv trick (to turn generalized eigh into eigh) is the same as the initial_guess stuff. Unfortunately, this has the downside of removing variable names so less documentation.
@@ -132,13 +137,14 @@ def _nanoDFT(state, ERI, grid_AO, grid_weights, opts, mol): | |||
|
|||
# Log matrices from all DFT iterations (not used by DFT algorithm). | |||
N = H_core.shape[0] | |||
log = {"matrices": np.zeros((opts.its, 4, N, N)), "E_xc": np.zeros((opts.its)), "energy": np.zeros((opts.its, 5))} | |||
log = {"matrices": np.zeros((opts.its, 4, N, N)), "E_xc": np.zeros((opts.its)), "energy": np.zeros((opts.its, 5)), "eigenvectors" : np.zeros((opts.its, N, N))} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Assuming this is mainly for plotting. Could you try to profile at.[i].set(..)
compared to jax.lax.dynamic_update_slice(..)
? I've had issues with the at[i].set(..)
being 10x slower due to unfortunate memory layout.
@@ -350,7 +366,7 @@ def linalg_eigh(x, opts): | |||
x = jnp.pad(x, [(0, 1), (0, 1)], mode='constant') | |||
#assert False | |||
|
|||
eigvects, eigvals = ipu_eigh(x, sort_eigenvalues=True, num_iters=12) | |||
eigvects, eigvals = ipu_eigh(x, sort_eigenvalues=True, num_iters=opts.ipu_eigh_its) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
neat.
Looks great! Minor comments: |
We need a before/after graph or table to show the effect of the change, all other things being fixed |
Do you mean before/after of |
The change:
ipu_eigh
initializationipu_eigh_its
parameter so that we can now select the number ofipu_eigh
iterations from CLI--visualise_eigvects true