diff --git a/vkFFT/vkFFT.h b/vkFFT/vkFFT.h index 37518132..e86b7a4a 100644 --- a/vkFFT/vkFFT.h +++ b/vkFFT/vkFFT.h @@ -283,6 +283,7 @@ typedef struct { uint64_t streamCounter;//Filled at app creation uint64_t streamID;//Filled at app creation int64_t useStrict32BitAddress; // guarantee 32 bit addresses in bytes instead of number of elements. This results in fewer instructions generated. -1: Disable, 0: Infer based on size, 1: enable. Has no effect with useUint64. + int64_t useStaticWorkGroupSize; // Embed the compile time known block dimensions into kernels instead of using blockDim, for potentially better performance. -1: Disable, 0: Automatically enable where beneficial, 1: Always enable. #elif(VKFFT_BACKEND==3) cl_command_queue* commandQueue; #elif(VKFFT_BACKEND==4) @@ -813,7 +814,8 @@ typedef struct { uint64_t performBufferSetUpdate; uint64_t useUint64; #if(VKFFT_BACKEND==2) - int64_t useStrict32BitAddress; + int64_t useStrict32BitAddress; + int64_t useStaticWorkGroupSize; #endif uint64_t disableSetLocale; @@ -25942,38 +25944,39 @@ static inline VkFFTResult shaderGenVkFFT_R2C_decomposition(char* output, VkFFTSp if (!strcmp(floatTypeOutputMemory, "half")) sprintf(vecTypeOutput, "f16vec2"); if (!strcmp(floatTypeOutputMemory, "float")) sprintf(vecTypeOutput, "float2"); if (!strcmp(floatTypeOutputMemory, "double")) sprintf(vecTypeOutput, "double2"); - sprintf(sc->gl_LocalInvocationID_x, "threadIdx.x"); - sprintf(sc->gl_LocalInvocationID_y, "threadIdx.y"); - sprintf(sc->gl_LocalInvocationID_z, "threadIdx.z"); + sprintf(sc->gl_LocalInvocationID_x, sc->localSize[0] > 1 ? "threadIdx.x" : "0u"); + sprintf(sc->gl_LocalInvocationID_y, sc->localSize[1] > 1 ? "threadIdx.y" : "0u"); + sprintf(sc->gl_LocalInvocationID_z, sc->localSize[2] > 1 ? "threadIdx.z" : "0u"); switch (sc->swapComputeWorkGroupID) { case 0: - sprintf(sc->gl_GlobalInvocationID_x, "(threadIdx.x + blockIdx.x * blockDim.x)"); - sprintf(sc->gl_GlobalInvocationID_y, "(threadIdx.y + blockIdx.y * blockDim.y)"); - sprintf(sc->gl_GlobalInvocationID_z, "(threadIdx.z + blockIdx.z * blockDim.z)"); sprintf(sc->gl_WorkGroupID_x, "blockIdx.x"); sprintf(sc->gl_WorkGroupID_y, "blockIdx.y"); sprintf(sc->gl_WorkGroupID_z, "blockIdx.z"); break; case 1: - sprintf(sc->gl_GlobalInvocationID_x, "(threadIdx.x + blockIdx.y * blockDim.x)"); - sprintf(sc->gl_GlobalInvocationID_y, "(threadIdx.y + blockIdx.x * blockDim.y)"); - sprintf(sc->gl_GlobalInvocationID_z, "(threadIdx.z + blockIdx.z * blockDim.z)"); sprintf(sc->gl_WorkGroupID_x, "blockIdx.y"); sprintf(sc->gl_WorkGroupID_y, "blockIdx.x"); sprintf(sc->gl_WorkGroupID_z, "blockIdx.z"); break; case 2: - sprintf(sc->gl_GlobalInvocationID_x, "(threadIdx.x + blockIdx.z * blockDim.x)"); - sprintf(sc->gl_GlobalInvocationID_y, "(threadIdx.y + blockIdx.y * blockDim.y)"); - sprintf(sc->gl_GlobalInvocationID_z, "(threadIdx.z + blockIdx.x * blockDim.z)"); sprintf(sc->gl_WorkGroupID_x, "blockIdx.z"); sprintf(sc->gl_WorkGroupID_y, "blockIdx.y"); sprintf(sc->gl_WorkGroupID_z, "blockIdx.x"); break; } - sprintf(sc->gl_WorkGroupSize_x, "blockDim.x"); - sprintf(sc->gl_WorkGroupSize_y, "blockDim.y"); - sprintf(sc->gl_WorkGroupSize_z, "blockDim.z"); + if(sc->useStaticWorkGroupSize > 0) { + sprintf(sc->gl_WorkGroupSize_x, "%" PRIu64 "u", sc->localSize[0]); + sprintf(sc->gl_WorkGroupSize_y, "%" PRIu64 "u", sc->localSize[1]); + sprintf(sc->gl_WorkGroupSize_z, "%" PRIu64 "u", sc->localSize[2]); + } + else { + sprintf(sc->gl_WorkGroupSize_x, "blockDim.x"); + sprintf(sc->gl_WorkGroupSize_y, "blockDim.y"); + sprintf(sc->gl_WorkGroupSize_z, "blockDim.z"); + } + sprintf(sc->gl_GlobalInvocationID_x, "(%s + %s * %s)", sc->gl_LocalInvocationID_x, sc->gl_WorkGroupID_x, sc->gl_WorkGroupSize_x); + sprintf(sc->gl_GlobalInvocationID_y, "(%s + %s * %s)", sc->gl_LocalInvocationID_y, sc->gl_WorkGroupID_y, sc->gl_WorkGroupSize_y); + sprintf(sc->gl_GlobalInvocationID_z, "(%s + %s * %s)", sc->gl_LocalInvocationID_z, sc->gl_WorkGroupID_z, sc->gl_WorkGroupSize_z); sprintf(sc->gl_SubgroupInvocationID, "(threadIdx.x %% %" PRIu64 ")", sc->warpSize); sprintf(sc->gl_SubgroupID, "(threadIdx.x / %" PRIu64 ")", sc->warpSize); if (!strcmp(floatType, "double")) sprintf(LFending, "l"); @@ -26833,38 +26836,39 @@ static inline VkFFTResult shaderGenVkFFT(char* output, VkFFTSpecializationConsta if (!strcmp(floatTypeOutputMemory, "half")) sprintf(vecTypeOutput, "f16vec2"); if (!strcmp(floatTypeOutputMemory, "float")) sprintf(vecTypeOutput, "float2"); if (!strcmp(floatTypeOutputMemory, "double")) sprintf(vecTypeOutput, "double2"); - sprintf(sc->gl_LocalInvocationID_x, "threadIdx.x"); - sprintf(sc->gl_LocalInvocationID_y, "threadIdx.y"); - sprintf(sc->gl_LocalInvocationID_z, "threadIdx.z"); + sprintf(sc->gl_LocalInvocationID_x, sc->localSize[0] > 1 ? "threadIdx.x" : "0u"); + sprintf(sc->gl_LocalInvocationID_y, sc->localSize[1] > 1 ? "threadIdx.y" : "0u"); + sprintf(sc->gl_LocalInvocationID_z, sc->localSize[2] > 1 ? "threadIdx.z" : "0u"); switch (sc->swapComputeWorkGroupID) { case 0: - sprintf(sc->gl_GlobalInvocationID_x, "(threadIdx.x + blockIdx.x * blockDim.x)"); - sprintf(sc->gl_GlobalInvocationID_y, "(threadIdx.y + blockIdx.y * blockDim.y)"); - sprintf(sc->gl_GlobalInvocationID_z, "(threadIdx.z + blockIdx.z * blockDim.z)"); sprintf(sc->gl_WorkGroupID_x, "blockIdx.x"); sprintf(sc->gl_WorkGroupID_y, "blockIdx.y"); sprintf(sc->gl_WorkGroupID_z, "blockIdx.z"); break; case 1: - sprintf(sc->gl_GlobalInvocationID_x, "(threadIdx.x + blockIdx.y * blockDim.x)"); - sprintf(sc->gl_GlobalInvocationID_y, "(threadIdx.y + blockIdx.x * blockDim.y)"); - sprintf(sc->gl_GlobalInvocationID_z, "(threadIdx.z + blockIdx.z * blockDim.z)"); sprintf(sc->gl_WorkGroupID_x, "blockIdx.y"); sprintf(sc->gl_WorkGroupID_y, "blockIdx.x"); sprintf(sc->gl_WorkGroupID_z, "blockIdx.z"); break; case 2: - sprintf(sc->gl_GlobalInvocationID_x, "(threadIdx.x + blockIdx.z * blockDim.x)"); - sprintf(sc->gl_GlobalInvocationID_y, "(threadIdx.y + blockIdx.y * blockDim.y)"); - sprintf(sc->gl_GlobalInvocationID_z, "(threadIdx.z + blockIdx.x * blockDim.z)"); sprintf(sc->gl_WorkGroupID_x, "blockIdx.z"); sprintf(sc->gl_WorkGroupID_y, "blockIdx.y"); sprintf(sc->gl_WorkGroupID_z, "blockIdx.x"); break; } - sprintf(sc->gl_WorkGroupSize_x, "blockDim.x"); - sprintf(sc->gl_WorkGroupSize_y, "blockDim.y"); - sprintf(sc->gl_WorkGroupSize_z, "blockDim.z"); + if(sc->useStaticWorkGroupSize > 0) { + sprintf(sc->gl_WorkGroupSize_x, "%" PRIu64 "u", sc->localSize[0]); + sprintf(sc->gl_WorkGroupSize_y, "%" PRIu64 "u", sc->localSize[1]); + sprintf(sc->gl_WorkGroupSize_z, "%" PRIu64 "u", sc->localSize[2]); + } + else { + sprintf(sc->gl_WorkGroupSize_x, "blockDim.x"); + sprintf(sc->gl_WorkGroupSize_y, "blockDim.y"); + sprintf(sc->gl_WorkGroupSize_z, "blockDim.z"); + } + sprintf(sc->gl_GlobalInvocationID_x, "(%s + %s * %s)", sc->gl_LocalInvocationID_x, sc->gl_WorkGroupID_x, sc->gl_WorkGroupSize_x); + sprintf(sc->gl_GlobalInvocationID_y, "(%s + %s * %s)", sc->gl_LocalInvocationID_y, sc->gl_WorkGroupID_y, sc->gl_WorkGroupSize_y); + sprintf(sc->gl_GlobalInvocationID_z, "(%s + %s * %s)", sc->gl_LocalInvocationID_z, sc->gl_WorkGroupID_z, sc->gl_WorkGroupSize_z); sprintf(sc->gl_SubgroupInvocationID, "(threadIdx.x %% %" PRIu64 ")", sc->warpSize); sprintf(sc->gl_SubgroupID, "(threadIdx.x / %" PRIu64 ")", sc->warpSize); #elif((VKFFT_BACKEND==3)||(VKFFT_BACKEND==4)) @@ -33973,6 +33977,7 @@ static inline VkFFTResult VkFFTPlanR2CMultiUploadDecomposition(VkFFTApplication* axis->specializationConstants.useUint64 = app->configuration.useUint64; #if(VKFFT_BACKEND==2) axis->specializationConstants.useStrict32BitAddress = app->configuration.useStrict32BitAddress; + axis->specializationConstants.useStaticWorkGroupSize = app->configuration.useStaticWorkGroupSize; #endif axis->specializationConstants.disableSetLocale = app->configuration.disableSetLocale; @@ -35735,6 +35740,7 @@ static inline VkFFTResult VkFFTPlanAxis(VkFFTApplication* app, VkFFTPlan* FFTPla axis->specializationConstants.useUint64 = app->configuration.useUint64; #if(VKFFT_BACKEND==2) axis->specializationConstants.useStrict32BitAddress = app->configuration.useStrict32BitAddress; + axis->specializationConstants.useStaticWorkGroupSize = app->configuration.useStaticWorkGroupSize; #endif axis->specializationConstants.disableSetLocale = app->configuration.disableSetLocale; @@ -39681,6 +39687,9 @@ static inline VkFFTResult initializeVkFFT(VkFFTApplication* app, VkFFTConfigurat return VKFFT_ERROR_FAILED_TO_GET_ATTRIBUTE; } app->configuration.warpSize = value; + if(inputLaunchConfiguration.useStaticWorkGroupSize != 0) app->configuration.useStaticWorkGroupSize = inputLaunchConfiguration.useStaticWorkGroupSize; + else if (app->configuration.warpSize == 32) app->configuration.useStaticWorkGroupSize = -1; // Embedding the work group size slows down kernels on RDNA + else app->configuration.useStaticWorkGroupSize = 1; app->configuration.sharedMemorySizePow2 = (uint64_t)pow(2, (uint64_t)log2(app->configuration.sharedMemorySize)); app->configuration.useRaderUintLUT = 0; if (app->configuration.num_streams > 1) {