Skip to content

Commit

Permalink
prepare for release
Browse files Browse the repository at this point in the history
  • Loading branch information
fishjojo committed Mar 3, 2024
1 parent 5b1659c commit 36a800a
Show file tree
Hide file tree
Showing 68 changed files with 1,383 additions and 318 deletions.
3 changes: 1 addition & 2 deletions examples/dft/00-simple.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import pyscf
from pyscfad import gto, dft

"""
Analytic nuclear gradient for RKS computed by auto-differentiation
"""

mol = gto.Mole()
mol.atom = 'H 0 0 0; H 0 0 0.74' # in Angstrom
mol.atom = 'H 0 0 0; H 0 0 0.74'
mol.basis = '631g'
mol.verbose=5
mol.build()
Expand Down
2 changes: 1 addition & 1 deletion examples/mp/10-oomp2.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from scipy.optimize import minimize
from jax import value_and_grad
from pyscf import numpy as np
from jax import numpy as np
from pyscfad import util
from pyscfad.tools import rotate_mo1
from pyscfad import gto, scf, mp
Expand Down
2 changes: 1 addition & 1 deletion pyscfad/ao2mo/_ao2mo.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from pyscf import numpy as np
from jax import numpy as np
from pyscfad import config
from ._ao2mo_opt import nr_e2 as nr_e2_opt

Expand Down
2 changes: 1 addition & 1 deletion pyscfad/ao2mo/incore.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from pyscf import numpy as np
from jax import numpy as np
from pyscf.ao2mo import incore
from pyscf.ao2mo.incore import iden_coeffs
from pyscfad import lib
Expand Down
5 changes: 2 additions & 3 deletions pyscfad/cc/ccsd.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from functools import reduce
import numpy
from pyscf import numpy as np
from pyscf.lib import logger
from jax import numpy as np
from pyscf.cc import ccsd as pyscf_ccsd
from pyscf.mp.mp2 import _mo_without_core
from pyscfad import lib
from pyscfad.lib import ops
from pyscfad.lib import ops, logger
#from pyscfad.lib import jit
#from pyscfad import util
from pyscfad import config
Expand Down
2 changes: 1 addition & 1 deletion pyscfad/cc/ccsd_rdm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from pyscf import numpy as np
from jax import numpy as np

def _make_rdm1(mycc, d1, with_frozen=True, ao_repr=False, with_mf=True):
doo, dov, dvo, dvv = d1
Expand Down
3 changes: 1 addition & 2 deletions pyscfad/cc/ccsd_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,13 @@
from jax import custom_vjp
from jax.tree_util import tree_flatten, tree_unflatten
from pyscf.lib import (
logger,
prange_tril,
num_threads,
current_memory,
# load_library
)
from pyscf.cc import ccsd_t as pyscf_ccsd_t

from pyscfad.lib import logger
#libcc = load_library('libcc')
from pyscfadlib import libcc_vjp as libcc

Expand Down
5 changes: 2 additions & 3 deletions pyscfad/cc/ccsd_t_slow.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
'''
CCSD(T)
'''
from pyscf import numpy as np
from pyscf.lib import logger
from jax import numpy as np
from pyscfad import config, config_update
from pyscfad import lib
from pyscfad.lib import jit, vmap
from pyscfad.lib import logger, jit, vmap
from pyscfad.implicit_diff import make_implicit_diff
from pyscfad.tools.linear_solver import gen_gmres

Expand Down
2 changes: 1 addition & 1 deletion pyscfad/cc/dfccsd.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from pyscf import numpy as np
from jax import numpy as np
from pyscf.lib import square_mat_in_trilu_indices
from pyscfad import util
from pyscfad import lib
Expand Down
5 changes: 2 additions & 3 deletions pyscfad/cc/rccsd.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from pyscf import numpy as np
from pyscf.lib import logger
from jax import numpy as np
from pyscf.lib import current_memory
from pyscfad import util
from pyscfad.lib import jit
from pyscfad.lib import logger, jit
from pyscfad import ao2mo
from pyscfad.cc import ccsd
from pyscfad.cc import rintermediates as imd
Expand Down
2 changes: 1 addition & 1 deletion pyscfad/cc/rintermediates.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
'''
Intermediates for restricted CCSD. Complex integrals are supported.
'''
from pyscf import numpy as np
from jax import numpy as np
#from pyscfad.lib import jit

