Skip to content

Commit

Permalink
changed to ERI s8 N^2 value sampling for faster testing; removed unne…
Browse files Browse the repository at this point in the history
…cessary code slowing program down; added integral value tolerance as program argument
  • Loading branch information
mihaipgc committed Sep 28, 2023
1 parent 209eb33 commit f628757
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 32 deletions.
16 changes: 11 additions & 5 deletions pyscf_ipu/nanoDFT/compute_eri_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def get_shapes(input_ijkl, bas):

return len, nf, buflen

def prepare_integrals_2_inputs(mol):
def prepare_integrals_2_inputs(mol, itol):
# Shapes/sizes.
atm, bas, env = mol._atm, mol._bas, mol._env
n_atm, n_bas, N = atm.shape[0], bas.shape[0], mol.nao_nr()
Expand Down Expand Up @@ -149,8 +149,10 @@ def prepare_integrals_2_inputs(mol):
screened_indices_s8_4d = []

if USE_TOLERANCE_THRESHOLD:
tolerance = 1e-9
ERI = mol.intor('int2e_sph')
tolerance = itol
print('computing ERI s8 to sample N*(N+1)/2 values... ', end='')
ERI_s8 = mol.intor('int2e_sph', aosym='s8')
print('done')

# sample symmetry pattern and do safety check
# if N % 2 == 0:
Expand All @@ -175,7 +177,9 @@ def prepare_integrals_2_inputs(mol):
I_max = 0
tril_idx = np.tril_indices(N)
for a, b in zip(tril_idx[0], tril_idx[1]):
abab = np.abs(ERI[a,b,a,b])
index_ab_s8 = a*(a+1)//2 + b
index_s8 = index_ab_s8*(index_ab_s8+3)//2
abab = np.abs(ERI_s8[index_s8])
if abab > I_max:
I_max = abab

Expand All @@ -184,7 +188,9 @@ def prepare_integrals_2_inputs(mol):
tril_idx = np.tril_indices(N)
for a, b in zip(tril_idx[0], tril_idx[1]):
if USE_TOLERANCE_THRESHOLD:
abab = np.abs(ERI[a,b,a,b])
index_ab_s8 = a*(a+1)//2 + b
index_s8 = index_ab_s8*(index_ab_s8+3)//2
abab = np.abs(ERI_s8[index_s8])
if abab*I_max>=tolerance**2:
considered_indices.append((a, b)) # collect candidate pairs for s8
else:
Expand Down
55 changes: 28 additions & 27 deletions pyscf_ipu/nanoDFT/sparse_symmetric_intor_ERI.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,8 @@ def sequentialized_iter(i, vals):



def compute_eri(mol):
input_floats, input_ints, input_ijkl, shapes, sizes, counts, indxs, indxs_inv, ao_loc = prepare_integrals_2_inputs(mol)
def compute_eri(mol, itol):
input_floats, input_ints, input_ijkl, shapes, sizes, counts, indxs, indxs_inv, ao_loc = prepare_integrals_2_inputs(mol, itol)

# Load vertex using TileJax.
vertex_filename = osp.join(osp.dirname(__file__), "intor_int2e_sph.cpp")
Expand Down Expand Up @@ -266,12 +266,12 @@ def compute_eri(mol):

return all_outputs, all_indices, ao_loc

def compute_diff_jk(dm, mol, nprog, nbatch, backend):
def compute_diff_jk(dm, mol, nprog, nbatch, itol, backend):
dm = dm.reshape(-1)
diff_JK = jnp.zeros(dm.shape)
N = int(np.sqrt(dm.shape[0]))

all_eris, all_indices, ao_loc = compute_eri(mol)
all_eris, all_indices, ao_loc = compute_eri(mol, itol)

BLOCK_ERI_SIZE = np.sum(np.array([eri.shape[0] for eri in all_eris]))

Expand Down Expand Up @@ -380,6 +380,7 @@ def ijkl2c(i, j, k, l):
parser.add_argument('-batches', default=5, type=int)
parser.add_argument('-nipu', default=16, type=int)
parser.add_argument('-skip', action="store_true")
parser.add_argument('-itol', default=1e-9, type=float)

