diff --git a/src/executor/execution_kernel.cu b/src/executor/execution_kernel.cu index ed0fdc505..a60317c51 100644 --- a/src/executor/execution_kernel.cu +++ b/src/executor/execution_kernel.cu @@ -22,7 +22,7 @@ void ExecutionKernel::launchKernel(int rank, int nthreadblocks, int nthreads, vo #endif break; case DataType::UINT32: - executionKernel<<>>( + executionKernel<<>>( rank, (uint32_t*)src, (uint32_t*)dst, (uint32_t*)scratch, scratchSize, plan, flag #if defined(ENABLE_NPKIT) , @@ -32,7 +32,7 @@ void ExecutionKernel::launchKernel(int rank, int nthreadblocks, int nthreads, vo #endif break; case DataType::FLOAT16: - executionKernel<<>>( + executionKernel<<>>( rank, (half*)src, (half*)dst, (half*)scratch, scratchSize, plan, flag #if defined(ENABLE_NPKIT) , @@ -42,7 +42,7 @@ void ExecutionKernel::launchKernel(int rank, int nthreadblocks, int nthreads, vo #endif break; case DataType::FLOAT32: - executionKernel<<>>( + executionKernel<<>>( rank, (float*)src, (float*)dst, (float*)scratch, scratchSize, plan, flag #if defined(ENABLE_NPKIT) , @@ -52,7 +52,7 @@ void ExecutionKernel::launchKernel(int rank, int nthreadblocks, int nthreads, vo #endif break; case DataType::BFLOAT16: - executionKernel<__bfloat16><<>>( + executionKernel<__bfloat16, PacketType><<>>( rank, (__bfloat16*)src, (__bfloat16*)dst, (__bfloat16*)scratch, scratchSize, plan, flag #if defined(ENABLE_NPKIT) ,