Skip to content

Commit

Permalink
cleaned code; improved index generation loops; added flag to use thre…
Browse files Browse the repository at this point in the history
…shold
  • Loading branch information
mihaipgc committed Sep 28, 2023
1 parent 05ced9c commit 209eb33
Showing 1 changed file with 84 additions and 205 deletions.
289 changes: 84 additions & 205 deletions pyscf_ipu/nanoDFT/compute_eri_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,240 +133,119 @@ def prepare_integrals_2_inputs(mol):
# Compute how many distinct integrals after 8x symmetry.
num_calls = 0
for i in range(n_bas):
for j in range(n_bas):
for k in range(n_bas):
for l in range(n_bas):
# * 8-fold symmetry, k>=l, k>=i>=j,
if not ( i >= j and k >= l and i*j >= k * l): continue
num_calls+=1
for j in range(i+1):
for k in range(i, n_bas):
for l in range(k+1):
# almost all 8-fold symmetry rules (except horizontal fade)
num_calls+=1
print('num_calls', num_calls)

# Input/outputs for calling the IPU vertex.
input_ijkl = np.zeros((num_calls, 4), dtype=np.int32)
cpu_output = np.zeros((num_calls, n_eri), dtype=np.float32)
output_sizes = np.zeros((num_calls, 5))


USE_TOLERANCE_THRESHOLD = True
screened_indices_s8_4d = []
tolerance = 1e-9

# sample symmetry pattern and do safety check
ERI = mol.intor('int2e_sph')
# if N % 2 == 0:
# nonzero_seed = ERI[N-1, N-1, :N//2, 0] != 0
# nonzero_seed = np.concatenate([nonzero_seed, np.flip(nonzero_seed)])
# else:
# nonzero_seed = ERI[N-1, N-1, :(N+1)//2, 0] != 0
# nonzero_seed = np.concatenate([nonzero_seed, np.flip(nonzero_seed[:-1])])
# if not np.equal(nonzero_seed, ERI[N-1, N-1, :, 0]!=0).all():
# print('# -------------------------------------------------------------- #')
# print('# WARNING: Experimental symmetry pattern sample is inconsistent. #')
# print('pred', nonzero_seed)
# print('real', ERI[N-1, N-1, :, 0]!=0)
# print('# -------------------------------------------------------------- #')

if USE_TOLERANCE_THRESHOLD:
tolerance = 1e-9
ERI = mol.intor('int2e_sph')

# sample symmetry pattern and do safety check
# if N % 2 == 0:
# nonzero_seed = ERI[N-1, N-1, :N//2, 0] != 0
# nonzero_seed = np.concatenate([nonzero_seed, np.flip(nonzero_seed)])
# else:
# nonzero_seed = ERI[N-1, N-1, :(N+1)//2, 0] != 0
# nonzero_seed = np.concatenate([nonzero_seed, np.flip(nonzero_seed[:-1])])
# if not np.equal(nonzero_seed, ERI[N-1, N-1, :, 0]!=0).all():
# print('# -------------------------------------------------------------- #')
# print('# WARNING: Experimental symmetry pattern sample is inconsistent. #')
# print('pred', nonzero_seed)
# print('real', ERI[N-1, N-1, :, 0]!=0)
# print('# -------------------------------------------------------------- #')

# hardcoded symmetry pattern
sym_pattern = np.array([(i+3)%5!=0 for i in range(N)])
nonzero_seed = sym_pattern

# print('test:')
# for k in range(N):
# for l in range(k+1):
# is_nonzero = ~(nonzero_seed[k] ^ nonzero_seed[l]) # not XOR
# print(is_nonzero, end=' ')
# print()
# exit()

# find max value
I_max = 0
tril_idx = np.tril_indices(N)
for a, b in zip(tril_idx[0], tril_idx[1]):
abab = np.abs(ERI[a,b,a,b])
if abab > I_max:
I_max = abab
if USE_TOLERANCE_THRESHOLD:
# find max value
I_max = 0
tril_idx = np.tril_indices(N)
for a, b in zip(tril_idx[0], tril_idx[1]):
abab = np.abs(ERI[a,b,a,b])
if abab > I_max:
I_max = abab

# collect candidate pairs for s8
considered_indices = []
tril_idx = np.tril_indices(N)
for a, b in zip(tril_idx[0], tril_idx[1]):
abab = np.abs(ERI[a,b,a,b])
if abab*I_max>=tolerance**2:
if USE_TOLERANCE_THRESHOLD:
abab = np.abs(ERI[a,b,a,b])
if abab*I_max>=tolerance**2:
considered_indices.append((a, b)) # collect candidate pairs for s8
else:
considered_indices.append((a, b)) # collect candidate pairs for s8
# considered_indices = set(considered_indices)