args = parser.parse_args()
backend = args.backend
Expand Down Expand Up @@ -583,35 +584,35 @@ def ijkl2c(i, j, k, l):
print(nonzero_distinct_ERI.shape)
print(nonzero_indices.shape)

distinct_idx = np.zeros((distinct_ERI.shape[0], 4))
for c in range(distinct_ERI.shape[0]):
c = np.array(c).astype(np.int32)
ij, kl = get_i_j(c)
i, j = get_i_j(ij)
k, l = get_i_j(kl)
distinct_idx[c, 0] = i
distinct_idx[c, 1] = j
distinct_idx[c, 2] = k
distinct_idx[c, 3] = l
distinct_idx = distinct_idx.reshape(1, -1, 4).astype(np.uint64)
# distinct_idx = np.zeros((distinct_ERI.shape[0], 4))
# for c in range(distinct_ERI.shape[0]):
# c = np.array(c).astype(np.int32)
# ij, kl = get_i_j(c)
# i, j = get_i_j(ij)
# k, l = get_i_j(kl)
# distinct_idx[c, 0] = i
# distinct_idx[c, 1] = j
# distinct_idx[c, 2] = k
# distinct_idx[c, 3] = l
# distinct_idx = distinct_idx.reshape(1, -1, 4).astype(np.uint64)

distinct_ERI = distinct_ERI.astype(np.float32)
# drep = num_repetitions_fast_4d(distinct_idx[:, :, 0], distinct_idx[:, :, 1], distinct_idx[:, :, 2], distinct_idx[:, :, 3])
# distinct_ERI = distinct_ERI / drep
distinct_ERI = distinct_ERI.reshape(1, -1)
# distinct_ERI = distinct_ERI.astype(np.float32)
# # drep = num_repetitions_fast_4d(distinct_idx[:, :, 0], distinct_idx[:, :, 1], distinct_idx[:, :, 2], distinct_idx[:, :, 3])
# # distinct_ERI = distinct_ERI / drep
# distinct_ERI = distinct_ERI.reshape(1, -1)

# print('distinct_ERI', distinct_ERI)
# print('nonzero_distinct_ERI', nonzero_distinct_ERI)
# # print('distinct_ERI', distinct_ERI)
# # print('nonzero_distinct_ERI', nonzero_distinct_ERI)

# print('distinct_idx', distinct_idx.reshape(-1, 4))
# print('nonzero_indices', nonzero_indices_bk.reshape(-1, 4))
# # print('distinct_idx', distinct_idx.reshape(-1, 4))
# # print('nonzero_indices', nonzero_indices_bk.reshape(-1, 4))

distinct_idx = jax.lax.bitcast_convert_type(distinct_idx.astype(np.int16), np.float16)
# distinct_idx = jax.lax.bitcast_convert_type(distinct_idx.astype(np.int16), np.float16)

print('distinct_ERI.shape', distinct_ERI.shape)
print('distinct_idx.shape', distinct_idx.shape)
# print('distinct_ERI.shape', distinct_ERI.shape)
# print('distinct_idx.shape', distinct_idx.shape)

diff_JK = jax.jit(compute_diff_jk, backend=backend, static_argnames=['mol', 'nprog', 'nbatch', 'backend'])(dm, mol, args.nipu, args.batches, args.backend)
diff_JK = jax.jit(compute_diff_jk, backend=backend, static_argnames=['mol', 'nprog', 'nbatch', 'itol', 'backend'])(dm, mol, args.nipu, args.batches, args.itol, args.backend)
# diff_JK = jax.jit(compute_diff_jk, backend=backend, static_argnames=['mol', 'backend'])(nonzero_distinct_ERI.reshape(1, -1), nonzero_indices.reshape(1, -1, 4), mol, dm, args.backend)

# exit()
Expand Down

0 comments on commit f628757

Please sign in to comment.