# This is restricted (R)CCSD
Expand Down
2 changes: 1 addition & 1 deletion pyscfad/df/addons.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from pyscf import numpy as np
from jax import numpy as np
from pyscf.df import addons as pyscf_addons
from pyscfad import lib
from pyscfad import ao2mo
Expand Down
3 changes: 2 additions & 1 deletion pyscfad/df/df.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
class DF(pyscf_df.DF):
# pylint: disable=redefined-outer-name
def __init__(self, mol, auxbasis=None, incore=True, **kwargs):
pyscf_df.DF.__init__(self, mol, auxbasis=auxbasis, incore=incore)
pyscf_df.DF.__init__(self, mol, auxbasis=auxbasis)
self.incore = incore
self.__dict__.update(kwargs)

def build(self):
Expand Down
2 changes: 1 addition & 1 deletion pyscfad/df/df_jk.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from pyscf import numpy as np
from jax import numpy as np
from pyscfad import config
from .addons import restore
from ._df_jk_opt import get_jk as get_jk_opt
Expand Down
2 changes: 1 addition & 1 deletion pyscfad/df/incore.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from functools import partial
from jax import scipy
from jax import custom_jvp
from pyscf import numpy as np
from jax import numpy as np
from pyscf.lib import logger
from pyscf import gto
from pyscf.df.outcore import _guess_shell_ranges
Expand Down
2 changes: 1 addition & 1 deletion pyscfad/dft/libxc.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from functools import partial
from jax import numpy as np
from jax import jit, custom_jvp
from pyscf.dft import libxc
from pyscf.dft.libxc import parse_xc, is_lda, is_meta_gga
from pyscfad.lib import numpy as np

def eval_xc(xc_code, rho, spin=0, relativity=0, deriv=1, omega=None, verbose=None):
# NOTE only consider exc and vxc
Expand Down
2 changes: 1 addition & 1 deletion pyscfad/dft/numint.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import warnings
from functools import partial
import numpy
from pyscf import numpy as np
from jax import numpy as np
from pyscf.lib import load_library
from pyscf.dft import numint
from pyscf.dft.numint import SWITCH_SIZE
Expand Down
10 changes: 9 additions & 1 deletion pyscfad/dft/rks.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import numpy
from jax import numpy as np
from pyscf import __config__
from pyscf.lib import current_memory
from pyscf.lib import logger
from pyscf.dft import rks as pyscf_rks
from pyscf.dft import gen_grid
from pyscfad import util
from pyscfad.lib import numpy as np
from pyscfad.lib import stop_grad
from pyscfad.scf import hf
from pyscfad.dft import numint
Expand Down Expand Up @@ -164,6 +164,14 @@ class KohnShamDFT(pyscf_rks.KohnShamDFT):
__init__ = _dft_common_init_
__post_init__ = _dft_common_post_init_

def reset(self, mol=None):
hf.SCF.reset(self, mol)
if getattr(self, 'grids', None) is not None:
self.grids.reset(mol)
if getattr(self, 'nlcgrids', None) is not None:
self.nlcgrids.reset(mol)
return self

@util.pytree_node(hf.Traced_Attributes, num_args=1)
class RKS(KohnShamDFT, hf.RHF):
def __init__(self, mol, xc='LDA,VWN', **kwargs):
Expand Down
2 changes: 1 addition & 1 deletion pyscfad/dft/test/test_rks_hess.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,5 @@ def test_rks_nuc_hess_gga(get_mol):