# generate s8 indices
# for index, ab in enumerate(considered_indices):
# a, b = ab
# for cd in considered_indices[index:]:
# c, d = cd
# # if b<=d:
# ok = True
# if ~(nonzero_seed[d] ^ nonzero_seed[c]):
# ok = ~(nonzero_seed[b] ^ nonzero_seed[a])
# else:
# ok = (nonzero_seed[b] ^ nonzero_seed[a])
# if ok:
# screened_indices_s8_4d.append((d, c, b, a))
considered_indices = set(considered_indices)

print('n_bas', n_bas)
print('ao_loc', ao_loc)

# ---------------------------------------------------------------------- #
# OLD CODE !!!
if True:
# generate s8 indices
for index, ab in enumerate(considered_indices):
a, b = ab
for cd in considered_indices[index:]:
c, d = cd
# if b<=d:
ok = True
if ~(nonzero_seed[d] ^ nonzero_seed[c]):
ok = ~(nonzero_seed[b] ^ nonzero_seed[a])
else:
ok = (nonzero_seed[b] ^ nonzero_seed[a])
if ok:
screened_indices_s8_4d.append((d, c, b, a))

# Fill input_ijkl and output_sizes with the necessary indices.
c = 0
for i in range(n_bas):
for j in range(n_bas):
for k in range(n_bas):
for l in range(n_bas):
di = ao_loc[i+1] - ao_loc[i]
dj = ao_loc[j+1] - ao_loc[j]
dk = ao_loc[k+1] - ao_loc[k]
dl = ao_loc[l+1] - ao_loc[l]
skip = True # !!!!
for ni, nj, nk, nl in screened_indices_s8_4d:
if ao_loc[i] <= ni and ni < ao_loc[i+1] and \
ao_loc[j] <= nj and nj < ao_loc[j+1] and \
ao_loc[k] <= nk and nk < ao_loc[k+1] and \
ao_loc[l] <= nl and nl < ao_loc[l+1]:
skip = False
break
if skip:
print('skipping', i, j, k, l, di*dj*dk*dl)
continue
# * 8-fold symmetry, k>=l, k>=i>=j,
# if not ( i >= j and k >= l and i*j >= k * l): continue



# print('>>>>>', i, j, k, l)



input_ijkl[c] = [i, j, k, l]

output_sizes[c] = [di, dj, dk, dl, di*dj*dk*dl]

c += 1
print('!!! saved', num_calls - c, 'calls i.e.', num_calls, '-', c)
print('vs')
# ---------------------------------------------------------------------- #

# Fill input_ijkl and output_sizes with the necessary indices.
n_items = 0
n_all_integrals = 0
n_new_integrals = 0
for i in range(n_bas):
for j in range(n_bas):
for k in range(n_bas):
for l in range(n_bas):
di = ao_loc[i+1] - ao_loc[i]
dj = ao_loc[j+1] - ao_loc[j]
dk = ao_loc[k+1] - ao_loc[k]
dl = ao_loc[l+1] - ao_loc[l]

ia, ib = ao_loc[i], ao_loc[i+1]
ja, jb = ao_loc[j], ao_loc[j+1]
ka, kb = ao_loc[k], ao_loc[k+1]
la, lb = ao_loc[l], ao_loc[l+1]

n_all_integrals += di*dj*dk*dl

found_nonzero = False
# check i,j boxes
for bi in range(ia, ib):
for bj in range(ja, jb):
if (bi, bj) in considered_indices: # if ij box is considered
# check if kl pairs are considered
for bk in range(ka, kb):
# mla = la
if bk>=bi: # apply symmetry - tril fade vertical
mla = la
if bk == bi:
mla = max(bj, la)
for bl in range(mla, lb):
if (bk, bl) in considered_indices:
# apply grid pattern to find final nonzeros
if ~(nonzero_seed[bi] ^ nonzero_seed[bj]) ^ (nonzero_seed[bk] ^ nonzero_seed[bl]):
found_nonzero = True
break
if found_nonzero: break
if found_nonzero: break
if found_nonzero: break
if not found_nonzero: continue




# skip = True # !!!!
# for ni, nj, nk, nl in screened_indices_s8_4d:
# if ao_loc[i] <= ni and ni < ao_loc[i+1] and \
# ao_loc[j] <= nj and nj < ao_loc[j+1] and \
# ao_loc[k] <= nk and nk < ao_loc[k+1] and \
# ao_loc[l] <= nl and nl < ao_loc[l+1]:
# skip = False
# break

