Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compute forces on the IPU. #66

Open
AlexanderMath opened this issue Sep 4, 2023 · 3 comments · May be fixed by #101
Open

Compute forces on the IPU. #66

AlexanderMath opened this issue Sep 4, 2023 · 3 comments · May be fixed by #101
Assignees
Labels

Comments

@AlexanderMath
Copy link
Contributor

AlexanderMath commented Sep 4, 2023

nanoDFT computes forces on the CPU using def grad(..) on line 230. To run def grad(..) on the IPU it is sufficient to port lines 269-273 and line 283.

Different strategies for porting lines 269-273:

  1. Compile libcint to poplar and replace all mol.intor(..) with corresponding poplar calls (ERI is only problematic part).
  2. Use Jax implementation from D4FT for the forward pass of the mol.intor(..) and match up the jax.grad(..) of the forward passes with lines 269-273 (pyscfad matched up libcint with jax.grad for CPU => their code may be helpful).
  3. Reimplement all integrals from first principles in Jax/tesselate.

Note: Line 230 uses this theorem to compute gradients. We could use jax.grad(_nanoDFT) instead of the theorem. That would require us to fix all calls in _nanoDFT(..) which don't support derivatives. We currently believe the work involved is the same as fixing def grad(..) (see the above different strategies). In other words: the non-autograd stuff _nanoDFT calls are calls which have derivatives as computed on line 269-273 and 283.

@AlexanderMath
Copy link
Contributor Author

AlexanderMath commented Sep 4, 2023

Note: If we use autodiff/jax.grad we need to store "activations" during jax.lax.fori_loop(0, opts.its, nanoDFT_iteration, ..). We might use jax.checkpoint to only store density_matrix of shape (N,N) in each iteration, and then during backprop keep all activations within an iteration in memory. This leads to a peak memory consumption of N^2*num_iterations + floats_within_one_iteration.

Note: It takes 3x more memory to store ERI_grad = mol.intor("int2e_ip") because ERI_grad.shape=(N, N, N, N, 3). ERI_grad is only used twice see lines 244-245 (similar einsum to how ERI is used).

vj = - jnp.einsum('sijkl,lk->sij', ERI_grad, dm0) # (3, N, N)
vk = - jnp.einsum('sijkl,jk->sil', ERI_grad, dm0) # (3, N, N)

Since we only have "one iteration of einsums" (opposed to 20 in forward pass), there is no advantage in storing ERI_grad. We might aswell compute the needed entries on the fly during the einsums. This may cause problems with the trick in #65. We may reuse the sparsity-pattern found in #63.

@AlexanderMath
Copy link
Contributor Author

@AlexanderMath
Copy link
Contributor Author

Another implementation (watch our for license) https://github.com/theochem/gbasis

@hatemhelal hatemhelal linked a pull request Sep 20, 2023 that will close this issue
6 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants