Skip to content

Commit

Permalink
isolated prescreening and orthogonal zero removal with dynamic gather…
Browse files Browse the repository at this point in the history
…ing to estimated final shape; implemented batch mode for ortho; work in progress
  • Loading branch information
mihaipgc committed Oct 3, 2023
1 parent 9681f73 commit f0102cd
Show file tree
Hide file tree
Showing 2 changed files with 152 additions and 3 deletions.
101 changes: 100 additions & 1 deletion pyscf_ipu/nanoDFT/compute_eri_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,22 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
import numpy as np
import numpy as np
import jax
import jax.numpy as jnp

def reconstruct_ERI(ERI, nonzero_idx, N, sym=True):
i, j, k, l = nonzero_idx[:, 0], nonzero_idx[:, 1], nonzero_idx[:, 2], nonzero_idx[:, 3]
rec_ERI = np.zeros((N, N, N, N))
rec_ERI[i, j, k, l] = ERI[i, j, k, l]
if sym:
rec_ERI[j, i, k, l] = ERI[j, i, k, l]
rec_ERI[i, j, l, k] = ERI[i, j, l, k]
rec_ERI[j, i, l, k] = ERI[j, i, l, k]
rec_ERI[k, l, i, j] = ERI[k, l, i, j]
rec_ERI[k, l, j, i] = ERI[k, l, j, i]
rec_ERI[l, k, i, j] = ERI[l, k, i, j]
rec_ERI[l, k, j, i] = ERI[l, k, j, i]

return rec_ERI

def inverse_permutation(a):
b = np.arange(a.shape[0])
Expand Down Expand Up @@ -106,6 +123,88 @@ def get_shapes(input_ijkl, bas):

return len, nf, buflen

def prescreening(mol, itol, dtype=jnp.int16):
assert dtype in [jnp.int16, jnp.int32, jnp.uint32]

# get molecule info
bas, env = mol._bas, mol._env
n_bas, N = bas.shape[0], mol.nao_nr()

tolerance = itol
print('computing ERI s8 to sample N*(N+1)/2 values... ', end='')
ERI_s8 = mol.intor('int2e_sph', aosym='s8')
print('done')

# find max value
I_max = 0
tril_idx = np.tril_indices(N)
for a, b in zip(tril_idx[0], tril_idx[1]):
index_ab_s8 = a*(a+1)//2 + b
index_s8 = index_ab_s8*(index_ab_s8+3)//2
abab = np.abs(ERI_s8[index_s8])
if abab > I_max:
I_max = abab

# collect candidate pairs for s8
considered_indices = []
tril_idx = np.tril_indices(N)
for a, b in zip(tril_idx[0], tril_idx[1]):
index_ab_s8 = a*(a+1)//2 + b
index_s8 = index_ab_s8*(index_ab_s8+3)//2
abab = np.abs(ERI_s8[index_s8])
if abab*I_max>=tolerance**2:
considered_indices.append((a, b)) # collect candidate pairs for s8

