Skip to content

Commit

Permalink
Enable WORD_0 of 2-nd VGPR with vi=4 for vw=8. (#1333)
Browse files Browse the repository at this point in the history
  • Loading branch information
geotseng-amd authored Nov 18, 2024
1 parent 91ba995 commit 9f30df5
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 10 deletions.
4 changes: 2 additions & 2 deletions clients/gtest/matmul_gtest.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1593,7 +1593,7 @@ Tests:
algo_method: [0,1]
transA_transB: *transA_transB_range
alpha: 1
beta: 0
beta: [0,1]
requested_solution_num: -1
unit_check: 1

Expand All @@ -1607,7 +1607,7 @@ Tests:
algo_method: [0,1]
transA_transB: *transA_transB_range
alpha: 1
beta: 0
beta: [0,1]
requested_solution_num: -1
unit_check: 1
gpu_arch: '94[0-2]'
Expand Down
20 changes: 12 additions & 8 deletions tensilelite/Tensile/Components/GlobalWriteBatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2011,7 +2011,7 @@ def _addSumAlphaWithCBeta(self, kernel, ss, gwvw, elementIdx, vc0, tmpVgpr, cvtV
# Generate single f32 code if edge is detected.
isPK = False
if ((vi + 1) == self.gwvw) and ((self.gwvw % 2) == 1):
if self.parentWriter.states.archCaps["NoSDWA"]: #cm review
if self.parentWriter.states.archCaps["NoSDWA"]:
sb = 0 if self.gwvw == 1 else 1
module.add(VCvtFP8toF32(dst=vgpr(tmpVgpr), src=vgpr(dataV), vop3=VOP3PModifiers(op_sel=[0,sb])))
else:
Expand All @@ -2022,11 +2022,13 @@ def _addSumAlphaWithCBeta(self, kernel, ss, gwvw, elementIdx, vc0, tmpVgpr, cvtV
continue
else:
isPK = True
if self.parentWriter.states.archCaps["NoSDWA"]: #cm review
sb = 0 if vi ==0 else 1
if self.parentWriter.states.archCaps["NoSDWA"]:
# Enable WORD_0 of 2-nd VGPR with vi=4 for vw=8
sb = 0 if vi%4 == 0 else 1
module.add(VCvtPkFP8toF32(dst=vgpr(tmpVgpr, 2), src=vgpr(dataV), vop3=VOP3PModifiers(op_sel=[sb])))
else:
sb = SelectBit.WORD_0 if vi == 0 else SelectBit.WORD_1
# Enable WORD_0 of 2-nd VGPR with vi=4 for vw=8
sb = SelectBit.WORD_0 if vi%4 == 0 else SelectBit.WORD_1
module.add(VCvtPkFP8toF32(dst=vgpr(tmpVgpr, 2), src=vgpr(dataV), sdwa=SDWAModifiers(src0_sel=sb)))
module.add(SNop(waitState=0))
if kernel["ProblemType"]["ComputeDataType"].isSingle():
Expand All @@ -2040,7 +2042,7 @@ def _addSumAlphaWithCBeta(self, kernel, ss, gwvw, elementIdx, vc0, tmpVgpr, cvtV
# Generate single f32 code if edge is detected.
isPK = False
if ((vi + 1) == self.gwvw) and ((self.gwvw % 2) == 1):
if self.parentWriter.states.archCaps["NoSDWA"]: #cm review
if self.parentWriter.states.archCaps["NoSDWA"]:
sb = 0 if self.gwvw == 1 else 1
module.add(VCvtFP8toF32(dst=vgpr(tmpVgpr), src=vgpr(dataV), vop3=VOP3PModifiers(op_sel=[0,sb])))
else:
Expand All @@ -2051,11 +2053,13 @@ def _addSumAlphaWithCBeta(self, kernel, ss, gwvw, elementIdx, vc0, tmpVgpr, cvtV
continue
else:
isPK = True
if self.parentWriter.states.archCaps["NoSDWA"]: #cm review
sb = 0 if vi ==0 else 1
if self.parentWriter.states.archCaps["NoSDWA"]:
# Enable WORD_0 of 2-nd VGPR with vi=4 for vw=8
sb = 0 if vi%4 == 0 else 1
module.add(VCvtPkFP8toF32(dst=vgpr(tmpVgpr, 2), src=vgpr(dataV), vop3=VOP3PModifiers(op_sel=[sb])))
else:
sb = SelectBit.WORD_0 if vi == 0 else SelectBit.WORD_1
# Enable WORD_0 of 2-nd VGPR with vi=4 for vw=8
sb = SelectBit.WORD_0 if vi%4 == 0 else SelectBit.WORD_1
module.add(VCvtPkBF8toF32(dst=vgpr(tmpVgpr, 2), src=vgpr(dataV), sdwa=SDWAModifiers(src0_sel=sb)))
module.add(SNop(waitState=0))
if kernel["ProblemType"]["ComputeDataType"].isSingle():
Expand Down

0 comments on commit 9f30df5

Please sign in to comment.