Skip to content

Commit

Permalink
further reduced memory and preparing for benchmarks
Browse files Browse the repository at this point in the history
  • Loading branch information
mihaipgc committed Nov 2, 2023
1 parent 4e36058 commit 40eb71c
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 41 deletions.
82 changes: 47 additions & 35 deletions pyscf_ipu/nanoDFT/nanoDFT.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from pyscf_ipu.nanoDFT.sparse_symmetric_intor_ERI import compute_diff_jk, gen_shells

HARTREE_TO_EV = 27.2114079527
HARTREE_TO_EV = 1 #27.2114079527
EPSILON_B3LYP = 1e-20
HYB_B3LYP = 0.2

Expand Down Expand Up @@ -144,6 +144,8 @@ def get_JK(grid_ijkl, density_matrix, ERI, dense_ERI, tolerance, backend, mol, n
return diff_JK

def _nanoDFT(state, ERI, grid_AO, grid_weights, grid_ijkl, opts, mol):
if opts.backend == "ipu": grid_weights, start = utils.get_ipu_cycles(grid_weights)

# Utilize the IPUs MIMD parallism to compute the electron repulsion integrals (ERIs) in parallel.
#if opts.backend == "ipu": state.ERI = electron_repulsion_integrals(state.input_floats, state.input_ints, mol, opts.threads_int, opts.intv)
#else: pass # Compute on CPU.
Expand All @@ -162,7 +164,12 @@ def _nanoDFT(state, ERI, grid_AO, grid_weights, grid_ijkl, opts, mol):
log = jax.lax.fori_loop(0, opts.its, partial(nanoDFT_iteration, opts=opts, mol=mol), [state.density_matrix, V_xc, diff_JK, state.O, H_core, state.L_inv, # all (N, N) matrices
state.E_nuc, state.mask, ERI, grid_weights, grid_AO, grid_ijkl, state.diis_history, log])[-1]

return log["matrices"], H_core, log["energy"]
cycles = -1
if opts.backend == "ipu":
log["energy"], stop = utils.get_ipu_cycles(log["energy"])
cycles = (stop.array-start.array)[0,0]

return log["matrices"], H_core, log["energy"], cycles


FloatN = Float[Array, "N"]
Expand Down Expand Up @@ -276,6 +283,8 @@ def nanoDFT(mol, opts):
grid_AO = grid_AO.reshape(opts.ndevices, -1, 4, N)
grid_weights = grid_weights.reshape(opts.ndevices, -1)

nipu = opts.ndevices

# Run DFT algorithm (can be hardware accelerated).
if opts.dense_ERI:
assert opts.ndevices == 1, "Only support '--dense_ERI True' for `--ndevices 1`. "
Expand All @@ -286,38 +295,40 @@ def nanoDFT(mol, opts):
ERI[below_thr] = 0.0
ic(ERI.size, np.sum(below_thr), np.sum(below_thr)/ERI.size)
else:
from pyscf_ipu.nanoDFT.sparse_symmetric_ERI import get_i_j, num_repetitions_fast
distinct_ERI = mol.intor("int2e_sph", aosym="s8")
print(distinct_ERI.size)
below_thr = np.abs(distinct_ERI) <= opts.eri_threshold
distinct_ERI[below_thr] = 0.0
ic(distinct_ERI.size, np.sum(below_thr), np.sum(below_thr)/distinct_ERI.size)
nonzero_indices = np.nonzero(distinct_ERI)[0].astype(np.uint64)
nonzero_distinct_ERI = distinct_ERI[nonzero_indices].astype(np.float32)

ij, kl = get_i_j(nonzero_indices)
rep = num_repetitions_fast(ij, kl)
nonzero_distinct_ERI = nonzero_distinct_ERI / rep
batches = int(opts.batches) # perhaps make 10 batches?
nipu = opts.ndevices
remainder = nonzero_indices.shape[0] % (nipu*batches)

if remainder != 0:
print(nipu*batches-remainder, ij.shape)
ij = np.pad(ij, ((0,nipu*batches-remainder)))
kl = np.pad(kl, ((0,nipu*batches-remainder)))
nonzero_distinct_ERI = np.pad(nonzero_distinct_ERI, (0,nipu*batches-remainder))

ij = ij.reshape(nipu, batches, -1)
kl = kl.reshape(nipu, batches, -1)
nonzero_distinct_ERI = nonzero_distinct_ERI.reshape(nipu, batches, -1)

i, j = get_i_j(ij.reshape(-1))
k, l = get_i_j(kl.reshape(-1))
nonzero_indices = np.vstack([i,j,k,l]).T.reshape(nipu, batches, -1, 4).astype(np.int16)
nonzero_indices = jax.lax.bitcast_convert_type(nonzero_indices, np.float16)

ERI = [nonzero_distinct_ERI, nonzero_indices]
# from pyscf_ipu.nanoDFT.sparse_symmetric_ERI import get_i_j, num_repetitions_fast
# distinct_ERI = mol.intor("int2e_sph", aosym="s8")
# print(distinct_ERI.size)
# below_thr = np.abs(distinct_ERI) <= opts.eri_threshold
# distinct_ERI[below_thr] = 0.0
# ic(distinct_ERI.size, np.sum(below_thr), np.sum(below_thr)/distinct_ERI.size)
# nonzero_indices = np.nonzero(distinct_ERI)[0].astype(np.uint64)
# nonzero_distinct_ERI = distinct_ERI[nonzero_indices].astype(np.float32)

# ij, kl = get_i_j(nonzero_indices)
# rep = num_repetitions_fast(ij, kl)
# nonzero_distinct_ERI = nonzero_distinct_ERI / rep
# batches = int(opts.batches) # perhaps make 10 batches?
# nipu = opts.ndevices
# remainder = nonzero_indices.shape[0] % (nipu*batches)

# if remainder != 0:
# print(nipu*batches-remainder, ij.shape)
# ij = np.pad(ij, ((0,nipu*batches-remainder)))
# kl = np.pad(kl, ((0,nipu*batches-remainder)))
# nonzero_distinct_ERI = np.pad(nonzero_distinct_ERI, (0,nipu*batches-remainder))

# ij = ij.reshape(nipu, batches, -1)
# kl = kl.reshape(nipu, batches, -1)
# nonzero_distinct_ERI = nonzero_distinct_ERI.reshape(nipu, batches, -1)

# i, j = get_i_j(ij.reshape(-1))
# k, l = get_i_j(kl.reshape(-1))
# nonzero_indices = np.vstack([i,j,k,l]).T.reshape(nipu, batches, -1, 4).astype(np.int16)
# nonzero_indices = jax.lax.bitcast_convert_type(nonzero_indices, np.float16)

# ERI = [nonzero_distinct_ERI, nonzero_indices]
# eri_in_axes = [0,0]
ERI = [np.ones((nipu, 1)), np.ones((nipu, 1))]
eri_in_axes = [0,0]

input_ijkl, _, _, _ = gen_shells(mol, opts.screen_tol, nipu, fast_shells=opts.fast_shells)
Expand All @@ -329,7 +340,8 @@ def nanoDFT(mol, opts):
axis_name="p")
print(grid_AO.shape, grid_weights.shape)
vals = jitted_nanoDFT(state, ERI, grid_AO, grid_weights, grid_ijkl)
logged_matrices, H_core, logged_energies = [np.asarray(a[0]).astype(np.float64) for a in vals] # Ensure CPU
logged_matrices, H_core, logged_energies, cycles = [np.asarray(a[0]).astype(np.float64) for a in vals] # Ensure CPU
if opts.backend == "ipu": print("Cycle Count: ", cycles/10**6, "[M]")

