Skip to content

Commit

Permalink
cleaned code
Browse files Browse the repository at this point in the history
  • Loading branch information
mihaipgc committed Oct 11, 2023
1 parent 2f5be9c commit d3d716a
Showing 1 changed file with 26 additions and 72 deletions.
98 changes: 26 additions & 72 deletions pyscf_ipu/nanoDFT/sparse_symmetric_intor_ERI.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def get_shapes(input_ijkl, bas):

return len, nf

def compute_diff_jk(dm, mol, nprog, nbatch, tolerance, backend):
def compute_diff_jk(dm, mol, nbatch, tolerance, backend):
dm = dm.reshape(-1)
diff_JK = jnp.zeros(dm.shape)
N = int(np.sqrt(dm.shape[0]))
Expand Down Expand Up @@ -392,91 +392,45 @@ def compute_diff_jk(dm, mol, nprog, nbatch, tolerance, backend):

num_shells, shell_size = eri.shape # save original tensor shape

def compute_full_shell_idx(idx):
comp_distinct_idx_list = []
for ind in range(eri.shape[0]):
i, j, k, l = [idx[ind, z] for z in range(4)]
_di, _dj, _dk, _dl = ao_loc[i+1] - ao_loc[i], ao_loc[j+1] - ao_loc[j], ao_loc[k+1] - ao_loc[k], ao_loc[l+1] - ao_loc[l]
_i0, _j0, _k0, _l0 = ao_loc[i], ao_loc[j], ao_loc[k], ao_loc[l]
block_idx = np.mgrid[
_i0:(_i0+_di),
_j0:(_j0+_dj),
_k0:(_k0+_dk),
_l0:(_l0+_dl)].transpose(4, 3, 2, 1, 0) #.astype(np.int16)

comp_distinct_idx_list.append(block_idx.reshape(-1, 4))
comp_distinct_idx = np.concatenate(comp_distinct_idx_list)
return comp_distinct_idx

comp_distinct_idx = compute_full_shell_idx(idx)

ijkl_arr = np.sum([np.prod(np.array(a).shape) for a in input_ijkl])
print('input_ijkl.nbytes/1e6', ijkl_arr*2/1e6)
print('comp_distinct_idx.nbytes/1e6', comp_distinct_idx.astype(np.int16).nbytes/1e6)

remainder = (eri.shape[0]) % (nprog*nbatch)
remainder = (eri.shape[0]) % (nbatch)

# unused for nipu==batches==1
# pad tensors; unused for nipu==batches==1
if remainder != 0:
print('padding', remainder, nprog*nbatch-remainder, comp_distinct_idx.shape)
comp_distinct_idx = np.pad(comp_distinct_idx.reshape(-1, shell_size, 4), ((0, (nprog*nbatch-remainder)), (0, 0), (0, 0))).reshape(-1, 4)
eri = jnp.pad(eri, ((0, nprog*nbatch-remainder), (0, 0)))
idx = jnp.pad(idx, ((0, nprog*nbatch-remainder), (0, 0)))
eri = jnp.pad(eri, ((0, nbatch-remainder), (0, 0)))
idx = jnp.pad(idx, ((0, nbatch-remainder), (0, 0)))

comp_distinct_ERI = eri.reshape(nprog, nbatch, -1)
comp_distinct_idx = comp_distinct_idx.reshape(nprog, nbatch, -1, 4)
idx = idx.reshape(nprog, nbatch, -1, 4)

nonzero_distinct_ERI = eri.reshape(nbatch, -1)
nonzero_indices = idx.reshape(nbatch, -1, 4)

# nonzero_distinct_ERI, nonzero_indices, dm, backend = comp_distinct_ERI[0], comp_distinct_idx[0], dm, backend
nonzero_distinct_ERI, nonzero_indices, dm, backend = comp_distinct_ERI[0], idx[0], dm, backend


dm = dm.reshape(-1)
diff_JK = jnp.zeros(dm.shape)
N = int(np.sqrt(dm.shape[0]))

