Skip to content

Commit

Permalink
add sycl kernel get_bufaa()
Browse files Browse the repository at this point in the history
cjknight committed Dec 20, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
1 parent aacc98e commit dcd2d78
Showing 2 changed files with 44 additions and 1 deletion.
1 change: 0 additions & 1 deletion gpu/src/device_cuda.cpp
Original file line number Diff line number Diff line change
@@ -275,7 +275,6 @@ __global__ void _get_bufaa (const double* bufpp, double* bufaa, int naux, int nm

/* ---------------------------------------------------------------------- */


__global__ void _transpose_120(double * in, double * out, int naux, int nao, int ncas) {
//Pum->muP
int i = blockIdx.x * blockDim.x + threadIdx.x;
44 changes: 44 additions & 0 deletions gpu/src/device_sycl.cpp
Original file line number Diff line number Diff line change
@@ -290,6 +290,25 @@ void _get_bufpa (const double* bufpp, double* bufpa, int naux, int nmo, int ncor
bufpa[outputIndex] = bufpp[inputIndex];
}

/* ---------------------------------------------------------------------- */
void _get_bufaa (const double* bufpp, double* bufaa, int naux, int nmo, int ncore, int ncas,
const sycl::nd_item<3> &item_ct1){
const int i = item_ct1.get_group(2) * item_ct1.get_local_range(2) +
item_ct1.get_local_id(2);
const int j = item_ct1.get_group(1) * item_ct1.get_local_range(1) +
item_ct1.get_local_id(1);
const int k = item_ct1.get_group(0) * item_ct1.get_local_range(0) +
item_ct1.get_local_id(0);

if(i >= naux) return;
if(j >= ncas) return;
if(k >= ncas) return;

int inputIndex = (i*nmo + (j+ncore))*nmo + k+ncore;
int outputIndex = (i*ncas + j)*ncas + k;
bufaa[outputIndex] = bufpp[inputIndex];
}

/* ---------------------------------------------------------------------- */

void _transpose_120(double * in, double * out, int naux, int nao, int ncas,
@@ -675,6 +694,31 @@ void Device::get_bufpa(const double* bufpp, double* bufpa, int naux, int nmo, in

/* ---------------------------------------------------------------------- */

void Device::get_bufaa(const double* bufpp, double* bufaa, int naux, int nmo, int ncore, int ncas)
{
sycl::range<3> block_size(1, 1, _UNPACK_BLOCK_SIZE);
sycl::range<3> grid_size(ncas, ncas, _TILE(naux, block_size[2]));

sycl::queue * s = pm->dev_get_queue();

/*
DPCT1049:2: The work-group size passed to the SYCL kernel may exceed the
limit. To get the device limit, query info::device::max_work_group_size.
Adjust the work-group size if needed.
*/
{
// dpct::has_capability_or_fail(s->get_device(), {sycl::aspect::fp64});

s->parallel_for(sycl::nd_range<3>(grid_size * block_size, block_size),
[=](sycl::nd_item<3> item_ct1) {
_get_bufaa(bufpp, bufaa, naux, nmo, ncore, ncas,
item_ct1);
});
}
}

/* ---------------------------------------------------------------------- */

void Device::transpose_120(double * in, double * out, int naux, int nao, int ncas, int order)
{
sycl::queue * s = pm->dev_get_queue();

0 comments on commit dcd2d78

Please sign in to comment.