# It's cheap to compute energy/hlgap on CPU in float64 from the logged values/matrices.
logged_E_xc = logged_energies[:, 3].copy()
Expand Down
12 changes: 6 additions & 6 deletions pyscf_ipu/nanoDFT/sparse_symmetric_intor_ERI.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def get_shapes(input_ijkl, bas):

len = leng + lenl + lenk + lenj + leni + len0

return len, nf
return len, nf, nc

vertex_filename = osp.join(osp.dirname(__file__), "intor_int2e_sph_condensed.cpp")
int2e_sph_forloop = create_ipu_tile_primitive(
Expand Down Expand Up @@ -289,7 +289,7 @@ def compute_diff_jk(input_ijkl_slice, dm, mol, nbatch, tolerance, ndevices, back
input_floats = env.reshape(1, -1)
input_ints = np.hstack([0, 0, n_atm, n_bas, 0, n_ao_loc, ao_loc.reshape(-1), atm.reshape(-1), bas.reshape(-1)]) # todo 0s not used (n_buf, n_eri, n_env)

input_ijkl, sizes, counts, shapes = gen_shells(mol, tolerance, ndevices)
input_ijkl, sizes, counts, shapes = gen_shells(mol, tolerance, ndevices, fast_shells=True)

all_eris = []
all_indices = []
Expand All @@ -316,16 +316,16 @@ def compute_diff_jk(input_ijkl_slice, dm, mol, nbatch, tolerance, ndevices, back
slice_count = count // ndevices
slice_ijkl = input_ijkl_slice[slice_indices[i]:slice_indices[i+1]].reshape(-1, 4)

glen, nf = shapes[i]
glen, nf, nc = shapes[i]
chunk_size = num_tiles * num_threads
num_full_batches = slice_count//chunk_size

tiles = tuple((np.arange(num_tiles*num_threads)%(num_tiles)+1).tolist())
tile_g = tile_put_replicated(jnp.empty(min(int(glen), 3888)+1), tiles)
tile_idx = tile_put_replicated(jnp.empty(max(256, min(int(nf*3), 3888)+1), dtype=jnp.int32), tiles)
tile_buf = tile_put_replicated(jnp.empty(1080*4+1), tiles)
tile_idx = tile_put_replicated(jnp.empty(256, dtype=jnp.int32), tiles) # condensed version only # sto3g, 631g
tile_buf = tile_put_replicated(jnp.empty(81), tiles) # condensed version only # sto3g, 631g
integral_size = tile_put_replicated(jnp.array(size, dtype=jnp.uint32), tiles)
tile_gctr_buf = tile_put_replicated(jnp.empty((1296+1), dtype=jnp.float32), tiles) # condensed version only
tile_gctr_buf = tile_put_replicated(jnp.empty((81+1), dtype=jnp.float32), tiles) # condensed version only # sto3g, 631g

def batched_compute(start, stop, chunk_size, tiles):
assert (stop-start) < chunk_size or (stop-start) % chunk_size == 0
Expand Down
Loading

0 comments on commit 40eb71c

Please sign in to comment.