def foreach_batch(i, vals):
diff_JK, nonzero_indices, ao_loc = vals

eris = nonzero_distinct_ERI[i].reshape(-1)

if False:

indices = nonzero_indices[i]
# # indices = jax.lax.bitcast_convert_type(indices, np.int16).astype(np.int32)
indices = indices.astype(jnp.int32)

print('eris.shape', eris.shape)
print('indices.shape', indices.shape)

else:
# Compute offsets and sizes
idx = nonzero_indices[i]
_i, _j, _k, _l = [idx[:, z] for z in range(4)]
_di, _dj, _dk, _dl = [(ao_loc[z+1] - ao_loc[z]).reshape(-1, 1) for z in [_i, _j, _k, _l]]
_i0, _j0, _k0, _l0 = [ao_loc[z].reshape(-1, 1) for z in [_i, _j, _k, _l]]

def gen_shell_idx(idx_sh):
idx_sh = idx_sh.reshape(-1, shell_size)
# Compute the indices
ind_i = (idx_sh ) % _di + _i0
ind_j = (idx_sh // (_di) ) % _dj + _j0
ind_k = (idx_sh // (_di*_dj) ) % _dk + _k0
ind_l = (idx_sh // (_di*_dj*_dk)) % _dl + _l0
print('>>', ind_i.shape)
# Update the array with the computed indices
return jnp.stack([ind_i.reshape(-1), ind_j.reshape(-1), ind_k.reshape(-1), ind_l.reshape(-1)], axis=1)

indices = gen_shell_idx(jnp.arange((eris.shape[0]))) # <<<<<<<<<<<<<<<<<<<<<<<<<

print('eris.shape', eris.shape)
print('indices.shape', indices.shape)
# Compute offsets and sizes
batch_idx = nonzero_indices[i]
_i, _j, _k, _l = [batch_idx[:, z] for z in range(4)]
_di, _dj, _dk, _dl = [(ao_loc[z+1] - ao_loc[z]).reshape(-1, 1) for z in [_i, _j, _k, _l]]
_i0, _j0, _k0, _l0 = [ao_loc[z].reshape(-1, 1) for z in [_i, _j, _k, _l]]

def gen_shell_idx(idx_sh):
# Compute the indices
ind_i = (idx_sh ) % _di + _i0
ind_j = (idx_sh // (_di) ) % _dj + _j0
ind_k = (idx_sh // (_di*_dj) ) % _dk + _k0
ind_l = (idx_sh // (_di*_dj*_dk)) % _dl + _l0

# Update the array with the computed indices
return jnp.stack([ind_i.reshape(-1), ind_j.reshape(-1), ind_k.reshape(-1), ind_l.reshape(-1)], axis=1)

eris = nonzero_distinct_ERI[i].reshape(-1)
indices = gen_shell_idx(jnp.arange((eris.shape[0])).reshape(-1, shell_size))

# compute repetitions caused by 8x symmetry when computing from the distinct_ERI form and scale accordingly
drep = num_repetitions_fast_4d(indices[:, 0], indices[:, 1], indices[:, 2], indices[:, 3], xnp=jnp, dtype=jnp.uint32)
eris = eris / drep


def foreach_symmetry(sym, vals):
# Generalized J/K computation: does J when symmetry is in range(0,8) and K when symmetry is in range(8,16)
Expand Down Expand Up @@ -570,7 +524,7 @@ def foreach_symmetry(sym, vals):

# ------------------------------------ #

diff_JK = jax.jit(compute_diff_jk, backend=backend, static_argnames=['mol', 'nprog', 'nbatch', 'tolerance', 'backend'])(dm, mol, args.nipu, args.batches, args.itol, args.backend)
diff_JK = jax.jit(compute_diff_jk, backend=backend, static_argnames=['mol', 'nbatch', 'tolerance', 'backend'])(dm, mol, args.batches, args.itol, args.backend)

# ------------------------------------ #

Expand Down

0 comments on commit d3d716a

Please sign in to comment.