Skip to content

Commit

Permalink
[NFC] Remove dead code related to IndexCastOp (#5596)
Browse files Browse the repository at this point in the history
IndexCast shouldn't exist at TTIR or TTGIR level
  • Loading branch information
ThomasRaoux authored Jan 13, 2025
1 parent 6aa2df9 commit 8f6e9d2
Show file tree
Hide file tree
Showing 9 changed files with 37 additions and 86 deletions.
1 change: 0 additions & 1 deletion lib/Analysis/AxisInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1010,7 +1010,6 @@ AxisInfoAnalysis::AxisInfoAnalysis(DataFlowSolver &solver)
visitors.append<CastOpAxisInfoVisitor<arith::ExtSIOp>,
CastOpAxisInfoVisitor<arith::ExtUIOp>,
CastOpAxisInfoVisitor<arith::TruncIOp>,
CastOpAxisInfoVisitor<arith::IndexCastOp>,
CastOpAxisInfoVisitor<triton::gpu::ConvertLayoutOp>,
CastOpAxisInfoVisitor<mlir::UnrealizedConversionCastOp>,
CastOpAxisInfoVisitor<triton::BitcastOp>>();
Expand Down
30 changes: 0 additions & 30 deletions lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -458,35 +458,6 @@ struct AbsFOpConversion
return {rewriter.create<LLVM::FAbsOp>(loc, elemTy, operands[0][0])};
}
};
/// The lowering of index_cast becomes an integer conversion since index
/// becomes an integer. If the bit width of the source and target integer
/// types is the same, just erase the cast. If the target type is wider,
/// sign-extend the value, otherwise truncate it.
struct IndexCastOpLowering
: public ElementwiseOpConversionBase<arith::IndexCastOp,
IndexCastOpLowering> {
using Base =
ElementwiseOpConversionBase<arith::IndexCastOp, IndexCastOpLowering>;
using Base::Base;
using Adaptor = typename Base::OpAdaptor;

SmallVector<Value> createDestOps(arith::IndexCastOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter,
Type elemTy, MultipleOperandsRange operands,
Location loc) const {
auto inElemTy =
this->getTypeConverter()->convertType(getElementType(op.getIn()));
unsigned targetBits = elemTy.getIntOrFloatBitWidth();
unsigned sourceBits = inElemTy.getIntOrFloatBitWidth();

if (targetBits == sourceBits)
return {operands[0][0]};
if (targetBits < sourceBits)
return {
rewriter.create<LLVM::TruncOp>(op.getLoc(), elemTy, operands[0][0])};
return {rewriter.create<LLVM::SExtOp>(op.getLoc(), elemTy, operands[0][0])};
}
};

