Skip to content

Commit

Permalink
LSU supports larger MT and reuse LDS. (#1433)
Browse files Browse the repository at this point in the history
1. Move LSU into LSU.py.
2. Do partial LSU when the LDS is not enough.
  • Loading branch information
hcman2 authored Dec 12, 2024
1 parent ebd940f commit e13e133
Show file tree
Hide file tree
Showing 7 changed files with 720 additions and 315 deletions.
547 changes: 547 additions & 0 deletions tensilelite/Tensile/Components/LSU.py

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion tensilelite/Tensile/Components/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,5 +48,6 @@ def use(): pass
"SumUnroll",
"GSU",
"StreamK",
"PersistentLoop"
"PersistentLoop",
"LSU",
]
43 changes: 5 additions & 38 deletions tensilelite/Tensile/KernelWriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2824,29 +2824,17 @@ def kernelBody( self, kernel, tensorParametersA, tensorParametersB ):
####################################
#if kernel["NumThreads"]%kernel["MacroTile0"] == 0:
if kernel["LocalSplitU"] > 1:
module.addComment2("LocalSplitU Reduction")
module.add(self._syncThreads(kernel))

# LocalSplitU: local write
module.addComment1("LocalSplitU: local write")
module.add(self.localSplitULocalWrite(kernel))

# LocalSplitU: local read
module.addComment1("LocalSplitU: local read")
module.add(self.localSplitULocalRead(kernel))
module.addComment1("LocalSplitU: local write and read")
lsuComponent = Component.LSU.find(self)
module.add(lsuComponent.writeReadReduction(self, kernel))

# LocalSplitU: global write indices
# Hide instructions in local read latency
module.addComment1("LocalSplitU: global write indices")
module.add(self.localSplitUGlobalWriteIndices(kernel))

# LocalSplitU: Reduction
module.addComment1("LocalSplitU: reduction")
module.add(self.localSplitUReduction(kernel))
module.add(lsuComponent.globalWriteIndices(self, kernel))

# LocalSplitU: global write
module.addComment1("LocalSplitU: global write")
module.add(self.localSplitUGlobalWrite(kernel, tensorParametersA, tensorParametersB))
module.add(lsuComponent.globalWrite(self, kernel, tensorParametersA, tensorParametersB))

else:
####################################
Expand Down Expand Up @@ -4839,27 +4827,6 @@ def localReadDo(self, kernel, bufferIdx, innerUnrollIndex, epsi, tP):
def shiftVectorComponents(self, kernel, tP):
return ""

##############################################################################
# LocalSplitU: Local Write
##############################################################################
@abc.abstractmethod
def localSplitULocalWrite(self, kernel):
return ""

##############################################################################
# LocalSplitU: Local Read
##############################################################################
@abc.abstractmethod
def localSplitULocalRead(self, kernel):
return ""

##############################################################################
# LocalSplitU: Reduction
##############################################################################
@abc.abstractmethod
def localSplitUReduction(self, kernel):
return ""

##############################################################################
# globalWriteWorkGroupInit:
# Perform work-group granularity init
Expand Down
275 changes: 0 additions & 275 deletions tensilelite/Tensile/KernelWriterAssembly.py
Original file line number Diff line number Diff line change
Expand Up @@ -7924,281 +7924,6 @@ def shiftVectorComponents(self, kernel, tP):
if component:
return component(self, kernel, tP)

##############################################################################
# LocalSplitU: Local Write
##############################################################################
def localSplitULocalWrite(self, kernel):
module = Module("localSplitULocalWrite")
# wait for summation to be done with lds before writing reduction values
module.add(self._syncThreads(kernel, "pre-lsu local write"))
module.add(Label("localSplitULocalWrite", ""))

tmpVgpr = self.vgprPool.checkOutAligned(2, 2, "tmpVgpr")
tmpVgprRes = RegisterPoolResource(tmpVgpr, 2)
lsu_id = self.vgprPool.checkOut(1,"lsu_id")
addr = self.vgprPool.checkOut(1,"addr")
self.lsuCoordOffset = self.vgprPool.checkOut(1,"lsuCoordOffset")
lr1 = self.vgprPool.checkOut(1,"lr1")
acc2arch, _ = accToArchMapper(kernel)
NumAccVgprRes = len(acc2arch)*kernel["MIRegPerOut"]
accVgprRes = self.vgprPool.checkOutAligned(NumAccVgprRes, 4, "accLSUVgprRes")
for i in range(len(acc2arch)):
for r in range(kernel["MIRegPerOut"]):
destIdx = (acc2arch[i]) * kernel["MIRegPerOut"] + r
srcIdx = ((i * kernel["MIRegPerOut"] + r))
if not kernel["MIArchVgpr"]:
accStr = accvgpr(srcIdx)
module.add(VAccvgprReadB32(dst=vgpr(accVgprRes+destIdx),
src=accStr,
comment="copy acc to vreg[%u]" % destIdx))
else:
module.add(VMovB32(dst=vgpr(accVgprRes+destIdx),
src=vgpr("ValuC+%u"%srcIdx),
comment="copy MI out reg to vreg[%u]" % destIdx))

ldsStride = kernel["MacroTile0"]*kernel["MacroTile1"]
numWaves = kernel["MIWaveGroup"][0] * kernel["MIWaveGroup"][1]
waveOffset = ldsStride // numWaves

# new method. output self.vgprs.coord0InMT/coord1InMT
if kernel["EnableMatrixInstruction"]:
module.add(self.computeStoreVgprs(kernel))
else:
# new method. output self.vgprs.coord0InMT/coord1InMT
# lr0 = serial % SG0
module.add(self.computeStoreVgprs(kernel, \
divisor = kernel["MacroTile0"] // kernel["GlobalWriteVectorWidth"], \
tid0Scale = kernel["GlobalWriteVectorWidth"], \
tid1Scale = 1))

self.LSUelemCoord0 = []
self.LSUelemCoord1 = []
self.LSUelements = []
self.LSUfullVw = []
(vwdummy, eledummy, self.LSUfullVw, self.LSUelements) = self.notLocalFullTileElements(kernel, False)
storevw = self.LSUfullVw
atomic = False # atomic is for GSU > 1
beta = True
vectorDataTypes = VectorDataTypes()
ss = StoreState(self, kernel, storevw, False, beta, atomic, self.LSUelements, vectorDataTypes, dim=0)
self.LSUelemCoord0, self.LSUelemCoord1 = ss.getStoreElementsInfoForBatch(kernel, self.LSUelements)

with self.allocTmpSgpr(1) as tmpSgprInfo:
tmpSgpr = tmpSgprInfo.idx

# lr1 = serial / kernel["WavefrontSize"]
module.add(vectorStaticDivide(lr1, "Serial", \
kernel["WavefrontSize"], tmpVgprRes))

module.add(vectorStaticDivide(lsu_id, lr1, \
numWaves, tmpVgprRes, comment="Get LSU wave ID"))

module.add(SMovB32(dst=sgpr(tmpSgpr), \
src=hex(ldsStride), comment="MT0*MT1"))
module.add(VMulLOU32(dst=vgpr(addr), src0=sgpr(tmpSgpr), src1=vgpr(lsu_id), \
comment="lsu_id *= MT0*MT1"))

module.add(SMovB32(dst=sgpr(tmpSgpr), \
src=hex(kernel["MacroTile0"]), comment="MT0"))
module.add(VMulLOU32(dst=vgpr(self.lsuCoordOffset), src0=sgpr(tmpSgpr), src1=vgpr(self.vgprs.coord1InMT), \
comment="MT0*coord1InMT"))
module.add(VAddU32(dst=vgpr(self.lsuCoordOffset), src0=vgpr(self.vgprs.coord0InMT), src1=vgpr(self.lsuCoordOffset), comment="coord0InMT"))

#thread offset
module.add(VAddLShiftLeftU32(dst=vgpr(addr), src0=vgpr(self.lsuCoordOffset), src1=vgpr(addr), shiftHex=hex(log2(self.states.bpeCinternal)), comment="local write LDS address"))

self.vgprPool.checkIn(lr1)
self.vgprPool.checkIn(lsu_id)
self.vgprPool.checkIn(tmpVgpr)

