Skip to content

Commit

Permalink
temporarily deactivated cycles; preparing to serialise exchange corre…
Browse files Browse the repository at this point in the history
…lation and add batching control
  • Loading branch information
mihaipgc committed Nov 8, 2023
1 parent 40eb71c commit a07eba6
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 23 deletions.
50 changes: 30 additions & 20 deletions pyscf_ipu/nanoDFT/nanoDFT.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def energy(density_matrix, H_core, diff_JK, E_xc, E_nuc, _np=jax.numpy):
def nanoDFT_iteration(i, vals, opts, mol):
"""Each call updates density_matrix attempting to minimize energy(density_matrix, ... ). """
density_matrix, V_xc, diff_JK, O, H_core, L_inv = vals[:6] # All (N, N) matrices
E_nuc, occupancy, ERI, grid_weights, grid_AO, grid_ijkl, diis_history, log = vals[6:] # Varying types/shapes.
E_nuc, occupancy, ERI, grid_weights, grid_AO, shell_ijkl, diis_history, log = vals[6:] # Varying types/shapes.

if opts.v:
print("---------- MEMORY CONSUMPTION ----------")
Expand Down Expand Up @@ -73,7 +73,7 @@ def nanoDFT_iteration(i, vals, opts, mol):
# Step 3: Use result from eigenproblem to update density_matrix.
density_matrix = (eigvects*occupancy*2) @ eigvects.T # (N, N)
E_xc, V_xc = exchange_correlation(density_matrix, grid_AO, grid_weights) # float (N, N)
diff_JK = get_JK(grid_ijkl, density_matrix, ERI, opts.dense_ERI, opts.screen_tol, opts.backend, mol, opts.ndevices) # (N, N) (N, N)
diff_JK = get_JK(shell_ijkl, density_matrix, ERI, opts.dense_ERI, opts.screen_tol, opts.backend, mol, opts.ndevices) # (N, N) (N, N)

# Log SCF matrices and energies (not used by DFT algorithm).
#log["matrices"] = log["matrices"].at[i].set(jnp.stack((density_matrix, J, K, H))) # (iterations, 4, N, N)
Expand All @@ -91,7 +91,7 @@ def nanoDFT_iteration(i, vals, opts, mol):

def host_callback(data, i):
# labels are adjusted to the `data` that will be passed to the callback - keep that in mind when passing different list of tensors
labels = ["density_matrix", "V_xc", "diff_JK", "O", "H_core", "L_inv", "E_nuc", "occupancy", "ERI", "grid_weights", "grid_AO", "grid_ijkl", "diis_history", "E_xc", "eigvects", "H"]
labels = ["density_matrix", "V_xc", "diff_JK", "O", "H_core", "L_inv", "E_nuc", "occupancy", "ERI", "grid_weights", "grid_AO", "shell_ijkl", "diis_history", "E_xc", "eigvects", "H"]
for l, d in zip(labels, data):
if l == "diis_history" or l == "ERI":
for idx, arr in enumerate(d):
Expand All @@ -102,7 +102,7 @@ def host_callback(data, i):
jax.debug.callback(host_callback, vals[:-1] + [E_xc, eigvects, H], i)

return [density_matrix, V_xc, diff_JK, O, H_core, L_inv, E_nuc, occupancy, ERI,
grid_weights, grid_AO, grid_ijkl, diis_history, log]
grid_weights, grid_AO, shell_ijkl, diis_history, log]


def exchange_correlation(density_matrix, grid_AO, grid_weights):
Expand All @@ -125,7 +125,7 @@ def exchange_correlation(density_matrix, grid_AO, grid_weights):
V_xc = V_xc + V_xc.T # (N, N)
return E_xc, V_xc # (float) (N, N)

def get_JK(grid_ijkl, density_matrix, ERI, dense_ERI, tolerance, backend, mol, ndevices):
def get_JK(shell_ijkl, density_matrix, ERI, dense_ERI, tolerance, backend, mol, ndevices):
"""Computes the (N, N) matrices J and K. Density matrix is (N, N) and ERI is (N, N, N, N). """
N = density_matrix.shape[0]

Expand All @@ -137,14 +137,14 @@ def get_JK(grid_ijkl, density_matrix, ERI, dense_ERI, tolerance, backend, mol, n
#from pyscf_ipu.nanoDFT.sparse_symmetric_ERI import sparse_symmetric_einsum
#diff_JK = sparse_symmetric_einsum(ERI[0], ERI[1], density_matrix, backend)

diff_JK = compute_diff_jk(grid_ijkl, density_matrix, mol, 1, tolerance, ndevices=ndevices, backend="ipu")
diff_JK, cycles = compute_diff_jk(shell_ijkl, density_matrix, mol, 32, tolerance, ndevices=ndevices, backend="ipu")

diff_JK = diff_JK.reshape(N, N)

return diff_JK

def _nanoDFT(state, ERI, grid_AO, grid_weights, grid_ijkl, opts, mol):
if opts.backend == "ipu": grid_weights, start = utils.get_ipu_cycles(grid_weights)
def _nanoDFT(state, ERI, grid_AO, grid_weights, shell_ijkl, opts, mol):
# if opts.backend == "ipu": grid_weights, start = utils.get_ipu_cycles(grid_weights)

# Utilize the IPUs MIMD parallism to compute the electron repulsion integrals (ERIs) in parallel.
#if opts.backend == "ipu": state.ERI = electron_repulsion_integrals(state.input_floats, state.input_ints, mol, opts.threads_int, opts.intv)
Expand All @@ -153,7 +153,7 @@ def _nanoDFT(state, ERI, grid_AO, grid_weights, grid_ijkl, opts, mol):

# Precompute the remaining tensors.
E_xc, V_xc = exchange_correlation(state.density_matrix, grid_AO, grid_weights) # float (N, N)
diff_JK = get_JK(grid_ijkl, state.density_matrix, ERI, opts.dense_ERI, opts.screen_tol, opts.backend, mol, opts.ndevices) # (N, N) (N, N)
diff_JK = get_JK(shell_ijkl, state.density_matrix, ERI, opts.dense_ERI, opts.screen_tol, opts.backend, mol, opts.ndevices) # (N, N) (N, N)
H_core = state.kinetic + state.nuclear # (N, N)

# Log matrices from all DFT iterations (not used by DFT algorithm).
Expand All @@ -162,14 +162,14 @@ def _nanoDFT(state, ERI, grid_AO, grid_weights, grid_ijkl, opts, mol):

# Perform DFT iterations.
log = jax.lax.fori_loop(0, opts.its, partial(nanoDFT_iteration, opts=opts, mol=mol), [state.density_matrix, V_xc, diff_JK, state.O, H_core, state.L_inv, # all (N, N) matrices
state.E_nuc, state.mask, ERI, grid_weights, grid_AO, grid_ijkl, state.diis_history, log])[-1]
state.E_nuc, state.mask, ERI, grid_weights, grid_AO, shell_ijkl, state.diis_history, log])[-1]

cycles = -1
if opts.backend == "ipu":
log["energy"], stop = utils.get_ipu_cycles(log["energy"])
cycles = (stop.array-start.array)[0,0]
# cycles = -1
# if opts.backend == "ipu":
# log["energy"], stop = utils.get_ipu_cycles(log["energy"])
# cycles = (stop.array-start.array)[0,0]

return log["matrices"], H_core, log["energy"], cycles
return log["matrices"], H_core, log["energy"], 0 #cycles


FloatN = Float[Array, "N"]
Expand Down Expand Up @@ -227,16 +227,21 @@ def init_dft_tensors_cpu(mol, opts, DIIS_iters=9):
grids = pyscf.dft.gen_grid.Grids(mol)
grids.level = opts.level
grids.build()

grid_weights = grids.weights # (grid_size,) = (45624,) for C6H6
coord_str = 'GTOval_cart_deriv1' if mol.cart else 'GTOval_sph_deriv1'
grid_AO = mol.eval_gto(coord_str, grids.coords, 4) # (4, grid_size, N) = (4, 45624, 9) for C6H6.

if opts.ao_threshold > 0.0:
grid_AO[np.abs(grid_AO)<opts.ao_threshold] = 0
sparsity_mask = np.where(np.all(grid_AO == 0, axis=0), 0, 1)
sparse_rows = np.where(np.all(sparsity_mask == 0, axis=1), 0, 1).reshape(-1, 1)
print('grid_AO.shape', grid_AO.shape)
print('sparsity_mask.shape', sparsity_mask.shape)
print(f"axis=( , ) sparsity in grid_AO: {np.sum(grid_AO==0) / grid_AO.size:.4f}")
print(f"axis=(0, ) sparsity in grid_AO: {np.sum(sparsity_mask==0) / sparsity_mask.size:.4f}")
print(f"axis=(0, 2) sparsity in grid_AO: {np.sum(sparse_rows==0) / sparse_rows.size:.4f}")
print(f"axis=( , ) sparsity in grid_weights: {np.sum(grid_weights==0) / grid_weights.size:.4f}")
grid_AO = jnp.delete(grid_AO, jnp.where(sparse_rows == 0)[0], axis=1)
grid_weights = jnp.delete(grid_weights, jnp.where(sparse_rows == 0)[0], axis=0)
grid_coords = jnp.delete(grids.coords, jnp.where(sparse_rows == 0)[0], axis=0)
Expand Down Expand Up @@ -276,10 +281,15 @@ def nanoDFT(mol, opts):
grid_weights = _grid_weights
gsize = grid_AO.shape[0]

remainder = gsize % opts.ndevices
grid_batches = 2

num_grid_shards = (opts.ndevices * grid_batches)

remainder = gsize % num_grid_shards
if remainder != 0:
grid_AO = jnp.pad(grid_AO, ((0,remainder), (0,0), (0,0)) )
grid_weights = jnp.pad(grid_weights, ((0,remainder)) )
grid_AO = jnp.pad(grid_AO, ((0,num_grid_shards-remainder), (0,0), (0,0)) )
grid_weights = jnp.pad(grid_weights, ((0,num_grid_shards-remainder)) )
grid_coords = np.pad(grid_coords, ((0,num_grid_shards-remainder), (0,0)))
grid_AO = grid_AO.reshape(opts.ndevices, -1, 4, N)
grid_weights = grid_weights.reshape(opts.ndevices, -1)

Expand Down Expand Up @@ -332,14 +342,14 @@ def nanoDFT(mol, opts):
eri_in_axes = [0,0]

input_ijkl, _, _, _ = gen_shells(mol, opts.screen_tol, nipu, fast_shells=opts.fast_shells)
grid_ijkl = np.concatenate([np.array(ijkl, dtype=int).reshape(nipu, -1) for ijkl in input_ijkl], axis=-1)
shell_ijkl = np.concatenate([np.array(ijkl, dtype=int).reshape(nipu, -1) for ijkl in input_ijkl], axis=-1)

#jitted_nanoDFT = jax.jit(partial(_nanoDFT, opts=opts, mol=mol), backend=opts.backend)
jitted_nanoDFT = jax.pmap(partial(_nanoDFT, opts=opts, mol=mol), backend=opts.backend,
in_axes=(None, eri_in_axes, 0, 0, 0),
axis_name="p")
print(grid_AO.shape, grid_weights.shape)
vals = jitted_nanoDFT(state, ERI, grid_AO, grid_weights, grid_ijkl)
vals = jitted_nanoDFT(state, ERI, grid_AO, grid_weights, shell_ijkl)
logged_matrices, H_core, logged_energies, cycles = [np.asarray(a[0]).astype(np.float64) for a in vals] # Ensure CPU
if opts.backend == "ipu": print("Cycle Count: ", cycles/10**6, "[M]")

Expand Down
22 changes: 19 additions & 3 deletions pyscf_ipu/nanoDFT/sparse_symmetric_intor_ERI.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from functools import partial, lru_cache
from icecream import ic
from tqdm import tqdm
import utils
from utils import process_mol_str

jax.config.update('jax_platform_name', "cpu")
#jax.config.update('jax_enable_x64', True)
Expand Down Expand Up @@ -273,6 +275,8 @@ def get_shapes(input_ijkl, bas):
)

def compute_diff_jk(input_ijkl_slice, dm, mol, nbatch, tolerance, ndevices, backend):
# if backend == "ipu": dm, start = utils.get_ipu_cycles(dm)

dm = dm.reshape(-1)
diff_JK = jnp.zeros(dm.shape)
N = int(np.sqrt(dm.shape[0]))
Expand Down Expand Up @@ -430,7 +434,17 @@ def foreach_symmetry(sym, vals):
total_diff_JK += diff_JK

# return total_diff_JK
return jax.lax.psum(total_diff_JK, axis_name="p")

jk_psum = jax.lax.psum(total_diff_JK, axis_name="p")

# cycles = -1
# if backend == "ipu":
# jk_psum, stop = utils.get_ipu_cycles(jk_psum)
# cycles = (stop.array-start.array)[0,0]

# return jk_psum, cycles

return jk_psum, 0

if __name__ == "__main__":
import time
Expand All @@ -456,7 +470,8 @@ def foreach_symmetry(sym, vals):

start = time.time()

mol = pyscf.gto.Mole(atom="".join(f"C 0 {1.54*j} {1.54*i};" for i in range(natm) for j in range(natm)), basis=args.basis)
# mol = pyscf.gto.Mole(atom="".join(f"C 0 {1.54*j} {1.54*i};" for i in range(natm) for j in range(natm)), basis=args.basis)
mol = pyscf.gto.Mole(atom=process_mol_str('bench10'), basis=args.basis)
#mol = pyscf.gto.Mole(atom="".join(f"C 0 {1.54*j} {1.54*i};" for i in range(natm) for j in range(natm))) # sto-3g by default
# mol = pyscf.gto.Mole(atom="".join(f"C 0 {1.54*j} {1.54*i};" for i in range(1) for j in range(2)), basis="sto3g")
#mol = pyscf.gto.Mole(atom="".join(f"C 0 {1.54*j} {1.54*i};" for i in range(natm) for j in range(natm)), basis="sto3g")
Expand Down Expand Up @@ -493,7 +508,8 @@ def foreach_symmetry(sym, vals):
input_ijkl, sizes, counts, shapes = gen_shells(mol, args.itol, args.nipu, fast_shells=args.fast_shells)
sliced_input_ijkl = np.concatenate([np.array(ijkl, dtype=int).reshape(args.nipu, -1) for ijkl in input_ijkl], axis=-1)

diff_JK = jax.pmap(compute_diff_jk, in_axes=(0, None, None, None, None, None, None), static_broadcasted_argnums=(2, 3, 4, 5, 6), backend=backend, axis_name="p")(sliced_input_ijkl, dm, mol, args.batches, args.itol, args.nipu, args.backend)
diff_JK, cycles = jax.pmap(compute_diff_jk, in_axes=(0, None, None, None, None, None, None), static_broadcasted_argnums=(2, 3, 4, 5, 6), backend=backend, axis_name="p")(sliced_input_ijkl, dm, mol, args.batches, args.itol, args.nipu, args.backend)
if backend == "ipu": print("Cycle Count: ", cycles/10**6, "[M]")

# ------------------------------------ #

Expand Down

0 comments on commit a07eba6

Please sign in to comment.