From 2fb81a0ea5eb8fcaa752fea413146da4abf2d582 Mon Sep 17 00:00:00 2001 From: alexandermath Date: Thu, 12 Oct 2023 11:48:37 +0000 Subject: [PATCH] Fixed kinetic grad and added a todo. --- .../electron_repulsion/popcint/libcint.c | 10 +-- .../electron_repulsion/popcint/libcint.py | 70 +++++++------------ 2 files changed, 28 insertions(+), 52 deletions(-) diff --git a/pyscf_ipu/electron_repulsion/popcint/libcint.c b/pyscf_ipu/electron_repulsion/popcint/libcint.c index 7fd37da..f45cae4 100644 --- a/pyscf_ipu/electron_repulsion/popcint/libcint.c +++ b/pyscf_ipu/electron_repulsion/popcint/libcint.c @@ -42,7 +42,6 @@ extern "C" { #define PTR_GRIDS 12 #define PTR_ENV_START 20 - #define CHARGE_OF 0 #define PTR_COORD 1 #define NUC_MOD_OF 2 @@ -62,13 +61,11 @@ extern "C" { #define RESERVE_BASLOT 7 #define BAS_SLOTS 8 - #define POSX 0 #define POSY 1 #define POSZ 2 #define POS1 3 - #define POSXX 0 #define POSYX 1 #define POSZX 2 @@ -576,7 +573,7 @@ dtype CINTgto_norm(FINT n, dtype a); var = reinterpret_cast(new char[(n) * sizeof(*var)]); #define MALLOC(type, var) \ - type var[64]; + type var[256]; #else #define MALLOC_INSTACK(var, n) \ @@ -1050,7 +1047,6 @@ CACHE_SIZE_T CINT1e_drv(dtype *out, FINT *dims, CINTEnvVars *envs, printf("\n");*/ - FINT counts[4]; if (dims == NULL) { dims = counts; @@ -28592,7 +28588,7 @@ void GTOint2c(int (*intor)(), dtype *mat, int comp, int hermi, int shls[2]; #ifdef __cplusplus //dtype cache[128]; - dtype cache[128]; + dtype cache[128*2]; #else dtype *cache = malloc(sizeof(dtype) * cache_size); #endif @@ -28637,7 +28633,7 @@ void GTOint2c(int (*intor)(), dtype *mat, int comp, int hermi, printf("\n"); return; */ - //int1e_kin_sph(mat+j0*naoi+i0, dims, shls, atm, natm, bas, nbas, env, opt, cache); + //int1e_ipkin_sph(mat+j0*naoi+i0, dims, shls, atm, natm, bas, nbas, env, opt, cache); //printf("%f\n", (mat+j0*naoi+i0)[0]); /*for (int asd = 0; asd < 12; asd++){ printf("%f ", mat[asd]); diff --git a/pyscf_ipu/electron_repulsion/popcint/libcint.py b/pyscf_ipu/electron_repulsion/popcint/libcint.py index 3650a57..b0cb85d 100644 --- a/pyscf_ipu/electron_repulsion/popcint/libcint.py +++ b/pyscf_ipu/electron_repulsion/popcint/libcint.py @@ -3,6 +3,8 @@ # [ ] Add tile-mapping of integral computation (what is basic unit we tmap? ). # [ ] Consider how to interface this code into nanoDFT. # [ ] Remove hard-coding of tensor (i.e. move shape computation to python/jax.trace). +# [ ] Add direct matmul. +# [ ] For -all test, compile graph once, then do all tests (add more molecules for this as well). import os import pyscf import numpy as np @@ -10,13 +12,32 @@ import ctypes import numpy from pyscf import lib -from icecream import ic from functools import partial import os.path as osp import jax import jax.numpy as jnp +from tessellate_ipu import create_ipu_tile_primitive, ipu_cycle_count, tile_map, tile_put_sharded, tile_put_replicated +vertex_filename = osp.join(osp.dirname(__file__), "libcint.cpp") +int2e = create_ipu_tile_primitive( + "Int2e" , + "Int2e" , + inputs=["mat", "shls_slice", "ao_loc", "atm", "bas", "env", "natm", "nbas", "which_integral", "comp"], + outputs={"out": 0}, + gp_filename=vertex_filename, + perf_estimate=100, +) +grad = create_ipu_tile_primitive( + "Grad" , + "Grad" , + inputs=["mat", "shls_slice", "ao_loc", "atm", "bas", "env", "natm", "nbas", "which_integral"], + outputs={"out": 0}, + gp_filename=vertex_filename, + perf_estimate=100, +) + libcgto = numpy.ctypeslib.load_library("libcint.so", "") +float32 = "#define dtype float" in open("libcint.c", "r").read() ANG_OF = 1 NPRIM_OF = 2 @@ -73,7 +94,6 @@ def getints2c(intor_name, N, atm, bas, env, shls_slice=None, comp=1, hermi=0, cintopt = None # type - float32 = "#define dtype float" in open("libcint.c", "r").read() if float32: mat = mat.astype(np.float32) env = env.astype(np.float32) @@ -97,17 +117,8 @@ def cpu_intor1e(self, intor, N, comp=None, hermi=0, aosym='s1', out=None, shls_s @partial(jax.jit, backend="ipu", static_argnums=(0,1,2,3,4,5,6,7,8)) def ipu_intor1e(self, intor, which_integral, N, comp=None, hermi=0, aosym='s1', out=None, shls_slice=None, grids=None): - from tessellate_ipu import create_ipu_tile_primitive, ipu_cycle_count, tile_map, tile_put_sharded, tile_put_replicated - vertex_filename = osp.join(osp.dirname(__file__), "libcint.cpp") #mat, shls_slice, ao_loc, atm, bas, env - grad = create_ipu_tile_primitive( - "Grad" , - "Grad" , - inputs=["mat", "shls_slice", "ao_loc", "atm", "bas", "env", "natm", "nbas", "which_integral"], - outputs={"out": 0}, - gp_filename=vertex_filename, - perf_estimate=100, - ) + intor_name, N, atm, bas, env, shls_slice, comp, hermi, ao_loc, cintopt, out=\ intor+"_sph", N, self._atm, self._bas, self._env, shls_slice, comp, hermi, None, None, out @@ -124,7 +135,6 @@ def ipu_intor1e(self, intor, which_integral, N, comp=None, hermi=0, aosym='s1', mat = numpy.ndarray(shape, dtype, out, order='F') # type - float32 = "#define dtype float" in open("libcint.c", "r").read() if float32: mat = mat.astype(np.float32) env = env.astype(np.float32) @@ -169,7 +179,6 @@ def cpu_getints4c(intor_name, atm, bas, env, N, shls_slice=None, comp=1, out = numpy.zeros(shape) # type - float32 = "#define dtype float" in open("libcint.c", "r").read() if float32: out = out.astype(np.float32) env = env.astype(np.float32) @@ -232,7 +241,6 @@ def cpu_tile_map_getints4c(intor_name, atm, bas, env, N, shls_slice=None, comp=1 eri = numpy.zeros(shape) # type - float32 = "#define dtype float" in open("libcint.c", "r").read() if float32: out = out.astype(np.float32) eri = eri.astype(np.float32) @@ -242,7 +250,6 @@ def cpu_tile_map_getints4c(intor_name, atm, bas, env, N, shls_slice=None, comp=1 cintopt = None prescreen = lib.c_null_ptr() - ic(intor_name, 'GTOnr2e_fill_'+aosym, "GTOnr2e_fill_drv") drv(getattr(libcgto, intor_name), fill, prescreen, out.ctypes.data_as(ctypes.c_void_p), ctypes.c_int(comp), (ctypes.c_int*8)(*shls_slice), @@ -279,29 +286,15 @@ def ipu_getints4c(intor_name, atm, bas, env, N, shls_slice=None, comp=1, out = numpy.ndarray(shape, buffer=out) # type - float32 = "#define dtype float" in open("libcint.c", "r").read() if float32: out = out.astype(np.float32) env = env.astype(np.float32) - - from tessellate_ipu import create_ipu_tile_primitive, ipu_cycle_count, tile_map, tile_put_sharded, tile_put_replicated - vertex_filename = osp.join(osp.dirname(__file__), "libcint.cpp") - grad = create_ipu_tile_primitive( - "Int2e" , - "Int2e" , - inputs=["mat", "shls_slice", "ao_loc", "atm", "bas", "env", "natm", "nbas", "which_integral", "comp"], - outputs={"out": 0}, - gp_filename=vertex_filename, - perf_estimate=100, - ) - natm = atm.shape[0] nbas = bas.shape[0] prefix = 'GTO' - float32 = "#define dtype float" in open("libcint.c", "r").read() if float32: out = out.astype(np.float32) env = env.astype(np.float32) @@ -319,7 +312,7 @@ def ipu_getints4c(intor_name, atm, bas, env, N, shls_slice=None, comp=1, which_integral = tile_put_replicated(np.array(which_integral, dtype=jnp.int32), (1,)) - value = tile_map(grad, out, shls_slice, ao_loc, atm, bas, env, natm, nbas, which_integral, comp) + value = tile_map(int2e, out, shls_slice, ao_loc, atm, bas, env, natm, nbas, which_integral, comp) out = value.array @@ -343,29 +336,16 @@ def ipu_tile_map_getints4c(intor_name, atm, bas, env, N, shls_slice=None, comp=1 out = numpy.ndarray(shape, buffer=out) # type - float32 = "#define dtype float" in open("libcint.c", "r").read() if float32: out = out.astype(np.float32) env = env.astype(np.float32) - from tessellate_ipu import create_ipu_tile_primitive, ipu_cycle_count, tile_map, tile_put_sharded, tile_put_replicated - vertex_filename = osp.join(osp.dirname(__file__), "libcint.cpp") - grad = create_ipu_tile_primitive( - "Int2e" , - "Int2e" , - inputs=["mat", "shls_slice", "ao_loc", "atm", "bas", "env", "natm", "nbas", "which_integral", "comp"], - outputs={"out": 0}, - gp_filename=vertex_filename, - perf_estimate=100, - ) - natm = atm.shape[0] nbas = bas.shape[0] prefix = 'GTO' - float32 = "#define dtype float" in open("libcint.c", "r").read() if float32: out = out.astype(np.float32) env = env.astype(np.float32) @@ -383,7 +363,7 @@ def ipu_tile_map_getints4c(intor_name, atm, bas, env, N, shls_slice=None, comp=1 which_integral = tile_put_replicated(np.array(which_integral, dtype=jnp.int32), (1,)) - value = tile_map(grad, out, shls_slice, ao_loc, atm, bas, env, natm, nbas, which_integral, comp) + value = tile_map(int2e, out, shls_slice, ao_loc, atm, bas, env, natm, nbas, which_integral, comp) out = value.array