Skip to content

Commit

Permalink
code-gen: improved tail loop and edge tile of swizzled A
Browse files Browse the repository at this point in the history
* Opt swizzleA tail-loop and minor bug fix

* Add test cases for tail loop

* disable big sizes
  • Loading branch information
solaslin authored Dec 20, 2024
1 parent 6b413e3 commit 5ca877f
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 27 deletions.
7 changes: 4 additions & 3 deletions tensilelite/Tensile/Contractions.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,12 +507,13 @@ def CompoundPredicates(cls, state, problemType):
if ('WorkGroupMappingXCC' in state) and ('WorkGroupMappingXCCGroup' in state):
rv += [cls("WorkgroupMappingXCCCheck", value=[state['WorkGroupMappingXCC'], state['WorkGroupMappingXCCGroup']])]

# TODO- To improve the perf of these non-multiples cases
if state['ProblemType']['SwizzleTensorA']:
rv += [cls('SwizzleTensorA', value=state['ProblemType']['SwizzleTensorA'])]
rv += [cls("Free0SizeMultiple", index=0, value=state['MacroTile0'])]
rv += [cls("BoundSizeMultiple", index=-1, value=state['DepthU'])]
# TODO- (TT + DTVA) tail-loop is not working yet.
if state['ProblemType']['TransposeB']:
rv += [cls("BoundSizeMultiple", index=-1, value=state['DepthU'])]

# TODO- Will remove the size predicate once we have SWZ-B request
if state['ProblemType']['SwizzleTensorB']:
rv += [cls('SwizzleTensorB', value=state['ProblemType']['SwizzleTensorB'])]
rv += [cls("Free1SizeMultiple", index=0, value=state['MacroTile1'])]
Expand Down
24 changes: 15 additions & 9 deletions tensilelite/Tensile/KernelWriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -994,7 +994,7 @@ def checkLocalReadFIFOFull(currentMFMA, fifo, lrItems, numLR, numLREven):
return numToBeIssued

oneBufferScheduling = kernel["1LDSBuffer"] or kernel["DirectToLdsA"] or kernel["DirectToLdsB"]

def hasDependency(lr: DSLoadInstruction, inst: Instruction) -> bool:
lrDataReg = lr.dst

Expand Down Expand Up @@ -2569,7 +2569,7 @@ def kernelBody( self, kernel, tensorParametersA, tensorParametersB ):
# last NLL or pack DTV case, no deep copy for pack
# pack code for local prefetch is generated in noLoadLoopBody and used for DTV even
deepCopyPack = pack
else:
else:
# deepCopy packCode for OptNLL noLoadLoop
deepCopyPack = fastdeepcopy(pack)
module.add(self.noLoadLoop(kernel, tensorParametersA, tensorParametersB, isOptNLL=False, isNGLL=False, pack=deepCopyPack, NLLindex=NLLindex, NLLnum=NLLnum))
Expand Down Expand Up @@ -2636,15 +2636,21 @@ def kernelBody( self, kernel, tensorParametersA, tensorParametersB ):
kernel["tailLoopOpt"] == False) else 0
globalReadMode2nd = 2 if (((tensorParameters2nd["glvw"] * tensorParameters2nd["bpeGR"]) < 4) or \
kernel["tailLoopOpt"] == False) else 0

# if we have swizzled A or B, then size-K is already guarded, we don't have to used guarded-k GR again
hasSwizzled = tensorParametersA["isSwizzled"] or tensorParametersB["isSwizzled"]
globalReadMode1st = 0 if hasSwizzled else globalReadMode1st
globalReadMode2nd = 0 if hasSwizzled else globalReadMode2nd

