Skip to content

Commit

Permalink
add examples
Browse files Browse the repository at this point in the history
  • Loading branch information
fishjojo committed Aug 4, 2022
1 parent 389a696 commit f69e185
Show file tree
Hide file tree
Showing 4 changed files with 158 additions and 23 deletions.
33 changes: 33 additions & 0 deletions examples/fci/10-nac.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import jax
from pyscfad import gto, scf, fci

# molecular structure
mol = gto.Mole()
mol.atom = 'H 0 0 0; H 0 0 1.1'
mol.basis = 'ccpvdz'
mol.build()

# HF and FCI calculation
nroots = 8
mf = scf.RHF(mol)
mf.kernel()
e, fcivec = fci.solve_fci(mf, nroots=nroots)
print(e)

nelec = mol.nelectron
norb = mf.mo_coeff.shape[-1]
stateI, stateJ = 2, 7
def ovlp(mol1):
mf1 = scf.RHF(mol1)
mf1.kernel()
e1, fcivec1 = fci.solve_fci(mf1, nroots=nroots)
# wavefunction overlap
s = fci.fci_ovlp(mol, mol1, fcivec[stateI], fcivec1[stateJ],
norb, norb, nelec, nelec, mf.mo_coeff, mf1.mo_coeff)
return s

# Only the ket state is differentiated
mol1 = mol.copy()
jac = jax.jacrev(ovlp)(mol1)
print("FCI derivative coupling:")
print(jac.coords)
59 changes: 59 additions & 0 deletions examples/pbc/scf/10-stress.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import numpy
import jax
from jax import numpy as np
from pyscf.data.nist import BOHR, HARTREE2EV
from pyscfad.pbc import gto as pbcgto
from pyscfad.pbc import scf as pbcscf

aas = numpy.arange(5,6.0,0.1,dtype=float)
for aa in aas:
basis = 'gth-szv'
pseudo = 'gth-pade'
lattice = numpy.asarray([[0., aa/2, aa/2],
[aa/2, 0., aa/2],
[aa/2, aa/2, 0.]])
atom = [['Si', [0., 0., 0.]],
['Si', [aa/4, aa/4, aa/4]]]

cell0 = pbcgto.Cell()
cell0.atom = atom
cell0.a = lattice
cell0.basis = basis
cell0.pseudo = pseudo
cell0.verbose = 4
cell0.exp_to_discard=0.1
cell0.build()

coords = []
for i, a in enumerate(atom):
coords.append(a[1])
coords = numpy.asarray(coords)

strain = numpy.zeros((3,3))
def khf_energy(strain, lattice, coords):
cell = pbcgto.Cell()
cell.atom = atom
cell.a = lattice
cell.basis = basis
cell.pseudo = pseudo
cell.verbose = 4
cell.exp_to_discard=0.1
cell.max_memory=40000
cell.build(trace_lattice_vectors=True)

cell.abc += np.einsum('ab,nb->na', strain, cell.lattice_vectors())
cell.coords += np.einsum('xy,ny->nx', strain, cell.atom_coords())

kpts = cell.make_kpts([2,2,2])

mf = pbcscf.KRHF(cell, kpts=kpts, exxdiv=None)
ehf = mf.kernel(dm0=None)
return ehf

jac = jax.jacrev(khf_energy)(strain, lattice, coords)
print('stress tensor')
print('----------------------------')
print(jac)
print(jac / cell0.vol)
print(jac*HARTREE2EV / (cell0.vol*(BOHR**3)))
print('----------------------------')
55 changes: 32 additions & 23 deletions examples/tdscf/10-cis_nac.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,43 @@
import numpy
import jax
from jax import numpy as jnp
from pyscfad import gto, scf, tdscf
from pyscfad import gto, scf
from pyscfad.tdscf.rhf import CIS, cis_ovlp

# molecular structure
mol = gto.Mole()
mol.atom = 'H 0 0 0; H 0 0 1.1'
mol.basis = 'ccpvtz'
mol.verbose = 4
mol.build(trace_exp=False, trace_ctr_coeff=False)
mol.basis = 'cc-pvdz'
mol.build()

# HF and CIS calculations
mf = scf.RHF(mol)
mf.kernel()
mytd = tdscf.rhf.CIS(mf)
mytd.nstates = 3
mytd = CIS(mf)
mytd.nstates = 8
e, xy = mytd.kernel()

i, j = 0, 2
xi = xy[i][0] * jnp.sqrt(2.)
xj = xy[j][0] * jnp.sqrt(2.)
# Target excited states I and J (1st and 4th)
stateI, stateJ = 0, 2
# CI coefficients of state I
xi = xy[stateI][0] * numpy.sqrt(2.)
nmo = mf.mo_coeff.shape[-1]
nocc = mol.nelectron // 2

# Using Hellman-Feynman formalism.
# The amplitude is closed over, so there is no tracing through the Davidson iteration.
def hellman(mol):
mf = scf.RHF(mol)
mf.kernel()
mytd = tdscf.rhf.CIS(mf)
mytd.nstates = 3
def ovlp(mol1):
mf1 = scf.RHF(mol1)
mf1.kernel()
mytd1 = CIS(mf1)
mytd1.nstates = 8
_, xy1 = mytd1.kernel()
# CI coefficients of state J
xj = xy1[stateJ][0] * numpy.sqrt(2.)
# CIS wavefunction overlap
s = cis_ovlp(mol, mol1, mf.mo_coeff, mf1.mo_coeff,
nocc, nocc, nmo, nmo, xi, xj)
return s

vind, _ = mytd.gen_vind(mytd._scf)
e = jnp.dot(xi.ravel(), vind(xj).ravel())
return e

nac = jax.grad(hellman)(mol)
print(nac.coords / (e[j]-e[i]))
# Only the ket state is differentiated
mol1 = mol.copy()
jac = jax.jacrev(ovlp)(mol1)
print("CIS derivative coupling:")
print(jac.coords)
34 changes: 34 additions & 0 deletions examples/tdscf/11-cis_nac_hellman.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import jax
from jax import numpy as jnp
from pyscfad import gto, scf, tdscf

mol = gto.Mole()
mol.atom = 'H 0 0 0; H 0 0 1.1'
mol.basis = 'ccpvtz'
mol.verbose = 4
mol.build(trace_exp=False, trace_ctr_coeff=False)

mf = scf.RHF(mol)
mf.kernel()
mytd = tdscf.rhf.CIS(mf)
mytd.nstates = 3
e, xy = mytd.kernel()

i, j = 0, 2
xi = xy[i][0] * jnp.sqrt(2.)
xj = xy[j][0] * jnp.sqrt(2.)

# Using Hellman-Feynman formalism.
# The amplitude is closed over, so there is no tracing through the Davidson iteration.
def hellman(mol):
mf = scf.RHF(mol)
mf.kernel()
mytd = tdscf.rhf.CIS(mf)
mytd.nstates = 3

vind, _ = mytd.gen_vind(mytd._scf)
e = jnp.dot(xi.ravel(), vind(xj).ravel())
return e

nac = jax.grad(hellman)(mol)
print(nac.coords / (e[j]-e[i]))

0 comments on commit f69e185

Please sign in to comment.