def test_rks_nuc_hess_gga_hybrid(get_mol):
mol = get_mol
hess = jacrev(jacrev(energy))(mol, 'b3lyp').coords.coords
hess = jacrev(jacrev(energy))(mol, 'b3lyp5').coords.coords
assert abs(fp(hess) - -0.5114248632559669) < 1e-6
3 changes: 1 addition & 2 deletions pyscfad/fci/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from functools import reduce
from pyscf import numpy as np
#from pyscfad.lib import numpy as np
from jax import numpy as np
from pyscfad import ao2mo
from pyscfad.fci import fci_slow
from pyscfad.fci.fci_slow import fci_ovlp
Expand Down
102 changes: 98 additions & 4 deletions pyscfad/fci/fci_slow.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from pyscf import numpy as np
import numpy
from jax import numpy as np
from pyscf.fci import cistring
from pyscf.fci import fci_slow as pyscf_fci_slow
from pyscfad.lib import vmap
from pyscfad import lib
from pyscfad.lib import vmap, ops, stop_grad
from pyscfad.lib.linalg_helper import davidson
from pyscfad.gto import mole

def get_occ_loc(strs, norb):
Expand Down Expand Up @@ -92,5 +94,97 @@ def body(mo_ia, mo_ib, ida, idb):
res += ci1[ia,ib] * (val * ci2.ravel()).sum()
return res

def contract_2e(eri, fcivec, norb, nelec, opt=None):
'''Compute E_{pq}E_{rs}|CI>'''
if isinstance(nelec, (int, np.integer)):
nelecb = nelec//2
neleca = nelec - nelecb
else:
neleca, nelecb = nelec
link_indexa = cistring.gen_linkstr_index(range(norb), neleca)
link_indexb = cistring.gen_linkstr_index(range(norb), nelecb)
na = cistring.num_strings(norb, neleca)
nb = cistring.num_strings(norb, nelecb)
ci0 = fcivec.reshape(na,nb)
t1 = np.zeros((norb,norb,na,nb))
for str0, tab in enumerate(link_indexa):
for a, i, str1, sign in tab:
t1 = ops.index_add(t1, ops.index[a,i,str1], sign * ci0[str0])
for str0, tab in enumerate(link_indexb):
for a, i, str1, sign in tab:
t1 = ops.index_add(t1, ops.index[a,i,:,str1], sign * ci0[:,str0])

t1 = np.einsum('bjai,aiAB->bjAB', eri.reshape([norb]*4), t1)

fcinew = np.zeros_like(ci0)
for str0, tab in enumerate(link_indexa):
for a, i, str1, sign in tab:
fcinew = ops.index_add(fcinew, ops.index[str1], sign * t1[a,i,str0])
for str0, tab in enumerate(link_indexb):
for a, i, str1, sign in tab:
fcinew = ops.index_add(fcinew, ops.index[:,str1], sign * t1[a,i,:,str0])
return fcinew.reshape(fcivec.shape)


def absorb_h1e(h1e, eri, norb, nelec, fac=1):
if not isinstance(nelec, (int, np.integer)):
nelec = sum(nelec)
if eri.size != norb**4:
h2e = ao2mo.restore(1, eri.copy(), norb)
else:
h2e = eri.copy().reshape(norb,norb,norb,norb)
f1e = h1e - np.einsum('jiik->jk', h2e) * .5
f1e = f1e * (1./(nelec+1e-100))
for k in range(norb):
h2e = ops.index_add(h2e, ops.index[k,k,:,:], f1e)
h2e = ops.index_add(h2e, ops.index[:,:,k,k], f1e)
return h2e * fac

def make_hdiag(h1e, eri, norb, nelec, opt=None):
if isinstance(nelec, (int, np.integer)):
nelecb = nelec//2
neleca = nelec - nelecb
else:
neleca, nelecb = nelec

occslista = cistring.gen_occslst(range(norb), neleca)
occslistb = cistring.gen_occslst(range(norb), nelecb)
if eri.size != norb**4:
eri = ao2mo.restore(1, eri, norb)
else:
eri = eri.reshape(norb,norb,norb,norb)
diagj = np.einsum('iijj->ij', eri)
diagk = np.einsum('ijji->ij', eri)
hdiag = []
for aocc in occslista:
for bocc in occslistb:
e1 = h1e[aocc,aocc].sum() + h1e[bocc,bocc].sum()
e2 = diagj[aocc][:,aocc].sum() + diagj[aocc][:,bocc].sum() \
+ diagj[bocc][:,aocc].sum() + diagj[bocc][:,bocc].sum() \
- diagk[aocc][:,aocc].sum() - diagk[bocc][:,bocc].sum()
hdiag.append(e1 + e2*.5)
return np.array(hdiag)

