Skip to content

Commit

Permalink
working shell-based indexing with lower memory consumption with batch…
Browse files Browse the repository at this point in the history
…ing (shell local index recomputation); still using int32, work in progress
  • Loading branch information
mihaipgc committed Oct 11, 2023
1 parent 9537f4d commit 3d2cc20
Showing 1 changed file with 55 additions and 42 deletions.
97 changes: 55 additions & 42 deletions pyscf_ipu/nanoDFT/sparse_symmetric_intor_ERI.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ def compute_diff_jk(dm, mol, nprog, nbatch, tolerance, backend):
for zip_counter, (eri, idx) in enumerate(zip(all_eris, all_indices)):
# go from our memory layout to mol.intor("int2e_sph", "s8")

shell_size = eri.shape[-1] # save original tensor shape
num_shells, shell_size = eri.shape # save original tensor shape

comp_distinct_idx_list = []
print(eri.shape)
Expand All @@ -421,28 +421,32 @@ def compute_diff_jk(dm, mol, nprog, nbatch, tolerance, backend):
print('input_ijkl.nbytes/1e6', ijkl_arr*2/1e6)
print('comp_distinct_idx.nbytes/1e6', comp_distinct_idx.astype(np.int16).nbytes/1e6)

remainder = comp_distinct_idx.shape[0] % (nprog*nbatch)
remainder = (eri.shape[0]) % (nprog*nbatch)

# unused for nipu==batches==1
if remainder != 0:
print('padding', remainder, nprog*nbatch-remainder, comp_distinct_idx.shape)
comp_distinct_idx = np.pad(comp_distinct_idx, ((0, nprog*nbatch-remainder), (0, 0)))
eri = jnp.pad(eri.reshape(-1), ((0, nprog*nbatch-remainder)))
comp_distinct_idx = np.pad(comp_distinct_idx.reshape(-1, shell_size, 4), ((0, (nprog*nbatch-remainder)), (0, 0), (0, 0))).reshape(-1, 4)
# eri = jnp.pad(eri.reshape(-1), ((0, (nprog*nbatch-remainder))))
eri = jnp.pad(eri, ((0, nprog*nbatch-remainder), (0, 0))).reshape(-1)
idx = jnp.pad(idx, ((0, nprog*nbatch-remainder), (0, 0)))

print('eri.shape', eri.shape)
print('comp_distinct_idx.shape', comp_distinct_idx.shape)

# output of mol.intor("int2e_ssph", aosym="s8")
comp_distinct_ERI = eri.reshape(nprog, nbatch, -1) #jnp.concatenate([eri.reshape(-1) for eri in all_eris]).reshape(nprog, nbatch, -1)
# comp_distinct_ERI = eri.reshape(nprog, nbatch, -1, shell_size) #jnp.concatenate([eri.reshape(-1) for eri in all_eris]).reshape(nprog, nbatch, -1)
comp_distinct_ERI = eri.reshape(nprog, nbatch, -1)
comp_distinct_idx = comp_distinct_idx.reshape(nprog, nbatch, -1, 4)
idx = idx.reshape(nprog, nbatch, -1, 4)

print('comp_distinct_ERI.shape', comp_distinct_ERI.shape)
print('comp_distinct_idx.shape', comp_distinct_idx.shape)


# compute repetitions caused by 8x symmetry when computing from the distinct_ERI form and scale accordingly
drep = num_repetitions_fast_4d(comp_distinct_idx[:, :, :, 0], comp_distinct_idx[:, :, :, 1], comp_distinct_idx[:, :, :, 2], comp_distinct_idx[:, :, :, 3], xnp=np, dtype=np.uint32)
comp_distinct_ERI = comp_distinct_ERI / drep
# drep = num_repetitions_fast_4d(comp_distinct_idx[:, :, :, 0], comp_distinct_idx[:, :, :, 1], comp_distinct_idx[:, :, :, 2], comp_distinct_idx[:, :, :, 3], xnp=np, dtype=np.uint32)
# comp_distinct_ERI = comp_distinct_ERI / drep


# int16 storage supported but not slicing; use conversion trick to enable slicing
Expand All @@ -451,51 +455,60 @@ def compute_diff_jk(dm, mol, nprog, nbatch, tolerance, backend):

#diff_JK = jax.pmap(sparse_symmetric_einsum, in_axes=(0,0,None,None), static_broadcasted_argnums=(3,), backend=backend, axis_name="p")(comp_distinct_ERI, comp_distinct_idx, dm, backend)
#diff_JK = sparse_symmetric_einsum(comp_distinct_ERI[0], comp_distinct_idx[0], dm, backend)
nonzero_distinct_ERI, nonzero_indices, dm, backend = comp_distinct_ERI[0], comp_distinct_idx[0], dm, backend


