-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
WIP: nuclear gradients #101
base: main
Are you sure you want to change the base?
Conversation
def grad_primitive_integral( | ||
primitive_op: Callable, atom_index: int, a: Primitive, b: Primitive | ||
) -> Float3: | ||
"""Generic gradient of a one-electron integral with respect the atom_index center""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would be very nice to link to a derivation in sympy if possible.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had a half-hearted attempt of doing this in sympy but found it difficult to encourage it to collect terms together. A derivation would be easy enough to write out in the gto_integrals
notebook.
lhs_p1 = vmap(a.offset_lmn, (0, None))(axes, 1) | ||
t1 = 2 * a.alpha * vmap(primitive_op, (0, None))(lhs_p1, b) | ||
|
||
lhs_m1 = vmap(a.offset_lmn, (0, None))(axes, -1) | ||
t2 = jnp.where(a.lmn > 0, a.lmn, jnp.zeros_like(a.lmn)) | ||
t2 *= vmap(primitive_op, (0, None))(lhs_m1, b) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I realise this is sacrilege, but I wonder if a python loop over 0,1,2 might be easier to read, and ultimately be no slower, as it will be vectorisable by a higher level vmap.
E.g. something like (artist's impression)
grad = jnp.zero(3)
for axis in range(3):
t1 = 2 * a.alpha * primitive_op(a.offset_lmn[axis], 1, b)
t2 = a.offset_lmn[axis] * primitive_op(a.offset_lmn[axis], -1, b)
grad = grad.at(axis).set(t1 - t2)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or if that turns out to be slower, it might serve as reference code in another function so testing can compare them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
trying out a list comprehension like this:
t1 = [primitive_op(a.offset_lmn(ax, 1), b) for ax in range(3)]
t2 = [primitive_op(a.offset_lmn(ax, -1), b) for ax in range(3)]
grad_out = 2 * a.alpha * jnp.stack(t1) - a.lmn * jnp.stack(t2)
return grad_out
better or worse?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Definitely easier to read - perf impact ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did a very quick benchmark using the test case on the CPU backend:
no-jit | with jit | |
---|---|---|
before | 1.06 s ± 2.43 ms | 359 µs ± 1.93 µs |
after | 2.77 s ± 9.36 ms | 368 µs ± 2.65 µs |
so arguably no difference when we use JIT but quite a bit slower without. Happy to go with the more readable code over the "vectorise or die trying" version.
|
||
|
||
def grad_primitive_integral( | ||
primitive_op: Callable, atom_index: int, a: Primitive, b: Primitive |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the signature of this Callable? It's great to point to an example, even from this more generic code to less generic.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've added some doc-strings throughout this file and tried to document the signature required by callables being passed around.
def take_primitives(indices): | ||
p = tree_map(lambda x: jnp.take(x, indices, axis=0), primitives) | ||
c = jnp.take(coefficients, indices) | ||
return p, c |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like to define these local lambdas after the variables to which they refer have been defined. It puts the computation closer to the point of use.
(i.e. move it after line 37)
OTOH, batch_orbitals
is already doing a lot of Python list comprehensions, so it might be clearer and the same complexity to treat b
as a list of lists of primitives, and just inline the list comprehensions here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll have a think, this lambda is actually used in a few places now -> maybe it should be promoted from a function-local to a definition in a module?
rhs, cr = take_primitives(jj.reshape(-1)) | ||
|
||
op = vmap(primitive_op, (None, 0, 0)) | ||
op = jit(vmap(op, (0, None, None))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this jit change much? It feels a bit out of place, but I get that it might be needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hard to say...I've removed it for now since I'd like to compare jit to no-jit when this is a bit more complete.
@@ -14,6 +15,7 @@ class Primitive: | |||
alpha: float = 1.0 | |||
lmn: Int3 = jnp.zeros(3, dtype=jnp.int32) | |||
norm: Optional[float] = None | |||
atom_index: Optional[int] = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a back-pointer into the list of which this is a member?
That's also a hint that the where(a.atom_index == atom_index)
could be lifted.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
indeed, the problem I haven't solved is not redundantly storing the center on both the primitives and in the Structure
27bf761
to
417ad4c
Compare
417ad4c
to
58e6383
Compare
To perform tasks such as geometry optimisation or molecular dynamics we need a way to compute the gradient of the total energy with respect to the nuclei centers. This PR adds the analytic evaluation of components of the energy: