Skip to content

Commit 9537f4d

Browse files
committed
refactoring and cleanup; index recomputation work in progress
1 parent 0f017bf commit 9537f4d

File tree

1 file changed

+51
-66
lines changed

1 file changed

+51
-66
lines changed

pyscf_ipu/nanoDFT/sparse_symmetric_intor_ERI.py

+51-66
Original file line numberDiff line numberDiff line change
@@ -90,49 +90,6 @@ def num_repetitions_fast(ij, kl):
9090
k*N+j, k*N+i, l*N+j, l*N+i, i*N+l, i*N+k, j*N+l, j*N+k,
9191
i*N+l, j*N+l, i*N+k, j*N+k, k*N+j, l*N+j, k*N+i, l*N+i])[symmetry]
9292

93-
def sparse_symmetric_einsum(nonzero_distinct_ERI, nonzero_indices, dm, backend):
94-
95-
96-
dm = dm.reshape(-1)
97-
diff_JK = jnp.zeros(dm.shape)
98-
N = int(np.sqrt(dm.shape[0]))
99-
100-
def iteration(symmetry, vals):
101-
diff_JK = vals
102-
is_K_matrix = (symmetry >= 8)
103-
104-
def sequentialized_iter(i, vals):
105-
# Generalized J/K computation: does J when symmetry is in range(0,8) and K when symmetry is in range(8,16)
106-
# Trade-off: Using one function leads to smaller always-live memory.
107-
diff_JK = vals
108-
109-
indices = nonzero_indices[i]
110-
111-
indices = jax.lax.bitcast_convert_type(indices, np.int16).astype(np.int32)
112-
eris = nonzero_distinct_ERI[i]
113-
print(indices.shape)
114-
115-
if backend == "cpu": dm_indices = cpu_ijkl(indices, symmetry+is_K_matrix*8, indices_func)
116-
else: dm_indices = ipu_ijkl(indices, symmetry+is_K_matrix*8, N)
117-
dm_values = jnp.take(dm, dm_indices, axis=0)
118-
119-
print('nonzero_distinct_ERI.shape', nonzero_distinct_ERI.shape)
120-
print('dm_values.shape', dm_values.shape)
121-
print('eris.shape', eris.shape)
122-
dm_values = dm_values.at[:].mul( eris ) # this is prod, but re-use variable for inplace update.
123-
124-
if backend == "cpu": ss_indices = cpu_ijkl(indices, symmetry+8+is_K_matrix*8, indices_func)
125-
else: ss_indices = ipu_ijkl(indices, symmetry+8+is_K_matrix*8, N)
126-
diff_JK = diff_JK + jax.ops.segment_sum(dm_values, ss_indices, N**2) * (-HYB_B3LYP/2)**is_K_matrix
127-
128-
return diff_JK
129-
130-
batches = nonzero_indices.shape[0] # before pmap, tensor had shape (nipus, batches, -1) so [0]=batches after pmap
131-
diff_JK = jax.lax.fori_loop(0, batches, sequentialized_iter, diff_JK)
132-
return diff_JK
133-
134-
return jax.lax.fori_loop(0, 16, iteration, diff_JK)
135-
13693
def get_shapes(input_ijkl, bas):
13794
i_sh, j_sh, k_sh, l_sh = input_ijkl[0]
13895
BAS_SLOTS = 8
@@ -439,11 +396,10 @@ def compute_diff_jk(dm, mol, nprog, nbatch, tolerance, backend):
439396

440397
temp = 0
441398
for zip_counter, (eri, idx) in enumerate(zip(all_eris, all_indices)):
399+
# go from our memory layout to mol.intor("int2e_sph", "s8")
442400

443-
# comp_list_index = 0
401+
shell_size = eri.shape[-1] # save original tensor shape
444402

445-
# go from our memory layout to mol.intor("int2e_sph", "s8")
446-
# for zip_counter, (eri, idx) in enumerate(zip(all_eris, all_indices)):
447403
comp_distinct_idx_list = []
448404
print(eri.shape)
449405
for ind in range(eri.shape[0]):
@@ -454,7 +410,7 @@ def compute_diff_jk(dm, mol, nprog, nbatch, tolerance, backend):
454410
_i0:(_i0+_di),
455411
_j0:(_j0+_dj),
456412
_k0:(_k0+_dk),
457-
_l0:(_l0+_dl)].transpose(4, 3, 2, 1, 0).astype(np.int16)
413+
_l0:(_l0+_dl)].transpose(4, 3, 2, 1, 0) #.astype(np.int16)
458414

459415
comp_distinct_idx_list.append(block_idx.reshape(-1, 4))
460416
# comp_list_index += 1
@@ -472,7 +428,9 @@ def compute_diff_jk(dm, mol, nprog, nbatch, tolerance, backend):
472428
print('padding', remainder, nprog*nbatch-remainder, comp_distinct_idx.shape)
473429
comp_distinct_idx = np.pad(comp_distinct_idx, ((0, nprog*nbatch-remainder), (0, 0)))
474430
eri = jnp.pad(eri.reshape(-1), ((0, nprog*nbatch-remainder)))
475-
431+
432+
print('eri.shape', eri.shape)
433+
print('comp_distinct_idx.shape', comp_distinct_idx.shape)
476434

477435
# output of mol.intor("int2e_ssph", aosym="s8")
478436
comp_distinct_ERI = eri.reshape(nprog, nbatch, -1) #jnp.concatenate([eri.reshape(-1) for eri in all_eris]).reshape(nprog, nbatch, -1)
@@ -498,50 +456,77 @@ def compute_diff_jk(dm, mol, nprog, nbatch, tolerance, backend):
498456
diff_JK = jnp.zeros(dm.shape)
499457
N = int(np.sqrt(dm.shape[0]))
500458