bytesPerElem = kernel["ProblemType"]["ComputeDataType"].numBytes()
regsPerElem = kernel["ProblemType"]["ComputeDataType"].numRegisters()
bytesPerVector = storevw * bytesPerElem
for i in range(0, len(self.LSUelements)):
(tt1, tt0, vc1, vc0) = self.LSUelements[i]
writeOffset = self.LSUelemCoord0[i] + self.LSUelemCoord1[i] * kernel["MacroTile0"]
regIdx = int(i * regsPerElem * storevw)
regIdxStep = 0
resedualBPV = bytesPerVector
while resedualBPV > 0:
bps = min(resedualBPV, 16)
regsPerStep = int((bps+3)//4)
DSStoreBX = {128: DSStoreB128,
64: DSStoreB64,
32: DSStoreB32,
16: DSStoreB16,
8: DSStoreB8}[bps*8]
module.add(DSStoreBX(dstAddr=vgpr(addr), src=vgpr(accVgprRes+regIdx+regIdxStep, regsPerStep), \
ds=DSModifiers(offset=(writeOffset*self.states.bpeCinternal+(regIdxStep*4))),
comment="tt1=%u tt0=%u vc1=%u vc0=%u"%(tt1, tt0, vc1, vc0)))
regIdxStep += regsPerStep
resedualBPV -= bps

self.vgprPool.checkIn(accVgprRes)
self.vgprPool.checkIn(addr)
return module

##############################################################################
# LocalSplitU: Local Read
##############################################################################
def localSplitULocalRead(self, kernel):
# search for valid lsu wave offset
maxtt1 = 0
maxtt0 = 0
maxvc1 = 0
maxvc0 = 0
validOffset = -1
validOffset0 = -1
validOffset1 = -1
self.LSUelementsPerLSUWave = []
self.LSUelemCoord0PerLSUWave = []
self.LSUelemCoord1PerLSUWave = []
# Check valid LSU/VW combination
if len(self.LSUelements) >= kernel["LocalSplitU"]:
if kernel["LocalSplitU"] == 4:
idxGrp = 1
for idxGrp in range(1, len(self.LSUelements)//4 + 1):
for i in range(idxGrp):
i0 = i
i1 = i + 1 * idxGrp
i2 = i + 2 * idxGrp
i3 = i + 3 * idxGrp
offset0 = self.LSUelemCoord0[i0] + self.LSUelemCoord1[i0] * kernel["MacroTile0"]
offset1 = self.LSUelemCoord0[i1] + self.LSUelemCoord1[i1] * kernel["MacroTile0"]
offset2 = self.LSUelemCoord0[i2] + self.LSUelemCoord1[i2] * kernel["MacroTile0"]
offset3 = self.LSUelemCoord0[i3] + self.LSUelemCoord1[i3] * kernel["MacroTile0"]
if (offset3 - offset2 == offset2 - offset1) and (offset2 - offset1 == offset1 - offset0):
validOffset0 = self.LSUelemCoord0[i1] - self.LSUelemCoord0[i0]
validOffset1 = self.LSUelemCoord1[i1] - self.LSUelemCoord1[i0]
if self.LSUelemCoord0[i2] - self.LSUelemCoord0[i1] == validOffset0 \
and self.LSUelemCoord0[i3] - self.LSUelemCoord0[i2] == validOffset0 \
and self.LSUelemCoord1[i2] - self.LSUelemCoord1[i1] == validOffset1 \
and self.LSUelemCoord1[i3] - self.LSUelemCoord1[i2] == validOffset1:
validOffset = offset1 - offset0
break
if validOffset != -1:
break
for idx in range(0, len(self.LSUelements), 4*idxGrp):
for idx2 in range(idxGrp):
self.LSUelementsPerLSUWave.append(self.LSUelements[idx + idx2])
self.LSUelemCoord0PerLSUWave.append(self.LSUelemCoord0[idx + idx2])
self.LSUelemCoord1PerLSUWave.append(self.LSUelemCoord1[idx + idx2])
elif kernel["LocalSplitU"] == 2:
i = 0
offset0 = self.LSUelemCoord0[i] + self.LSUelemCoord1[i] * kernel["MacroTile0"]
offset1 = self.LSUelemCoord0[i + 1] + self.LSUelemCoord1[i + 1] * kernel["MacroTile0"]
validOffset = offset1 - offset0
validOffset0 = self.LSUelemCoord0[i + 1] - self.LSUelemCoord0[i]
validOffset1 = self.LSUelemCoord1[i + 1] - self.LSUelemCoord1[i]
for idx in range(0, len(self.LSUelements), 2):
self.LSUelementsPerLSUWave.append(self.LSUelements[idx])
self.LSUelemCoord0PerLSUWave.append(self.LSUelemCoord0[idx])
self.LSUelemCoord1PerLSUWave.append(self.LSUelemCoord1[idx])
else:
assert 0, "No valid LSU offset found."

if validOffset == -1:
assert 0, "No valid LSU offset found."
self.LSUValidOffset0 = validOffset0
self.LSUValidOffset1 = validOffset1
bytesPerElem = kernel["ProblemType"]["ComputeDataType"].numBytes()
bytesPerVector = self.LSUfullVw * bytesPerElem
regsPerElem = kernel["ProblemType"]["ComputeDataType"].numRegisters()
numWaves = kernel["MIWaveGroup"][0] * kernel["MIWaveGroup"][1]
regsPerStep = int((bytesPerVector+3)//4)
elementStep = bytesPerVector // bytesPerElem
lsuStep = kernel["MacroTile0"] * kernel["MacroTile1"]
# alloc resource
baseAddr = self.vgprPool.checkOut(1,"baseAddr")
offsetSgpr = self.sgprPool.checkOut(1)
numTotalAccVgprLdsReduction = len(self.LSUelements)*regsPerStep*(self.LSUfullVw//elementStep)
self.accVgprLdsReduction = self.vgprPool.checkOutAligned(numTotalAccVgprLdsReduction, 4, "LsuReduction")
module = Module("localSplitULocalRead")
module.add(Label("localSplitULocalRead", ""))
module.add(RegSet("v", "vgprLsuReduction", self.accVgprLdsReduction))
# reset vgprValuC register
module.add(RegSet("v", "vgprValuC", self.accVgprLdsReduction))
self.states.c.startVgprValu = self.accVgprLdsReduction

# Calculate offset for wave id and lsu id
# re-use the vgpr from numTotalAccVgprLdsReduction
tmpVgpr0 = self.accVgprLdsReduction
lsu_id = self.accVgprLdsReduction + 1

with self.allocTmpSgpr(1) as tmpSgprInfo:
tmpSgpr = tmpSgprInfo.idx
module.add(vectorStaticDivide(lsu_id, "Serial", \
kernel["WavefrontSize"], tmpVgpr0))

module.add(vectorStaticDivide(lsu_id, lsu_id, \
numWaves, tmpVgpr0, comment="Get LSU wave ID"))
module.add(SMovB32(dst=sgpr(tmpSgpr), \
src=hex(validOffset), comment="a valid offset"))
module.add(VMulLOU32(dst=vgpr(baseAddr), src0=sgpr(tmpSgpr), src1=vgpr(lsu_id), \
comment="Addr = lsu_id * a valid offset"))

# reuse lsuCoordOffset from local write
module.add(VAddLShiftLeftU32(dst=vgpr(baseAddr), src0=vgpr(self.lsuCoordOffset), src1=vgpr(baseAddr), shiftHex=hex(log2(self.states.bpeCinternal)), comment="local read LDS address"))

module.add(SWaitCnt(lgkmcnt=0, vscnt=0, comment="wait for all writes"))
module.add(self._syncThreads(kernel, "post-lsu local write"))

for r in range(0, kernel["LocalSplitU"]):
for i in range(0, len(self.LSUelementsPerLSUWave)):
offset = r * lsuStep
offset += self.LSUelemCoord0PerLSUWave[i] + self.LSUelemCoord1PerLSUWave[i] * kernel["MacroTile0"]
regIdx = int(((i)*self.LSUfullVw + r*kernel["GlobalWriteVectorWidth"]*kernel["NumGlobalWriteVectorsPerThread"]) * regsPerElem)
# generate source
regIdxStep = 0
resedualBPV = bytesPerVector
while resedualBPV > 0:
bps = min(resedualBPV, 16)
regsPerStep = int((bps+3)//4)
DSLoadBX = {128: DSLoadB128,
64: DSLoadB64,
32: DSLoadB32}[bps*8]
module.add(DSLoadBX(dst=vgpr("LsuReduction+%u"%(regIdx + regIdxStep),regsPerStep), src=vgpr(baseAddr), \
ds=DSModifiers(offset=(offset*self.states.bpeCinternal+(regIdxStep*4))), comment="r=%u i=%u"%(r,i)))
regIdxStep += regsPerStep
resedualBPV -= bps

# free resources
self.vgprPool.checkIn(baseAddr)
self.sgprPool.checkIn(offsetSgpr)

return module

##############################################################################
# LocalSplitU: Reduction
##############################################################################
def localSplitUReduction(self, kernel):
module = Module("localSplitUReduction")
module.add(Label("localSplitUReduction", ""))
is_non_hpa_fp16 = kernel["ProblemType"]["DataType"].isHalf() and (not kernel["ProblemType"]["HighPrecisionAccumulate"])
elementStep = 2 if is_non_hpa_fp16 else 1
regsPerElem = kernel["ProblemType"]["ComputeDataType"].numRegisters()

module.add(SWaitCnt(lgkmcnt=0, vscnt=0, comment="wait for all reads"))
if self.states.archCaps["SeparateVscnt"]:
module.add(SWaitCnt(vscnt=0))

for r in range(1, kernel["LocalSplitU"]):
for i in range(0, kernel["NumGlobalWriteVectorsPerThread"]):
for s in range(0, kernel["GlobalWriteVectorWidth"], elementStep):
cIdx = int((s + i * kernel["GlobalWriteVectorWidth"]) * regsPerElem)
regIdx = int((s + i * kernel["GlobalWriteVectorWidth"] + r * kernel["GlobalWriteVectorWidth"] * kernel["NumGlobalWriteVectorsPerThread"]) * regsPerElem)

if kernel["ProblemType"]["ComputeDataType"].isSingle():
module.add(VAddF32(dst=vgpr("LsuReduction+%u"%cIdx), src0=vgpr(self.accVgprLdsReduction+ regIdx), src1=vgpr(self.accVgprLdsReduction+cIdx), \
comment="c[%u] += c[%u]"%(cIdx, regIdx)))
elif kernel["ProblemType"]["ComputeDataType"].isInt32():
module.add(VAddI32(dst=vgpr("LsuReduction+%u"%cIdx), src0=vgpr(self.accVgprLdsReduction+ regIdx), src1=vgpr(self.accVgprLdsReduction+cIdx), \
comment="c[%u] += c[%u]"%(cIdx, regIdx)))
else:
# TODO: hpa_half, int8
assert(0) # unsupported data type, need to modify here and LSU write/read code
return module

##############################################################################
# computeStoreSrd
# Add tile assignment fields to store srd
Expand Down
6 changes: 6 additions & 0 deletions tensilelite/Tensile/SolutionStructs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3494,7 +3494,13 @@ def subCheckLdsBlockSizePerPad(tc, idx):
ldsNumBytesAB = state["LdsOffsetB"] + ldsNumBytesB

# lds buffer size for reduction
# if User want to control the LDS usage, we may open this para in the future
ldsNumBytesReduction = state["LocalSplitU"] * state["MacroTile0"] * state["MacroTile1"] * state["ProblemType"]["ComputeDataType"].numBytes() if state["LocalSplitU"] > 1 else 0
state["LocalSplitUReuseLDS"] = 1
if ldsNumBytesReduction > globalParameters["MaxLDS"]:
state["LocalSplitUReuseLDS"] = math.ceil(ldsNumBytesReduction / globalParameters["MaxLDS"])
# reserve all the LDS to LSU.
ldsNumBytesReduction = globalParameters["MaxLDS"]

# lds max occupancy
ldsSizeOccupancy = globalParameters["DeviceLDS"] // state["MaxOccupancy"]
Expand Down
2 changes: 1 addition & 1 deletion tensilelite/Tensile/TensileInstructions/Instructions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2721,7 +2721,7 @@ def setupInstructions(self):

class _VLShiftLeftAddU32(CommonInstruction):
def __init__(self, dst, shiftHex, src0, src1, vop3: Optional[VOP3PModifiers] = None, comment="") -> None:
super().__init__(InstType.INST_U32, dst, [src0, src1, shiftHex], None, vop3, comment)
super().__init__(InstType.INST_U32, dst, [src0, shiftHex, src1], None, vop3, comment)
self.setInst("v_lshl_add_u32")

class VLShiftLeftAddU32(CompositeInstruction):
Expand Down
Loading

0 comments on commit e13e133

Please sign in to comment.