Skip to content

Commit

Permalink
improve eval_gto performance
Browse files Browse the repository at this point in the history
  • Loading branch information
fishjojo committed Jan 5, 2024
1 parent 553030b commit 45f5197
Show file tree
Hide file tree
Showing 2 changed files with 144 additions and 130 deletions.
199 changes: 120 additions & 79 deletions pyscfad/gto/eval_gto.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,63 @@
from functools import partial
from jax import custom_jvp
import numpy
#import jax
from pyscf import numpy as np
from pyscf.gto.moleintor import make_loc
from pyscf.gto.eval_gto import _get_intor_and_comp
from pyscf.gto.eval_gto import eval_gto as pyscf_eval_gto
from .moleintor import get_bas_label
from pyscfad.lib import jit, custom_jvp, vmap
from pyscfad.gto.moleintor import get_bas_label

_MAX_DERIV_ORDER = 4
_DERIV_LABEL = []
for i in range(_MAX_DERIV_ORDER+1):
if i == 0:
_label = ['',]
else:
_label = get_bas_label(i)
_DERIV_LABEL += _label
#_DERIV_LABEL = []
#for i in range(_MAX_DERIV_ORDER+1):
# if i == 0:
# _label = ['',]
# else:
# _label = get_bas_label(i)
# _DERIV_LABEL += _label
#
#order = _MAX_DERIV_ORDER - 1
#_X_ID = [[] for iorder in range(order+1)]
#_Y_ID = [[] for iorder in range(order+1)]
#_Z_ID = [[] for iorder in range(order+1)]
#for iorder in range(order+1):
# start = (iorder+1) * (iorder+2) * (iorder+3) // 6
# end = (iorder+2) * (iorder+3) * (iorder+4) // 6
# for il, label in enumerate(get_bas_label(iorder)):
# idx_x = _DERIV_LABEL.index(''.join(sorted(label + 'x')), start, end)
# idx_y = _DERIV_LABEL.index(''.join(sorted(label + 'y')), start, end)
# idx_z = _DERIV_LABEL.index(''.join(sorted(label + 'z')), start, end)
# _X_ID[iorder].append(idx_x)
# _Y_ID[iorder].append(idx_y)
# _Z_ID[iorder].append(idx_z)
#
#_X_ID = [[1], [4, 5, 6], [10, 11, 12, 13, 14, 15], [20, 21, 22, 23, 24, 25, 26, 27, 28, 29]]
#_Y_ID = [[2], [5, 7, 8], [11, 13, 14, 16, 17, 18], [21, 23, 24, 26, 27, 28, 30, 31, 32, 33]]
#_Z_ID = [[3], [6, 8, 9], [12, 14, 15, 17, 18, 19], [22, 24, 25, 27, 28, 29, 31, 32, 33, 34]]

_XYZ_ID = [
numpy.array(
[[1,],
[2,],
[3,],]
),
numpy.array(
[[4, 5, 6],
[5, 7, 8],
[6, 8, 9],]
),
numpy.array(
[[10, 11, 12, 13, 14, 15],
[11, 13, 14, 16, 17, 18],
[12, 14, 15, 17, 18, 19],]
),
numpy.array(
[[20, 21, 22, 23, 24, 25, 26, 27, 28, 29],
[21, 23, 24, 26, 27, 28, 30, 31, 32, 33],
[22, 24, 25, 27, 28, 29, 31, 32, 33, 34],]
),
]

