Skip to content

Commit

Permalink
Enable generic functions for 64-bit shuffles (#171)
Browse files Browse the repository at this point in the history
* enable generic functions for 64-bit shuffles

* Update include/cutlass/gpu_generics.h

Co-authored-by: Joe Todd <[email protected]>

---------

Co-authored-by: Joe Todd <[email protected]>
  • Loading branch information
t4c1 and joeatodd authored Jan 9, 2025
1 parent 25b7875 commit 096165f
Showing 1 changed file with 12 additions and 8 deletions.
20 changes: 12 additions & 8 deletions include/cutlass/gpu_generics.h
Original file line number Diff line number Diff line change
Expand Up @@ -249,10 +249,11 @@ unsigned int byte_perm(unsigned int x, unsigned int y, unsigned int s) {

// shfl

template<typename T>
CUTLASS_DEVICE
unsigned int shfl_up_sync(
T shfl_up_sync(
unsigned int const mask,
unsigned int const var,
T const var,
int const delta,
int const width = NumThreadsPerWarp) {
#if defined(__CUDA_ARCH__)
Expand All @@ -264,10 +265,11 @@ unsigned int shfl_up_sync(
#endif
}

template<typename T>
CUTLASS_DEVICE
unsigned int shfl_down_sync(
T shfl_down_sync(
unsigned int const mask,
unsigned int const var,
T const var,
int const delta,
int const width = NumThreadsPerWarp) {
#if defined(__CUDA_ARCH__)
Expand All @@ -279,10 +281,11 @@ unsigned int shfl_down_sync(
#endif
}

template<typename T>
CUTLASS_DEVICE
unsigned int shfl_sync(
T shfl_sync(
unsigned int const mask,
unsigned int const var,
T const var,
int const delta,
int const width = NumThreadsPerWarp) {
#if defined(__CUDA_ARCH__)
Expand All @@ -296,10 +299,11 @@ unsigned int shfl_sync(
#endif
}

template<typename T>
CUTLASS_DEVICE
unsigned int shfl_xor_sync(
T shfl_xor_sync(
unsigned int const mask,
unsigned int const var,
T const var,
int const laneMask,
int const width = NumThreadsPerWarp) {
#if defined(__CUDA_ARCH__)
Expand Down

0 comments on commit 096165f

Please sign in to comment.