Skip to content

Commit

Permalink
add eval_gto exponent gradient
Browse files Browse the repository at this point in the history
  • Loading branch information
fishjojo committed Jan 8, 2024
1 parent 5e7d27d commit 0d420a7
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 18 deletions.
4 changes: 2 additions & 2 deletions pyscfad/gto/_mole_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,8 +178,8 @@ def setup_ctr_coeff(mol):
cs_of = cs_of[idx]
return cs, cs_of, env_of

def get_fakemol_exp(mol, order=2):
mol1 = uncontract(mol)
def get_fakemol_exp(mol, order=2, shls_slice=None):
mol1 = uncontract(mol, shls_slice=shls_slice)
mol1._bas[:,ANG_OF] += order
return mol1

Expand Down
90 changes: 78 additions & 12 deletions pyscfad/gto/eval_gto.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
from pyscf.gto.eval_gto import eval_gto as pyscf_eval_gto
from pyscfad.lib import jit, custom_jvp, vmap
from pyscfad.gto import mole
#from pyscfad.gto.moleintor import get_bas_label
from pyscfad.gto.moleintor import get_bas_label, promote_xyz
from pyscfad.gto._mole_helper import (
# setup_exp,
setup_exp,
setup_ctr_coeff,
# get_fakemol_exp,
get_fakemol_exp,
get_fakemol_cs,
)

