Skip to content

Commit

Permalink
fix exp ctr_coeff gradient
Browse files Browse the repository at this point in the history
  • Loading branch information
fishjojo committed Jan 10, 2024
1 parent 0d420a7 commit b2dcfcb
Show file tree
Hide file tree
Showing 5 changed files with 627 additions and 223 deletions.
4 changes: 4 additions & 0 deletions pyscfad/gto/_mole_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ def uncontract(mol, shls_slice=None):
_bas = numpy.vstack(tuple(_bas))
_bas[:,-1] = 0
mol1._bas = _bas

# stop tracing
mol1.ctr_coeff = None
mol1.exp = None
return mol1

def shlmap_ctr2unctr(mol):
Expand Down
124 changes: 68 additions & 56 deletions pyscfad/gto/eval_gto.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,6 @@ def _eval_gto_jvp_r(mol, eval_name, grid_coords, grid_coords_t,
if order + 1 > _MAX_DERIV_ORDER:
raise NotImplementedError

ng = grid_coords.shape[0]
ao1 = _eval_gto(mol, new_eval, grid_coords, None, shls_slice, non0tab,
ao_loc, None, None)

Expand Down Expand Up @@ -154,7 +153,7 @@ def _eval_gto_dot_grad_tangent_r0(mol, mol_t, intor,
ao_start = ao_loc[sh0]
ao_end = ao_loc[sh1]
nao = ao_end - ao_start
atmlst = np.asarray(range(mol.natm))
atmlst = numpy.asarray(range(mol.natm))
aoslices = mol.aoslice_by_atom(ao_loc)

ids = numpy.zeros((mol.natm, 2))
Expand Down Expand Up @@ -238,32 +237,40 @@ def _eval_gto_jvp_cs(mol, mol_t, eval_name, grid_coords, comp, shls_slice, non0t
shl0, shl1 = shls_slice

mol1 = get_fakemol_cs(mol, shls_slice)
comp = _get_intor_and_comp(mol, eval_name, comp)[1]
ao1 = _eval_gto(mol1, eval_name, grid_coords, comp, None, non0tab,
None, None, None)
ao1 = ao1.reshape(comp,ngrids,-1)

_, cs_of, _ = setup_ctr_coeff(mol)

nao_id0, nao_id1 = mole.nao_nr_range(mol, shl0, shl1)
nao = nao_id1 - nao_id0
grad = numpy.zeros((comp,ngrids,len(ctr_coeff),nao), dtype=ao1.dtype)
_, cs_of, _ = setup_ctr_coeff(mol)

off = 0
ibas = 0
for i in range(shl0, shl1):
l = mol._bas[i,pyscf_mole.ANG_OF]
if mol.cart:
nbas = (l+1)*(l+2)//2
else:
nbas = 2*l + 1
nprim = mol._bas[i,pyscf_mole.NPRIM_OF]
nctr = mol._bas[i,pyscf_mole.NCTR_OF]
g = ao1[:,:,off:(off+nprim*nbas)].reshape(comp,ngrids,nprim,nbas)
for j in range(nctr):
grad[:,:,(cs_of[i]+j*nprim):(cs_of[i]+(j+1)*nprim),ibas:(ibas+nbas)] += g
ibas += nbas
off += nprim*nbas

tangent_out = np.einsum('cgxi,x->cgi', grad, ctr_coeff_t)
# TODO improve performance
@jit
def _dot_grad_tangent(ao1, tangent):
tangent_out = np.empty((comp,ngrids,nao), dtype=ao1.dtype)
off = 0
ibas = 0
for i in range(shl0, shl1):
l = mol._bas[i,pyscf_mole.ANG_OF]
if mol.cart:
nbas = (l+1)*(l+2)//2
else:
nbas = 2*l + 1
nprim = mol._bas[i,pyscf_mole.NPRIM_OF]
nctr = mol._bas[i,pyscf_mole.NCTR_OF]
g = ao1[...,off:(off+nprim*nbas)].reshape(comp,ngrids,nprim,nbas)
for j in range(nctr):
out = np.einsum('cgxi,x->cgi', g,
tangent[(cs_of[i]+j*nprim):(cs_of[i]+(j+1)*nprim)])
tangent_out = tangent_out.at[...,ibas:(ibas+nbas)].set(out)
ibas += nbas
off += nprim*nbas
return tangent_out
tangent_out = _dot_grad_tangent(ao1, ctr_coeff_t)

if comp == 1:
tangent_out = tangent_out[0]
return tangent_out
Expand All @@ -282,6 +289,7 @@ def _eval_gto_jvp_exp(mol, mol_t, eval_name, grid_coords, comp, shls_slice, non0
shl0, shl1 = shls_slice

mol1 = get_fakemol_exp(mol, shls_slice=shls_slice)
comp = _get_intor_and_comp(mol, eval_name, comp)[1]
ao1 = _eval_gto(mol1, eval_name, grid_coords, comp, None, non0tab,
None, None, None)

Expand All @@ -290,42 +298,46 @@ def _eval_gto_jvp_exp(mol, mol_t, eval_name, grid_coords, comp, shls_slice, non0
ngrids = len(grid_coords)
nao = pyscf_mole.nao_cart(mol)
ao1 = ao1.reshape(comp, ngrids, -1)
# 1st order derivative only
grad = numpy.zeros((comp, ngrids, len(es), nao), dtype=ao1.dtype)

off = 0
ibas = 0
for i in range(shl0, shl1):
ioff = es_of[i]

l = mol._bas[i,pyscf_mole.ANG_OF]
nbas = (l+1)*(l+2)//2
nbas1 = (l+3)*(l+4)//2
nprim = mol._bas[i,pyscf_mole.NPRIM_OF]
nctr = mol._bas[i,pyscf_mole.NCTR_OF]
ptr_ctr_coeff = mol._bas[i,pyscf_mole.PTR_COEFF]
g = ao1[:,:,off:off+nprim*nbas1].reshape(comp, ngrids, nprim, nbas1)

xyz = get_bas_label(l)
xyz1 = get_bas_label(l+2)
for k in range(nctr):
for j in range(nprim):
c = mol._env[ptr_ctr_coeff + k*nprim + j]
if l == 0:
c *= 0.282094791773878143 # normalization factor for s orbital
elif l == 1:
c *= 0.488602511902919921 # normalization factor for p orbital
jbas = ibas
for orb in xyz:
idx_x = xyz1.index(promote_xyz(orb, 'x', 2))
idx_y = xyz1.index(promote_xyz(orb, 'y', 2))
idx_z = xyz1.index(promote_xyz(orb, 'z', 2))
gc = -(g[:,:,j,idx_x] + g[:,:,j,idx_y] + g[:,:,j,idx_z]) * c
grad[:,:,ioff+j,jbas] += gc
jbas += 1
ibas += nbas
off += nprim * nbas1

# TODO improve performance
@jit
def _fill_grad(ao1):
grad = np.zeros((comp,ngrids,len(es),nao), dtype=ao1.dtype)
off = 0
ibas = 0
for i in range(shl0, shl1):
ioff = es_of[i]

l = mol._bas[i,pyscf_mole.ANG_OF]
nbas = (l+1)*(l+2)//2
nbas1 = (l+3)*(l+4)//2
nprim = mol._bas[i,pyscf_mole.NPRIM_OF]
nctr = mol._bas[i,pyscf_mole.NCTR_OF]
ptr_ctr_coeff = mol._bas[i,pyscf_mole.PTR_COEFF]
g = ao1[:,:,off:off+nprim*nbas1].reshape(comp, ngrids, nprim, nbas1)

xyz = get_bas_label(l)
xyz1 = get_bas_label(l+2)
for k in range(nctr):
for j in range(nprim):
c = mol._env[ptr_ctr_coeff + k*nprim + j]
if l == 0:
c *= 0.282094791773878143 # normalization factor for s orbital
elif l == 1:
c *= 0.488602511902919921 # normalization factor for p orbital
jbas = ibas
for orb in xyz:
idx_x = xyz1.index(promote_xyz(orb, 'x', 2))
idx_y = xyz1.index(promote_xyz(orb, 'y', 2))
idx_z = xyz1.index(promote_xyz(orb, 'z', 2))
gc = -(g[:,:,j,idx_x] + g[:,:,j,idx_y] + g[:,:,j,idx_z]) * c
grad = grad.at[:,:,ioff+j,jbas].add(gc)
jbas += 1
ibas += nbas
off += nprim * nbas1
return grad

grad = _fill_grad(ao1)
tangent_out = np.einsum('cgxi,x->cgi', grad, mol_t.exp)
if not cart:
c2s = np.asarray(mol.cart2sph_coeff())
Expand Down
Loading

0 comments on commit b2dcfcb

Please sign in to comment.