diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.cpp index 24fb9fb6b..81b267457 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.cpp @@ -657,18 +657,25 @@ LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter, // Emit different versions of the induction variable. They will be // removed by dead code if not used. - // bounds_range = ub - lb - // total_iterations = (bounds_range + step - 1) / step + // range_diff = ub - lb + // total_iterations = (range_diff + step + (step < 0 ? 1 : -1)) / step Type t = lb.getType(); - Value minus1 = - rewriter.create(loc, rewriter.getIntegerAttr(t, -1)); - Value boundsRange = rewriter.create(loc, ub, lb); - Value rangeIncr = rewriter.create(loc, boundsRange, step); - Value rangeDecr = rewriter.create(loc, rangeIncr, minus1); - Value totalIterations = rewriter.create(loc, rangeDecr, step); - Value zero = rewriter.create(loc, rewriter.getIntegerAttr(t, 0)); + Value one = + rewriter.create(loc, rewriter.getIntegerAttr(t, 1)); + Value minusOne = + rewriter.create(loc, rewriter.getIntegerAttr(t, -1)); + Value stepLessZero = rewriter.create( + loc, arith::CmpIPredicate::slt, step, zero); + Value stepDecr = + rewriter.create(loc, stepLessZero, one, minusOne); + + Value rangeDiff = rewriter.create(loc, ub, lb); + Value rangeIncrStep = rewriter.create(loc, rangeDiff, step); + Value rangeDecr = + rewriter.create(loc, rangeIncrStep, stepDecr); + Value totalIterations = rewriter.create(loc, rangeDecr, step); // Capture predicates for dynamic loops. SmallVector predicates(maxStage + 1); @@ -679,7 +686,7 @@ LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter, Value minusI = rewriter.create(loc, rewriter.getIntegerAttr(t, -i)); Value iterI = rewriter.create( - loc, rewriter.create(loc, totalIterations, minus1), + loc, rewriter.create(loc, totalIterations, minusOne), minusI); // newLastIter = lb + step * iterI Value newlastIter = rewriter.create( diff --git a/test/TritonGPU/loop-pipeline.mlir b/test/TritonGPU/loop-pipeline.mlir index 4e8b55b95..81cb2d9a0 100644 --- a/test/TritonGPU/loop-pipeline.mlir +++ b/test/TritonGPU/loop-pipeline.mlir @@ -57,8 +57,11 @@ // CHECK: scf.yield {{.*}}, %[[INS_IDX_3]], %[[EXT_IDX_3]], %[[NEXT_A]], %[[NEXT_B]] // AMD-LABEL: tt.func @matmul_loop +// AMD-DAG: %[[CM1:.*]] = arith.constant -1 : index +// AMD-DAG: %[[C1:.*]] = arith.constant 1 : index // AMD-DAG: %[[C0:.*]] = arith.constant 0 : index -// AMD: %{{.*}}:6 = scf.for %[[ARG5:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %{{.*}}, %[[ARG9:.*]] = %{{.*}}, %[[ARG10:.*]] = %{{.*}}, %[[ARG11:.*]] = %{{.*}}) +// AMD: %[[UB1:.*]] = arith.subi %[[UB:.*]], %arg2 : index +// AMD: %[[FOR:.*]]:6 = scf.for %[[ARG5:.*]] = %[[LB:.*]] to %[[UB1]] step %[[STEP:.*]] iter_args(%[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %{{.*}}, %[[ARG9:.*]] = %{{.*}}, %[[ARG10:.*]] = %{{.*}}, %[[ARG11:.*]] = %{{.*}}) // AMD: %[[LOCAL_LOAD_32:.*]] = triton_gpu.local_load %[[ARG10]] // AMD: %[[LOCAL_LOAD_33:.*]] = triton_gpu.local_load %[[ARG11]] // AMD: %[[MULF_34:.*]] = arith.mulf %[[LOCAL_LOAD_33]], %{{.*}} @@ -76,22 +79,24 @@ // AMD: triton_gpu.local_store %[[LOAD_39]], %[[MEMDESC_SUBVIEW_44]] // AMD: scf.yield %[[ADDPTR_36]], %[[ADDPTR_37]], %[[DOT_35]], %[[SELECT_42]], %[[MEMDESC_SUBVIEW_43]], %[[MEMDESC_SUBVIEW_44]] // AMD: } -// AMD: %[[SUBI_21:.*]] = arith.subi %{{.*}}, %{{.*}} -// AMD: %[[ADDI_22:.*]] = arith.addi %[[SUBI_21]], %{{.*}} -// AMD: %[[ADDI_23:.*]] = arith.addi %[[ADDI_22]], %{{.*}}-1 -// AMD: %[[DIVUI_24:.*]] = arith.divui %[[ADDI_23]], %{{.*}} -// AMD: %[[ADDI_25:.*]] = arith.addi %[[DIVUI_24]], %{{.*}}-1 -// AMD: %[[CMPI_26:.*]] = arith.cmpi sge, %[[ADDI_25]], %[[C0]] -// AMD: %[[LOCAL_LOAD_27:.*]] = triton_gpu.local_load %{{.*}}#4 -// AMD: %[[LOCAL_LOAD_28:.*]] = triton_gpu.local_load %{{.*}}#5 +// AMD: %[[CMPI_21:.*]] = arith.cmpi slt, %[[STEP]], %[[C0]] +// AMD: %[[SELECT_22:.*]] = arith.select %[[CMPI_21]], %[[C1]], %[[CM1]] +// AMD: %[[SUBI_23:.*]] = arith.subi %[[UB]], %[[LB]] +// AMD: %[[ADDI_24:.*]] = arith.addi %[[SUBI_23]], %[[STEP]] +// AMD: %[[ADDI_25:.*]] = arith.addi %[[ADDI_24]], %[[SELECT_22]] +// AMD: %[[DIVUI_26:.*]] = arith.divui %[[ADDI_25]], %[[STEP]] +// AMD: %[[ADDI_27:.*]] = arith.addi %[[DIVUI_26]], %[[CM1]] +// AMD: %[[CMPI_28:.*]] = arith.cmpi sge, %[[ADDI_27]], %[[C0]] +// AMD: %[[LOCAL_LOAD_27:.*]] = triton_gpu.local_load %[[FOR]]#4 +// AMD: %[[LOCAL_LOAD_28:.*]] = triton_gpu.local_load %[[FOR]]#5 // AMD: %[[MULF_29:.*]] = arith.mulf %[[LOCAL_LOAD_28]], %{{.*}} -// AMD: %[[IF_30:.*]] = scf.if %[[CMPI_26]] -// AMD: %[[DOT_32:.*]] = tt.dot %[[LOCAL_LOAD_27]], %[[MULF_29]], %{{.*}}#2 +// AMD: %[[IF_30:.*]] = scf.if %[[CMPI_28]] +// AMD: %[[DOT_32:.*]] = tt.dot %[[LOCAL_LOAD_27]], %[[MULF_29]], %[[FOR]]#2 // AMD: scf.yield %[[DOT_32]] // AMD: } else { -// AMD: scf.yield %{{.*}}#2 +// AMD: scf.yield %[[FOR]]#2 // AMD: } -// AMD: %[[SELECT_31:.*]] = arith.select %[[CMPI_26]], %[[IF_30]], %{{.*}}#2 +// AMD: %[[SELECT_31:.*]] = arith.select %[[CMPI_28]], %[[IF_30]], %[[FOR]]#2 // AMD: triton_gpu.local_dealloc %{{.*}} // AMD: triton_gpu.local_dealloc %{{.*}}