You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Compile libcint to poplar and replace all mol.intor(..) with corresponding poplar calls (ERI is only problematic part).
Use Jax implementation from D4FT for the forward pass of the mol.intor(..) and match up the jax.grad(..) of the forward passes with lines 269-273 (pyscfad matched up libcint with jax.grad for CPU => their code may be helpful).
Reimplement all integrals from first principles in Jax/tesselate.
Note: Line 230 uses this theorem to compute gradients. We could use jax.grad(_nanoDFT) instead of the theorem. That would require us to fix all calls in _nanoDFT(..) which don't support derivatives. We currently believe the work involved is the same as fixing def grad(..) (see the above different strategies). In other words: the non-autograd stuff _nanoDFT calls are calls which have derivatives as computed on line 269-273 and 283.
The text was updated successfully, but these errors were encountered:
Note: If we use autodiff/jax.grad we need to store "activations" during jax.lax.fori_loop(0, opts.its, nanoDFT_iteration, ..). We might use jax.checkpoint to only store density_matrix of shape (N,N) in each iteration, and then during backprop keep all activations within an iteration in memory. This leads to a peak memory consumption of N^2*num_iterations + floats_within_one_iteration.
Note: It takes 3x more memory to store ERI_grad = mol.intor("int2e_ip") because ERI_grad.shape=(N, N, N, N, 3). ERI_grad is only used twice see lines 244-245 (similar einsum to how ERI is used).
vj = - jnp.einsum('sijkl,lk->sij', ERI_grad, dm0) # (3, N, N)
vk = - jnp.einsum('sijkl,jk->sil', ERI_grad, dm0) # (3, N, N)
Since we only have "one iteration of einsums" (opposed to 20 in forward pass), there is no advantage in storing ERI_grad. We might aswell compute the needed entries on the fly during the einsums. This may cause problems with the trick in #65. We may reuse the sparsity-pattern found in #63.
nanoDFT computes forces on the CPU using
def grad(..)
on line 230. To rundef grad(..)
on the IPU it is sufficient to port lines 269-273 and line 283.Different strategies for porting lines 269-273:
mol.intor(..)
with corresponding poplar calls (ERI is only problematic part).mol.intor(..)
and match up thejax.grad(..)
of the forward passes with lines 269-273 (pyscfad matched up libcint withjax.grad
for CPU => their code may be helpful).Note: Line 230 uses this theorem to compute gradients. We could use
jax.grad(_nanoDFT)
instead of the theorem. That would require us to fix all calls in_nanoDFT(..)
which don't support derivatives. We currently believe the work involved is the same as fixingdef grad(..)
(see the above different strategies). In other words: the non-autograd stuff_nanoDFT
calls are calls which have derivatives as computed on line 269-273 and 283.The text was updated successfully, but these errors were encountered: