Skip to content

Commit

Permalink
Fix PBC lattice response (#35)
Browse files Browse the repository at this point in the history
  • Loading branch information
fishjojo authored Jun 27, 2024
1 parent 3906545 commit 1921947
Show file tree
Hide file tree
Showing 34 changed files with 482 additions and 1,254 deletions.
16 changes: 13 additions & 3 deletions examples/pbc/scf/00-simple.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
'''
--------------- KRHF gradients ---------------
x y z
0 Si -0.0077806264 -0.0077806264 -0.0077806264
1 Si 0.0077806264 0.0077806264 0.0077806264
----------------------------------------------
'''
import jax
from pyscfad.pbc import gto, scf

basis = 'gth-szv'
Expand All @@ -19,7 +27,9 @@
cell.verbose = 4
cell.build()

mf = scf.RHF(cell, exxdiv=None)
mf.kernel()
jac = mf.energy_grad()
def hf_energy(cell):
mf = scf.RHF(cell, exxdiv=None)
e_tot = mf.kernel()
return e_tot
e_tot, jac = jax.value_and_grad(hf_energy)(cell)
print(f'Nuclaer gradient:\n{jac.coords}')
13 changes: 8 additions & 5 deletions examples/pbc/scf/01-krhf.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
1 Si 0.0109162228 0.0028233781 0.0028233781
----------------------------------------------
'''

import jax
from pyscfad.pbc import gto
from pyscfad.pbc import scf, df

Expand All @@ -25,10 +25,13 @@
cell.a = lattice
cell.basis = basis
cell.pseudo = pseudo
cell.verbose = 4
cell.build()
kpts = cell.make_kpts([2,1,1])

mf = scf.KRHF(cell, kpts=kpts, exxdiv=None)
mf.kernel()
jac = mf.energy_grad()
print(jac.coords)
def hf_energy(cell, kpts):
mf = scf.KRHF(cell, kpts=kpts, exxdiv=None)
e_tot = mf.kernel()
return e_tot
e_tot, jac = jax.value_and_grad(hf_energy)(cell, kpts)
print(f'Nuclaer gradient:\n{jac.coords}')
1 change: 1 addition & 0 deletions examples/pbc/scf/02-krhf-hessian.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
cell.a = lattice
cell.basis = basis
cell.pseudo = pseudo
cell.verbose = 4
cell.build(trace_exp=False, trace_ctr_coeff=False)

def ehf(cell):
Expand Down
49 changes: 25 additions & 24 deletions examples/pbc/scf/10-stress.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,26 @@
'''
stress tensor
----------------------------
[[-7.57812381e-01 1.65085403e-09 1.65085540e-09]
[ 1.65085006e-09 -7.57812381e-01 1.65084496e-09]
[ 1.65084745e-09 1.65084512e-09 -7.57812381e-01]]
[[-3.59347869e-03 7.82820248e-12 7.82820901e-12]
[ 7.82818367e-12 -3.59347869e-03 7.82815947e-12]
[ 7.82817132e-12 7.82816026e-12 -3.59347869e-03]]
[[-6.59876007e-01 1.43750484e-09 1.43750604e-09]
[ 1.43750138e-09 -6.59876007e-01 1.43749694e-09]
[ 1.43749911e-09 1.43749708e-09 -6.59876007e-01]]
----------------------------
'''
import numpy
import jax
from jax import numpy as np
from pyscf.data.nist import BOHR, HARTREE2EV
from pyscfad.pbc import gto as pbcgto
from pyscfad.pbc import scf as pbcscf

aas = numpy.arange(5,6.0,0.1,dtype=float)
for aa in aas:
#aas = numpy.arange(5,6.0,0.1,dtype=float)
for aa in [5.]:
basis = 'gth-szv'
pseudo = 'gth-pade'
lattice = numpy.asarray([[0., aa/2, aa/2],
Expand All @@ -15,30 +29,17 @@
atom = [['Si', [0., 0., 0.]],
['Si', [aa/4, aa/4, aa/4]]]

cell0 = pbcgto.Cell()
cell0.atom = atom
cell0.a = lattice
cell0.basis = basis
cell0.pseudo = pseudo
cell0.verbose = 4
cell0.exp_to_discard=0.1
cell0.build()

coords = []
for i, a in enumerate(atom):
coords.append(a[1])
coords = numpy.asarray(coords)

strain = numpy.zeros((3,3))
def khf_energy(strain, lattice, coords):
strain = np.zeros((3,3))
def khf_energy(strain, lattice):
cell = pbcgto.Cell()
cell.atom = atom
cell.a = lattice
cell.basis = basis
cell.pseudo = pseudo
cell.verbose = 4
cell.exp_to_discard=0.1
cell.max_memory=40000
cell.mesh = [21]*3
cell.max_memory=100000
cell.build(trace_lattice_vectors=True)

cell.abc += np.einsum('ab,nb->na', strain, cell.lattice_vectors())
Expand All @@ -47,13 +48,13 @@ def khf_energy(strain, lattice, coords):
kpts = cell.make_kpts([2,2,2])

mf = pbcscf.KRHF(cell, kpts=kpts, exxdiv=None)
ehf = mf.kernel(dm0=None)
return ehf
ehf = mf.kernel()
return ehf, cell

jac = jax.jacrev(khf_energy)(strain, lattice, coords)
jac, cell = jax.jacrev(khf_energy, has_aux=True)(strain, lattice)
print('stress tensor')
print('----------------------------')
print(jac)
print(jac / cell0.vol)
print(jac*HARTREE2EV / (cell0.vol*(BOHR**3)))
print(jac / cell.vol)
print(jac*HARTREE2EV / (cell.vol*(BOHR**3)))
print('----------------------------')
6 changes: 3 additions & 3 deletions pyscfad/_src/scipy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def eigh(a, b=None, lower=True, eigvals_only=False, overwrite_a=False,
@partial(custom_jvp, nondiff_argnums=(2,))
def _eigh(a, b, deg_thresh=DEG_THRESH):
w, v = scipy.linalg.eigh(a, b=b)
w = np.asarray(w, dtype=v.dtype)
w = np.asarray(w, dtype=float)
return w, v

@_eigh.defjvp
Expand Down Expand Up @@ -63,15 +63,15 @@ def _eigh_jvp_jitted(w, v, Fmat, at, bt, bmask):
vt_bt_v = np.dot(v.conj().T, np.dot(bt, v))
vt_bt_v_w = np.dot(vt_bt_v, np.diag(w))
da_minus_ds = vt_at_v - vt_bt_v_w
dw = np.diag(da_minus_ds)#.real
dw = np.diag(da_minus_ds).real

dv = np.dot(v, np.multiply(Fmat, da_minus_ds) - np.multiply(bmask, vt_bt_v) * .5)
return dw, dv

@jit
def _eigh_jvp_jitted_nob(v, Fmat, at):
vt_at_v = np.dot(v.conj().T, np.dot(at, v))
dw = np.diag(vt_at_v)
dw = np.diag(vt_at_v).real
dv = np.dot(v, np.multiply(Fmat, vt_at_v))
return dw, dv

Expand Down
12 changes: 6 additions & 6 deletions pyscfad/df/_df_jk_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,20 +58,20 @@ def get_jk_bwd(hermi, with_j, with_k, direct_scf_tol,
rargs = (ctypes.c_int(nao), (ctypes.c_int*4)(0, nao, 0, nao),
null, ctypes.c_int(0))
max_memory = dfobj.max_memory - lib.current_memory()[0]
blksize = max(4, int(min(dfobj.blockdim, max_memory*.4e6/8/nao**2)))
blksize = max(4, int(min(dfobj.blockdim, max_memory*.3e6/8/nao**2)))
buf = numpy.empty((blksize,nao,nao))
p1 = 0
for eri1 in dfobj.loop(blksize):
naux, nao_pair = eri1.shape
p0, p1 = p1, p1 + naux
if with_j:
rho_bar = numpy.einsum('ix,px->ip', vj_bar_tril, eri1)
dmtril_bar = numpy.einsum('ip,px->ix', rho_bar, eri1)
rho_bar = vj_bar_tril @ eri1.T
dmtril_bar = rho_bar @ eri1
dms_bar += lib.unpack_tril(dmtril_bar)

rho = numpy.einsum('ix,px->ip', dmtril, eri1)
eri_bar[p0:p1] += numpy.einsum('ip,ix->px', rho, vj_bar_tril)
eri_bar[p0:p1] += numpy.einsum('ix,ip->px', dmtril, rho_bar)
rho = dmtril @ eri1.T
eri_bar[p0:p1] += rho.T @ vj_bar_tril
eri_bar[p0:p1] += rho_bar.T @ dmtril

for k in range(nset):
#TODO save buf1 on disk to avoid recomputation
Expand Down
10 changes: 3 additions & 7 deletions pyscfad/pbc/df/fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,14 @@
from jax import numpy as np
from pyscf import __config__
from pyscf import lib
from pyscf.pbc.gto import Cell
from pyscf.pbc.df import fft as pyscf_fft
#from pyscfad import util
from pyscfad.pbc import tools
from pyscfad.pbc.lib.kpts_helper import gamma_point

def get_pp(mydf, kpts=None):
from pyscf import gto
from pyscf.pbc.gto import pseudo
from pyscfad.pbc.gto import pseudo
from pyscfad.gto.mole import Mole
cell = mydf.cell
if kpts is None:
Expand All @@ -31,7 +30,7 @@ def get_pp(mydf, kpts=None):
for ao_ks_etc, p0, p1 in mydf.aoR_loop(mydf.grids, kpts_lst):
ao_ks = ao_ks_etc[0]
for k, ao in enumerate(ao_ks):
vpp[k] += np.dot(ao.T.conj()*vpplocR[p0:p1], ao)
vpp[k] += (ao.T.conj()*vpplocR[p0:p1]) @ ao
ao = ao_ks = None

# vppnonloc evaluated in reciprocal space
Expand Down Expand Up @@ -124,7 +123,7 @@ def vppnl_by_k(kpt):
#@util.pytree_node(['cell','kpts'])
class FFTDF(pyscf_fft.FFTDF):
def __init__(self, cell, kpts=numpy.zeros((1,3))):#, **kwargs):
from pyscf.pbc.dft import gen_grid
from pyscfad.pbc.dft import gen_grid
from pyscfad.pbc.dft import numint
self.cell = cell
self.stdout = cell.stdout
Expand Down Expand Up @@ -173,9 +172,6 @@ def aoR_loop(self, grids=None, kpts=None, deriv=0):
cell = self.cell
else:
cell = grids.cell

# NOTE stop tracing cell through grids
grids.cell = cell.view(Cell)
if grids.non0tab is None:
grids.build(with_non0tab=True)

Expand Down
32 changes: 32 additions & 0 deletions pyscfad/pbc/dft/gen_grid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import numpy
from pyscf.pbc.gto import Cell as pyscf_Cell
from pyscf.pbc.dft import gen_grid as pyscf_gen_grid
from pyscfad import numpy as np
from pyscfad.ops import stop_grad
from pyscfad.pbc.gto.cell import get_uniform_grids

class UniformGrids(pyscf_gen_grid.UniformGrids):
@property
def coords(self):
if self._coords is not None:
return self._coords
else:
return get_uniform_grids(self.cell, self.mesh)

@property
def weights(self):
if self._weights is not None:
return self._weights
else:
ngrids = numpy.prod(self.mesh)
weights = np.full((ngrids,), self.cell.vol / ngrids)
return weights

def make_mask(self, cell=None, coords=None, relativity=0, shls_slice=None,
verbose=None):
if cell is None:
cell = self.cell
if coords is None:
coords = self.coords
return pyscf_gen_grid.make_mask(cell.view(pyscf_Cell), stop_grad(coords),
relativity, shls_slice, verbose)
Loading

0 comments on commit 1921947

Please sign in to comment.