@@ -90,49 +90,6 @@ def num_repetitions_fast(ij, kl):
90
90
k * N + j , k * N + i , l * N + j , l * N + i , i * N + l , i * N + k , j * N + l , j * N + k ,
91
91
i * N + l , j * N + l , i * N + k , j * N + k , k * N + j , l * N + j , k * N + i , l * N + i ])[symmetry ]
92
92
93
- def sparse_symmetric_einsum (nonzero_distinct_ERI , nonzero_indices , dm , backend ):
94
-
95
-
96
- dm = dm .reshape (- 1 )
97
- diff_JK = jnp .zeros (dm .shape )
98
- N = int (np .sqrt (dm .shape [0 ]))
99
-
100
- def iteration (symmetry , vals ):
101
- diff_JK = vals
102
- is_K_matrix = (symmetry >= 8 )
103
-
104
- def sequentialized_iter (i , vals ):
105
- # Generalized J/K computation: does J when symmetry is in range(0,8) and K when symmetry is in range(8,16)
106
- # Trade-off: Using one function leads to smaller always-live memory.
107
- diff_JK = vals
108
-
109
- indices = nonzero_indices [i ]
110
-
111
- indices = jax .lax .bitcast_convert_type (indices , np .int16 ).astype (np .int32 )
112
- eris = nonzero_distinct_ERI [i ]
113
- print (indices .shape )
114
-
115
- if backend == "cpu" : dm_indices = cpu_ijkl (indices , symmetry + is_K_matrix * 8 , indices_func )
116
- else : dm_indices = ipu_ijkl (indices , symmetry + is_K_matrix * 8 , N )
117
- dm_values = jnp .take (dm , dm_indices , axis = 0 )
118
-
119
- print ('nonzero_distinct_ERI.shape' , nonzero_distinct_ERI .shape )
120
- print ('dm_values.shape' , dm_values .shape )
121
- print ('eris.shape' , eris .shape )
122
- dm_values = dm_values .at [:].mul ( eris ) # this is prod, but re-use variable for inplace update.
123
-
124
- if backend == "cpu" : ss_indices = cpu_ijkl (indices , symmetry + 8 + is_K_matrix * 8 , indices_func )
125
- else : ss_indices = ipu_ijkl (indices , symmetry + 8 + is_K_matrix * 8 , N )
126
- diff_JK = diff_JK + jax .ops .segment_sum (dm_values , ss_indices , N ** 2 ) * (- HYB_B3LYP / 2 )** is_K_matrix
127
-
128
- return diff_JK
129
-
130
- batches = nonzero_indices .shape [0 ] # before pmap, tensor had shape (nipus, batches, -1) so [0]=batches after pmap
131
- diff_JK = jax .lax .fori_loop (0 , batches , sequentialized_iter , diff_JK )
132
- return diff_JK
133
-
134
- return jax .lax .fori_loop (0 , 16 , iteration , diff_JK )
135
-
136
93
def get_shapes (input_ijkl , bas ):
137
94
i_sh , j_sh , k_sh , l_sh = input_ijkl [0 ]
138
95
BAS_SLOTS = 8
@@ -439,11 +396,10 @@ def compute_diff_jk(dm, mol, nprog, nbatch, tolerance, backend):
439
396
440
397
temp = 0
441
398
for zip_counter , (eri , idx ) in enumerate (zip (all_eris , all_indices )):
399
+ # go from our memory layout to mol.intor("int2e_sph", "s8")
442
400
443
- # comp_list_index = 0
401
+ shell_size = eri . shape [ - 1 ] # save original tensor shape
444
402
445
- # go from our memory layout to mol.intor("int2e_sph", "s8")
446
- # for zip_counter, (eri, idx) in enumerate(zip(all_eris, all_indices)):
447
403
comp_distinct_idx_list = []
448
404
print (eri .shape )
449
405
for ind in range (eri .shape [0 ]):
@@ -454,7 +410,7 @@ def compute_diff_jk(dm, mol, nprog, nbatch, tolerance, backend):
454
410
_i0 :(_i0 + _di ),
455
411
_j0 :(_j0 + _dj ),
456
412
_k0 :(_k0 + _dk ),
457
- _l0 :(_l0 + _dl )].transpose (4 , 3 , 2 , 1 , 0 ).astype (np .int16 )
413
+ _l0 :(_l0 + _dl )].transpose (4 , 3 , 2 , 1 , 0 ) # .astype(np.int16)
458
414
459
415
comp_distinct_idx_list .append (block_idx .reshape (- 1 , 4 ))
460
416
# comp_list_index += 1
@@ -472,7 +428,9 @@ def compute_diff_jk(dm, mol, nprog, nbatch, tolerance, backend):
472
428
print ('padding' , remainder , nprog * nbatch - remainder , comp_distinct_idx .shape )
473
429
comp_distinct_idx = np .pad (comp_distinct_idx , ((0 , nprog * nbatch - remainder ), (0 , 0 )))
474
430
eri = jnp .pad (eri .reshape (- 1 ), ((0 , nprog * nbatch - remainder )))
475
-
431
+
432
+ print ('eri.shape' , eri .shape )
433
+ print ('comp_distinct_idx.shape' , comp_distinct_idx .shape )
476
434
477
435
# output of mol.intor("int2e_ssph", aosym="s8")
478
436
comp_distinct_ERI = eri .reshape (nprog , nbatch , - 1 ) #jnp.concatenate([eri.reshape(-1) for eri in all_eris]).reshape(nprog, nbatch, -1)
@@ -498,50 +456,77 @@ def compute_diff_jk(dm, mol, nprog, nbatch, tolerance, backend):
498
456
diff_JK = jnp .zeros (dm .shape )
499
457
N = int (np .sqrt (dm .shape [0 ]))
500
458
501
- def iteration (i , vals ):
502
- diff_JK = vals
459
+ def foreach_batch (i , vals ):
460
+ diff_JK , nonzero_indices , ao_loc = vals
503
461
504
462
indices = nonzero_indices [i ]
505
463
506
464
# indices = jax.lax.bitcast_convert_type(indices, np.int16).astype(np.int32)
507
465
indices = indices .astype (jnp .int32 )
508
466
eris = nonzero_distinct_ERI [i ]
509
467
print (indices .shape )
468
+
469
+ # exp_distinct_idx = jnp.zeros((eris.shape[0], 4), dtype=jnp.int32)
470
+
471
+ # # exp_distinct_idx_list = []
472
+ # print(eris.shape)
473
+ # # for ind in range(eris.shape[0]):
474
+ # def gen_all_shell_idx(idx, arr):
475
+ # i, j, k, l = [indices[ind, z] for z in range(4)]
476
+ # _di, _dj, _dk, _dl = ao_loc[i+1] - ao_loc[i], ao_loc[j+1] - ao_loc[j], ao_loc[k+1] - ao_loc[k], ao_loc[l+1] - ao_loc[l]
477
+ # _i0, _j0, _k0, _l0 = ao_loc[i], ao_loc[j], ao_loc[k], ao_loc[l]
478
+
479
+ # block_idx = jnp.zeros((shell_size, 4), dtype=jnp.int32)
480
+
481
+ # def gen_shell_idx(idx_sh, arr):
482
+ # # Compute the indices
483
+ # ind_i = (idx_sh ) % _di + _i0
484
+ # ind_j = (idx_sh // (_di) ) % _dj + _j0
485
+ # ind_k = (idx_sh // (_di*_dj) ) % _dk + _k0
486
+ # ind_l = (idx_sh // (_di*_dj*_dk)) % _dl + _l0
487
+
488
+ # # Update the array with the computed indices
489
+ # return arr.at[idx_sh, :].set(jnp.array([ind_i, ind_j, ind_k, ind_l]))
490
+
491
+ # block_idx = jax.lax.fori_loop(0, shell_size, gen_shell_idx, block_idx)
492
+
493
+ # # return arr.at[idx*shell_size:(idx+1)*shell_size, :].set(block_idx)
494
+ # return jax.lax.dynamic_update_slice(arr, block_idx, (idx*shell_size, 0))
495
+
496
+ # exp_distinct_idx = jax.lax.fori_loop(0, eris.shape[0]//shell_size, gen_all_shell_idx, exp_distinct_idx) #jnp.concatenate(exp_distinct_idx_list)
497
+
498
+ # indices = exp_distinct_idx
510
499
511
500
512
- def sequentialized_iter ( symmetry , vals ):
501
+ def foreach_symmetry ( sym , vals ):
513
502
# Generalized J/K computation: does J when symmetry is in range(0,8) and K when symmetry is in range(8,16)
514
503
# Trade-off: Using one function leads to smaller always-live memory.
515
- is_K_matrix = (symmetry >= 8 )
516
-
504
+ is_K_matrix = (sym >= 8 )
517
505
diff_JK = vals
518
506
519
-
520
-
521
- if backend == "cpu" : dm_indices = cpu_ijkl (indices , symmetry + is_K_matrix * 8 , indices_func )
522
- else : dm_indices = ipu_ijkl (indices , symmetry + is_K_matrix * 8 , N )
507
+ if backend == "cpu" : dm_indices = cpu_ijkl (indices , sym + is_K_matrix * 8 , indices_func )
508
+ else : dm_indices = ipu_ijkl (indices , sym + is_K_matrix * 8 , N )
523
509
dm_values = jnp .take (dm , dm_indices , axis = 0 )
524
510
525
- print ('nonzero_distinct_ERI .shape' , nonzero_distinct_ERI .shape )
511
+ print ('indices .shape' , indices .shape )
526
512
print ('dm_values.shape' , dm_values .shape )
527
513
print ('eris.shape' , eris .shape )
528
514
dm_values = dm_values .at [:].mul ( eris ) # this is prod, but re-use variable for inplace update.
529
515
530
- if backend == "cpu" : ss_indices = cpu_ijkl (indices , symmetry + 8 + is_K_matrix * 8 , indices_func )
531
- else : ss_indices = ipu_ijkl (indices , symmetry + 8 + is_K_matrix * 8 , N )
516
+ if backend == "cpu" : ss_indices = cpu_ijkl (indices , sym + 8 + is_K_matrix * 8 , indices_func )
517
+ else : ss_indices = ipu_ijkl (indices , sym + 8 + is_K_matrix * 8 , N )
532
518
diff_JK = diff_JK + jax .ops .segment_sum (dm_values , ss_indices , N ** 2 ) * (- HYB_B3LYP / 2 )** is_K_matrix
533
519
534
520
return diff_JK
535
521
536
522
537
- # diff_JK = jax.lax.fori_loop(0, batches, sequentialized_iter, diff_JK)
538
- diff_JK = jax .lax .fori_loop (0 , 16 , sequentialized_iter , diff_JK )
539
- # diff_JK = sequentialized_iter(0, diff_JK)
540
- return diff_JK
523
+ diff_JK = jax .lax .fori_loop (0 , 16 , foreach_symmetry , diff_JK )
524
+
525
+ return (diff_JK , nonzero_indices , ao_loc )
541
526
542
527
batches = nonzero_indices .shape [0 ] # before pmap, tensor had shape (nipus, batches, -1) so [0]=batches after pmap
543
- for bi in range ( batches ):
544
- diff_JK = iteration ( bi , diff_JK )
528
+
529
+ diff_JK , _ , _ = jax . lax . fori_loop ( 0 , batches , foreach_batch , ( diff_JK , nonzero_indices , ao_loc ))
545
530
546
531
temp += diff_JK
547
532
0 commit comments