Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pmap nanodft #84

Merged
merged 10 commits into from
Sep 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
208 changes: 117 additions & 91 deletions pyscf_ipu/nanoDFT/experimental_pmap_nanoDFT.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
import os
os.environ["OMP_NUM_THREADS"] = "16"
import jax
import jax.numpy as jnp
import numpy as np
Expand Down Expand Up @@ -39,22 +41,26 @@ def energy(density_matrix, H_core, diff_JK, E_xc, E_nuc, _np=jax.numpy):
def nanoDFT_iteration(i, vals, opts, mol):
"""Each call updates density_matrix attempting to minimize energy(density_matrix, ... ). """
density_matrix, V_xc, O, H_core, L_inv, diff_JK = vals[:6] # All (N, N) matrices
E_nuc, occupancy, ERI, grid_weights, sharded_grid_weights, grid_AO, sharded_grid_AO, diis_history, log= vals[6:] # Varying types/shapes.

print("@@@@@@@@@@")
mb = 0
for t in vals:
try:
if type(t) != type(()) and len(np.shape(t)) > 0:
print(np.shape(t), t.nbytes/10**6)
mb += t.nbytes/10**6
except:
print(type(t))
for a in ERI:
print(a.shape, a.nbytes)
print(mb)
print("@@@@@@@@@@")
print("")
E_nuc, occupancy, ERI, sharded_grid_weights, sharded_grid_AO, diis_history, log= vals[6:] # Varying types/shapes.

if opts.v:
print("---------- MEMORY CONSUMPTION ----------")
MB = 0
for t in vals:
try:
if type(t) != type(()) and len(np.shape(t)) > 0:
print(t.nbytes/10**6, t.shape, t.dtype)
MB += t.nbytes/10**6
except:
print(type(t))
print("ERI")
for a in ERI: # prints weird in dense_ERI case
print( a.nbytes/10**6, a.shape)
MB += a.nbytes / 10**6
print("__________")
print("Total: ", MB)
print("----------------------------------------")
print("")

# Step 1: Update Hamiltonian (optionally use DIIS to improve DFT convergence).
H = H_core + diff_JK + V_xc # (N, N)
Expand All @@ -65,22 +71,23 @@ def nanoDFT_iteration(i, vals, opts, mol):

# Step 3: Use result from eigenproblem to update density_matrix.
density_matrix = (eigvects*occupancy*2) @ eigvects.T # (N, N)
E_xc, V_xc = exchange_correlation(density_matrix, grid_AO, sharded_grid_AO, grid_weights, sharded_grid_weights) # float (N, N)
diff_JK = get_JK(density_matrix, ERI, opts, mol) #(4,gsize,N) # (N, N)
E_xc, V_xc = exchange_correlation(density_matrix, sharded_grid_AO, sharded_grid_weights) # float (N, N)
diff_JK = get_JK(density_matrix, ERI, opts.dense_ERI) #(4,gsize,N) # (N, N)

# Log SCF matrices and energies (not used by DFT algorithm).
#log["matrices"] = log["matrices"].at[i].set(jnp.stack((density_matrix, J, K, H))) # (iterations, 4, N, N) ## removing this reduces memory from 50kb to 18kb
N = density_matrix.shape[0]
log["matrices"] = jax.lax.dynamic_update_slice(log["matrices"], density_matrix.reshape(1, 1, N, N), (i, 0, 0, 0))
log["matrices"] = jax.lax.dynamic_update_slice(log["matrices"], diff_JK.reshape(1, 1, N, N), (i, 1, 0, 0))
log["matrices"] = jax.lax.dynamic_update_slice(log["matrices"], H.reshape(1, 1, N, N), (i, 2, 0, 0))

log["energy"] = log["energy"].at[i].set(energy(density_matrix, H_core, diff_JK, E_xc, E_nuc)) # (iterations, 6)
log["matrices"] = jax.lax.dynamic_update_slice(log["matrices"], diff_JK. reshape(1, 1, N, N), (i, 1, 0, 0))
log["matrices"] = jax.lax.dynamic_update_slice(log["matrices"], H. reshape(1, 1, N, N), (i, 2, 0, 0))
log["energy"] = log["energy"].at[i].set(energy(density_matrix, H_core, diff_JK, E_xc, E_nuc)) # (iterations, 6)

