Skip to content

Commit 1cd7be4

Browse files
committed
Complex dms for DFHF (fix pyscf#2670)
1 parent ccedc56 commit 1cd7be4

File tree

2 files changed

+41
-2
lines changed

2 files changed

+41
-2
lines changed

pyscf/df/df_jk.py

+29-2
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,27 @@ def get_jk(dfobj, dm, hermi=0, with_j=True, with_k=True, direct_scf_tol=1e-13):
268268
# uses numpy.matmul
269269
vj += dmtril.dot(eri1.T).dot(eri1)
270270

271+
elif dms.dtype != numpy.float64:
272+
if with_j:
273+
vj = numpy.zeros_like(dms)
274+
max_memory = dfobj.max_memory - lib.current_memory()[0]
275+
blksize = max(4, int(min(dfobj.blockdim, max_memory*.22e6/8/nao**2)))
276+
buf = numpy.empty((blksize,nao,nao))
277+
buf1 = numpy.empty((nao,blksize,nao))
278+
for eri1 in dfobj.loop(blksize):
279+
naux, nao_pair = eri1.shape
280+
eri1 = lib.unpack_tril(eri1, out=buf)
281+
if with_j:
282+
tmp = numpy.einsum('pij,nji->pn', eri1, dms)
283+
vj += numpy.einsum('pn,pij->nij', tmp, eri1)
284+
buf2 = numpy.ndarray((nao,naux,nao), buffer=buf1)
285+
for k in range(nset):
286+
buf2[:] = lib.einsum('pij,jk->ipk', eri1, dms[k].real)
287+
vk[k].real += lib.einsum('ipk,pkj->ij', buf2, eri1)
288+
buf2[:] = lib.einsum('pij,jk->ipk', eri1, dms[k].imag)
289+
vk[k].imag += lib.einsum('ipk,pkj->ij', buf2, eri1)
290+
t1 = log.timer_debug1('jk', *t1)
291+
271292
elif getattr(dm, 'mo_coeff', None) is not None:
272293
#TODO: test whether dm.mo_coeff matching dm
273294
mo_coeff = numpy.asarray(dm.mo_coeff, order='F')
@@ -322,6 +343,7 @@ def get_jk(dfobj, dm, hermi=0, with_j=True, with_k=True, direct_scf_tol=1e-13):
322343
buf = numpy.empty((2,blksize,nao,nao))
323344
for eri1 in dfobj.loop(blksize):
324345
naux, nao_pair = eri1.shape
346+
assert (nao_pair == nao*(nao+1)//2)
325347
if with_j:
326348
# uses numpy.matmul
327349
vj += dmtril.dot(eri1.T).dot(eri1)
@@ -338,8 +360,12 @@ def get_jk(dfobj, dm, hermi=0, with_j=True, with_k=True, direct_scf_tol=1e-13):
338360
vk[k] += lib.dot(buf1.reshape(-1,nao).T, buf2.reshape(-1,nao))
339361
t1 = log.timer_debug1('jk', *t1)
340362

341-
if with_j: vj = lib.unpack_tril(vj, 1).reshape(dm_shape)
342-
if with_k: vk = vk.reshape(dm_shape)
363+
if with_j:
364+
if dms.dtype == numpy.float64:
365+
vj = lib.unpack_tril(vj, 1)
366+
vj = vj.reshape(dm_shape)
367+
if with_k:
368+
vk = vk.reshape(dm_shape)
343369
logger.timer(dfobj, 'df vj and vk', *t0)
344370
return vj, vk
345371

@@ -348,6 +374,7 @@ def get_j(dfobj, dm, hermi=0, direct_scf_tol=1e-13):
348374
from pyscf.scf import jk
349375
from pyscf.df import addons
350376
t0 = t1 = (logger.process_clock(), logger.perf_counter())
377+
assert dm.dtype == numpy.float64
351378

352379
mol = dfobj.mol
353380
if dfobj._vjopt is None:

pyscf/df/test/test_df_jk.py

+12
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,18 @@ def test_get_j(self):
194194
self.assertAlmostEqual(abs(vj0-vj1).max(), 0, 12)
195195
self.assertAlmostEqual(lib.fp(vj0), -194.15910890730052, 9)
196196

197+
def test_df_jk_complex_dm(self):
198+
mol = gto.M(atom='H 0 0 0; H 0 0 1')
199+
mf = mol.RHF().run()
200+
dm = mf.make_rdm1() + 0j
201+
dm[0,:] += .1j
202+
dm[:,0] -= .1j
203+
mf.kernel(dm)
204+
self.assertTrue(mf.mo_coeff.dtype == numpy.complex128)
205+
dfmf = mf.density_fit()
206+
self.assertAlmostEqual(dfmf.energy_tot(), -1.0661355663696201, 9)
207+
self.assertAlmostEqual(dfmf.energy_tot(), mf.e_tot, 3)
208+
197209

198210
if __name__ == "__main__":
199211
print("Full Tests for df")

0 commit comments

Comments
 (0)