diff --git a/lib/Analysis/AxisInfo.cpp b/lib/Analysis/AxisInfo.cpp index fc6a2c73befc..004275c1cfb3 100644 --- a/lib/Analysis/AxisInfo.cpp +++ b/lib/Analysis/AxisInfo.cpp @@ -1010,7 +1010,6 @@ AxisInfoAnalysis::AxisInfoAnalysis(DataFlowSolver &solver) visitors.append, CastOpAxisInfoVisitor, CastOpAxisInfoVisitor, - CastOpAxisInfoVisitor, CastOpAxisInfoVisitor, CastOpAxisInfoVisitor, CastOpAxisInfoVisitor>(); diff --git a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp index 5869ab36f08b..335912c778fc 100644 --- a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -458,35 +458,6 @@ struct AbsFOpConversion return {rewriter.create(loc, elemTy, operands[0][0])}; } }; -/// The lowering of index_cast becomes an integer conversion since index -/// becomes an integer. If the bit width of the source and target integer -/// types is the same, just erase the cast. If the target type is wider, -/// sign-extend the value, otherwise truncate it. -struct IndexCastOpLowering - : public ElementwiseOpConversionBase { - using Base = - ElementwiseOpConversionBase; - using Base::Base; - using Adaptor = typename Base::OpAdaptor; - - SmallVector createDestOps(arith::IndexCastOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter, - Type elemTy, MultipleOperandsRange operands, - Location loc) const { - auto inElemTy = - this->getTypeConverter()->convertType(getElementType(op.getIn())); - unsigned targetBits = elemTy.getIntOrFloatBitWidth(); - unsigned sourceBits = inElemTy.getIntOrFloatBitWidth(); - - if (targetBits == sourceBits) - return {operands[0][0]}; - if (targetBits < sourceBits) - return { - rewriter.create(op.getLoc(), elemTy, operands[0][0])}; - return {rewriter.create(op.getLoc(), elemTy, operands[0][0])}; - } -}; struct SelectOpConversion : ElementwiseOpConversionBase { @@ -705,6 +676,5 @@ void mlir::triton::populateElementwiseOpToLLVMPatterns( patterns.add(typeConverter, benefit); patterns.add(typeConverter, axisInfoAnalysis, benefit); patterns.add(typeConverter, axisInfoAnalysis, benefit); - patterns.add(typeConverter, axisInfoAnalysis, benefit); patterns.add(typeConverter, axisInfoAnalysis, benefit); } diff --git a/python/src/ir.cc b/python/src/ir.cc index 6c31946d664a..9f6e2bd56496 100644 --- a/python/src/ir.cc +++ b/python/src/ir.cc @@ -1,4 +1,4 @@ -#include +#include #include #include #include @@ -1033,16 +1033,6 @@ void init_triton_ir(py::module &&m) { else return self.create(dstType, src); }) - .def("create_to_index", - [](TritonOpBuilder &self, Value &input) -> Value { - return self.create( - self.getBuilder().getIndexType(), input); - }) - .def("create_index_to_si", - [](TritonOpBuilder &self, Value &input) -> Value { - return self.create( - self.getBuilder().getI64Type(), input); - }) .def("create_fmul", [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { return self.create(lhs, rhs); diff --git a/test/Analysis/test-alignment.mlir b/test/Analysis/test-alignment.mlir index ebc5383b9321..206239ff0910 100644 --- a/test/Analysis/test-alignment.mlir +++ b/test/Analysis/test-alignment.mlir @@ -473,14 +473,14 @@ tt.func @for() { // CHECK-NEXT: contiguity = [1, 1], divisibility = [4, 4], constancy = [128, 32], constant_value = 4 %c_init = arith.constant dense<4> : tensor<128x32xi32> // CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [1], constant_value = 128 - %ub = arith.constant 128 : index + %ub = arith.constant 128 : i32 // CHECK-NEXT: contiguity = [1], divisibility = [4611686018427387904], constancy = [1], constant_value = 0 - %lb = arith.constant 0 : index + %lb = arith.constant 0 : i32 // CHECK-NEXT: contiguity = [1], divisibility = [16], constancy = [1], constant_value = 16 - %step = arith.constant 16 : index - %a, %b, %c = scf.for %iv = %lb to %ub step %step iter_args(%a = %a_init, %b = %b_init, %c = %c_init) -> (tensor<128x32xi32>, tensor<128x32xi32>, tensor<128x32xi32>) { + %step = arith.constant 16 : i32 + %a, %b, %c = scf.for %iv = %lb to %ub step %step iter_args(%a = %a_init, %b = %b_init, %c = %c_init) -> (tensor<128x32xi32>, tensor<128x32xi32>, tensor<128x32xi32>) : i32 { // CHECK-NEXT: contiguity = [1], divisibility = [16], constancy = [1], constant_value = - %t = arith.index_cast %iv : index to i32 + %t = arith.addi %iv, %lb : i32 // CHECK: contiguity = [1, 1], divisibility = [1, 1], constancy = [128, 32], constant_value = // CHECK: contiguity = [1, 1], divisibility = [1, 1], constancy = [128, 32], constant_value = // CHECK: contiguity = [1, 1], divisibility = [4, 4], constancy = [128, 32], constant_value = 4 @@ -492,10 +492,12 @@ tt.func @for() { // ----- // CHECK-LABEL: @for_dynamic -tt.func @for_dynamic(%lb: index {tt.divisibility = 16 : i32}, %step: index {tt.divisibility = 8 : i32}, %ub: index) { - scf.for %iv = %lb to %ub step %step { +tt.func @for_dynamic(%lb: i32 {tt.divisibility = 16 : i32}, %step: i32 {tt.divisibility = 8 : i32}, %ub: i32) { + // CHECK-NEXT: contiguity = [1], divisibility = [4611686018427387904], constancy = [1], constant_value = 0 + %c0 = arith.constant 0 : i32 + scf.for %iv = %lb to %ub step %step : i32 { // CHECK-NEXT: contiguity = [1], divisibility = [8], constancy = [1], constant_value = - %t = arith.index_cast %iv : index to i32 + %t = arith.addi %iv, %c0 : i32 } tt.return } diff --git a/test/Triton/vecadd.mlir b/test/Triton/vecadd.mlir index 1f7de7d6d939..40c4210d2115 100644 --- a/test/Triton/vecadd.mlir +++ b/test/Triton/vecadd.mlir @@ -18,10 +18,7 @@ module { %11 = tt.splat %cst : f32 -> tensor<256xf32> %c0_i32 = arith.constant 0 : i32 %c32_i32 = arith.constant 32 : i32 - %12 = arith.index_cast %c0_i32 : i32 to index - %13 = arith.index_cast %arg4 : i32 to index - %14 = arith.index_cast %c32_i32 : i32 to index - %15:3 = scf.for %arg6 = %12 to %13 step %14 iter_args(%arg7 = %11, %arg8 = %8, %arg9 = %10) -> (tensor<256xf32>, tensor<256x!tt.ptr>, tensor<256x!tt.ptr>) { + %15:3 = scf.for %arg6 = %c0_i32 to %arg4 step %c32_i32 iter_args(%arg7 = %11, %arg8 = %8, %arg9 = %10) -> (tensor<256xf32>, tensor<256x!tt.ptr>, tensor<256x!tt.ptr>) : i32 { %cst_0 = arith.constant 0.000000e+00 : f32 %18 = tt.splat %cst_0 : f32 -> tensor<256xf32> %19 = tt.load %arg8, %6, %18 : tensor<256x!tt.ptr> diff --git a/test/TritonGPU/amd/amd-convert-buffer-ops.mlir b/test/TritonGPU/amd/amd-convert-buffer-ops.mlir index 26ddbad4e067..b1f13396639c 100644 --- a/test/TritonGPU/amd/amd-convert-buffer-ops.mlir +++ b/test/TritonGPU/amd/amd-convert-buffer-ops.mlir @@ -344,23 +344,21 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { #blocked = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: unsigned_ops - tt.func @unsigned_ops(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg2 : i32, %arg3 : i32, %arg4 : f32, %arg5 : index) { + tt.func @unsigned_ops(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg2 : i32, %arg3 : i32, %arg4 : f32) { %c5_i32 = arith.constant 5 : i32 %0 = arith.ceildivui %arg2, %c5_i32 : i32 %1 = arith.divui %arg3, %c5_i32 : i32 %2 = arith.fptoui %arg4 : f32 to i32 - %3 = arith.index_castui %arg5 : index to i32 %4 = arith.maxui %arg2, %arg3 : i32 %5 = arith.minui %arg2, %arg3 : i32 %6 = arith.remui %arg2, %c5_i32 : i32 %7 = arith.shrui %arg3, %c5_i32 : i32 %8 = arith.addi %0, %1 : i32 - %9 = arith.addi %2, %3 : i32 %10 = arith.addi %4, %5 : i32 %11 = arith.addi %6, %7 : i32 - %12 = arith.addi %8, %9 : i32 + %12 = arith.addi %8, %2 : i32 %13 = arith.addi %10, %11 : i32 - %14 = arith.addi %12, %13 : i32 + %14 = arith.addi %8, %13 : i32 %15 = tt.splat %14 : i32 -> tensor<8xi32, #blocked> %16 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32, #blocked> %17 = arith.addi %15, %16 : tensor<8xi32, #blocked> diff --git a/test/TritonGPU/combine.mlir b/test/TritonGPU/combine.mlir index cd45d1ee05b5..5959adbed154 100644 --- a/test/TritonGPU/combine.mlir +++ b/test/TritonGPU/combine.mlir @@ -453,9 +453,9 @@ tt.func @select(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr // CHECK-NOT: ttg.convert_layout %cst = arith.constant dense<30000> : tensor<1x1xi32, #blocked2> %cst_0 = arith.constant dense<30000> : tensor<1x512xi32, #blocked2> - %c512 = arith.constant 512 : index - %c30000 = arith.constant 30000 : index - %c0 = arith.constant 0 : index + %c512 = arith.constant 512 : i32 + %c30000 = arith.constant 30000 : i32 + %c0 = arith.constant 0 : i32 %cst_1 = arith.constant dense<2048> : tensor<1x1xi32, #blocked2> %cst_2 = arith.constant dense<0.000000e+00> : tensor<1x512xf64, #blocked2> %0 = tt.get_program_id x : i32 @@ -473,9 +473,8 @@ tt.func @select(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr %12 = tt.broadcast %11 : tensor<1x1xi32, #blocked2> -> tensor<1x512xi32, #blocked2> %13 = tt.splat %arg0 : !tt.ptr -> tensor<1x512x!tt.ptr, #blocked2> %14 = tt.broadcast %7 : tensor<1x1xi1, #blocked2> -> tensor<1x512xi1, #blocked2> - %15 = scf.for %arg3 = %c0 to %c30000 step %c512 iter_args(%arg4 = %cst_2) -> (tensor<1x512xf64, #blocked2>) { - %16 = arith.index_cast %arg3 : index to i32 - %17 = tt.splat %16 : i32 -> tensor<1x512xi32, #blocked2> + %15 = scf.for %arg3 = %c0 to %c30000 step %c512 iter_args(%arg4 = %cst_2) -> (tensor<1x512xf64, #blocked2>) : i32 { + %17 = tt.splat %arg3 : i32 -> tensor<1x512xi32, #blocked2> %18 = arith.addi %17, %10 : tensor<1x512xi32, #blocked2> %19 = arith.cmpi "slt", %18, %cst_0 : tensor<1x512xi32, #blocked2> %20 = arith.addi %18, %12 : tensor<1x512xi32, #blocked2> @@ -999,9 +998,9 @@ tt.func public @mnist(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: ! // CHECK-LABEL: cmp module attributes {"ttg.num-warps" = 8 : i32, "ttg.num-ctas" = 1 : i32} { tt.func public @cmp(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) { - %c64 = arith.constant 64 : index - %c2048 = arith.constant 2048 : index - %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : i32 + %c2048 = arith.constant 2048 : i32 + %c0 = arith.constant 0 : i32 %c64_i32 = arith.constant 64 : i32 %cst = arith.constant dense<-3.40282347E+38> : tensor<64x64xf32, #blocked2> %cst_0 = arith.constant dense<4194304> : tensor<64x1xi32, #blocked2> @@ -1036,9 +1035,8 @@ tt.func public @cmp(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt %22 = arith.muli %21, %cst_0 : tensor<64x1xi32, #blocked2> %23 = tt.broadcast %22 : tensor<64x1xi32, #blocked2> -> tensor<64x64xi32, #blocked2> %24 = tt.splat %arg1 : !tt.ptr -> tensor<64x64x!tt.ptr, #blocked2> - %25 = scf.for %arg6 = %c0 to %c2048 step %c64 iter_args(%arg7 = %14) -> (tensor<64x64xf32, #blocked2>) { - %44 = arith.index_cast %arg6 : index to i32 - %45 = tt.splat %44 : i32 -> tensor<1x64xi32, #blocked3> + %25 = scf.for %arg6 = %c0 to %c2048 step %c64 iter_args(%arg7 = %14) -> (tensor<64x64xf32, #blocked2>) : i32 { + %45 = tt.splat %arg6 : i32 -> tensor<1x64xi32, #blocked3> %46 = arith.addi %45, %10 : tensor<1x64xi32, #blocked3> %47 = arith.cmpi "slt", %46, %cst_2 : tensor<1x64xi32, #blocked3> %48 = tt.broadcast %46 : tensor<1x64xi32, #blocked3> -> tensor<64x64xi32, #blocked3> @@ -1092,9 +1090,8 @@ tt.func public @cmp(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt %41 = tt.broadcast %30 : tensor<64x1xf32, #blocked2> -> tensor<64x64xf32, #blocked2> %42 = tt.splat %arg2 : !tt.ptr -> tensor<64x64x!tt.ptr, #blocked2> %43 = tt.splat %arg3 : !tt.ptr -> tensor<64x64x!tt.ptr, #blocked2> - scf.for %arg6 = %c0 to %c2048 step %c64 { - %44 = arith.index_cast %arg6 : index to i32 - %45 = tt.splat %44 : i32 -> tensor<1x64xi32, #blocked3> + scf.for %arg6 = %c0 to %c2048 step %c64 : i32 { + %45 = tt.splat %arg6 : i32 -> tensor<1x64xi32, #blocked3> %46 = arith.addi %45, %10 : tensor<1x64xi32, #blocked3> %47 = arith.cmpi "slt", %46, %cst_2 : tensor<1x64xi32, #blocked3> %48 = tt.broadcast %46 : tensor<1x64xi32, #blocked3> -> tensor<64x64xi32, #blocked3> @@ -1226,9 +1223,9 @@ module attributes {"ttg.num-warps" = 2 : i32, "ttg.num-ctas" = 1 : i32} { module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { tt.func public @reduce_cvt2(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}) { %cst = arith.constant dense<0.000000e+00> : tensor<1x256xf32, #blocked> - %c3136_i32 = arith.constant 3136 : index - %c256_i32 = arith.constant 256 : index - %c0_i32 = arith.constant 0 : index + %c3136_i32 = arith.constant 3136 : i32 + %c256_i32 = arith.constant 256 : i32 + %c0_i32 = arith.constant 0 : i32 %cst_0 = arith.constant dense<3.136000e+03> : tensor<1x1xf32, #blocked> %cst_1 = arith.constant dense<50176> : tensor<1x256xi32, #blocked> %cst_2 = arith.constant dense<196> : tensor<1x1xi32, #blocked> @@ -1250,9 +1247,8 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { %12 = tt.broadcast %11 : tensor<1x1xi32, #blocked> -> tensor<1x256xi32, #blocked> %13 = tt.splat %arg1 : !tt.ptr -> tensor<1x256x!tt.ptr, #blocked> %14 = tt.broadcast %7 : tensor<1x1xi1, #blocked> -> tensor<1x256xi1, #blocked> - %15 = scf.for %arg5 = %c0_i32 to %c3136_i32 step %c256_i32 iter_args(%arg6 = %cst) -> (tensor<1x256xf32, #blocked>) { - %42 = arith.index_cast %arg5 : index to i32 - %43 = tt.splat %42 : i32 -> tensor<1x256xi32, #blocked> + %15 = scf.for %arg5 = %c0_i32 to %c3136_i32 step %c256_i32 iter_args(%arg6 = %cst) -> (tensor<1x256xf32, #blocked>) : i32 { + %43 = tt.splat %arg5 : i32 -> tensor<1x256xi32, #blocked> %44 = arith.addi %43, %10 : tensor<1x256xi32, #blocked> %45 = arith.cmpi "slt", %44, %cst_4 : tensor<1x256xi32, #blocked> %46 = arith.remsi %44, %cst_3 : tensor<1x256xi32, #blocked> diff --git a/test/TritonGPU/matmul.mlir b/test/TritonGPU/matmul.mlir index 50b90037e7ae..d83aa38f5841 100644 --- a/test/TritonGPU/matmul.mlir +++ b/test/TritonGPU/matmul.mlir @@ -6,8 +6,8 @@ module { tt.func @matmul_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32_i32_i32_i32_i32_i32_i32__12c64_13c64_14c64_15c8(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32, %arg8: i32 {tt.divisibility = 16 : i32}, %arg9: i32, %arg10: i32 {tt.divisibility = 16 : i32}, %arg11: i32) { %cst = arith.constant dense : tensor<64x64xi1> - %c64 = arith.constant 64 : index - %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : i32 + %c0 = arith.constant 0 : i32 %cst_0 = arith.constant dense<0.000000e+00> : tensor<64x64xf32> %c64_i32 = arith.constant 64 : i32 %c63_i32 = arith.constant 63 : i32 @@ -58,8 +58,7 @@ tt.func @matmul_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32_i32_i32_i32_i32_i32_i32__1 %43 = arith.addi %41, %42 : tensor<64x64xi32> %44 = tt.splat %arg1 : !tt.ptr -> tensor<64x64x!tt.ptr> %45 = tt.addptr %44, %43 : tensor<64x64x!tt.ptr>, tensor<64x64xi32> - %46 = arith.index_cast %arg5 : i32 to index - %47:3 = scf.for %arg12 = %c0 to %46 step %c64 iter_args(%arg13 = %cst_0, %arg14 = %34, %arg15 = %45) -> (tensor<64x64xf32>, tensor<64x64x!tt.ptr>, tensor<64x64x!tt.ptr>) { + %47:3 = scf.for %arg12 = %c0 to %arg5 step %c64 iter_args(%arg13 = %cst_0, %arg14 = %34, %arg15 = %45) -> (tensor<64x64xf32>, tensor<64x64x!tt.ptr>, tensor<64x64x!tt.ptr>) : i32 { %76 = tt.load %arg14, %cst, %cst_0 : tensor<64x64x!tt.ptr> %77 = tt.load %arg15, %cst, %cst_0 : tensor<64x64x!tt.ptr> %78 = tt.dot %76, %77, %cst_0 : tensor<64x64xf32> * tensor<64x64xf32> -> tensor<64x64xf32> diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp index fdc1e37b71fc..5eb4c54de778 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp @@ -151,8 +151,8 @@ bool verifyNonNegativeExpr(Value expr, const DenseSet &assumptions) { .Case( [&](auto) { return false; }) .Case( + arith::FPToUIOp, arith::MaxUIOp, arith::MinUIOp, arith::RemUIOp, + arith::ShRUIOp>( // These OPs also return unsigned values. // TODO: We can also sniff whether a Value is unsigned by looking // for whether or not it's used as an argument to one of