Skip to content

Commit

Permalink
missed timer fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
cjknight committed Mar 28, 2024
1 parent f4e7432 commit 266b25a
Showing 1 changed file with 11 additions and 4 deletions.
15 changes: 11 additions & 4 deletions gpu/gpu4pyscf/df/df_jk.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,15 +64,19 @@ def get_jk(dfobj, dm, hermi=1, with_j=True, with_k=True, direct_scf_tol=1e-13):
t2 = lib.logger.timer(dfobj, 'get_jk setup', *t0)

if with_j:

t2 = (logger.process_clock(), logger.perf_counter())
idx = numpy.arange(nao)
dmtril = lib.pack_tril(dms + dms.conj().transpose(0,2,1))
dmtril[:,idx*(idx+1)//2+idx] *= .5
t3 = lib.logger.timer(dfobj, 'get_jk with_j',*t2)
t3 = lib.logger.timer(dfobj, 'get_jk with_j',*t2)

if not with_k:
t2 = (logger.process_clock(), logger.perf_counter())
for eri1 in dfobj.loop():
rho = numpy.einsum('ix,px->ip', dmtril, eri1)
vj += numpy.einsum('ip,px->ix', rho, eri1)
t3 = lib.logger.timer(dfobj, 'get_jk not with_k',*t2)

# Commented 2-19-2024 in favor of accelerated implementation below
# Can offload this if need arises.
Expand Down Expand Up @@ -122,6 +126,8 @@ def get_jk(dfobj, dm, hermi=1, with_j=True, with_k=True, direct_scf_tol=1e-13):
else:
#:vk = numpy.einsum('pij,jk->pki', cderi, dm)
#:vk = numpy.einsum('pki,pkj->ij', cderi, vk)

t2 = (logger.process_clock(), logger.perf_counter())
rargs = (ctypes.c_int(nao), (ctypes.c_int*4)(0, nao, 0, nao),
null, ctypes.c_int(0))
dms = [numpy.asarray(x, order='F') for x in dms]
Expand All @@ -133,7 +139,7 @@ def get_jk(dfobj, dm, hermi=1, with_j=True, with_k=True, direct_scf_tol=1e-13):
count = 0
vj = numpy.zeros_like(dmtril)

t4 = lib.logger.timer(dfobj, 'get_jk with_k setup',*t3)
t3 = lib.logger.timer(dfobj, 'get_jk with_k setup',*t2)
for eri1 in dfobj.loop(blksize): # how much time spent unnecessarily copying eri1 data?
naux, nao_pair = eri1.shape

Expand Down Expand Up @@ -161,14 +167,15 @@ def get_jk(dfobj, dm, hermi=1, with_j=True, with_k=True, direct_scf_tol=1e-13):
count+=1

t1 = log.timer_debug1('jk', *t1)
t5 = lib.logger.timer(dfobj, 'get_jk with_k loop',*t4)
t4 = lib.logger.timer(dfobj, 'get_jk with_k loop',*t3)

if gpu:
libgpu.libgpu_pull_get_jk(gpu, vj, vk, 1)

t2 = (logger.process_clock(), logger.perf_counter())
if with_j: vj = lib.unpack_tril(vj, 1).reshape(dm_shape)
if with_k: vk = vk.reshape(dm_shape)
lib.logger.timer(dfobj, 'get_jk finalize',*t5)
lib.logger.timer(dfobj, 'get_jk finalize',*t2)
logger.timer(dfobj, 'df vj and vk', *t0)
return vj, vk

Expand Down

0 comments on commit 266b25a

Please sign in to comment.