Skip to content

Commit

Permalink
cleanup reduction of vj and vk from multiple gpus
Browse files Browse the repository at this point in the history
  • Loading branch information
cjknight committed May 30, 2024
1 parent 9847f20 commit c199000
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 33 deletions.
11 changes: 9 additions & 2 deletions gpu/src/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@ Device::Device()
printf("LIBGPU: created device\n");

pm = new PM();

n = 0;

update_dfobj = 0;

Expand All @@ -39,6 +37,12 @@ Device::Device()
buf4 = nullptr;

buf_fdrv = nullptr;

size_buf_vj = 0;
size_buf_vk = 0;

buf_vj = nullptr;
buf_vk = nullptr;

#if defined(_USE_GPU)
d_bPpj = nullptr;
Expand Down Expand Up @@ -126,6 +130,9 @@ Device::~Device()
pm->dev_free_host(buf3);
pm->dev_free_host(buf4);

pm->dev_free_host(buf_vj);
pm->dev_free_host(buf_vk);

pm->dev_free_host(buf_fdrv);

// for(int i=0; i<size_tril_map.size(); ++i) {
Expand Down
9 changes: 6 additions & 3 deletions gpu/src/device.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,6 @@ public :
void profile_start(const char *);
void profile_stop();
void profile_next(const char *);

int n;
int size_data;

size_t grid_size, block_size;

Expand All @@ -108,6 +105,9 @@ public :
int nset;
int nao_pair;

int size_buf_vj;
int size_buf_vk;

// get_jk

double * rho;
Expand All @@ -119,6 +119,9 @@ public :
double * buf4;
double * buf_fdrv;

double * buf_vj;
double * buf_vk;

// hessop_get_veff

int size_bPpj;
Expand Down
91 changes: 63 additions & 28 deletions gpu/src/device_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,20 @@ void Device::init_get_jk(py::array_t<double> _eri1, py::array_t<double> _dmtril,

dd->d_tril_map_ptr = dd->d_tril_map[indx];

int _size_buf_vj = num_devices * nset * nao_pair;
if(_size_buf_vj > size_buf_vj) {
size_buf_vj = _size_buf_vj;
if(buf_vj) pm->dev_free_host(buf_vj);
buf_vj = (double *) pm->dev_malloc_host(_size_buf_vj*sizeof(double));
}

int _size_buf_vk = num_devices * nset * nao * nao;
if(_size_buf_vk > size_buf_vk) {
size_buf_vk = _size_buf_vk;
if(buf_vk) pm->dev_free_host(buf_vk);
buf_vk = (double *) pm->dev_malloc_host(_size_buf_vk*sizeof(double));
}

#ifdef _SIMPLE_TIMER
double t2 = omp_get_wtime();
t_array_jk[0] += t2 - t0;
Expand Down Expand Up @@ -191,52 +205,73 @@ void Device::pull_get_jk(py::array_t<double> _vj, py::array_t<double> _vk, int w
py::buffer_info info_vj = _vj.request(); // 2D array (nset, nao_pair)

double * vj = static_cast<double*>(info_vj.ptr);

double * tmp = (double *) pm->dev_malloc_host(nset * nao_pair * sizeof(double));

int size = nset * nao_pair * sizeof(double);

double * tmp;

for(int i=0; i<num_devices; ++i) {
pm->dev_set_device(i);

my_device_data * dd = &(device_data[i]);

if(i == 0) tmp = vj;
else tmp = &(buf_vj[i * nset * nao_pair]);

pm->dev_pull_async(dd->d_vj, tmp, size, dd->stream);
}

for(int i=0; i<num_devices; ++i) {
my_device_data * dd = &(device_data[i]);
pm->dev_stream_wait(dd->stream);

if(i == 0) pm->dev_pull(dd->d_vj, vj, nset * nao_pair * sizeof(double));
else if(dd->d_vj) {
pm->dev_pull(dd->d_vj, tmp, nset*nao_pair*sizeof(double));

if(i > 0) {

tmp = &(buf_vj[i * nset * nao_pair]);
#pragma omp parallel for
for(int j=0; j<nset*nao_pair; ++j) vj[j] += tmp[j];

}
}

pm->dev_free_host(tmp);

if(with_k) {
pm->dev_set_device(0);

py::buffer_info info_vk = _vk.request(); // 3D array (nset, nao, nao)

update_dfobj = 0;
if(!with_k) {
#ifdef _DEBUG_DEVICE
printf("LIBGPU :: -- Leaving Device::pull_get_jk()\n");
#endif
return;
}

double * vk = static_cast<double*>(info_vk.ptr);
py::buffer_info info_vk = _vk.request(); // 3D array (nset, nao, nao)

tmp = (double *) pm->dev_malloc_host(nset * nao * nao * sizeof(double));
double * vk = static_cast<double*>(info_vk.ptr);

for(int i=0; i<num_devices; ++i) {
pm->dev_set_device(i);
my_device_data * dd = &(device_data[i]);
size = nset * nao * nao * sizeof(double);

for(int i=0; i<num_devices; ++i) {
pm->dev_set_device(i);

pm->dev_stream_wait(dd->stream);
my_device_data * dd = &(device_data[i]);

if(i == 0) pm->dev_pull(dd->d_vkk, vk, nset * nao * nao * sizeof(double));
else if(dd->d_vkk) {
pm->dev_pull(dd->d_vkk, tmp, nset*nao*nao*sizeof(double));
for(int j=0; j<nset*nao*nao; ++j) vk[j] += tmp[j];
}
}
if(i == 0) tmp = vk;
else tmp = &(buf_vk[i * nset * nao * nao]);

pm->dev_free_host(tmp);
pm->dev_pull_async(dd->d_vkk, tmp, size, dd->stream);
}

for(int i=0; i<num_devices; ++i) {
my_device_data * dd = &(device_data[i]);
pm->dev_stream_wait(dd->stream);

if(i > 0) {

tmp = &(buf_vk[i * nset * nao * nao]);
#pragma omp parallel for
for(int j=0; j<nset*nao*nao; ++j) vk[j] += tmp[j];

update_dfobj = 0;
}

}

#ifdef _DEBUG_DEVICE
printf("LIBGPU :: -- Leaving Device::pull_get_jk()\n");
Expand Down

0 comments on commit c199000

Please sign in to comment.