Skip to content

Commit

Permalink
Implement GL_NV_cooperative_vector
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffbolznv authored and jeremy-lunarg committed Feb 3, 2025
1 parent 0549c71 commit 1b65bd6
Show file tree
Hide file tree
Showing 41 changed files with 6,831 additions and 4,858 deletions.
2 changes: 2 additions & 0 deletions SPIRV/GLSL.ext.KHR.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,4 +64,6 @@ static const char* const E_SPV_KHR_expect_assume = "SPV_KHR_expec
static const char* const E_SPV_EXT_replicated_composites = "SPV_EXT_replicated_composites";
static const char* const E_SPV_KHR_relaxed_extended_instruction = "SPV_KHR_relaxed_extended_instruction";
static const char* const E_SPV_KHR_integer_dot_product = "SPV_KHR_integer_dot_product";
static const char* const E_SPV_NV_cooperative_vector = "SPV_NV_cooperative_vector";

#endif // #ifndef GLSLextKHR_H
223 changes: 193 additions & 30 deletions SPIRV/GlslangToSpv.cpp

Large diffs are not rendered by default.

55 changes: 46 additions & 9 deletions SPIRV/SpvBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,28 @@ Id Builder::makeCooperativeMatrixTypeWithSameShape(Id component, Id otherType)
}
}

Id Builder::makeCooperativeVectorTypeNV(Id componentType, Id components)
{
// try to find it
Instruction* type;
for (int t = 0; t < (int)groupedTypes[OpTypeCooperativeVectorNV].size(); ++t) {
type = groupedTypes[OpTypeCooperativeVectorNV][t];
if (type->getIdOperand(0) == componentType &&
type->getIdOperand(1) == components)
return type->getResultId();
}

// not found, make it
type = new Instruction(getUniqueId(), NoType, OpTypeCooperativeVectorNV);
type->addIdOperand(componentType);
type->addIdOperand(components);
groupedTypes[OpTypeCooperativeVectorNV].push_back(type);
constantsTypesGlobals.push_back(std::unique_ptr<Instruction>(type));
module.mapInstruction(type);

return type->getResultId();
}

Id Builder::makeGenericType(spv::Op opcode, std::vector<spv::IdImmediate>& operands)
{
// try to find it
Expand Down Expand Up @@ -1363,6 +1385,7 @@ unsigned int Builder::getNumTypeConstituents(Id typeId) const
case OpTypeVector:
case OpTypeMatrix:
return instr->getImmediateOperand(1);
case OpTypeCooperativeVectorNV:
case OpTypeArray:
{
Id lengthId = instr->getIdOperand(1);
Expand Down Expand Up @@ -1401,6 +1424,7 @@ Id Builder::getScalarTypeId(Id typeId) const
case OpTypeArray:
case OpTypeRuntimeArray:
case OpTypePointer:
case OpTypeCooperativeVectorNV:
return getScalarTypeId(getContainedTypeId(typeId));
default:
assert(0);
Expand All @@ -1422,6 +1446,7 @@ Id Builder::getContainedTypeId(Id typeId, int member) const
case OpTypeRuntimeArray:
case OpTypeCooperativeMatrixKHR:
case OpTypeCooperativeMatrixNV:
case OpTypeCooperativeVectorNV:
return instr->getIdOperand(0);
case OpTypePointer:
return instr->getIdOperand(1);
Expand Down Expand Up @@ -1804,7 +1829,7 @@ Id Builder::importNonSemanticShaderDebugInfoInstructions()
return nonSemanticShaderDebugInfo;
}

Id Builder::findCompositeConstant(Op typeClass, Id typeId, const std::vector<Id>& comps)
Id Builder::findCompositeConstant(Op typeClass, Op opcode, Id typeId, const std::vector<Id>& comps, size_t numMembers)
{
Instruction* constant = nullptr;
bool found = false;
Expand All @@ -1814,6 +1839,13 @@ Id Builder::findCompositeConstant(Op typeClass, Id typeId, const std::vector<Id>
if (constant->getTypeId() != typeId)
continue;

if (constant->getOpCode() != opcode) {
continue;
}

if (constant->getNumOperands() != (int)numMembers)
continue;

// same contents?
bool mismatch = false;
for (int op = 0; op < constant->getNumOperands(); ++op) {
Expand Down Expand Up @@ -1863,7 +1895,7 @@ Id Builder::makeCompositeConstant(Id typeId, const std::vector<Id>& members, boo

bool replicate = false;
size_t numMembers = members.size();
if (useReplicatedComposites) {
if (useReplicatedComposites || typeClass == OpTypeCooperativeVectorNV) {
// use replicate if all members are the same
replicate = numMembers > 0 &&
std::equal(members.begin() + 1, members.end(), members.begin());
Expand All @@ -1885,8 +1917,9 @@ Id Builder::makeCompositeConstant(Id typeId, const std::vector<Id>& members, boo
case OpTypeMatrix:
case OpTypeCooperativeMatrixKHR:
case OpTypeCooperativeMatrixNV:
case OpTypeCooperativeVectorNV:
if (! specConstant) {
Id existing = findCompositeConstant(typeClass, typeId, members);
Id existing = findCompositeConstant(typeClass, opcode, typeId, members, numMembers);
if (existing)
return existing;
}
Expand Down Expand Up @@ -3021,7 +3054,7 @@ Id Builder::smearScalar(Decoration precision, Id scalar, Id vectorType)
assert(getTypeId(scalar) == getScalarTypeId(vectorType));

int numComponents = getNumTypeComponents(vectorType);
if (numComponents == 1)
if (numComponents == 1 && !isCooperativeVectorType(vectorType))
return scalar;

Instruction* smear = nullptr;
Expand All @@ -3038,7 +3071,7 @@ Id Builder::smearScalar(Decoration precision, Id scalar, Id vectorType)
auto result_id = makeCompositeConstant(vectorType, members, isSpecConstant(scalar));
smear = module.getInstruction(result_id);
} else {
bool replicate = useReplicatedComposites && (numComponents > 0);
bool replicate = (useReplicatedComposites || isCooperativeVectorType(vectorType)) && (numComponents > 0);

if (replicate) {
numComponents = 1;
Expand Down Expand Up @@ -3425,7 +3458,8 @@ Id Builder::createCompositeCompare(Decoration precision, Id value1, Id value2, b
Id Builder::createCompositeConstruct(Id typeId, const std::vector<Id>& constituents)
{
assert(isAggregateType(typeId) || (getNumTypeConstituents(typeId) > 1 &&
getNumTypeConstituents(typeId) == constituents.size()));
getNumTypeConstituents(typeId) == constituents.size()) ||
(isCooperativeVectorType(typeId) && constituents.size() == 1));

if (generatingOpCodeForSpecConst) {
// Sometime, even in spec-constant-op mode, the constant composite to be
Expand All @@ -3444,7 +3478,7 @@ Id Builder::createCompositeConstruct(Id typeId, const std::vector<Id>& constitue
bool replicate = false;
size_t numConstituents = constituents.size();

if (useReplicatedComposites) {
if (useReplicatedComposites || isCooperativeVectorType(typeId)) {
replicate = numConstituents > 0 &&
std::equal(constituents.begin() + 1, constituents.end(), constituents.begin());
}
Expand Down Expand Up @@ -3510,7 +3544,7 @@ Id Builder::createConstructor(Decoration precision, const std::vector<Id>& sourc

// Special case: when calling a vector constructor with a single scalar
// argument, smear the scalar
if (sources.size() == 1 && isScalar(sources[0]) && numTargetComponents > 1)
if (sources.size() == 1 && isScalar(sources[0]) && (numTargetComponents > 1 || isCooperativeVectorType(resultTypeId)))
return smearScalar(precision, sources[0], resultTypeId);

// Special case: 2 vectors of equal size
Expand Down Expand Up @@ -3574,7 +3608,7 @@ Id Builder::createConstructor(Decoration precision, const std::vector<Id>& sourc

if (isScalar(sources[i]) || isPointer(sources[i]))
latchResult(sources[i]);
else if (isVector(sources[i]))
else if (isVector(sources[i]) || isCooperativeVector(sources[i]))
accumulateVectorConstituents(sources[i]);
else if (isMatrix(sources[i]))
accumulateMatrixConstituents(sources[i]);
Expand Down Expand Up @@ -4021,6 +4055,9 @@ Id Builder::accessChainLoad(Decoration precision, Decoration l_nonUniform,
if (constant) {
id = createCompositeExtract(accessChain.base, swizzleBase, indexes);
setPrecision(id, precision);
} else if (isCooperativeVector(accessChain.base)) {
assert(accessChain.indexChain.size() == 1);
id = createVectorExtractDynamic(accessChain.base, resultType, accessChain.indexChain[0]);
} else {
Id lValue = NoResult;
if (spvVersion >= Spv_1_4 && isValidInitializer(accessChain.base)) {
Expand Down
6 changes: 5 additions & 1 deletion SPIRV/SpvBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ class Builder {
Id makeCooperativeMatrixTypeKHR(Id component, Id scope, Id rows, Id cols, Id use);
Id makeCooperativeMatrixTypeNV(Id component, Id scope, Id rows, Id cols);
Id makeCooperativeMatrixTypeWithSameShape(Id component, Id otherType);
Id makeCooperativeVectorTypeNV(Id componentType, Id components);
Id makeGenericType(spv::Op opcode, std::vector<spv::IdImmediate>& operands);

// SPIR-V NonSemantic Shader DebugInfo Instructions
Expand Down Expand Up @@ -280,12 +281,14 @@ class Builder {
{ return (ImageFormat)module.getInstruction(typeId)->getImmediateOperand(6); }
Id getResultingAccessChainType() const;
Id getIdOperand(Id resultId, int idx) { return module.getInstruction(resultId)->getIdOperand(idx); }
Id getCooperativeVectorNumComponents(Id typeId) const { return module.getInstruction(typeId)->getIdOperand(1); }

bool isPointer(Id resultId) const { return isPointerType(getTypeId(resultId)); }
bool isScalar(Id resultId) const { return isScalarType(getTypeId(resultId)); }
bool isVector(Id resultId) const { return isVectorType(getTypeId(resultId)); }
bool isMatrix(Id resultId) const { return isMatrixType(getTypeId(resultId)); }
bool isCooperativeMatrix(Id resultId)const { return isCooperativeMatrixType(getTypeId(resultId)); }
bool isCooperativeVector(Id resultId)const { return isCooperativeVectorType(getTypeId(resultId)); }
bool isAggregate(Id resultId) const { return isAggregateType(getTypeId(resultId)); }
bool isSampledImage(Id resultId) const { return isSampledImageType(getTypeId(resultId)); }
bool isTensorView(Id resultId)const { return isTensorViewType(getTypeId(resultId)); }
Expand All @@ -310,6 +313,7 @@ class Builder {
return getTypeClass(typeId) == OpTypeCooperativeMatrixKHR || getTypeClass(typeId) == OpTypeCooperativeMatrixNV;
}
bool isTensorViewType(Id typeId) const { return getTypeClass(typeId) == OpTypeTensorViewNV; }
bool isCooperativeVectorType(Id typeId)const { return getTypeClass(typeId) == OpTypeCooperativeVectorNV; }
bool isAggregateType(Id typeId) const
{ return isArrayType(typeId) || isStructType(typeId) || isCooperativeMatrixType(typeId); }
bool isImageType(Id typeId) const { return getTypeClass(typeId) == OpTypeImage; }
Expand Down Expand Up @@ -898,7 +902,7 @@ class Builder {
protected:
Id findScalarConstant(Op typeClass, Op opcode, Id typeId, unsigned value);
Id findScalarConstant(Op typeClass, Op opcode, Id typeId, unsigned v1, unsigned v2);
Id findCompositeConstant(Op typeClass, Id typeId, const std::vector<Id>& comps);
Id findCompositeConstant(Op typeClass, Op opcode, Id typeId, const std::vector<Id>& comps, size_t numMembers);
Id findStructConstant(Id typeId, const std::vector<Id>& comps);
Id collapseAccessChain();
void remapDynamicSwizzle();
Expand Down
71 changes: 71 additions & 0 deletions SPIRV/doc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1046,6 +1046,9 @@ const char* CapabilityString(int info)

case CapabilityShaderSMBuiltinsNV: return "ShaderSMBuiltinsNV";

case CapabilityCooperativeVectorNV: return "CooperativeVectorNV";
case CapabilityCooperativeVectorTrainingNV: return "CooperativeVectorTrainingNV";

case CapabilityFragmentShaderSampleInterlockEXT: return "CapabilityFragmentShaderSampleInterlockEXT";
case CapabilityFragmentShaderPixelInterlockEXT: return "CapabilityFragmentShaderPixelInterlockEXT";
case CapabilityFragmentShaderShadingRateInterlockEXT: return "CapabilityFragmentShaderShadingRateInterlockEXT";
Expand Down Expand Up @@ -1579,6 +1582,14 @@ const char* OpcodeString(int op)
case OpTensorViewSetStrideNV: return "OpTensorViewSetStrideNV";
case OpTensorViewSetClipNV: return "OpTensorViewSetClipNV";

case OpTypeCooperativeVectorNV: return "OpTypeCooperativeVectorNV";
case OpCooperativeVectorMatrixMulNV: return "OpCooperativeVectorMatrixMulNV";
case OpCooperativeVectorMatrixMulAddNV: return "OpCooperativeVectorMatrixMulAddNV";
case OpCooperativeVectorLoadNV: return "OpCooperativeVectorLoadNV";
case OpCooperativeVectorStoreNV: return "OpCooperativeVectorStoreNV";
case OpCooperativeVectorOuterProductAccumulateNV: return "OpCooperativeVectorOuterProductAccumulateNV";
case OpCooperativeVectorReduceSumAccumulateNV: return "OpCooperativeVectorReduceSumAccumulateNV";

case OpBeginInvocationInterlockEXT: return "OpBeginInvocationInterlockEXT";
case OpEndInvocationInterlockEXT: return "OpEndInvocationInterlockEXT";

Expand Down Expand Up @@ -1766,6 +1777,11 @@ void Parameterize()
InstructionDesc[OpTypeTensorLayoutNV].setResultAndType(true, false);
InstructionDesc[OpTypeTensorViewNV].setResultAndType(true, false);
InstructionDesc[OpCooperativeMatrixStoreTensorNV].setResultAndType(false, false);
InstructionDesc[OpTypeCooperativeVectorNV].setResultAndType(true, false);
InstructionDesc[OpCooperativeVectorStoreNV].setResultAndType(false, false);
InstructionDesc[OpCooperativeVectorOuterProductAccumulateNV].setResultAndType(false, false);
InstructionDesc[OpCooperativeVectorReduceSumAccumulateNV].setResultAndType(false, false);

// Specific additional context-dependent operands

ExecutionModeOperands[ExecutionModeInvocations].push(OperandLiteralNumber, "'Number of <<Invocation,invocations>>'");
Expand Down Expand Up @@ -3272,6 +3288,61 @@ void Parameterize()

InstructionDesc[OpCooperativeMatrixLengthKHR].operands.push(OperandId, "'Type'");

InstructionDesc[OpTypeCooperativeVectorNV].operands.push(OperandId, "'Component Type'");
InstructionDesc[OpTypeCooperativeVectorNV].operands.push(OperandId, "'Components'");

InstructionDesc[OpCooperativeVectorMatrixMulNV].operands.push(OperandId, "'Input'");
InstructionDesc[OpCooperativeVectorMatrixMulNV].operands.push(OperandId, "'InputInterpretation'");
InstructionDesc[OpCooperativeVectorMatrixMulNV].operands.push(OperandId, "'Matrix'");
InstructionDesc[OpCooperativeVectorMatrixMulNV].operands.push(OperandId, "'MatrixOffset'");
InstructionDesc[OpCooperativeVectorMatrixMulNV].operands.push(OperandId, "'MatrixInterpretation'");
InstructionDesc[OpCooperativeVectorMatrixMulNV].operands.push(OperandId, "'M'");
InstructionDesc[OpCooperativeVectorMatrixMulNV].operands.push(OperandId, "'K'");
InstructionDesc[OpCooperativeVectorMatrixMulNV].operands.push(OperandId, "'MemoryLayout'");
InstructionDesc[OpCooperativeVectorMatrixMulNV].operands.push(OperandId, "'Transpose'");
InstructionDesc[OpCooperativeVectorMatrixMulNV].operands.push(OperandId, "'MatrixStride'", true);
InstructionDesc[OpCooperativeVectorMatrixMulNV].operands.push(OperandCooperativeMatrixOperands, "'Cooperative Matrix Operands'", true);

InstructionDesc[OpCooperativeVectorMatrixMulAddNV].operands.push(OperandId, "'Input'");
InstructionDesc[OpCooperativeVectorMatrixMulAddNV].operands.push(OperandId, "'InputInterpretation'");
InstructionDesc[OpCooperativeVectorMatrixMulAddNV].operands.push(OperandId, "'Matrix'");
InstructionDesc[OpCooperativeVectorMatrixMulAddNV].operands.push(OperandId, "'MatrixOffset'");
InstructionDesc[OpCooperativeVectorMatrixMulAddNV].operands.push(OperandId, "'MatrixInterpretation'");
InstructionDesc[OpCooperativeVectorMatrixMulAddNV].operands.push(OperandId, "'Bias'");
InstructionDesc[OpCooperativeVectorMatrixMulAddNV].operands.push(OperandId, "'BiasOffset'");
InstructionDesc[OpCooperativeVectorMatrixMulAddNV].operands.push(OperandId, "'BiasInterpretation'");
InstructionDesc[OpCooperativeVectorMatrixMulAddNV].operands.push(OperandId, "'M'");
InstructionDesc[OpCooperativeVectorMatrixMulAddNV].operands.push(OperandId, "'K'");
InstructionDesc[OpCooperativeVectorMatrixMulAddNV].operands.push(OperandId, "'MemoryLayout'");
InstructionDesc[OpCooperativeVectorMatrixMulAddNV].operands.push(OperandId, "'Transpose'");
InstructionDesc[OpCooperativeVectorMatrixMulAddNV].operands.push(OperandId, "'MatrixStride'", true);
InstructionDesc[OpCooperativeVectorMatrixMulAddNV].operands.push(OperandCooperativeMatrixOperands, "'Cooperative Matrix Operands'", true);

InstructionDesc[OpCooperativeVectorLoadNV].operands.push(OperandId, "'Pointer'");
InstructionDesc[OpCooperativeVectorLoadNV].operands.push(OperandId, "'Offset'");
InstructionDesc[OpCooperativeVectorLoadNV].operands.push(OperandMemoryAccess, "'Memory Access'");
InstructionDesc[OpCooperativeVectorLoadNV].operands.push(OperandLiteralNumber, "", true);
InstructionDesc[OpCooperativeVectorLoadNV].operands.push(OperandId, "", true);

InstructionDesc[OpCooperativeVectorStoreNV].operands.push(OperandId, "'Pointer'");
InstructionDesc[OpCooperativeVectorStoreNV].operands.push(OperandId, "'Offset'");
InstructionDesc[OpCooperativeVectorStoreNV].operands.push(OperandId, "'Object'");
InstructionDesc[OpCooperativeVectorStoreNV].operands.push(OperandMemoryAccess, "'Memory Access'");
InstructionDesc[OpCooperativeVectorStoreNV].operands.push(OperandLiteralNumber, "", true);
InstructionDesc[OpCooperativeVectorStoreNV].operands.push(OperandId, "", true);

InstructionDesc[OpCooperativeVectorOuterProductAccumulateNV].operands.push(OperandId, "'Pointer'");
InstructionDesc[OpCooperativeVectorOuterProductAccumulateNV].operands.push(OperandId, "'Offset'");
InstructionDesc[OpCooperativeVectorOuterProductAccumulateNV].operands.push(OperandId, "'A'");
InstructionDesc[OpCooperativeVectorOuterProductAccumulateNV].operands.push(OperandId, "'B'");
InstructionDesc[OpCooperativeVectorOuterProductAccumulateNV].operands.push(OperandId, "'MemoryLayout'");
InstructionDesc[OpCooperativeVectorOuterProductAccumulateNV].operands.push(OperandId, "'MatrixInterpretation'");
InstructionDesc[OpCooperativeVectorOuterProductAccumulateNV].operands.push(OperandId, "'MatrixStride'", true);

InstructionDesc[OpCooperativeVectorReduceSumAccumulateNV].operands.push(OperandId, "'Pointer'");
InstructionDesc[OpCooperativeVectorReduceSumAccumulateNV].operands.push(OperandId, "'Offset'");
InstructionDesc[OpCooperativeVectorReduceSumAccumulateNV].operands.push(OperandId, "'V'");

InstructionDesc[OpDemoteToHelperInvocationEXT].setResultAndType(false, false);

InstructionDesc[OpReadClockKHR].operands.push(OperandScope, "'Scope'");
Expand Down
Loading

0 comments on commit 1b65bd6

Please sign in to comment.