# print('nonzero_seed[ao_loc[i]:ao_loc[i+1]]', nonzero_seed[ao_loc[i]:ao_loc[i+1]])
# print('nonzero_seed[ao_loc[j]:ao_loc[j+1]]', nonzero_seed[ao_loc[j]:ao_loc[j+1]])
# print('~(nonzero_seed[ao_loc[i]:ao_loc[i+1]] ^ nonzero_seed[ao_loc[j]:ao_loc[j+1]])', ~(nonzero_seed[ao_loc[i]:ao_loc[i+1]] ^ nonzero_seed[ao_loc[j]:ao_loc[j+1]]))
# print('(nonzero_seed[ao_loc[k]:ao_loc[k+1]] ^ nonzero_seed[ao_loc[l]:ao_loc[l+1]])', (nonzero_seed[ao_loc[k]:ao_loc[k+1]] ^ nonzero_seed[ao_loc[l]:ao_loc[l+1]]))
# print('>', nonzero_seed[ao_loc[i]:ao_loc[i+1]] ^ nonzero_seed[ao_loc[j]:ao_loc[j+1]])

# if not ((nonzero_seed[ao_loc[i]:ao_loc[i+1]] ^ nonzero_seed[ao_loc[j]:ao_loc[j+1]]) ^ \
# ~(nonzero_seed[ao_loc[k]:ao_loc[k+1]] ^ nonzero_seed[ao_loc[l]:ao_loc[l+1]])).any():
# print('skip ^^')
# continue


# found = False # !!!!
# for index, ab in enumerate(considered_indices):
# a, b = ab
# if a < ao_loc[i] or a >= ao_loc[i+1] or \
# b < ao_loc[j] or b >= ao_loc[j+1]:
# continue

# for cd in considered_indices:
# c, d = cd
# if c < ao_loc[k] or c >= ao_loc[k+1] or \
# d < ao_loc[l] or d >= ao_loc[l+1]:
# continue
# # found = True
# # break
# if ~(nonzero_seed[b] ^ nonzero_seed[a]) ^ (nonzero_seed[d] ^ nonzero_seed[c]):
# found = True
# break
# if found: break
# if not found:
# # print('skipping', i, j, k, l, di*dj*dk*dl)
# # print(considered_indices)
# continue

n_new_integrals += di*dj*dk*dl

input_ijkl[n_items] = [i, j, k, l]

output_sizes[n_items] = [di, dj, dk, dl, di*dj*dk*dl]

n_items += 1
for j in range(i+1):
for k in range(i, n_bas):
for l in range(k+1):
di = ao_loc[i+1] - ao_loc[i]
dj = ao_loc[j+1] - ao_loc[j]
dk = ao_loc[k+1] - ao_loc[k]
dl = ao_loc[l+1] - ao_loc[l]

ia, ib = ao_loc[i], ao_loc[i+1]
ja, jb = ao_loc[j], ao_loc[j+1]
ka, kb = ao_loc[k], ao_loc[k+1]
la, lb = ao_loc[l], ao_loc[l+1]

n_all_integrals += di*dj*dk*dl

found_nonzero = False
# check i,j boxes
for bi in range(ia, ib):
for bj in range(ja, jb):
if (bi, bj) in considered_indices: # if ij box is considered
# check if kl pairs are considered
for bk in range(ka, kb):
if bk>=bi: # apply symmetry - tril fade vertical
mla = la
if bk == bi:
mla = max(bj, la) # apply symmetry - tril fade horizontal
for bl in range(mla, lb):
if (bk, bl) in considered_indices:
# apply grid pattern to find final nonzeros
if ~(nonzero_seed[bi] ^ nonzero_seed[bj]) ^ (nonzero_seed[bk] ^ nonzero_seed[bl]):
found_nonzero = True
break
if found_nonzero: break
if found_nonzero: break
if found_nonzero: break
if not found_nonzero: continue

n_new_integrals += di*dj*dk*dl

input_ijkl[n_items] = [i, j, k, l]

output_sizes[n_items] = [di, dj, dk, dl, di*dj*dk*dl]

n_items += 1
print('!!! saved', num_calls - n_items, 'calls i.e.', num_calls, '-', n_items)
print('!!! saved', n_all_integrals - n_new_integrals, 'integrals i.e.', n_all_integrals, '-', n_new_integrals)

exit()

num_calls = n_items
input_ijkl = input_ijkl[:num_calls, :]
cpu_output = cpu_output[:num_calls, :]
Expand Down

0 comments on commit 209eb33

Please sign in to comment.