diff --git a/pyscf_ipu/nanoDFT/compute_eri_utils.py b/pyscf_ipu/nanoDFT/compute_eri_utils.py index 921d111..e63717c 100644 --- a/pyscf_ipu/nanoDFT/compute_eri_utils.py +++ b/pyscf_ipu/nanoDFT/compute_eri_utils.py @@ -147,11 +147,20 @@ def prepare_integrals_2_inputs(mol): screened_indices_s8_4d = [] + tolerance = 1e-7 - # sample complete column as pattern seed + # sample symmetry pattern and do safety check ERI = mol.intor('int2e_sph') - nonzero_seed = ERI[N-1, N-1, :, 0] != 0 - tolerance = 1e-7 + if N % 2 == 0: + nonzero_seed = ERI[N-1, N-1, :N//2, 0] != 0 + nonzero_seed = np.concatenate([nonzero_seed, np.flip(nonzero_seed)]) + else: + nonzero_seed = ERI[N-1, N-1, :(N+1)//2, 0] != 0 + nonzero_seed = np.concatenate([nonzero_seed, np.flip(nonzero_seed[:-1])]) + if not np.equal(nonzero_seed, ERI[N-1, N-1, :, 0]).all(): + print('# -------------------------------------------------------------- #') + print('# WARNING: Experimental symmetry pattern sample is inconsistent. #') + print('# -------------------------------------------------------------- #') # print('test:') # for k in range(N): @@ -191,6 +200,9 @@ def prepare_integrals_2_inputs(mol): if ok: screened_indices_s8_4d.append((d, c, b, a)) + print('n_bas', n_bas) + print('ao_loc', ao_loc) + # Fill input_ijkl and output_sizes with the necessary indices. c = 0 for i in range(n_bas):