diff --git a/cute_kernels/kernels/swiglu/cuda_implementation/kernels_backward.cu b/cute_kernels/kernels/swiglu/cuda_implementation/kernels_backward.cu index ec6053bc..c5477383 100644 --- a/cute_kernels/kernels/swiglu/cuda_implementation/kernels_backward.cu +++ b/cute_kernels/kernels/swiglu/cuda_implementation/kernels_backward.cu @@ -36,7 +36,7 @@ __global__ void _swiglu_backward_cuda_kernel(const scalar_t *gate, static_assert(vector_instruction_width == 1 || vector_instruction_width == 2 || vector_instruction_width == 4 || vector_instruction_width == 8); - const int64_t thread_id = get_global_thread_id(); + const uint64 thread_id = get_global_thread_id(); using dtype = DType; if constexpr (vector_instruction_width == 1) { diff --git a/cute_kernels/kernels/swiglu/cuda_implementation/kernels_forward.cu b/cute_kernels/kernels/swiglu/cuda_implementation/kernels_forward.cu index 5e9f0598..16b87dda 100644 --- a/cute_kernels/kernels/swiglu/cuda_implementation/kernels_forward.cu +++ b/cute_kernels/kernels/swiglu/cuda_implementation/kernels_forward.cu @@ -16,7 +16,7 @@ __global__ void _swiglu_forward_cuda_kernel(const scalar_t *gate, static_assert(vector_instruction_width == 1 || vector_instruction_width == 2 || vector_instruction_width == 4 || vector_instruction_width == 8); - const int64_t thread_id = get_global_thread_id(); + const uint64 thread_id = get_global_thread_id(); using dtype = DType; if constexpr (vector_instruction_width == 1) {