def kernel(h1e, eri, norb, nelec, ecore=0, nroots=1):
h2e = absorb_h1e(h1e, eri, norb, nelec, .5)
na = cistring.num_strings(norb, nelec//2)

hdiag = make_hdiag(h1e, eri, norb, nelec)
try:
from pyscf.fci.direct_spin1 import pspace
addrs, h0 = pspace(stop_grad(h1e), stop_grad(eri),
norb, nelec, stop_grad(hdiag), nroots)
except:
addrs = numpy.argsort(hdiag)[:nroots]
ci0 = []
for addr in addrs:
x = numpy.zeros((na*na))
x[addr] = 1.
ci0.append(x.ravel())

kernel = pyscf_fci_slow.kernel
def hop(c):
hc = contract_2e(h2e, c, norb, nelec)
return hc.ravel()
precond = lambda x, e, *args: x/(hdiag-e+1e-4)
e, c = davidson(hop, ci0, precond, nroots=nroots)
return e+ecore, c
4 changes: 2 additions & 2 deletions pyscfad/fci/test/test_fci.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from functools import reduce
import pytest
import jax
from jax import numpy as np
from pyscfad import gto, scf, ao2mo
from pyscfad.lib import numpy as np
from pyscfad.fci import fci_slow

@pytest.fixture()
Expand Down Expand Up @@ -31,6 +31,6 @@ def test_nuc_grad(get_h2):
e = fci_energy(mol)
assert abs(e - -1.1372838344885026) < 1e-9

g = jax.jacrev(fci_energy)(mol).coords
g = jax.grad(fci_energy)(mol).coords
assert abs(g[0,2] - -0.00455429) < 1e-6
assert abs(g.sum()) < 1e-6
22 changes: 18 additions & 4 deletions pyscfad/gto/_moleintor_jvp.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from functools import partial
import numpy
from pyscf import numpy as np
from jax import numpy as np
from pyscf import ao2mo
from pyscf.gto import mole as pyscf_mole
from pyscf.gto import ATOM_OF
from pyscf.gto.moleintor import _get_intor_and_comp
from pyscfad.lib import ops, custom_jvp, jit, vmap
from ._mole_helper import (
Expand Down Expand Up @@ -331,12 +332,25 @@ def _int1e_nuc_jvp_rc(mol, mol_t, intor):
coords_t = mol_t.coords
atmlst = range(mol.natm)
nao = mol.nao
grad = np.zeros((mol.natm,3,nao,nao), dtype=float)

ecp_intor = 'ECP' in intor
if ecp_intor:
if not mol.has_ecp():
return 0
else:
ecp_atoms = set(mol._ecpbas[:,ATOM_OF])

grad = np.zeros((mol.natm,3,nao,nao))
for k, ia in enumerate(atmlst):
with mol.with_rinv_at_nucleus(ia):
vrinv = getints2c_rc(mol, intor, comp=3, rc_deriv=ia)
if 'ECP' not in intor:
if not ecp_intor:
vrinv = getints2c_rc(mol, intor, comp=3, rc_deriv=ia)
vrinv *= -mol.atom_charge(ia)
else:
if ia in ecp_atoms:
vrinv = getints2c_rc(mol, intor, comp=3, rc_deriv=ia)
else:
vrinv = 0
grad = ops.index_update(grad, ops.index[k], vrinv)
tangent_out = _int1e_dot_grad_tangent_r0(grad, coords_t)
return tangent_out
Expand Down
2 changes: 1 addition & 1 deletion pyscfad/gto/eval_gto.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from functools import partial
import numpy
#import jax
from pyscf import numpy as np
from jax import numpy as np
from pyscf.gto import mole as pyscf_mole
from pyscf.gto.moleintor import make_loc
from pyscf.gto.eval_gto import _get_intor_and_comp
Expand Down
2 changes: 1 addition & 1 deletion pyscfad/gto/mole.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from functools import wraps
from pyscf import numpy as np
from jax import numpy as np
from pyscf.gto import mole as pyscf_mole
from pyscf.lib import logger, param
from pyscfad import util
Expand Down
Loading

0 comments on commit 36a800a

Please sign in to comment.