Skip to content

Commit

Permalink
Fixed kinetic grad and added a todo.
Browse files Browse the repository at this point in the history
  • Loading branch information
alexandermath committed Oct 12, 2023
1 parent 6909bb6 commit 2fb81a0
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 52 deletions.
10 changes: 3 additions & 7 deletions pyscf_ipu/electron_repulsion/popcint/libcint.c
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ extern "C" {
#define PTR_GRIDS 12
#define PTR_ENV_START 20


#define CHARGE_OF 0
#define PTR_COORD 1
#define NUC_MOD_OF 2
Expand All @@ -62,13 +61,11 @@ extern "C" {
#define RESERVE_BASLOT 7
#define BAS_SLOTS 8


#define POSX 0
#define POSY 1
#define POSZ 2
#define POS1 3


#define POSXX 0
#define POSYX 1
#define POSZX 2
Expand Down Expand Up @@ -576,7 +573,7 @@ dtype CINTgto_norm(FINT n, dtype a);
var = reinterpret_cast<decltype(var)>(new char[(n) * sizeof(*var)]);

#define MALLOC(type, var) \
type var[64];
type var[256];

#else
#define MALLOC_INSTACK(var, n) \
Expand Down Expand Up @@ -1050,7 +1047,6 @@ CACHE_SIZE_T CINT1e_drv(dtype *out, FINT *dims, CINTEnvVars *envs,
printf("\n");*/



FINT counts[4];
if (dims == NULL) {
dims = counts;
Expand Down Expand Up @@ -28592,7 +28588,7 @@ void GTOint2c(int (*intor)(), dtype *mat, int comp, int hermi,
int shls[2];
#ifdef __cplusplus
//dtype cache[128];
dtype cache[128];
dtype cache[128*2];
#else
dtype *cache = malloc(sizeof(dtype) * cache_size);
#endif
Expand Down Expand Up @@ -28637,7 +28633,7 @@ void GTOint2c(int (*intor)(), dtype *mat, int comp, int hermi,
printf("\n");
return; */

//int1e_kin_sph(mat+j0*naoi+i0, dims, shls, atm, natm, bas, nbas, env, opt, cache);
//int1e_ipkin_sph(mat+j0*naoi+i0, dims, shls, atm, natm, bas, nbas, env, opt, cache);
//printf("%f\n", (mat+j0*naoi+i0)[0]);
/*for (int asd = 0; asd < 12; asd++){
printf("%f ", mat[asd]);
Expand Down
70 changes: 25 additions & 45 deletions pyscf_ipu/electron_repulsion/popcint/libcint.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,41 @@
# [ ] Add tile-mapping of integral computation (what is basic unit we tmap? ).
# [ ] Consider how to interface this code into nanoDFT.
# [ ] Remove hard-coding of tensor (i.e. move shape computation to python/jax.trace).
# [ ] Add direct matmul.
# [ ] For -all test, compile graph once, then do all tests (add more molecules for this as well).
import os
import pyscf
import numpy as np
import ctypes
import ctypes
import numpy
from pyscf import lib
from icecream import ic
from functools import partial
import os.path as osp
import jax
import jax.numpy as jnp
from tessellate_ipu import create_ipu_tile_primitive, ipu_cycle_count, tile_map, tile_put_sharded, tile_put_replicated
vertex_filename = osp.join(osp.dirname(__file__), "libcint.cpp")
int2e = create_ipu_tile_primitive(
"Int2e" ,
"Int2e" ,
inputs=["mat", "shls_slice", "ao_loc", "atm", "bas", "env", "natm", "nbas", "which_integral", "comp"],
outputs={"out": 0},
gp_filename=vertex_filename,
perf_estimate=100,
)
grad = create_ipu_tile_primitive(
"Grad" ,
"Grad" ,
inputs=["mat", "shls_slice", "ao_loc", "atm", "bas", "env", "natm", "nbas", "which_integral"],
outputs={"out": 0},
gp_filename=vertex_filename,
perf_estimate=100,
)


libcgto = numpy.ctypeslib.load_library("libcint.so", "")
float32 = "#define dtype float" in open("libcint.c", "r").read()

ANG_OF = 1
NPRIM_OF = 2
Expand Down Expand Up @@ -73,7 +94,6 @@ def getints2c(intor_name, N, atm, bas, env, shls_slice=None, comp=1, hermi=0,
cintopt = None

# type
float32 = "#define dtype float" in open("libcint.c", "r").read()
if float32:
mat = mat.astype(np.float32)
env = env.astype(np.float32)
Expand All @@ -97,17 +117,8 @@ def cpu_intor1e(self, intor, N, comp=None, hermi=0, aosym='s1', out=None, shls_s

@partial(jax.jit, backend="ipu", static_argnums=(0,1,2,3,4,5,6,7,8))
def ipu_intor1e(self, intor, which_integral, N, comp=None, hermi=0, aosym='s1', out=None, shls_slice=None, grids=None):
from tessellate_ipu import create_ipu_tile_primitive, ipu_cycle_count, tile_map, tile_put_sharded, tile_put_replicated
vertex_filename = osp.join(osp.dirname(__file__), "libcint.cpp")
#mat, shls_slice, ao_loc, atm, bas, env
grad = create_ipu_tile_primitive(
"Grad" ,
"Grad" ,
inputs=["mat", "shls_slice", "ao_loc", "atm", "bas", "env", "natm", "nbas", "which_integral"],
outputs={"out": 0},
gp_filename=vertex_filename,
perf_estimate=100,
)


intor_name, N, atm, bas, env, shls_slice, comp, hermi, ao_loc, cintopt, out=\
intor+"_sph", N, self._atm, self._bas, self._env, shls_slice, comp, hermi, None, None, out
Expand All @@ -124,7 +135,6 @@ def ipu_intor1e(self, intor, which_integral, N, comp=None, hermi=0, aosym='s1',
mat = numpy.ndarray(shape, dtype, out, order='F')

# type
float32 = "#define dtype float" in open("libcint.c", "r").read()
if float32:
mat = mat.astype(np.float32)
env = env.astype(np.float32)
Expand Down Expand Up @@ -169,7 +179,6 @@ def cpu_getints4c(intor_name, atm, bas, env, N, shls_slice=None, comp=1,
out = numpy.zeros(shape)

# type
float32 = "#define dtype float" in open("libcint.c", "r").read()
if float32:
out = out.astype(np.float32)
env = env.astype(np.float32)
Expand Down Expand Up @@ -232,7 +241,6 @@ def cpu_tile_map_getints4c(intor_name, atm, bas, env, N, shls_slice=None, comp=1
eri = numpy.zeros(shape)

# type
float32 = "#define dtype float" in open("libcint.c", "r").read()
if float32:
out = out.astype(np.float32)
eri = eri.astype(np.float32)
Expand All @@ -242,7 +250,6 @@ def cpu_tile_map_getints4c(intor_name, atm, bas, env, N, shls_slice=None, comp=1

cintopt = None
prescreen = lib.c_null_ptr()
ic(intor_name, 'GTOnr2e_fill_'+aosym, "GTOnr2e_fill_drv")
drv(getattr(libcgto, intor_name), fill, prescreen,
out.ctypes.data_as(ctypes.c_void_p), ctypes.c_int(comp),
(ctypes.c_int*8)(*shls_slice),
Expand Down Expand Up @@ -279,29 +286,15 @@ def ipu_getints4c(intor_name, atm, bas, env, N, shls_slice=None, comp=1,
out = numpy.ndarray(shape, buffer=out)

# type
float32 = "#define dtype float" in open("libcint.c", "r").read()
if float32:
out = out.astype(np.float32)
env = env.astype(np.float32)


from tessellate_ipu import create_ipu_tile_primitive, ipu_cycle_count, tile_map, tile_put_sharded, tile_put_replicated
vertex_filename = osp.join(osp.dirname(__file__), "libcint.cpp")
grad = create_ipu_tile_primitive(
"Int2e" ,
"Int2e" ,
inputs=["mat", "shls_slice", "ao_loc", "atm", "bas", "env", "natm", "nbas", "which_integral", "comp"],
outputs={"out": 0},
gp_filename=vertex_filename,
perf_estimate=100,
)

natm = atm.shape[0]
nbas = bas.shape[0]

prefix = 'GTO'

float32 = "#define dtype float" in open("libcint.c", "r").read()
if float32:
out = out.astype(np.float32)
env = env.astype(np.float32)
Expand All @@ -319,7 +312,7 @@ def ipu_getints4c(intor_name, atm, bas, env, N, shls_slice=None, comp=1,

which_integral = tile_put_replicated(np.array(which_integral, dtype=jnp.int32), (1,))

value = tile_map(grad, out, shls_slice, ao_loc, atm, bas, env, natm, nbas, which_integral, comp)
value = tile_map(int2e, out, shls_slice, ao_loc, atm, bas, env, natm, nbas, which_integral, comp)

out = value.array

Expand All @@ -343,29 +336,16 @@ def ipu_tile_map_getints4c(intor_name, atm, bas, env, N, shls_slice=None, comp=1
out = numpy.ndarray(shape, buffer=out)

# type
float32 = "#define dtype float" in open("libcint.c", "r").read()
if float32:
out = out.astype(np.float32)
env = env.astype(np.float32)


from tessellate_ipu import create_ipu_tile_primitive, ipu_cycle_count, tile_map, tile_put_sharded, tile_put_replicated
vertex_filename = osp.join(osp.dirname(__file__), "libcint.cpp")
grad = create_ipu_tile_primitive(
"Int2e" ,
"Int2e" ,
inputs=["mat", "shls_slice", "ao_loc", "atm", "bas", "env", "natm", "nbas", "which_integral", "comp"],
outputs={"out": 0},
gp_filename=vertex_filename,
perf_estimate=100,
)

natm = atm.shape[0]
nbas = bas.shape[0]

prefix = 'GTO'

float32 = "#define dtype float" in open("libcint.c", "r").read()
if float32:
out = out.astype(np.float32)
env = env.astype(np.float32)
Expand All @@ -383,7 +363,7 @@ def ipu_tile_map_getints4c(intor_name, atm, bas, env, N, shls_slice=None, comp=1

which_integral = tile_put_replicated(np.array(which_integral, dtype=jnp.int32), (1,))

value = tile_map(grad, out, shls_slice, ao_loc, atm, bas, env, natm, nbas, which_integral, comp)
value = tile_map(int2e, out, shls_slice, ao_loc, atm, bas, env, natm, nbas, which_integral, comp)

out = value.array

Expand Down

0 comments on commit 2fb81a0

Please sign in to comment.