diff --git a/pyscf_ipu/nanoDFT/compute_eri_utils.py b/pyscf_ipu/nanoDFT/compute_eri_utils.py index 9c78414..50fecc9 100644 --- a/pyscf_ipu/nanoDFT/compute_eri_utils.py +++ b/pyscf_ipu/nanoDFT/compute_eri_utils.py @@ -106,7 +106,7 @@ def get_shapes(input_ijkl, bas): return len, nf, buflen -def prepare_integrals_2_inputs(mol): +def prepare_integrals_2_inputs(mol, itol): # Shapes/sizes. atm, bas, env = mol._atm, mol._bas, mol._env n_atm, n_bas, N = atm.shape[0], bas.shape[0], mol.nao_nr() @@ -149,8 +149,10 @@ def prepare_integrals_2_inputs(mol): screened_indices_s8_4d = [] if USE_TOLERANCE_THRESHOLD: - tolerance = 1e-9 - ERI = mol.intor('int2e_sph') + tolerance = itol + print('computing ERI s8 to sample N*(N+1)/2 values... ', end='') + ERI_s8 = mol.intor('int2e_sph', aosym='s8') + print('done') # sample symmetry pattern and do safety check # if N % 2 == 0: @@ -175,7 +177,9 @@ def prepare_integrals_2_inputs(mol): I_max = 0 tril_idx = np.tril_indices(N) for a, b in zip(tril_idx[0], tril_idx[1]): - abab = np.abs(ERI[a,b,a,b]) + index_ab_s8 = a*(a+1)//2 + b + index_s8 = index_ab_s8*(index_ab_s8+3)//2 + abab = np.abs(ERI_s8[index_s8]) if abab > I_max: I_max = abab @@ -184,7 +188,9 @@ def prepare_integrals_2_inputs(mol): tril_idx = np.tril_indices(N) for a, b in zip(tril_idx[0], tril_idx[1]): if USE_TOLERANCE_THRESHOLD: - abab = np.abs(ERI[a,b,a,b]) + index_ab_s8 = a*(a+1)//2 + b + index_s8 = index_ab_s8*(index_ab_s8+3)//2 + abab = np.abs(ERI_s8[index_s8]) if abab*I_max>=tolerance**2: considered_indices.append((a, b)) # collect candidate pairs for s8 else: diff --git a/pyscf_ipu/nanoDFT/sparse_symmetric_intor_ERI.py b/pyscf_ipu/nanoDFT/sparse_symmetric_intor_ERI.py index d0f33ca..94f1af3 100644 --- a/pyscf_ipu/nanoDFT/sparse_symmetric_intor_ERI.py +++ b/pyscf_ipu/nanoDFT/sparse_symmetric_intor_ERI.py @@ -134,8 +134,8 @@ def sequentialized_iter(i, vals): -def compute_eri(mol): - input_floats, input_ints, input_ijkl, shapes, sizes, counts, indxs, indxs_inv, ao_loc = prepare_integrals_2_inputs(mol) +def compute_eri(mol, itol): + input_floats, input_ints, input_ijkl, shapes, sizes, counts, indxs, indxs_inv, ao_loc = prepare_integrals_2_inputs(mol, itol) # Load vertex using TileJax. vertex_filename = osp.join(osp.dirname(__file__), "intor_int2e_sph.cpp") @@ -266,12 +266,12 @@ def compute_eri(mol): return all_outputs, all_indices, ao_loc -def compute_diff_jk(dm, mol, nprog, nbatch, backend): +def compute_diff_jk(dm, mol, nprog, nbatch, itol, backend): dm = dm.reshape(-1) diff_JK = jnp.zeros(dm.shape) N = int(np.sqrt(dm.shape[0])) - all_eris, all_indices, ao_loc = compute_eri(mol) + all_eris, all_indices, ao_loc = compute_eri(mol, itol) BLOCK_ERI_SIZE = np.sum(np.array([eri.shape[0] for eri in all_eris])) @@ -380,6 +380,7 @@ def ijkl2c(i, j, k, l): parser.add_argument('-batches', default=5, type=int) parser.add_argument('-nipu', default=16, type=int) parser.add_argument('-skip', action="store_true") + parser.add_argument('-itol', default=1e-9, type=float) args = parser.parse_args() backend = args.backend @@ -583,35 +584,35 @@ def ijkl2c(i, j, k, l): print(nonzero_distinct_ERI.shape) print(nonzero_indices.shape) - distinct_idx = np.zeros((distinct_ERI.shape[0], 4)) - for c in range(distinct_ERI.shape[0]): - c = np.array(c).astype(np.int32) - ij, kl = get_i_j(c) - i, j = get_i_j(ij) - k, l = get_i_j(kl) - distinct_idx[c, 0] = i - distinct_idx[c, 1] = j - distinct_idx[c, 2] = k - distinct_idx[c, 3] = l - distinct_idx = distinct_idx.reshape(1, -1, 4).astype(np.uint64) + # distinct_idx = np.zeros((distinct_ERI.shape[0], 4)) + # for c in range(distinct_ERI.shape[0]): + # c = np.array(c).astype(np.int32) + # ij, kl = get_i_j(c) + # i, j = get_i_j(ij) + # k, l = get_i_j(kl) + # distinct_idx[c, 0] = i + # distinct_idx[c, 1] = j + # distinct_idx[c, 2] = k + # distinct_idx[c, 3] = l + # distinct_idx = distinct_idx.reshape(1, -1, 4).astype(np.uint64) - distinct_ERI = distinct_ERI.astype(np.float32) - # drep = num_repetitions_fast_4d(distinct_idx[:, :, 0], distinct_idx[:, :, 1], distinct_idx[:, :, 2], distinct_idx[:, :, 3]) - # distinct_ERI = distinct_ERI / drep - distinct_ERI = distinct_ERI.reshape(1, -1) + # distinct_ERI = distinct_ERI.astype(np.float32) + # # drep = num_repetitions_fast_4d(distinct_idx[:, :, 0], distinct_idx[:, :, 1], distinct_idx[:, :, 2], distinct_idx[:, :, 3]) + # # distinct_ERI = distinct_ERI / drep + # distinct_ERI = distinct_ERI.reshape(1, -1) - # print('distinct_ERI', distinct_ERI) - # print('nonzero_distinct_ERI', nonzero_distinct_ERI) + # # print('distinct_ERI', distinct_ERI) + # # print('nonzero_distinct_ERI', nonzero_distinct_ERI) - # print('distinct_idx', distinct_idx.reshape(-1, 4)) - # print('nonzero_indices', nonzero_indices_bk.reshape(-1, 4)) + # # print('distinct_idx', distinct_idx.reshape(-1, 4)) + # # print('nonzero_indices', nonzero_indices_bk.reshape(-1, 4)) - distinct_idx = jax.lax.bitcast_convert_type(distinct_idx.astype(np.int16), np.float16) + # distinct_idx = jax.lax.bitcast_convert_type(distinct_idx.astype(np.int16), np.float16) - print('distinct_ERI.shape', distinct_ERI.shape) - print('distinct_idx.shape', distinct_idx.shape) + # print('distinct_ERI.shape', distinct_ERI.shape) + # print('distinct_idx.shape', distinct_idx.shape) - diff_JK = jax.jit(compute_diff_jk, backend=backend, static_argnames=['mol', 'nprog', 'nbatch', 'backend'])(dm, mol, args.nipu, args.batches, args.backend) + diff_JK = jax.jit(compute_diff_jk, backend=backend, static_argnames=['mol', 'nprog', 'nbatch', 'itol', 'backend'])(dm, mol, args.nipu, args.batches, args.itol, args.backend) # diff_JK = jax.jit(compute_diff_jk, backend=backend, static_argnames=['mol', 'backend'])(nonzero_distinct_ERI.reshape(1, -1), nonzero_indices.reshape(1, -1, 4), mol, dm, args.backend) # exit()