# nonzero_distinct_ERI, nonzero_indices, dm, backend = comp_distinct_ERI[0], comp_distinct_idx[0], dm, backend
nonzero_distinct_ERI, nonzero_indices, dm, backend = comp_distinct_ERI[0], idx[0], dm, backend


dm = dm.reshape(-1)
diff_JK = jnp.zeros(dm.shape)
N = int(np.sqrt(dm.shape[0]))

def foreach_batch(i, vals):
diff_JK, nonzero_indices, ao_loc = vals

indices = nonzero_indices[i]

# indices = jax.lax.bitcast_convert_type(indices, np.int16).astype(np.int32)
indices = indices.astype(jnp.int32)
eris = nonzero_distinct_ERI[i]
print(indices.shape)

# exp_distinct_idx = jnp.zeros((eris.shape[0], 4), dtype=jnp.int32)

# # exp_distinct_idx_list = []
# print(eris.shape)
# # for ind in range(eris.shape[0]):
# def gen_all_shell_idx(idx, arr):
# i, j, k, l = [indices[ind, z] for z in range(4)]
# _di, _dj, _dk, _dl = ao_loc[i+1] - ao_loc[i], ao_loc[j+1] - ao_loc[j], ao_loc[k+1] - ao_loc[k], ao_loc[l+1] - ao_loc[l]
# _i0, _j0, _k0, _l0 = ao_loc[i], ao_loc[j], ao_loc[k], ao_loc[l]

# block_idx = jnp.zeros((shell_size, 4), dtype=jnp.int32)

# def gen_shell_idx(idx_sh, arr):
# # Compute the indices
# ind_i = (idx_sh ) % _di + _i0
# ind_j = (idx_sh // (_di) ) % _dj + _j0
# ind_k = (idx_sh // (_di*_dj) ) % _dk + _k0
# ind_l = (idx_sh // (_di*_dj*_dk)) % _dl + _l0

# # Update the array with the computed indices
# return arr.at[idx_sh, :].set(jnp.array([ind_i, ind_j, ind_k, ind_l]))
eris = nonzero_distinct_ERI[i].reshape(-1)

if False:

indices = nonzero_indices[i]

# block_idx = jax.lax.fori_loop(0, shell_size, gen_shell_idx, block_idx)

# # return arr.at[idx*shell_size:(idx+1)*shell_size, :].set(block_idx)
# return jax.lax.dynamic_update_slice(arr, block_idx, (idx*shell_size, 0))

# exp_distinct_idx = jax.lax.fori_loop(0, eris.shape[0]//shell_size, gen_all_shell_idx, exp_distinct_idx) #jnp.concatenate(exp_distinct_idx_list)
# # indices = jax.lax.bitcast_convert_type(indices, np.int16).astype(np.int32)
indices = indices.astype(jnp.int32)


print('eris.shape', eris.shape)
print('indices.shape', indices.shape)

else:

# Compute offsets and sizes
idx = nonzero_indices[i]
_i, _j, _k, _l = [idx[:, z] for z in range(4)]
_di, _dj, _dk, _dl = [(ao_loc[z+1] - ao_loc[z]).reshape(-1, 1) for z in [_i, _j, _k, _l]]
_i0, _j0, _k0, _l0 = [ao_loc[z].reshape(-1, 1) for z in [_i, _j, _k, _l]]

def gen_shell_idx(idx_sh):
idx_sh = idx_sh.reshape(-1, shell_size)
# Compute the indices
ind_i = (idx_sh ) % _di + _i0
ind_j = (idx_sh // (_di) ) % _dj + _j0
ind_k = (idx_sh // (_di*_dj) ) % _dk + _k0
ind_l = (idx_sh // (_di*_dj*_dk)) % _dl + _l0
print('>>', ind_i.shape)
# Update the array with the computed indices
return jnp.stack([ind_i.reshape(-1), ind_j.reshape(-1), ind_k.reshape(-1), ind_l.reshape(-1)], axis=1)

indices = gen_shell_idx(jnp.arange((eris.shape[0]))) # <<<<<<<<<<<<<<<<<<<<<<<<<

print('eris.shape', eris.shape)
print('indices.shape', indices.shape)

# indices = exp_distinct_idx
# compute repetitions caused by 8x symmetry when computing from the distinct_ERI form and scale accordingly
drep = num_repetitions_fast_4d(indices[:, 0], indices[:, 1], indices[:, 2], indices[:, 3], xnp=jnp, dtype=jnp.uint32)
eris = eris / drep


def foreach_symmetry(sym, vals):
Expand Down

0 comments on commit 3d2cc20

Please sign in to comment.