Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: nuclear gradients #101

Draft
wants to merge 15 commits into
base: main
Choose a base branch
from
Draft

WIP: nuclear gradients #101

wants to merge 15 commits into from

Conversation

hatemhelal
Copy link
Contributor

@hatemhelal hatemhelal commented Sep 20, 2023

To perform tasks such as geometry optimisation or molecular dynamics we need a way to compute the gradient of the total energy with respect to the nuclei centers. This PR adds the analytic evaluation of components of the energy:

  • overlap -> $\langle \nabla i | j \rangle$
  • kinetic energy -> $-\frac{1}{2} \langle \nabla i | \nabla^2 |\phi_j \rangle$
  • nuclear attraction -> $\langle \nabla i | V_N | j \rangle$
  • two-electron integrals -> $\langle \nabla i j | k l \rangle$
  • exchange-correlation -> tdb
  • nuclear-nuclear energy -> $\nabla V_{NN}$

@hatemhelal hatemhelal linked an issue Sep 20, 2023 that may be closed by this pull request
@hatemhelal hatemhelal self-assigned this Sep 21, 2023
def grad_primitive_integral(
primitive_op: Callable, atom_index: int, a: Primitive, b: Primitive
) -> Float3:
"""Generic gradient of a one-electron integral with respect the atom_index center"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be very nice to link to a derivation in sympy if possible.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had a half-hearted attempt of doing this in sympy but found it difficult to encourage it to collect terms together. A derivation would be easy enough to write out in the gto_integrals notebook.

Comment on lines 19 to 26
lhs_p1 = vmap(a.offset_lmn, (0, None))(axes, 1)
t1 = 2 * a.alpha * vmap(primitive_op, (0, None))(lhs_p1, b)

lhs_m1 = vmap(a.offset_lmn, (0, None))(axes, -1)
t2 = jnp.where(a.lmn > 0, a.lmn, jnp.zeros_like(a.lmn))
t2 *= vmap(primitive_op, (0, None))(lhs_m1, b)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I realise this is sacrilege, but I wonder if a python loop over 0,1,2 might be easier to read, and ultimately be no slower, as it will be vectorisable by a higher level vmap.
E.g. something like (artist's impression)

grad = jnp.zero(3)
for axis in range(3):
  t1 = 2 * a.alpha * primitive_op(a.offset_lmn[axis], 1, b)
  t2 = a.offset_lmn[axis] * primitive_op(a.offset_lmn[axis], -1, b)
  grad = grad.at(axis).set(t1 - t2)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or if that turns out to be slower, it might serve as reference code in another function so testing can compare them.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

trying out a list comprehension like this:

    t1 = [primitive_op(a.offset_lmn(ax, 1), b) for ax in range(3)]
    t2 = [primitive_op(a.offset_lmn(ax, -1), b) for ax in range(3)]
    grad_out = 2 * a.alpha * jnp.stack(t1) - a.lmn * jnp.stack(t2)
    return grad_out

better or worse?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely easier to read - perf impact ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did a very quick benchmark using the test case on the CPU backend:

no-jit with jit
before 1.06 s ± 2.43 ms 359 µs ± 1.93 µs
after 2.77 s ± 9.36 ms 368 µs ± 2.65 µs

so arguably no difference when we use JIT but quite a bit slower without. Happy to go with the more readable code over the "vectorise or die trying" version.



def grad_primitive_integral(
primitive_op: Callable, atom_index: int, a: Primitive, b: Primitive
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the signature of this Callable? It's great to point to an example, even from this more generic code to less generic.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added some doc-strings throughout this file and tried to document the signature required by callables being passed around.

pyscf_ipu/experimental/nuclear_gradients.py Outdated Show resolved Hide resolved
pyscf_ipu/experimental/nuclear_gradients.py Outdated Show resolved Hide resolved
Comment on lines +32 to +103
def take_primitives(indices):
p = tree_map(lambda x: jnp.take(x, indices, axis=0), primitives)
c = jnp.take(coefficients, indices)
return p, c
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like to define these local lambdas after the variables to which they refer have been defined. It puts the computation closer to the point of use.

(i.e. move it after line 37)

OTOH, batch_orbitals is already doing a lot of Python list comprehensions, so it might be clearer and the same complexity to treat b as a list of lists of primitives, and just inline the list comprehensions here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll have a think, this lambda is actually used in a few places now -> maybe it should be promoted from a function-local to a definition in a module?

rhs, cr = take_primitives(jj.reshape(-1))

op = vmap(primitive_op, (None, 0, 0))
op = jit(vmap(op, (0, None, None)))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this jit change much? It feels a bit out of place, but I get that it might be needed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hard to say...I've removed it for now since I'd like to compare jit to no-jit when this is a bit more complete.

pyscf_ipu/experimental/nuclear_gradients.py Outdated Show resolved Hide resolved
@@ -14,6 +15,7 @@ class Primitive:
alpha: float = 1.0
lmn: Int3 = jnp.zeros(3, dtype=jnp.int32)
norm: Optional[float] = None
atom_index: Optional[int] = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a back-pointer into the list of which this is a member?
That's also a hint that the where(a.atom_index == atom_index) could be lifted.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indeed, the problem I haven't solved is not redundantly storing the center on both the primitives and in the Structure

@hatemhelal hatemhelal force-pushed the 66-compute-forces-on-the-ipu branch 2 times, most recently from 27bf761 to 417ad4c Compare September 26, 2023 08:39
@hatemhelal hatemhelal force-pushed the 66-compute-forces-on-the-ipu branch from 417ad4c to 58e6383 Compare October 16, 2023 14:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Compute forces on the IPU.
2 participants