@@ -268,6 +268,27 @@ def get_jk(dfobj, dm, hermi=0, with_j=True, with_k=True, direct_scf_tol=1e-13):
268
268
# uses numpy.matmul
269
269
vj += dmtril .dot (eri1 .T ).dot (eri1 )
270
270
271
+ elif dms .dtype != numpy .float64 :
272
+ if with_j :
273
+ vj = numpy .zeros_like (dms )
274
+ max_memory = dfobj .max_memory - lib .current_memory ()[0 ]
275
+ blksize = max (4 , int (min (dfobj .blockdim , max_memory * .22e6 / 8 / nao ** 2 )))
276
+ buf = numpy .empty ((blksize ,nao ,nao ))
277
+ buf1 = numpy .empty ((nao ,blksize ,nao ))
278
+ for eri1 in dfobj .loop (blksize ):
279
+ naux , nao_pair = eri1 .shape
280
+ eri1 = lib .unpack_tril (eri1 , out = buf )
281
+ if with_j :
282
+ tmp = numpy .einsum ('pij,nji->pn' , eri1 , dms )
283
+ vj += numpy .einsum ('pn,pij->nij' , tmp , eri1 )
284
+ buf2 = numpy .ndarray ((nao ,naux ,nao ), buffer = buf1 )
285
+ for k in range (nset ):
286
+ buf2 [:] = lib .einsum ('pij,jk->ipk' , eri1 , dms [k ].real )
287
+ vk [k ].real += lib .einsum ('ipk,pkj->ij' , buf2 , eri1 )
288
+ buf2 [:] = lib .einsum ('pij,jk->ipk' , eri1 , dms [k ].imag )
289
+ vk [k ].imag += lib .einsum ('ipk,pkj->ij' , buf2 , eri1 )
290
+ t1 = log .timer_debug1 ('jk' , * t1 )
291
+
271
292
elif getattr (dm , 'mo_coeff' , None ) is not None :
272
293
#TODO: test whether dm.mo_coeff matching dm
273
294
mo_coeff = numpy .asarray (dm .mo_coeff , order = 'F' )
@@ -322,6 +343,7 @@ def get_jk(dfobj, dm, hermi=0, with_j=True, with_k=True, direct_scf_tol=1e-13):
322
343
buf = numpy .empty ((2 ,blksize ,nao ,nao ))
323
344
for eri1 in dfobj .loop (blksize ):
324
345
naux , nao_pair = eri1 .shape
346
+ assert (nao_pair == nao * (nao + 1 )// 2 )
325
347
if with_j :
326
348
# uses numpy.matmul
327
349
vj += dmtril .dot (eri1 .T ).dot (eri1 )
@@ -338,8 +360,12 @@ def get_jk(dfobj, dm, hermi=0, with_j=True, with_k=True, direct_scf_tol=1e-13):
338
360
vk [k ] += lib .dot (buf1 .reshape (- 1 ,nao ).T , buf2 .reshape (- 1 ,nao ))
339
361
t1 = log .timer_debug1 ('jk' , * t1 )
340
362
341
- if with_j : vj = lib .unpack_tril (vj , 1 ).reshape (dm_shape )
342
- if with_k : vk = vk .reshape (dm_shape )
363
+ if with_j :
364
+ if dms .dtype == numpy .float64 :
365
+ vj = lib .unpack_tril (vj , 1 )
366
+ vj = vj .reshape (dm_shape )
367
+ if with_k :
368
+ vk = vk .reshape (dm_shape )
343
369
logger .timer (dfobj , 'df vj and vk' , * t0 )
344
370
return vj , vk
345
371
@@ -348,6 +374,7 @@ def get_j(dfobj, dm, hermi=0, direct_scf_tol=1e-13):
348
374
from pyscf .scf import jk
349
375
from pyscf .df import addons
350
376
t0 = t1 = (logger .process_clock (), logger .perf_counter ())
377
+ assert dm .dtype == numpy .float64
351
378
352
379
mol = dfobj .mol
353
380
if dfobj ._vjopt is None :
0 commit comments