diff --git a/pyscf_ipu/nanoDFT/sparse_symmetric_intor_ERI.py b/pyscf_ipu/nanoDFT/sparse_symmetric_intor_ERI.py index 55c9b00..0dd4e14 100644 --- a/pyscf_ipu/nanoDFT/sparse_symmetric_intor_ERI.py +++ b/pyscf_ipu/nanoDFT/sparse_symmetric_intor_ERI.py @@ -398,7 +398,7 @@ def compute_diff_jk(dm, mol, nprog, nbatch, tolerance, backend): for zip_counter, (eri, idx) in enumerate(zip(all_eris, all_indices)): # go from our memory layout to mol.intor("int2e_sph", "s8") - shell_size = eri.shape[-1] # save original tensor shape + num_shells, shell_size = eri.shape # save original tensor shape comp_distinct_idx_list = [] print(eri.shape) @@ -421,28 +421,32 @@ def compute_diff_jk(dm, mol, nprog, nbatch, tolerance, backend): print('input_ijkl.nbytes/1e6', ijkl_arr*2/1e6) print('comp_distinct_idx.nbytes/1e6', comp_distinct_idx.astype(np.int16).nbytes/1e6) - remainder = comp_distinct_idx.shape[0] % (nprog*nbatch) + remainder = (eri.shape[0]) % (nprog*nbatch) # unused for nipu==batches==1 if remainder != 0: print('padding', remainder, nprog*nbatch-remainder, comp_distinct_idx.shape) - comp_distinct_idx = np.pad(comp_distinct_idx, ((0, nprog*nbatch-remainder), (0, 0))) - eri = jnp.pad(eri.reshape(-1), ((0, nprog*nbatch-remainder))) + comp_distinct_idx = np.pad(comp_distinct_idx.reshape(-1, shell_size, 4), ((0, (nprog*nbatch-remainder)), (0, 0), (0, 0))).reshape(-1, 4) + # eri = jnp.pad(eri.reshape(-1), ((0, (nprog*nbatch-remainder)))) + eri = jnp.pad(eri, ((0, nprog*nbatch-remainder), (0, 0))).reshape(-1) + idx = jnp.pad(idx, ((0, nprog*nbatch-remainder), (0, 0))) print('eri.shape', eri.shape) print('comp_distinct_idx.shape', comp_distinct_idx.shape) # output of mol.intor("int2e_ssph", aosym="s8") - comp_distinct_ERI = eri.reshape(nprog, nbatch, -1) #jnp.concatenate([eri.reshape(-1) for eri in all_eris]).reshape(nprog, nbatch, -1) + # comp_distinct_ERI = eri.reshape(nprog, nbatch, -1, shell_size) #jnp.concatenate([eri.reshape(-1) for eri in all_eris]).reshape(nprog, nbatch, -1) + comp_distinct_ERI = eri.reshape(nprog, nbatch, -1) comp_distinct_idx = comp_distinct_idx.reshape(nprog, nbatch, -1, 4) + idx = idx.reshape(nprog, nbatch, -1, 4) print('comp_distinct_ERI.shape', comp_distinct_ERI.shape) print('comp_distinct_idx.shape', comp_distinct_idx.shape) # compute repetitions caused by 8x symmetry when computing from the distinct_ERI form and scale accordingly - drep = num_repetitions_fast_4d(comp_distinct_idx[:, :, :, 0], comp_distinct_idx[:, :, :, 1], comp_distinct_idx[:, :, :, 2], comp_distinct_idx[:, :, :, 3], xnp=np, dtype=np.uint32) - comp_distinct_ERI = comp_distinct_ERI / drep + # drep = num_repetitions_fast_4d(comp_distinct_idx[:, :, :, 0], comp_distinct_idx[:, :, :, 1], comp_distinct_idx[:, :, :, 2], comp_distinct_idx[:, :, :, 3], xnp=np, dtype=np.uint32) + # comp_distinct_ERI = comp_distinct_ERI / drep # int16 storage supported but not slicing; use conversion trick to enable slicing @@ -451,7 +455,12 @@ def compute_diff_jk(dm, mol, nprog, nbatch, tolerance, backend): #diff_JK = jax.pmap(sparse_symmetric_einsum, in_axes=(0,0,None,None), static_broadcasted_argnums=(3,), backend=backend, axis_name="p")(comp_distinct_ERI, comp_distinct_idx, dm, backend) #diff_JK = sparse_symmetric_einsum(comp_distinct_ERI[0], comp_distinct_idx[0], dm, backend) - nonzero_distinct_ERI, nonzero_indices, dm, backend = comp_distinct_ERI[0], comp_distinct_idx[0], dm, backend + + + # nonzero_distinct_ERI, nonzero_indices, dm, backend = comp_distinct_ERI[0], comp_distinct_idx[0], dm, backend + nonzero_distinct_ERI, nonzero_indices, dm, backend = comp_distinct_ERI[0], idx[0], dm, backend + + dm = dm.reshape(-1) diff_JK = jnp.zeros(dm.shape) N = int(np.sqrt(dm.shape[0])) @@ -459,43 +468,47 @@ def compute_diff_jk(dm, mol, nprog, nbatch, tolerance, backend): def foreach_batch(i, vals): diff_JK, nonzero_indices, ao_loc = vals - indices = nonzero_indices[i] - - # indices = jax.lax.bitcast_convert_type(indices, np.int16).astype(np.int32) - indices = indices.astype(jnp.int32) - eris = nonzero_distinct_ERI[i] - print(indices.shape) - - # exp_distinct_idx = jnp.zeros((eris.shape[0], 4), dtype=jnp.int32) - - # # exp_distinct_idx_list = [] - # print(eris.shape) - # # for ind in range(eris.shape[0]): - # def gen_all_shell_idx(idx, arr): - # i, j, k, l = [indices[ind, z] for z in range(4)] - # _di, _dj, _dk, _dl = ao_loc[i+1] - ao_loc[i], ao_loc[j+1] - ao_loc[j], ao_loc[k+1] - ao_loc[k], ao_loc[l+1] - ao_loc[l] - # _i0, _j0, _k0, _l0 = ao_loc[i], ao_loc[j], ao_loc[k], ao_loc[l] - - # block_idx = jnp.zeros((shell_size, 4), dtype=jnp.int32) - - # def gen_shell_idx(idx_sh, arr): - # # Compute the indices - # ind_i = (idx_sh ) % _di + _i0 - # ind_j = (idx_sh // (_di) ) % _dj + _j0 - # ind_k = (idx_sh // (_di*_dj) ) % _dk + _k0 - # ind_l = (idx_sh // (_di*_dj*_dk)) % _dl + _l0 - - # # Update the array with the computed indices - # return arr.at[idx_sh, :].set(jnp.array([ind_i, ind_j, ind_k, ind_l])) + eris = nonzero_distinct_ERI[i].reshape(-1) + + if False: + + indices = nonzero_indices[i] - # block_idx = jax.lax.fori_loop(0, shell_size, gen_shell_idx, block_idx) - # # return arr.at[idx*shell_size:(idx+1)*shell_size, :].set(block_idx) - # return jax.lax.dynamic_update_slice(arr, block_idx, (idx*shell_size, 0)) - - # exp_distinct_idx = jax.lax.fori_loop(0, eris.shape[0]//shell_size, gen_all_shell_idx, exp_distinct_idx) #jnp.concatenate(exp_distinct_idx_list) + # # indices = jax.lax.bitcast_convert_type(indices, np.int16).astype(np.int32) + indices = indices.astype(jnp.int32) + + + print('eris.shape', eris.shape) + print('indices.shape', indices.shape) + + else: + + # Compute offsets and sizes + idx = nonzero_indices[i] + _i, _j, _k, _l = [idx[:, z] for z in range(4)] + _di, _dj, _dk, _dl = [(ao_loc[z+1] - ao_loc[z]).reshape(-1, 1) for z in [_i, _j, _k, _l]] + _i0, _j0, _k0, _l0 = [ao_loc[z].reshape(-1, 1) for z in [_i, _j, _k, _l]] + + def gen_shell_idx(idx_sh): + idx_sh = idx_sh.reshape(-1, shell_size) + # Compute the indices + ind_i = (idx_sh ) % _di + _i0 + ind_j = (idx_sh // (_di) ) % _dj + _j0 + ind_k = (idx_sh // (_di*_dj) ) % _dk + _k0 + ind_l = (idx_sh // (_di*_dj*_dk)) % _dl + _l0 + print('>>', ind_i.shape) + # Update the array with the computed indices + return jnp.stack([ind_i.reshape(-1), ind_j.reshape(-1), ind_k.reshape(-1), ind_l.reshape(-1)], axis=1) + + indices = gen_shell_idx(jnp.arange((eris.shape[0]))) # <<<<<<<<<<<<<<<<<<<<<<<<< + + print('eris.shape', eris.shape) + print('indices.shape', indices.shape) - # indices = exp_distinct_idx + # compute repetitions caused by 8x symmetry when computing from the distinct_ERI form and scale accordingly + drep = num_repetitions_fast_4d(indices[:, 0], indices[:, 1], indices[:, 2], indices[:, 3], xnp=jnp, dtype=jnp.uint32) + eris = eris / drep def foreach_symmetry(sym, vals):