Skip to content

Commit

Permalink
Fixed block dims and grid dims through syclcompat (#50)
Browse files Browse the repository at this point in the history
Fix the calls to BlockDim* and GridDim* through SYCL. The current changes give incorrect output if you run a CUDA kernel through SYCL on NVIDIA A100.

Co-authored-by: Mehdi Goli <[email protected]>

---------

Co-authored-by: Mehdi Goli <[email protected]>
  • Loading branch information
muhammad-tanvir-1211 and mehdi-goli authored Apr 29, 2024
1 parent b2746a2 commit 6a8aee6
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions include/cutlass/cutlass.h
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ CUTLASS_HOST_DEVICE uint BlockDimX() {
#if defined(__CUDA_ARCH__)
return blockDim.x;
#elif defined(__SYCL_DEVICE_ONLY__)
return syclcompat::work_group_range::x();
return syclcompat::local_range::x();
#else
return 0;
#endif
Expand All @@ -188,7 +188,7 @@ CUTLASS_HOST_DEVICE uint BlockDimY() {
#if defined(__CUDA_ARCH__)
return blockDim.y;
#elif defined(__SYCL_DEVICE_ONLY__)
return syclcompat::work_group_range::y();
return syclcompat::local_range::y();
#else
return 0;
#endif
Expand All @@ -198,7 +198,7 @@ CUTLASS_HOST_DEVICE uint BlockDimZ() {
#if defined(__CUDA_ARCH__)
return blockDim.z;
#elif defined(__SYCL_DEVICE_ONLY__)
return syclcompat::work_group_range::z();
return syclcompat::local_range::z();
#else
return 0;
#endif
Expand All @@ -208,7 +208,7 @@ CUTLASS_HOST_DEVICE uint GridDimX() {
#if defined(__CUDA_ARCH__)
return gridDim.x;
#elif defined(__SYCL_DEVICE_ONLY__)
return syclcompat::global_range::x();
return syclcompat::work_group_range::x();
#else
return 0;
#endif
Expand All @@ -218,7 +218,7 @@ CUTLASS_HOST_DEVICE uint GridDimY() {
#if defined(__CUDA_ARCH__)
return gridDim.y;
#elif defined(__SYCL_DEVICE_ONLY__)
return syclcompat::global_range::y();
return syclcompat::work_group_range::y();
#else
return 0;
#endif
Expand All @@ -228,7 +228,7 @@ CUTLASS_HOST_DEVICE uint GridDimZ() {
#if defined(__CUDA_ARCH__)
return gridDim.z;
#elif defined(__SYCL_DEVICE_ONLY__)
return syclcompat::global_range::z();
return syclcompat::work_group_range::z();
#else
return 0;
#endif
Expand Down Expand Up @@ -372,6 +372,8 @@ CUTLASS_DEVICE int atomicCAS(int *address, int compare, int val) {
CUTLASS_HOST_DEVICE bool thread0() {
#if defined(__CUDA_ARCH__)
return (!threadIdx.x && !threadIdx.y && !threadIdx.z) && (!blockIdx.x && !blockIdx.y && !blockIdx.z);
#elif defined(CUTLASS_ENABLE_SYCL)
return (!syclcompat::global_id::x() && !syclcompat::global_id::y() && !syclcompat::global_id::z());
#else
return false;
#endif
Expand Down

0 comments on commit 6a8aee6

Please sign in to comment.