diff --git a/src/igl/tests/vulkan/SpvConstantSpecializationTest.cpp b/src/igl/tests/vulkan/SpvConstantSpecializationTest.cpp new file mode 100644 index 0000000000..784be77d52 --- /dev/null +++ b/src/igl/tests/vulkan/SpvConstantSpecializationTest.cpp @@ -0,0 +1,95 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +namespace igl::tests { +// GLSL: +// +// layout(constant_id = 0) const int kConstant0 = 10; +// layout(constant_id = 1) const int kConstant1 = 11; +// +// out vec4 fragColor; +// +// void main() { +// fragColor = vec4(float(kConstant0), float(kConstant1), 0.0, 1.0); +// } + +// SPIR-V: +// +// OpCapability Shader +// OpMemoryModel Logical Simple +// OpEntryPoint Fragment %main "main" %fragColor +// OpName %kConstant0 "kConstant0" +// OpName %kConstant1 "kConstant1" +// OpName %fragColor "fragColor" +// OpName %main "main" +// OpDecorate %kConstant0 SpecId 0 +// OpDecorate %kConstant1 SpecId 1 +// %int = OpTypeInt 32 1 +// %kConstant0 = OpSpecConstant %int 10 +// %kConstant1 = OpSpecConstant %int 11 +// %float = OpTypeFloat 32 +// %v4float = OpTypeVector %float 4 +// %ptr_Output_v4float = OpTypePointer Output %v4float +// %fragColor = OpVariable %ptr_Output_v4float Output +// %void = OpTypeVoid +// %func = OpTypeFunction %void +// %_0_0f = OpConstant %float 0.0 +// %_1_0f = OpConstant %float 1.0 +// %main = OpFunction %void None %func +// %_1 = OpLabel +// %_2 = OpConvertSToF %float %kConstant0 +// %_3 = OpConvertSToF %float %kConstant1 +// %_4 = OpCompositeConstruct %v4float %_2 %_3 %_0_0f %_1_0f +// OpStore %fragColor %_4 +// OpReturn +// OpFunctionEnd + +namespace { +uint32_t getWord(int32_t val) { + return *reinterpret_cast(&val); +} +} // namespace + +TEST(SpvConstantSpecializationTest, intSpecialization) { + using namespace vulkan::util; + std::vector spv = { + 0x07230203, 0x00010300, 0xdeadbeef, 0x00000011, 0x00000000, 0x00020011, 0x00000001, + 0x0003000e, 0x00000000, 0x00000000, 0x0006000f, 0x00000004, 0x00000001, 0x6e69616d, + 0x00000000, 0x00000002, 0x00050005, 0x00000003, 0x6e6f436b, 0x6e617473, 0x00003074, + 0x00050005, 0x00000004, 0x6e6f436b, 0x6e617473, 0x00003174, 0x00050005, 0x00000002, + 0x67617266, 0x6f6c6f43, 0x00000072, 0x00040005, 0x00000001, 0x6e69616d, 0x00000000, + 0x00040047, 0x00000003, 0x00000001, 0x00000000, 0x00040047, 0x00000004, 0x00000001, + 0x00000001, 0x00040015, 0x00000005, 0x00000020, 0x00000001, 0x00040032, 0x00000005, + 0x00000003, 0x0000000a, 0x00040032, 0x00000005, 0x00000004, 0x0000000b, 0x00030016, + 0x00000006, 0x00000020, 0x00040017, 0x00000007, 0x00000006, 0x00000004, 0x00040020, + 0x00000008, 0x00000003, 0x00000007, 0x0004003b, 0x00000008, 0x00000002, 0x00000003, + 0x00020013, 0x00000009, 0x00030021, 0x0000000a, 0x00000009, 0x0004002b, 0x00000006, + 0x0000000b, 0x00000000, 0x0004002b, 0x00000006, 0x0000000c, 0x3f800000, 0x00050036, + 0x00000009, 0x00000001, 0x00000000, 0x0000000a, 0x000200f8, 0x0000000d, 0x0004006f, + 0x00000006, 0x0000000e, 0x00000003, 0x0004006f, 0x00000006, 0x0000000f, 0x00000004, + 0x00070050, 0x00000007, 0x00000010, 0x0000000e, 0x0000000f, 0x0000000b, 0x0000000c, + 0x0003003e, 0x00000002, 0x00000010, 0x000100fd, 0x00010038, + + }; + + // Specialize kConstant0 to 0 and kConstant1 to 1 + const std::vector values = {getWord(0), getWord(1)}; + + EXPECT_EQ(spv[50], getWord(10)); // 0x0000000a above + EXPECT_EQ(spv[54], getWord(11)); // 0x0000000b + + specializeConstants(spv.data(), spv.size() * sizeof(uint32_t), values); + + EXPECT_EQ(spv[50], getWord(0)); + EXPECT_EQ(spv[54], getWord(1)); +} + +} // namespace igl::tests diff --git a/src/igl/vulkan/util/SpvConstantSpecialization.cpp b/src/igl/vulkan/util/SpvConstantSpecialization.cpp new file mode 100644 index 0000000000..88beba53c2 --- /dev/null +++ b/src/igl/vulkan/util/SpvConstantSpecialization.cpp @@ -0,0 +1,105 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#define IGL_COMMON_SKIP_CHECK +#include + +namespace igl::vulkan::util { +namespace { +uint32_t makeOpCode(uint32_t opCode, uint32_t wordCount) { + return opCode | (wordCount << SpvWordCountShift); +} +} // namespace + +void specializeConstants(uint32_t* spirv, size_t numBytes, const std::vector& values) { + const uint32_t bound = spirv[3]; + const size_t size = numBytes / sizeof(uint32_t); + + if (!IGL_DEBUG_VERIFY(bound < 1024 * 1024)) { + return; + } + + if (!IGL_DEBUG_VERIFY(spirv[0] == SpvMagicNumber)) { + return; + } + + std::vector idToValue(bound, kNoValue); + + uint32_t* instruction = spirv + 5; + while (instruction < spirv + size) { + const uint16_t instructionSize = static_cast(instruction[0] >> SpvWordCountShift); + const uint16_t opCode = static_cast(instruction[0] & SpvOpCodeMask); + + switch (opCode) { + case SpvOpDecorate: { + constexpr uint32_t kOpDecorateTargetId = 1; + constexpr uint32_t kOpDecorateDecoration = 2; + constexpr uint32_t kOpDecorateOperandIds = 3; + + IGL_DEBUG_ASSERT(instruction + kOpDecorateDecoration <= spirv + size, + "OpDecorate out of bounds"); + + const uint32_t decoration = instruction[kOpDecorateDecoration]; + const uint32_t targetId = instruction[kOpDecorateTargetId]; + IGL_DEBUG_ASSERT(targetId < bound); + + switch (decoration) { + case SpvDecorationSpecId: { + IGL_DEBUG_ASSERT(instruction + kOpDecorateOperandIds <= spirv + size, + "OpDecorate out of bounds"); + const uint32_t specId = instruction[kOpDecorateOperandIds]; + idToValue[targetId] = values.size() > specId ? values[specId] : kNoValue; + break; + } + default: + break; + } + + break; + } + case SpvOpSpecConstantFalse: + case SpvOpSpecConstantTrue: { + constexpr uint32_t kOpSpecConstantTrueResultId = 2; + + const uint32_t resultId = instruction[kOpSpecConstantTrueResultId]; + const uint32_t specializedValue = idToValue[resultId]; + if (specializedValue == kNoValue) { + break; + } + instruction[0] = + makeOpCode(specializedValue ? SpvOpConstantTrue : SpvOpConstantFalse, instructionSize); + break; + } + + case SpvOpSpecConstant: { + constexpr uint32_t kOpSpecConstantResultId = 2; + constexpr uint32_t kOpSpecConstantValue = 3; + + uint32_t resultId = instruction[kOpSpecConstantResultId]; + uint32_t specializedValue = idToValue[resultId]; + if (specializedValue == kNoValue) { + break; + } + instruction[0] = makeOpCode(SpvOpConstant, instructionSize); + instruction[kOpSpecConstantValue] = specializedValue; + break; + } + + default: + break; + } + + IGL_DEBUG_ASSERT(instruction + instructionSize <= spirv + size); + instruction += instructionSize; + } +} + +} // namespace igl::vulkan::util diff --git a/src/igl/vulkan/util/SpvConstantSpecialization.h b/src/igl/vulkan/util/SpvConstantSpecialization.h new file mode 100644 index 0000000000..4beec4308c --- /dev/null +++ b/src/igl/vulkan/util/SpvConstantSpecialization.h @@ -0,0 +1,25 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include + +namespace igl::vulkan::util { + +constexpr uint32_t kNoValue = 0xffffffff; + +// Specializes integer, float and boolean constants in-place in the given SPIR-V binary. The value +// at the given index corrosponds the specialization constants constantId. Note that while we can't +// specialize OpSpecConstantOp, we could specialize OpSpecConstantComposite, but we would need +// support for variable size spec-constant values. +void specializeConstants(uint32_t* spirv, size_t numBytes, const std::vector& values); + +} // namespace igl::vulkan::util