return [density_matrix, V_xc, O, H_core, L_inv, diff_JK, E_nuc, occupancy, ERI, grid_weights, sharded_grid_weights, grid_AO, sharded_grid_AO, diis_history, log]
return [density_matrix, V_xc, O, H_core, L_inv, diff_JK, E_nuc, occupancy, ERI, sharded_grid_weights, sharded_grid_AO, diis_history, log]

def exchange_correlation(density_matrix, grid_AO, sharded_grid_AO, grid_weights, sharded_grid_weights):
def exchange_correlation(density_matrix, sharded_grid_AO, sharded_grid_weights):
"""Compute exchange correlation integral using atomic orbitals (AO) evalauted on a grid. """
# TODO: we can also use sparsity on grid! Visualization shows it's similarly sparse to ERI!

# Perfectly SIMD parallelizable over grid_size axis.
# Only need one reduce_sum in the end.
grid_AO_dm = sharded_grid_AO[0] @ density_matrix # (gsize, N) @ (N, N) -> (gsize, N)
Expand All @@ -99,40 +106,32 @@ def exchange_correlation(density_matrix, grid_AO, sharded_grid_AO, grid_weights,
V_xc = V_xc + V_xc.T # (N, N)
return E_xc, V_xc # (float) (N, N)

def get_JK(density_matrix, ERI, opts, mol):
def get_JK(density_matrix, ERI, dense_ERI):
"""Computes the (N, N) matrices J and K. Density matrix is (N, N) and ERI is (N, N, N, N). """
N = density_matrix.shape[0]

def sparse_mult(sparse, vector):
rows, cols, values = sparse
in_ = vector.take(cols, axis=0)
prod = in_*values[:, None]
return jax.ops.segment_sum(prod, rows, N**2)

if opts.backend == "ipu":
from sparse_symmetric_ERI import sparse_symmetric_einsum
mask = jnp.ones(1)
diff_JK = sparse_symmetric_einsum(ERI[0], ERI[1], density_matrix, mask, opts.backend)
else:
if dense_ERI:
J = jnp.einsum('ijkl,ji->kl', ERI, density_matrix) # (N, N)
K = jnp.einsum('ijkl,jk->il', ERI, density_matrix) # (N, N)
diff_JK = J - (K / 2 * HYB_B3LYP)
else:
from sparse_symmetric_ERI import sparse_symmetric_einsum
diff_JK = sparse_symmetric_einsum(ERI[0], ERI[1], density_matrix, opts.backend)

diff_JK = diff_JK.reshape(N, N)

return diff_JK # (N, N) (N, N),

def _nanoDFT(E_nuc, density_matrix, kinetic, nuclear, O, grid_AO, grid_weights,
def _nanoDFT(E_nuc, density_matrix, kinetic, nuclear, O,
mask, _input_floats, _input_ints, L_inv, diis_history, sharded_grid_AO, sharded_grid_weights, ERI, opts, mol):
# Utilize the IPUs MIMD parallism to compute the electron repulsion integrals (ERIs) in parallel.
#if opts.backend == "ipu": ERI = electron_repulsion_integrals(_input_floats, _input_ints, mol, opts.threads_int, opts.intv)
#else: pass # Compute on CPU.

sharded_grid_AO = jnp.transpose(sharded_grid_AO, (1,0,2)) # (padded_gsize/16, 4, N) -> (4, pgsize, N)

# Precompute the remaining tensors.
E_xc, V_xc = exchange_correlation(density_matrix, grid_AO, sharded_grid_AO, grid_weights, sharded_grid_weights) # float (N, N)
diff_JK = get_JK(density_matrix, ERI, opts, mol) # (N, N) (N, N)
E_xc, V_xc = exchange_correlation(density_matrix, sharded_grid_AO, sharded_grid_weights) # float (N, N)
diff_JK = get_JK(density_matrix, ERI, opts.dense_ERI) # (N, N) (N, N)
H_core = kinetic + nuclear # (N, N)

# Log matrices from all DFT iterations (not used by DFT algorithm).
Expand All @@ -141,7 +140,7 @@ def _nanoDFT(E_nuc, density_matrix, kinetic, nuclear, O, grid_AO, grid_weights,

# Perform DFT iterations.
log = jax.lax.fori_loop(0, opts.its, partial(nanoDFT_iteration, opts=opts, mol=mol), [density_matrix, V_xc, O, H_core, L_inv, diff_JK, # all (N, N) matrices
E_nuc, mask, ERI, grid_weights, sharded_grid_weights, grid_AO, sharded_grid_AO, diis_history, log])[-1]
E_nuc, mask, ERI, sharded_grid_weights, sharded_grid_AO, diis_history, log])[-1]
return log["matrices"], H_core, log["energy"]

def init_dft_tensors_cpu(mol, opts, DIIS_iters=9):
Expand All @@ -164,9 +163,6 @@ def init_dft_tensors_cpu(mol, opts, DIIS_iters=9):
nuclear = mol.intor_symmetric('int1e_nuc') # (N,N)
O = mol.intor_symmetric('int1e_ovlp') # (N,N)
L_inv = np.linalg.inv(np.linalg.cholesky(O)) # (N,N)
#if opts.backend != "ipu": ERI = mol.intor("int2e_sph") # (N,N,N,N)=(66,66,66,66) for C6H6.
#else: ERI = None # will be computed on device
ERI = 0# todo: move ERI defined later in here.

input_floats, input_ints = 0, 0 #prepare_electron_repulsion_integrals(mol)[:2]
mask = np.concatenate([np.ones(n_electrons_half), np.zeros(N-n_electrons_half)])
Expand All @@ -176,78 +172,77 @@ def init_dft_tensors_cpu(mol, opts, DIIS_iters=9):
DIIS_H[0,1:] = DIIS_H[1:,0] = 1
diis_history = (np.zeros((DIIS_iters, N**2)), np.zeros((DIIS_iters, N**2)), DIIS_H)

tensors = [E_nuc, density_matrix, kinetic, nuclear, O, grid_AO,
grid_weights, mask, input_floats, input_ints, L_inv, diis_history]
tensors = [E_nuc, density_matrix, kinetic, nuclear, O, mask, input_floats, input_ints, L_inv, diis_history]

return tensors, ERI, n_electrons_half, E_nuc, N, L_inv, grid_weights, grids.coords
return tensors, n_electrons_half, E_nuc, N, L_inv, grid_weights, grids.coords, grid_AO

def nanoDFT(mol, opts):
# Init DFT tensors on CPU using PySCF.
tensors, ERI, n_electrons_half, E_nuc, N, L_inv, grid_weights, grid_coords = init_dft_tensors_cpu(mol, opts)
tensors, n_electrons_half, E_nuc, N, L_inv, grid_weights, grid_coords, grid_AO = init_dft_tensors_cpu(mol, opts)

sharded_grid_AO = jnp.transpose(tensors[5], (1, 0, 2)) # (4,gsize,N) -> (gsize,4,N)
sharded_grid_weights = tensors[6]
sharded_grid_AO = jnp.transpose(grid_AO, (1, 0, 2)) # (4,gsize,N) -> (gsize,4,N)
sharded_grid_weights = grid_weights
gsize = sharded_grid_AO.shape[0]

remainder = gsize % 16
remainder = gsize % opts.ndevices
if remainder != 0:
sharded_grid_AO = jnp.pad(sharded_grid_AO, ((0,remainder), (0,0), (0,0)) )
sharded_grid_weights = jnp.pad(sharded_grid_weights, ((0,remainder)) )
# tensors[6] = jnp.pad(tensors[6], ((0, remainder)))
# tensors[5] = jnp.pad(tensors[5], ((0,0), (0,remainder), (0,0)) )
sharded_grid_AO = sharded_grid_AO.reshape(16, -1, 4, N)
sharded_grid_weights = sharded_grid_weights.reshape(16, -1)
sharded_grid_AO = sharded_grid_AO.reshape(opts.ndevices, -1, 4, N)
sharded_grid_weights = sharded_grid_weights.reshape(opts.ndevices, -1)

# tensors[5] = 0
# tensors[6] = 0
tensors.append(sharded_grid_AO)
tensors.append(sharded_grid_weights)

# Run DFT algorithm (can be hardware accelerated).
# we can have "ipus" argument instead, or pod16 as backend?
if opts.backend == "ipu":
jitted_nanoDFT = jax.pmap(partial(_nanoDFT, opts=opts, mol=mol), axis_name="p", backend=opts.backend,
in_axes=(None, None, None, None, None, None, None, None, None, None, None, (None, None, None), 0, 0, [0, 0]))


from sparse_symmetric_ERI import vmap_num_repetitions_fast
import time
start = time.time()
if opts.dense_ERI:
assert opts.ndevices == 1, "Only support '--dense_ERI True' for `--ndevices 1`. "
eri_in_axes = 0
ERI = mol.intor("int2e_sph")
ERI = np.expand_dims(ERI, 0)
tensors.append(ERI)
else:
from sparse_symmetric_ERI import num_repetitions_fast, get_i_j
distinct_ERI = mol.intor("int2e_sph", aosym="s8")
nonzero_indices = np.nonzero(distinct_ERI)[0].astype(np.int32)
print(distinct_ERI.size)
indxs = np.abs(distinct_ERI)<1e-7
distinct_ERI[indxs] = 0
nonzero_indices = np.nonzero(distinct_ERI)[0].astype(np.uint64)
nonzero_distinct_ERI = distinct_ERI[nonzero_indices].astype(np.float32)
print("compute eri normalization ", time.time()-start)
rep = jax.jit(vmap_num_repetitions_fast, backend="cpu")(nonzero_indices)
print("perform normalization ", time.time()-start)
nonzero_distinct_ERI = nonzero_distinct_ERI / rep

print("pad remainder ", time.time()-start)
print(nonzero_distinct_ERI.shape, nonzero_indices.shape)
batches = 64 # perhaps make 10 batches?
nipu = 16
ij, kl = get_i_j(nonzero_indices)
rep = num_repetitions_fast(ij, kl)
nonzero_distinct_ERI = nonzero_distinct_ERI / rep
batches = int(opts.batches) # perhaps make 10 batches?
nipu = opts.ndevices
remainder = nonzero_indices.shape[0] % (nipu*batches)

if remainder != 0:
nonzero_indices = np.pad(nonzero_indices, (0,nipu*batches-remainder))
print(nipu*batches-remainder, ij.shape)
ij = np.pad(ij, ((0,nipu*batches-remainder)))
kl = np.pad(kl, ((0,nipu*batches-remainder)))
nonzero_distinct_ERI = np.pad(nonzero_distinct_ERI, (0,nipu*batches-remainder))

print("reshape to 16 IPUs pmap", time.time()-start)
nonzero_indices = nonzero_indices.reshape(nipu, batches, -1)
ij = ij.reshape(nipu, batches, -1)
kl = kl.reshape(nipu, batches, -1)
nonzero_distinct_ERI = nonzero_distinct_ERI.reshape(nipu, batches, -1)
print(nonzero_indices.shape)
print(nonzero_distinct_ERI.shape)

sparse_ERI = [nonzero_distinct_ERI, nonzero_indices]
i, j = get_i_j(ij.reshape(-1))
k, l = get_i_j(kl.reshape(-1))
nonzero_indices = np.vstack([i,j,k,l]).T.reshape(nipu, batches, -1, 4).astype(np.int16)
nonzero_indices = jax.lax.bitcast_convert_type(nonzero_indices, np.float16)

tensors.append(sharded_grid_AO)
tensors.append(sharded_grid_weights)
sparse_ERI = [nonzero_distinct_ERI, nonzero_indices]
tensors.append(sparse_ERI)
eri_in_axes = [0,0]

vals = jitted_nanoDFT(*tensors)
logged_matrices, H_core, logged_energies = [np.asarray(a[0]).astype(np.float64) for a in vals] # Ensure CPU
else:
jitted_nanoDFT = jax.jit(partial(_nanoDFT, opts=opts, mol=mol), backend=opts.backend)
tensors = tensors + (sharded_grid_AO,)
jitted_nanoDFT = jax.pmap(partial(_nanoDFT, opts=opts, mol=mol), axis_name="p", backend=opts.backend,
in_axes=(None, None, None, None, None, None, None, None, None, (None, None, None), 0, 0, eri_in_axes))

vals = jitted_nanoDFT(*tensors)
logged_matrices, H_core, logged_energies = [np.asarray(a).astype(np.float64) for a in vals] # Ensure CPU
jitted_nanoDFT(*tensors)
vals = jitted_nanoDFT(*tensors)
logged_matrices, H_core, logged_energies = [np.asarray(a[0]).astype(np.float64) for a in vals] # Ensure CPU

# It's cheap to compute energy/hlgap on CPU in float64 from the logged values/matrices.
logged_E_xc = logged_energies[:, 3].copy()
Expand Down Expand Up @@ -402,7 +397,6 @@ def hcore_deriv(atm_id, aoslices, h1): # <\nabla|1/r|>
return _grad_elec(weight, ao, eri, s1, h1aos, mol.natm, tuple([tuple(a) for a in aoslices.tolist()]), mask, mo_energy, mo_coeff) + _grad_nuc(charges, coords)

def pyscf_reference(mol_str, opts):

from pyscf import __config__
__config__.dft_rks_RKS_grids_level = opts.level

Expand Down Expand Up @@ -474,6 +468,12 @@ def nanoDFT_options(
threads_int: int = 1,
diis: bool = True,
structure_optimization: bool = False, # AKA gradient descent on energy wrt nuclei
batches: int = 32,
ndevices: int = 1,
dense_ERI: bool = False, # whether to use sparse/distinct eri.
v: bool = False, # verbose
profile: bool = False # if we only want profile exit after IPU finishes.

):
"""
nanoDFT
Expand Down Expand Up @@ -526,11 +526,37 @@ def nanoDFT_options(
# mol_str = [["H", (0, 0, n) for n in range(8*2*2*2*2)]] # N=256
basis = "6-31g"
elif mol_str == "cgrid":
mol_str = [["C", (0, i*1.5, j*1.5)] for i in range(4) for j in range(4)]
mol_str = [["C", (0, i*1.5, j*1.5)] for i in range(10) for j in range(10)]
#mol_str = [["C", (0, i*1.5, j*1.5)] for i in range(2) for j in range(2)]
#mol_str = [["C", (0, i*1.5, j*1.5)] for i in range(5) for j in range(6)]
#mol_str = [["C", (0, i*1.5, j*1.5)] for i in range(5) for j in range(5)]
#mol_str = [["C", (0, i*1.5, j*1.5)] for i in range(6) for j in range(6)]
#mol_str = [["C", (0, i*1.5, j*1.5)] for i in range(7) for j in range(7)]
elif mol_str == "c20":
mol_str = """C 1.56910 -0.65660 -0.93640 ;
C 1.76690 0.64310 -0.47200 ;
C 0.47050 -0.66520 -1.79270 ;
C 0.01160 0.64780 -1.82550 ;
C 0.79300 1.46730 -1.02840 ;
C -0.48740 -1.48180 -1.21570;
C -1.56350 -0.65720 -0.89520 ;
C -1.26940 0.64900 -1.27670 ;
C -0.00230 -1.96180 -0.00720 ;
C -0.76980 -1.45320 1.03590 ;
C -1.75760 -0.63800 0.47420 ;
C 1.28780 -1.45030 0.16290 ;
C 1.28960 -0.65950 1.30470;
C 0.01150 -0.64600 1.85330 ;
C 1.58300 0.64540 0.89840 ;
C 0.48480 1.43830 1.19370 ;
C -0.50320 0.64690 1.77530 ;
C -1.60620 0.67150 0.92310 ;
C -1.29590 1.48910 -0.16550;
C -0.01020 1.97270 -0.00630 ;"""




args = locals()
mol_str = args["mol_str"]
del args["mol_str"]
Expand Down
Loading