screened_indices_s8_4d = np.zeros(((len(considered_indices)*(len(considered_indices)+1)//2), 4), dtype=dtype)

# generate s8 indices
sid = 0
for index, ab in enumerate(considered_indices):
a, b = ab
for cd in considered_indices[index:]:
c, d = cd
screened_indices_s8_4d[sid, :] = (a, b, c, d)
sid += 1

return screened_indices_s8_4d

def remove_ortho(arr, nonzero_pattern, output_size, dtype=jnp.int16):
assert dtype in [jnp.int16, jnp.int32, jnp.uint32]
if dtype == jnp.int16: reinterpret_dtype = jnp.float16
else: reinterpret_dtype = None

def condition(i, j, k, l):
return ~(nonzero_pattern[i] ^ nonzero_pattern[j]) ^ (nonzero_pattern[k] ^ nonzero_pattern[l])

def body_fun(carry, x):
results, counter = carry
x_reinterpret = jax.lax.bitcast_convert_type(x, dtype).astype(jnp.uint32)
i, j, k, l = x_reinterpret

def update_vals(carry):
res, count, v = carry
res = res.at[count].set(v)
count = count + 1
return res, count

results, counter = jax.lax.cond(condition(i, j, k, l), (results, counter, x), update_vals, (results, counter), lambda x: x)
return (results, counter), ()


init_results = jnp.zeros((output_size, arr.shape[1]), dtype=dtype)
init_count = jnp.array(0, dtype=jnp.int32)

if reinterpret_dtype is not None:
init_results = jax.lax.bitcast_convert_type(init_results, reinterpret_dtype)
arr = jax.lax.bitcast_convert_type(arr, reinterpret_dtype)

(final_results, _), _ = jax.lax.scan(body_fun, (init_results, init_count), arr)

final_results = jax.lax.bitcast_convert_type(final_results, dtype)

return final_results
vmap_remove_ortho = jax.vmap(remove_ortho, in_axes=(0, None, None), out_axes=0)

def prepare_integrals_2_inputs(mol, itol):
# Shapes/sizes.
atm, bas, env = mol._atm, mol._bas, mol._env
Expand Down
54 changes: 52 additions & 2 deletions pyscf_ipu/nanoDFT/sparse_symmetric_intor_ERI.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from tessellate_ipu import create_ipu_tile_primitive, ipu_cycle_count, tile_map, tile_put_sharded, tile_put_replicated
from functools import partial
from icecream import ic
from compute_eri_utils import prepare_integrals_2_inputs
from compute_eri_utils import prepare_integrals_2_inputs, prescreening, remove_ortho, vmap_remove_ortho, reconstruct_ERI
jax.config.update('jax_platform_name', "cpu")
#jax.config.update('jax_enable_x64', True)
HYB_B3LYP = 0.2
Expand Down Expand Up @@ -458,12 +458,62 @@ def ijkl2c(i, j, k, l):
dm = dm.reshape(-1)
diff_JK = np.zeros(dm.shape)

# --------------------------------------------------------------------------------------- #
# Prescreening to compute relevant indices
comp_prescreened_distinct_idx = jax.jit(prescreening, backend=backend, static_argnames=['mol', 'itol', 'dtype'])(mol, args.itol)

if not args.skip:
print('ERI stats after prescreening:')
rec_ERI = reconstruct_ERI(dense_ERI, np.array(comp_prescreened_distinct_idx, dtype=np.uint32), N)
absdiff = np.abs(dense_ERI-rec_ERI)
print('avg error:', np.mean(absdiff))
print('std error:', np.std(absdiff))
print('max error:', np.max(absdiff))

# --------------------------------------------------------------------------------------- #
# Filter to remove orthogonal zeros
nonzero_pattern = np.array([(i+3)%5!=0 for i in range(N)]).astype(bool)
num_nonzeros = int(2*len(comp_prescreened_distinct_idx)*(np.count_nonzero(nonzero_pattern.reshape(N, 1) ^ nonzero_pattern.reshape(1, N))/(N*N)) + 1)
print('Estimating', num_nonzeros, 'nonzeros out of', len(comp_prescreened_distinct_idx), 'prescreened values')
print('actual nonzero_indices', len(nonzero_indices))
print('comp_prescreened_distinct_idx', len(comp_prescreened_distinct_idx))

comp_nonzero_distinct_idx = jax.jit(remove_ortho, backend=backend, static_argnames=['output_size', 'dtype'])(comp_prescreened_distinct_idx, nonzero_pattern, num_nonzeros)

if not args.skip:
print('ERI stats after orthogonal zero removal:')
rec_ERI = reconstruct_ERI(dense_ERI, np.array(comp_nonzero_distinct_idx, dtype=np.uint32), N)
absdiff = np.abs(dense_ERI-rec_ERI)
print('shape', comp_nonzero_distinct_idx.shape)
print('avg error:', np.mean(absdiff))
print('std error:', np.std(absdiff))
print('max error:', np.max(absdiff))

num_batches = 2
if comp_prescreened_distinct_idx.shape[0] % num_batches > 0:
comp_prescreened_distinct_idx = jnp.pad(comp_prescreened_distinct_idx, ((0, num_batches-comp_prescreened_distinct_idx.shape[0] % num_batches), (0, 0)))
vmap_comp_prescreened_distinct_idx = comp_prescreened_distinct_idx.reshape(num_batches, -1, 4)
vmap_num_nonzeros = num_nonzeros // num_batches + 1
comp_nonzero_distinct_idx = jax.jit(vmap_remove_ortho, backend=backend, static_argnames=['output_size', 'dtype'])(vmap_comp_prescreened_distinct_idx, nonzero_pattern, vmap_num_nonzeros)
comp_nonzero_distinct_idx = comp_nonzero_distinct_idx.reshape(-1, 4)

if not args.skip:
print('ERI stats after orthogonal zero removal:')
rec_ERI = reconstruct_ERI(dense_ERI, np.array(comp_nonzero_distinct_idx, dtype=np.uint32), N)
absdiff = np.abs(dense_ERI-rec_ERI)
print('shape', comp_nonzero_distinct_idx.shape)
print('avg error:', np.mean(absdiff))
print('std error:', np.std(absdiff))
print('max error:', np.max(absdiff))

exit()

# --------------------------------------------------------------------------------------- #
# Compute ERI values from mol; returned values are already:
# - scaled to account for repetitions in JK computation
# - padded/reshaped to fit the required number of devices and batches

comp_distinct_ERI, comp_distinct_idx = jax.jit(compute_nonzero_distinct_ERI,backend=backend, static_argnames=['mol', 'nprog', 'nbatch', 'itol', 'backend'])(mol, args.nipu, args.batches, args.itol, args.backend)
comp_distinct_ERI, comp_distinct_idx = jax.jit(compute_nonzero_distinct_ERI, backend=backend, static_argnames=['mol', 'nprog', 'nbatch', 'itol', 'backend'])(mol, args.nipu, args.batches, args.itol, args.backend)
# comp_distinct_ERI, comp_distinct_idx = compute_nonzero_distinct_ERI(mol, args.nipu, args.batches, args.itol, args.backend)

# --------------------------------------------------------------------------------------- #
Expand Down

0 comments on commit f0102cd

Please sign in to comment.