Skip to content

Commit

Permalink
fix exp ctr_coeff deriv
Browse files Browse the repository at this point in the history
  • Loading branch information
fishjojo committed Jan 10, 2024
1 parent b2dcfcb commit 9c0c022
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 1 deletion.
6 changes: 6 additions & 0 deletions pyscfad/gto/moleintor.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,12 @@ def _getints2c_rc_jvp(intor, shls_slice, comp, hermi, out, rc_deriv,
if mol.coords is not None:
intor_ip_bra, intor_ip_ket = int1e_dr1_name(intor)
tangent_out += _gen_int1e_jvp_r0(mol, mol_t, intor_ip_bra, intor_ip_ket, rc_deriv)

if mol.ctr_coeff is not None:
tangent_out += _int1e_jvp_cs(mol, mol_t, intor, shls_slice, comp, hermi)

if mol.exp is not None:
tangent_out += _int1e_jvp_exp(mol, mol_t, intor, shls_slice, comp, hermi)
return primal_out, tangent_out

@partial(custom_jvp, nondiff_argnums=(1,2,3,4,5))
Expand Down
78 changes: 77 additions & 1 deletion pyscfad/gto/test/test_deriv1_cs_exp.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import pytest
import jax
from functools import partial
import numpy as np
import jax
from jax import numpy as jnp
from pyscfad import gto
from pyscfad.gto._mole_helper import setup_exp, setup_ctr_coeff
from .test_int1e import grad_analyt, nuc_grad_analyt

INTORS = ['int1e_kin_dr10', 'int1e_kin_dr01',
'int1e_ovlp_dr10', 'int1e_ovlp_dr01',
Expand Down Expand Up @@ -39,6 +43,35 @@ def exp_grad_fd(mol, intor):
g = four_point_fd(mol, intor, _env_of, disp)
return g

def two_point_fd(mol, fn, _env_of, disp=1e-4):
grad_fd = []
for _, ptr_exp in enumerate(_env_of):
mol._env[ptr_exp] += disp
sp = fn(mol)

mol._env[ptr_exp] -= 2*disp
sm = fn(mol)

g = (sp-sm) / (2*disp)
grad_fd.append(g)
mol._env[ptr_exp] += disp

grad_fd = np.asarray(grad_fd)
grad_fd = np.moveaxis(grad_fd, 0, -1)
return grad_fd

def cs_grad_fd1(mol, fn):
disp = 1e-3
_, _, _env_of = setup_ctr_coeff(mol)
g = two_point_fd(mol, fn, _env_of, disp)
return g

def exp_grad_fd1(mol, fn):
disp = 1e-4
_, _, _env_of = setup_exp(mol)
g = two_point_fd(mol, fn, _env_of, disp)
return g

def test_cs_exp():
mol = gto.Mole()
mol.atom = 'H 0 0 0; H 0 0 .74' # in Angstrom
Expand All @@ -60,3 +93,46 @@ def test_cs_exp():
assert e_fwd < tol
e_rev1 = abs(jac_rev.exp - g_fd).max()
assert e_rev < tol

def test_chain_deriv():
mol = gto.Mole()
mol.atom = 'H 0 0 0; H 0 0 .74' # in Angstrom
mol.basis = 'sto3g'
mol.build()

def grad_loss_ad(mol, intor):
def fn(mol):
ints = mol.intor(intor)
return jnp.linalg.norm(ints)
g = jax.grad(fn)(mol).coords
g_norm = jnp.linalg.norm(g)
return g_norm

def grad_loss_analyt(mol, intor):
ints = mol.intor(intor)
jac = grad_analyt(mol, intor.replace('int1e_', 'int1e_ip'))
g = np.einsum("ij,ijnx->nx", ints, jac) / np.linalg.norm(ints)
g_norm = np.linalg.norm(g)
return g_norm

def nuc_grad_loss_analyt(mol):
ints = mol.intor('int1e_nuc')
jac = nuc_grad_analyt(mol)
g = np.einsum("ij,ijnx->nx", ints, jac) / np.linalg.norm(ints)
g_norm = np.linalg.norm(g)
return g_norm

tol = 1e-6

for intor in ['int1e_kin', 'int1e_ovlp']:
grad_cs = cs_grad_fd1(mol, partial(grad_loss_analyt, intor=intor))
grad_exp = exp_grad_fd1(mol, partial(grad_loss_analyt, intor=intor))
jac = jax.grad(grad_loss_ad)(mol, intor)
assert abs(jac.ctr_coeff - grad_cs).max() < tol
assert abs(jac.exp - grad_exp).max() < tol

grad_cs = cs_grad_fd1(mol, nuc_grad_loss_analyt)
grad_exp = exp_grad_fd1(mol, nuc_grad_loss_analyt)
jac = jax.grad(grad_loss_ad)(mol, 'int1e_nuc')
assert abs(jac.ctr_coeff - grad_cs).max() < tol
assert abs(jac.exp - grad_exp).max() < tol

0 comments on commit 9c0c022

Please sign in to comment.