From 99a9d161ccc63e0bda97d287a52fd6f234063a6e Mon Sep 17 00:00:00 2001 From: Xing Zhang Date: Mon, 15 Apr 2024 18:22:33 -0700 Subject: [PATCH] Fix LRC-RKS (#24) --- pyscfad/dft/test/test_rks.py | 11 +++++++++++ pyscfad/scf/hf.py | 21 ++++++++++++--------- pyscfad/scf/uhf.py | 16 ---------------- 3 files changed, 23 insertions(+), 25 deletions(-) diff --git a/pyscfad/dft/test/test_rks.py b/pyscfad/dft/test/test_rks.py index dd1a09a0..d1b89c9e 100644 --- a/pyscfad/dft/test/test_rks.py +++ b/pyscfad/dft/test/test_rks.py @@ -63,6 +63,17 @@ def test_rks_nuc_grad_hybrid(get_mol): assert abs(g1-g0).max() < 1e-6 assert abs(g2-g0).max() < 1e-6 +def test_rks_nuc_grad_lrc(get_mol): + mol = get_mol() + mf = dft.RKS(mol) + mf.xc = 'HYB_GGA_XC_LRC_WPBE' + g1 = mf.energy_grad(mode="rev").coords + mf.kernel() + g2 = mf.energy_grad(mode="rev").coords + g0 = mf.nuc_grad_method().kernel() + assert abs(g1-g0).max() < 1e-6 + assert abs(g2-g0).max() < 1e-6 + #FIXME MGGA is broken since pyscf v2.1 def test_rks_nuc_grad_mgga_skip(get_mol, get_mol_p, get_mol_m): mol = get_mol diff --git a/pyscfad/scf/hf.py b/pyscfad/scf/hf.py index 0fba5f42..33313d39 100644 --- a/pyscfad/scf/hf.py +++ b/pyscfad/scf/hf.py @@ -116,10 +116,8 @@ def kernel(mf, conv_tol=1e-10, conv_tol_grad=None, mf.with_df.build() else: if mf._eri is None: - if config.moleintor_opt: - mf._eri = mol.intor('int2e', aosym='s4') - else: - mf._eri = mol.intor('int2e', aosym='s1') + aosym = 's4' if config.moleintor_opt else 's1' + mf._eri = mol.intor('int2e', aosym=aosym) scf_conv = False mo_energy = mo_coeff = mo_occ = None @@ -371,12 +369,17 @@ def get_jk(self, mol=None, dm=None, hermi=1, with_j=True, with_k=True, mol = self.mol if dm is None: dm = self.make_rdm1() + + aosym = 's4' if config.moleintor_opt else 's1' if self._eri is None: - if config.moleintor_opt: - self._eri = self.mol.intor('int2e', aosym='s4') - else: - self._eri = self.mol.intor('int2e', aosym='s1') - vj, vk = dot_eri_dm(self._eri, dm, hermi, with_j, with_k) + self._eri = mol.intor('int2e', aosym=aosym) + if omega: + with mol.with_range_coulomb(omega): + _eri = mol.intor('int2e', aosym=aosym) + else: + _eri = self._eri + + vj, vk = dot_eri_dm(_eri, dm, hermi, with_j, with_k) return vj, vk def get_init_guess(self, mol=None, key='minao'): diff --git a/pyscfad/scf/uhf.py b/pyscfad/scf/uhf.py index 18232a61..596d2952 100644 --- a/pyscfad/scf/uhf.py +++ b/pyscfad/scf/uhf.py @@ -3,7 +3,6 @@ from jax import numpy as np from pyscf.lib import module_method from pyscf.scf import uhf as pyscf_uhf -from pyscfad import config from pyscfad import util from pyscfad.lib import logger, stop_grad from pyscfad.scf import hf @@ -100,21 +99,6 @@ def eig(self, h, s): e_b, c_b = self._eigh(h[1], s) return np.array((e_a,e_b)), np.array((c_a,c_b)) - @wraps(pyscf_uhf.UHF.get_jk) - def get_jk(self, mol=None, dm=None, hermi=1, with_j=True, with_k=True, - omega=None): - if mol is None: - mol = self.mol - if dm is None: - dm = self.make_rdm1() - if self._eri is None: - if config.moleintor_opt: - self._eri = mol.intor('int2e', aosym='s4') - else: - self._eri = mol.intor('int2e', aosym='s1') - vj, vk = hf.dot_eri_dm(self._eri, dm, hermi, with_j, with_k) - return vj, vk - @wraps(pyscf_uhf.UHF.get_veff) def get_veff(self, mol=None, dm=None, dm_last=0, vhf_last=0, hermi=1): if mol is None: