From c6f98b9079c56953a89c45c4f4af3d5b397764ff Mon Sep 17 00:00:00 2001 From: Mihai Polceanu Date: Wed, 8 Nov 2023 20:58:31 +0000 Subject: [PATCH] implemented exchange correlation batching and eliminated memory spikes --- pyscf_ipu/nanoDFT/nanoDFT.py | 51 +++++++++++++++++++++++++----------- 1 file changed, 35 insertions(+), 16 deletions(-) diff --git a/pyscf_ipu/nanoDFT/nanoDFT.py b/pyscf_ipu/nanoDFT/nanoDFT.py index 1664877..44518ef 100644 --- a/pyscf_ipu/nanoDFT/nanoDFT.py +++ b/pyscf_ipu/nanoDFT/nanoDFT.py @@ -42,7 +42,7 @@ def energy(density_matrix, H_core, diff_JK, E_xc, E_nuc, _np=jax.numpy): def nanoDFT_iteration(i, vals, opts, mol): """Each call updates density_matrix attempting to minimize energy(density_matrix, ... ). """ density_matrix, V_xc, diff_JK, O, H_core, L_inv = vals[:6] # All (N, N) matrices - E_nuc, occupancy, ERI, grid_weights, grid_AO, shell_ijkl, diis_history, log = vals[6:] # Varying types/shapes. + E_nuc, occupancy, ERI, batched_grid_weights, batched_grid_AO, shell_ijkl, diis_history, log = vals[6:] # Varying types/shapes. if opts.v: print("---------- MEMORY CONSUMPTION ----------") @@ -72,8 +72,9 @@ def nanoDFT_iteration(i, vals, opts, mol): # Step 3: Use result from eigenproblem to update density_matrix. density_matrix = (eigvects*occupancy*2) @ eigvects.T # (N, N) - E_xc, V_xc = exchange_correlation(density_matrix, grid_AO, grid_weights) # float (N, N) - diff_JK = get_JK(shell_ijkl, density_matrix, ERI, opts.dense_ERI, opts.screen_tol, opts.backend, mol, opts.ndevices) # (N, N) (N, N) + # E_xc, V_xc = exchange_correlation(density_matrix, grid_AO, grid_weights) # float (N, N) + E_xc, V_xc, _, _, _ = jax.lax.fori_loop(0, opts.batches_grid, foreach_exchange_correlation_batch, (0, V_xc, density_matrix, batched_grid_AO, batched_grid_weights)) + diff_JK = get_JK(shell_ijkl, density_matrix, ERI, opts.dense_ERI, opts.screen_tol, opts.backend, mol, opts.batches_einsum, opts.ndevices) # (N, N) (N, N) # Log SCF matrices and energies (not used by DFT algorithm). #log["matrices"] = log["matrices"].at[i].set(jnp.stack((density_matrix, J, K, H))) # (iterations, 4, N, N) @@ -91,7 +92,7 @@ def nanoDFT_iteration(i, vals, opts, mol): def host_callback(data, i): # labels are adjusted to the `data` that will be passed to the callback - keep that in mind when passing different list of tensors - labels = ["density_matrix", "V_xc", "diff_JK", "O", "H_core", "L_inv", "E_nuc", "occupancy", "ERI", "grid_weights", "grid_AO", "shell_ijkl", "diis_history", "E_xc", "eigvects", "H"] + labels = ["density_matrix", "V_xc", "diff_JK", "O", "H_core", "L_inv", "E_nuc", "occupancy", "ERI", "batched_grid_weights", "batched_grid_AO", "shell_ijkl", "diis_history", "E_xc", "eigvects", "H"] for l, d in zip(labels, data): if l == "diis_history" or l == "ERI": for idx, arr in enumerate(d): @@ -102,7 +103,7 @@ def host_callback(data, i): jax.debug.callback(host_callback, vals[:-1] + [E_xc, eigvects, H], i) return [density_matrix, V_xc, diff_JK, O, H_core, L_inv, E_nuc, occupancy, ERI, - grid_weights, grid_AO, shell_ijkl, diis_history, log] + batched_grid_weights, batched_grid_AO, shell_ijkl, diis_history, log] def exchange_correlation(density_matrix, grid_AO, grid_weights): @@ -125,7 +126,7 @@ def exchange_correlation(density_matrix, grid_AO, grid_weights): V_xc = V_xc + V_xc.T # (N, N) return E_xc, V_xc # (float) (N, N) -def get_JK(shell_ijkl, density_matrix, ERI, dense_ERI, tolerance, backend, mol, ndevices): +def get_JK(shell_ijkl, density_matrix, ERI, dense_ERI, tolerance, backend, mol, nbatches, ndevices): """Computes the (N, N) matrices J and K. Density matrix is (N, N) and ERI is (N, N, N, N). """ N = density_matrix.shape[0] @@ -137,23 +138,43 @@ def get_JK(shell_ijkl, density_matrix, ERI, dense_ERI, tolerance, backend, mol, #from pyscf_ipu.nanoDFT.sparse_symmetric_ERI import sparse_symmetric_einsum #diff_JK = sparse_symmetric_einsum(ERI[0], ERI[1], density_matrix, backend) - diff_JK, cycles = compute_diff_jk(shell_ijkl, density_matrix, mol, 32, tolerance, ndevices=ndevices, backend="ipu") + diff_JK, cycles = compute_diff_jk(shell_ijkl, density_matrix, mol, nbatches, tolerance, ndevices=ndevices, backend="ipu") diff_JK = diff_JK.reshape(N, N) return diff_JK +def foreach_exchange_correlation_batch(i, vals): + E_xc, V_xc, dm, ao, w = vals + crt_grid_AO = ao[i] + crt_weights = w[i] + crt_grid_AO = jnp.transpose(crt_grid_AO, (1,0,2)) + crt_E_xc, crt_V_xc = exchange_correlation(dm, crt_grid_AO, crt_weights) # float (N, N) + E_xc += crt_E_xc + V_xc += crt_V_xc + return (E_xc, V_xc, dm, ao, w) + def _nanoDFT(state, ERI, grid_AO, grid_weights, shell_ijkl, opts, mol): # if opts.backend == "ipu": grid_weights, start = utils.get_ipu_cycles(grid_weights) # Utilize the IPUs MIMD parallism to compute the electron repulsion integrals (ERIs) in parallel. #if opts.backend == "ipu": state.ERI = electron_repulsion_integrals(state.input_floats, state.input_ints, mol, opts.threads_int, opts.intv) #else: pass # Compute on CPU. - grid_AO = jnp.transpose(grid_AO, (1,0,2)) # (padded_gsize/16, 4, N) -> (4, pgsize, N) + # grid_AO = jnp.transpose(grid_AO, (1,0,2)) # (padded_gsize/16, 4, N) -> (4, pgsize, N) + + E_xc = jnp.zeros(1) + V_xc = jnp.zeros(state.density_matrix.shape) + + batched_grid_AO = grid_AO.reshape(opts.batches_grid, -1, grid_AO.shape[-2], grid_AO.shape[-1]) + batched_grid_weights = grid_weights.reshape(opts.batches_grid, -1) # Precompute the remaining tensors. - E_xc, V_xc = exchange_correlation(state.density_matrix, grid_AO, grid_weights) # float (N, N) - diff_JK = get_JK(shell_ijkl, state.density_matrix, ERI, opts.dense_ERI, opts.screen_tol, opts.backend, mol, opts.ndevices) # (N, N) (N, N) + + E_xc, V_xc, _, _, _ = jax.lax.fori_loop(0, opts.batches_grid, foreach_exchange_correlation_batch, (E_xc, V_xc, state.density_matrix, batched_grid_AO, batched_grid_weights)) + + # grid_AO = jnp.transpose(grid_AO, (1,0,2)) # (padded_gsize/16, 4, N) -> (4, pgsize, N) + + diff_JK = get_JK(shell_ijkl, state.density_matrix, ERI, opts.dense_ERI, opts.screen_tol, opts.backend, mol, opts.batches_einsum, opts.ndevices) # (N, N) (N, N) H_core = state.kinetic + state.nuclear # (N, N) # Log matrices from all DFT iterations (not used by DFT algorithm). @@ -162,7 +183,7 @@ def _nanoDFT(state, ERI, grid_AO, grid_weights, shell_ijkl, opts, mol): # Perform DFT iterations. log = jax.lax.fori_loop(0, opts.its, partial(nanoDFT_iteration, opts=opts, mol=mol), [state.density_matrix, V_xc, diff_JK, state.O, H_core, state.L_inv, # all (N, N) matrices - state.E_nuc, state.mask, ERI, grid_weights, grid_AO, shell_ijkl, state.diis_history, log])[-1] + state.E_nuc, state.mask, ERI, batched_grid_weights, batched_grid_AO, shell_ijkl, state.diis_history, log])[-1] # cycles = -1 # if opts.backend == "ipu": @@ -281,15 +302,12 @@ def nanoDFT(mol, opts): grid_weights = _grid_weights gsize = grid_AO.shape[0] - grid_batches = 2 - - num_grid_shards = (opts.ndevices * grid_batches) + num_grid_shards = (opts.ndevices * opts.batches_grid) remainder = gsize % num_grid_shards if remainder != 0: grid_AO = jnp.pad(grid_AO, ((0,num_grid_shards-remainder), (0,0), (0,0)) ) grid_weights = jnp.pad(grid_weights, ((0,num_grid_shards-remainder)) ) - grid_coords = np.pad(grid_coords, ((0,num_grid_shards-remainder), (0,0))) grid_AO = grid_AO.reshape(opts.ndevices, -1, 4, N) grid_weights = grid_weights.reshape(opts.ndevices, -1) @@ -592,7 +610,8 @@ def nanoDFT_options( structure_optimization: bool = False, # AKA gradient descent on energy wrt nuclei eri_threshold : float = 0.0, ao_threshold: float = 0.0, - batches: int = 32, + batches_einsum: int = 32, + batches_grid: int = 32, ndevices: int = 1, dense_ERI: bool = False, v: bool = False, # verbose