501-
def iteration(i, vals):
502-
diff_JK = vals
459+
def foreach_batch(i, vals):
460+
diff_JK, nonzero_indices, ao_loc = vals
503461

504462
indices = nonzero_indices[i]
505463

506464
# indices = jax.lax.bitcast_convert_type(indices, np.int16).astype(np.int32)
507465
indices = indices.astype(jnp.int32)
508466
eris = nonzero_distinct_ERI[i]
509467
print(indices.shape)
468+
469+
# exp_distinct_idx = jnp.zeros((eris.shape[0], 4), dtype=jnp.int32)
470+
471+
# # exp_distinct_idx_list = []
472+
# print(eris.shape)
473+
# # for ind in range(eris.shape[0]):
474+
# def gen_all_shell_idx(idx, arr):
475+
# i, j, k, l = [indices[ind, z] for z in range(4)]
476+
# _di, _dj, _dk, _dl = ao_loc[i+1] - ao_loc[i], ao_loc[j+1] - ao_loc[j], ao_loc[k+1] - ao_loc[k], ao_loc[l+1] - ao_loc[l]
477+
# _i0, _j0, _k0, _l0 = ao_loc[i], ao_loc[j], ao_loc[k], ao_loc[l]
478+
479+
# block_idx = jnp.zeros((shell_size, 4), dtype=jnp.int32)
480+
481+
# def gen_shell_idx(idx_sh, arr):
482+
# # Compute the indices
483+
# ind_i = (idx_sh ) % _di + _i0
484+
# ind_j = (idx_sh // (_di) ) % _dj + _j0
485+
# ind_k = (idx_sh // (_di*_dj) ) % _dk + _k0
486+
# ind_l = (idx_sh // (_di*_dj*_dk)) % _dl + _l0
487+
488+
# # Update the array with the computed indices
489+
# return arr.at[idx_sh, :].set(jnp.array([ind_i, ind_j, ind_k, ind_l]))
490+
491+
# block_idx = jax.lax.fori_loop(0, shell_size, gen_shell_idx, block_idx)
492+
493+
# # return arr.at[idx*shell_size:(idx+1)*shell_size, :].set(block_idx)
494+
# return jax.lax.dynamic_update_slice(arr, block_idx, (idx*shell_size, 0))
495+
496+
# exp_distinct_idx = jax.lax.fori_loop(0, eris.shape[0]//shell_size, gen_all_shell_idx, exp_distinct_idx) #jnp.concatenate(exp_distinct_idx_list)
497+
498+
# indices = exp_distinct_idx
510499

511500

512-
def sequentialized_iter(symmetry, vals):
501+
def foreach_symmetry(sym, vals):
513502
# Generalized J/K computation: does J when symmetry is in range(0,8) and K when symmetry is in range(8,16)
514503
# Trade-off: Using one function leads to smaller always-live memory.
515-
is_K_matrix = (symmetry >= 8)
516-
504+
is_K_matrix = (sym >= 8)
517505
diff_JK = vals
518506

519-
520-
521-
if backend == "cpu": dm_indices = cpu_ijkl(indices, symmetry+is_K_matrix*8, indices_func)
522-
else: dm_indices = ipu_ijkl(indices, symmetry+is_K_matrix*8, N)
507+
if backend == "cpu": dm_indices = cpu_ijkl(indices, sym+is_K_matrix*8, indices_func)
508+
else: dm_indices = ipu_ijkl(indices, sym+is_K_matrix*8, N)
523509
dm_values = jnp.take(dm, dm_indices, axis=0)
524510

525-
print('nonzero_distinct_ERI.shape', nonzero_distinct_ERI.shape)
511+
print('indices.shape', indices.shape)
526512
print('dm_values.shape', dm_values.shape)
527513
print('eris.shape', eris.shape)
528514
dm_values = dm_values.at[:].mul( eris ) # this is prod, but re-use variable for inplace update.
529515

530-
if backend == "cpu": ss_indices = cpu_ijkl(indices, symmetry+8+is_K_matrix*8, indices_func)
531-
else: ss_indices = ipu_ijkl(indices, symmetry+8+is_K_matrix*8, N)
516+
if backend == "cpu": ss_indices = cpu_ijkl(indices, sym+8+is_K_matrix*8, indices_func)
517+
else: ss_indices = ipu_ijkl(indices, sym+8+is_K_matrix*8, N)
532518
diff_JK = diff_JK + jax.ops.segment_sum(dm_values, ss_indices, N**2) * (-HYB_B3LYP/2)**is_K_matrix
533519

534520
return diff_JK
535521

536522

537-
# diff_JK = jax.lax.fori_loop(0, batches, sequentialized_iter, diff_JK)
538-
diff_JK = jax.lax.fori_loop(0, 16, sequentialized_iter, diff_JK)
539-
# diff_JK = sequentialized_iter(0, diff_JK)
540-
return diff_JK
523+
diff_JK = jax.lax.fori_loop(0, 16, foreach_symmetry, diff_JK)
524+
525+
return (diff_JK, nonzero_indices, ao_loc)
541526

542527
batches = nonzero_indices.shape[0] # before pmap, tensor had shape (nipus, batches, -1) so [0]=batches after pmap
543-
for bi in range(batches):
544-
diff_JK = iteration(bi, diff_JK)
528+
529+
diff_JK, _, _ = jax.lax.fori_loop(0, batches, foreach_batch, (diff_JK, nonzero_indices, ao_loc))
545530

546531
temp += diff_JK
547532

0 commit comments

Comments
 (0)