Skip to content

Commit

Permalink
Added comments
Browse files Browse the repository at this point in the history
  • Loading branch information
alexandermath committed Oct 4, 2023
1 parent 6fb7df0 commit 441d206
Showing 1 changed file with 3 additions and 0 deletions.
3 changes: 3 additions & 0 deletions pyscf_ipu/nanoDFT/sparse_symmetric_intor_ERI.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,6 +680,7 @@ def compute_diff_jk(dm, mol, nprog, nbatch, tolerance, backend):
print('[a.shape for a in all_eris]', [a.shape for a in all_eris])
print('[a.shape for a in all_indices]', [a.shape for a in all_indices])


# go from our memory layout to mol.intor("int2e_sph", "s8")
for zip_counter, (eri, idx) in enumerate(zip(all_eris, all_indices)):
print(eri.shape)
Expand Down Expand Up @@ -726,6 +727,7 @@ def ijkl2c(i, j, k, l):
comp_do = np.pad(comp_do, ((0, nprog*nbatch-remainder)))
all_eris.append(jnp.zeros((nprog*nbatch-remainder), dtype=jnp.float32))

# output of mol.intor("int2e_ssph", aosym="s8")
comp_distinct_ERI = jnp.concatenate([eri.reshape(-1) for eri in all_eris]).reshape(nprog, nbatch, -1)
comp_distinct_idx = comp_distinct_idx.reshape(nprog, nbatch, -1, 4)
comp_do = comp_do.reshape(nprog, nbatch, -1)
Expand All @@ -740,6 +742,7 @@ def ijkl2c(i, j, k, l):

# int16 storage supported but not slicing; use conversion trick to enable slicing
comp_distinct_idx = jax.lax.bitcast_convert_type(comp_distinct_idx, jnp.float16)
# reduce this from |eri_floats| to num_calls*4 ~ perhaps 10x smaller

#diff_JK = jax.pmap(sparse_symmetric_einsum, in_axes=(0,0,None,None), static_broadcasted_argnums=(3,), backend=backend, axis_name="p")(comp_distinct_ERI, comp_distinct_idx, dm, backend)
#diff_JK = sparse_symmetric_einsum(comp_distinct_ERI[0], comp_distinct_idx[0], dm, backend)
Expand Down

0 comments on commit 441d206

Please sign in to comment.