diff --git a/pyscf_ipu/nanoDFT/experimental_pmap_nanoDFT.py b/pyscf_ipu/nanoDFT/experimental_pmap_nanoDFT.py index 5db00f0..3791e10 100644 --- a/pyscf_ipu/nanoDFT/experimental_pmap_nanoDFT.py +++ b/pyscf_ipu/nanoDFT/experimental_pmap_nanoDFT.py @@ -1,4 +1,6 @@ # Copyright (c) 2023 Graphcore Ltd. All rights reserved. +import os +os.environ["OMP_NUM_THREADS"] = "16" import jax import jax.numpy as jnp import numpy as np @@ -39,22 +41,26 @@ def energy(density_matrix, H_core, diff_JK, E_xc, E_nuc, _np=jax.numpy): def nanoDFT_iteration(i, vals, opts, mol): """Each call updates density_matrix attempting to minimize energy(density_matrix, ... ). """ density_matrix, V_xc, O, H_core, L_inv, diff_JK = vals[:6] # All (N, N) matrices - E_nuc, occupancy, ERI, grid_weights, sharded_grid_weights, grid_AO, sharded_grid_AO, diis_history, log= vals[6:] # Varying types/shapes. - - print("@@@@@@@@@@") - mb = 0 - for t in vals: - try: - if type(t) != type(()) and len(np.shape(t)) > 0: - print(np.shape(t), t.nbytes/10**6) - mb += t.nbytes/10**6 - except: - print(type(t)) - for a in ERI: - print(a.shape, a.nbytes) - print(mb) - print("@@@@@@@@@@") - print("") + E_nuc, occupancy, ERI, sharded_grid_weights, sharded_grid_AO, diis_history, log= vals[6:] # Varying types/shapes. + + if opts.v: + print("---------- MEMORY CONSUMPTION ----------") + MB = 0 + for t in vals: + try: + if type(t) != type(()) and len(np.shape(t)) > 0: + print(t.nbytes/10**6, t.shape, t.dtype) + MB += t.nbytes/10**6 + except: + print(type(t)) + print("ERI") + for a in ERI: # prints weird in dense_ERI case + print( a.nbytes/10**6, a.shape) + MB += a.nbytes / 10**6 + print("__________") + print("Total: ", MB) + print("----------------------------------------") + print("") # Step 1: Update Hamiltonian (optionally use DIIS to improve DFT convergence). H = H_core + diff_JK + V_xc # (N, N) @@ -65,22 +71,23 @@ def nanoDFT_iteration(i, vals, opts, mol): # Step 3: Use result from eigenproblem to update density_matrix. density_matrix = (eigvects*occupancy*2) @ eigvects.T # (N, N) - E_xc, V_xc = exchange_correlation(density_matrix, grid_AO, sharded_grid_AO, grid_weights, sharded_grid_weights) # float (N, N) - diff_JK = get_JK(density_matrix, ERI, opts, mol) #(4,gsize,N) # (N, N) + E_xc, V_xc = exchange_correlation(density_matrix, sharded_grid_AO, sharded_grid_weights) # float (N, N) + diff_JK = get_JK(density_matrix, ERI, opts.dense_ERI) #(4,gsize,N) # (N, N) # Log SCF matrices and energies (not used by DFT algorithm). #log["matrices"] = log["matrices"].at[i].set(jnp.stack((density_matrix, J, K, H))) # (iterations, 4, N, N) ## removing this reduces memory from 50kb to 18kb N = density_matrix.shape[0] log["matrices"] = jax.lax.dynamic_update_slice(log["matrices"], density_matrix.reshape(1, 1, N, N), (i, 0, 0, 0)) - log["matrices"] = jax.lax.dynamic_update_slice(log["matrices"], diff_JK.reshape(1, 1, N, N), (i, 1, 0, 0)) - log["matrices"] = jax.lax.dynamic_update_slice(log["matrices"], H.reshape(1, 1, N, N), (i, 2, 0, 0)) - - log["energy"] = log["energy"].at[i].set(energy(density_matrix, H_core, diff_JK, E_xc, E_nuc)) # (iterations, 6) + log["matrices"] = jax.lax.dynamic_update_slice(log["matrices"], diff_JK. reshape(1, 1, N, N), (i, 1, 0, 0)) + log["matrices"] = jax.lax.dynamic_update_slice(log["matrices"], H. reshape(1, 1, N, N), (i, 2, 0, 0)) + log["energy"] = log["energy"].at[i].set(energy(density_matrix, H_core, diff_JK, E_xc, E_nuc)) # (iterations, 6) - return [density_matrix, V_xc, O, H_core, L_inv, diff_JK, E_nuc, occupancy, ERI, grid_weights, sharded_grid_weights, grid_AO, sharded_grid_AO, diis_history, log] + return [density_matrix, V_xc, O, H_core, L_inv, diff_JK, E_nuc, occupancy, ERI, sharded_grid_weights, sharded_grid_AO, diis_history, log] -def exchange_correlation(density_matrix, grid_AO, sharded_grid_AO, grid_weights, sharded_grid_weights): +def exchange_correlation(density_matrix, sharded_grid_AO, sharded_grid_weights): """Compute exchange correlation integral using atomic orbitals (AO) evalauted on a grid. """ + # TODO: we can also use sparsity on grid! Visualization shows it's similarly sparse to ERI! + # Perfectly SIMD parallelizable over grid_size axis. # Only need one reduce_sum in the end. grid_AO_dm = sharded_grid_AO[0] @ density_matrix # (gsize, N) @ (N, N) -> (gsize, N) @@ -99,40 +106,32 @@ def exchange_correlation(density_matrix, grid_AO, sharded_grid_AO, grid_weights, V_xc = V_xc + V_xc.T # (N, N) return E_xc, V_xc # (float) (N, N) -def get_JK(density_matrix, ERI, opts, mol): +def get_JK(density_matrix, ERI, dense_ERI): """Computes the (N, N) matrices J and K. Density matrix is (N, N) and ERI is (N, N, N, N). """ N = density_matrix.shape[0] - def sparse_mult(sparse, vector): - rows, cols, values = sparse - in_ = vector.take(cols, axis=0) - prod = in_*values[:, None] - return jax.ops.segment_sum(prod, rows, N**2) - - if opts.backend == "ipu": - from sparse_symmetric_ERI import sparse_symmetric_einsum - mask = jnp.ones(1) - diff_JK = sparse_symmetric_einsum(ERI[0], ERI[1], density_matrix, mask, opts.backend) - else: + if dense_ERI: J = jnp.einsum('ijkl,ji->kl', ERI, density_matrix) # (N, N) K = jnp.einsum('ijkl,jk->il', ERI, density_matrix) # (N, N) diff_JK = J - (K / 2 * HYB_B3LYP) + else: + from sparse_symmetric_ERI import sparse_symmetric_einsum + diff_JK = sparse_symmetric_einsum(ERI[0], ERI[1], density_matrix, opts.backend) diff_JK = diff_JK.reshape(N, N) return diff_JK # (N, N) (N, N), -def _nanoDFT(E_nuc, density_matrix, kinetic, nuclear, O, grid_AO, grid_weights, +def _nanoDFT(E_nuc, density_matrix, kinetic, nuclear, O, mask, _input_floats, _input_ints, L_inv, diis_history, sharded_grid_AO, sharded_grid_weights, ERI, opts, mol): # Utilize the IPUs MIMD parallism to compute the electron repulsion integrals (ERIs) in parallel. #if opts.backend == "ipu": ERI = electron_repulsion_integrals(_input_floats, _input_ints, mol, opts.threads_int, opts.intv) #else: pass # Compute on CPU. - sharded_grid_AO = jnp.transpose(sharded_grid_AO, (1,0,2)) # (padded_gsize/16, 4, N) -> (4, pgsize, N) # Precompute the remaining tensors. - E_xc, V_xc = exchange_correlation(density_matrix, grid_AO, sharded_grid_AO, grid_weights, sharded_grid_weights) # float (N, N) - diff_JK = get_JK(density_matrix, ERI, opts, mol) # (N, N) (N, N) + E_xc, V_xc = exchange_correlation(density_matrix, sharded_grid_AO, sharded_grid_weights) # float (N, N) + diff_JK = get_JK(density_matrix, ERI, opts.dense_ERI) # (N, N) (N, N) H_core = kinetic + nuclear # (N, N) # Log matrices from all DFT iterations (not used by DFT algorithm). @@ -141,7 +140,7 @@ def _nanoDFT(E_nuc, density_matrix, kinetic, nuclear, O, grid_AO, grid_weights, # Perform DFT iterations. log = jax.lax.fori_loop(0, opts.its, partial(nanoDFT_iteration, opts=opts, mol=mol), [density_matrix, V_xc, O, H_core, L_inv, diff_JK, # all (N, N) matrices - E_nuc, mask, ERI, grid_weights, sharded_grid_weights, grid_AO, sharded_grid_AO, diis_history, log])[-1] + E_nuc, mask, ERI, sharded_grid_weights, sharded_grid_AO, diis_history, log])[-1] return log["matrices"], H_core, log["energy"] def init_dft_tensors_cpu(mol, opts, DIIS_iters=9): @@ -164,9 +163,6 @@ def init_dft_tensors_cpu(mol, opts, DIIS_iters=9): nuclear = mol.intor_symmetric('int1e_nuc') # (N,N) O = mol.intor_symmetric('int1e_ovlp') # (N,N) L_inv = np.linalg.inv(np.linalg.cholesky(O)) # (N,N) - #if opts.backend != "ipu": ERI = mol.intor("int2e_sph") # (N,N,N,N)=(66,66,66,66) for C6H6. - #else: ERI = None # will be computed on device - ERI = 0# todo: move ERI defined later in here. input_floats, input_ints = 0, 0 #prepare_electron_repulsion_integrals(mol)[:2] mask = np.concatenate([np.ones(n_electrons_half), np.zeros(N-n_electrons_half)]) @@ -176,78 +172,77 @@ def init_dft_tensors_cpu(mol, opts, DIIS_iters=9): DIIS_H[0,1:] = DIIS_H[1:,0] = 1 diis_history = (np.zeros((DIIS_iters, N**2)), np.zeros((DIIS_iters, N**2)), DIIS_H) - tensors = [E_nuc, density_matrix, kinetic, nuclear, O, grid_AO, - grid_weights, mask, input_floats, input_ints, L_inv, diis_history] + tensors = [E_nuc, density_matrix, kinetic, nuclear, O, mask, input_floats, input_ints, L_inv, diis_history] - return tensors, ERI, n_electrons_half, E_nuc, N, L_inv, grid_weights, grids.coords + return tensors, n_electrons_half, E_nuc, N, L_inv, grid_weights, grids.coords, grid_AO def nanoDFT(mol, opts): # Init DFT tensors on CPU using PySCF. - tensors, ERI, n_electrons_half, E_nuc, N, L_inv, grid_weights, grid_coords = init_dft_tensors_cpu(mol, opts) + tensors, n_electrons_half, E_nuc, N, L_inv, grid_weights, grid_coords, grid_AO = init_dft_tensors_cpu(mol, opts) - sharded_grid_AO = jnp.transpose(tensors[5], (1, 0, 2)) # (4,gsize,N) -> (gsize,4,N) - sharded_grid_weights = tensors[6] + sharded_grid_AO = jnp.transpose(grid_AO, (1, 0, 2)) # (4,gsize,N) -> (gsize,4,N) + sharded_grid_weights = grid_weights gsize = sharded_grid_AO.shape[0] - remainder = gsize % 16 + remainder = gsize % opts.ndevices if remainder != 0: sharded_grid_AO = jnp.pad(sharded_grid_AO, ((0,remainder), (0,0), (0,0)) ) sharded_grid_weights = jnp.pad(sharded_grid_weights, ((0,remainder)) ) - # tensors[6] = jnp.pad(tensors[6], ((0, remainder))) - # tensors[5] = jnp.pad(tensors[5], ((0,0), (0,remainder), (0,0)) ) - sharded_grid_AO = sharded_grid_AO.reshape(16, -1, 4, N) - sharded_grid_weights = sharded_grid_weights.reshape(16, -1) + sharded_grid_AO = sharded_grid_AO.reshape(opts.ndevices, -1, 4, N) + sharded_grid_weights = sharded_grid_weights.reshape(opts.ndevices, -1) - # tensors[5] = 0 - # tensors[6] = 0 + tensors.append(sharded_grid_AO) + tensors.append(sharded_grid_weights) # Run DFT algorithm (can be hardware accelerated). # we can have "ipus" argument instead, or pod16 as backend? - if opts.backend == "ipu": - jitted_nanoDFT = jax.pmap(partial(_nanoDFT, opts=opts, mol=mol), axis_name="p", backend=opts.backend, - in_axes=(None, None, None, None, None, None, None, None, None, None, None, (None, None, None), 0, 0, [0, 0])) - - - from sparse_symmetric_ERI import vmap_num_repetitions_fast - import time - start = time.time() + if opts.dense_ERI: + assert opts.ndevices == 1, "Only support '--dense_ERI True' for `--ndevices 1`. " + eri_in_axes = 0 + ERI = mol.intor("int2e_sph") + ERI = np.expand_dims(ERI, 0) + tensors.append(ERI) + else: + from sparse_symmetric_ERI import num_repetitions_fast, get_i_j distinct_ERI = mol.intor("int2e_sph", aosym="s8") - nonzero_indices = np.nonzero(distinct_ERI)[0].astype(np.int32) + print(distinct_ERI.size) + indxs = np.abs(distinct_ERI)<1e-7 + distinct_ERI[indxs] = 0 + nonzero_indices = np.nonzero(distinct_ERI)[0].astype(np.uint64) nonzero_distinct_ERI = distinct_ERI[nonzero_indices].astype(np.float32) - print("compute eri normalization ", time.time()-start) - rep = jax.jit(vmap_num_repetitions_fast, backend="cpu")(nonzero_indices) - print("perform normalization ", time.time()-start) - nonzero_distinct_ERI = nonzero_distinct_ERI / rep - print("pad remainder ", time.time()-start) - print(nonzero_distinct_ERI.shape, nonzero_indices.shape) - batches = 64 # perhaps make 10 batches? - nipu = 16 + ij, kl = get_i_j(nonzero_indices) + rep = num_repetitions_fast(ij, kl) + nonzero_distinct_ERI = nonzero_distinct_ERI / rep + batches = int(opts.batches) # perhaps make 10 batches? + nipu = opts.ndevices remainder = nonzero_indices.shape[0] % (nipu*batches) + if remainder != 0: - nonzero_indices = np.pad(nonzero_indices, (0,nipu*batches-remainder)) + print(nipu*batches-remainder, ij.shape) + ij = np.pad(ij, ((0,nipu*batches-remainder))) + kl = np.pad(kl, ((0,nipu*batches-remainder))) nonzero_distinct_ERI = np.pad(nonzero_distinct_ERI, (0,nipu*batches-remainder)) - print("reshape to 16 IPUs pmap", time.time()-start) - nonzero_indices = nonzero_indices.reshape(nipu, batches, -1) + ij = ij.reshape(nipu, batches, -1) + kl = kl.reshape(nipu, batches, -1) nonzero_distinct_ERI = nonzero_distinct_ERI.reshape(nipu, batches, -1) - print(nonzero_indices.shape) - print(nonzero_distinct_ERI.shape) - sparse_ERI = [nonzero_distinct_ERI, nonzero_indices] + i, j = get_i_j(ij.reshape(-1)) + k, l = get_i_j(kl.reshape(-1)) + nonzero_indices = np.vstack([i,j,k,l]).T.reshape(nipu, batches, -1, 4).astype(np.int16) + nonzero_indices = jax.lax.bitcast_convert_type(nonzero_indices, np.float16) - tensors.append(sharded_grid_AO) - tensors.append(sharded_grid_weights) + sparse_ERI = [nonzero_distinct_ERI, nonzero_indices] tensors.append(sparse_ERI) + eri_in_axes = [0,0] - vals = jitted_nanoDFT(*tensors) - logged_matrices, H_core, logged_energies = [np.asarray(a[0]).astype(np.float64) for a in vals] # Ensure CPU - else: - jitted_nanoDFT = jax.jit(partial(_nanoDFT, opts=opts, mol=mol), backend=opts.backend) - tensors = tensors + (sharded_grid_AO,) + jitted_nanoDFT = jax.pmap(partial(_nanoDFT, opts=opts, mol=mol), axis_name="p", backend=opts.backend, + in_axes=(None, None, None, None, None, None, None, None, None, (None, None, None), 0, 0, eri_in_axes)) - vals = jitted_nanoDFT(*tensors) - logged_matrices, H_core, logged_energies = [np.asarray(a).astype(np.float64) for a in vals] # Ensure CPU + jitted_nanoDFT(*tensors) + vals = jitted_nanoDFT(*tensors) + logged_matrices, H_core, logged_energies = [np.asarray(a[0]).astype(np.float64) for a in vals] # Ensure CPU # It's cheap to compute energy/hlgap on CPU in float64 from the logged values/matrices. logged_E_xc = logged_energies[:, 3].copy() @@ -402,7 +397,6 @@ def hcore_deriv(atm_id, aoslices, h1): # <\nabla|1/r|> return _grad_elec(weight, ao, eri, s1, h1aos, mol.natm, tuple([tuple(a) for a in aoslices.tolist()]), mask, mo_energy, mo_coeff) + _grad_nuc(charges, coords) def pyscf_reference(mol_str, opts): - from pyscf import __config__ __config__.dft_rks_RKS_grids_level = opts.level @@ -474,6 +468,12 @@ def nanoDFT_options( threads_int: int = 1, diis: bool = True, structure_optimization: bool = False, # AKA gradient descent on energy wrt nuclei + batches: int = 32, + ndevices: int = 1, + dense_ERI: bool = False, # whether to use sparse/distinct eri. + v: bool = False, # verbose + profile: bool = False # if we only want profile exit after IPU finishes. + ): """ nanoDFT @@ -526,11 +526,37 @@ def nanoDFT_options( # mol_str = [["H", (0, 0, n) for n in range(8*2*2*2*2)]] # N=256 basis = "6-31g" elif mol_str == "cgrid": - mol_str = [["C", (0, i*1.5, j*1.5)] for i in range(4) for j in range(4)] + mol_str = [["C", (0, i*1.5, j*1.5)] for i in range(10) for j in range(10)] + #mol_str = [["C", (0, i*1.5, j*1.5)] for i in range(2) for j in range(2)] + #mol_str = [["C", (0, i*1.5, j*1.5)] for i in range(5) for j in range(6)] #mol_str = [["C", (0, i*1.5, j*1.5)] for i in range(5) for j in range(5)] #mol_str = [["C", (0, i*1.5, j*1.5)] for i in range(6) for j in range(6)] #mol_str = [["C", (0, i*1.5, j*1.5)] for i in range(7) for j in range(7)] + elif mol_str == "c20": + mol_str = """C 1.56910 -0.65660 -0.93640 ; + C 1.76690 0.64310 -0.47200 ; + C 0.47050 -0.66520 -1.79270 ; + C 0.01160 0.64780 -1.82550 ; + C 0.79300 1.46730 -1.02840 ; + C -0.48740 -1.48180 -1.21570; + C -1.56350 -0.65720 -0.89520 ; + C -1.26940 0.64900 -1.27670 ; + C -0.00230 -1.96180 -0.00720 ; + C -0.76980 -1.45320 1.03590 ; + C -1.75760 -0.63800 0.47420 ; + C 1.28780 -1.45030 0.16290 ; + C 1.28960 -0.65950 1.30470; + C 0.01150 -0.64600 1.85330 ; + C 1.58300 0.64540 0.89840 ; + C 0.48480 1.43830 1.19370 ; + C -0.50320 0.64690 1.77530 ; + C -1.60620 0.67150 0.92310 ; + C -1.29590 1.48910 -0.16550; + C -0.01020 1.97270 -0.00630 ;""" + + + args = locals() mol_str = args["mol_str"] del args["mol_str"] diff --git a/pyscf_ipu/nanoDFT/sparse_symmetric_ERI.py b/pyscf_ipu/nanoDFT/sparse_symmetric_ERI.py index dd7c661..d72a5e6 100644 --- a/pyscf_ipu/nanoDFT/sparse_symmetric_ERI.py +++ b/pyscf_ipu/nanoDFT/sparse_symmetric_ERI.py @@ -16,10 +16,10 @@ def get_i_j(val): j = (((val - i) - (i**2 - val))//2) return i, j -def cpu_ijkl(value, symmetry, f): +def cpu_ijkl(value, symmetry, N, f): i, j, k, l = value[0].astype(np.uint32), value[1].astype(np.uint32), value[2].astype(np.uint32), value[3].astype(np.uint32) - return f(i,j,k,l,symmetry) -cpu_ijkl = jax.vmap(cpu_ijkl, in_axes=(0, None, None)) + return f(i,j,k,l,symmetry,N) +cpu_ijkl = jax.vmap(cpu_ijkl, in_axes=(0, None, None, None)) @partial(jax.jit, backend="ipu") def ipu_ijkl(nonzero_indices, symmetry, N): @@ -79,10 +79,10 @@ def num_repetitions_fast(ij, kl): return repetitions -indices_func = lambda i,j,k,l,symmetry: jnp.array([i*N+j, j*N+i, i*N+j, j*N+i, k*N+l, l*N+k, k*N+l, l*N+k, - k*N+l, k*N+l, l*N+k, l*N+k, i*N+j, i*N+j, j*N+i, j*N+i, - k*N+j, k*N+i, l*N+j, l*N+i, i*N+l, i*N+k, j*N+l, j*N+k, - i*N+l, j*N+l, i*N+k, j*N+k, k*N+j, l*N+j, k*N+i, l*N+i])[symmetry] +indices_func = lambda i,j,k,l,symmetry,N: jnp.array([i*N+j, j*N+i, i*N+j, j*N+i, k*N+l, l*N+k, k*N+l, l*N+k, + k*N+l, k*N+l, l*N+k, l*N+k, i*N+j, i*N+j, j*N+i, j*N+i, + k*N+j, k*N+i, l*N+j, l*N+i, i*N+l, i*N+k, j*N+l, j*N+k, + i*N+l, j*N+l, i*N+k, j*N+k, k*N+j, l*N+j, k*N+i, l*N+i])[symmetry] def sparse_symmetric_einsum(nonzero_distinct_ERI, nonzero_indices, dm, backend): @@ -91,6 +91,16 @@ def sparse_symmetric_einsum(nonzero_distinct_ERI, nonzero_indices, dm, backend): diff_JK = jnp.zeros(dm.shape) N = int(np.sqrt(dm.shape[0])) + + dnums = jax.lax.GatherDimensionNumbers( + offset_dims=(), + collapsed_slice_dims=(0,), + start_index_map=(0,)) + scatter_dnums = jax.lax.ScatterDimensionNumbers( + update_window_dims=(), + inserted_window_dims=(0,), + scatter_dims_to_operand_dims=(0,)) + def iteration(symmetry, vals): diff_JK = vals is_K_matrix = (symmetry >= 8) @@ -104,21 +114,25 @@ def sequentialized_iter(i, vals): indices = jax.lax.bitcast_convert_type(indices, np.int16).astype(np.int32) eris = nonzero_distinct_ERI[i] - print(indices.shape) - if backend == "cpu": dm_indices = cpu_ijkl(indices, symmetry+is_K_matrix*8, indices_func) - else: dm_indices = ipu_ijkl(indices, symmetry+is_K_matrix*8, N) - dm_values = jnp.take(dm, dm_indices, axis=0) + if backend == "cpu": dm_indices = cpu_ijkl(indices, symmetry+is_K_matrix*8, N, indices_func).reshape(-1, 1) + else: dm_indices = ipu_ijkl(indices, symmetry+is_K_matrix*8, N) .reshape(-1, 1) + # dm_values = jnp.take(dm, indices, axis=0) # for our special case the 50 lines of code reduces to the one line below. + dm_values = jax.lax.gather(dm, dm_indices, dimension_numbers=dnums, slice_sizes=(1,), mode=jax.lax.GatherScatterMode.FILL_OR_DROP) dm_values = dm_values.at[:].mul( eris ) # this is prod, but re-use variable for inplace update. - if backend == "cpu": ss_indices = cpu_ijkl(indices, symmetry+8+is_K_matrix*8, indices_func) - else: ss_indices = ipu_ijkl(indices, symmetry+8+is_K_matrix*8, N) - diff_JK = diff_JK + jax.ops.segment_sum(dm_values, ss_indices, N**2) * (-HYB_B3LYP/2)**is_K_matrix + if backend == "cpu": ss_indices = cpu_ijkl(indices, symmetry+8+is_K_matrix*8, N, indices_func) .reshape(-1,1) + else: ss_indices = ipu_ijkl(indices, symmetry+8+is_K_matrix*8, N).astype(np.int32).reshape(-1,1) + # diff_JK = diff_JK + jax.lax.segment_sum( ...) # for our special case the 100 lines of code reduces to the one line below. + diff_JK = diff_JK + jax.lax.scatter_add(jnp.zeros((N**2,)), + ss_indices, dm_values, + scatter_dnums, indices_are_sorted=True, unique_indices=True, mode=jax.lax.GatherScatterMode.FILL_OR_DROP)\ + *(-HYB_B3LYP/2)**is_K_matrix return diff_JK - batches = nonzero_indices.shape[0] # before pmap, tensor had shape (nipus, batches, -1) so [0]=batches after pmap + batches = nonzero_indices.shape[0] # Before pmap, tensor had shape (nipus, batches, -1) so [0]=batches after pmap diff_JK = jax.lax.fori_loop(0, batches, sequentialized_iter, diff_JK) return diff_JK @@ -151,6 +165,8 @@ def sequentialized_iter(i, vals): mol.build() N = mol.nao_nr() print("N %i"%mol.nao_nr()) + print("NxN:", (N**2, N**2)) + print("Naive operations: ", N**4*2/10**9, "[Giga]") if not args.skip: dense_ERI = mol.intor("int2e_sph", aosym="s1") distinct_ERI = mol.intor("int2e_sph", aosym="s8") distinct_ERI[np.abs(distinct_ERI)<1e-9] = 0 # zero out stuff @@ -163,6 +179,7 @@ def sequentialized_iter(i, vals): nonzero_indices = np.nonzero(distinct_ERI)[0].astype(np.uint64) nonzero_distinct_ERI = distinct_ERI[nonzero_indices].astype(np.float32) + print("Nonzero Operations:", nonzero_indices.size*8*2/10**9, "[Giga]") ij, kl = get_i_j(nonzero_indices) rep = num_repetitions_fast(ij, kl) nonzero_distinct_ERI = nonzero_distinct_ERI / rep