Skip to content

Commit

Permalink
100x+ speedup in shell computation with c++ and openmp; can generate …
Browse files Browse the repository at this point in the history
…shell indices and sizes for 100 atoms
  • Loading branch information
mihaipgc committed Nov 1, 2023
1 parent c28e84e commit 4e36058
Show file tree
Hide file tree
Showing 4 changed files with 190 additions and 50 deletions.
15 changes: 15 additions & 0 deletions pyscf_ipu/nanoDFT/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.

CXX = g++
CXXFLAGS = -fopenmp -O3 -Wall -shared -std=c++11 -fPIC
PYBIND11_INCLUDES = $(shell python3 -m pybind11 --includes)
PYTHON_EXTENSION_SUFFIX = $(shell python3-config --extension-suffix)
TARGET = cpp_shell_gen$(PYTHON_EXTENSION_SUFFIX)

all: $(TARGET)

$(TARGET): cpp_shell_gen.cpp
$(CXX) $(CXXFLAGS) $(PYBIND11_INCLUDES) $< -o $@

clean:
rm -f $(TARGET)
105 changes: 105 additions & 0 deletions pyscf_ipu/nanoDFT/cpp_shell_gen.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
// Copyright (c) 2023 Graphcore Ltd. All rights reserved.

#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
#include <pybind11/stl.h>
#include <unordered_set>
#include <omp.h>

namespace py = pybind11;

// Define a custom hash function for std::pair<int, int>
struct pair_hash {
template <class T1, class T2>
std::size_t operator () (std::pair<T1, T2> const& pair) const {
auto h1 = std::hash<T1>{}(pair.first);
auto h2 = std::hash<T2>{}(pair.second);
return h1 ^ h2;
}
};

std::tuple<py::array_t<int>, py::array_t<int>, int> compute_indices(int n_bas, py::array_t<int> ao_loc_array, std::unordered_set<std::pair<int, int>, pair_hash> considered_indices) {
py::buffer_info ao_loc_info = ao_loc_array.request();
int* ao_loc_ptr = static_cast<int*>(ao_loc_info.ptr);

py::ssize_t n_upper_bound = (n_bas * (n_bas) / 2) * (n_bas * (n_bas - 1) / 2);
// py::print(n_upper_bound);
auto input_ijkl_array = py::array_t<int>({n_upper_bound, (py::ssize_t)4});
input_ijkl_array[py::make_tuple(py::ellipsis())] = -1;
auto output_sizes_array = py::array_t<int>({n_upper_bound, (py::ssize_t)5});
output_sizes_array[py::make_tuple(py::ellipsis())] = -1;

py::buffer_info input_ijkl_info = input_ijkl_array.request();
int* input_ijkl_ptr = static_cast<int*>(input_ijkl_info.ptr);

py::buffer_info output_sizes_info = output_sizes_array.request();
int* output_sizes_ptr = static_cast<int*>(output_sizes_info.ptr);

py::ssize_t num_calls = 0;
#pragma omp parallel for reduction(+:num_calls) // Parallelize the outermost loop
for (py::ssize_t i = 0; i < n_bas; ++i) {
py::ssize_t partial_num_calls = 0;
// py::print(".", py::arg("end")="", py::arg("flush") = true); // do not print from threads
for (py::ssize_t j = 0; j <= i; ++j) {
for (py::ssize_t k = i; k < n_bas; ++k) {
for (py::ssize_t l = 0; l <= k; ++l) {
int di = ao_loc_ptr[i + 1] - ao_loc_ptr[i];
int dj = ao_loc_ptr[j + 1] - ao_loc_ptr[j];
int dk = ao_loc_ptr[k + 1] - ao_loc_ptr[k];
int dl = ao_loc_ptr[l + 1] - ao_loc_ptr[l];

bool found_nonzero = false;
for (int bi = ao_loc_ptr[i]; bi < ao_loc_ptr[i + 1]; ++bi) {
for (int bj = ao_loc_ptr[j]; bj < ao_loc_ptr[j + 1]; ++bj) {
if (considered_indices.count({bi, bj}) > 0) {
for (int bk = ao_loc_ptr[k]; bk < ao_loc_ptr[k + 1]; ++bk) {
if (bk >= bi) {
int mla = ao_loc_ptr[l];
if (bk == bi) {
mla = std::max(bj, ao_loc_ptr[l]);
}
for (int bl = mla; bl < ao_loc_ptr[l + 1]; ++bl) {
if (considered_indices.count({bk, bl}) > 0) {
found_nonzero = true;
break;
}
}
if (found_nonzero) break;
}
}
if (found_nonzero) break;
}
if (found_nonzero) break;
}
if (found_nonzero) break;
}

if (!found_nonzero) continue;

py::ssize_t offset_ijkl = (i * (n_bas) / 2) * (n_bas * (n_bas - 1) / 2) * 4 + partial_num_calls * 4;
input_ijkl_ptr[offset_ijkl + 0] = i;
input_ijkl_ptr[offset_ijkl + 1] = j;
input_ijkl_ptr[offset_ijkl + 2] = k;
input_ijkl_ptr[offset_ijkl + 3] = l;

py::ssize_t offset_sizes = (i * (n_bas) / 2) * (n_bas * (n_bas - 1) / 2) * 5 +partial_num_calls * 5;
output_sizes_ptr[offset_sizes + 0] = di;
output_sizes_ptr[offset_sizes + 1] = dj;
output_sizes_ptr[offset_sizes + 2] = dk;
output_sizes_ptr[offset_sizes + 3] = dl;
output_sizes_ptr[offset_sizes + 4] = di * dj * dk * dl;

partial_num_calls += 1;

}
}
}
num_calls += partial_num_calls;
}

return std::make_tuple(input_ijkl_array, output_sizes_array, num_calls);
}