module.addComment1("Update M0 for DTLDS")
moduleTmp = self.directToLdsM0Update(kernel, 1, tensorParameters1st)
module.add(replaceHolder(moduleTmp, 0))
module.addComment1("global read %s"%tc1)
module.addComment1("Tail global read %s"%tc1)
module.add(self.globalReadDo(kernel, globalReadMode1st, tensorParameters1st))
module.addComment1("Update M0 for DTLDS")
moduleTmp = self.directToLdsM0Update(kernel, 1, tensorParameters2nd)
module.add(replaceHolder(moduleTmp, 0))
module.addComment1("global read %s"%tc2)
module.addComment1("Tail global read %s"%tc2)
module.add(self.globalReadDo(kernel, globalReadMode2nd, tensorParameters2nd))
if kernel["tailLoopOpt"] and \
(((tensorParameters1st["glvw"] * tensorParameters1st["bpeGR"]) >= 4) or \
Expand Down Expand Up @@ -2679,13 +2685,13 @@ def kernelBody( self, kernel, tensorParametersA, tensorParametersB ):

# tail: re-init local read addresses
if kernel["PrefetchGlobalRead"]:
module.addComment1("local read reset offsets a")
module.addComment1("Tail: local read reset offsets a")
module.add(self.localReadResetOffsets(kernel, tensorParametersA))
module.addComment1("local read reset offsets b")
module.addComment1("Tail: local read reset offsets b")
module.add(self.localReadResetOffsets(kernel, tensorParametersB))
module.addComment1("local read init pointers a")
module.addComment1("Tail: local read init pointers a")
module.add(self.localReadInitPointers(kernel, tensorParametersA, tensorParametersA))
module.addComment1("local read init pointers b")
module.addComment1("Tail: local read init pointers b")
module.add(self.localReadInitPointers(kernel, tensorParametersA, tensorParametersB))
if kernel["ProblemType"]["Sparse"] and not kernel["DirectToVgprSparseMetadata"]:
module.addComment1("local read reset offsets metadata")
Expand Down Expand Up @@ -2908,7 +2914,7 @@ def initKernel(self, kernel, tensorParametersA, tensorParametersB):
self.states.asmCaps = self.ti.getAsmCaps()
self.states.archCaps = self.ti.getArchCaps()
self.states.regCaps = self.ti.getRegCaps()

self.asmAssert = Assert(self.states.laneSGPRCount, kernel["WavefrontSize"], self.db["EnableAsserts"])

# Only assembly supports scheduling
Expand Down
12 changes: 7 additions & 5 deletions tensilelite/Tensile/KernelWriterAssembly.py
Original file line number Diff line number Diff line change
Expand Up @@ -5328,11 +5328,13 @@ def mfmaIter(self, kernel, tPA, tPB, u, innerUnroll, vregSetIdx, unrollLoopIdx =
if kernel["LocalSplitU"] > 1:
shiftK.add(SMinI32(dst=sgpr(loopCntSgpr), src0=sgpr(loopCounterName), src1=sgpr("LSUTailLoopOffset"), comment="check lsu bound"))
shiftK.add(VCmpGEI32(dst=sgpr(tmpSgprX2, self.states.laneSGPRCount), src0=vgpr(kReg), src1=sgpr(loopCntSgpr), comment="check K index >= Size L"))
for bk in range(0, vgprPerSet0Group):
for a in range(0, kernel["MIWaveTileA"]):
for iui in range(0, innerUnroll):
aStr = vgpr(self.generateSrcStrForMFMA(kernel, tPA, innerUnroll, vregSetIdx, vgprPerInputA, m, u, iui, a, bk=bk + group * vgprPerSet0Group), 1)
shiftK.add(VCndMaskB32(dst=aStr, src0=aStr, src1=hex(0), src2=sgpr(tmpSgprX2, self.states.laneSGPRCount), comment="set 0 if K_idx >= sizeL"))
# if A is swizzled, then no need to do this since MatB will be set to 0
if not tPA["isSwizzled"]:
for bk in range(0, vgprPerSet0Group):
for a in range(0, kernel["MIWaveTileA"]):
for iui in range(0, innerUnroll):
aStr = vgpr(self.generateSrcStrForMFMA(kernel, tPA, innerUnroll, vregSetIdx, vgprPerInputA, m, u, iui, a, bk=bk + group * vgprPerSet0Group), 1)
shiftK.add(VCndMaskB32(dst=aStr, src0=aStr, src1=hex(0), src2=sgpr(tmpSgprX2, self.states.laneSGPRCount), comment="set 0 if K_idx >= sizeL"))

if kernel["ProblemType"]["Sparse"] == 2 and numMIInput//8 >= 1:
shiftK.add(vectorStaticRemainder(dummy, kReg, "Serial", kernel["WavefrontSize"], tmpVgpr, tmpSgprInfo))
Expand Down
13 changes: 7 additions & 6 deletions tensilelite/Tensile/SolutionStructs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2776,16 +2776,17 @@ def calcLdsNumBytes(ldsPadA: int, LdsBlockSizePerPadA: int, ldsPadB: int, LdsBlo
reject(state, f"SwizzleTensor{tc} requires VectorWidth{tc} ({VW_TC}) == 1")

if state["ProblemType"]["SwizzleTensorA"]:
if state["ProblemType"]["TransposeA"] is False:
reject(state, f"Tensor A swizzling supports TN or TT only")
if state["DirectToVgprA"] is False:
if not state["DirectToVgprA"]:
reject(state, f"Tensor A swizzling requires DirectToVgprA")
if not state["ProblemType"]["TransposeA"]:
reject(state, f"Tensor A swizzling supports TN or TT only")

if state["ProblemType"]["SwizzleTensorB"]:
if state["ProblemType"]["TransposeB"] is True:
reject(state, f"Tensor B swizzling supports NN or TN only")
if state["DirectToVgprB"] is False:
if not state["DirectToVgprB"]:
reject(state, f"Tensor B swizzling requires DirectToVgprB")
# TODO- NN fails validation due to DTVB + Tail-Loop is not working correctly
if not (state["ProblemType"]["TransposeA"] and not state["ProblemType"]["TransposeB"]):
reject(state, f"Tensor B swizzling supports TN only")

def calcOptGRVW(lrvw: int, unrollMajorLDS: bool, datatype: DataType) -> int:
# with UnrollMajorLDS, GRVW need to less or equal than LRVW to have conflict free LDS read with padding.
Expand Down
21 changes: 17 additions & 4 deletions tensilelite/Tensile/Tests/common/gemm/dtvA_swizzleA.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ GlobalParameters:
MergeFiles: False
KernelTime: True
MaxWorkspaceSize: 13421772800
DataInitTypeA: 13
DataInitTypeB: 13
DataInitTypeAlpha: 1
DataInitTypeBeta: 1
BoundsCheck: 2
Expand Down Expand Up @@ -55,9 +57,11 @@ BenchmarkProblems:
- [16, 16, 16, 1, 1, 4, 16, 4,1 ] # MT = 256x256
- [16, 16, 16, 1, 1, 8, 8, 4,1 ] # MT = 512x128
- [16, 16, 16, 1, 1, 8, 16, 4,1 ] # MT = 512x256
- [16, 16, 16, 1, 1, 5, 8, 2, 1 ] # MT = 160x128
- [16, 16, 16, 1, 1, 5, 8, 2,1 ] # MT = 160x128
- AssertFree0ElementMultiple: [16]
- AssertSummationElementMultiple: [32]
- GlobalReadVectorWidthA: [8]
- GlobalReadVectorWidthB: [2,4,8]
- GlobalReadVectorWidthB: [-1]
- PrefetchGlobalRead: [1,2]
- PrefetchLocalRead: [1,2,4]
- ClusterLocalRead: [1]
Expand All @@ -68,7 +72,7 @@ BenchmarkProblems:
- LocalWritePerMfma: [-1]
- StaggerU: [4]
- StaggerUStride: [256]
- StaggerUMapping: [2]
- StaggerUMapping: [0]
- WorkGroupMappingXCC: [8]
- ScheduleIterAlg: [3]
- LdsBlockSizePerPadA: [-1]
Expand All @@ -87,9 +91,18 @@ BenchmarkProblems:
BenchmarkJoinParameters:
BenchmarkFinalParameters:
- ProblemSizes:
- Exact: [160, 256, 1, 224]
- Exact: [160, 256, 1, 256]
- Exact: [1600, 512, 1, 1024]
- Exact: [160, 256, 1, 288]

# Enable the big sizes when GSU3 + MBSK is ready.
# - Exact: [1600, 512, 1, 992]
# - Exact: [1600, 512, 1, 1024]
# - Exact: [1600, 512, 1, 1056]

- Exact: [512, 256, 1, 224]
- Exact: [512, 256, 1, 256]
- Exact: [512, 256, 1, 288]
- BiasTypeArgs: ['h']
- ActivationArgs:
- [Enum: none]
Expand Down

0 comments on commit 5ca877f

Please sign in to comment.