Skip to content

Commit

Permalink
[cpu] Don't reuse shuffle dummies (#88)
Browse files Browse the repository at this point in the history
This results in the following compile-time assertion error (in debug Triton builds):

    Assertion `Index < size() && "invalid index for value range"' failed.

This occurs when there is more than one tt.reduce call with a given
number of arguments in the same function, with the later call using more
arguments. Reusing the dummy values means that the subsequent call has
fewer dummy values than expected, hence the error.

This bug also resulted in type mismatches errors between the reused
dummy value and the current input value.
  • Loading branch information
int3 authored Aug 5, 2024
1 parent 240ada1 commit 58d232e
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 12 deletions.
9 changes: 9 additions & 0 deletions .github/workflows/build-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,12 @@ jobs:
python/test/unit/language/test_annotations.py \
python/test/unit/language/test_block_pointer.py \
python/test/unit/language/test_conversions.py
- name: Run lit tests
run: |
cd python
LIT_TEST_DIR="build/$(ls build | grep -i cmake)/test"
if [ ! -d "${LIT_TEST_DIR}" ]; then
echo "Could not find '${LIT_TEST_DIR}'" ; exit -1
fi
lit -v "${LIT_TEST_DIR}/TritonCPU"
18 changes: 18 additions & 0 deletions test/TritonCPU/reduction.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
// RUN: triton-opt %s -split-input-file -triton-cpu-convert-reduction -canonicalize

// Regression test: Check that we handle consecutive calls to tt.reduce with
// different types & number of arguments.

module {
tt.func public @triton_(%arg0: tensor<1x4xf32>, %arg1: tensor<1x4xi32>) {
%0 = "tt.reduce"(%arg0) <{axis = 1 : i32}> ({
^bb0(%arg3: f32, %arg4: f32):
tt.reduce.return %arg3 : f32
}) : (tensor<1x4xf32>) -> tensor<1xf32>
%1:2 = "tt.reduce"(%arg0, %arg1) <{axis = 1 : i32}> ({
^bb0(%arg3: f32, %arg4: i32, %arg5: f32, %arg6: i32):
tt.reduce.return %arg3, %arg4 : f32, i32
}) : (tensor<1x4xf32>, tensor<1x4xi32>) -> (tensor<1xf32>, tensor<1xi32>)
tt.return
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ struct ReduceOpConversion
SmallVector<int64_t> range(vecSize);
std::iota(range.begin(), range.end(), 0);

ArrayRef<Value> dummies = createShuffleDummies(loc, inputs, rewriter);
SmallVector<Value> dummies = createShuffleDummies(loc, inputs, rewriter);
SmallVector<Value> res = inputs;
for (int64_t stride = vecSize / 2; stride > 0; stride = stride / 2) {
SmallVector<int64_t> shuffleIndices = range;
Expand Down
2 changes: 1 addition & 1 deletion third_party/cpu/lib/TritonToTritonCPU/ConvertScanOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ struct ScanOpConversion
int64_t vecSize = cast<VectorType>(inputs[0].getType()).getShape()[0];
Type maskTy = VectorType::get(vecSize, rewriter.getI1Type());

ArrayRef<Value> dummies = createShuffleDummies(loc, inputs, rewriter);
SmallVector<Value> dummies = createShuffleDummies(loc, inputs, rewriter);
SmallVector<Value> res = inputs;
for (int64_t stride = 1; stride < vecSize; stride *= 2) {
SmallVector<int64_t> shuffleIndices(vecSize, 0);
Expand Down
18 changes: 8 additions & 10 deletions third_party/cpu/lib/TritonToTritonCPU/ReduceScanCommon.h
Original file line number Diff line number Diff line change
Expand Up @@ -221,24 +221,22 @@ struct ReduceScanOpConversionBase : public OpConversionPattern<OpT> {

// Dummy vectors are required for shuffles that cannot work on a single
// vector.
ArrayRef<Value>
SmallVector<Value>
createShuffleDummies(Location loc, ValueRange inputs,
ConversionPatternRewriter &rewriter) const {
if (shuffleDummies.empty()) {
SmallVector<int64_t, 1> dummyShape({1});
for (auto val : inputs) {
auto ty = cast<VectorType>(val.getType());
shuffleDummies.push_back(rewriter.create<arith::ConstantOp>(
loc, rewriter.getZeroAttr(
ty.cloneWith(dummyShape, ty.getElementType()))));
}
SmallVector<Value> shuffleDummies;
SmallVector<int64_t, 1> dummyShape({1});
for (auto val : inputs) {
auto ty = cast<VectorType>(val.getType());
shuffleDummies.push_back(rewriter.create<arith::ConstantOp>(
loc,
rewriter.getZeroAttr(ty.cloneWith(dummyShape, ty.getElementType()))));
}
return shuffleDummies;
}

private:
mutable IRMapping invariantsMap;
mutable SmallVector<Value> shuffleDummies;
};

} // namespace cpu
Expand Down

0 comments on commit 58d232e

Please sign in to comment.