From 6653168be29ca64ca42be164ab617024f915cac4 Mon Sep 17 00:00:00 2001 From: Christian Sigg Date: Mon, 16 Sep 2024 13:44:40 -0700 Subject: [PATCH] Integrate Triton up to [50d803cd](https://github.com/openai/triton/commits/50d803cdb4e68910ed663251100e168ea4d2519d PiperOrigin-RevId: 675275491 --- .../triton/llvm_integration/cl667589778.patch | 98 ------- .../triton/llvm_integration/cl670868497.patch | 37 --- .../triton/llvm_integration/cl673225353.patch | 26 -- .../triton/llvm_integration/series.bzl | 2 - .../fix_index_cast_op_lowering_to_llvm.patch | 16 -- third_party/triton/temporary/fp8_fix.patch | 265 ------------------ third_party/triton/temporary/series.bzl | 1 - .../splat-value-shift-too-large.patch | 24 -- third_party/triton/workspace.bzl | 4 +- .../triton/xla_extensions/sparse_dot.patch | 20 +- 10 files changed, 12 insertions(+), 481 deletions(-) delete mode 100644 third_party/triton/llvm_integration/cl667589778.patch delete mode 100644 third_party/triton/llvm_integration/cl670868497.patch delete mode 100644 third_party/triton/llvm_integration/cl673225353.patch delete mode 100644 third_party/triton/temporary/fix_index_cast_op_lowering_to_llvm.patch delete mode 100644 third_party/triton/temporary/fp8_fix.patch delete mode 100644 third_party/triton/temporary/splat-value-shift-too-large.patch diff --git a/third_party/triton/llvm_integration/cl667589778.patch b/third_party/triton/llvm_integration/cl667589778.patch deleted file mode 100644 index eee23cff21986..0000000000000 --- a/third_party/triton/llvm_integration/cl667589778.patch +++ /dev/null @@ -1,98 +0,0 @@ ---- a/include/triton/Analysis/Alias.h 2023-10-19 13:35:54.000000000 -0700 -+++ b/include/triton/Analysis/Alias.h 2024-08-26 08:31:25.000000000 -0700 -@@ -85,10 +85,9 @@ - } - - /// Computes if the alloc set of the results are changed. -- void -- visitOperation(Operation *op, -- ArrayRef *> operands, -- ArrayRef *> results) override; -+ LogicalResult visitOperation( -+ Operation *op, ArrayRef *> operands, -+ ArrayRef *> results) override; - }; - - } // namespace mlir - ---- a/lib/Analysis/Alias.cpp 2024-06-07 05:28:31.000000000 -0700 -+++ b/lib/Analysis/Alias.cpp 2024-08-26 08:31:25.000000000 -0700 -@@ -21,7 +21,7 @@ - return ret; - } - --void SharedMemoryAliasAnalysis::visitOperation( -+LogicalResult SharedMemoryAliasAnalysis::visitOperation( - Operation *op, ArrayRef *> operands, - ArrayRef *> results) { - AliasInfo aliasInfo; -@@ -31,7 +31,7 @@ - if (auto memdescTy = dyn_cast(result.getType())) { - if (!isa_and_nonnull( - memdescTy.getMemorySpace())) -- return; -+ return mlir::success(); - } - - // Only LocalAllocOp creates a new buffer. -@@ -49,11 +49,13 @@ - } - - if (pessimistic) { -- return setAllToEntryStates(results); -+ setAllToEntryStates(results); -+ return mlir::success(); - } - // Join all lattice elements - for (auto *result : results) - propagateIfChanged(result, result->join(aliasInfo)); -+ return mlir::success(); - } - - AliasResult SharedMemoryAliasAnalysis::alias(Value lhs, Value rhs) { - ---- a/lib/Analysis/AxisInfo.cpp 2024-07-03 07:14:55.000000000 -0700 -+++ b/lib/Analysis/AxisInfo.cpp 2024-08-26 08:31:25.000000000 -0700 -@@ -195,9 +195,9 @@ - dataflow::Lattice>::getLatticeElement; - using FuncAxisInfoMapT = DenseMap; - -- void visitOperation(Operation *op, -- ArrayRef *> operands, -- ArrayRef *> results) override; -+ LogicalResult visitOperation( -+ Operation *op, ArrayRef *> operands, -+ ArrayRef *> results) override; - void - visitForOpInductionVar(scf::ForOp op, - ArrayRef *> argLattices); -@@ -1039,7 +1039,7 @@ - visitors.append(); - } - --void AxisInfoAnalysis::visitOperation( -+LogicalResult AxisInfoAnalysis::visitOperation( - Operation *op, ArrayRef *> operands, - ArrayRef *> results) { - // TODO: For sure not the right way to do this -@@ -1048,8 +1048,10 @@ - if (op->getValue().getRank() == 0) - setToEntryState((dataflow::Lattice *)op); - AxisInfo curr = visitors.apply(op, operands); -- if (curr.getRank() == 0) -- return setAllToEntryStates(results); -+ if (curr.getRank() == 0) { -+ setAllToEntryStates(results); -+ return mlir::success(); -+ } - // override with hint - auto newContiguity = curr.getContiguity(); - auto newDivisibility = curr.getDivisibility(); -@@ -1071,6 +1073,7 @@ - // join all lattice elements - for (auto *result : results) - propagateIfChanged(result, result->join(curr)); -+ return mlir::success(); - } - - void AxisInfoAnalysis::visitForOpInductionVar( diff --git a/third_party/triton/llvm_integration/cl670868497.patch b/third_party/triton/llvm_integration/cl670868497.patch deleted file mode 100644 index b737800e467ef..0000000000000 --- a/third_party/triton/llvm_integration/cl670868497.patch +++ /dev/null @@ -1,37 +0,0 @@ - ---- a/include/triton/Analysis/Alias.h 2024-08-27 12:43:55.000000000 -0700 -+++ b/include/triton/Analysis/Alias.h 2024-09-04 01:31:58.000000000 -0700 -@@ -81,7 +81,7 @@ - void setToEntryState(dataflow::Lattice *lattice) override { - propagateIfChanged( - lattice, lattice->join( -- AliasInfo::getPessimisticValueState(lattice->getPoint()))); -+ AliasInfo::getPessimisticValueState(lattice->getAnchor()))); - } - - /// Computes if the alloc set of the results are changed. - ---- a/lib/Analysis/AxisInfo.cpp 2024-08-27 12:43:55.000000000 -0700 -+++ b/lib/Analysis/AxisInfo.cpp 2024-09-04 01:31:58.000000000 -0700 -@@ -173,7 +173,7 @@ - void setToEntryState(dataflow::Lattice *lattice) override { - propagateIfChanged( - lattice, -- lattice->join(AxisInfo::getPessimisticValueState(lattice->getPoint()))); -+ lattice->join(AxisInfo::getPessimisticValueState(lattice->getAnchor()))); - } - - void visitNonControlFlowArguments( - ---- a/lib/Target/LLVMIR/LLVMDIScope.cpp 2024-05-14 06:33:36.000000000 -0700 -+++ b/lib/Target/LLVMIR/LLVMDIScope.cpp 2024-09-04 01:31:58.000000000 -0700 -@@ -105,7 +105,8 @@ - context, distinctId, compileUnitAttr, fileAttr, funcNameAttr, - funcNameAttr, fileAttr, - /*line=*/line, -- /*scopeline=*/line, subprogramFlags, subroutineTypeAttr); -+ /*scopeline=*/line, subprogramFlags, subroutineTypeAttr, -+ /*retainedNodes=*/{}); - funcOp->setLoc(FusedLoc::get(context, {loc}, subprogramAttr)); - } - diff --git a/third_party/triton/llvm_integration/cl673225353.patch b/third_party/triton/llvm_integration/cl673225353.patch deleted file mode 100644 index 12ad71bb6f4f9..0000000000000 --- a/third_party/triton/llvm_integration/cl673225353.patch +++ /dev/null @@ -1,26 +0,0 @@ - ---- a/include/triton/Analysis/AxisInfo.h 2024-03-11 11:42:57.000000000 -0700 -+++ b/include/triton/Analysis/AxisInfo.h 2024-09-10 21:57:51.000000000 -0700 -@@ -180,8 +180,8 @@ - for (auto funcOp : llvm::reverse(sortedFuncs)) { - initialize(funcOp); - funcOp.walk([&](CallOpInterface callOp) { -- auto callee = -- dyn_cast(callOp.resolveCallable(&symbolTable)); -+ auto callee = dyn_cast( -+ callOp.resolveCallableInTable(&symbolTable)); - update(callOp, callee); - }); - } - ---- a/include/triton/Analysis/Utility.h 2024-08-14 09:36:23.000000000 -0700 -+++ b/include/triton/Analysis/Utility.h 2024-09-10 21:57:51.000000000 -0700 -@@ -316,7 +316,7 @@ - moduleOp.walk([&](Operation *op) { - auto caller = op->getParentOfType(); - if (auto callOp = dyn_cast(op)) { -- auto *callee = callOp.resolveCallable(&symbolTable); -+ auto *callee = callOp.resolveCallableInTable(&symbolTable); - auto funcOp = dyn_cast_or_null(callee); - if (funcOp) { - graph[caller].emplace_back( diff --git a/third_party/triton/llvm_integration/series.bzl b/third_party/triton/llvm_integration/series.bzl index 41d80ac1132c5..656b9c894904d 100644 --- a/third_party/triton/llvm_integration/series.bzl +++ b/third_party/triton/llvm_integration/series.bzl @@ -8,7 +8,5 @@ LLVM nor MLIR integrator, please do not add any patches to this list. """ llvm_patch_list = [ - "//third_party/triton:llvm_integration/cl670868497.patch", - "//third_party/triton:llvm_integration/cl673225353.patch", # Add new patches just above this line ] diff --git a/third_party/triton/temporary/fix_index_cast_op_lowering_to_llvm.patch b/third_party/triton/temporary/fix_index_cast_op_lowering_to_llvm.patch deleted file mode 100644 index e10db851f985b..0000000000000 --- a/third_party/triton/temporary/fix_index_cast_op_lowering_to_llvm.patch +++ /dev/null @@ -1,16 +0,0 @@ ---- a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp -+++ b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp -@@ -611,10 +611,10 @@ struct IndexCastOpLowering - if (targetBits == sourceBits) - return {operands[0][0]}; - if (targetBits < sourceBits) -- return {rewriter.replaceOpWithNewOp(op, elemTy, -- operands[0][0])}; -+ return { -+ rewriter.create(op.getLoc(), elemTy, operands[0][0])}; - return { -- rewriter.replaceOpWithNewOp(op, elemTy, operands[0][0])}; -+ rewriter.create(op.getLoc(), elemTy, operands[0][0])}; - } - }; - diff --git a/third_party/triton/temporary/fp8_fix.patch b/third_party/triton/temporary/fp8_fix.patch deleted file mode 100644 index 661ce4e7a2f40..0000000000000 --- a/third_party/triton/temporary/fp8_fix.patch +++ /dev/null @@ -1,265 +0,0 @@ -This patch can be removed as part of the next integrate. -The corresponding import patch has already been added. - -==== triton/include/triton/Dialect/Triton/IR/TritonTypes.td#13 - triton/include/triton/Dialect/Triton/IR/TritonTypes.td ==== -# action=edit type=text ---- triton/include/triton/Dialect/Triton/IR/TritonTypes.td 2024-06-07 05:28:31.000000000 -0700 -+++ triton/include/triton/Dialect/Triton/IR/TritonTypes.td 2024-08-20 06:34:55.000000000 -0700 -@@ -15,7 +15,7 @@ - } - - // Floating-point Type --def TT_Float : AnyTypeOf<[F8E4M3FNUZ, F8E5M2, F8E5M2FNUZ, F16, BF16, F32, F64], "floating-point">; -+def TT_Float : AnyTypeOf<[F8E4M3FN, F8E4M3FNUZ, F8E5M2, F8E5M2FNUZ, F16, BF16, F32, F64], "floating-point">; - def TT_FloatTensor : RankedTensorOf<[TT_Float]>; - def TT_FloatLike : AnyTypeOf<[TT_Float, TT_FloatTensor]>; - -==== triton/lib/Analysis/Utility.cpp#42 - triton/lib/Analysis/Utility.cpp ==== -# action=edit type=text ---- triton/lib/Analysis/Utility.cpp 2024-08-14 09:36:23.000000000 -0700 -+++ triton/lib/Analysis/Utility.cpp 2024-08-20 06:34:55.000000000 -0700 -@@ -425,6 +425,7 @@ - if (a.getIntOrFloatBitWidth() != b.getIntOrFloatBitWidth()) - return false; - -+ auto F8E4M3FN = TypeID::get(); - auto F8E5M2 = TypeID::get(); - auto F8E4M3FNUZ = TypeID::get(); - auto F8E5M2FNUZ = TypeID::get(); -@@ -436,6 +437,7 @@ - {F32, F32}, - {F16, F16}, - {BF16, BF16}, -+ {F8E4M3FN, F8E4M3FN}, - {F8E5M2, F8E5M2}, - {F8E4M3FNUZ, F8E4M3FNUZ}, - {F8E4M3FNUZ, F8E5M2FNUZ}, -@@ -495,14 +497,14 @@ - return false; - if (!(numWarps % 4 == 0 && retShapePerCTA[rank - 2] % 64 == 0 && - retShapePerCTA[rank - 1] % 8 == 0 && -- (aElemTy.isFloat8E5M2() || aElemTy.isFloat8E4M3FNUZ() || -+ (aElemTy.isFloat8E5M2() || aElemTy.isFloat8E4M3FN() || - aElemTy.isInteger(8) || aElemTy.isF16() || aElemTy.isBF16() || - aElemTy.isF32()))) { - return false; - } - // We cannot use MMA_V3 if we need to accumulate in F32 within the MMA op. - if (op.getMaxNumImpreciseAcc() < 32 && -- (aElemTy.isFloat8E5M2() || aElemTy.isFloat8E4M3FNUZ()) && -+ (aElemTy.isFloat8E5M2() || aElemTy.isFloat8E4M3FN()) && - cast(op.getType()).getElementType().isF32()) { - return false; - } -==== triton/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp#20 - triton/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp ==== -# action=edit type=text ---- triton/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp 2024-06-07 05:28:31.000000000 -0700 -+++ triton/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp 2024-08-20 06:34:55.000000000 -0700 -@@ -34,6 +34,9 @@ - addConversion([&](mlir::Float8E4M3FNUZType type) -> std::optional { - return IntegerType::get(type.getContext(), 8); - }); -+ addConversion([&](mlir::Float8E4M3FNType type) -> std::optional { -+ return IntegerType::get(type.getContext(), 8); -+ }); - addConversion([&](mlir::Float8E5M2Type type) -> std::optional { - return IntegerType::get(type.getContext(), 8); - }); -==== triton/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp#44 - triton/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp ==== -# action=edit type=text ---- triton/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp 2024-07-31 01:05:00.000000000 -0700 -+++ triton/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp 2024-08-20 06:40:32.000000000 -0700 -@@ -382,7 +382,7 @@ - NvidiaMmaEncodingAttr mmaLayout = - dyn_cast(D.getType().getEncoding()); - if (mmaLayout) { -- bool isNativeFP8 = AElType.isFloat8E5M2() || AElType.isFloat8E4M3FNUZ(); -+ bool isNativeFP8 = AElType.isFloat8E5M2() || AElType.isFloat8E4M3FN(); - // promote operands for sm < 89 since fp8 mma is not natively supported - // promote operands for sm >= 90 when mma is not v3 - if (!isNativeFP8 || -==== triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp#39 - triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp ==== -# action=edit type=text ---- triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp 2024-08-14 09:36:23.000000000 -0700 -+++ triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp 2024-08-20 06:34:55.000000000 -0700 -@@ -45,8 +45,9 @@ - SmallVector validN; - - // MMAv3 with larger instruction shape is preferred. -- if (eltType.isFloat8E5M2() || eltType.isFloat8E4M3FNUZ() || -- eltType.isF16() || eltType.isBF16() || eltType.isF32()) { -+ if (eltType.isFloat8E5M2() || eltType.isFloat8E4M3FN() || -+ eltType.isFloat8E4M3FNUZ() || eltType.isF16() || eltType.isBF16() || -+ eltType.isF32()) { - validN.assign({256, 248, 240, 232, 224, 216, 208, 200, 192, 184, 176, - 168, 160, 152, 144, 136, 128, 120, 112, 104, 96, 88, - 80, 72, 64, 56, 48, 40, 32, 24, 16, 8}); -==== triton/python/src/ir.cc#24 - triton/python/src/ir.cc ==== -# action=edit type=text ---- triton/python/src/ir.cc 2024-08-12 00:24:31.000000000 -0700 -+++ triton/python/src/ir.cc 2024-08-21 01:46:02.000000000 -0700 -@@ -745,10 +745,8 @@ - return self.getBuilder().getI64Type(); - }) - .def("get_fp8e4nv_ty", -- // TODO: fp8e4nv is using Float8E4M3FNUZType, which -- // does not seem right. It should use FloatE4M3FNType - [](TritonOpBuilder &self) -> Type { -- return self.getBuilder().getType(); -+ return self.getBuilder().getType(); - }) - .def("get_fp8e4b8_ty", - [](TritonOpBuilder &self) -> Type { -==== triton/test/Conversion/tritongpu_to_llvm_hopper.mlir#25 - triton/test/Conversion/tritongpu_to_llvm_hopper.mlir ==== -# action=edit type=text ---- triton/test/Conversion/tritongpu_to_llvm_hopper.mlir 2024-07-03 07:14:55.000000000 -0700 -+++ triton/test/Conversion/tritongpu_to_llvm_hopper.mlir 2024-08-20 06:34:55.000000000 -0700 -@@ -129,24 +129,24 @@ - module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { - // CHECK-LABEL: test_fp8_to_f16_conversion - tt.func @test_fp8_to_f16_conversion( -- %in0: tensor<128xf8E5M2, #blocked>, %in1: tensor<128xf8E4M3FNUZ, #blocked>, -+ %in0: tensor<128xf8E5M2, #blocked>, %in1: tensor<128xf8E4M3FN, #blocked>, - %in2: tensor<128xf16, #blocked>, %in3: tensor<128xf32, #blocked>) { - // CHECK-COUNT-2: cvt.rn.f16x2.e5m2x2 {{.*}} "=r,h" %{{.*}} : (i16) -> vector<2xf16> - %out0 = tt.fp_to_fp %in0 : tensor<128xf8E5M2, #blocked> -> tensor<128xf16, #blocked> - // CHECK-COUNT-2: cvt.rn.f16x2.e4m3x2 {{.*}} "=r,h" %{{.*}} : (i16) -> vector<2xf16> -- %out1 = tt.fp_to_fp %in1 : tensor<128xf8E4M3FNUZ, #blocked> -> tensor<128xf16, #blocked> -+ %out1 = tt.fp_to_fp %in1 : tensor<128xf8E4M3FN, #blocked> -> tensor<128xf16, #blocked> - // CHECK-COUNT-2: mul.rn.bf16x2 - %out2 = tt.fp_to_fp %in0 : tensor<128xf8E5M2, #blocked> -> tensor<128xbf16, #blocked> - - // CHECK-COUNT-2: cvt.rn.satfinite.e5m2x2.f16x2 {{.*}} "=h,r" %{{.*}} : (i32) -> vector<2xi8> - %out3 = tt.fp_to_fp %in2, rounding = rtne : tensor<128xf16, #blocked> -> tensor<128xf8E5M2, #blocked> - // CHECK-COUNT-2: cvt.rn.satfinite.e4m3x2.f16x2 {{.*}} "=h,r" %{{.*}} : (i32) -> vector<2xi8> -- %out4 = tt.fp_to_fp %in2, rounding = rtne : tensor<128xf16, #blocked> -> tensor<128xf8E4M3FNUZ, #blocked> -+ %out4 = tt.fp_to_fp %in2, rounding = rtne : tensor<128xf16, #blocked> -> tensor<128xf8E4M3FN, #blocked> - - // CHECK-COUNT-2: cvt.rn.satfinite.e5m2x2.f32 {{.*}} "=h,r,r" %{{.*}}, %{{.*}} : (i32, i32) -> vector<2xi8> - %out5 = tt.fp_to_fp %in3, rounding = rtne : tensor<128xf32, #blocked> -> tensor<128xf8E5M2, #blocked> - // CHECK-COUNT-2: cvt.rn.satfinite.e4m3x2.f32 {{.*}} "=h,r,r" %{{.*}}, %{{.*}} : (i32, i32) -> vector<2xi8> -- %out6 = tt.fp_to_fp %in3, rounding = rtne : tensor<128xf32, #blocked> -> tensor<128xf8E4M3FNUZ, #blocked> -+ %out6 = tt.fp_to_fp %in3, rounding = rtne : tensor<128xf32, #blocked> -> tensor<128xf8E4M3FN, #blocked> - tt.return - } - } -==== triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp#4 - triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp ==== -# action=edit type=text ---- triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp 2024-05-14 06:33:36.000000000 -0700 -+++ triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp 2024-08-20 06:34:55.000000000 -0700 -@@ -81,9 +81,9 @@ - FP32_TF32_TF32_FP32, - FP16_FP16_FP16_FP16, - FP32_FP8E5M2_FP8E5M2_FP32, -- FP32_FP8E5M2_FP8E4M3FNUZ_FP32, -- FP32_FP8E4M3FNUZ_FP8E5M2_FP32, -- FP32_FP8E4M3FNUZ_FP8E4M3FNUZ_FP32, -+ FP32_FP8E5M2_FP8E4M3FN_FP32, -+ FP32_FP8E4M3FN_FP8E5M2_FP32, -+ FP32_FP8E4M3FN_FP8E4M3FN_FP32, - // integer tensor core instr - INT32_INT1_INT1_INT32, // Not implemented - INT32_INT4_INT4_INT32, // Not implemented -@@ -112,9 +112,9 @@ - case TensorCoreType::FP16_FP16_FP16_FP16: - return fp16x2Pack2Ty; - case TensorCoreType::FP32_FP8E5M2_FP8E5M2_FP32: -- case TensorCoreType::FP32_FP8E5M2_FP8E4M3FNUZ_FP32: -- case TensorCoreType::FP32_FP8E4M3FNUZ_FP8E5M2_FP32: -- case TensorCoreType::FP32_FP8E4M3FNUZ_FP8E4M3FNUZ_FP32: -+ case TensorCoreType::FP32_FP8E5M2_FP8E4M3FN_FP32: -+ case TensorCoreType::FP32_FP8E4M3FN_FP8E5M2_FP32: -+ case TensorCoreType::FP32_FP8E4M3FN_FP8E4M3FN_FP32: - return fp32x4Ty; - case TensorCoreType::INT32_INT8_INT8_INT32: - return i32x4Ty; -@@ -140,14 +140,14 @@ - bTy.getElementType().isFloat8E5M2()) - return TensorCoreType::FP32_FP8E5M2_FP8E5M2_FP32; - if (aTy.getElementType().isFloat8E5M2() && -- bTy.getElementType().isFloat8E4M3FNUZ()) -- return TensorCoreType::FP32_FP8E5M2_FP8E4M3FNUZ_FP32; -- if (aTy.getElementType().isFloat8E4M3FNUZ() && -+ bTy.getElementType().isFloat8E4M3FN()) -+ return TensorCoreType::FP32_FP8E5M2_FP8E4M3FN_FP32; -+ if (aTy.getElementType().isFloat8E4M3FN() && - bTy.getElementType().isFloat8E5M2()) -- return TensorCoreType::FP32_FP8E4M3FNUZ_FP8E5M2_FP32; -- if (aTy.getElementType().isFloat8E4M3FNUZ() && -- bTy.getElementType().isFloat8E4M3FNUZ()) -- return TensorCoreType::FP32_FP8E4M3FNUZ_FP8E4M3FNUZ_FP32; -+ return TensorCoreType::FP32_FP8E4M3FN_FP8E5M2_FP32; -+ if (aTy.getElementType().isFloat8E4M3FN() && -+ bTy.getElementType().isFloat8E4M3FN()) -+ return TensorCoreType::FP32_FP8E4M3FN_FP8E4M3FN_FP32; - if (aTy.getElementType().isF32() && bTy.getElementType().isF32() && - op.getInputPrecision() == InputPrecision::TF32) - return TensorCoreType::FP32_TF32_TF32_FP32; -@@ -193,11 +193,11 @@ - - {TensorCoreType::FP32_FP8E5M2_FP8E5M2_FP32, - "mma.sync.aligned.m16n8k32.row.col.f32.e5m2.e5m2.f32"}, -- {TensorCoreType::FP32_FP8E5M2_FP8E4M3FNUZ_FP32, -+ {TensorCoreType::FP32_FP8E5M2_FP8E4M3FN_FP32, - "mma.sync.aligned.m16n8k32.row.col.f32.e5m2.e4m3.f32"}, -- {TensorCoreType::FP32_FP8E4M3FNUZ_FP8E5M2_FP32, -+ {TensorCoreType::FP32_FP8E4M3FN_FP8E5M2_FP32, - "mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e5m2.f32"}, -- {TensorCoreType::FP32_FP8E4M3FNUZ_FP8E4M3FNUZ_FP32, -+ {TensorCoreType::FP32_FP8E4M3FN_FP8E4M3FN_FP32, - "mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32"}, - }; - -==== triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp#9 - triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp ==== -# action=edit type=text ---- triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp 2024-06-07 05:28:31.000000000 -0700 -+++ triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp 2024-08-20 06:34:55.000000000 -0700 -@@ -58,7 +58,7 @@ - return triton::nvgpu::WGMMAEltType::s8; - } else if (aTy.isFloat8E5M2()) { - return triton::nvgpu::WGMMAEltType::e5m2; -- } else if (aTy.isFloat8E4M3FNUZ()) { -+ } else if (aTy.isFloat8E4M3FN()) { - return triton::nvgpu::WGMMAEltType::e4m3; - } else { - llvm::report_fatal_error("Unsupported mma operand type found"); -==== triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp#9 - triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp ==== -# action=edit type=text ---- triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp 2024-07-17 02:05:59.000000000 -0700 -+++ triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp 2024-08-20 06:34:55.000000000 -0700 -@@ -386,7 +386,7 @@ - std::pair - getConversionFunc(Type srcTy, Type dstTy, - std::optional roundingMode) const { -- auto F8E4M3TyID = TypeID::get(); -+ auto F8E4M3TyID = TypeID::get(); - auto F8E5M2TyID = TypeID::get(); - auto F16TyID = TypeID::get(); - auto BF16TyID = TypeID::get(); -@@ -430,7 +430,7 @@ - llvm::report_fatal_error("Unsupported rounding mode for conversion."); - } - if (computeCapability < 89 && -- (srcTy.isFloat8E4M3FNUZ() || dstTy.isFloat8E4M3FNUZ())) { -+ (srcTy.isFloat8E4M3FN() || dstTy.isFloat8E4M3FN())) { - llvm::errs() << "Conversion from/to f8e4m3nv is only supported on " - "compute capability >= 89" - << "\n"; -@@ -452,7 +452,7 @@ - auto dstElementType = getElementType(op.getResult()); - auto roundingMode = op.getRounding(); - -- if (dstElementType.isFloat8E5M2() || dstElementType.isFloat8E4M3FNUZ()) { -+ if (dstElementType.isFloat8E5M2() || dstElementType.isFloat8E4M3FN()) { - assert(roundingMode.has_value() && - "Rounding mode must be specified for convertsions to fp8"); - -@@ -489,7 +489,7 @@ - - bool useFP16IntermediateSrc = - srcElementType.isF32() && -- (!(computeCapability >= 90 && (dstElementType.isFloat8E4M3FNUZ() || -+ (!(computeCapability >= 90 && (dstElementType.isFloat8E4M3FN() || - dstElementType.isFloat8E5M2())) || - roundingMode.value() == RoundingMode::RTZ); - bool isDstFP32 = dstElementType.isF32(); diff --git a/third_party/triton/temporary/series.bzl b/third_party/triton/temporary/series.bzl index 9f15f85169752..4fa55269e3323 100644 --- a/third_party/triton/temporary/series.bzl +++ b/third_party/triton/temporary/series.bzl @@ -14,6 +14,5 @@ those to this list. """ temporary_patch_list = [ - "//third_party/triton:temporary/fix_index_cast_op_lowering_to_llvm.patch", # Add new patches just above this line ] diff --git a/third_party/triton/temporary/splat-value-shift-too-large.patch b/third_party/triton/temporary/splat-value-shift-too-large.patch deleted file mode 100644 index 7538c90be79d9..0000000000000 --- a/third_party/triton/temporary/splat-value-shift-too-large.patch +++ /dev/null @@ -1,24 +0,0 @@ ---- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp -+++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp -@@ -272,6 +272,12 @@ struct LoadOpConversion : public Convert - ld(dstsOpr, addrOpr, evictOpr).predicate(pred, "b"); - - if (other) { -+ if (otherIsSplatConstInt) { -+ for (size_t s = valueElemNBits; s < movWidth; s += valueElemNBits) { -+ splatVal |= splatVal << valueElemNBits; -+ } -+ } -+ - for (size_t ii = 0; ii < nWords; ++ii) { - // PTX doesn't support mov.u8, so we need to use mov.u16 - PTXInstr &mov = -@@ -292,8 +298,6 @@ struct LoadOpConversion : public Convert - PTXInstr::Operand *opr{}; - - if (otherIsSplatConstInt) { -- for (size_t s = 0; s < 32; s += valueElemNBits) -- splatVal |= splatVal << valueElemNBits; - opr = ptxBuilder.newConstantOperand(splatVal); - } else - opr = ptxBuilder.newOperand(v, readConstraint); diff --git a/third_party/triton/workspace.bzl b/third_party/triton/workspace.bzl index 05c5703ca94cf..8f68ab621acb1 100644 --- a/third_party/triton/workspace.bzl +++ b/third_party/triton/workspace.bzl @@ -8,8 +8,8 @@ load("//third_party/triton:xla_extensions/series.bzl", "extensions_files_patch_l def repo(): """Imports Triton.""" - TRITON_COMMIT = "cl667508787" - TRITON_SHA256 = "3eb8f4172f8976be0a21003724c0fd9c2664409ed4e64cec305abcf1f4f6dd3a" + TRITON_COMMIT = "cl673813747" + TRITON_SHA256 = "3e901c1b441407b1b7ac601092f64a9141571879b00a1ff54437c8e9370a365f" tf_http_archive( name = "triton", sha256 = TRITON_SHA256, diff --git a/third_party/triton/xla_extensions/sparse_dot.patch b/third_party/triton/xla_extensions/sparse_dot.patch index dadc7732a4f28..8f613badb5398 100644 --- a/third_party/triton/xla_extensions/sparse_dot.patch +++ b/third_party/triton/xla_extensions/sparse_dot.patch @@ -2,7 +2,7 @@ diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/ index 04ba95196..192d8fab4 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td -@@ -1302,6 +1302,18 @@ elements along the K dim, or they use all elements of the tensor along the K dim +@@ -1349,6 +1349,18 @@ elements along the K dim, or they use all elements of the tensor along the K dim }]; } @@ -33,7 +33,7 @@ index a87e1c44a..456a4f224 100644 include "mlir/IR/OpBase.td" include "mlir/Interfaces/SideEffectInterfaces.td" // Pure include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType -@@ -238,4 +239,19 @@ def TTG_LocalStoreOp : TTG_Op<"local_store", [DeclareOpInterfaceMethods shape, +@@ -497,6 +497,123 @@ getDefaultBlockedEncoding(MLIRContext *context, ArrayRef shape, return encoding; } @@ -280,7 +280,7 @@ index d74e0a224..4e45f7c4c 100644 static void createAsyncCopy(scf::ForOp &forOp, tt::LoadOp loadOp, Value alloc, Value insertIdx, Value extractIdx, tt::CoarseSchedule &schedule, -@@ -248,19 +252,28 @@ getSharedEncIfAllUsersAreDotEnc(Value val, bool &incompatible) { +@@ -236,19 +240,28 @@ getSharedEncIfAllUsersAreDotEnc(Value val, bool &incompatible) { } else { if (!isa(user)) return std::nullopt; @@ -321,7 +321,7 @@ index d74e0a224..4e45f7c4c 100644 } // Check that the shared encodings needed by the users are compatible. if (attr != nullptr && attr != tempAttr) { -@@ -369,7 +382,7 @@ loadOpsToIndirectionLevelAndUse(scf::ForOp forOp) { +@@ -357,7 +370,7 @@ loadOpsToIndirectionLevelAndUse(scf::ForOp forOp) { }; for (Operation &op : forOp.getBody()->without_terminator()) { @@ -330,7 +330,7 @@ index d74e0a224..4e45f7c4c 100644 continue; seen.clear(); dfs(&op, 0, &op); -@@ -446,7 +459,7 @@ assignMemoryLayouts(llvm::SmallVector> +@@ -434,7 +447,7 @@ assignMemoryLayouts(llvm::SmallVector> continue; } @@ -339,7 +339,7 @@ index d74e0a224..4e45f7c4c 100644 loadInfo.usedByDot = true; if (loadIsMMAv3(op)) { loadInfo.loadIsMMAV3 = true; -@@ -472,7 +485,7 @@ assignMemoryLayouts(llvm::SmallVector> +@@ -460,7 +473,7 @@ assignMemoryLayouts(llvm::SmallVector> // The codegen bug is caught by an assertion, so if you think you've // fixed it, feel free to delete this code and see if the assert still // fails. :) @@ -375,7 +375,7 @@ index 7affd8840..52aa2c131 100644 --- a/third_party/nvidia/include/Dialect/NVGPU/IR/NVGPUOps.td +++ b/third_party/nvidia/include/Dialect/NVGPU/IR/NVGPUOps.td @@ -87,6 +87,15 @@ def NVGPU_WGMMAOp : NVGPU_Op<"wgmma", []> { - let assemblyFormat = "$opA `,` $opB (`,` $opC^)? attr-dict `:` functional-type(operands, $res)"; + let assemblyFormat = "$opA `,` $opB `,` $useC (`,` $opC^)? attr-dict `:` functional-type(operands, $res)"; } +def NVGPU_SparseWGMMAOp : NVGPU_Op<"wgmma_sp", []> { @@ -409,11 +409,11 @@ diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp b/ index f2742218e..4cb1fae93 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp -@@ -68,6 +68,7 @@ public: +@@ -69,6 +69,7 @@ public: addIllegalDialect(); addIllegalDialect(); addLegalOp(); + addLegalOp(); // Rewritten in a separate pass. } }; - +