Skip to content

Commit

Permalink
[Pipeliner] Fix loop iteration calculation for negative step (#4786)
Browse files Browse the repository at this point in the history
This fixes loop iteration count calculation if the step is
a negative value, where we should adjust the added
delta from `step-1` to `step+1` when doing the ceil div.
  • Loading branch information
sjw36 authored Sep 24, 2024
1 parent 6152840 commit 4d711bd
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 23 deletions.
27 changes: 17 additions & 10 deletions lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -657,18 +657,25 @@ LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter,
// Emit different versions of the induction variable. They will be
// removed by dead code if not used.

// bounds_range = ub - lb
// total_iterations = (bounds_range + step - 1) / step
// range_diff = ub - lb
// total_iterations = (range_diff + step + (step < 0 ? 1 : -1)) / step
Type t = lb.getType();
Value minus1 =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -1));
Value boundsRange = rewriter.create<arith::SubIOp>(loc, ub, lb);
Value rangeIncr = rewriter.create<arith::AddIOp>(loc, boundsRange, step);
Value rangeDecr = rewriter.create<arith::AddIOp>(loc, rangeIncr, minus1);
Value totalIterations = rewriter.create<arith::DivUIOp>(loc, rangeDecr, step);

Value zero =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, 0));
Value one =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, 1));
Value minusOne =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -1));
Value stepLessZero = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, step, zero);
Value stepDecr =
rewriter.create<arith::SelectOp>(loc, stepLessZero, one, minusOne);

Value rangeDiff = rewriter.create<arith::SubIOp>(loc, ub, lb);
Value rangeIncrStep = rewriter.create<arith::AddIOp>(loc, rangeDiff, step);
Value rangeDecr =
rewriter.create<arith::AddIOp>(loc, rangeIncrStep, stepDecr);
Value totalIterations = rewriter.create<arith::DivUIOp>(loc, rangeDecr, step);

