From e13e13351a8167fefec51974786f3822a6e38281 Mon Sep 17 00:00:00 2001 From: hcman2 <52367956+hcman2@users.noreply.github.com> Date: Thu, 12 Dec 2024 10:36:01 +0800 Subject: [PATCH] LSU supports larger MT and reuse LDS. (#1433) 1. Move LSU into LSU.py. 2. Do partial LSU when the LDS is not enough. --- tensilelite/Tensile/Components/LSU.py | 547 ++++++++++++++++++ tensilelite/Tensile/Components/__init__.py | 3 +- tensilelite/Tensile/KernelWriter.py | 43 +- tensilelite/Tensile/KernelWriterAssembly.py | 275 --------- tensilelite/Tensile/SolutionStructs.py | 6 + .../TensileInstructions/Instructions.py | 2 +- .../Tensile/Tests/common/gemm/lsu.yaml | 159 +++++ 7 files changed, 720 insertions(+), 315 deletions(-) create mode 100644 tensilelite/Tensile/Components/LSU.py diff --git a/tensilelite/Tensile/Components/LSU.py b/tensilelite/Tensile/Components/LSU.py new file mode 100644 index 0000000000..ae57264c0e --- /dev/null +++ b/tensilelite/Tensile/Components/LSU.py @@ -0,0 +1,547 @@ +################################################################################ +# +# Copyright (C) 2024 Advanced Micro Devices, Inc. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell cop- +# ies of the Software, and to permit persons to whom the Software is furnished +# to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IM- +# PLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNE- +# CTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +################################################################################ + +from ..TensileInstructions import Module, Label, RegisterPoolResource, SCmpEQU32, \ + SMovB32, log2, ceilDivide, SCBranchSCC0, Instruction, \ + SAndB32, RegSet, vectorStaticDivide +from ..Component import Component +from ..KernelWriterModules import * +from ..AsmStoreState import StoreState, VectorDataTypes +#import abc + +class LSU(Component): + """ + LSU block. + """ + +class LSUOff(LSU): + kernel = {"LocalSplitU": 1} + +class LSUOn(LSU): + + @classmethod + def matches(cls, writer, debug=False): + return writer.states.kernel["LocalSplitU"] > 1 + + def __call__(self): + assert(0) + + def splitOutputData(self, writer, kernel): + self.LSUelemCoord0 = [] + self.LSUelemCoord1 = [] + self.LSUelements = [] + self.LSUfullVw = [] + (vwdummy, eledummy, self.LSUfullVw, self.LSUelements) = writer.notLocalFullTileElements(kernel, False) + storevw = self.LSUfullVw + atomic = False # atomic is for GSU > 1 + beta = True + vectorDataTypes = VectorDataTypes() + ss = StoreState(writer, kernel, storevw, False, beta, atomic, self.LSUelements, vectorDataTypes, dim=0) + self.LSUelemCoord0, self.LSUelemCoord1 = ss.getStoreElementsInfoForBatch(kernel, self.LSUelements) + + # search for valid lsu wave offset + maxtt1 = 0 + maxtt0 = 0 + maxvc1 = 0 + maxvc0 = 0 + validOffset = -1 + validOffset0 = -1 + validOffset1 = -1 + self.LSUelementsArchIdx = [[] for i in range(4)] + self.LSUelementsPerLSUWave = [] + self.LSUelemCoord0PerLSUWave = [] + self.LSUelemCoord1PerLSUWave = [] + + # Check valid LSU/VW combination + if len(self.LSUelements) >= kernel["LocalSplitU"]: + if kernel["LocalSplitU"] == 4: + idxGrp = 1 + for idxGrp in range(1, len(self.LSUelements)//4 + 1): + for i in range(idxGrp): + i0 = i + i1 = i + 1 * idxGrp + i2 = i + 2 * idxGrp + i3 = i + 3 * idxGrp + offset0 = self.LSUelemCoord0[i0] + self.LSUelemCoord1[i0] * kernel["MacroTile0"] + offset1 = self.LSUelemCoord0[i1] + self.LSUelemCoord1[i1] * kernel["MacroTile0"] + offset2 = self.LSUelemCoord0[i2] + self.LSUelemCoord1[i2] * kernel["MacroTile0"] + offset3 = self.LSUelemCoord0[i3] + self.LSUelemCoord1[i3] * kernel["MacroTile0"] + if (offset3 - offset2 == offset2 - offset1) and (offset2 - offset1 == offset1 - offset0): + validOffset0 = self.LSUelemCoord0[i1] - self.LSUelemCoord0[i0] + validOffset1 = self.LSUelemCoord1[i1] - self.LSUelemCoord1[i0] + if self.LSUelemCoord0[i2] - self.LSUelemCoord0[i1] == validOffset0 \ + and self.LSUelemCoord0[i3] - self.LSUelemCoord0[i2] == validOffset0 \ + and self.LSUelemCoord1[i2] - self.LSUelemCoord1[i1] == validOffset1 \ + and self.LSUelemCoord1[i3] - self.LSUelemCoord1[i2] == validOffset1: + validOffset = offset1 - offset0 + break + if validOffset != -1: + break + for idx in range(0, len(self.LSUelements), 4*idxGrp): + for idx2 in range(idxGrp): + self.LSUelementsArchIdx[0].append(self.LSUfullVw*(idx + idx2)) + self.LSUelementsArchIdx[1].append(self.LSUfullVw*(idx + 1*idxGrp + idx2)) + self.LSUelementsArchIdx[2].append(self.LSUfullVw*(idx + 2*idxGrp + idx2)) + self.LSUelementsArchIdx[3].append(self.LSUfullVw*(idx + 3*idxGrp + idx2)) + self.LSUelementsPerLSUWave.append(self.LSUelements[idx + idx2]) + self.LSUelemCoord0PerLSUWave.append(self.LSUelemCoord0[idx + idx2]) + self.LSUelemCoord1PerLSUWave.append(self.LSUelemCoord1[idx + idx2]) + elif kernel["LocalSplitU"] == 2: + i = 0 + offset0 = self.LSUelemCoord0[i] + self.LSUelemCoord1[i] * kernel["MacroTile0"] + offset1 = self.LSUelemCoord0[i + 1] + self.LSUelemCoord1[i + 1] * kernel["MacroTile0"] + validOffset = offset1 - offset0 + validOffset0 = self.LSUelemCoord0[i + 1] - self.LSUelemCoord0[i] + validOffset1 = self.LSUelemCoord1[i + 1] - self.LSUelemCoord1[i] + for idx in range(0, len(self.LSUelements), 2): + self.LSUelementsArchIdx[0].append(self.LSUfullVw*idx) + self.LSUelementsArchIdx[1].append(self.LSUfullVw*(idx+1)) + self.LSUelementsPerLSUWave.append(self.LSUelements[idx]) + self.LSUelemCoord0PerLSUWave.append(self.LSUelemCoord0[idx]) + self.LSUelemCoord1PerLSUWave.append(self.LSUelemCoord1[idx]) + else: + assert 0, "No valid LSU offset found." + + if validOffset == -1: + assert 0, "No valid LSU offset found." + self.LSUValidOffset0 = validOffset0 + self.LSUValidOffset1 = validOffset1 + return validOffset + + ############################################################################## + # LocalSplitU: Local Write, Read, and Reduction + ############################################################################## + def writeReadReduction(self, writer, kernel): + module = Module("localSplitULocalWriteAndRead") + module.addComment2("LocalSplitU Reduction") + module.add(writer._syncThreads(kernel)) + module.add(Label("localSplitULocalWriteAndRead", "")) + + acc2arch, arch2acc = accToArchMapper(kernel) + + # prepare the data that is to be Reduction in this wave + # the output LSUelementsArchIdx has all arch-indices. + validOffset = self.splitOutputData(writer, kernel) + + numAccIdx = len(self.LSUelementsArchIdx[0]) + numSetAccIdx = ceilDivide(numAccIdx, kernel["LocalSplitUReuseLDS"]) + + # computeStoreVgprs + if kernel["EnableMatrixInstruction"]: + module.add(writer.computeStoreVgprs(kernel)) + else: + # new method. output self.vgprs.coord0InMT/coord1InMT + module.add(writer.computeStoreVgprs(kernel, \ + divisor = kernel["MacroTile0"] // kernel["GlobalWriteVectorWidth"], \ + tid0Scale = kernel["GlobalWriteVectorWidth"], \ + tid1Scale = 1)) + + # Checkout local read resource + bpr = 4 #bytes per register + bytesPerElem = kernel["ProblemType"]["ComputeDataType"].numBytes() + bytesPerVector = self.LSUfullVw * bytesPerElem + numWaves = kernel["MIWaveGroup"][0] * kernel["MIWaveGroup"][1] + regsPerStep = int((bytesPerVector+3)//4) + elementStep = bytesPerVector // bytesPerElem + numTotalAccVgprLdsReduction = len(self.LSUelements)*regsPerStep*(self.LSUfullVw//elementStep) + assert (numTotalAccVgprLdsReduction%kernel["LocalSplitU"]) == 0 + numTotalAccVgprLdsReduction = numTotalAccVgprLdsReduction // kernel["LocalSplitU"] + self.accVgprLdsReduction = writer.vgprPool.checkOutAligned(numTotalAccVgprLdsReduction, 4, "LsuReduction") + module.add(RegSet("v", "vgprLsuReduction", self.accVgprLdsReduction)) + writer.states.c.startVgprValu = self.accVgprLdsReduction + + # Local Read VGPR idx + localReadVgprIdx = 0 + + lsu_id = writer.vgprPool.checkOut(1,"lsu_id") + wave_id = writer.vgprPool.checkOut(1,"wave_id") + tmpVgpr = writer.vgprPool.checkOutAligned(2, 2, "tmpVgpr") + tmpVgprRes = RegisterPoolResource(tmpVgpr, 2) + + module.add(vectorStaticDivide(wave_id, "Serial", \ + kernel["WavefrontSize"], tmpVgprRes)) + module.add(vectorStaticDivide(lsu_id, wave_id, numWaves, tmpVgprRes, \ + comment="Get LSU wave ID")) + module.add(VAndB32(vgpr(wave_id), hex(numWaves - 1), vgpr(wave_id), \ + comment="Get wave ID")) + + for reUseIdx in range(kernel["LocalSplitUReuseLDS"]): + module.addComment1("LocalSplitU: local write %d/%d"%(reUseIdx+1,kernel["LocalSplitUReuseLDS"])) + module.add(Label("localSplitULocalWriteAndRead_%d"%(reUseIdx+1), "")) + + startLSUaccIdxSet = reUseIdx * numSetAccIdx + endLSUaccIdxSet = min(numAccIdx, startLSUaccIdxSet + numSetAccIdx) + + #scan the needed accVGPRIdx + neededAccVGPRIdx = [[] for i in range(kernel["LocalSplitU"])] + numAccVgpr = 0 + for lsu in range(kernel["LocalSplitU"]): + for i in range(startLSUaccIdxSet, endLSUaccIdxSet): + for j in range(self.LSUfullVw): + accIdx = arch2acc[self.LSUelementsArchIdx[lsu][i] + j] + neededAccVGPRIdx[lsu].append(accIdx) + numAccVgpr += 1 + + # lsuProcessOffset is used when local read + numVgprPerLSU = len(neededAccVGPRIdx[0]) + lsuProcessOffset = numVgprPerLSU * kernel["WavefrontSize"] * 4 + + assert numAccVgpr > 0,"startLSUaccIdxSet=%u,endLSUaccIdxSet=%u,numAccIdx=%u"%(startLSUaccIdxSet,endLSUaccIdxSet,numAccIdx) + accVgprRes = writer.vgprPool.checkOutAligned(numAccVgpr, 4, "accLSUVgprRes") + + destIdx = 0 + for lsu in range(kernel["LocalSplitU"]): + for i in range(numVgprPerLSU): + srcIdx = neededAccVGPRIdx[lsu][i] + if not kernel["MIArchVgpr"]: + accStr = accvgpr(srcIdx) + module.add(VAccvgprReadB32(dst=vgpr(accVgprRes+destIdx), + src=accStr, + comment="copy acc[%u] to vreg[%u], LSU%u will process" % (srcIdx,destIdx,lsu))) + else: + module.add(VMovB32(dst=vgpr(accVgprRes+destIdx), + src=vgpr("ValuC+%u"%srcIdx), + comment="copy MI out reg to vreg[%u], LSU%u will process" % (destIdx,lsu))) + destIdx += 1 + + dataPerWave = numAccVgpr * kernel["WavefrontSize"] * 4 + ldsStride = dataPerWave * numWaves + + addr = writer.vgprPool.checkOut(1,"addr") + + # Prepare Write/Read instruction info + if bytesPerVector % 16 == 0: + DSStoreBX = DSStoreB128 + DSLoadBX = DSLoadB128 + numInstPerVW = bytesPerVector // 16 + regsPerStore = 4 + elif bytesPerVector % 8 == 0: + DSStoreBX = DSStoreB64 + DSLoadBX = DSLoadB64 + numInstPerVW = bytesPerVector // 8 + regsPerStore = 2 + else: + DSStoreBX = DSStoreB32 + DSLoadBX = DSLoadB32 + numInstPerVW = bytesPerVector // 4 + regsPerStore = 1 + + with writer.allocTmpSgpr(1) as tmpSgprInfo: + tmpSgpr = tmpSgprInfo.idx + module.add(SMovB32(dst=sgpr(tmpSgpr), src=hex(dataPerWave), \ + comment="dataPerWave (%d)"%dataPerWave)) + module.add(VAndB32(vgpr(addr), hex(kernel["WavefrontSize"]-1), vgpr("Serial"), \ + comment="initial addr")) + module.add(VMulLOU32(dst=vgpr(tmpVgpr), src0=sgpr(tmpSgpr), src1=vgpr(wave_id), \ + comment="tmp = waveId * dataPerWave")) + module.add(VLShiftLeftAddU32(dst=vgpr(addr), shiftHex=log2(regsPerStore * bpr), src0=vgpr(addr), src1=vgpr(tmpVgpr), \ + comment="addr = initial addr + tmp")) + module.add(SMovB32(dst=sgpr(tmpSgpr), src=hex(ldsStride), \ + comment="ldsStride = waveNum * dataPerWave (%d)"%ldsStride)) + module.add(VMulLOU32(dst=vgpr(tmpVgpr), src0=sgpr(tmpSgpr), src1=vgpr(lsu_id), \ + comment="tmp = (waveNum * dataPerWave) * lsu_id")) + module.add(VAddU32(vgpr(addr), vgpr(tmpVgpr), vgpr(addr), \ + comment="addr += tmp")) + + module.add(SWaitCnt(lgkmcnt=0, vscnt=0, comment="wait for all writes")) + module.add(writer._syncThreads(kernel, "pre-lsu local write")) + + module.add(Label("localSplitULocalWrite_%d"%(reUseIdx+1), "")) + + # Do Local Write + for i in range(0, numAccVgpr // self.LSUfullVw): + for v in range(numInstPerVW): + regIdx = (i * numInstPerVW + v) * regsPerStore + module.add(DSStoreBX(dstAddr=vgpr(addr), src=vgpr(accVgprRes+regIdx, regsPerStore), \ + ds=DSModifiers(offset=(regIdx * (bpr * kernel["WavefrontSize"]))), \ + comment="arch[%d]"%(i * numInstPerVW + v))) + + # Release local write resource + writer.vgprPool.checkIn(accVgprRes) + + module.addComment1("LocalSplitU: local read %d/%d"%(reUseIdx+1,kernel["LocalSplitUReuseLDS"])) + + # Calculate offset for wave id and lsu id + with writer.allocTmpSgpr(1) as tmpSgprInfo: + tmpSgpr = tmpSgprInfo.idx + module.add(VAndB32(vgpr(addr), hex(kernel["WavefrontSize"]-1), vgpr("Serial"), \ + comment="initial addr")) + module.add(SMovB32(dst=sgpr(tmpSgpr), src=hex(dataPerWave), \ + comment="wave offset (%d)"%dataPerWave)) + module.add(VMulLOU32(dst=vgpr(tmpVgpr), src0=sgpr(tmpSgpr), src1=vgpr(wave_id), \ + comment="wave offset = wave_id * wave offset")) + module.add(VLShiftLeftAddU32(dst=vgpr(addr), shiftHex=log2(regsPerStore * bpr), src0=vgpr(addr), src1=vgpr(tmpVgpr), \ + comment="addr = initial addr + wave offset")) + module.add(SMovB32(dst=sgpr(tmpSgpr), \ + src=hex(lsuProcessOffset), comment="LSU Process Offset %d"%(lsuProcessOffset))) + module.add(VMulLOU32(dst=vgpr(tmpVgpr), src0=sgpr(tmpSgpr), src1=vgpr(lsu_id), \ + comment="lsu offset = lsu_id * LSU Process Offset")) + module.add(VAddU32(dst=vgpr(addr), src0=vgpr(addr), src1=vgpr(tmpVgpr), \ + comment="addr += lsu offset")) + + module.add(SWaitCnt(lgkmcnt=0, vscnt=0, comment="wait for all writes")) + module.add(writer._syncThreads(kernel, "post-lsu local write")) + module.add(Label("localSplitULocalRead_%d"%(reUseIdx+1), "")) + + moduleReduction = Module("LocalSplitU_Reduction") + inLoopTmpVgpr = writer.vgprPool.checkOutAligned(numVgprPerLSU*(kernel["LocalSplitU"]-1), 4, "TempLsuReduction") + + # Do Local Read + for i in range(0, numVgprPerLSU // self.LSUfullVw): + for v in range(numInstPerVW): + for r in range(0, kernel["LocalSplitU"]): + regIdx = (i * numInstPerVW + v) * regsPerStore + offset = r * ldsStride + regIdx * (bpr * kernel["WavefrontSize"]) + if r == 0: + vgprStr = "LsuReduction+%u"%(localReadVgprIdx) + else: + vgprStr = inLoopTmpVgpr + (numVgprPerLSU * (r - 1) + regIdx) + module.add(DSLoadBX(dst=vgpr(vgprStr, regsPerStore), src=vgpr(addr), \ + ds=DSModifiers(offset=(offset)), \ + comment="r=%u i=%u, from acc[%d]"%(r, (i * numInstPerVW + v), neededAccVGPRIdx[0][(i * numInstPerVW + v)]))) + # Generate Reduction code at the same time. + if r == 0: + # Insert waitcnt code here + numTotalInst = numVgprPerLSU // self.LSUfullVw * numInstPerVW * kernel["LocalSplitU"] + numPassedInst = (i * numInstPerVW + (v + 1)) * kernel["LocalSplitU"] + numLRWaitCnt = numTotalInst - numPassedInst + moduleReduction.add(SWaitCnt(lgkmcnt=numLRWaitCnt, comment="wait count is (%u-%u)"%(numTotalInst, numPassedInst))) + if writer.states.archCaps["SeparateVscnt"]: + moduleReduction.add(SWaitCnt(vscnt=numLRWaitCnt)) + if r > 0: + for regToAdd in range(regsPerStore): + if kernel["ProblemType"]["ComputeDataType"].isSingle(): + moduleReduction.add(VAddF32(dst=vgpr("LsuReduction+%u"%(localReadVgprIdx+regToAdd)), src0=vgpr(vgprStr+regToAdd), \ + src1=vgpr("LsuReduction+%u"%(localReadVgprIdx+regToAdd)), comment="")) + elif kernel["ProblemType"]["ComputeDataType"].isInt32(): + moduleReduction.add(VAddI32(dst=vgpr("LsuReduction+%u"%(localReadVgprIdx+regToAdd)), src0=vgpr(vgprStr+regToAdd), \ + src1=vgpr("LsuReduction+%u"%(localReadVgprIdx+regToAdd)), comment="")) + else: + # TODO: hpa_half, int8 + assert(0) # unsupported data type, need to modify here and LSU write/read code + localReadVgprIdx += regsPerStore + + # Release write/read resource + writer.vgprPool.checkIn(addr) + + # Do Reduction + module.add(moduleReduction) + + # Release reduction resource + writer.vgprPool.checkIn(inLoopTmpVgpr) + + # Release all resource + writer.vgprPool.checkIn(lsu_id) + writer.vgprPool.checkIn(wave_id) + writer.vgprPool.checkIn(tmpVgpr) + # reset vgprValuC register + module.add(RegSet("v", "vgprValuC", self.accVgprLdsReduction)) + + return module + + ############################################################################## + # LocalSplitU: Global Write Indices + ############################################################################## + def globalWriteIndices(self, writer, kernel): + module = Module("localSplitUGlobalWriteIndices") + + # Add LSU Offset back + packedC1 = kernel["PackedC1IndicesX"] + strideC1 = "StrideC%s" % (writer.states.indexChars[packedC1[0]]) + strideD1 = "StrideD%s" % (writer.states.indexChars[packedC1[0]]) + wave_id = writer.vgprPool.checkOut(1, "tmpWaveID") + tmpVgpr1 = writer.vgprPool.checkOutAligned(2,2,"tmpVgpr1") + tmpVgpr1Res = RegisterPoolResource(tmpVgpr1, 2) + module.add(vectorStaticDivide(wave_id, "Serial", kernel["WavefrontSize"], tmpVgpr1Res)) + numWaves = kernel["MIWaveGroup"][0] * kernel["MIWaveGroup"][1] + module.add(vectorStaticDivide(wave_id, wave_id, numWaves, tmpVgpr1Res)) + + with writer.allocTmpSgpr(1) as tmpSgprInfo: + tmpSgpr = tmpSgprInfo.idx + if self.LSUValidOffset0 > 0: + module.add(SMovB32(dst=sgpr(tmpSgpr), \ + src=hex(self.LSUValidOffset0), comment="a valid offset")) + module.add(VMulLOU32(dst=vgpr(tmpVgpr1), src0=vgpr(wave_id), src1=sgpr(tmpSgpr), comment="wave LSU offset")) + module.add(VAddU32(dst=vgpr(writer.vgprs.coord0), src0=vgpr(tmpVgpr1), src1=vgpr(writer.vgprs.coord0), comment="coord0 += LSU offset0")) + else: + module.addComment0("valid offset coord0 is zero.") + + if self.LSUValidOffset1 > 0: + module.add(SMovB32(dst=sgpr(tmpSgpr), \ + src=hex(self.LSUValidOffset1), comment="a valid offset")) + module.add(VMulLOU32(dst=vgpr(tmpVgpr1), src0=vgpr(wave_id), src1=sgpr(tmpSgpr), comment="wave LSU offset")) + module.add(VAddU32(dst=vgpr(writer.vgprs.coord1), src0=vgpr(tmpVgpr1), src1=vgpr(writer.vgprs.coord1), comment="coord1 += LSU offset1")) + module.add(VAddU32(dst=vgpr(writer.vgprs.coord1InMT), src0=vgpr(tmpVgpr1), src1=vgpr(writer.vgprs.coord1InMT), comment="coord1InMT += LSU offset1")) + + # this code is from CouputeStoreVgprs. coord 1 : offset part + packedC1 = kernel["PackedC1IndicesX"] + strideC1 = "StrideC%s" % (writer.states.indexChars[packedC1[0]]) + strideD1 = "StrideD%s" % (writer.states.indexChars[packedC1[0]]) + module.add(VMulLOU32(dst=vgpr(writer.vgprs.cinRowPtr), src0=vgpr(writer.vgprs.coord1InMT), src1=sgpr(strideC1), comment=" offset 1")) + module.add(VMulLOU32(dst=vgpr(writer.vgprs.coutRowPtrD), src0=vgpr(writer.vgprs.coord1InMT), src1=sgpr(strideD1), comment=" offset 1")) + if kernel["ProblemType"]["UseE"] and (kernel["GlobalSplitU"] == 1): + module.add(VMovB32(dst=vgpr(writer.vgprs.coutRowPtrE), src=vgpr(writer.vgprs.coord1InMT), comment=" save offset 1 for E")) + if writer.vgprs.coutRowPtrBias != -1: + index = packedC1[0] - 1 + strideW1 = "Size%s" % "I" if index == 0 else ("J" if index == 1 else (writer.states.indexChars[index])) + module.add(VMulLOU32(dst=vgpr(writer.vgprs.coutRowPtrBias), src0=vgpr(writer.vgprs.coord1InMT), src1=sgpr(strideW1), comment=" offset 1")) + else: + module.addComment0("valid offset coord1 is zero.") + + writer.vgprPool.checkIn(tmpVgpr1) + writer.vgprPool.checkIn(wave_id) + writer.vgprPool.checkIn(writer.vgprs.coord0InMT) + writer.vgprPool.checkIn(writer.vgprs.coord1InMT) + + if kernel["BufferStore"]: + #print "----AddressC-LocalSplitU" + #print self.vgprPool.state() + writer.vgprs.addrE = -1 + writer.vgprs.addrD = -1 + writer.vgprs.addrC = -1 + writer.vgprs.addrBias = -1 + writer.vgprs.addrScaleAVec = -1 + writer.vgprs.addrScaleBVec = -1 + writer.vgprs.addrScaleAlphaVec = -1 + else: + writer.vgprs.addrD = writer.vgprPool.checkOut(2) + module.add(VMovB32( + dst=vgpr(writer.vgprs.addrD+0), \ + src=sgpr("AddressD+0"), \ + comment="sgpr -> vgpr")) + module.add(VMovB32( + dst=vgpr(writer.vgprs.addrD+1), \ + src=sgpr("AddressD+1"), \ + comment="sgpr -> vgpr")) + writer.vgprs.addrC = writer.vgprPool.checkOut(2) + module.add(VMovB32( + dst=vgpr(writer.vgprs.addrC+0), \ + src=sgpr("AddressC+0"), \ + comment="sgpr -> vgpr")) + module.add(VMovB32( + dst=vgpr(writer.vgprs.addrC+1), \ + src=sgpr("AddressC+1"), \ + comment="sgpr -> vgpr")) + + if kernel["GlobalSplitU"] > 0: + gsuLabel = Label(label=writer.labels.getNameInc("GSU"), comment="") + with writer.allocTmpSgpr(1) as tmpSgprGSU: + module.add(SAndB32(dst=sgpr(tmpSgprGSU.idx), src0=sgpr("GSU"), src1=hex(0x3FFF), comment="Restore GSU")) + module.add(SCmpEQU32(src0=sgpr(tmpSgprGSU.idx), src1=1, comment="GSU == 1 ?")) + module.add(SCBranchSCC0(labelName=gsuLabel.getLabelName(), comment="branch if GSU != 1")) + if kernel["ProblemType"]["UseE"]: + writer.vgprs.addrE = writer.vgprPool.checkOut(2, 'addrE') + module.add(VMovB32( \ + dst=vgpr(writer.vgprs.addrE+0), \ + src=sgpr("AddressE+0"), \ + comment="sgpr -> vgpr")) + module.add(VMovB32( \ + dst=vgpr(writer.vgprs.addrE+1), \ + src=sgpr("AddressE+1"), \ + comment="sgpr -> vgpr")) + if writer.states.useBias == DataDirection.READ: + writer.vgprs.addrBias = writer.vgprPool.checkOut(2, 'addrBias') + module.add(VMovB32( \ + dst=vgpr(writer.vgprs.addrBias+0), \ + src=sgpr("AddressBias+0"), \ + comment="sgpr -> vgpr")) + module.add(VMovB32( \ + dst=vgpr(writer.vgprs.addrBias+1), \ + src=sgpr("AddressBias+1"), \ + comment="sgpr -> vgpr")) + if (kernel["ProblemType"]["UseScaleAB"] == "Vector"): + writer.vgprs.addrScaleAVec = writer.vgprPool.checkOut(2, 'addrScaleAVec') + module.add(VMovB32( \ + dst=vgpr(writer.vgprs.addrScaleAVec+0), \ + src=sgpr("AddressScaleA+0"), \ + comment="sgpr -> vgpr")) + module.add(VMovB32( \ + dst=vgpr(writer.vgprs.addrScaleAVec+1), \ + src=sgpr("AddressScaleA+1"), \ + comment="sgpr -> vgpr")) + writer.vgprs.addrScaleBVec = writer.vgprPool.checkOut(2, 'addrScaleVVec') + module.add(VMovB32( \ + dst=vgpr(writer.vgprs.addrScaleBVec+0), \ + src=sgpr("AddressScaleB+0"), \ + comment="sgpr -> vgpr")) + module.add(VMovB32( \ + dst=vgpr(writer.vgprs.addrScaleBVec+1), \ + src=sgpr("AddressScaleB+1"), \ + comment="sgpr -> vgpr")) + if kernel["ProblemType"]["UseScaleAlphaVec"]: + writer.vgprs.addrScaleAlphaVec = writer.vgprPool.checkOut(2, 'addrScaleAlphaVec') + module.add(VMovB32( \ + dst=vgpr(self.vgprs.addrScaleAlphaVec+0), \ + src=sgpr("AddressScaleAlphaVec+0"), \ + comment="sgpr -> vgpr")) + module.add(VMovB32( \ + dst=vgpr(self.vgprs.addrScaleAlphaVec+1), \ + src=sgpr("AddressScaleAlphaVec+1"), \ + comment="sgpr -> vgpr")) + if kernel["GlobalSplitU"] > 0: + module.add(gsuLabel) + + return module + + + ############################################################################## + # LocalSplitU: Global Write + ############################################################################## + def globalWrite(self, writer, kernel, tPA, tPB): + if not writer.do["PostLoop"]: return "" + + elements_0 = [[] for y in range(2)] + elements_1 = [[] for y in range(2)] + elements_f0 = [[] for y in range(2)] + elements_f1 = [[] for y in range(2)] + (fullVw, elements_0[False], fullVw_1, elements_1[False]) = writer.notLocalFullTileElements(kernel, False) + (edgeVw, elements_0[True], edgeVw_1, elements_1[True] ) = writer.notLocalFullTileElements(kernel, True) + edgeScaled_0 = len(elements_0[True]) // len(elements_1[False]) + edgeScaled_1 = len(elements_1[True]) // len(elements_1[False]) + noEgScaled_0 = len(elements_0[False]) // len(elements_1[False]) + + for i in range(0, len(elements_1[False])): + element = elements_1[False][i] + if element in self.LSUelementsPerLSUWave: + elements_f1[False].append(element) + for j in range(0, edgeScaled_0): + # in general, edge will affect vc0 dimension. + element = elements_0[True][i*edgeScaled_0+j] + elements_f0[True].append(element) + for j in range(0, edgeScaled_1): + # in general, edge will affect vc0 dimension. + element = elements_1[True][i*edgeScaled_1+j] + elements_f1[True].append(element) + for j in range(0, noEgScaled_0): + # in general, edge will affect vc0 dimension. + element = elements_0[False][i*noEgScaled_0+j] + elements_f0[False].append(element) + + vectorWidths = [fullVw, edgeVw] + vectorWidths_1 = [fullVw_1, edgeVw_1] + + noGSUBranch = (kernel["GlobalSplitU"] == 0) + module = Module("localSplitUGlobalWrite") + module.add(writer.globalWriteElements(kernel, tPA, tPB, vectorWidths, vectorWidths_1, elements_f0, elements_f1, noGSUBranch=noGSUBranch)) + writer.cleanupGlobalWrite(kernel) + writer.vgprPool.checkIn(self.accVgprLdsReduction) + return module diff --git a/tensilelite/Tensile/Components/__init__.py b/tensilelite/Tensile/Components/__init__.py index fe3f952c66..b7a2bd7fd8 100644 --- a/tensilelite/Tensile/Components/__init__.py +++ b/tensilelite/Tensile/Components/__init__.py @@ -48,5 +48,6 @@ def use(): pass "SumUnroll", "GSU", "StreamK", - "PersistentLoop" + "PersistentLoop", + "LSU", ] diff --git a/tensilelite/Tensile/KernelWriter.py b/tensilelite/Tensile/KernelWriter.py index b3e3716e4b..572503ca53 100644 --- a/tensilelite/Tensile/KernelWriter.py +++ b/tensilelite/Tensile/KernelWriter.py @@ -2824,29 +2824,17 @@ def kernelBody( self, kernel, tensorParametersA, tensorParametersB ): #################################### #if kernel["NumThreads"]%kernel["MacroTile0"] == 0: if kernel["LocalSplitU"] > 1: - module.addComment2("LocalSplitU Reduction") - module.add(self._syncThreads(kernel)) - - # LocalSplitU: local write - module.addComment1("LocalSplitU: local write") - module.add(self.localSplitULocalWrite(kernel)) - - # LocalSplitU: local read - module.addComment1("LocalSplitU: local read") - module.add(self.localSplitULocalRead(kernel)) + module.addComment1("LocalSplitU: local write and read") + lsuComponent = Component.LSU.find(self) + module.add(lsuComponent.writeReadReduction(self, kernel)) # LocalSplitU: global write indices - # Hide instructions in local read latency module.addComment1("LocalSplitU: global write indices") - module.add(self.localSplitUGlobalWriteIndices(kernel)) - - # LocalSplitU: Reduction - module.addComment1("LocalSplitU: reduction") - module.add(self.localSplitUReduction(kernel)) + module.add(lsuComponent.globalWriteIndices(self, kernel)) # LocalSplitU: global write module.addComment1("LocalSplitU: global write") - module.add(self.localSplitUGlobalWrite(kernel, tensorParametersA, tensorParametersB)) + module.add(lsuComponent.globalWrite(self, kernel, tensorParametersA, tensorParametersB)) else: #################################### @@ -4839,27 +4827,6 @@ def localReadDo(self, kernel, bufferIdx, innerUnrollIndex, epsi, tP): def shiftVectorComponents(self, kernel, tP): return "" - ############################################################################## - # LocalSplitU: Local Write - ############################################################################## - @abc.abstractmethod - def localSplitULocalWrite(self, kernel): - return "" - - ############################################################################## - # LocalSplitU: Local Read - ############################################################################## - @abc.abstractmethod - def localSplitULocalRead(self, kernel): - return "" - - ############################################################################## - # LocalSplitU: Reduction - ############################################################################## - @abc.abstractmethod - def localSplitUReduction(self, kernel): - return "" - ############################################################################## # globalWriteWorkGroupInit: # Perform work-group granularity init diff --git a/tensilelite/Tensile/KernelWriterAssembly.py b/tensilelite/Tensile/KernelWriterAssembly.py index 3121fd06a1..77aa0b5c4f 100644 --- a/tensilelite/Tensile/KernelWriterAssembly.py +++ b/tensilelite/Tensile/KernelWriterAssembly.py @@ -7924,281 +7924,6 @@ def shiftVectorComponents(self, kernel, tP): if component: return component(self, kernel, tP) - ############################################################################## - # LocalSplitU: Local Write - ############################################################################## - def localSplitULocalWrite(self, kernel): - module = Module("localSplitULocalWrite") - # wait for summation to be done with lds before writing reduction values - module.add(self._syncThreads(kernel, "pre-lsu local write")) - module.add(Label("localSplitULocalWrite", "")) - - tmpVgpr = self.vgprPool.checkOutAligned(2, 2, "tmpVgpr") - tmpVgprRes = RegisterPoolResource(tmpVgpr, 2) - lsu_id = self.vgprPool.checkOut(1,"lsu_id") - addr = self.vgprPool.checkOut(1,"addr") - self.lsuCoordOffset = self.vgprPool.checkOut(1,"lsuCoordOffset") - lr1 = self.vgprPool.checkOut(1,"lr1") - acc2arch, _ = accToArchMapper(kernel) - NumAccVgprRes = len(acc2arch)*kernel["MIRegPerOut"] - accVgprRes = self.vgprPool.checkOutAligned(NumAccVgprRes, 4, "accLSUVgprRes") - for i in range(len(acc2arch)): - for r in range(kernel["MIRegPerOut"]): - destIdx = (acc2arch[i]) * kernel["MIRegPerOut"] + r - srcIdx = ((i * kernel["MIRegPerOut"] + r)) - if not kernel["MIArchVgpr"]: - accStr = accvgpr(srcIdx) - module.add(VAccvgprReadB32(dst=vgpr(accVgprRes+destIdx), - src=accStr, - comment="copy acc to vreg[%u]" % destIdx)) - else: - module.add(VMovB32(dst=vgpr(accVgprRes+destIdx), - src=vgpr("ValuC+%u"%srcIdx), - comment="copy MI out reg to vreg[%u]" % destIdx)) - - ldsStride = kernel["MacroTile0"]*kernel["MacroTile1"] - numWaves = kernel["MIWaveGroup"][0] * kernel["MIWaveGroup"][1] - waveOffset = ldsStride // numWaves - - # new method. output self.vgprs.coord0InMT/coord1InMT - if kernel["EnableMatrixInstruction"]: - module.add(self.computeStoreVgprs(kernel)) - else: - # new method. output self.vgprs.coord0InMT/coord1InMT - # lr0 = serial % SG0 - module.add(self.computeStoreVgprs(kernel, \ - divisor = kernel["MacroTile0"] // kernel["GlobalWriteVectorWidth"], \ - tid0Scale = kernel["GlobalWriteVectorWidth"], \ - tid1Scale = 1)) - - self.LSUelemCoord0 = [] - self.LSUelemCoord1 = [] - self.LSUelements = [] - self.LSUfullVw = [] - (vwdummy, eledummy, self.LSUfullVw, self.LSUelements) = self.notLocalFullTileElements(kernel, False) - storevw = self.LSUfullVw - atomic = False # atomic is for GSU > 1 - beta = True - vectorDataTypes = VectorDataTypes() - ss = StoreState(self, kernel, storevw, False, beta, atomic, self.LSUelements, vectorDataTypes, dim=0) - self.LSUelemCoord0, self.LSUelemCoord1 = ss.getStoreElementsInfoForBatch(kernel, self.LSUelements) - - with self.allocTmpSgpr(1) as tmpSgprInfo: - tmpSgpr = tmpSgprInfo.idx - - # lr1 = serial / kernel["WavefrontSize"] - module.add(vectorStaticDivide(lr1, "Serial", \ - kernel["WavefrontSize"], tmpVgprRes)) - - module.add(vectorStaticDivide(lsu_id, lr1, \ - numWaves, tmpVgprRes, comment="Get LSU wave ID")) - - module.add(SMovB32(dst=sgpr(tmpSgpr), \ - src=hex(ldsStride), comment="MT0*MT1")) - module.add(VMulLOU32(dst=vgpr(addr), src0=sgpr(tmpSgpr), src1=vgpr(lsu_id), \ - comment="lsu_id *= MT0*MT1")) - - module.add(SMovB32(dst=sgpr(tmpSgpr), \ - src=hex(kernel["MacroTile0"]), comment="MT0")) - module.add(VMulLOU32(dst=vgpr(self.lsuCoordOffset), src0=sgpr(tmpSgpr), src1=vgpr(self.vgprs.coord1InMT), \ - comment="MT0*coord1InMT")) - module.add(VAddU32(dst=vgpr(self.lsuCoordOffset), src0=vgpr(self.vgprs.coord0InMT), src1=vgpr(self.lsuCoordOffset), comment="coord0InMT")) - - #thread offset - module.add(VAddLShiftLeftU32(dst=vgpr(addr), src0=vgpr(self.lsuCoordOffset), src1=vgpr(addr), shiftHex=hex(log2(self.states.bpeCinternal)), comment="local write LDS address")) - - self.vgprPool.checkIn(lr1) - self.vgprPool.checkIn(lsu_id) - self.vgprPool.checkIn(tmpVgpr) - - bytesPerElem = kernel["ProblemType"]["ComputeDataType"].numBytes() - regsPerElem = kernel["ProblemType"]["ComputeDataType"].numRegisters() - bytesPerVector = storevw * bytesPerElem - for i in range(0, len(self.LSUelements)): - (tt1, tt0, vc1, vc0) = self.LSUelements[i] - writeOffset = self.LSUelemCoord0[i] + self.LSUelemCoord1[i] * kernel["MacroTile0"] - regIdx = int(i * regsPerElem * storevw) - regIdxStep = 0 - resedualBPV = bytesPerVector - while resedualBPV > 0: - bps = min(resedualBPV, 16) - regsPerStep = int((bps+3)//4) - DSStoreBX = {128: DSStoreB128, - 64: DSStoreB64, - 32: DSStoreB32, - 16: DSStoreB16, - 8: DSStoreB8}[bps*8] - module.add(DSStoreBX(dstAddr=vgpr(addr), src=vgpr(accVgprRes+regIdx+regIdxStep, regsPerStep), \ - ds=DSModifiers(offset=(writeOffset*self.states.bpeCinternal+(regIdxStep*4))), - comment="tt1=%u tt0=%u vc1=%u vc0=%u"%(tt1, tt0, vc1, vc0))) - regIdxStep += regsPerStep - resedualBPV -= bps - - self.vgprPool.checkIn(accVgprRes) - self.vgprPool.checkIn(addr) - return module - - ############################################################################## - # LocalSplitU: Local Read - ############################################################################## - def localSplitULocalRead(self, kernel): - # search for valid lsu wave offset - maxtt1 = 0 - maxtt0 = 0 - maxvc1 = 0 - maxvc0 = 0 - validOffset = -1 - validOffset0 = -1 - validOffset1 = -1 - self.LSUelementsPerLSUWave = [] - self.LSUelemCoord0PerLSUWave = [] - self.LSUelemCoord1PerLSUWave = [] - # Check valid LSU/VW combination - if len(self.LSUelements) >= kernel["LocalSplitU"]: - if kernel["LocalSplitU"] == 4: - idxGrp = 1 - for idxGrp in range(1, len(self.LSUelements)//4 + 1): - for i in range(idxGrp): - i0 = i - i1 = i + 1 * idxGrp - i2 = i + 2 * idxGrp - i3 = i + 3 * idxGrp - offset0 = self.LSUelemCoord0[i0] + self.LSUelemCoord1[i0] * kernel["MacroTile0"] - offset1 = self.LSUelemCoord0[i1] + self.LSUelemCoord1[i1] * kernel["MacroTile0"] - offset2 = self.LSUelemCoord0[i2] + self.LSUelemCoord1[i2] * kernel["MacroTile0"] - offset3 = self.LSUelemCoord0[i3] + self.LSUelemCoord1[i3] * kernel["MacroTile0"] - if (offset3 - offset2 == offset2 - offset1) and (offset2 - offset1 == offset1 - offset0): - validOffset0 = self.LSUelemCoord0[i1] - self.LSUelemCoord0[i0] - validOffset1 = self.LSUelemCoord1[i1] - self.LSUelemCoord1[i0] - if self.LSUelemCoord0[i2] - self.LSUelemCoord0[i1] == validOffset0 \ - and self.LSUelemCoord0[i3] - self.LSUelemCoord0[i2] == validOffset0 \ - and self.LSUelemCoord1[i2] - self.LSUelemCoord1[i1] == validOffset1 \ - and self.LSUelemCoord1[i3] - self.LSUelemCoord1[i2] == validOffset1: - validOffset = offset1 - offset0 - break - if validOffset != -1: - break - for idx in range(0, len(self.LSUelements), 4*idxGrp): - for idx2 in range(idxGrp): - self.LSUelementsPerLSUWave.append(self.LSUelements[idx + idx2]) - self.LSUelemCoord0PerLSUWave.append(self.LSUelemCoord0[idx + idx2]) - self.LSUelemCoord1PerLSUWave.append(self.LSUelemCoord1[idx + idx2]) - elif kernel["LocalSplitU"] == 2: - i = 0 - offset0 = self.LSUelemCoord0[i] + self.LSUelemCoord1[i] * kernel["MacroTile0"] - offset1 = self.LSUelemCoord0[i + 1] + self.LSUelemCoord1[i + 1] * kernel["MacroTile0"] - validOffset = offset1 - offset0 - validOffset0 = self.LSUelemCoord0[i + 1] - self.LSUelemCoord0[i] - validOffset1 = self.LSUelemCoord1[i + 1] - self.LSUelemCoord1[i] - for idx in range(0, len(self.LSUelements), 2): - self.LSUelementsPerLSUWave.append(self.LSUelements[idx]) - self.LSUelemCoord0PerLSUWave.append(self.LSUelemCoord0[idx]) - self.LSUelemCoord1PerLSUWave.append(self.LSUelemCoord1[idx]) - else: - assert 0, "No valid LSU offset found." - - if validOffset == -1: - assert 0, "No valid LSU offset found." - self.LSUValidOffset0 = validOffset0 - self.LSUValidOffset1 = validOffset1 - bytesPerElem = kernel["ProblemType"]["ComputeDataType"].numBytes() - bytesPerVector = self.LSUfullVw * bytesPerElem - regsPerElem = kernel["ProblemType"]["ComputeDataType"].numRegisters() - numWaves = kernel["MIWaveGroup"][0] * kernel["MIWaveGroup"][1] - regsPerStep = int((bytesPerVector+3)//4) - elementStep = bytesPerVector // bytesPerElem - lsuStep = kernel["MacroTile0"] * kernel["MacroTile1"] - # alloc resource - baseAddr = self.vgprPool.checkOut(1,"baseAddr") - offsetSgpr = self.sgprPool.checkOut(1) - numTotalAccVgprLdsReduction = len(self.LSUelements)*regsPerStep*(self.LSUfullVw//elementStep) - self.accVgprLdsReduction = self.vgprPool.checkOutAligned(numTotalAccVgprLdsReduction, 4, "LsuReduction") - module = Module("localSplitULocalRead") - module.add(Label("localSplitULocalRead", "")) - module.add(RegSet("v", "vgprLsuReduction", self.accVgprLdsReduction)) - # reset vgprValuC register - module.add(RegSet("v", "vgprValuC", self.accVgprLdsReduction)) - self.states.c.startVgprValu = self.accVgprLdsReduction - - # Calculate offset for wave id and lsu id - # re-use the vgpr from numTotalAccVgprLdsReduction - tmpVgpr0 = self.accVgprLdsReduction - lsu_id = self.accVgprLdsReduction + 1 - - with self.allocTmpSgpr(1) as tmpSgprInfo: - tmpSgpr = tmpSgprInfo.idx - module.add(vectorStaticDivide(lsu_id, "Serial", \ - kernel["WavefrontSize"], tmpVgpr0)) - - module.add(vectorStaticDivide(lsu_id, lsu_id, \ - numWaves, tmpVgpr0, comment="Get LSU wave ID")) - module.add(SMovB32(dst=sgpr(tmpSgpr), \ - src=hex(validOffset), comment="a valid offset")) - module.add(VMulLOU32(dst=vgpr(baseAddr), src0=sgpr(tmpSgpr), src1=vgpr(lsu_id), \ - comment="Addr = lsu_id * a valid offset")) - - # reuse lsuCoordOffset from local write - module.add(VAddLShiftLeftU32(dst=vgpr(baseAddr), src0=vgpr(self.lsuCoordOffset), src1=vgpr(baseAddr), shiftHex=hex(log2(self.states.bpeCinternal)), comment="local read LDS address")) - - module.add(SWaitCnt(lgkmcnt=0, vscnt=0, comment="wait for all writes")) - module.add(self._syncThreads(kernel, "post-lsu local write")) - - for r in range(0, kernel["LocalSplitU"]): - for i in range(0, len(self.LSUelementsPerLSUWave)): - offset = r * lsuStep - offset += self.LSUelemCoord0PerLSUWave[i] + self.LSUelemCoord1PerLSUWave[i] * kernel["MacroTile0"] - regIdx = int(((i)*self.LSUfullVw + r*kernel["GlobalWriteVectorWidth"]*kernel["NumGlobalWriteVectorsPerThread"]) * regsPerElem) - # generate source - regIdxStep = 0 - resedualBPV = bytesPerVector - while resedualBPV > 0: - bps = min(resedualBPV, 16) - regsPerStep = int((bps+3)//4) - DSLoadBX = {128: DSLoadB128, - 64: DSLoadB64, - 32: DSLoadB32}[bps*8] - module.add(DSLoadBX(dst=vgpr("LsuReduction+%u"%(regIdx + regIdxStep),regsPerStep), src=vgpr(baseAddr), \ - ds=DSModifiers(offset=(offset*self.states.bpeCinternal+(regIdxStep*4))), comment="r=%u i=%u"%(r,i))) - regIdxStep += regsPerStep - resedualBPV -= bps - - # free resources - self.vgprPool.checkIn(baseAddr) - self.sgprPool.checkIn(offsetSgpr) - - return module - - ############################################################################## - # LocalSplitU: Reduction - ############################################################################## - def localSplitUReduction(self, kernel): - module = Module("localSplitUReduction") - module.add(Label("localSplitUReduction", "")) - is_non_hpa_fp16 = kernel["ProblemType"]["DataType"].isHalf() and (not kernel["ProblemType"]["HighPrecisionAccumulate"]) - elementStep = 2 if is_non_hpa_fp16 else 1 - regsPerElem = kernel["ProblemType"]["ComputeDataType"].numRegisters() - - module.add(SWaitCnt(lgkmcnt=0, vscnt=0, comment="wait for all reads")) - if self.states.archCaps["SeparateVscnt"]: - module.add(SWaitCnt(vscnt=0)) - - for r in range(1, kernel["LocalSplitU"]): - for i in range(0, kernel["NumGlobalWriteVectorsPerThread"]): - for s in range(0, kernel["GlobalWriteVectorWidth"], elementStep): - cIdx = int((s + i * kernel["GlobalWriteVectorWidth"]) * regsPerElem) - regIdx = int((s + i * kernel["GlobalWriteVectorWidth"] + r * kernel["GlobalWriteVectorWidth"] * kernel["NumGlobalWriteVectorsPerThread"]) * regsPerElem) - - if kernel["ProblemType"]["ComputeDataType"].isSingle(): - module.add(VAddF32(dst=vgpr("LsuReduction+%u"%cIdx), src0=vgpr(self.accVgprLdsReduction+ regIdx), src1=vgpr(self.accVgprLdsReduction+cIdx), \ - comment="c[%u] += c[%u]"%(cIdx, regIdx))) - elif kernel["ProblemType"]["ComputeDataType"].isInt32(): - module.add(VAddI32(dst=vgpr("LsuReduction+%u"%cIdx), src0=vgpr(self.accVgprLdsReduction+ regIdx), src1=vgpr(self.accVgprLdsReduction+cIdx), \ - comment="c[%u] += c[%u]"%(cIdx, regIdx))) - else: - # TODO: hpa_half, int8 - assert(0) # unsupported data type, need to modify here and LSU write/read code - return module - ############################################################################## # computeStoreSrd # Add tile assignment fields to store srd diff --git a/tensilelite/Tensile/SolutionStructs.py b/tensilelite/Tensile/SolutionStructs.py index 0611d3221d..3fda2f5116 100644 --- a/tensilelite/Tensile/SolutionStructs.py +++ b/tensilelite/Tensile/SolutionStructs.py @@ -3494,7 +3494,13 @@ def subCheckLdsBlockSizePerPad(tc, idx): ldsNumBytesAB = state["LdsOffsetB"] + ldsNumBytesB # lds buffer size for reduction + # if User want to control the LDS usage, we may open this para in the future ldsNumBytesReduction = state["LocalSplitU"] * state["MacroTile0"] * state["MacroTile1"] * state["ProblemType"]["ComputeDataType"].numBytes() if state["LocalSplitU"] > 1 else 0 + state["LocalSplitUReuseLDS"] = 1 + if ldsNumBytesReduction > globalParameters["MaxLDS"]: + state["LocalSplitUReuseLDS"] = math.ceil(ldsNumBytesReduction / globalParameters["MaxLDS"]) + # reserve all the LDS to LSU. + ldsNumBytesReduction = globalParameters["MaxLDS"] # lds max occupancy ldsSizeOccupancy = globalParameters["DeviceLDS"] // state["MaxOccupancy"] diff --git a/tensilelite/Tensile/TensileInstructions/Instructions.py b/tensilelite/Tensile/TensileInstructions/Instructions.py index a2985b873f..be396c67ca 100644 --- a/tensilelite/Tensile/TensileInstructions/Instructions.py +++ b/tensilelite/Tensile/TensileInstructions/Instructions.py @@ -2721,7 +2721,7 @@ def setupInstructions(self): class _VLShiftLeftAddU32(CommonInstruction): def __init__(self, dst, shiftHex, src0, src1, vop3: Optional[VOP3PModifiers] = None, comment="") -> None: - super().__init__(InstType.INST_U32, dst, [src0, src1, shiftHex], None, vop3, comment) + super().__init__(InstType.INST_U32, dst, [src0, shiftHex, src1], None, vop3, comment) self.setInst("v_lshl_add_u32") class VLShiftLeftAddU32(CompositeInstruction): diff --git a/tensilelite/Tensile/Tests/common/gemm/lsu.yaml b/tensilelite/Tensile/Tests/common/gemm/lsu.yaml index bac0bbe6de..f83258393c 100644 --- a/tensilelite/Tensile/Tests/common/gemm/lsu.yaml +++ b/tensilelite/Tensile/Tests/common/gemm/lsu.yaml @@ -524,3 +524,162 @@ BenchmarkProblems: - Exact: [127, 128, 1, 640] - Exact: [129, 128, 1, 640] - BiasTypeArgs: ['h'] + + ######################################## + # LSU4 - 4 waves larger MT + ######################################## + - + - # ProblemType + OperationType: GEMM + DataType: h + DataTypeB: F8 + DestDataType: h + ComputeDataType: s + HighPrecisionAccumulate: True + TransposeA: 0 + TransposeB: 1 + UseBeta: True + Batched: True + - # BenchmarkProblemSizeGroup - Standard - All problem + InitialSolutionParameters: + BenchmarkCommonParameters: + - KernelLanguage: ["Assembly"] + ForkParameters: + - MatrixInstruction: + - [16, 16, 16, 1, 1, 6, 6, 1,1 ] + - [16, 16, 16, 1, 1, 4, 6, 1,1 ] + - [16, 16, 16, 1, 1, 6, 4, 1,1 ] + - [16, 16, 16, 1, 1, 8, 4, 1,1 ] + - [16, 16, 16, 1, 1, 4, 8, 1,1 ] + - [16, 16, 16, 1, 1, 8, 8, 1,1 ] + - WorkGroup: + - [4,4,4] + - GlobalReadVectorWidthA: [8, -1] + - GlobalReadVectorWidthB: [8, -1] + - PrefetchGlobalRead: [2] + - PrefetchLocalRead: [1] + - ClusterLocalRead: [1] + - NumElementsPerBatchStore: [0] + - DepthU: [32,64] + - VectorWidthA: [1,2,-1] + - VectorWidthB: [1,2,-1] + - MIArchVgpr: [0] + - LocalWritePerMfma: [-1] + - StaggerU: [4] + - StaggerUStride: [256] + - StaggerUMapping: [2] + - WorkGroupMapping: [1] + - ScheduleIterAlg: [3] + - ExpandPointerSwap: [0] + - TransposeLDS: [0,1,2] + - LdsBlockSizePerPadA: [-1] + - LdsBlockSizePerPadB: [-1] + - StorePriorityOpt: [0] + - VectorStore: [-1] + - StoreSyncOpt: [0] + - LdsPadA: [-1] + - LdsPadB: [-1] + - 1LDSBuffer: [1] + - GlobalSplitU: [1] + - SourceSwap: [0, 1] + BenchmarkJoinParameters: + BenchmarkFinalParameters: + - ProblemSizes: + - Exact: [1, 127, 1, 127] + - Exact: [2, 127, 1, 127] + - Exact: [3, 127, 1, 127] + - Exact: [1, 1, 1, 127] + - Exact: [127, 1, 1, 127] + - Exact: [127, 2, 1, 127] + - Exact: [127, 3, 1, 127] + - Exact: [127, 127, 1, 127] + - Exact: [128, 128, 1, 128] + - Exact: [129, 129, 1, 129] + - Exact: [127, 127, 1, 128] + - Exact: [127, 127, 1, 129] + - Exact: [127, 128, 1, 640] + - Exact: [129, 128, 1, 640] + - Exact: [512, 512, 1, 640] + + ######################################## + # LSU2 - 2 waves larger MT + ######################################## + - + - # ProblemType + OperationType: GEMM + DataType: h + DataTypeA: F8 + DestDataType: h + ComputeDataType: s + HighPrecisionAccumulate: True + TransposeA: 1 + TransposeB: 1 + UseBeta: True + Batched: True + - # BenchmarkProblemSizeGroup - Standard - All problem + InitialSolutionParameters: + BenchmarkCommonParameters: + - KernelLanguage: ["Assembly"] + ForkParameters: + - MatrixInstruction: + - [16, 16,16, 1, 1, 8, 4, 1, 2 ] # larger than MT128x64 + - [16, 16,16, 1, 1, 8, 8, 1, 1 ] + - [16, 16,16, 1, 1, 4, 7, 2, 1 ] + - [16, 16,16, 1, 1, 3, 7, 2, 1 ] + - [16, 16,16, 1, 1, 7, 7, 1, 1 ] + - [16, 16,16, 1, 1, 5, 5, 1, 1 ] + - [16, 16,16, 1, 1, 4, 5, 2, 1 ] + - [16, 16,16, 1, 1, 4, 5, 1, 2 ] + - [16, 16,16, 1, 1, 5, 4, 1, 2 ] + - [16, 16,16, 1, 1, 3, 9, 2, 1 ] + - [16, 16,16, 1, 1, 4, 8, 2, 1 ] + - WorkGroup: + - [4,4,2] + - GlobalReadVectorWidthA: [-1] + - GlobalReadVectorWidthB: [-1] + - LocalReadVectorWidth: [-1] + - PrefetchGlobalRead: [2] + - PrefetchLocalRead: [1] + - ClusterLocalRead: [1] + - NumElementsPerBatchStore: [0] + - DepthU: [32,64] + - VectorWidthA: [-1] + - VectorWidthB: [-1] + - MIArchVgpr: [0] + - LocalWritePerMfma: [-1] + - StaggerU: [0] + - StaggerUStride: [-1] + - StaggerUMapping: [2] + - WorkGroupMapping: [1] + - LocalReadVectorWidth: [8] + - ScheduleIterAlg: [3] + - ExpandPointerSwap: [0] + - TransposeLDS: [1] + - LdsBlockSizePerPadA: [-1] + - LdsBlockSizePerPadB: [-1] + - StorePriorityOpt: [0] + - VectorStore: [-1] + - StoreSyncOpt: [0] + - LdsPadA: [-1] + - LdsPadB: [-1] + - 1LDSBuffer: [1] + - GlobalSplitU: [1] + - SourceSwap: [0, 1] + BenchmarkJoinParameters: + BenchmarkFinalParameters: + - ProblemSizes: + - Exact: [1, 127, 1, 127] + - Exact: [2, 127, 1, 127] + - Exact: [3, 127, 1, 127] + - Exact: [1, 1, 1, 127] + - Exact: [127, 1, 1, 127] + - Exact: [127, 2, 1, 127] + - Exact: [127, 3, 1, 127] + - Exact: [127, 127, 1, 127] + - Exact: [128, 128, 1, 128] + - Exact: [129, 129, 1, 129] + - Exact: [127, 127, 1, 128] + - Exact: [127, 127, 1, 129] + - Exact: [127, 128, 1, 640] + - Exact: [129, 128, 1, 640] + - Exact: [512, 512, 1, 640]