diff --git a/khd/dtypes.h b/khd/utils/dtypes.h similarity index 100% rename from khd/dtypes.h rename to khd/utils/dtypes.h diff --git a/khd/utils/threads.h b/khd/utils/threads.h new file mode 100644 index 00000000..1c97524a --- /dev/null +++ b/khd/utils/threads.h @@ -0,0 +1,14 @@ +#include +#include + +__device__ int get_threads_per_block() { return blockDim.x * blockDim.y * blockDim.z; } + +__device__ int get_num_blocks() { return gridDim.x * gridDim.y * gridDim.z; } + +__device__ int get_block_id() { return gridDim.x * gridDim.y * blockIdx.z + gridDim.x * blockIdx.y + blockIdx.x; } + +__device__ int get_local_thread_id() { + return blockDim.x * blockDim.y * threadIdx.z + blockDim.x * threadIdx.y + threadIdx.x; +} + +__device__ int get_global_thread_id() { return get_threads_per_block() * get_block_id() + get_local_thread_id(); } diff --git a/khd/vector_addition/cuda_implementation/kernels.cu b/khd/vector_addition/cuda_implementation/kernels.cu index c7523805..db8db065 100644 --- a/khd/vector_addition/cuda_implementation/kernels.cu +++ b/khd/vector_addition/cuda_implementation/kernels.cu @@ -1,4 +1,5 @@ -#include "../../dtypes.h" +#include "../../utils/dtypes.h" +#include "../../utils/threads.h" #include #include #include @@ -18,7 +19,7 @@ __global__ void vector_addition_forward_kernel(const scalar_t *x, scalar_t *output, const int num_elements, const int num_elements_per_thread) { - const int thread_id = blockIdx.x * blockDim.x + threadIdx.x; + const int thread_id = get_global_thread_id(); const int start = thread_id * num_elements_per_thread; const int end = (thread_id + 1) * num_elements_per_thread - 1; // inclusive of last element @@ -33,24 +34,24 @@ __global__ void vector_addition_forward_kernel(const scalar_t *x, const fp32 *_y = (fp32 *)(&y4[thread_id]); // tmp is initialized here to avoid doing multiple writes - fp32_4 tmp; - fp32 *_tmp = (fp32 *)(&tmp); + fp32_4 tmp4; + fp32 *tmp = (fp32 *)(&tmp4); // clang-format off #pragma unroll // clang-format on for (int i = 0; i < NUM_FP32_ELEMENTS_PER_THREAD; i++) { if (std::is_same_v) { - _tmp[i] = _x[i] + _y[i]; + tmp[i] = _x[i] + _y[i]; } else if constexpr (std::is_same_v || std::is_same_v) { DType q; - _tmp[i] = q.pack_to_fp32(__hadd2(q.unpack_from_fp32(_x[i]), q.unpack_from_fp32(_y[i]))); + tmp[i] = q.pack_to_fp32(__hadd2(q.unpack_from_fp32(_x[i]), q.unpack_from_fp32(_y[i]))); } else { assert(false && "Function not implemented"); } } - output4[thread_id] = tmp; + output4[thread_id] = tmp4; } else if (start < num_elements) { // clang-format off #pragma unroll