diff --git a/examples/pbc/scf/00-simple.py b/examples/pbc/scf/00-simple.py index 804c876b..e56d8e2b 100644 --- a/examples/pbc/scf/00-simple.py +++ b/examples/pbc/scf/00-simple.py @@ -1,3 +1,11 @@ +''' +--------------- KRHF gradients --------------- + x y z +0 Si -0.0077806264 -0.0077806264 -0.0077806264 +1 Si 0.0077806264 0.0077806264 0.0077806264 +---------------------------------------------- +''' +import jax from pyscfad.pbc import gto, scf basis = 'gth-szv' @@ -19,7 +27,9 @@ cell.verbose = 4 cell.build() -mf = scf.RHF(cell, exxdiv=None) -mf.kernel() -jac = mf.energy_grad() +def hf_energy(cell): + mf = scf.RHF(cell, exxdiv=None) + e_tot = mf.kernel() + return e_tot +e_tot, jac = jax.value_and_grad(hf_energy)(cell) print(f'Nuclaer gradient:\n{jac.coords}') diff --git a/examples/pbc/scf/01-krhf.py b/examples/pbc/scf/01-krhf.py index fd277ebb..a31fa7db 100644 --- a/examples/pbc/scf/01-krhf.py +++ b/examples/pbc/scf/01-krhf.py @@ -5,7 +5,7 @@ 1 Si 0.0109162228 0.0028233781 0.0028233781 ---------------------------------------------- ''' - +import jax from pyscfad.pbc import gto from pyscfad.pbc import scf, df @@ -25,10 +25,13 @@ cell.a = lattice cell.basis = basis cell.pseudo = pseudo +cell.verbose = 4 cell.build() kpts = cell.make_kpts([2,1,1]) -mf = scf.KRHF(cell, kpts=kpts, exxdiv=None) -mf.kernel() -jac = mf.energy_grad() -print(jac.coords) +def hf_energy(cell, kpts): + mf = scf.KRHF(cell, kpts=kpts, exxdiv=None) + e_tot = mf.kernel() + return e_tot +e_tot, jac = jax.value_and_grad(hf_energy)(cell, kpts) +print(f'Nuclaer gradient:\n{jac.coords}') diff --git a/examples/pbc/scf/02-krhf-hessian.py b/examples/pbc/scf/02-krhf-hessian.py index 405a59ff..09bccf4d 100644 --- a/examples/pbc/scf/02-krhf-hessian.py +++ b/examples/pbc/scf/02-krhf-hessian.py @@ -18,6 +18,7 @@ cell.a = lattice cell.basis = basis cell.pseudo = pseudo +cell.verbose = 4 cell.build(trace_exp=False, trace_ctr_coeff=False) def ehf(cell): diff --git a/examples/pbc/scf/10-stress.py b/examples/pbc/scf/10-stress.py index 9f7f4046..5c353076 100644 --- a/examples/pbc/scf/10-stress.py +++ b/examples/pbc/scf/10-stress.py @@ -1,3 +1,17 @@ +''' +stress tensor +---------------------------- +[[-7.57812381e-01 1.65085403e-09 1.65085540e-09] + [ 1.65085006e-09 -7.57812381e-01 1.65084496e-09] + [ 1.65084745e-09 1.65084512e-09 -7.57812381e-01]] +[[-3.59347869e-03 7.82820248e-12 7.82820901e-12] + [ 7.82818367e-12 -3.59347869e-03 7.82815947e-12] + [ 7.82817132e-12 7.82816026e-12 -3.59347869e-03]] +[[-6.59876007e-01 1.43750484e-09 1.43750604e-09] + [ 1.43750138e-09 -6.59876007e-01 1.43749694e-09] + [ 1.43749911e-09 1.43749708e-09 -6.59876007e-01]] +---------------------------- +''' import numpy import jax from jax import numpy as np @@ -5,8 +19,8 @@ from pyscfad.pbc import gto as pbcgto from pyscfad.pbc import scf as pbcscf -aas = numpy.arange(5,6.0,0.1,dtype=float) -for aa in aas: +#aas = numpy.arange(5,6.0,0.1,dtype=float) +for aa in [5.]: basis = 'gth-szv' pseudo = 'gth-pade' lattice = numpy.asarray([[0., aa/2, aa/2], @@ -15,22 +29,8 @@ atom = [['Si', [0., 0., 0.]], ['Si', [aa/4, aa/4, aa/4]]] - cell0 = pbcgto.Cell() - cell0.atom = atom - cell0.a = lattice - cell0.basis = basis - cell0.pseudo = pseudo - cell0.verbose = 4 - cell0.exp_to_discard=0.1 - cell0.build() - - coords = [] - for i, a in enumerate(atom): - coords.append(a[1]) - coords = numpy.asarray(coords) - - strain = numpy.zeros((3,3)) - def khf_energy(strain, lattice, coords): + strain = np.zeros((3,3)) + def khf_energy(strain, lattice): cell = pbcgto.Cell() cell.atom = atom cell.a = lattice @@ -38,7 +38,8 @@ def khf_energy(strain, lattice, coords): cell.pseudo = pseudo cell.verbose = 4 cell.exp_to_discard=0.1 - cell.max_memory=40000 + cell.mesh = [21]*3 + cell.max_memory=100000 cell.build(trace_lattice_vectors=True) cell.abc += np.einsum('ab,nb->na', strain, cell.lattice_vectors()) @@ -47,13 +48,13 @@ def khf_energy(strain, lattice, coords): kpts = cell.make_kpts([2,2,2]) mf = pbcscf.KRHF(cell, kpts=kpts, exxdiv=None) - ehf = mf.kernel(dm0=None) - return ehf + ehf = mf.kernel() + return ehf, cell - jac = jax.jacrev(khf_energy)(strain, lattice, coords) + jac, cell = jax.jacrev(khf_energy, has_aux=True)(strain, lattice) print('stress tensor') print('----------------------------') print(jac) - print(jac / cell0.vol) - print(jac*HARTREE2EV / (cell0.vol*(BOHR**3))) + print(jac / cell.vol) + print(jac*HARTREE2EV / (cell.vol*(BOHR**3))) print('----------------------------') diff --git a/pyscfad/_src/scipy/linalg.py b/pyscfad/_src/scipy/linalg.py index b2e7ab20..51e970b3 100644 --- a/pyscfad/_src/scipy/linalg.py +++ b/pyscfad/_src/scipy/linalg.py @@ -35,7 +35,7 @@ def eigh(a, b=None, lower=True, eigvals_only=False, overwrite_a=False, @partial(custom_jvp, nondiff_argnums=(2,)) def _eigh(a, b, deg_thresh=DEG_THRESH): w, v = scipy.linalg.eigh(a, b=b) - w = np.asarray(w, dtype=v.dtype) + w = np.asarray(w, dtype=float) return w, v @_eigh.defjvp @@ -63,7 +63,7 @@ def _eigh_jvp_jitted(w, v, Fmat, at, bt, bmask): vt_bt_v = np.dot(v.conj().T, np.dot(bt, v)) vt_bt_v_w = np.dot(vt_bt_v, np.diag(w)) da_minus_ds = vt_at_v - vt_bt_v_w - dw = np.diag(da_minus_ds)#.real + dw = np.diag(da_minus_ds).real dv = np.dot(v, np.multiply(Fmat, da_minus_ds) - np.multiply(bmask, vt_bt_v) * .5) return dw, dv @@ -71,7 +71,7 @@ def _eigh_jvp_jitted(w, v, Fmat, at, bt, bmask): @jit def _eigh_jvp_jitted_nob(v, Fmat, at): vt_at_v = np.dot(v.conj().T, np.dot(at, v)) - dw = np.diag(vt_at_v) + dw = np.diag(vt_at_v).real dv = np.dot(v, np.multiply(Fmat, vt_at_v)) return dw, dv diff --git a/pyscfad/df/_df_jk_opt.py b/pyscfad/df/_df_jk_opt.py index 87010919..2c90567f 100644 --- a/pyscfad/df/_df_jk_opt.py +++ b/pyscfad/df/_df_jk_opt.py @@ -58,20 +58,20 @@ def get_jk_bwd(hermi, with_j, with_k, direct_scf_tol, rargs = (ctypes.c_int(nao), (ctypes.c_int*4)(0, nao, 0, nao), null, ctypes.c_int(0)) max_memory = dfobj.max_memory - lib.current_memory()[0] - blksize = max(4, int(min(dfobj.blockdim, max_memory*.4e6/8/nao**2))) + blksize = max(4, int(min(dfobj.blockdim, max_memory*.3e6/8/nao**2))) buf = numpy.empty((blksize,nao,nao)) p1 = 0 for eri1 in dfobj.loop(blksize): naux, nao_pair = eri1.shape p0, p1 = p1, p1 + naux if with_j: - rho_bar = numpy.einsum('ix,px->ip', vj_bar_tril, eri1) - dmtril_bar = numpy.einsum('ip,px->ix', rho_bar, eri1) + rho_bar = vj_bar_tril @ eri1.T + dmtril_bar = rho_bar @ eri1 dms_bar += lib.unpack_tril(dmtril_bar) - rho = numpy.einsum('ix,px->ip', dmtril, eri1) - eri_bar[p0:p1] += numpy.einsum('ip,ix->px', rho, vj_bar_tril) - eri_bar[p0:p1] += numpy.einsum('ix,ip->px', dmtril, rho_bar) + rho = dmtril @ eri1.T + eri_bar[p0:p1] += rho.T @ vj_bar_tril + eri_bar[p0:p1] += rho_bar.T @ dmtril for k in range(nset): #TODO save buf1 on disk to avoid recomputation diff --git a/pyscfad/pbc/df/fft.py b/pyscfad/pbc/df/fft.py index 10e03b21..763b4dc2 100644 --- a/pyscfad/pbc/df/fft.py +++ b/pyscfad/pbc/df/fft.py @@ -2,7 +2,6 @@ from jax import numpy as np from pyscf import __config__ from pyscf import lib -from pyscf.pbc.gto import Cell from pyscf.pbc.df import fft as pyscf_fft #from pyscfad import util from pyscfad.pbc import tools @@ -10,7 +9,7 @@ def get_pp(mydf, kpts=None): from pyscf import gto - from pyscf.pbc.gto import pseudo + from pyscfad.pbc.gto import pseudo from pyscfad.gto.mole import Mole cell = mydf.cell if kpts is None: @@ -31,7 +30,7 @@ def get_pp(mydf, kpts=None): for ao_ks_etc, p0, p1 in mydf.aoR_loop(mydf.grids, kpts_lst): ao_ks = ao_ks_etc[0] for k, ao in enumerate(ao_ks): - vpp[k] += np.dot(ao.T.conj()*vpplocR[p0:p1], ao) + vpp[k] += (ao.T.conj()*vpplocR[p0:p1]) @ ao ao = ao_ks = None # vppnonloc evaluated in reciprocal space @@ -124,7 +123,7 @@ def vppnl_by_k(kpt): #@util.pytree_node(['cell','kpts']) class FFTDF(pyscf_fft.FFTDF): def __init__(self, cell, kpts=numpy.zeros((1,3))):#, **kwargs): - from pyscf.pbc.dft import gen_grid + from pyscfad.pbc.dft import gen_grid from pyscfad.pbc.dft import numint self.cell = cell self.stdout = cell.stdout @@ -173,9 +172,6 @@ def aoR_loop(self, grids=None, kpts=None, deriv=0): cell = self.cell else: cell = grids.cell - - # NOTE stop tracing cell through grids - grids.cell = cell.view(Cell) if grids.non0tab is None: grids.build(with_non0tab=True) diff --git a/pyscfad/pbc/dft/gen_grid.py b/pyscfad/pbc/dft/gen_grid.py new file mode 100644 index 00000000..9c1eb786 --- /dev/null +++ b/pyscfad/pbc/dft/gen_grid.py @@ -0,0 +1,32 @@ +import numpy +from pyscf.pbc.gto import Cell as pyscf_Cell +from pyscf.pbc.dft import gen_grid as pyscf_gen_grid +from pyscfad import numpy as np +from pyscfad.ops import stop_grad +from pyscfad.pbc.gto.cell import get_uniform_grids + +class UniformGrids(pyscf_gen_grid.UniformGrids): + @property + def coords(self): + if self._coords is not None: + return self._coords + else: + return get_uniform_grids(self.cell, self.mesh) + + @property + def weights(self): + if self._weights is not None: + return self._weights + else: + ngrids = numpy.prod(self.mesh) + weights = np.full((ngrids,), self.cell.vol / ngrids) + return weights + + def make_mask(self, cell=None, coords=None, relativity=0, shls_slice=None, + verbose=None): + if cell is None: + cell = self.cell + if coords is None: + coords = self.coords + return pyscf_gen_grid.make_mask(cell.view(pyscf_Cell), stop_grad(coords), + relativity, shls_slice, verbose) diff --git a/pyscfad/pbc/gto/cell.py b/pyscfad/pbc/gto/cell.py index 3fe908c9..a4c893c8 100644 --- a/pyscfad/pbc/gto/cell.py +++ b/pyscfad/pbc/gto/cell.py @@ -1,7 +1,9 @@ +from functools import wraps import warnings import numpy from jax.scipy.special import erf, erfc from pyscf import __config__ +from pyscf.lib import cartesian_prod from pyscf.gto.mole import PTR_COORD from pyscf.gto.moleintor import _get_intor_and_comp from pyscf.pbc.gto import cell as pyscf_cell @@ -17,15 +19,76 @@ from pyscfad.pbc.gto.eval_gto import eval_gto as pbc_eval_gto from pyscfad.pbc import tools as pbctools -def get_SI(cell, Gv=None): +@wraps(pyscf_cell.get_Gv) +def get_Gv(cell, mesh=None, **kwargs): + return get_Gv_weights(cell, mesh, **kwargs)[0] + +@wraps(pyscf_cell.get_Gv_weights) +def get_Gv_weights(cell, mesh=None, **kwargs): + if mesh is None: + mesh = cell.mesh + if 'gs' in kwargs: + warnings.warn('cell.gs is deprecated. It is replaced by cell.mesh,' + 'the number of PWs (=2*gs+1) along each direction.') + mesh = [2*n+1 for n in kwargs['gs']] + + # Default, the 3D uniform grids + rx = numpy.fft.fftfreq(mesh[0], 1./mesh[0]) + ry = numpy.fft.fftfreq(mesh[1], 1./mesh[1]) + rz = numpy.fft.fftfreq(mesh[2], 1./mesh[2]) + b = cell.reciprocal_vectors() + weights = abs(np.linalg.det(b)) + + if (cell.dimension < 2 or + (cell.dimension == 2 and cell.low_dim_ft_type == 'inf_vacuum')): + raise NotImplementedError + + Gvbase = (rx, ry, rz) + Gv = cartesian_prod(Gvbase) @ b + Gv = Gv.reshape(-1, 3) + + # 1/cell.vol == det(b)/(2pi)^3 + weights *= 1/(2*numpy.pi)**3 + return Gv, Gvbase, weights + +@wraps(pyscf_cell.get_SI) +def get_SI(cell, Gv=None, mesh=None, atmlst=None): coords = cell.atom_coords() - ngrids = numpy.prod(cell.mesh) - if Gv is None or Gv.shape[0] == ngrids: - Gv = cell.get_Gv() - GvT = Gv.T - SI = np.exp(-1j*np.dot(coords, GvT)) + if atmlst is not None: + coords = coords[numpy.asarray(atmlst)] + if Gv is None: + if mesh is None: + mesh = cell.mesh + basex, basey, basez = cell.get_Gv_weights(mesh)[1] + b = cell.reciprocal_vectors() + rb = coords @ b.T + SIx = np.exp(-1j*np.einsum('z,g->zg', rb[:,0], basex)) + SIy = np.exp(-1j*np.einsum('z,g->zg', rb[:,1], basey)) + SIz = np.exp(-1j*np.einsum('z,g->zg', rb[:,2], basez)) + SI = SIx[:,:,None,None] * SIy[:,None,:,None] * SIz[:,None,None,:] + natm = coords.shape[0] + SI = SI.reshape(natm, -1) + else: + SI = np.exp(-1j * (coords @ Gv.T)) return SI +@wraps(pyscf_cell.get_uniform_grids) +def get_uniform_grids(cell, mesh=None, wrap_around=True): + if mesh is None: + mesh = cell.mesh + + a = cell.lattice_vectors() + if wrap_around: + qv = cartesian_prod([numpy.fft.fftfreq(x) for x in mesh]) + coords = qv @ a + else: + mesh = numpy.asarray(mesh, float) + qv = cartesian_prod([numpy.arange(x) for x in mesh]) + a_frac = (1./mesh)[:,None] * a + coords = qv @ a_frac + return coords +gen_uniform_grids = get_uniform_grids + def shift_bas_center(cell0, r): cell = cell0.copy() cell.coords = cell0.atom_coords() + r[None,:] @@ -79,7 +142,31 @@ def pbc_intor(cell, intor, comp=None, hermi=0, kpts=None, kpt=None, kpt=kpt, shls_slice=shls_slice, **kwargs) return res +@wraps(pyscf_cell.get_ewald_params) +def get_ewald_params(cell, precision=None, mesh=None): + if cell.natm == 0: + return 0, 0 + + if precision is None: + precision = cell.precision + + if (cell.dimension < 2 or + (cell.dimension == 2 and cell.low_dim_ft_type == 'inf_vacuum')): + ew_cut = cell.rcut + ew_eta = numpy.sqrt(max(numpy.log(4*numpy.pi*ew_cut**2/precision)/ew_cut**2, .1)) + elif cell.dimension == 2: + a = cell.lattice_vectors() + ew_cut = a[2,2] / 2 + # ewovrl ~ erfc(eta*rcut) / rcut ~ e^{(-eta**2 rcut*2)} < precision + log_precision = numpy.log(precision / (cell.atom_charges().sum()*16*numpy.pi**2)) + ew_eta = (-log_precision)**.5 / ew_cut + else: # dimension == 3 + ew_eta = 1./cell.vol**(1./6) + ew_cut = pyscf_cell._estimate_rcut(stop_grad(ew_eta)**2, 0, 1., precision) + return ew_eta, ew_cut + # modified from pyscf v2.3 +@wraps(pyscf_cell.ewald) def ewald(cell, ew_eta=None, ew_cut=None): if cell.a is None: return mole.energy_nuc(cell) @@ -222,12 +309,48 @@ def build(self, *args, **kwargs): if trace_lattice_vectors: self.abc = np.asarray(self.lattice_vectors()) + @property + def vol(self): + return abs(np.linalg.det(self.lattice_vectors())) + def lattice_vectors(self): if self.abc is None: return pyscf_cell.Cell.lattice_vectors(self) else: return self.abc + @wraps(pyscf_cell.Cell.get_scaled_atom_coords) + def get_scaled_atom_coords(self, a=None): + if a is None: + a = self.lattice_vectors() + return np.dot(self.atom_coords(), np.linalg.inv(a)) + + @wraps(pyscf_cell.Cell.reciprocal_vectors) + def reciprocal_vectors(self, norm_to=2*numpy.pi): + a = self.lattice_vectors() + if self.dimension == 1: + assert(abs(a[0] @ a[1]) < 1e-9 and + abs(a[0] @ a[2]) < 1e-9 and + abs(a[1] @ a[2]) < 1e-9) + elif self.dimension == 2: + assert(abs(a[0] @ a[2]) < 1e-9 and + abs(a[1] @ a[2]) < 1e-9) + b = np.linalg.inv(a.T) + return norm_to * b + + @wraps(pyscf_cell.Cell.get_abs_kpts) + def get_abs_kpts(self, scaled_kpts): + return np.dot(scaled_kpts, self.reciprocal_vectors()) + + @wraps(pyscf_cell.Cell.cutoff_to_mesh) + def cutoff_to_mesh(self, ke_cutoff): + a = self.lattice_vectors() + dim = self.dimension + mesh = pbctools.cutoff_to_mesh(a, ke_cutoff) + if dim < 2 or (dim == 2 and self.low_dim_ft_type == 'inf_vacuum'): + mesh[dim:] = self.mesh[dim:] + return mesh + def pbc_eval_gto(self, eval_name, coords, comp=None, kpts=None, kpt=None, shls_slice=None, non0tab=None, ao_loc=None, out=None): return pbc_eval_gto(self, eval_name, coords, comp, kpts, kpt, @@ -242,9 +365,14 @@ def eval_gto(self, eval_name, coords, comp=None, kpts=None, kpt=None, else: return mole.eval_gto(self, eval_name, coords, comp, shls_slice, non0tab, ao_loc, out) + eval_ao = eval_gto pbc_intor = pbc_intor + get_Gv = get_Gv + get_Gv_weights = get_Gv_weights get_SI = get_SI + gen_uniform_grids = get_uniform_grids = get_uniform_grids + get_ewald_params = get_ewald_params ewald = ewald energy_nuc = ewald get_lattice_Ls = pbctools.get_lattice_Ls diff --git a/pyscfad/pbc/gto/pseudo/__init__.py b/pyscfad/pbc/gto/pseudo/__init__.py new file mode 100644 index 00000000..9518583d --- /dev/null +++ b/pyscfad/pbc/gto/pseudo/__init__.py @@ -0,0 +1,2 @@ +from pyscfad.pbc.gto.pseudo.pp import * +from pyscfad.pbc.gto.pseudo import pp_int diff --git a/pyscfad/pbc/gto/pseudo/pp.py b/pyscfad/pbc/gto/pseudo/pp.py new file mode 100644 index 00000000..3003bb19 --- /dev/null +++ b/pyscfad/pbc/gto/pseudo/pp.py @@ -0,0 +1,84 @@ +from functools import wraps +import numpy +from pyscf.pbc.gto.pseudo import pp as pyscf_pp + +from pyscfad import numpy as np +from pyscfad import ops +from pyscfad.pbc.gto.pseudo import pp_int + +@wraps(pyscf_pp.get_vlocG) +def get_vlocG(cell, Gv=None): + if Gv is None: + Gv = cell.Gv + vlocG = get_gth_vlocG(cell, Gv) + return vlocG + +@wraps(pyscf_pp.get_gth_vlocG) +def get_gth_vlocG(cell, Gv): + vlocG = pp_int.get_gth_vlocG_part1(cell, Gv) + + # Add the C1, C2, C3, C4 contributions + G2 = np.einsum("ix,ix->i", Gv, Gv) + for ia in range(cell.natm): + symb = cell.atom_symbol(ia) + if symb not in cell._pseudo: + continue + + pp = cell._pseudo[symb] + rloc, nexp, cexp = pp[1:3+1] + + G2_red = G2 * rloc**2 + cfacs = 0 + if nexp >= 1: + cfacs += cexp[0] + if nexp >= 2: + cfacs += cexp[1] * (3 - G2_red) + if nexp >= 3: + cfacs += cexp[2] * (15 - 10*G2_red + G2_red**2) + if nexp >= 4: + cfacs += cexp[3] * (105 - 105*G2_red + 21*G2_red**2 - G2_red**3) + + vlocG = ops.index_add(vlocG, ops.index[ia,:], + -(2*numpy.pi)**(3/2.)*rloc**3*np.exp(-0.5*G2_red) * cfacs) + + return vlocG + +def _qli(x,l,i): + sqrt = np.sqrt + if l==0 and i==0: + return 4*sqrt(2.) + elif l==0 and i==1: + return 8*sqrt(2/15.)*(3-x**2) # MH & GTH (right) + #return sqrt(8*2/15.)*(3-x**2) # HGH (wrong) + elif l==0 and i==2: + #return 16/3.*sqrt(2/105.)*(15-20*x**2+4*x**4) # MH (wrong) + return 16/3.*sqrt(2/105.)*(15-10*x**2+x**4) # HGH (right) + elif l==1 and i==0: + return 8*sqrt(1/3.) + elif l==1 and i==1: + return 16*sqrt(1/105.)*(5-x**2) + elif l==1 and i==2: + #return 32/3.*sqrt(1/1155.)*(35-28*x**2+4*x**4) # MH (wrong) + return 32/3.*sqrt(1/1155.)*(35-14*x**2+x**4) # HGH (right) + elif l==2 and i==0: + return 8*sqrt(2/15.) + elif l==2 and i==1: + return 16/3.*sqrt(2/105.)*(7-x**2) + elif l==2 and i==2: + #return 32/3.*sqrt(2/15015.)*(63-36*x**2+4*x**4) # MH (wrong I think) + return 32/3.*sqrt(2/15015.)*(63-18*x**2+x**4) # TCB + elif l==3 and i==0: + return 16*sqrt(1/105.) + elif l==3 and i==1: + return 32/3.*sqrt(1/1155.)*(9-x**2) + elif l==3 and i==2: + return 64/45.*sqrt(1/1001.)*(99-22*x**2+x**4) + elif l==4 and i==0: + return 16/3.*sqrt(2/105.) + elif l==4 and i==1: + return 32/3.*sqrt(2/15015.)*(11-x**2) + elif l==4 and i==2: + return 64/45.*sqrt(2/17017.)*(143-26*x**2+x**4) + else: + print("*** WARNING *** l =", l, ", i =", i, "not yet implemented for NL PP!") + return 0. diff --git a/pyscfad/pbc/gto/pseudo/pp_int.py b/pyscfad/pbc/gto/pseudo/pp_int.py new file mode 100644 index 00000000..44eb0db8 --- /dev/null +++ b/pyscfad/pbc/gto/pseudo/pp_int.py @@ -0,0 +1,30 @@ +import numpy +from pyscfad import numpy as np +from pyscfad import ops + +def get_gth_vlocG_part1(cell, Gv): + from pyscfad.pbc import tools + coulG = tools.get_coulG(cell, Gv=Gv) + G2 = np.einsum('ix,ix->i', Gv, Gv) + G0idx = np.where(G2==0)[0] + + if cell.dimension != 2 or cell.low_dim_ft_type == 'inf_vacuum': + vlocG = np.zeros((cell.natm, len(G2))) + for ia in range(cell.natm): + Zia = cell.atom_charge(ia) + symb = cell.atom_symbol(ia) + # Note the signs -- potential here is positive + vlocG = ops.index_update(vlocG, ops.index[ia], Zia * coulG) + if symb in cell._pseudo: + pp = cell._pseudo[symb] + rloc, nexp, cexp = pp[1:3+1] + vlocG = ops.index_mul(vlocG, ops.index[ia], np.exp(-0.5*rloc**2 * G2)) + # alpha parameters from the non-divergent Hartree+Vloc G=0 term. + vlocG = ops.index_update(vlocG, ops.index[ia,G0idx], -2*numpy.pi*Zia*rloc**2) + + elif cell.dimension == 2: + raise NotImplementedError + else: + raise NotImplementedError(f'Low dimension ft_type {cell.low_dim_ft_type}' + f' not implemented for dimension {cell.dimension}') + return vlocG diff --git a/pyscfad/pbc/scf/hf.py b/pyscfad/pbc/scf/hf.py index 8afee05e..a31485fa 100644 --- a/pyscfad/pbc/scf/hf.py +++ b/pyscfad/pbc/scf/hf.py @@ -120,8 +120,12 @@ def dump_chk(self, envs): fh5['scf/kpt'] = stop_grad(self.kpt) return self + def energy_nuc(self): + # recompute nuclear energy to trace it + return self.cell.energy_nuc() + + check_sanity = stop_trace(pyscf_pbc_hf.SCF.check_sanity) get_veff = pyscf_pbc_hf.SCF.get_veff - energy_nuc = pyscf_pbc_hf.SCF.energy_nuc energy_grad = NotImplemented diff --git a/pyscfad/pbc/scf/khf.py b/pyscfad/pbc/scf/khf.py index 5ca1b4ad..c440f225 100644 --- a/pyscfad/pbc/scf/khf.py +++ b/pyscfad/pbc/scf/khf.py @@ -5,7 +5,7 @@ from pyscf.pbc.scf import khf as pyscf_khf from pyscfad import util from pyscfad import numpy as np -from pyscfad.ops import stop_grad +from pyscfad.ops import stop_grad, stop_trace from pyscfad.lib import logger from pyscfad.scf import hf as mol_hf from pyscfad.pbc import df @@ -203,3 +203,5 @@ def get_init_guess(self, cell=None, key='minao', s1e=None): 'of electrons %s', ne/nkpts, nelectron/nkpts) dm_kpts *= (nelectron / ne).reshape(-1,1,1) return dm_kpts + + check_sanity = stop_trace(pyscf_khf.KRHF.check_sanity) diff --git a/pyscfad/pbc/tools/pbc.py b/pyscfad/pbc/tools/pbc.py index 5b6b1202..0bac9dd8 100644 --- a/pyscfad/pbc/tools/pbc.py +++ b/pyscfad/pbc/tools/pbc.py @@ -1,17 +1,16 @@ +from functools import wraps import warnings import copy import numpy as np from pyscf import lib -from pyscf.pbc.tools import get_monkhorst_pack_size, cutoff_to_mesh +from pyscf.pbc import tools as pyscf_pbctools from pyscfad import numpy as jnp from pyscfad import ops -from pyscfad.ops import stop_grad +from pyscfad.ops import stop_grad, stop_trace from pyscfad.lib import logger +@wraps(pyscf_pbctools.fft) def fft(f, mesh): - ''' - 3D FFT with jax.numpy backend - ''' if f.size == 0: return np.zeros_like(f) @@ -24,10 +23,8 @@ def fft(f, mesh): else: return g3d.reshape(-1, ngrids) +@wraps(pyscf_pbctools.ifft) def ifft(g, mesh): - ''' - 3D inverse FFT with jax.numpy backend - ''' if g.size == 0: return np.zeros_like(g) @@ -40,22 +37,16 @@ def ifft(g, mesh): else: return f3d.reshape(-1, ngrids) +@wraps(pyscf_pbctools.fftk) def fftk(f, mesh, expmikr): - r'''Perform the 3D FFT of a real-space function which is (periodic*e^{ikr}). - - fk(k+G) = \sum_r fk(r) e^{-i(k+G)r} = \sum_r [f(k)e^{-ikr}] e^{-iGr} - ''' return fft(f*expmikr, mesh) - +@wraps(pyscf_pbctools.ifftk) def ifftk(g, mesh, expikr): - r'''Perform the 3D inverse FFT of f(k+G) into a function which is (periodic*e^{ikr}). - - fk(r) = (1/Ng) \sum_G fk(k+G) e^{i(k+G)r} = (1/Ng) \sum_G [fk(k+G)e^{iGr}] e^{ikr} - ''' return ifft(g, mesh) * expikr -# modified from pyscf v2.3 +# modified from pyscf v2.6 +@wraps(pyscf_pbctools.get_lattice_Ls) def get_lattice_Ls(cell, nimgs=None, rcut=None, dimension=None, discard=True): if dimension is None: if cell.dimension < 2 or cell.low_dim_ft_type == 'inf_vacuum': @@ -68,9 +59,10 @@ def get_lattice_Ls(cell, nimgs=None, rcut=None, dimension=None, discard=True): if dimension == 0 or rcut <= 0: return np.zeros((1, 3)) - a = cell.lattice_vectors() + a1 = cell.lattice_vectors() + a = ops.to_numpy(a1) - scaled_atom_coords = np.linalg.solve(stop_grad(a.T), stop_grad(cell.atom_coords().T)).T + scaled_atom_coords = ops.to_numpy(cell.get_scaled_atom_coords()) atom_boundary_max = scaled_atom_coords[:,:dimension].max(axis=0) atom_boundary_min = scaled_atom_coords[:,:dimension].min(axis=0) if (np.any(atom_boundary_max > 1) or np.any(atom_boundary_min < -1)): @@ -86,29 +78,31 @@ def find_boundary(a): ub = (rcut + abs(r[2,3:]).sum()) / abs(r[2,2]) return ub - xb = find_boundary(stop_grad(a[[1,2,0]])) + xb = find_boundary(a[[1,2,0]]) if dimension > 1: - yb = find_boundary(stop_grad(a[[2,0,1]])) + yb = find_boundary(a[[2,0,1]]) else: yb = 0 if dimension > 2: - zb = find_boundary(stop_grad(a)) + zb = find_boundary(a) else: zb = 0 bounds = np.ceil([xb, yb, zb]).astype(int) Ts = lib.cartesian_prod((np.arange(-bounds[0], bounds[0]+1), np.arange(-bounds[1], bounds[1]+1), np.arange(-bounds[2], bounds[2]+1))) - Ls = jnp.dot(Ts[:,:dimension], a[:dimension]) - - ovlp_penalty += 1e-200 # avoid /0 - Ts_scaled = (Ts[:,:dimension] + 1e-200) / ovlp_penalty - ovlp_penalty_fac = 1. / abs(Ts_scaled).min(axis=1) - Ls_mask = np.linalg.norm(stop_grad(Ls), axis=1) * (1-ovlp_penalty_fac) < rcut - Ls = Ls[Ls_mask] + Ls = jnp.dot(Ts[:,:dimension], a1[:dimension]) + + if discard: + ovlp_penalty += 1e-200 # avoid /0 + Ts_scaled = (Ts[:,:dimension] + 1e-200) / ovlp_penalty + ovlp_penalty_fac = 1. / abs(Ts_scaled).min(axis=1) + Ls_mask = np.linalg.norm(stop_grad(Ls), axis=1) * (1-ovlp_penalty_fac) < rcut + Ls = Ls[Ls_mask] return jnp.asarray(Ls) +@wraps(pyscf_pbctools.get_coulG) def get_coulG(cell, k=np.zeros(3), exx=False, mf=None, mesh=None, Gv=None, wrap_around=True, omega=None, **kwargs): exxdiv = exx @@ -123,7 +117,6 @@ def get_coulG(cell, k=np.zeros(3), exx=False, mf=None, mesh=None, Gv=None, warnings.warn('cell.gs is deprecated. It is replaced by cell.mesh,' 'the number of PWs (=2*gs+1) along each direction.') mesh = [2*n+1 for n in kwargs['gs']] - if Gv is None: Gv = cell.get_Gv(mesh) @@ -136,8 +129,8 @@ def get_coulG(cell, k=np.zeros(3), exx=False, mf=None, mesh=None, Gv=None, if wrap_around and abs(k).sum() > 1e-9: b = cell.reciprocal_vectors() box_edge = jnp.einsum('i,ij->ij', np.asarray(mesh)//2+0.5, b) - assert (all(np.linalg.solve(stop_grad(box_edge.T), k).round(9).astype(int)==0)) - reduced_coords = np.linalg.solve(stop_grad(box_edge.T), stop_grad(kG.T)).T.round(9) + assert (all(stop_trace(np.linalg.solve)(box_edge.T, k).round(9).astype(int)==0)) + reduced_coords = stop_trace(np.linalg.solve)(box_edge.T, kG.T).T.round(9) on_edge = reduced_coords.astype(int) if cell.dimension >= 1: equal2boundary |= reduced_coords[:,0] == 1 @@ -156,6 +149,8 @@ def get_coulG(cell, k=np.zeros(3), exx=False, mf=None, mesh=None, Gv=None, kG = ops.index_add(kG, ops.index[on_edge[:,2]==-1], 2 * box_edge[2]) absG2 = jnp.einsum('gi,gi->g', kG, kG) + G0_idx = jnp.where(absG2==0)[0] + absG2 = jnp.where(absG2!=0, absG2, 0) if getattr(mf, 'kpts', None) is not None: kpts = mf.kpts @@ -167,7 +162,8 @@ def get_coulG(cell, k=np.zeros(3), exx=False, mf=None, mesh=None, Gv=None, Rc = (3*Nk*cell.vol/(4*np.pi))**(1./3) with np.errstate(divide='ignore',invalid='ignore'): coulG = 4*np.pi/absG2*(1.0 - jnp.cos(jnp.sqrt(absG2)*Rc)) - coulG = ops.index_update(coulG, ops.index[absG2==0], 4*np.pi*0.5*Rc**2) + if len(G0_idx) > 0: + coulG = ops.index_update(coulG, ops.index[G0_idx], 4*np.pi*0.5*Rc**2) if cell.dimension < 3: raise NotImplementedError @@ -177,10 +173,10 @@ def get_coulG(cell, k=np.zeros(3), exx=False, mf=None, mesh=None, Gv=None, # Ewald probe charge method to get the leading term of the finite size # error in exchange integrals - G0_idx = np.where(stop_grad(absG2)==0)[0] if cell.dimension != 2 or cell.low_dim_ft_type == 'inf_vacuum': with np.errstate(divide='ignore'): coulG = 4*np.pi/absG2 + if len(G0_idx) > 0: coulG = ops.index_update(coulG, ops.index[G0_idx], 0) elif cell.dimension == 2: @@ -195,6 +191,7 @@ def get_coulG(cell, k=np.zeros(3), exx=False, mf=None, mesh=None, Gv=None, coulG = weights*4*np.pi/absG2 if len(G0_idx) > 0: coulG = ops.index_update(coulG, ops.index[G0_idx], -2*np.pi*Ld2**2) + elif cell.dimension == 1: logger.warn(cell, 'No method for PBC dimension 1, dim-type %s.' ' cell.low_dim_ft_type="inf_vacuum" should be set.', @@ -204,7 +201,9 @@ def get_coulG(cell, k=np.zeros(3), exx=False, mf=None, mesh=None, Gv=None, if cell.dimension > 0 and exxdiv == 'ewald' and len(G0_idx) > 0: coulG = ops.index_add(coulG, ops.index[G0_idx], Nk*cell.vol*madelung(cell, kpts)) - coulG = ops.index_update(coulG, ops.index[equal2boundary], 0) + if equal2boundary is not None: + coulG = ops.index_update(coulG, ops.index[equal2boundary], 0) + if omega is not None: if omega > 0: # long range part @@ -219,8 +218,11 @@ def get_coulG(cell, k=np.zeros(3), exx=False, mf=None, mesh=None, Gv=None, return coulG +get_monkhorst_pack_size = stop_trace(pyscf_pbctools.get_monkhorst_pack_size) +cutoff_to_mesh = stop_trace(pyscf_pbctools.cutoff_to_mesh) + def madelung(cell, kpts): - Nk = get_monkhorst_pack_size(stop_grad(cell), stop_grad(kpts)) + Nk = get_monkhorst_pack_size(cell, kpts) ecell = copy.copy(cell) ecell._atm = np.array([[1, cell._env.size, 0, 0, 0, 0]]) ecell._env = np.append(cell._env, [0., 0., 0.]) diff --git a/pyscfadlib/pyscfadlib/gto/CMakeLists.txt b/pyscfadlib/pyscfadlib/gto/CMakeLists.txt index b81d133e..2c3407d0 100644 --- a/pyscfadlib/pyscfadlib/gto/CMakeLists.txt +++ b/pyscfadlib/pyscfadlib/gto/CMakeLists.txt @@ -18,7 +18,6 @@ add_library(cgto_ad SHARED ft_ao.c ft_ao_deriv.c fill_grids_int2c.c grid_ao_drv.c deriv1.c deriv2.c nr_ecp.c nr_ecp_deriv.c autocode/auto_eval1.c) -add_dependencies(cgto_ad np_helper_ad) set_target_properties(cgto_ad PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${PROJECT_SOURCE_DIR}) diff --git a/pyscfadlib/pyscfadlib/np_helper/CMakeLists.txt b/pyscfadlib/pyscfadlib/np_helper/CMakeLists.txt index d1cc879a..760dc231 100644 --- a/pyscfadlib/pyscfadlib/np_helper/CMakeLists.txt +++ b/pyscfadlib/pyscfadlib/np_helper/CMakeLists.txt @@ -13,9 +13,9 @@ # limitations under the License. add_library(np_helper_ad SHARED - transpose.c pack_tril.c npdot.c condense.c omp_reduce.c np_helper.c) + pack_tril.c np_helper.c) set_target_properties(np_helper_ad PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${PROJECT_SOURCE_DIR}) -target_link_libraries(np_helper_ad ${BLAS_LIBRARIES} ${OPENMP_C_PROPERTIES}) +target_link_libraries(np_helper_ad ${OPENMP_C_PROPERTIES}) diff --git a/pyscfadlib/pyscfadlib/np_helper/condense.c b/pyscfadlib/pyscfadlib/np_helper/condense.c deleted file mode 100644 index 586ca80e..00000000 --- a/pyscfadlib/pyscfadlib/np_helper/condense.c +++ /dev/null @@ -1,283 +0,0 @@ -/* Copyright 2014-2018 The PySCF Developers. All Rights Reserved. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - - * - * Author: Qiming Sun - */ - -#include -#include -#include -#define MIN(X,Y) ((X)<(Y) ? (X) : (Y)) -#define MAX(X,Y) ((X)>(Y) ? (X) : (Y)) - -/* - * def condense(op, a, loc): - * nd = loc[-1] - * out = numpy.empty((nd,nd)) - * for i,i0 in enumerate(loc): - * i1 = loc[i+1] - * for j,j0 in enumerate(loc): - * j1 = loc[j+1] - * out[i,j] = op(a[i0:i1,j0:j1]) - * return out - */ - -void NPcondense(double (*op)(double *, int, int, int), double *out, double *a, - int *loc_x, int *loc_y, int nloc_x, int nloc_y) -{ - const size_t nj = loc_y[nloc_y]; - const size_t Nloc_y = nloc_y; -#pragma omp parallel -{ - int i, j, i0, j0, di, dj; -#pragma omp for - for (i = 0; i < nloc_x; i++) { - i0 = loc_x[i]; - di = loc_x[i+1] - i0; - for (j = 0; j < nloc_y; j++) { - j0 = loc_y[j]; - dj = loc_y[j+1] - j0; - out[i*Nloc_y+j] = op(a+i0*nj+j0, nj, di, dj); - } - } -} -} - -double NP_sum(double *a, int nd, int di, int dj) -{ - int i, j; - double out = 0; - for (i = 0; i < di; i++) { - for (j = 0; j < dj; j++) { - out += a[i*nd+j]; - } } - return out; -} -double NP_max(double *a, int nd, int di, int dj) -{ - if (di == 0 || dj == 0) { - return 0.; - } - int i, j; - double out = a[0]; - for (i = 0; i < di; i++) { - for (j = 0; j < dj; j++) { - out = MAX(out, a[i*nd+j]); - } } - return out; -} -double NP_min(double *a, int nd, int di, int dj) -{ - if (di == 0 || dj == 0) { - return 0.; - } - int i, j; - double out = a[0]; - for (i = 0; i < di; i++) { - for (j = 0; j < dj; j++) { - out = MIN(out, a[i*nd+j]); - } } - return out; -} -double NP_abssum(double *a, int nd, int di, int dj) -{ - int i, j; - double out = 0; - for (i = 0; i < di; i++) { - for (j = 0; j < dj; j++) { - out += fabs(a[i*nd+j]); - } } - return out; -} -double NP_absmax(double *a, int nd, int di, int dj) -{ - if (di == 0 || dj == 0) { - return 0.; - } - int i, j; - double out = fabs(a[0]); - for (i = 0; i < di; i++) { - for (j = 0; j < dj; j++) { - out = MAX(out, fabs(a[i*nd+j])); - } } - return out; -} -double NP_absmin(double *a, int nd, int di, int dj) -{ - if (di == 0 || dj == 0) { - return 0.; - } - int i, j; - double out = fabs(a[0]); - for (i = 0; i < di; i++) { - for (j = 0; j < dj; j++) { - out = MIN(out, fabs(a[i*nd+j])); - } } - return out; -} -double NP_norm(double *a, int nd, int di, int dj) -{ - if (di == 0 || dj == 0) { - return 0.; - } - int i, j; - double out = 0; - for (i = 0; i < di; i++) { - for (j = 0; j < dj; j++) { - out += a[i*nd+j] * a[i*nd+j]; - } } - return sqrt(out); -} - -void NPbcondense(int8_t (*op)(int8_t *, int, int, int), int8_t *out, int8_t *a, - int *loc_x, int *loc_y, int nloc_x, int nloc_y) -{ - size_t nj = loc_y[nloc_y]; - size_t Nloc_y = nloc_y; -#pragma omp parallel -{ - int i, j, i0, j0, di, dj; -#pragma omp for - for (i = 0; i < nloc_x; i++) { - i0 = loc_x[i]; - di = loc_x[i+1] - i0; - for (j = 0; j < nloc_y; j++) { - j0 = loc_y[j]; - dj = loc_y[j+1] - j0; - out[i*Nloc_y+j] = op(a+i0*nj+j0, nj, di, dj); - } - } -} -} - -void NPBcondense(uint8_t (*op)(uint8_t *, int, int, int), uint8_t *out, uint8_t *a, - int *loc_x, int *loc_y, int nloc_x, int nloc_y) -{ - size_t nj = loc_y[nloc_y]; - size_t Nloc_y = nloc_y; -#pragma omp parallel -{ - int i, j, i0, j0, di, dj; -#pragma omp for - for (i = 0; i < nloc_x; i++) { - i0 = loc_x[i]; - di = loc_x[i+1] - i0; - for (j = 0; j < nloc_y; j++) { - j0 = loc_y[j]; - dj = loc_y[j+1] - j0; - out[i*Nloc_y+j] = op(a+i0*nj+j0, nj, di, dj); - } - } -} -} - -void NPicondense(int (*op)(int *, int, int, int), int *out, int *a, - int *loc_x, int *loc_y, int nloc_x, int nloc_y) -{ - size_t nj = loc_y[nloc_y]; - size_t Nloc_y = nloc_y; -#pragma omp parallel -{ - int i, j, i0, j0, di, dj; -#pragma omp for - for (i = 0; i < nloc_x; i++) { - i0 = loc_x[i]; - di = loc_x[i+1] - i0; - for (j = 0; j < nloc_y; j++) { - j0 = loc_y[j]; - dj = loc_y[j+1] - j0; - out[i*Nloc_y+j] = op(a+i0*nj+j0, nj, di, dj); - } - } -} -} - -void NPfcondense(float (*op)(float *, int, int, int), float *out, float *a, - int *loc_x, int *loc_y, int nloc_x, int nloc_y) -{ - size_t nj = loc_y[nloc_y]; - size_t Nloc_y = nloc_y; -#pragma omp parallel -{ - int i, j, i0, j0, di, dj; -#pragma omp for - for (i = 0; i < nloc_x; i++) { - i0 = loc_x[i]; - di = loc_x[i+1] - i0; - for (j = 0; j < nloc_y; j++) { - j0 = loc_y[j]; - dj = loc_y[j+1] - j0; - out[i*Nloc_y+j] = op(a+i0*nj+j0, nj, di, dj); - } - } -} -} - -int8_t NP_any(int8_t *a, int nd, int di, int dj) -{ - int i, j; - for (i = 0; i < di; i++) { - for (j = 0; j < dj; j++) { - if (a[i*nd+j]) { - return 1; - } - } } - return 0; -} - -int8_t NP_all(int8_t *a, int nd, int di, int dj) -{ - int i, j; - for (i = 0; i < di; i++) { - for (j = 0; j < dj; j++) { - if (!a[i*nd+j]) { - return 0; - } - } } - return 1; -} - -uint8_t NP_Bmax(uint8_t *a, int nd, int di, int dj) -{ - int i, j; - uint8_t out = a[0]; - for (i = 0; i < di; i++) { - for (j = 0; j < dj; j++) { - out = MAX(out, a[i*nd+j]); - } } - return out; -} - -int NP_imax(int *a, int nd, int di, int dj) -{ - int i, j; - int out = a[0]; - for (i = 0; i < di; i++) { - for (j = 0; j < dj; j++) { - out = MAX(out, a[i*nd+j]); - } } - return out; -} - -float NP_fmax(float *a, int nd, int di, int dj) -{ - int i, j; - float out = a[0]; - for (i = 0; i < di; i++) { - for (j = 0; j < dj; j++) { - out = MAX(out, a[i*nd+j]); - } } - return out; -} diff --git a/pyscfadlib/pyscfadlib/np_helper/np_helper.c b/pyscfadlib/pyscfadlib/np_helper/np_helper.c index a997a242..e59e01f3 100644 --- a/pyscfadlib/pyscfadlib/np_helper/np_helper.c +++ b/pyscfadlib/pyscfadlib/np_helper/np_helper.c @@ -24,14 +24,6 @@ void NPdset0(double *p, const size_t n) } } -void NPzset0(double complex *p, const size_t n) -{ - size_t i; - for (i = 0; i < n; i++) { - p[i] = 0; - } -} - void NPdcopy(double *out, const double *in, const size_t n) { size_t i; @@ -40,10 +32,3 @@ void NPdcopy(double *out, const double *in, const size_t n) } } -void NPzcopy(double complex *out, const double complex *in, const size_t n) -{ - size_t i; - for (i = 0; i < n; i++) { - out[i] = in[i]; - } -} diff --git a/pyscfadlib/pyscfadlib/np_helper/np_helper.h b/pyscfadlib/pyscfadlib/np_helper/np_helper.h index 3ed8d055..0eac73e0 100644 --- a/pyscfadlib/pyscfadlib/np_helper/np_helper.h +++ b/pyscfadlib/pyscfadlib/np_helper/np_helper.h @@ -35,36 +35,6 @@ void NPdsymm_triu(int n, double *mat, int hermi); void NPzhermi_triu(int n, double complex *mat, int hermi); void NPdunpack_tril(int n, double *tril, double *mat, int hermi); -void NPdunpack_row(int ndim, int row_id, double *tril, double *row); -void NPzunpack_tril(int n, double complex *tril, double complex *mat, - int hermi); -void NPdpack_tril(int n, double *tril, double *mat); -void NPzpack_tril(int n, double complex *tril, double complex *mat); - -void NPdtranspose(int n, int m, double *a, double *at); -void NPztranspose(int n, int m, double complex *a, double complex *at); -void NPdtranspose_021(int *shape, double *a, double *at); -void NPztranspose_021(int *shape, double complex *a, double complex *at); - -void NPdunpack_tril_2d(int count, int n, double *tril, double *mat, int hermi); -void NPzunpack_tril_2d(int count, int n, - double complex *tril, double complex *mat, int hermi); -void NPdpack_tril_2d(int count, int n, double *tril, double *mat); - -void NPomp_split(size_t *start, size_t *end, size_t n); -void NPomp_dsum_reduce_inplace(double **vec, size_t count); -void NPomp_dprod_reduce_inplace(double **vec, size_t count); -void NPomp_zsum_reduce_inplace(double complex **vec, size_t count); -void NPomp_zprod_reduce_inplace(double complex **vec, size_t count); void NPdset0(double *p, const size_t n); -void NPzset0(double complex *p, const size_t n); void NPdcopy(double *out, const double *in, const size_t n); -void NPzcopy(double complex *out, const double complex *in, const size_t n); - -void NPdgemm(const char trans_a, const char trans_b, - const int m, const int n, const int k, - const int lda, const int ldb, const int ldc, - const int offseta, const int offsetb, const int offsetc, - double *a, double *b, double *c, - const double alpha, const double beta); diff --git a/pyscfadlib/pyscfadlib/np_helper/npdot.c b/pyscfadlib/pyscfadlib/np_helper/npdot.c deleted file mode 100644 index 0ae03987..00000000 --- a/pyscfadlib/pyscfadlib/np_helper/npdot.c +++ /dev/null @@ -1,244 +0,0 @@ -/* Copyright 2014-2018 The PySCF Developers. All Rights Reserved. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - - * - * Author: Qiming Sun - */ - -#include -#include -//#include -#include "config.h" -#include "vhf/fblas.h" -#include "np_helper/np_helper.h" - -#define MIN(X,Y) ((X) < (Y) ? (X) : (Y)) -#define MAX(X,Y) ((X) > (Y) ? (X) : (Y)) - -/* - * numpy.dot may call unoptimized blas - */ -void NPdgemm(const char trans_a, const char trans_b, - const int m, const int n, const int k, - const int lda, const int ldb, const int ldc, - const int offseta, const int offsetb, const int offsetc, - double *a, double *b, double *c, - const double alpha, const double beta) -{ - const size_t Ldc = ldc; - int i, j; - if (m == 0 || n == 0) { - return; - } else if (k == 0) { - for (i = 0; i < n; i++) { - for (j = 0; j < m; j++) { - c[i*Ldc+j] = 0; - } } - return; - } - a += offseta; - b += offsetb; - c += offsetc; - - if ((k/m) > 3 && (k/n) > 3) { // parallelize k - - if (beta == 0) { - for (i = 0; i < n; i++) { - for (j = 0; j < m; j++) { - c[i*Ldc+j] = 0; - } - } - } else { - for (i = 0; i < n; i++) { - for (j = 0; j < m; j++) { - c[i*Ldc+j] *= beta; - } - } - } - -#pragma omp parallel private(i, j) -{ - double D0 = 0; - double *cpriv = malloc(sizeof(double) * (m*n+2)); - size_t k0, k1, ij; - NPomp_split(&k0, &k1, k); - int dk = k1 - k0; - if (dk > 0) { - size_t astride = k0; - size_t bstride = k0; - if (trans_a == 'N') { - astride *= lda; - } - if (trans_b != 'N') { - bstride *= ldb; - } - dgemm_(&trans_a, &trans_b, &m, &n, &dk, - &alpha, a+astride, &lda, b+bstride, &ldb, - &D0, cpriv, &m); - } -#pragma omp critical - if (dk > 0) { - for (ij = 0, i = 0; i < n; i++) { - for (j = 0; j < m; j++, ij++) { - c[i*Ldc+j] += cpriv[ij]; - } } - } - - free(cpriv); -} - - } else if (m > n*2) { // parallelize m - -#pragma omp parallel -{ - size_t m0, m1; - NPomp_split(&m0, &m1, m); - int dm = m1 - m0; - if (dm > 0) { - size_t astride = m0; - if (trans_a != 'N') { - astride *= lda; - } - dgemm_(&trans_a, &trans_b, &dm, &n, &k, - &alpha, a+astride, &lda, b, &ldb, - &beta, c+m0, &ldc); - } -} - - } else { // parallelize n - -#pragma omp parallel -{ - size_t n0, n1; - NPomp_split(&n0, &n1, n); - int dn = n1 - n0; - if (dn > 0) { - size_t bstride = n0; - if (trans_b == 'N') { - bstride *= ldb; - } - dgemm_(&trans_a, &trans_b, &m, &dn, &k, - &alpha, a, &lda, b+bstride, &ldb, - &beta, c+Ldc*n0, &ldc); - } -} - } -} - - -void NPzgemm(const char trans_a, const char trans_b, - const int m, const int n, const int k, - const int lda, const int ldb, const int ldc, - const int offseta, const int offsetb, const int offsetc, - double complex *a, double complex *b, double complex *c, - const double complex *alpha, const double complex *beta) -{ - const size_t Ldc = ldc; - int i, j; - if (m == 0 || n == 0) { - return; - } else if (k == 0) { - for (i = 0; i < n; i++) { - for (j = 0; j < m; j++) { - c[i*Ldc+j] = 0; - } } - return; - } - a += offseta; - b += offsetb; - c += offsetc; - - if ((k/m) > 3 && (k/n) > 3) { // parallelize k - - if (creal(*beta) == 0 && cimag(*beta) == 0) { - for (i = 0; i < n; i++) { - for (j = 0; j < m; j++) { - c[i*Ldc+j] = 0; - } - } - } else { - for (i = 0; i < n; i++) { - for (j = 0; j < m; j++) { - c[i*Ldc+j] *= beta[0]; - } - } - } - -#pragma omp parallel private(i, j) -{ - double complex Z0 = 0; - double complex *cpriv = malloc(sizeof(double complex) * (m*n+2)); - size_t k0, k1, ij; - NPomp_split(&k0, &k1, k); - int dk = k1 - k0; - if (dk > 0) { - size_t astride = k0; - size_t bstride = k0; - if (trans_a == 'N') { - astride *= lda; - } - if (trans_b != 'N') { - bstride *= ldb; - } - zgemm_(&trans_a, &trans_b, &m, &n, &dk, - alpha, a+astride, &lda, b+bstride, &ldb, - &Z0, cpriv, &m); - } -#pragma omp critical - if (dk > 0) { - for (ij = 0, i = 0; i < n; i++) { - for (j = 0; j < m; j++, ij++) { - c[i*Ldc+j] += cpriv[ij]; - } } - } - free(cpriv); -} - - } else if (m > n*2) { // parallelize m - -#pragma omp parallel -{ - size_t m0, m1; - NPomp_split(&m0, &m1, m); - int dm = m1 - m0; - if (dm > 0) { - size_t astride = m0; - if (trans_a != 'N') { - astride *= lda; - } - zgemm_(&trans_a, &trans_b, &dm, &n, &k, - alpha, a+astride, &lda, b, &ldb, - beta, c+m0, &ldc); - } -} - - } else { // parallelize n - -#pragma omp parallel -{ - size_t n0, n1; - NPomp_split(&n0, &n1, n); - int dn = n1 - n0; - if (dn > 0) { - size_t bstride = n0; - if (trans_b == 'N') { - bstride *= ldb; - } - zgemm_(&trans_a, &trans_b, &m, &dn, &k, - alpha, a, &lda, b+bstride, &ldb, - beta, c+Ldc*n0, &ldc); - } -} - } -} diff --git a/pyscfadlib/pyscfadlib/np_helper/omp_reduce.c b/pyscfadlib/pyscfadlib/np_helper/omp_reduce.c deleted file mode 100644 index c59620a1..00000000 --- a/pyscfadlib/pyscfadlib/np_helper/omp_reduce.c +++ /dev/null @@ -1,166 +0,0 @@ -/* Copyright 2014-2020 The PySCF Developers. All Rights Reserved. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - - * - * Author: Qiming Sun - */ - -#include -#include -#include "config.h" - -#define MIN(x, y) ((x) < (y) ? (x) : (y)) - -void NPomp_split(size_t *start, size_t *end, size_t n) { - int nthread = omp_get_num_threads(); - int thread_id = omp_get_thread_num(); - int rest = n % nthread; - size_t blksize = n / nthread; - if (thread_id < rest) { - blksize++; - *start = blksize * thread_id; - *end = blksize * (thread_id + 1); - } else{ - *start = blksize * thread_id + rest; - *end = *start + blksize; - } -} - -static int _highest_power2(int n) -{ - int v = n - 1; - v |= v >> 1; - v |= v >> 2; - v |= v >> 4; - v |= v >> 8; - return (v + 1) >> 1; -} - -void NPomp_dsum_reduce_inplace1(double **vec, size_t count) -{ - if (count <= 1) { - return; - } - unsigned int nthreads = omp_get_num_threads(); - unsigned int thread_id = omp_get_thread_num(); - double *src = vec[thread_id]; - double *dst; - int n; - size_t i; -#pragma omp barrier - for (n = _highest_power2(nthreads); n > 0; n >>= 1) { - if (thread_id >= n) { - dst = vec[thread_id - n]; - for (i = 0; i < count; i++) { - dst[i] += src[i]; - } - } -#pragma omp barrier - } -} - -void NPomp_dsum_reduce_inplace(double **vec, size_t count) -{ - unsigned int nthreads = omp_get_num_threads(); - unsigned int thread_id = omp_get_thread_num(); - size_t blksize = (count + nthreads - 1) / nthreads; - size_t start = thread_id * blksize; - size_t end = MIN(start + blksize, count); - double *dst = vec[0]; - double *src; - size_t it, i; -#pragma omp barrier - for (it = 1; it < nthreads; it++) { - src = vec[it]; - for (i = start; i < end; i++) { - dst[i] += src[i]; - } - } -#pragma omp barrier -} - -void NPomp_dprod_reduce_inplace(double **vec, size_t count) -{ - unsigned int nthreads = omp_get_num_threads(); - unsigned int thread_id = omp_get_thread_num(); - size_t blksize = (count + nthreads - 1) / nthreads; - size_t start = thread_id * blksize; - size_t end = MIN(start + blksize, count); - double *dst = vec[0]; - double *src; - size_t it, i; -#pragma omp barrier - for (it = 1; it < nthreads; it++) { - src = vec[it]; - for (i = start; i < end; i++) { - dst[i] *= src[i]; - } - } -#pragma omp barrier -} - -void NPomp_zsum_reduce_inplace(double complex **vec, size_t count) -{ - unsigned int nthreads = omp_get_num_threads(); - unsigned int thread_id = omp_get_thread_num(); - size_t blksize = (count + nthreads - 1) / nthreads; - size_t start = thread_id * blksize; - size_t end = MIN(start + blksize, count); - double complex *dst = vec[0]; - double complex *src; - size_t it, i; -#pragma omp barrier - for (it = 1; it < nthreads; it++) { - src = vec[it]; - for (i = start; i < end; i++) { - dst[i] += src[i]; - } - } -#pragma omp barrier -} - -void NPomp_zprod_reduce_inplace(double complex **vec, size_t count) -{ - unsigned int nthreads = omp_get_num_threads(); - unsigned int thread_id = omp_get_thread_num(); - size_t blksize = (count + nthreads - 1) / nthreads; - size_t start = thread_id * blksize; - size_t end = MIN(start + blksize, count); - double complex *dst = vec[0]; - double complex *src; - size_t it, i; -#pragma omp barrier - for (it = 1; it < nthreads; it++) { - src = vec[it]; - for (i = start; i < end; i++) { - dst[i] *= src[i]; - } - } -#pragma omp barrier -} - - -#ifdef _OPENMP -int get_omp_threads() { - return omp_get_max_threads(); -} -int set_omp_threads(int n) { - omp_set_num_threads(n); - return n; -} -#else -// mimic omp_get_max_threads omp_set_num_threads function of libgomp -int get_omp_threads() { return 1; } -int set_omp_threads(int n) { return 0; } -#endif diff --git a/pyscfadlib/pyscfadlib/np_helper/pack_tril.c b/pyscfadlib/pyscfadlib/np_helper/pack_tril.c index 0859c443..44c80e5f 100644 --- a/pyscfadlib/pyscfadlib/np_helper/pack_tril.c +++ b/pyscfadlib/pyscfadlib/np_helper/pack_tril.c @@ -17,8 +17,6 @@ */ #include "stdlib.h" -#include -#include "config.h" #include "np_helper.h" void NPdsymm_triu(int n, double *mat, int hermi) @@ -69,206 +67,3 @@ void NPdunpack_tril(int n, double *tril, double *mat, int hermi) } } -// unpack one row from the compact matrix-tril coefficients -void NPdunpack_row(int ndim, int row_id, double *tril, double *row) -{ - int i; - size_t idx = ((size_t)row_id) * (row_id + 1) / 2; - NPdcopy(row, tril+idx, row_id); - for (i = row_id; i < ndim; i++) { - idx += i; - row[i] = tril[idx]; - } -} - -void NPzunpack_tril(int n, double complex *tril, double complex *mat, - int hermi) -{ - size_t i, j, ij; - for (ij = 0, i = 0; i < n; i++) { - for (j = 0; j <= i; j++, ij++) { - mat[i*n+j] = tril[ij]; - } - } - if (hermi) { - NPzhermi_triu(n, mat, hermi); - } -} - -void NPdpack_tril(int n, double *tril, double *mat) -{ - size_t i, j, ij; - for (ij = 0, i = 0; i < n; i++) { - for (j = 0; j <= i; j++, ij++) { - tril[ij] = mat[i*n+j]; - } - } -} - -void NPzpack_tril(int n, double complex *tril, double complex *mat) -{ - size_t i, j, ij; - for (ij = 0, i = 0; i < n; i++) { - for (j = 0; j <= i; j++, ij++) { - tril[ij] = mat[i*n+j]; - } - } -} - -/* out += in[idx[:,None],idy] */ -void NPdtake_2d(double *out, double *in, int *idx, int *idy, - int odim, int idim, int nx, int ny) -{ -#pragma omp parallel default(none) \ - shared(out, in, idx,idy, odim, idim, nx, ny) -{ - size_t i, j; - double *pin; -#pragma omp for schedule (static) - for (i = 0; i < nx; i++) { - pin = in + (size_t)idim * idx[i]; - for (j = 0; j < ny; j++) { - out[i*odim+j] = pin[idy[j]]; - } - } -} -} - -void NPztake_2d(double complex *out, double complex *in, int *idx, int *idy, - int odim, int idim, int nx, int ny) -{ -#pragma omp parallel default(none) \ - shared(out, in, idx,idy, odim, idim, nx, ny) -{ - size_t i, j; - double complex *pin; -#pragma omp for schedule (static) - for (i = 0; i < nx; i++) { - pin = in + (size_t)idim * idx[i]; - for (j = 0; j < ny; j++) { - out[i*odim+j] = pin[idy[j]]; - } - } -} -} - -/* out[idx[:,None],idy] += in */ -void NPdtakebak_2d(double *out, double *in, int *idx, int *idy, - int odim, int idim, int nx, int ny, int thread_safe) -{ - if (thread_safe) { -#pragma omp parallel default(none) \ - shared(out, in, idx,idy, odim, idim, nx, ny) -{ - size_t i, j; - double *pout; -#pragma omp for schedule (static) - for (i = 0; i < nx; i++) { - pout = out + (size_t)odim * idx[i]; - for (j = 0; j < ny; j++) { - pout[idy[j]] += in[i*idim+j]; - } - } -} - } else { - size_t i, j; - double *pout; - for (i = 0; i < nx; i++) { - pout = out + (size_t)odim * idx[i]; - for (j = 0; j < ny; j++) { - pout[idy[j]] += in[i*idim+j]; - } - } - } -} - -void NPztakebak_2d(double complex *out, double complex *in, int *idx, int *idy, - int odim, int idim, int nx, int ny, int thread_safe) -{ - if (thread_safe) { -#pragma omp parallel default(none) \ - shared(out, in, idx,idy, odim, idim, nx, ny) -{ - size_t i, j; - double complex *pout; -#pragma omp for schedule (static) - for (i = 0; i < nx; i++) { - pout = out + (size_t)odim * idx[i]; - for (j = 0; j < ny; j++) { - pout[idy[j]] += in[i*idim+j]; - } - } -} - } else { - size_t i, j; - double complex *pout; - for (i = 0; i < nx; i++) { - pout = out + (size_t)odim * idx[i]; - for (j = 0; j < ny; j++) { - pout[idy[j]] += in[i*idim+j]; - } - } - } -} - -void NPdunpack_tril_2d(int count, int n, double *tril, double *mat, int hermi) -{ -#pragma omp parallel default(none) \ - shared(count, n, tril, mat, hermi) -{ - int ic; - size_t nn = n * n; - size_t n2 = n*(n+1)/2; -#pragma omp for schedule (static) - for (ic = 0; ic < count; ic++) { - NPdunpack_tril(n, tril+n2*ic, mat+nn*ic, hermi); - } -} -} - -void NPzunpack_tril_2d(int count, int n, - double complex *tril, double complex *mat, int hermi) -{ -#pragma omp parallel default(none) \ - shared(count, n, tril, mat, hermi) -{ - int ic; - size_t nn = n * n; - size_t n2 = n*(n+1)/2; -#pragma omp for schedule (static) - for (ic = 0; ic < count; ic++) { - NPzunpack_tril(n, tril+n2*ic, mat+nn*ic, hermi); - } -} -} - -void NPdpack_tril_2d(int count, int n, double *tril, double *mat) -{ -#pragma omp parallel default(none) \ - shared(count, n, tril, mat) -{ - int ic; - size_t nn = n * n; - size_t n2 = n*(n+1)/2; -#pragma omp for schedule (static) - for (ic = 0; ic < count; ic++) { - NPdpack_tril(n, tril+n2*ic, mat+nn*ic); - } -} -} - -void NPzpack_tril_2d(int count, int n, double complex *tril, double complex *mat) -{ -#pragma omp parallel default(none) \ - shared(count, n, tril, mat) -{ - int ic; - size_t nn = n * n; - size_t n2 = n*(n+1)/2; -#pragma omp for schedule (static) - for (ic = 0; ic < count; ic++) { - NPzpack_tril(n, tril+n2*ic, mat+nn*ic); - } -} -} - diff --git a/pyscfadlib/pyscfadlib/np_helper/transpose.c b/pyscfadlib/pyscfadlib/np_helper/transpose.c deleted file mode 100644 index cb4ba042..00000000 --- a/pyscfadlib/pyscfadlib/np_helper/transpose.c +++ /dev/null @@ -1,155 +0,0 @@ -/* Copyright 2014-2018 The PySCF Developers. All Rights Reserved. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - - * - * Author: Qiming Sun - */ - -#include -#include -#include "np_helper.h" - -/* - * matrix a[n,m] - */ -void NPdtranspose(int n, int m, double *a, double *at) -{ - size_t i, j, j0, j1; - for (j0 = 0; j0 < n; j0+=BLOCK_DIM) { - j1 = MIN(j0+BLOCK_DIM, n); - for (i = 0; i < m; i++) { - for (j = j0; j < j1; j++) { - at[i*n+j] = a[j*m+i]; - } - } - } -} - -void NPztranspose(int n, int m, double complex *a, double complex *at) -{ - size_t i, j, j0, j1; - for (j0 = 0; j0 < n; j0+=BLOCK_DIM) { - j1 = MIN(j0+BLOCK_DIM, n); - for (i = 0; i < m; i++) { - for (j = j0; j < j1; j++) { - at[i*n+j] = a[j*m+i]; - } - } - } -} - - -void NPdtranspose_021(int *shape, double *a, double *at) -{ -#pragma omp parallel default(none) \ - shared(shape, a, at) -{ - int ic; - size_t nm = shape[1] * shape[2]; -#pragma omp for schedule (static) - for (ic = 0; ic < shape[0]; ic++) { - NPdtranspose(shape[1], shape[2], a+ic*nm, at+ic*nm); - } -} -} - -void NPztranspose_021(int *shape, double complex *a, double complex *at) -{ -#pragma omp parallel default(none) \ - shared(shape, a, at) -{ - int ic; - size_t nm = shape[1] * shape[2]; -#pragma omp for schedule (static) - for (ic = 0; ic < shape[0]; ic++) { - NPztranspose(shape[1], shape[2], a+ic*nm, at+ic*nm); - } -} -} - - -void NPdsymm_sum(int n, double *a, double *out, int hermi) -{ - size_t i, j, j0, j1; - double tmp; - - if (hermi == HERMITIAN || hermi == SYMMETRIC) { - TRIU_LOOP(i, j) { - tmp = a[i*n+j] + a[j*n+i]; - out[i*n+j] = tmp; - out[j*n+i] = tmp; - } - } else { - TRIU_LOOP(i, j) { - tmp = a[i*n+j] - a[j*n+i]; - out[i*n+j] = tmp; - out[j*n+i] =-tmp; - } - } -} - -void NPzhermi_sum(int n, double complex *a, double complex *out, int hermi) -{ - size_t i, j, j0, j1; - double complex tmp; - - if (hermi == HERMITIAN) { - TRIU_LOOP(i, j) { - tmp = a[i*n+j] + conj(a[j*n+i]); - out[i*n+j] = tmp; - out[j*n+i] = conj(tmp); - } - } else if (hermi == SYMMETRIC) { - TRIU_LOOP(i, j) { - tmp = a[i*n+j] + a[j*n+i]; - out[i*n+j] = tmp; - out[j*n+i] = tmp; - } - } else { - TRIU_LOOP(i, j) { - tmp = a[i*n+j] - conj(a[j*n+i]); - out[i*n+j] = tmp; - out[j*n+i] =-conj(tmp); - } - } -} - - -void NPdsymm_021_sum(int *shape, double *a, double *out, int hermi) -{ -#pragma omp parallel default(none) \ - shared(shape, a, out, hermi) -{ - int ic; - size_t nn = shape[1] * shape[1]; -#pragma omp for schedule (static) - for (ic = 0; ic < shape[0]; ic++) { - NPdsymm_sum(shape[1], a+ic*nn, out+ic*nn, hermi); - } -} -} - -void NPzhermi_021_sum(int *shape, double complex *a, double complex *out, int hermi) -{ -#pragma omp parallel default(none) \ - shared(shape, a, out, hermi) -{ - int ic; - size_t nn = shape[1] * shape[1]; -#pragma omp for schedule (static) - for (ic = 0; ic < shape[0]; ic++) { - NPzhermi_sum(shape[1], a+ic*nn, out+ic*nn, hermi); - } -} -} diff --git a/pyscfadlib/pyscfadlib/vjp/CMakeLists.txt b/pyscfadlib/pyscfadlib/vjp/CMakeLists.txt index 5894e431..689e2ab4 100644 --- a/pyscfadlib/pyscfadlib/vjp/CMakeLists.txt +++ b/pyscfadlib/pyscfadlib/vjp/CMakeLists.txt @@ -25,7 +25,7 @@ add_library(ao2mo_vjp SHARED set_target_properties(ao2mo_vjp PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${PROJECT_SOURCE_DIR} ) -target_link_libraries(ao2mo_vjp np_helper_ad ${OPENMP_C_PROPERTIES}) +target_link_libraries(ao2mo_vjp np_helper_ad vjp_util ${BLAS_LIBRARIES} ${OPENMP_C_PROPERTIES}) #vhf add_library(cvhf_vjp SHARED @@ -35,7 +35,7 @@ add_library(cvhf_vjp SHARED set_target_properties(cvhf_vjp PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${PROJECT_SOURCE_DIR} ) -target_link_libraries(cvhf_vjp np_helper_ad ${OPENMP_C_PROPERTIES}) +target_link_libraries(cvhf_vjp np_helper_ad vjp_util ${BLAS_LIBRARIES} ${OPENMP_C_PROPERTIES}) #cc add_library(cc_vjp SHARED @@ -47,4 +47,15 @@ add_library(cc_vjp SHARED set_target_properties(cc_vjp PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${PROJECT_SOURCE_DIR} ) -target_link_libraries(cc_vjp np_helper_ad ${OPENMP_C_PROPERTIES}) +target_link_libraries(cc_vjp vjp_util ${BLAS_LIBRARIES} ${OPENMP_C_PROPERTIES}) + +#util +add_library(vjp_util SHARED + util/omp_reduce.c + util/pack_tril.c +) +set_target_properties(vjp_util PROPERTIES + LIBRARY_OUTPUT_DIRECTORY ${PROJECT_SOURCE_DIR} +) +target_link_libraries(vjp_util ${OPENMP_C_PROPERTIES}) + diff --git a/pyscfadlib/pyscfadlib/vjp/ao2mo/nr_ao2mo_vjp.c b/pyscfadlib/pyscfadlib/vjp/ao2mo/nr_ao2mo_vjp.c index d53cb7ba..bacc46d4 100644 --- a/pyscfadlib/pyscfadlib/vjp/ao2mo/nr_ao2mo_vjp.c +++ b/pyscfadlib/pyscfadlib/vjp/ao2mo/nr_ao2mo_vjp.c @@ -2,6 +2,7 @@ #include "config.h" #include "np_helper/np_helper.h" #include "vhf/fblas.h" +#include "vjp/util/util.h" #define MAX_THREADS 128 #define OUTPUTIJ 1 @@ -18,20 +19,6 @@ struct _AO2MOvjpEnvs { }; -static void pack_tril(int n, double *tril, double *mat) -{ - size_t i, j, ij; - for (ij = 0, i = 0; i < n; i++) { - for (j = 0; j < i; j++, ij++) { - tril[ij] += mat[i*n+j]; - tril[ij] += mat[j*n+i]; - } - tril[ij] += mat[i*n+i]; - ij++; - } -} - - int AO2MOmmm_nr_vjp_s2_iltj(double *eri_bar, double *mo_coeff_bar, double *eri, double *ybar, double *buf, struct _AO2MOvjpEnvs *envs, int seekdim) @@ -204,7 +191,7 @@ void AO2MOnr_e2_vjp_drv(void (*ftrans)(), int (*fmmm)(), } free(buf); - NPomp_dsum_reduce_inplace(mo_coeff_bar_bufs, nao*nmo); + omp_dsum_reduce_inplace(mo_coeff_bar_bufs, nao*nmo); if (thread_id != 0) { free(mo_coeff_bar_priv); } diff --git a/pyscfadlib/pyscfadlib/vjp/cc/ccsd_t.c b/pyscfadlib/pyscfadlib/vjp/cc/ccsd_t.c index 3a577217..b3ade393 100644 --- a/pyscfadlib/pyscfadlib/vjp/cc/ccsd_t.c +++ b/pyscfadlib/pyscfadlib/vjp/cc/ccsd_t.c @@ -1,6 +1,6 @@ #include #include "vhf/fblas.h" -#include "ccsd_t.h" +#include "vjp/cc/ccsd_t.h" /* copied from pyscf */ size_t _ccsd_t_gen_jobs(CacheJob *jobs, int nocc, int nvir, diff --git a/pyscfadlib/pyscfadlib/vjp/cc/ccsd_t_vjp.c b/pyscfadlib/pyscfadlib/vjp/cc/ccsd_t_vjp.c index eb17f51a..b5f6e7c3 100644 --- a/pyscfadlib/pyscfadlib/vjp/cc/ccsd_t_vjp.c +++ b/pyscfadlib/pyscfadlib/vjp/cc/ccsd_t_vjp.c @@ -1,8 +1,8 @@ #include #include "config.h" -#include "np_helper/np_helper.h" #include "vhf/fblas.h" -#include "ccsd_t.h" +#include "vjp/cc/ccsd_t.h" +#include "vjp/util/util.h" #define MAX_THREADS 128 @@ -457,12 +457,12 @@ void ccsd_t_energy_vjp(double *mo_energy, double *t1T, double *t2T, free(jobs_vjp); free(cache1); - NPomp_dsum_reduce_inplace(mo_energy_bar_bufs, nmo); - NPomp_dsum_reduce_inplace(t1T_bar_bufs, nvir*nocc); - NPomp_dsum_reduce_inplace(t2T_bar_bufs, nvir*nvir*nocc*nocc); - NPomp_dsum_reduce_inplace(vooo_bar_bufs, nvir*nocc*nocc*nocc); - NPomp_dsum_reduce_inplace(fvo_bar_bufs, nvir*nocc); - NPomp_dsum_reduce_inplace(cache_row_a_bar_bufs, da*a1*nocc*nmo); + omp_dsum_reduce_inplace(mo_energy_bar_bufs, nmo); + omp_dsum_reduce_inplace(t1T_bar_bufs, nvir*nocc); + omp_dsum_reduce_inplace(t2T_bar_bufs, nvir*nvir*nocc*nocc); + omp_dsum_reduce_inplace(vooo_bar_bufs, nvir*nocc*nocc*nocc); + omp_dsum_reduce_inplace(fvo_bar_bufs, nvir*nocc); + omp_dsum_reduce_inplace(cache_row_a_bar_bufs, da*a1*nocc*nmo); if (thread_id != 0) { free(mo_energy_bar_priv); free(t1T_bar_priv); @@ -473,18 +473,18 @@ void ccsd_t_energy_vjp(double *mo_energy, double *t1T, double *t2T, } if (a0 > 0) { - NPomp_dsum_reduce_inplace(cache_col_a_bar_bufs, a0*da*nocc*nmo); + omp_dsum_reduce_inplace(cache_col_a_bar_bufs, a0*da*nocc*nmo); if (thread_id != 0) { free(cache_col_a_bar_priv); } } if (b1 <= a0) { - NPomp_dsum_reduce_inplace(cache_row_b_bar_bufs, db*b1*nocc*nmo); + omp_dsum_reduce_inplace(cache_row_b_bar_bufs, db*b1*nocc*nmo); if (thread_id != 0) { free(cache_row_b_bar_priv); } if (b0 > 0) { - NPomp_dsum_reduce_inplace(cache_col_b_bar_bufs, b0*db*nocc*nmo); + omp_dsum_reduce_inplace(cache_col_b_bar_bufs, b0*db*nocc*nmo); if (thread_id != 0) { free(cache_col_b_bar_priv); } diff --git a/pyscfadlib/pyscfadlib/vjp/cc/lno_ccsd_t.c b/pyscfadlib/pyscfadlib/vjp/cc/lno_ccsd_t.c index ce591c23..aa7ab962 100644 --- a/pyscfadlib/pyscfadlib/vjp/cc/lno_ccsd_t.c +++ b/pyscfadlib/pyscfadlib/vjp/cc/lno_ccsd_t.c @@ -1,7 +1,7 @@ #include #include "config.h" #include "vhf/fblas.h" -#include "ccsd_t.h" +#include "vjp/cc/ccsd_t.h" static double lnoccsdt_get_energy(double *mat, double *w, double *v, double *mo_energy, int nocc, diff --git a/pyscfadlib/pyscfadlib/vjp/cc/lno_ccsd_t_vjp.c b/pyscfadlib/pyscfadlib/vjp/cc/lno_ccsd_t_vjp.c index 45b2eb83..69b49f3d 100644 --- a/pyscfadlib/pyscfadlib/vjp/cc/lno_ccsd_t_vjp.c +++ b/pyscfadlib/pyscfadlib/vjp/cc/lno_ccsd_t_vjp.c @@ -1,8 +1,8 @@ #include #include "config.h" -#include "np_helper/np_helper.h" #include "vhf/fblas.h" -#include "ccsd_t.h" +#include "vjp/cc/ccsd_t.h" +#include "vjp/util/util.h" #define MAX_THREADS 128 @@ -528,13 +528,13 @@ void lnoccsdt_energy_vjp(double *mat, double *mo_energy, double *t1T, double *t2 free(jobs_vjp); free(cache1); - NPomp_dsum_reduce_inplace(mat_bar_bufs, nocc*nocc); - NPomp_dsum_reduce_inplace(mo_energy_bar_bufs, nmo); - NPomp_dsum_reduce_inplace(t1T_bar_bufs, nvir*nocc); - NPomp_dsum_reduce_inplace(t2T_bar_bufs, nvir*nvir*nocc*nocc); - NPomp_dsum_reduce_inplace(vooo_bar_bufs, nvir*nocc*nocc*nocc); - NPomp_dsum_reduce_inplace(fvo_bar_bufs, nvir*nocc); - NPomp_dsum_reduce_inplace(cache_row_a_bar_bufs, da*db*nocc*nmo); + omp_dsum_reduce_inplace(mat_bar_bufs, nocc*nocc); + omp_dsum_reduce_inplace(mo_energy_bar_bufs, nmo); + omp_dsum_reduce_inplace(t1T_bar_bufs, nvir*nocc); + omp_dsum_reduce_inplace(t2T_bar_bufs, nvir*nvir*nocc*nocc); + omp_dsum_reduce_inplace(vooo_bar_bufs, nvir*nocc*nocc*nocc); + omp_dsum_reduce_inplace(fvo_bar_bufs, nvir*nocc); + omp_dsum_reduce_inplace(cache_row_a_bar_bufs, da*db*nocc*nmo); if (thread_id != 0) { free(mat_bar_priv); free(mo_energy_bar_priv); @@ -546,14 +546,14 @@ void lnoccsdt_energy_vjp(double *mat, double *mo_energy, double *t1T, double *t2 } if (b0 != 0 || b1 != nvir) { - NPomp_dsum_reduce_inplace(cache_col_a_bar_bufs, da*db*nocc*nmo); + omp_dsum_reduce_inplace(cache_col_a_bar_bufs, da*db*nocc*nmo); if (thread_id != 0) { free(cache_col_a_bar_priv); } } if (c1 <= b0) { - NPomp_dsum_reduce_inplace(cache_row_b_bar_bufs, da*dc*nocc*nmo); - NPomp_dsum_reduce_inplace(cache_col_b_bar_bufs, da*dc*nocc*nmo); + omp_dsum_reduce_inplace(cache_row_b_bar_bufs, da*dc*nocc*nmo); + omp_dsum_reduce_inplace(cache_col_b_bar_bufs, da*dc*nocc*nmo); if (thread_id != 0) { free(cache_row_b_bar_priv); free(cache_col_b_bar_priv); diff --git a/pyscfadlib/pyscfadlib/vjp/util/omp_reduce.c b/pyscfadlib/pyscfadlib/vjp/util/omp_reduce.c new file mode 100644 index 00000000..3b772788 --- /dev/null +++ b/pyscfadlib/pyscfadlib/vjp/util/omp_reduce.c @@ -0,0 +1,24 @@ +#include +#include "config.h" + +#define MIN(X, Y) ((X) < (Y) ? (X) : (Y)) + +void omp_dsum_reduce_inplace(double **vec, size_t count) +{ + unsigned int nthreads = omp_get_num_threads(); + unsigned int thread_id = omp_get_thread_num(); + size_t blksize = (count + nthreads - 1) / nthreads; + size_t start = thread_id * blksize; + size_t end = MIN(start + blksize, count); + double *dst = vec[0]; + double *src; + size_t it, i; +#pragma omp barrier + for (it = 1; it < nthreads; it++) { + src = vec[it]; + for (i = start; i < end; i++) { + dst[i] += src[i]; + } + } +#pragma omp barrier +} diff --git a/pyscfadlib/pyscfadlib/vjp/util/pack_tril.c b/pyscfadlib/pyscfadlib/vjp/util/pack_tril.c new file mode 100644 index 00000000..89ae0f50 --- /dev/null +++ b/pyscfadlib/pyscfadlib/vjp/util/pack_tril.c @@ -0,0 +1,14 @@ +#include + +void pack_tril(int n, double *tril, double *mat) +{ + size_t i, j, ij; + for (ij = 0, i = 0; i < n; i++) { + for (j = 0; j < i; j++, ij++) { + tril[ij] += mat[i*n+j]; + tril[ij] += mat[j*n+i]; + } + tril[ij] += mat[i*n+i]; + ij++; + } +} diff --git a/pyscfadlib/pyscfadlib/vjp/util/util.h b/pyscfadlib/pyscfadlib/vjp/util/util.h new file mode 100644 index 00000000..05ba69f9 --- /dev/null +++ b/pyscfadlib/pyscfadlib/vjp/util/util.h @@ -0,0 +1,8 @@ +#ifndef HAVE_DEFINED_VJPUTIL_H +#define HAVE_DEFINED_VJPUTIL_H + +void omp_dsum_reduce_inplace(double **vec, size_t count); + +void pack_tril(int n, double *tril, double *mat); + +#endif diff --git a/pyscfadlib/pyscfadlib/vjp/vhf/df_jk_vjp.c b/pyscfadlib/pyscfadlib/vjp/vhf/df_jk_vjp.c index 6dd4b837..332fd169 100644 --- a/pyscfadlib/pyscfadlib/vjp/vhf/df_jk_vjp.c +++ b/pyscfadlib/pyscfadlib/vjp/vhf/df_jk_vjp.c @@ -2,22 +2,10 @@ #include "config.h" #include "np_helper/np_helper.h" #include "vhf/fblas.h" +#include "vjp/util/util.h" #define MAX_THREADS 128 -static void pack_tril(int n, double *tril, double *mat) -{ - size_t i, j, ij; - for (ij = 0, i = 0; i < n; i++) { - for (j = 0; j < i; j++, ij++) { - tril[ij] += mat[i*n+j]; - tril[ij] += mat[j*n+i]; - } - tril[ij] += mat[i*n+i]; - ij++; - } -} - // vk_bar in F order // eri_bar = einsum('pki,ij->pkj', buf1, vk_bar) // buf1_bar = einsum('ij,pkj->pki', vk_bar, eri) @@ -64,8 +52,8 @@ void df_vk_vjp(double *eri_tril_bar, double *dm_bar, double *eri_tril, double *dm, int naux, int nao) { - const size_t nao2 = nao * nao; - const size_t nao_pair = nao * (nao+1) /2; + const size_t nao2 = (size_t)nao * nao; + const size_t nao_pair = (size_t)nao * (nao+1) /2; double *dm_bar_bufs[MAX_THREADS]; #pragma omp parallel { @@ -78,7 +66,7 @@ void df_vk_vjp(double *eri_tril_bar, double *dm_bar, dm_bar_priv = calloc(nao2, sizeof(double)); } dm_bar_bufs[thread_id] = dm_bar_priv; - double *cache = malloc((nao2*2+2) * sizeof(double)); + double *cache = malloc(nao2*2 * sizeof(double)); #pragma omp for schedule(dynamic) for (i = 0; i < naux; i++) { _contract_vk(eri_tril_bar+i*nao_pair, dm_bar_priv, @@ -87,7 +75,7 @@ void df_vk_vjp(double *eri_tril_bar, double *dm_bar, } free(cache); - NPomp_dsum_reduce_inplace(dm_bar_bufs, nao2); + omp_dsum_reduce_inplace(dm_bar_bufs, nao2); if (thread_id != 0) { free(dm_bar_priv); }