Skip to content

Commit

Permalink
directly modify totalElementsPerpA
Browse files Browse the repository at this point in the history
  • Loading branch information
solaslin committed Jan 14, 2025
1 parent 5d0320a commit a0ff71e
Showing 1 changed file with 8 additions and 15 deletions.
23 changes: 8 additions & 15 deletions tensilelite/Tensile/SolutionStructs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1458,8 +1458,10 @@ def setGlobalLoadTileDimClassic(state, tc, numLoads, totalVectorsCoalesced, tota
and totalElementsPerp % nlp == 0:
state["NumLoadsCoalesced%s"%tc] = nlc
state["NumLoadsPerpendicular%s"%tc] = nlp
#print("NumLoadsCoalesced",state["NumLoadsCoalesced%s"%tc])
#print("NumLoadsPerpendicular",state["NumLoadsPerpendicular%s"%tc])
# print("NumLoads%s:"%tc,state["NumLoads%s"%tc])
# print("NumLoadsCoalesced%s:"%tc,state["NumLoadsCoalesced%s"%tc])
# print("NumLoadsPerpendicular%s:"%tc,state["NumLoadsPerpendicular%s"%tc])
# print("\n")
foundValid = True
break
if not foundValid:
Expand Down Expand Up @@ -2985,12 +2987,14 @@ def calcOptGRVW(lrvw: int, unrollMajorLDS: bool, datatype: DataType) -> int:
validDepthU = True

# how many elements to load
if state["ProblemType"]["TLUA"]:
if state["ProblemType"]["TLUA"]: # NT/NN
totalElementsCoalescedA = state["MacroTileA"]
totalElementsPerpA = depthUA
else:
else: # TN/TT
totalElementsCoalescedA = depthUA
totalElementsPerpA = state["MacroTileA"]
if state["DirectToVgprA"]:
totalElementsPerpA *= state["MIWaveGroup"][1]

if state["ProblemType"]["TLUB"]:
totalElementsCoalescedB = state["MacroTileB"]
Expand Down Expand Up @@ -3249,17 +3253,6 @@ def calcOptGRVW(lrvw: int, unrollMajorLDS: bool, datatype: DataType) -> int:
if not Solution.isDirectToVgprDoable(state, 'A'):
return # rejected

# Have to do this after isDirectToVgprDoable and setGlobalLoadTileDimClassic is done
# TODO- this allows MIWG[0] > 1 for DTVA. If pure DTVA will cause lots of issues,
# We may consider to allow it for swizzledA only
waveGroupsAlongN = state["MIWaveGroup"][1]
# When WVG is along N-Dim, each wave needs to load the same part of matA instead of distributing them
state["NumLoadsA"] *= waveGroupsAlongN
state["NumLoadsPerpendicularA"] *= waveGroupsAlongN
if state["ProblemType"]["TLUA"]:
state["LSPA"] = int(math.ceil(float(depthUM) / state["NumLoadsPerpendicularA"]))
else:
state["LSPA"] = state["MacroTileA"] // state["NumLoadsPerpendicularA"]

if state["DirectToVgprB"]:
if not Solution.isDirectToVgprDoable(state, 'B'):
Expand Down

0 comments on commit a0ff71e

Please sign in to comment.