Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add algebraic trick for ipu_eigh initialisation #104

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

akrzgc
Copy link
Contributor

@akrzgc akrzgc commented Sep 21, 2023

The change:

  • adds algebraic trick with the ipu_eigh initialization
  • adds ipu_eigh_its parameter so that we can now select the number of ipu_eigh iterations from CLI
  • adds visualisation of eigenvectors for each nanoDFT iteration which is created as animated gif when new parameter --visualise_eigvects true

@@ -39,8 +39,8 @@ def energy(density_matrix, H_core, diff_JK, E_xc, E_nuc, _np=jax.numpy):

def nanoDFT_iteration(i, vals, opts, mol):
"""Each call updates density_matrix attempting to minimize energy(density_matrix, ... ). """
density_matrix, V_xc, diff_JK, O, H_core, L_inv = vals[:6] # All (N, N) matrices
E_nuc, occupancy, ERI, grid_weights, grid_AO, diis_history, log = vals[6:] # Varying types/shapes.
density_matrix, V_xc, diff_JK, O, H_core, L_inv, prev_eigvects = vals[:7] # All (N, N) matrices
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assuming initialization prev_eigvects=np.eye(N).

@@ -66,10 +66,14 @@ def nanoDFT_iteration(i, vals, opts, mol):
if opts.diis: H, diis_history = DIIS(i, H, density_matrix, O, diis_history, opts) # H_{i+1}=c_1H_i+...+c9H_{i-9}.

# Step 2: Solve eigh (L_inv turns generalized eigh into eigh).
eigvects = L_inv.T @ linalg_eigh(L_inv @ H @ L_inv.T, opts)[1] # (N, N)
linalg_input = L_inv @ H @ L_inv.T
Copy link
Contributor

@AlexanderMath AlexanderMath Sep 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might consider merging all the eigh tricks together:

L_inv.T @ prev_eigvects @ eigh(prev_eigvects.T @ L_inv @ H @ L_inv.T @ prev_eigvects)

This hints that the L_inv trick (to turn generalized eigh into eigh) is the same as the initial_guess stuff. Unfortunately, this has the downside of removing variable names so less documentation.

@@ -132,13 +137,14 @@ def _nanoDFT(state, ERI, grid_AO, grid_weights, opts, mol):

# Log matrices from all DFT iterations (not used by DFT algorithm).
N = H_core.shape[0]
log = {"matrices": np.zeros((opts.its, 4, N, N)), "E_xc": np.zeros((opts.its)), "energy": np.zeros((opts.its, 5))}
log = {"matrices": np.zeros((opts.its, 4, N, N)), "E_xc": np.zeros((opts.its)), "energy": np.zeros((opts.its, 5)), "eigenvectors" : np.zeros((opts.its, N, N))}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assuming this is mainly for plotting. Could you try to profile at.[i].set(..) compared to jax.lax.dynamic_update_slice(..)? I've had issues with the at[i].set(..) being 10x slower due to unfortunate memory layout.

@@ -350,7 +366,7 @@ def linalg_eigh(x, opts):
x = jnp.pad(x, [(0, 1), (0, 1)], mode='constant')
#assert False

eigvects, eigvals = ipu_eigh(x, sort_eigenvalues=True, num_iters=12)
eigvects, eigvals = ipu_eigh(x, sort_eigenvalues=True, num_iters=opts.ipu_eigh_its)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

neat.

@AlexanderMath
Copy link
Contributor

AlexanderMath commented Sep 21, 2023

Looks great! Minor comments:
(1) some stuff wrt naming and hinting on linear algebra stuff (likely not too useful) and
(2) potential performance issue with at[i].set(.) vs jax.lax.dynamic_update_slice(.) (could you try to run e.g. benzene on IPU with both, if the cycle count is +-5% my concerns are misplaced).

@awf
Copy link
Collaborator

awf commented Sep 21, 2023

We need a before/after graph or table to show the effect of the change, all other things being fixed

@akrzgc
Copy link
Contributor Author

akrzgc commented Sep 21, 2023

We need a before/after graph or table to show the effect of the change, all other things being fixed

Do you mean before/after of ipu_eigh input?
Should I attach it to this review or would you like it to be plotted each time nanoDFT.py is run?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants