From 13d9eb5e364b3460c71243ff744eaaa6f8beb10d Mon Sep 17 00:00:00 2001 From: Viktoria Maximova Date: Wed, 5 Feb 2025 18:20:35 +0100 Subject: [PATCH] [Backport to 16] Align translation of `OpCooperativeMatrixLengthKHR` to match the spec (#2964) (#2996) `SPV_KHR_cooperative_matrix` extension defines that the only argument accepted in this instruction is `Matrix Type `, not the pointer to an actual matrix. --- lib/SPIRV/SPIRVReader.cpp | 3 +-- lib/SPIRV/SPIRVWriter.cpp | 4 ++++ lib/SPIRV/libSPIRV/SPIRVModule.cpp | 11 +++++++++++ lib/SPIRV/libSPIRV/SPIRVModule.h | 3 +++ .../cooperative_matrix_checked.ll | 3 +-- .../SPV_KHR_cooperative_matrix/cooperative_matrix.ll | 5 +++++ 6 files changed, 25 insertions(+), 4 deletions(-) diff --git a/lib/SPIRV/SPIRVReader.cpp b/lib/SPIRV/SPIRVReader.cpp index 3e4ae46ee..99976cab6 100644 --- a/lib/SPIRV/SPIRVReader.cpp +++ b/lib/SPIRV/SPIRVReader.cpp @@ -3397,8 +3397,7 @@ Instruction *SPIRVToLLVM::transBuiltinFromInst(const std::string &FuncName, Func->addFnAttr(Attribute::Convergent); } CallInst *Call; - if (BI->getOpCode() == OpCooperativeMatrixLengthKHR && - Ops[0]->getOpCode() == OpTypeCooperativeMatrixKHR) { + if (BI->getOpCode() == OpCooperativeMatrixLengthKHR) { // OpCooperativeMatrixLengthKHR needs special handling as its operand is // a Type instead of a Value. llvm::Type *MatTy = transType(reinterpret_cast(Ops[0])); diff --git a/lib/SPIRV/SPIRVWriter.cpp b/lib/SPIRV/SPIRVWriter.cpp index 404dc0cc4..c639e3ff6 100644 --- a/lib/SPIRV/SPIRVWriter.cpp +++ b/lib/SPIRV/SPIRVWriter.cpp @@ -5943,6 +5943,10 @@ LLVMToSPIRVBase::transBuiltinToInstWithoutDecoration(Op OC, CallInst *CI, transValue(CI->getArgOperand(2), BB), BB); return BM->addStoreInst(transValue(CI->getArgOperand(0), BB), V, {}, BB); } + case OpCooperativeMatrixLengthKHR: { + return BM->addCooperativeMatrixLengthKHRInst( + transScavengedType(CI), transType(CI->getArgOperand(0)->getType()), BB); + } default: { if (isCvtOpCode(OC) && OC != OpGenericCastToPtrExplicit) { return BM->addUnaryInst(OC, transType(CI->getType()), diff --git a/lib/SPIRV/libSPIRV/SPIRVModule.cpp b/lib/SPIRV/libSPIRV/SPIRVModule.cpp index 1cfc55335..61b626b41 100644 --- a/lib/SPIRV/libSPIRV/SPIRVModule.cpp +++ b/lib/SPIRV/libSPIRV/SPIRVModule.cpp @@ -293,6 +293,9 @@ class SPIRVModuleImpl : public SPIRVModule { addJointMatrixINTELType(SPIRVType *, std::vector) override; SPIRVTypeCooperativeMatrixKHR * addCooperativeMatrixKHRType(SPIRVType *, std::vector) override; + SPIRVInstruction * + addCooperativeMatrixLengthKHRInst(SPIRVType *, SPIRVType *, + SPIRVBasicBlock *) override; SPIRVType *addOpaqueGenericType(Op) override; SPIRVTypeDeviceEvent *addDeviceEventType() override; SPIRVTypeQueue *addQueueType() override; @@ -992,6 +995,14 @@ SPIRVModuleImpl::addCooperativeMatrixKHRType(SPIRVType *CompType, new SPIRVTypeCooperativeMatrixKHR(this, getId(), CompType, Args)); } +SPIRVInstruction *SPIRVModuleImpl::addCooperativeMatrixLengthKHRInst( + SPIRVType *RetTy, SPIRVType *MatTy, SPIRVBasicBlock *BB) { + return addInstruction( + SPIRVInstTemplateBase::create(OpCooperativeMatrixLengthKHR, RetTy, + getId(), getVec(MatTy->getId()), BB, this), + BB); +} + SPIRVType *SPIRVModuleImpl::addOpaqueGenericType(Op TheOpCode) { return addType(new SPIRVTypeOpaqueGeneric(TheOpCode, this, getId())); } diff --git a/lib/SPIRV/libSPIRV/SPIRVModule.h b/lib/SPIRV/libSPIRV/SPIRVModule.h index 2b5e44859..ae111b45a 100644 --- a/lib/SPIRV/libSPIRV/SPIRVModule.h +++ b/lib/SPIRV/libSPIRV/SPIRVModule.h @@ -251,6 +251,9 @@ class SPIRVModule { addJointMatrixINTELType(SPIRVType *, std::vector) = 0; virtual SPIRVTypeCooperativeMatrixKHR * addCooperativeMatrixKHRType(SPIRVType *, std::vector) = 0; + virtual SPIRVInstruction * + addCooperativeMatrixLengthKHRInst(SPIRVType *, SPIRVType *, + SPIRVBasicBlock *) = 0; virtual SPIRVTypeVoid *addVoidType() = 0; virtual SPIRVType *addOpaqueGenericType(Op) = 0; virtual SPIRVTypeDeviceEvent *addDeviceEventType() = 0; diff --git a/test/extensions/INTEL/SPV_INTEL_joint_matrix/cooperative_matrix_checked.ll b/test/extensions/INTEL/SPV_INTEL_joint_matrix/cooperative_matrix_checked.ll index 4c481df18..ffc1a8f41 100644 --- a/test/extensions/INTEL/SPV_INTEL_joint_matrix/cooperative_matrix_checked.ll +++ b/test/extensions/INTEL/SPV_INTEL_joint_matrix/cooperative_matrix_checked.ll @@ -24,8 +24,7 @@ ; CHECK-SPIRV-DAG: TypeCooperativeMatrixKHR [[#MatTy3:]] [[#Int8Ty]] [[#Const2]] [[#Const48]] [[#Const12]] [[#Const1]] ; CHECK-SPIRV: CooperativeMatrixConstructCheckedINTEL [[#MatTy1]] ; CHECK-SPIRV: CooperativeMatrixLoadCheckedINTEL [[#MatTy2]] [[#Load1:]] -; TODO: Pass Matrix Type Id instead of Matrix Id to CooperativeMatrixLengthKHR. -; CHECK-SPIRV: CooperativeMatrixLengthKHR [[#Int32Ty]] [[#]] [[#Load1]] +; CHECK-SPIRV: CooperativeMatrixLengthKHR [[#Int32Ty]] [[#]] [[#MatTy2]] ; CHECK-SPIRV: CooperativeMatrixLoadCheckedINTEL [[#MatTy3]] ; CHECK-SPIRV: CooperativeMatrixMulAddKHR [[#MatTy1]] ; CHECK-SPIRV: CooperativeMatrixStoreCheckedINTEL diff --git a/test/extensions/KHR/SPV_KHR_cooperative_matrix/cooperative_matrix.ll b/test/extensions/KHR/SPV_KHR_cooperative_matrix/cooperative_matrix.ll index c571e9e25..9c7e9d9f9 100644 --- a/test/extensions/KHR/SPV_KHR_cooperative_matrix/cooperative_matrix.ll +++ b/test/extensions/KHR/SPV_KHR_cooperative_matrix/cooperative_matrix.ll @@ -23,6 +23,7 @@ ; CHECK-SPIRV-DAG: TypeCooperativeMatrixKHR [[#MatTy3:]] [[#Int8Ty]] [[#Const2]] [[#Const48]] [[#Const12]] [[#Const3]] ; CHECK-SPIRV: CompositeConstruct [[#MatTy1]] ; CHECK-SPIRV: CooperativeMatrixLoadKHR [[#MatTy2]] +; CHECK-SPIRV: CooperativeMatrixLengthKHR [[#Int32Ty]] [[#]] [[#MatTy2]] ; CHECK-SPIRV: CooperativeMatrixLoadKHR [[#MatTy3]] ; CHECK-SPIRV: CooperativeMatrixMulAddKHR [[#MatTy1]] ; CHECK-SPIRV: CooperativeMatrixStoreKHR @@ -30,6 +31,7 @@ ; CHECK-LLVM: call spir_func target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructi(i32 0) ; CHECK-LLVM: call spir_func target("spirv.CooperativeMatrixKHR", i8, 0, 12, 48, 3) @_Z86__spirv_CooperativeMatrixLoadKHR_RPU3AS144__spirv_CooperativeMatrixKHR__char_0_12_48_3PU3AS4clii +; CHECK-LLVM: call spir_func i32 @_Z34__spirv_CooperativeMatrixLengthKHRPU3AS144__spirv_CooperativeMatrixKHR__char_0_12_48_3(target("spirv.CooperativeMatrixKHR", i8, 0, 12, 48, 3) ; CHECK-LLVM: call spir_func target("spirv.CooperativeMatrixKHR", i8, 2, 48, 12, 3) @_Z86__spirv_CooperativeMatrixLoadKHR_RPU3AS144__spirv_CooperativeMatrixKHR__char_2_48_12_3PU3AS4cl ; CHECK-LLVM: call spir_func target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) @_Z34__spirv_CooperativeMatrixMulAddKHR{{.*}}(target("spirv.CooperativeMatrixKHR", i8, 0, 12, 48, 3) %{{.*}}, target("spirv.CooperativeMatrixKHR", i8, 2, 48, 12, 3) %{{.*}}, target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) ; CHECK-LLVM: call spir_func void @_Z33__spirv_CooperativeMatrixStoreKHR{{.*}}(ptr addrspace(4) %call.ascast.i.i, target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) @@ -105,6 +107,7 @@ for.body.i: ; preds = %for.cond.i %add.ptr.i96.i = getelementptr inbounds i8, ptr addrspace(1) %add.ptr.i93.i, i64 %conv13.i %call.ascast.i66.i = addrspacecast ptr addrspace(1) %add.ptr.i96.i to ptr addrspace(4) %call1.i.i = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i8, 0, 12, 48, 3) @_Z32__spirv_CooperativeMatrixLoadKHR_1(ptr addrspace(4) noundef %call.ascast.i66.i, i64 noundef %_arg_K, i32 noundef 0, i32 noundef 1) #4 + %len = tail call spir_func noundef i32 @_Z34__spirv_CooperativeMatrixLengthKHR(target("spirv.CooperativeMatrixKHR", i8, 0, 12, 48, 3) %call1.i.i) %div20.i = mul nsw i32 %k.0.i, 12 %conv21.i = zext i32 %div20.i to i64 %mul23.i = mul i64 %mul22.i, %conv21.i @@ -136,6 +139,8 @@ _ZZZ15matrix_multiplyIiaLm24ELm96ELm24ELm96ELm24ELm24EEvR10big_matrixIT_XT5_EXT6 ; Function Attrs: convergent declare dso_local spir_func noundef target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) @_Z26__spirv_CompositeConstruct(i32 noundef) local_unnamed_addr #2 +declare dso_local spir_func noundef i32 @_Z34__spirv_CooperativeMatrixLengthKHR(target("spirv.CooperativeMatrixKHR", i8, 0, 12, 48, 3) noundef) + ; Function Attrs: convergent declare dso_local spir_func noundef target("spirv.CooperativeMatrixKHR", i8, 0, 12, 48, 3) @_Z32__spirv_CooperativeMatrixLoadKHR_1(ptr addrspace(4) noundef, i64 noundef, i32 noundef, i32 noundef) local_unnamed_addr #2