Skip to content

Commit

Permalink
[Backport to 16] Align translation of OpCooperativeMatrixLengthKHR
Browse files Browse the repository at this point in the history
…to match the spec (#2964) (#2996)

`SPV_KHR_cooperative_matrix` extension defines that the only argument
accepted in this instruction is `Matrix Type <id>`, not the pointer to
an actual matrix.
  • Loading branch information
vmaksimo authored Feb 5, 2025
1 parent 252b6d2 commit 13d9eb5
Show file tree
Hide file tree
Showing 6 changed files with 25 additions and 4 deletions.
3 changes: 1 addition & 2 deletions lib/SPIRV/SPIRVReader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3397,8 +3397,7 @@ Instruction *SPIRVToLLVM::transBuiltinFromInst(const std::string &FuncName,
Func->addFnAttr(Attribute::Convergent);
}
CallInst *Call;
if (BI->getOpCode() == OpCooperativeMatrixLengthKHR &&
Ops[0]->getOpCode() == OpTypeCooperativeMatrixKHR) {
if (BI->getOpCode() == OpCooperativeMatrixLengthKHR) {
// OpCooperativeMatrixLengthKHR needs special handling as its operand is
// a Type instead of a Value.
llvm::Type *MatTy = transType(reinterpret_cast<SPIRVType *>(Ops[0]));
Expand Down
4 changes: 4 additions & 0 deletions lib/SPIRV/SPIRVWriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5943,6 +5943,10 @@ LLVMToSPIRVBase::transBuiltinToInstWithoutDecoration(Op OC, CallInst *CI,
transValue(CI->getArgOperand(2), BB), BB);
return BM->addStoreInst(transValue(CI->getArgOperand(0), BB), V, {}, BB);
}
case OpCooperativeMatrixLengthKHR: {
return BM->addCooperativeMatrixLengthKHRInst(
transScavengedType(CI), transType(CI->getArgOperand(0)->getType()), BB);
}
default: {
if (isCvtOpCode(OC) && OC != OpGenericCastToPtrExplicit) {
return BM->addUnaryInst(OC, transType(CI->getType()),
Expand Down
11 changes: 11 additions & 0 deletions lib/SPIRV/libSPIRV/SPIRVModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,9 @@ class SPIRVModuleImpl : public SPIRVModule {
addJointMatrixINTELType(SPIRVType *, std::vector<SPIRVValue *>) override;
SPIRVTypeCooperativeMatrixKHR *
addCooperativeMatrixKHRType(SPIRVType *, std::vector<SPIRVValue *>) override;
SPIRVInstruction *
addCooperativeMatrixLengthKHRInst(SPIRVType *, SPIRVType *,
SPIRVBasicBlock *) override;
SPIRVType *addOpaqueGenericType(Op) override;
SPIRVTypeDeviceEvent *addDeviceEventType() override;
SPIRVTypeQueue *addQueueType() override;
Expand Down Expand Up @@ -992,6 +995,14 @@ SPIRVModuleImpl::addCooperativeMatrixKHRType(SPIRVType *CompType,
new SPIRVTypeCooperativeMatrixKHR(this, getId(), CompType, Args));
}

SPIRVInstruction *SPIRVModuleImpl::addCooperativeMatrixLengthKHRInst(
SPIRVType *RetTy, SPIRVType *MatTy, SPIRVBasicBlock *BB) {
return addInstruction(
SPIRVInstTemplateBase::create(OpCooperativeMatrixLengthKHR, RetTy,
getId(), getVec(MatTy->getId()), BB, this),
BB);
}

SPIRVType *SPIRVModuleImpl::addOpaqueGenericType(Op TheOpCode) {
return addType(new SPIRVTypeOpaqueGeneric(TheOpCode, this, getId()));
}
Expand Down
3 changes: 3 additions & 0 deletions lib/SPIRV/libSPIRV/SPIRVModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,9 @@ class SPIRVModule {
addJointMatrixINTELType(SPIRVType *, std::vector<SPIRVValue *>) = 0;
virtual SPIRVTypeCooperativeMatrixKHR *
addCooperativeMatrixKHRType(SPIRVType *, std::vector<SPIRVValue *>) = 0;
virtual SPIRVInstruction *
addCooperativeMatrixLengthKHRInst(SPIRVType *, SPIRVType *,
SPIRVBasicBlock *) = 0;
virtual SPIRVTypeVoid *addVoidType() = 0;
virtual SPIRVType *addOpaqueGenericType(Op) = 0;
virtual SPIRVTypeDeviceEvent *addDeviceEventType() = 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@
; CHECK-SPIRV-DAG: TypeCooperativeMatrixKHR [[#MatTy3:]] [[#Int8Ty]] [[#Const2]] [[#Const48]] [[#Const12]] [[#Const1]]
; CHECK-SPIRV: CooperativeMatrixConstructCheckedINTEL [[#MatTy1]]
; CHECK-SPIRV: CooperativeMatrixLoadCheckedINTEL [[#MatTy2]] [[#Load1:]]
; TODO: Pass Matrix Type Id instead of Matrix Id to CooperativeMatrixLengthKHR.
; CHECK-SPIRV: CooperativeMatrixLengthKHR [[#Int32Ty]] [[#]] [[#Load1]]
; CHECK-SPIRV: CooperativeMatrixLengthKHR [[#Int32Ty]] [[#]] [[#MatTy2]]
; CHECK-SPIRV: CooperativeMatrixLoadCheckedINTEL [[#MatTy3]]
; CHECK-SPIRV: CooperativeMatrixMulAddKHR [[#MatTy1]]
; CHECK-SPIRV: CooperativeMatrixStoreCheckedINTEL
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,15 @@
; CHECK-SPIRV-DAG: TypeCooperativeMatrixKHR [[#MatTy3:]] [[#Int8Ty]] [[#Const2]] [[#Const48]] [[#Const12]] [[#Const3]]
; CHECK-SPIRV: CompositeConstruct [[#MatTy1]]
; CHECK-SPIRV: CooperativeMatrixLoadKHR [[#MatTy2]]
; CHECK-SPIRV: CooperativeMatrixLengthKHR [[#Int32Ty]] [[#]] [[#MatTy2]]
; CHECK-SPIRV: CooperativeMatrixLoadKHR [[#MatTy3]]
; CHECK-SPIRV: CooperativeMatrixMulAddKHR [[#MatTy1]]
; CHECK-SPIRV: CooperativeMatrixStoreKHR


; CHECK-LLVM: call spir_func target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructi(i32 0)
; CHECK-LLVM: call spir_func target("spirv.CooperativeMatrixKHR", i8, 0, 12, 48, 3) @_Z86__spirv_CooperativeMatrixLoadKHR_RPU3AS144__spirv_CooperativeMatrixKHR__char_0_12_48_3PU3AS4clii
; CHECK-LLVM: call spir_func i32 @_Z34__spirv_CooperativeMatrixLengthKHRPU3AS144__spirv_CooperativeMatrixKHR__char_0_12_48_3(target("spirv.CooperativeMatrixKHR", i8, 0, 12, 48, 3)
; CHECK-LLVM: call spir_func target("spirv.CooperativeMatrixKHR", i8, 2, 48, 12, 3) @_Z86__spirv_CooperativeMatrixLoadKHR_RPU3AS144__spirv_CooperativeMatrixKHR__char_2_48_12_3PU3AS4cl
; CHECK-LLVM: call spir_func target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) @_Z34__spirv_CooperativeMatrixMulAddKHR{{.*}}(target("spirv.CooperativeMatrixKHR", i8, 0, 12, 48, 3) %{{.*}}, target("spirv.CooperativeMatrixKHR", i8, 2, 48, 12, 3) %{{.*}}, target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3)
; CHECK-LLVM: call spir_func void @_Z33__spirv_CooperativeMatrixStoreKHR{{.*}}(ptr addrspace(4) %call.ascast.i.i, target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3)
Expand Down Expand Up @@ -105,6 +107,7 @@ for.body.i: ; preds = %for.cond.i
%add.ptr.i96.i = getelementptr inbounds i8, ptr addrspace(1) %add.ptr.i93.i, i64 %conv13.i
%call.ascast.i66.i = addrspacecast ptr addrspace(1) %add.ptr.i96.i to ptr addrspace(4)
%call1.i.i = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i8, 0, 12, 48, 3) @_Z32__spirv_CooperativeMatrixLoadKHR_1(ptr addrspace(4) noundef %call.ascast.i66.i, i64 noundef %_arg_K, i32 noundef 0, i32 noundef 1) #4
%len = tail call spir_func noundef i32 @_Z34__spirv_CooperativeMatrixLengthKHR(target("spirv.CooperativeMatrixKHR", i8, 0, 12, 48, 3) %call1.i.i)
%div20.i = mul nsw i32 %k.0.i, 12
%conv21.i = zext i32 %div20.i to i64
%mul23.i = mul i64 %mul22.i, %conv21.i
Expand Down Expand Up @@ -136,6 +139,8 @@ _ZZZ15matrix_multiplyIiaLm24ELm96ELm24ELm96ELm24ELm24EEvR10big_matrixIT_XT5_EXT6
; Function Attrs: convergent
declare dso_local spir_func noundef target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) @_Z26__spirv_CompositeConstruct(i32 noundef) local_unnamed_addr #2

declare dso_local spir_func noundef i32 @_Z34__spirv_CooperativeMatrixLengthKHR(target("spirv.CooperativeMatrixKHR", i8, 0, 12, 48, 3) noundef)

; Function Attrs: convergent
declare dso_local spir_func noundef target("spirv.CooperativeMatrixKHR", i8, 0, 12, 48, 3) @_Z32__spirv_CooperativeMatrixLoadKHR_1(ptr addrspace(4) noundef, i64 noundef, i32 noundef, i32 noundef) local_unnamed_addr #2

Expand Down

0 comments on commit 13d9eb5

Please sign in to comment.