struct SelectOpConversion
: ElementwiseOpConversionBase<arith::SelectOp, SelectOpConversion> {
Expand Down Expand Up @@ -705,6 +676,5 @@ void mlir::triton::populateElementwiseOpToLLVMPatterns(
patterns.add<ElementwiseInlineAsmOpConversion>(typeConverter, benefit);
patterns.add<AbsIOpConversion>(typeConverter, axisInfoAnalysis, benefit);
patterns.add<AbsFOpConversion>(typeConverter, axisInfoAnalysis, benefit);
patterns.add<IndexCastOpLowering>(typeConverter, axisInfoAnalysis, benefit);
patterns.add<SelectOpConversion>(typeConverter, axisInfoAnalysis, benefit);
}
12 changes: 1 addition & 11 deletions python/src/ir.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#include <optional>
#include <optional>
#include <pybind11/functional.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
Expand Down Expand Up @@ -1033,16 +1033,6 @@ void init_triton_ir(py::module &&m) {
else
return self.create<arith::ExtUIOp>(dstType, src);
})
.def("create_to_index",
[](TritonOpBuilder &self, Value &input) -> Value {
return self.create<arith::IndexCastOp>(
self.getBuilder().getIndexType(), input);
})
.def("create_index_to_si",
[](TritonOpBuilder &self, Value &input) -> Value {
return self.create<arith::IndexCastOp>(
self.getBuilder().getI64Type(), input);
})
.def("create_fmul",
[](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value {
return self.create<arith::MulFOp>(lhs, rhs);
Expand Down
18 changes: 10 additions & 8 deletions test/Analysis/test-alignment.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -473,14 +473,14 @@ tt.func @for() {
// CHECK-NEXT: contiguity = [1, 1], divisibility = [4, 4], constancy = [128, 32], constant_value = 4
%c_init = arith.constant dense<4> : tensor<128x32xi32>
// CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [1], constant_value = 128
%ub = arith.constant 128 : index
%ub = arith.constant 128 : i32
// CHECK-NEXT: contiguity = [1], divisibility = [4611686018427387904], constancy = [1], constant_value = 0
%lb = arith.constant 0 : index
%lb = arith.constant 0 : i32
// CHECK-NEXT: contiguity = [1], divisibility = [16], constancy = [1], constant_value = 16
%step = arith.constant 16 : index
%a, %b, %c = scf.for %iv = %lb to %ub step %step iter_args(%a = %a_init, %b = %b_init, %c = %c_init) -> (tensor<128x32xi32>, tensor<128x32xi32>, tensor<128x32xi32>) {
%step = arith.constant 16 : i32
%a, %b, %c = scf.for %iv = %lb to %ub step %step iter_args(%a = %a_init, %b = %b_init, %c = %c_init) -> (tensor<128x32xi32>, tensor<128x32xi32>, tensor<128x32xi32>) : i32 {
// CHECK-NEXT: contiguity = [1], divisibility = [16], constancy = [1], constant_value = <none>
%t = arith.index_cast %iv : index to i32
%t = arith.addi %iv, %lb : i32
// CHECK: contiguity = [1, 1], divisibility = [1, 1], constancy = [128, 32], constant_value = <none>
// CHECK: contiguity = [1, 1], divisibility = [1, 1], constancy = [128, 32], constant_value = <none>
// CHECK: contiguity = [1, 1], divisibility = [4, 4], constancy = [128, 32], constant_value = 4
Expand All @@ -492,10 +492,12 @@ tt.func @for() {
// -----

// CHECK-LABEL: @for_dynamic
tt.func @for_dynamic(%lb: index {tt.divisibility = 16 : i32}, %step: index {tt.divisibility = 8 : i32}, %ub: index) {
scf.for %iv = %lb to %ub step %step {
tt.func @for_dynamic(%lb: i32 {tt.divisibility = 16 : i32}, %step: i32 {tt.divisibility = 8 : i32}, %ub: i32) {
// CHECK-NEXT: contiguity = [1], divisibility = [4611686018427387904], constancy = [1], constant_value = 0
%c0 = arith.constant 0 : i32
scf.for %iv = %lb to %ub step %step : i32 {
// CHECK-NEXT: contiguity = [1], divisibility = [8], constancy = [1], constant_value = <none>
%t = arith.index_cast %iv : index to i32
%t = arith.addi %iv, %c0 : i32
}
tt.return
}
Expand Down
5 changes: 1 addition & 4 deletions test/Triton/vecadd.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,7 @@ module {
%11 = tt.splat %cst : f32 -> tensor<256xf32>
%c0_i32 = arith.constant 0 : i32
%c32_i32 = arith.constant 32 : i32
%12 = arith.index_cast %c0_i32 : i32 to index
%13 = arith.index_cast %arg4 : i32 to index
%14 = arith.index_cast %c32_i32 : i32 to index
%15:3 = scf.for %arg6 = %12 to %13 step %14 iter_args(%arg7 = %11, %arg8 = %8, %arg9 = %10) -> (tensor<256xf32>, tensor<256x!tt.ptr<f32>>, tensor<256x!tt.ptr<f32>>) {
%15:3 = scf.for %arg6 = %c0_i32 to %arg4 step %c32_i32 iter_args(%arg7 = %11, %arg8 = %8, %arg9 = %10) -> (tensor<256xf32>, tensor<256x!tt.ptr<f32>>, tensor<256x!tt.ptr<f32>>) : i32 {
%cst_0 = arith.constant 0.000000e+00 : f32
%18 = tt.splat %cst_0 : f32 -> tensor<256xf32>
%19 = tt.load %arg8, %6, %18 : tensor<256x!tt.ptr<f32>>
Expand Down
8 changes: 3 additions & 5 deletions test/TritonGPU/amd/amd-convert-buffer-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -344,23 +344,21 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
#blocked = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
// CHECK-LABEL: unsigned_ops
tt.func @unsigned_ops(%arg0: !tt.ptr<bf16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr<bf16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg2 : i32, %arg3 : i32, %arg4 : f32, %arg5 : index) {
tt.func @unsigned_ops(%arg0: !tt.ptr<bf16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr<bf16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg2 : i32, %arg3 : i32, %arg4 : f32) {
%c5_i32 = arith.constant 5 : i32
%0 = arith.ceildivui %arg2, %c5_i32 : i32
%1 = arith.divui %arg3, %c5_i32 : i32
%2 = arith.fptoui %arg4 : f32 to i32
%3 = arith.index_castui %arg5 : index to i32
%4 = arith.maxui %arg2, %arg3 : i32
%5 = arith.minui %arg2, %arg3 : i32
%6 = arith.remui %arg2, %c5_i32 : i32
%7 = arith.shrui %arg3, %c5_i32 : i32
%8 = arith.addi %0, %1 : i32
%9 = arith.addi %2, %3 : i32
%10 = arith.addi %4, %5 : i32
%11 = arith.addi %6, %7 : i32
%12 = arith.addi %8, %9 : i32
%12 = arith.addi %8, %2 : i32
%13 = arith.addi %10, %11 : i32
%14 = arith.addi %12, %13 : i32
%14 = arith.addi %8, %13 : i32
%15 = tt.splat %14 : i32 -> tensor<8xi32, #blocked>
%16 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32, #blocked>
%17 = arith.addi %15, %16 : tensor<8xi32, #blocked>
Expand Down
38 changes: 17 additions & 21 deletions test/TritonGPU/combine.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -453,9 +453,9 @@ tt.func @select(%arg0: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr
// CHECK-NOT: ttg.convert_layout
%cst = arith.constant dense<30000> : tensor<1x1xi32, #blocked2>
%cst_0 = arith.constant dense<30000> : tensor<1x512xi32, #blocked2>
%c512 = arith.constant 512 : index
%c30000 = arith.constant 30000 : index
%c0 = arith.constant 0 : index
%c512 = arith.constant 512 : i32
%c30000 = arith.constant 30000 : i32
%c0 = arith.constant 0 : i32
%cst_1 = arith.constant dense<2048> : tensor<1x1xi32, #blocked2>
%cst_2 = arith.constant dense<0.000000e+00> : tensor<1x512xf64, #blocked2>
%0 = tt.get_program_id x : i32
Expand All @@ -473,9 +473,8 @@ tt.func @select(%arg0: !tt.ptr<f64> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr
%12 = tt.broadcast %11 : tensor<1x1xi32, #blocked2> -> tensor<1x512xi32, #blocked2>
%13 = tt.splat %arg0 : !tt.ptr<f64> -> tensor<1x512x!tt.ptr<f64>, #blocked2>
%14 = tt.broadcast %7 : tensor<1x1xi1, #blocked2> -> tensor<1x512xi1, #blocked2>
%15 = scf.for %arg3 = %c0 to %c30000 step %c512 iter_args(%arg4 = %cst_2) -> (tensor<1x512xf64, #blocked2>) {
%16 = arith.index_cast %arg3 : index to i32
%17 = tt.splat %16 : i32 -> tensor<1x512xi32, #blocked2>
%15 = scf.for %arg3 = %c0 to %c30000 step %c512 iter_args(%arg4 = %cst_2) -> (tensor<1x512xf64, #blocked2>) : i32 {
%17 = tt.splat %arg3 : i32 -> tensor<1x512xi32, #blocked2>
%18 = arith.addi %17, %10 : tensor<1x512xi32, #blocked2>
%19 = arith.cmpi "slt", %18, %cst_0 : tensor<1x512xi32, #blocked2>
%20 = arith.addi %18, %12 : tensor<1x512xi32, #blocked2>
Expand Down Expand Up @@ -999,9 +998,9 @@ tt.func public @mnist(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !
// CHECK-LABEL: cmp
module attributes {"ttg.num-warps" = 8 : i32, "ttg.num-ctas" = 1 : i32} {
tt.func public @cmp(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) {
%c64 = arith.constant 64 : index
%c2048 = arith.constant 2048 : index
%c0 = arith.constant 0 : index
%c64 = arith.constant 64 : i32
%c2048 = arith.constant 2048 : i32
%c0 = arith.constant 0 : i32
%c64_i32 = arith.constant 64 : i32
%cst = arith.constant dense<-3.40282347E+38> : tensor<64x64xf32, #blocked2>
%cst_0 = arith.constant dense<4194304> : tensor<64x1xi32, #blocked2>
Expand Down Expand Up @@ -1036,9 +1035,8 @@ tt.func public @cmp(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt
%22 = arith.muli %21, %cst_0 : tensor<64x1xi32, #blocked2>
%23 = tt.broadcast %22 : tensor<64x1xi32, #blocked2> -> tensor<64x64xi32, #blocked2>
%24 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<64x64x!tt.ptr<f32>, #blocked2>
%25 = scf.for %arg6 = %c0 to %c2048 step %c64 iter_args(%arg7 = %14) -> (tensor<64x64xf32, #blocked2>) {
%44 = arith.index_cast %arg6 : index to i32
%45 = tt.splat %44 : i32 -> tensor<1x64xi32, #blocked3>
%25 = scf.for %arg6 = %c0 to %c2048 step %c64 iter_args(%arg7 = %14) -> (tensor<64x64xf32, #blocked2>) : i32 {
%45 = tt.splat %arg6 : i32 -> tensor<1x64xi32, #blocked3>
%46 = arith.addi %45, %10 : tensor<1x64xi32, #blocked3>
%47 = arith.cmpi "slt", %46, %cst_2 : tensor<1x64xi32, #blocked3>
%48 = tt.broadcast %46 : tensor<1x64xi32, #blocked3> -> tensor<64x64xi32, #blocked3>
Expand Down Expand Up @@ -1092,9 +1090,8 @@ tt.func public @cmp(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt
%41 = tt.broadcast %30 : tensor<64x1xf32, #blocked2> -> tensor<64x64xf32, #blocked2>
%42 = tt.splat %arg2 : !tt.ptr<f32> -> tensor<64x64x!tt.ptr<f32>, #blocked2>
%43 = tt.splat %arg3 : !tt.ptr<f16> -> tensor<64x64x!tt.ptr<f16>, #blocked2>
scf.for %arg6 = %c0 to %c2048 step %c64 {
%44 = arith.index_cast %arg6 : index to i32
%45 = tt.splat %44 : i32 -> tensor<1x64xi32, #blocked3>
scf.for %arg6 = %c0 to %c2048 step %c64 : i32 {
%45 = tt.splat %arg6 : i32 -> tensor<1x64xi32, #blocked3>
%46 = arith.addi %45, %10 : tensor<1x64xi32, #blocked3>
%47 = arith.cmpi "slt", %46, %cst_2 : tensor<1x64xi32, #blocked3>
%48 = tt.broadcast %46 : tensor<1x64xi32, #blocked3> -> tensor<64x64xi32, #blocked3>
Expand Down Expand Up @@ -1226,9 +1223,9 @@ module attributes {"ttg.num-warps" = 2 : i32, "ttg.num-ctas" = 1 : i32} {
module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} {
tt.func public @reduce_cvt2(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}) {
%cst = arith.constant dense<0.000000e+00> : tensor<1x256xf32, #blocked>
%c3136_i32 = arith.constant 3136 : index
%c256_i32 = arith.constant 256 : index
%c0_i32 = arith.constant 0 : index
%c3136_i32 = arith.constant 3136 : i32
%c256_i32 = arith.constant 256 : i32
%c0_i32 = arith.constant 0 : i32
%cst_0 = arith.constant dense<3.136000e+03> : tensor<1x1xf32, #blocked>
%cst_1 = arith.constant dense<50176> : tensor<1x256xi32, #blocked>
%cst_2 = arith.constant dense<196> : tensor<1x1xi32, #blocked>
Expand All @@ -1250,9 +1247,8 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} {
%12 = tt.broadcast %11 : tensor<1x1xi32, #blocked> -> tensor<1x256xi32, #blocked>
%13 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<1x256x!tt.ptr<f32>, #blocked>
%14 = tt.broadcast %7 : tensor<1x1xi1, #blocked> -> tensor<1x256xi1, #blocked>
%15 = scf.for %arg5 = %c0_i32 to %c3136_i32 step %c256_i32 iter_args(%arg6 = %cst) -> (tensor<1x256xf32, #blocked>) {
%42 = arith.index_cast %arg5 : index to i32
%43 = tt.splat %42 : i32 -> tensor<1x256xi32, #blocked>
%15 = scf.for %arg5 = %c0_i32 to %c3136_i32 step %c256_i32 iter_args(%arg6 = %cst) -> (tensor<1x256xf32, #blocked>) : i32 {
%43 = tt.splat %arg5 : i32 -> tensor<1x256xi32, #blocked>
%44 = arith.addi %43, %10 : tensor<1x256xi32, #blocked>
%45 = arith.cmpi "slt", %44, %cst_4 : tensor<1x256xi32, #blocked>
%46 = arith.remsi %44, %cst_3 : tensor<1x256xi32, #blocked>
Expand Down
7 changes: 3 additions & 4 deletions test/TritonGPU/matmul.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
module {
tt.func @matmul_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32_i32_i32_i32_i32_i32_i32__12c64_13c64_14c64_15c8(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32, %arg8: i32 {tt.divisibility = 16 : i32}, %arg9: i32, %arg10: i32 {tt.divisibility = 16 : i32}, %arg11: i32) {
%cst = arith.constant dense<true> : tensor<64x64xi1>
%c64 = arith.constant 64 : index
%c0 = arith.constant 0 : index
%c64 = arith.constant 64 : i32
%c0 = arith.constant 0 : i32
%cst_0 = arith.constant dense<0.000000e+00> : tensor<64x64xf32>
%c64_i32 = arith.constant 64 : i32
%c63_i32 = arith.constant 63 : i32
Expand Down Expand Up @@ -58,8 +58,7 @@ tt.func @matmul_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32_i32_i32_i32_i32_i32_i32__1
%43 = arith.addi %41, %42 : tensor<64x64xi32>
%44 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<64x64x!tt.ptr<f32>>
%45 = tt.addptr %44, %43 : tensor<64x64x!tt.ptr<f32>>, tensor<64x64xi32>
%46 = arith.index_cast %arg5 : i32 to index
%47:3 = scf.for %arg12 = %c0 to %46 step %c64 iter_args(%arg13 = %cst_0, %arg14 = %34, %arg15 = %45) -> (tensor<64x64xf32>, tensor<64x64x!tt.ptr<f32>>, tensor<64x64x!tt.ptr<f32>>) {
%47:3 = scf.for %arg12 = %c0 to %arg5 step %c64 iter_args(%arg13 = %cst_0, %arg14 = %34, %arg15 = %45) -> (tensor<64x64xf32>, tensor<64x64x!tt.ptr<f32>>, tensor<64x64x!tt.ptr<f32>>) : i32 {
%76 = tt.load %arg14, %cst, %cst_0 : tensor<64x64x!tt.ptr<f32>>
%77 = tt.load %arg15, %cst, %cst_0 : tensor<64x64x!tt.ptr<f32>>
%78 = tt.dot %76, %77, %cst_0 : tensor<64x64xf32> * tensor<64x64xf32> -> tensor<64x64xf32>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,8 @@ bool verifyNonNegativeExpr(Value expr, const DenseSet<Value> &assumptions) {
.Case<triton::PtrToIntOp, triton::BitcastOp>(
[&](auto) { return false; })
.Case<arith::CeilDivUIOp, arith::DivUIOp, arith::ExtUIOp,
arith::FPToUIOp, arith::IndexCastUIOp, arith::MaxUIOp,
arith::MinUIOp, arith::RemUIOp, arith::ShRUIOp>(
arith::FPToUIOp, arith::MaxUIOp, arith::MinUIOp, arith::RemUIOp,
arith::ShRUIOp>(
// These OPs also return unsigned values.
// TODO: We can also sniff whether a Value is unsigned by looking
// for whether or not it's used as an argument to one of
Expand Down

0 comments on commit 8f6e9d2

Please sign in to comment.