From 9b7bd6771fff9ef3a0fac8d08b04035db947c237 Mon Sep 17 00:00:00 2001 From: Jez Ng Date: Tue, 6 Aug 2024 11:35:51 -0400 Subject: [PATCH] Make tl.debug_barrier() a no-op on CPU (#89) --- .../lib/TritonCPUToLLVM/DebugOpsToLLVM.cpp | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/third_party/cpu/lib/TritonCPUToLLVM/DebugOpsToLLVM.cpp b/third_party/cpu/lib/TritonCPUToLLVM/DebugOpsToLLVM.cpp index e6b6a531059c..2bad397c9b77 100644 --- a/third_party/cpu/lib/TritonCPUToLLVM/DebugOpsToLLVM.cpp +++ b/third_party/cpu/lib/TritonCPUToLLVM/DebugOpsToLLVM.cpp @@ -3,6 +3,8 @@ #include "cpu/include/TritonCPUToLLVM/Passes.h" +#include "mlir/Dialect/GPU/IR/GPUOps.h.inc" + #include "triton/Conversion/TritonGPUToLLVM/Utility.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonCPU/IR/Dialect.h" @@ -164,6 +166,23 @@ struct PrintOpConversion : public ConvertOpToLLVMPattern { } }; +using BarrierOp = mlir::gpu::BarrierOp; + +// This is part of the DebugOps pass because gpu::barrier is generated by +// tl.debug_barrier. +struct BarrierOpConversion : public ConvertOpToLLVMPattern { + explicit BarrierOpConversion(LLVMTypeConverter &typeConverter) + : mlir::ConvertOpToLLVMPattern(typeConverter) {} + + LogicalResult + matchAndRewrite(BarrierOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Just make it a no-op for now + rewriter.eraseOp(op); + return success(); + } +}; + struct DebugOpsToLLVM : public triton::impl::DebugOpsToLLVMBase { using DebugOpsToLLVMBase::DebugOpsToLLVMBase; @@ -180,6 +199,7 @@ struct DebugOpsToLLVM RewritePatternSet patterns(context); patterns.add(typeConverter); + patterns.add(typeConverter); // patterns.add(typeConverter); if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) {