diff --git a/pyscfad/gto/_mole_helper.py b/pyscfad/gto/_mole_helper.py index 2bd194a0..1944b383 100644 --- a/pyscfad/gto/_mole_helper.py +++ b/pyscfad/gto/_mole_helper.py @@ -178,8 +178,8 @@ def setup_ctr_coeff(mol): cs_of = cs_of[idx] return cs, cs_of, env_of -def get_fakemol_exp(mol, order=2): - mol1 = uncontract(mol) +def get_fakemol_exp(mol, order=2, shls_slice=None): + mol1 = uncontract(mol, shls_slice=shls_slice) mol1._bas[:,ANG_OF] += order return mol1 diff --git a/pyscfad/gto/eval_gto.py b/pyscfad/gto/eval_gto.py index c76e3111..47135c46 100644 --- a/pyscfad/gto/eval_gto.py +++ b/pyscfad/gto/eval_gto.py @@ -8,11 +8,11 @@ from pyscf.gto.eval_gto import eval_gto as pyscf_eval_gto from pyscfad.lib import jit, custom_jvp, vmap from pyscfad.gto import mole -#from pyscfad.gto.moleintor import get_bas_label +from pyscfad.gto.moleintor import get_bas_label, promote_xyz from pyscfad.gto._mole_helper import ( -# setup_exp, + setup_exp, setup_ctr_coeff, -# get_fakemol_exp, + get_fakemol_exp, get_fakemol_cs, ) @@ -83,6 +83,9 @@ def _eval_gto(mol, eval_name, grid_coords, @_eval_gto.defjvp def _eval_gto_jvp(eval_name, comp, shls_slice, non0tab, ao_loc, cutoff, out, primals, tangents): + if 'spinor' in eval_name or 'ip' in eval_name: + raise NotImplementedError + mol, grid_coords = primals mol_t, grid_coords_t = tangents @@ -107,15 +110,15 @@ def _eval_gto_jvp(eval_name, comp, shls_slice, non0tab, ao_loc, cutoff, out, def _eval_gto_jvp_r(mol, eval_name, grid_coords, grid_coords_t, comp, shls_slice, non0tab, ao_loc, nao): - if 'deriv'+str(_MAX_DERIV_ORDER) in eval_name: - raise NotImplementedError if 'deriv' not in eval_name: - new_eval = eval_name + '_deriv1' order = 0 + new_eval = eval_name + '_deriv1' else: tmp = eval_name.split('deriv', 1) order = int(tmp[1]) new_eval = tmp[0] + 'deriv' + str(order + 1) + if order + 1 > _MAX_DERIV_ORDER: + raise NotImplementedError ng = grid_coords.shape[0] ao1 = _eval_gto(mol, new_eval, grid_coords, None, shls_slice, non0tab, @@ -203,15 +206,15 @@ def body(slices, coords_t): return fn(-ao1, ids) def _eval_gto_jvp_r0(mol, mol_t, eval_name, grid_coords, comp, shls_slice, non0tab, ao_loc): - if 'deriv'+str(_MAX_DERIV_ORDER) in eval_name: - raise NotImplementedError if 'deriv' not in eval_name: - new_eval = eval_name + '_deriv1' order = 0 + new_eval = eval_name + '_deriv1' else: tmp = eval_name.split('deriv', 1) order = int(tmp[1]) new_eval = tmp[0] + 'deriv' + str(order + 1) + if order + 1 > _MAX_DERIV_ORDER: + raise NotImplementedError ao1 = _eval_gto(mol, new_eval, grid_coords, None, shls_slice, non0tab, ao_loc, None, None) @@ -235,8 +238,8 @@ def _eval_gto_jvp_cs(mol, mol_t, eval_name, grid_coords, comp, shls_slice, non0t shl0, shl1 = shls_slice mol1 = get_fakemol_cs(mol, shls_slice) - ao1 = _eval_gto(mol1, eval_name, grid_coords, None, None, non0tab, - ao_loc, None, None) + ao1 = _eval_gto(mol1, eval_name, grid_coords, comp, None, non0tab, + None, None, None) ao1 = ao1.reshape(comp,ngrids,-1) nao_id0, nao_id1 = mole.nao_nr_range(mol, shl0, shl1) @@ -266,4 +269,67 @@ def _eval_gto_jvp_cs(mol, mol_t, eval_name, grid_coords, comp, shls_slice, non0t return tangent_out def _eval_gto_jvp_exp(mol, mol_t, eval_name, grid_coords, comp, shls_slice, non0tab, ao_loc): - return 0 + if 'sph' in eval_name: + eval_name = eval_name.replace('sph', 'cart') + cart = False + elif 'cart' in eval_name: + cart = True + else: + raise KeyError + + if shls_slice is None: + shls_slice = (0, mol.nbas) + shl0, shl1 = shls_slice + + mol1 = get_fakemol_exp(mol, shls_slice=shls_slice) + ao1 = _eval_gto(mol1, eval_name, grid_coords, comp, None, non0tab, + None, None, None) + + es, es_of, _env_of = setup_exp(mol) + + ngrids = len(grid_coords) + nao = pyscf_mole.nao_cart(mol) + ao1 = ao1.reshape(comp, ngrids, -1) + # 1st order derivative only + grad = numpy.zeros((comp, ngrids, len(es), nao), dtype=ao1.dtype) + + off = 0 + ibas = 0 + for i in range(shl0, shl1): + ioff = es_of[i] + + l = mol._bas[i,pyscf_mole.ANG_OF] + nbas = (l+1)*(l+2)//2 + nbas1 = (l+3)*(l+4)//2 + nprim = mol._bas[i,pyscf_mole.NPRIM_OF] + nctr = mol._bas[i,pyscf_mole.NCTR_OF] + ptr_ctr_coeff = mol._bas[i,pyscf_mole.PTR_COEFF] + g = ao1[:,:,off:off+nprim*nbas1].reshape(comp, ngrids, nprim, nbas1) + + xyz = get_bas_label(l) + xyz1 = get_bas_label(l+2) + for k in range(nctr): + for j in range(nprim): + c = mol._env[ptr_ctr_coeff + k*nprim + j] + if l == 0: + c *= 0.282094791773878143 # normalization factor for s orbital + elif l == 1: + c *= 0.488602511902919921 # normalization factor for p orbital + jbas = ibas + for orb in xyz: + idx_x = xyz1.index(promote_xyz(orb, 'x', 2)) + idx_y = xyz1.index(promote_xyz(orb, 'y', 2)) + idx_z = xyz1.index(promote_xyz(orb, 'z', 2)) + gc = -(g[:,:,j,idx_x] + g[:,:,j,idx_y] + g[:,:,j,idx_z]) * c + grad[:,:,ioff+j,jbas] += gc + jbas += 1 + ibas += nbas + off += nprim * nbas1 + + tangent_out = np.einsum('cgxi,x->cgi', grad, mol_t.exp) + if not cart: + c2s = np.asarray(mol.cart2sph_coeff()) + tangent_out = np.einsum('cgp,pi->cgi', tangent_out, c2s) + if comp == 1: + tangent_out = tangent_out[0] + return tangent_out diff --git a/pyscfad/gto/test/test_eval_gto.py b/pyscfad/gto/test/test_eval_gto.py index f377dc16..6ea03543 100644 --- a/pyscfad/gto/test/test_eval_gto.py +++ b/pyscfad/gto/test/test_eval_gto.py @@ -73,6 +73,12 @@ def cs_grad_fd(mol, eval_name, coords): g = four_point_fd(mol, eval_name, coords, _env_of, disp) return g +def exp_grad_fd(mol, eval_name, coords): + disp = 1e-4 + _, _, _env_of = gto.mole.setup_exp(mol) + g = four_point_fd(mol, eval_name, coords, _env_of, disp) + return g + def test_eval_gto_cs(): mol = gto.Mole() mol.atom = 'H 0 0 0; H 0 0 0.74' # in Angstrom @@ -87,10 +93,32 @@ def test_eval_gto_cs(): tol = [1e-6, 1e-6, 1e-6, 1e-6] for i, eval_name in enumerate(eval_names): - ao0 = pyscf_eval_gto(mol, eval_name, coords) - ao = mol.eval_gto(eval_name, coords) - - jac_fwd = jax.jacfwd(mol.__class__.eval_gto)(mol, eval_name, coords) g_fd = cs_grad_fd(mol, eval_name, coords) + jac_fwd = jax.jacfwd(mol.__class__.eval_gto)(mol, eval_name, coords) assert abs(jac_fwd.ctr_coeff - g_fd).max() < tol[i] + + jac_rev = jax.jacrev(mol.__class__.eval_gto)(mol, eval_name, coords) + assert abs(jac_rev.ctr_coeff - g_fd).max() < tol[i] + +def test_eval_gto_exp(): + mol = gto.Mole() + mol.atom = 'H 0 0 0; H 0 0 0.74' # in Angstrom + mol.basis = bas + mol.build(trace_coords=False, trace_ctr_coeff=False, trace_exp=True) + + grids = gen_grid.Grids(mol) + grids.atom_grid = {'H': (10, 14)} + grids.build(with_non0tab=True) + coords = grids.coords + + tol = [1e-6, 1e-6, 1e-6, 1e-6] + + for i, eval_name in enumerate(eval_names): + g_fd = exp_grad_fd(mol, eval_name, coords) + + jac_fwd = jax.jacfwd(mol.__class__.eval_gto)(mol, eval_name, coords) + assert abs(jac_fwd.exp - g_fd).max() < tol[i] + + jac_rev = jax.jacrev(mol.__class__.eval_gto)(mol, eval_name, coords) + assert abs(jac_rev.exp - g_fd).max() < tol[i]