PYBIND11_MODULE(cpp_shell_gen, m) {
m.def("compute_indices", &compute_indices);
}
5 changes: 3 additions & 2 deletions pyscf_ipu/nanoDFT/nanoDFT.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ def nanoDFT(mol, opts):
ERI = [nonzero_distinct_ERI, nonzero_indices]
eri_in_axes = [0,0]

input_ijkl, _, _, _ = gen_shells(mol, opts.screen_tol, nipu)
input_ijkl, _, _, _ = gen_shells(mol, opts.screen_tol, nipu, fast_shells=opts.fast_shells)
grid_ijkl = np.concatenate([np.array(ijkl, dtype=int).reshape(nipu, -1) for ijkl in input_ijkl], axis=-1)

#jitted_nanoDFT = jax.jit(partial(_nanoDFT, opts=opts, mol=mol), backend=opts.backend)
Expand Down Expand Up @@ -577,7 +577,8 @@ def nanoDFT_options(
profile: bool = False, # if we only want profile exit after IPU finishes.
vis_num_error: bool = False,
molecule_name: str = None,
screen_tol: float = 1e-9
screen_tol: float = 1e-9,
fast_shells: bool = False
):
"""
nanoDFT
Expand Down
115 changes: 67 additions & 48 deletions pyscf_ipu/nanoDFT/sparse_symmetric_intor_ERI.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from functools import partial, lru_cache
from icecream import ic
from tqdm import tqdm

jax.config.update('jax_platform_name', "cpu")
#jax.config.update('jax_enable_x64', True)
HYB_B3LYP = 0.2
Expand Down Expand Up @@ -69,7 +70,7 @@ def num_repetitions_fast_4d(i, j, k, l, xnp=np, dtype=np.uint64):

# gen_shells is designed to be called multiple times to simplify the rest of the code, but compute shells only once (lru_cache)
@lru_cache(maxsize=None) # cache results
def gen_shells(mol, tolerance, ndevices):
def gen_shells(mol, tolerance, ndevices, fast_shells=True):
N = mol.nao_nr()
bas = mol._bas
n_bas = bas.shape[0]
Expand All @@ -90,60 +91,77 @@ def gen_shells(mol, tolerance, ndevices):
lst_abab[c] = abab
lst_ab[c, :] = (a, b)
abab_max = np.max(lst_abab)
considered_indices = set([(a,b) for abab, (a, b) in tqdm(zip(lst_abab, lst_ab)) if abab*abab_max >= tolerance**2])
considered_indices_list = [(a,b) for abab, (a, b) in tqdm(zip(lst_abab, lst_ab)) if abab*abab_max >= tolerance**2]
considered_indices_array = np.array(considered_indices_list)
considered_indices = set(considered_indices_list)
else:
lst_ab = np.zeros((N*(N+1)//2, 2), dtype=np.int32)
tril_idx = np.tril_indices(N)
considered_indices = set([(a, b) for a, b in zip(tril_idx[0], tril_idx[1])])
considered_indices_list = [(a, b) for a, b in zip(tril_idx[0], tril_idx[1])]
considered_indices_array = np.array(considered_indices_list)
considered_indices = set(considered_indices_list)
print('n_bas', n_bas)
print('ao_loc', ao_loc)

# Step 2: Remove zeros by orthogonality and match indices with shells.
# Precompute Fill input_ijkl and output_sizes with the necessary indices.
ortho_pattern = np.array([(i+3)%5!=0 for i in range(N)]) # hardcoded
n_upper_bound = (n_bas*(n_bas-1))**2
input_ijkl = np.zeros((n_upper_bound, 4), dtype=np.int32)
output_sizes = np.zeros((n_upper_bound, 5))

num_calls = 0

for i in tqdm(range(n_bas), desc="Computing shells"): # consider all shells << all ijkl
for j in range(i+1):
for k in range(i, n_bas):
for l in range(k+1):
di, dj, dk, dl = [ao_loc[z+1] - ao_loc[z] for z in [i,j,k,l]]

found_nonzero = False
# check i,j boxes
for bi in range(ao_loc[i], ao_loc[i+1]):
for bj in range(ao_loc[j], ao_loc[j+1]):
if (bi, bj) in considered_indices: # if ij box is considered
# check if kl pairs are considered
for bk in range(ao_loc[k], ao_loc[k+1]):
if bk>=bi: # apply symmetry - tril fade vertical
mla = ao_loc[l]
if bk == bi:
mla = max(bj, ao_loc[l]) # apply symmetry - tril fade horizontal
for bl in range(mla, ao_loc[l+1]):
if (bk, bl) in considered_indices:
# apply grid pattern to find final nonzeros
if False and not ortho_pattern[bi] ^ ortho_pattern[bj] ^ ortho_pattern[bk] ^ ortho_pattern[bl]:
found_nonzero = True
break
else:
found_nonzero = True
break
if found_nonzero: break

if fast_shells:
import cpp_shell_gen
print('Computing shells', end=' ')
test_input_ijkl, test_output_sizes, num_calls = cpp_shell_gen.compute_indices(n_bas, ao_loc, considered_indices)
print('done.')
print(test_input_ijkl.shape)
print(test_output_sizes.shape)

input_ijkl = test_input_ijkl[np.any(test_input_ijkl>=0, axis=1)] #test_input_ijkl[:num_calls, :]
output_sizes = test_output_sizes[np.any(test_output_sizes>=0, axis=1)] #test_output_sizes[:num_calls, :]
assert input_ijkl.shape[0] == num_calls, (input_ijkl.shape[0], num_calls)
assert output_sizes.shape[0] == num_calls, (output_sizes.shape[0], num_calls)
else:
# Step 2: Remove zeros by orthogonality and match indices with shells.
# Precompute Fill input_ijkl and output_sizes with the necessary indices.
ortho_pattern = np.array([(i+3)%5!=0 for i in range(N)]) # hardcoded
n_upper_bound = (n_bas*(n_bas-1))**2
input_ijkl = np.zeros((n_upper_bound, 4), dtype=np.int32)
output_sizes = np.zeros((n_upper_bound, 5))

num_calls = 0

for i in tqdm(range(n_bas), desc="Computing shells"): # consider all shells << all ijkl
for j in range(i+1):
for k in range(i, n_bas):
for l in range(k+1):
di, dj, dk, dl = [ao_loc[z+1] - ao_loc[z] for z in [i,j,k,l]]

found_nonzero = False
# check i,j boxes
for bi in range(ao_loc[i], ao_loc[i+1]):
for bj in range(ao_loc[j], ao_loc[j+1]):
if (bi, bj) in considered_indices: # if ij box is considered
# check if kl pairs are considered
for bk in range(ao_loc[k], ao_loc[k+1]):
if bk>=bi: # apply symmetry - tril fade vertical
mla = ao_loc[l]
if bk == bi:
mla = max(bj, ao_loc[l]) # apply symmetry - tril fade horizontal
for bl in range(mla, ao_loc[l+1]):
if (bk, bl) in considered_indices:
# apply grid pattern to find final nonzeros
if False and not ortho_pattern[bi] ^ ortho_pattern[bj] ^ ortho_pattern[bk] ^ ortho_pattern[bl]:
found_nonzero = True
break
else:
found_nonzero = True
break
if found_nonzero: break
if found_nonzero: break
if found_nonzero: break
if found_nonzero: break
if not found_nonzero: continue
if not found_nonzero: continue

input_ijkl[num_calls] = [i, j, k, l]
output_sizes[num_calls] = [di, dj, dk, dl, di*dj*dk*dl]
num_calls += 1
input_ijkl[num_calls] = [i, j, k, l]
output_sizes[num_calls] = [di, dj, dk, dl, di*dj*dk*dl]
num_calls += 1

input_ijkl = input_ijkl[:num_calls, :]
output_sizes = output_sizes[:num_calls, :]
input_ijkl = input_ijkl[:num_calls, :]
output_sizes = output_sizes[:num_calls, :]

sizes, counts = [tuple(out.astype(np.int32).tolist()) for out in np.unique(output_sizes[:, -1], return_counts=True)]

Expand Down Expand Up @@ -427,6 +445,7 @@ def foreach_symmetry(sym, vals):
parser.add_argument('-skip', action="store_true")
parser.add_argument('-itol', default=1e-9, type=float)
parser.add_argument('-basis', default="6-311G", type=str)
parser.add_argument('-fast_shells', action="store_true")

args = parser.parse_args()
backend = args.backend
Expand Down Expand Up @@ -471,7 +490,7 @@ def foreach_symmetry(sym, vals):

# ------------------------------------ #

input_ijkl, sizes, counts, shapes = gen_shells(mol, args.itol, args.nipu)
input_ijkl, sizes, counts, shapes = gen_shells(mol, args.itol, args.nipu, fast_shells=args.fast_shells)
sliced_input_ijkl = np.concatenate([np.array(ijkl, dtype=int).reshape(args.nipu, -1) for ijkl in input_ijkl], axis=-1)

diff_JK = jax.pmap(compute_diff_jk, in_axes=(0, None, None, None, None, None, None), static_broadcasted_argnums=(2, 3, 4, 5, 6), backend=backend, axis_name="p")(sliced_input_ijkl, dm, mol, args.batches, args.itol, args.nipu, args.backend)
Expand Down

0 comments on commit 4e36058

Please sign in to comment.