// Capture predicates for dynamic loops.
SmallVector<Value> predicates(maxStage + 1);
Expand All @@ -679,7 +686,7 @@ LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter,
Value minusI =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -i));
Value iterI = rewriter.create<arith::AddIOp>(
loc, rewriter.create<arith::AddIOp>(loc, totalIterations, minus1),
loc, rewriter.create<arith::AddIOp>(loc, totalIterations, minusOne),
minusI);
// newLastIter = lb + step * iterI
Value newlastIter = rewriter.create<arith::AddIOp>(
Expand Down
31 changes: 18 additions & 13 deletions test/TritonGPU/loop-pipeline.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,11 @@
// CHECK: scf.yield {{.*}}, %[[INS_IDX_3]], %[[EXT_IDX_3]], %[[NEXT_A]], %[[NEXT_B]]

// AMD-LABEL: tt.func @matmul_loop
// AMD-DAG: %[[CM1:.*]] = arith.constant -1 : index
// AMD-DAG: %[[C1:.*]] = arith.constant 1 : index
// AMD-DAG: %[[C0:.*]] = arith.constant 0 : index
// AMD: %{{.*}}:6 = scf.for %[[ARG5:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %{{.*}}, %[[ARG9:.*]] = %{{.*}}, %[[ARG10:.*]] = %{{.*}}, %[[ARG11:.*]] = %{{.*}})
// AMD: %[[UB1:.*]] = arith.subi %[[UB:.*]], %arg2 : index
// AMD: %[[FOR:.*]]:6 = scf.for %[[ARG5:.*]] = %[[LB:.*]] to %[[UB1]] step %[[STEP:.*]] iter_args(%[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %{{.*}}, %[[ARG9:.*]] = %{{.*}}, %[[ARG10:.*]] = %{{.*}}, %[[ARG11:.*]] = %{{.*}})
// AMD: %[[LOCAL_LOAD_32:.*]] = triton_gpu.local_load %[[ARG10]]
// AMD: %[[LOCAL_LOAD_33:.*]] = triton_gpu.local_load %[[ARG11]]
// AMD: %[[MULF_34:.*]] = arith.mulf %[[LOCAL_LOAD_33]], %{{.*}}
Expand All @@ -76,22 +79,24 @@
// AMD: triton_gpu.local_store %[[LOAD_39]], %[[MEMDESC_SUBVIEW_44]]
// AMD: scf.yield %[[ADDPTR_36]], %[[ADDPTR_37]], %[[DOT_35]], %[[SELECT_42]], %[[MEMDESC_SUBVIEW_43]], %[[MEMDESC_SUBVIEW_44]]
// AMD: }
// AMD: %[[SUBI_21:.*]] = arith.subi %{{.*}}, %{{.*}}
// AMD: %[[ADDI_22:.*]] = arith.addi %[[SUBI_21]], %{{.*}}
// AMD: %[[ADDI_23:.*]] = arith.addi %[[ADDI_22]], %{{.*}}-1
// AMD: %[[DIVUI_24:.*]] = arith.divui %[[ADDI_23]], %{{.*}}
// AMD: %[[ADDI_25:.*]] = arith.addi %[[DIVUI_24]], %{{.*}}-1
// AMD: %[[CMPI_26:.*]] = arith.cmpi sge, %[[ADDI_25]], %[[C0]]
// AMD: %[[LOCAL_LOAD_27:.*]] = triton_gpu.local_load %{{.*}}#4
// AMD: %[[LOCAL_LOAD_28:.*]] = triton_gpu.local_load %{{.*}}#5
// AMD: %[[CMPI_21:.*]] = arith.cmpi slt, %[[STEP]], %[[C0]]
// AMD: %[[SELECT_22:.*]] = arith.select %[[CMPI_21]], %[[C1]], %[[CM1]]
// AMD: %[[SUBI_23:.*]] = arith.subi %[[UB]], %[[LB]]
// AMD: %[[ADDI_24:.*]] = arith.addi %[[SUBI_23]], %[[STEP]]
// AMD: %[[ADDI_25:.*]] = arith.addi %[[ADDI_24]], %[[SELECT_22]]
// AMD: %[[DIVUI_26:.*]] = arith.divui %[[ADDI_25]], %[[STEP]]
// AMD: %[[ADDI_27:.*]] = arith.addi %[[DIVUI_26]], %[[CM1]]
// AMD: %[[CMPI_28:.*]] = arith.cmpi sge, %[[ADDI_27]], %[[C0]]
// AMD: %[[LOCAL_LOAD_27:.*]] = triton_gpu.local_load %[[FOR]]#4
// AMD: %[[LOCAL_LOAD_28:.*]] = triton_gpu.local_load %[[FOR]]#5
// AMD: %[[MULF_29:.*]] = arith.mulf %[[LOCAL_LOAD_28]], %{{.*}}
// AMD: %[[IF_30:.*]] = scf.if %[[CMPI_26]]
// AMD: %[[DOT_32:.*]] = tt.dot %[[LOCAL_LOAD_27]], %[[MULF_29]], %{{.*}}#2
// AMD: %[[IF_30:.*]] = scf.if %[[CMPI_28]]
// AMD: %[[DOT_32:.*]] = tt.dot %[[LOCAL_LOAD_27]], %[[MULF_29]], %[[FOR]]#2
// AMD: scf.yield %[[DOT_32]]
// AMD: } else {
// AMD: scf.yield %{{.*}}#2
// AMD: scf.yield %[[FOR]]#2
// AMD: }
// AMD: %[[SELECT_31:.*]] = arith.select %[[CMPI_26]], %[[IF_30]], %{{.*}}#2
// AMD: %[[SELECT_31:.*]] = arith.select %[[CMPI_28]], %[[IF_30]], %[[FOR]]#2
// AMD: triton_gpu.local_dealloc %{{.*}}
// AMD: triton_gpu.local_dealloc %{{.*}}

Expand Down

0 comments on commit 4d711bd

Please sign in to comment.