Skip to content

Commit

Permalink
[Backport to 18] Align translation of OpCooperativeMatrixLengthKHR
Browse files Browse the repository at this point in the history
…to match the spec (#2964) (#2994)

`SPV_KHR_cooperative_matrix` extension defines that the only argument
accepted in this instruction is `Matrix Type <id>`, not the pointer to
an actual matrix.
  • Loading branch information
vmaksimo authored Feb 5, 2025
1 parent 1af600d commit d8fe736
Show file tree
Hide file tree
Showing 7 changed files with 22 additions and 8 deletions.
3 changes: 1 addition & 2 deletions lib/SPIRV/SPIRVReader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3550,8 +3550,7 @@ Instruction *SPIRVToLLVM::transBuiltinFromInst(const std::string &FuncName,
Func->addFnAttr(Attribute::Convergent);
}
CallInst *Call;
if (BI->getOpCode() == OpCooperativeMatrixLengthKHR &&
Ops[0]->getOpCode() == OpTypeCooperativeMatrixKHR) {
if (BI->getOpCode() == OpCooperativeMatrixLengthKHR) {
// OpCooperativeMatrixLengthKHR needs special handling as its operand is
// a Type instead of a Value.
llvm::Type *MatTy = transType(reinterpret_cast<SPIRVType *>(Ops[0]));
Expand Down
4 changes: 4 additions & 0 deletions lib/SPIRV/SPIRVWriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6490,6 +6490,10 @@ LLVMToSPIRVBase::transBuiltinToInstWithoutDecoration(Op OC, CallInst *CI,
return BM->addCompositeConstructInst(transType(CI->getType()), Operands,
BB);
}
case OpCooperativeMatrixLengthKHR: {
return BM->addCooperativeMatrixLengthKHRInst(
transScavengedType(CI), transType(CI->getArgOperand(0)->getType()), BB);
}
default: {
if (isCvtOpCode(OC) && OC != OpGenericCastToPtrExplicit) {
return BM->addUnaryInst(OC, transScavengedType(CI),
Expand Down
11 changes: 11 additions & 0 deletions lib/SPIRV/libSPIRV/SPIRVModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,9 @@ class SPIRVModuleImpl : public SPIRVModule {
addJointMatrixINTELType(SPIRVType *, std::vector<SPIRVValue *>) override;
SPIRVTypeCooperativeMatrixKHR *
addCooperativeMatrixKHRType(SPIRVType *, std::vector<SPIRVValue *>) override;
SPIRVInstruction *
addCooperativeMatrixLengthKHRInst(SPIRVType *, SPIRVType *,
SPIRVBasicBlock *) override;
SPIRVType *addOpaqueGenericType(Op) override;
SPIRVTypeDeviceEvent *addDeviceEventType() override;
SPIRVTypeQueue *addQueueType() override;
Expand Down Expand Up @@ -1049,6 +1052,14 @@ SPIRVModuleImpl::addCooperativeMatrixKHRType(SPIRVType *CompType,
new SPIRVTypeCooperativeMatrixKHR(this, getId(), CompType, Args));
}

SPIRVInstruction *SPIRVModuleImpl::addCooperativeMatrixLengthKHRInst(
SPIRVType *RetTy, SPIRVType *MatTy, SPIRVBasicBlock *BB) {
return addInstruction(
SPIRVInstTemplateBase::create(OpCooperativeMatrixLengthKHR, RetTy,
getId(), getVec(MatTy->getId()), BB, this),
BB);
}

SPIRVType *SPIRVModuleImpl::addOpaqueGenericType(Op TheOpCode) {
return addType(new SPIRVTypeOpaqueGeneric(TheOpCode, this, getId()));
}
Expand Down
3 changes: 3 additions & 0 deletions lib/SPIRV/libSPIRV/SPIRVModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,9 @@ class SPIRVModule {
addJointMatrixINTELType(SPIRVType *, std::vector<SPIRVValue *>) = 0;
virtual SPIRVTypeCooperativeMatrixKHR *
addCooperativeMatrixKHRType(SPIRVType *, std::vector<SPIRVValue *>) = 0;
virtual SPIRVInstruction *
addCooperativeMatrixLengthKHRInst(SPIRVType *, SPIRVType *,
SPIRVBasicBlock *) = 0;
virtual SPIRVTypeVoid *addVoidType() = 0;
virtual SPIRVType *addOpaqueGenericType(Op) = 0;
virtual SPIRVTypeDeviceEvent *addDeviceEventType() = 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@
; CHECK-SPIRV-DAG: TypeCooperativeMatrixKHR [[#MatTy3:]] [[#Int8Ty]] [[#Const2]] [[#Const48]] [[#Const12]] [[#Const1]]
; CHECK-SPIRV: CooperativeMatrixConstructCheckedINTEL [[#MatTy1]]
; CHECK-SPIRV: CooperativeMatrixLoadCheckedINTEL [[#MatTy2]] [[#Load1:]]
; TODO: Pass Matrix Type Id instead of Matrix Id to CooperativeMatrixLengthKHR.
; CHECK-SPIRV: CooperativeMatrixLengthKHR [[#Int32Ty]] [[#]] [[#Load1]]
; CHECK-SPIRV: CooperativeMatrixLengthKHR [[#Int32Ty]] [[#]] [[#MatTy2]]
; CHECK-SPIRV: CooperativeMatrixLoadCheckedINTEL [[#MatTy3]]
; CHECK-SPIRV: CooperativeMatrixMulAddKHR [[#MatTy1]]
; CHECK-SPIRV: CooperativeMatrixStoreCheckedINTEL
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@
; CHECK-SPIRV-DAG: TypeCooperativeMatrixKHR [[#MatTy3:]] [[#Int8Ty]] [[#Const2]] [[#Const48]] [[#Const12]] [[#Const1]]
; CHECK-SPIRV: CompositeConstruct [[#MatTy1]]
; CHECK-SPIRV: CooperativeMatrixLoadKHR [[#MatTy2]] [[#Load1:]]
; TODO: Pass Matrix Type Id instead of Matrix Id to CooperativeMatrixLengthKHR.
; CHECK-SPIRV: CooperativeMatrixLengthKHR [[#Int32Ty]] [[#]] [[#Load1]]
; CHECK-SPIRV: CooperativeMatrixLengthKHR [[#Int32Ty]] [[#]] [[#MatTy2]]
; CHECK-SPIRV: CooperativeMatrixPrefetchINTEL
; CHECK-SPIRV: CooperativeMatrixLoadKHR [[#MatTy3]]
; CHECK-SPIRV: CooperativeMatrixMulAddKHR [[#MatTy1]]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@
; CHECK-SPIRV-DAG: TypeCooperativeMatrixKHR [[#MatTy3:]] [[#Int8Ty]] [[#Const2]] [[#Const48]] [[#Const12]] [[#Const1]]
; CHECK-SPIRV: CompositeConstruct [[#MatTy1]]
; CHECK-SPIRV: CooperativeMatrixLoadKHR [[#MatTy2]] [[#Load1:]]
; TODO: Pass Matrix Type Id instead of Matrix Id to CooperativeMatrixLengthKHR.
; CHECK-SPIRV: CooperativeMatrixLengthKHR [[#Int32Ty]] [[#]] [[#Load1]]
; CHECK-SPIRV: CooperativeMatrixLengthKHR [[#Int32Ty]] [[#]] [[#MatTy2]]
; CHECK-SPIRV: CooperativeMatrixLoadKHR [[#MatTy3]]
; CHECK-SPIRV: CooperativeMatrixMulAddKHR [[#MatTy1]]
; CHECK-SPIRV: CooperativeMatrixStoreKHR
Expand Down

0 comments on commit d8fe736

Please sign in to comment.