Expand Down Expand Up @@ -83,6 +83,9 @@ def _eval_gto(mol, eval_name, grid_coords,
@_eval_gto.defjvp
def _eval_gto_jvp(eval_name, comp, shls_slice, non0tab, ao_loc, cutoff, out,
primals, tangents):
if 'spinor' in eval_name or 'ip' in eval_name:
raise NotImplementedError

mol, grid_coords = primals
mol_t, grid_coords_t = tangents

Expand All @@ -107,15 +110,15 @@ def _eval_gto_jvp(eval_name, comp, shls_slice, non0tab, ao_loc, cutoff, out,

def _eval_gto_jvp_r(mol, eval_name, grid_coords, grid_coords_t,
comp, shls_slice, non0tab, ao_loc, nao):
if 'deriv'+str(_MAX_DERIV_ORDER) in eval_name:
raise NotImplementedError
if 'deriv' not in eval_name:
new_eval = eval_name + '_deriv1'
order = 0
new_eval = eval_name + '_deriv1'
else:
tmp = eval_name.split('deriv', 1)
order = int(tmp[1])
new_eval = tmp[0] + 'deriv' + str(order + 1)
if order + 1 > _MAX_DERIV_ORDER:
raise NotImplementedError

ng = grid_coords.shape[0]
ao1 = _eval_gto(mol, new_eval, grid_coords, None, shls_slice, non0tab,
Expand Down Expand Up @@ -203,15 +206,15 @@ def body(slices, coords_t):
return fn(-ao1, ids)

def _eval_gto_jvp_r0(mol, mol_t, eval_name, grid_coords, comp, shls_slice, non0tab, ao_loc):
if 'deriv'+str(_MAX_DERIV_ORDER) in eval_name:
raise NotImplementedError
if 'deriv' not in eval_name:
new_eval = eval_name + '_deriv1'
order = 0
new_eval = eval_name + '_deriv1'
else:
tmp = eval_name.split('deriv', 1)
order = int(tmp[1])
new_eval = tmp[0] + 'deriv' + str(order + 1)
if order + 1 > _MAX_DERIV_ORDER:
raise NotImplementedError

ao1 = _eval_gto(mol, new_eval, grid_coords, None, shls_slice, non0tab,
ao_loc, None, None)
Expand All @@ -235,8 +238,8 @@ def _eval_gto_jvp_cs(mol, mol_t, eval_name, grid_coords, comp, shls_slice, non0t
shl0, shl1 = shls_slice

mol1 = get_fakemol_cs(mol, shls_slice)
ao1 = _eval_gto(mol1, eval_name, grid_coords, None, None, non0tab,
ao_loc, None, None)
ao1 = _eval_gto(mol1, eval_name, grid_coords, comp, None, non0tab,
None, None, None)
ao1 = ao1.reshape(comp,ngrids,-1)

nao_id0, nao_id1 = mole.nao_nr_range(mol, shl0, shl1)
Expand Down Expand Up @@ -266,4 +269,67 @@ def _eval_gto_jvp_cs(mol, mol_t, eval_name, grid_coords, comp, shls_slice, non0t
return tangent_out

def _eval_gto_jvp_exp(mol, mol_t, eval_name, grid_coords, comp, shls_slice, non0tab, ao_loc):
return 0
if 'sph' in eval_name:
eval_name = eval_name.replace('sph', 'cart')
cart = False
elif 'cart' in eval_name:
cart = True
else:
raise KeyError

if shls_slice is None:
shls_slice = (0, mol.nbas)
shl0, shl1 = shls_slice

mol1 = get_fakemol_exp(mol, shls_slice=shls_slice)
ao1 = _eval_gto(mol1, eval_name, grid_coords, comp, None, non0tab,
None, None, None)

es, es_of, _env_of = setup_exp(mol)

ngrids = len(grid_coords)
nao = pyscf_mole.nao_cart(mol)
ao1 = ao1.reshape(comp, ngrids, -1)
# 1st order derivative only
grad = numpy.zeros((comp, ngrids, len(es), nao), dtype=ao1.dtype)

off = 0
ibas = 0
for i in range(shl0, shl1):
ioff = es_of[i]

l = mol._bas[i,pyscf_mole.ANG_OF]
nbas = (l+1)*(l+2)//2
nbas1 = (l+3)*(l+4)//2
nprim = mol._bas[i,pyscf_mole.NPRIM_OF]
nctr = mol._bas[i,pyscf_mole.NCTR_OF]
ptr_ctr_coeff = mol._bas[i,pyscf_mole.PTR_COEFF]
g = ao1[:,:,off:off+nprim*nbas1].reshape(comp, ngrids, nprim, nbas1)

xyz = get_bas_label(l)
xyz1 = get_bas_label(l+2)
for k in range(nctr):
for j in range(nprim):
c = mol._env[ptr_ctr_coeff + k*nprim + j]
if l == 0:
c *= 0.282094791773878143 # normalization factor for s orbital
elif l == 1:
c *= 0.488602511902919921 # normalization factor for p orbital
jbas = ibas
for orb in xyz:
idx_x = xyz1.index(promote_xyz(orb, 'x', 2))
idx_y = xyz1.index(promote_xyz(orb, 'y', 2))
idx_z = xyz1.index(promote_xyz(orb, 'z', 2))
gc = -(g[:,:,j,idx_x] + g[:,:,j,idx_y] + g[:,:,j,idx_z]) * c
grad[:,:,ioff+j,jbas] += gc
jbas += 1
ibas += nbas
off += nprim * nbas1

tangent_out = np.einsum('cgxi,x->cgi', grad, mol_t.exp)
if not cart:
c2s = np.asarray(mol.cart2sph_coeff())
tangent_out = np.einsum('cgp,pi->cgi', tangent_out, c2s)
if comp == 1:
tangent_out = tangent_out[0]
return tangent_out
36 changes: 32 additions & 4 deletions pyscfad/gto/test/test_eval_gto.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,12 @@ def cs_grad_fd(mol, eval_name, coords):
g = four_point_fd(mol, eval_name, coords, _env_of, disp)
return g

def exp_grad_fd(mol, eval_name, coords):
disp = 1e-4
_, _, _env_of = gto.mole.setup_exp(mol)
g = four_point_fd(mol, eval_name, coords, _env_of, disp)
return g

def test_eval_gto_cs():
mol = gto.Mole()
mol.atom = 'H 0 0 0; H 0 0 0.74' # in Angstrom
Expand All @@ -87,10 +93,32 @@ def test_eval_gto_cs():
tol = [1e-6, 1e-6, 1e-6, 1e-6]

for i, eval_name in enumerate(eval_names):
ao0 = pyscf_eval_gto(mol, eval_name, coords)
ao = mol.eval_gto(eval_name, coords)

jac_fwd = jax.jacfwd(mol.__class__.eval_gto)(mol, eval_name, coords)
g_fd = cs_grad_fd(mol, eval_name, coords)

jac_fwd = jax.jacfwd(mol.__class__.eval_gto)(mol, eval_name, coords)
assert abs(jac_fwd.ctr_coeff - g_fd).max() < tol[i]

jac_rev = jax.jacrev(mol.__class__.eval_gto)(mol, eval_name, coords)
assert abs(jac_rev.ctr_coeff - g_fd).max() < tol[i]

def test_eval_gto_exp():
mol = gto.Mole()
mol.atom = 'H 0 0 0; H 0 0 0.74' # in Angstrom
mol.basis = bas
mol.build(trace_coords=False, trace_ctr_coeff=False, trace_exp=True)

grids = gen_grid.Grids(mol)
grids.atom_grid = {'H': (10, 14)}
grids.build(with_non0tab=True)
coords = grids.coords

tol = [1e-6, 1e-6, 1e-6, 1e-6]

for i, eval_name in enumerate(eval_names):
g_fd = exp_grad_fd(mol, eval_name, coords)

jac_fwd = jax.jacfwd(mol.__class__.eval_gto)(mol, eval_name, coords)
assert abs(jac_fwd.exp - g_fd).max() < tol[i]

jac_rev = jax.jacrev(mol.__class__.eval_gto)(mol, eval_name, coords)
assert abs(jac_rev.exp - g_fd).max() < tol[i]

0 comments on commit 0d420a7

Please sign in to comment.