diff --git a/tensilelite/Tensile/TensileInstructions/Math.py b/tensilelite/Tensile/TensileInstructions/Math.py index aac4edf41e..f45a813a88 100644 --- a/tensilelite/Tensile/TensileInstructions/Math.py +++ b/tensilelite/Tensile/TensileInstructions/Math.py @@ -1,6 +1,6 @@ ################################################################################ # -# Copyright (C) 2022-2023 Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2022-2024 Advanced Micro Devices, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -349,6 +349,8 @@ def scalarUInt32DivideAndRemainder(qReg, dReg, divReg, rReg, tmpVgprRes: Registe tmpVgpr0 = tmpVgprRes.idx tmpVgpr1 = tmpVgprRes.idx + 1 + SMovBX = SMovB64 if wavewidth == 64 else SMovB32 + module = Module("scalarUInt32DivideAndRemainder") module.add(VCvtU32toF32(dst=vgpr(tmpVgpr0), src=sgpr(divReg), comment=dComment)) module.add(VRcpIFlagF32(dst=vgpr(tmpVgpr0), src=vgpr(tmpVgpr0), comment=dComment)) @@ -361,8 +363,13 @@ def scalarUInt32DivideAndRemainder(qReg, dReg, divReg, rReg, tmpVgprRes: Registe module.add(VAddU32(dst=vgpr(tmpVgpr0), src0=1, src1=vgpr(tmpVgpr0), comment=dComment)) if doRemainder: module.add(VMovB32(dst=vgpr(tmpVgpr1), src=0, comment=rComment)) - SMovBX = SMovB64 if wavewidth == 64 else SMovB32 - module.add(SMovBX(dst=EXEC(), src=-1, comment=dComment)) + module.add(SMovBX(dst=EXEC(), src=-1, comment="Reset exec")) + module.add(VCmpXGtU32(dst=EXEC(), src0=vgpr(tmpVgpr1), src1=sgpr(divReg), comment="overflow happened in remainder")) + module.add(VSubU32(dst=vgpr(tmpVgpr0), src0=vgpr(tmpVgpr0), src1=1, comment="quotient - 1")) + if doRemainder: + module.add(VMulU32U24(dst=vgpr(tmpVgpr1), src0=vgpr(tmpVgpr0), src1=sgpr(divReg), comment="re-calculate remainder")) + module.add(VSubU32(dst=vgpr(tmpVgpr1), src0=sgpr(dReg), src1=vgpr(tmpVgpr1), comment="re-calculate remainder")) + module.add(SMovBX(dst=EXEC(), src=-1, comment="Reset exec")) module.add(VReadfirstlaneB32(dst=sgpr(qReg), src=vgpr(tmpVgpr0), comment="quotient")) if doRemainder: module.add(VReadfirstlaneB32(dst=sgpr(rReg), src=vgpr(tmpVgpr1), comment="remainder"))