def eval_gto(mol, eval_name, grid_coords,
comp=None, shls_slice=None, non0tab=None,
Expand All @@ -33,8 +77,6 @@ def _eval_gto_jvp(eval_name, comp, shls_slice, non0tab, ao_loc, cutoff, out,
primals, tangents):
mol, grid_coords = primals
mol_t, grid_coords_t = tangents
#mol, = primals
#mol_t, = tangents

primal_out = _eval_gto(mol, eval_name, grid_coords, comp, shls_slice, non0tab,
ao_loc, cutoff, out)
Expand Down Expand Up @@ -71,27 +113,28 @@ def _eval_gto_jvp_r(mol, eval_name, grid_coords, grid_coords_t,
ao1 = _eval_gto(mol, new_eval, grid_coords, None, shls_slice, non0tab,
ao_loc, None, None)

nc = (order+1) * (order+2) * (order+3) // 6
grad = np.zeros((3,nc,ng,nao))
for iorder in range(order+1):
start0 = iorder * (iorder+1) * (iorder+2) // 6
start = (iorder+1) * (iorder+2) * (iorder+3) // 6
end = (iorder+2) * (iorder+3) * (iorder+4) // 6
for il, label in enumerate(get_bas_label(iorder)):
idx_x = _DERIV_LABEL.index(''.join(sorted(label + 'x')), start, end)
idx_y = _DERIV_LABEL.index(''.join(sorted(label + 'y')), start, end)
idx_z = _DERIV_LABEL.index(''.join(sorted(label + 'z')), start, end)
grad = grad.at[0,start0+il].add(ao1[idx_x])
grad = grad.at[1,start0+il].add(ao1[idx_y])
grad = grad.at[2,start0+il].add(ao1[idx_z])

tangent_out = np.einsum('xygi,gx->ygi', grad, grid_coords_t)
@jit
def _contract(ao1, grid_coords_t):
tangent_out = []
for iorder in range(order+1):
tmp = 0
for i in range(3):
tmp += np.einsum('ygi,g->ygi',
ao1[_XYZ_ID[iorder][i]],
grid_coords_t[:,i])
tangent_out.append(tmp)
tangent_out = np.concatenate(tangent_out)
return tangent_out

tangent_out = _contract(ao1, grid_coords_t)
if order == 0:
tangent_out = tangent_out[0]
return tangent_out

def _eval_gto_fill_grad_r0(mol, intor, shls_slice, ao_loc, ao1, order, ngrids):
nc = (order+1) * (order+2) * (order+3) // 6
def _eval_gto_dot_grad_tangent_r0(mol, mol_t, intor,
shls_slice, ao_loc, ao1,
order, ngrids):
coords_t = mol_t.coords
if shls_slice is None:
shls_slice = (0, mol.nbas)
sh0, sh1 = shls_slice
Expand All @@ -100,57 +143,55 @@ def _eval_gto_fill_grad_r0(mol, intor, shls_slice, ao_loc, ao1, order, ngrids):
ao_start = ao_loc[sh0]
ao_end = ao_loc[sh1]
nao = ao_end - ao_start
ng = ngrids
atmlst = np.asarray(range(mol.natm))
aoslices = mol.aoslice_by_atom(ao_loc)

#if nc == 1:
# tangent_out = np.zeros((ng,nao))
#else:
# tangent_out = np.zeros((nc,ng,nao))
#for iorder in range(order+1):
# for k, ia in enumerate(atmlst):
# p0, p1 = aoslices [ia, 2:]
# if p1 <= ao_start:
# continue
# id0 = max(0, p0 - ao_start)
# id1 = min(p1, ao_end) - ao_start
# if order == 0:
# tmp = np.einsum('xgi,x->gi', ao1[1:4,:,id0:id1], coords_t[k])
# tangent_out = ops.index_add(tangent_out, ops.index[:,id0:id1], tmp)
# else:
# start0 = iorder * (iorder+1) * (iorder+2) // 6
# start = (iorder+1) * (iorder+2) * (iorder+3) // 6
# end = (iorder+2) * (iorder+3) * (iorder+4) // 6
# for il, label in enumerate(get_bas_label(iorder)):
# idx_x = _DERIV_LABEL.index(''.join(sorted(label + 'x')), start, end)
# idx_y = _DERIV_LABEL.index(''.join(sorted(label + 'y')), start, end)
# idx_z = _DERIV_LABEL.index(''.join(sorted(label + 'z')), start, end)
# tmp = ( ao1[idx_x,:,id0:id1] * coords_t[k,0]
# + ao1[idx_y,:,id0:id1] * coords_t[k,1]
# + ao1[idx_z,:,id0:id1] * coords_t[k,2])
# tangent_out = ops.index_add(tangent_out, ops.index[start0+il,:,id0:id1], tmp)

grad = np.zeros([mol.natm,3,nc,ng,nao], dtype=ao1.dtype)
for iorder in range(order+1):
for k, ia in enumerate(atmlst):
p0, p1 = aoslices [ia, 2:]
if p1 <= ao_start:
continue
id0 = max(0, p0 - ao_start)
id1 = min(p1, ao_end) - ao_start

start0 = iorder * (iorder+1) * (iorder+2) // 6
start = (iorder+1) * (iorder+2) * (iorder+3) // 6
end = (iorder+2) * (iorder+3) * (iorder+4) // 6
for il, label in enumerate(get_bas_label(iorder)):
idx_x = _DERIV_LABEL.index(''.join(sorted(label + 'x')), start, end)
idx_y = _DERIV_LABEL.index(''.join(sorted(label + 'y')), start, end)
idx_z = _DERIV_LABEL.index(''.join(sorted(label + 'z')), start, end)
grad = grad.at[k,0,start0+il,:,id0:id1].add(-ao1[idx_x,:,id0:id1])
grad = grad.at[k,1,start0+il,:,id0:id1].add(-ao1[idx_y,:,id0:id1])
grad = grad.at[k,2,start0+il,:,id0:id1].add(-ao1[idx_z,:,id0:id1])
return grad
ids = numpy.zeros((mol.natm, 2))
for k, ia in enumerate(atmlst):
p0, p1 = aoslices[ia, 2:]
if p1 <= ao_start:
ids[k,0] = ids[k,1] = 0
else:
ids[k,0] = max(0, p0 - ao_start)
ids[k,1] = min(p1, ao_end) - ao_start

#FIXME scan does not work for reverse mode
#@jit
#def fn(ao1, ids):
# tangent_out = []
# for iorder in range(order+1):
# def body(carry, xs):
# slices, coords_t = xs
# _zero = np.array(0, dtype=ao1.dtype)
# idx = np.arange(nao)[None,None,:]
# p0, p1 = slices[:]
# mask = (idx >= p0) & (idx < p1)
# for i in range(3):
# carry += np.where(mask, ao1[_XYZ_ID[iorder][i]], _zero) * coords_t[i]
# return carry, None
# nl = (iorder+1)*(iorder+2)//2
# tangent_out.append(jax.lax.scan(body, np.zeros((nl,ngrids,nao)), (ids, coords_t))[0])
# tangent_out = np.concatenate(tangent_out)
# return tangent_out

@jit
def fn(ao1, ids):
tangent_out = []
for iorder in range(order+1):
def body(slices, coords_t):
_zero = np.array(0, dtype=ao1.dtype)
idx = np.arange(nao)[None,None,:]
p0, p1 = slices[:]
mask = (idx >= p0) & (idx < p1)
out = np.where(mask, ao1[_XYZ_ID[iorder][0]], _zero) * coords_t[0]
out += np.where(mask, ao1[_XYZ_ID[iorder][1]], _zero) * coords_t[1]
out += np.where(mask, ao1[_XYZ_ID[iorder][2]], _zero) * coords_t[2]
return out
out = vmap(body)(ids, coords_t)
tangent_out.append(np.sum(out, axis=0))
tangent_out = np.concatenate(tangent_out)
return tangent_out
return fn(-ao1, ids)

def _eval_gto_jvp_r0(mol, mol_t, eval_name, grid_coords, comp, shls_slice, non0tab, ao_loc):
if 'deriv'+str(_MAX_DERIV_ORDER) in eval_name:
Expand All @@ -166,10 +207,10 @@ def _eval_gto_jvp_r0(mol, mol_t, eval_name, grid_coords, comp, shls_slice, non0t
ao1 = _eval_gto(mol, new_eval, grid_coords, None, shls_slice, non0tab,
ao_loc, None, None)
ngrids = len(grid_coords)
grad = _eval_gto_fill_grad_r0(mol, new_eval, shls_slice, ao_loc, ao1, order, ngrids)
tangent_out = _eval_gto_dot_grad_tangent_r0(mol, mol_t, new_eval,
shls_slice, ao_loc, ao1,
order, ngrids)
ao1 = None
tangent_out = np.einsum('nxlgi,nx->lgi', grad, mol_t.coords)
grad = None
if order == 0:
tangent_out = tangent_out[0]
return tangent_out
Expand Down
Loading

0 comments on commit 45f5197

Please sign in to comment.