diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp index ef0508e7ef5f0ea..b11511f21d03d4d 100644 --- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp @@ -1057,17 +1057,21 @@ static LLVM::CallOp createSPIRVBuiltinCall(Location loc, OpBuilder &builder, return call; } -class ControlBarrierPattern - : public SPIRVToLLVMConversion { +template +class ControlBarrierPattern : public SPIRVToLLVMConversion { public: - using SPIRVToLLVMConversion::SPIRVToLLVMConversion; + using OpAdaptor = typename SPIRVToLLVMConversion::OpAdaptor; + + using SPIRVToLLVMConversion::SPIRVToLLVMConversion; + + static constexpr StringRef getFuncName(); LogicalResult - matchAndRewrite(spirv::ControlBarrierOp controlBarrierOp, OpAdaptor adaptor, + matchAndRewrite(BarrierOpTy controlBarrierOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - constexpr StringLiteral funcName = "_Z22__spirv_ControlBarrieriii"; + constexpr StringRef funcName = getFuncName(); Operation *symbolTable = - controlBarrierOp->getParentWithTrait(); + controlBarrierOp->template getParentWithTrait(); Type i32 = rewriter.getI32Type(); @@ -1266,6 +1270,24 @@ class GroupReducePattern : public SPIRVToLLVMConversion { } }; +template <> +constexpr StringRef +ControlBarrierPattern::getFuncName() { + return "_Z22__spirv_ControlBarrieriii"; +} + +template <> +constexpr StringRef +ControlBarrierPattern::getFuncName() { + return "_Z33__spirv_ControlBarrierArriveINTELiii"; +} + +template <> +constexpr StringRef +ControlBarrierPattern::getFuncName() { + return "_Z31__spirv_ControlBarrierWaitINTELiii"; +} + /// Converts `spirv.mlir.loop` to LLVM dialect. All blocks within selection /// should be reachable for conversion to succeed. The structure of the loop in /// LLVM dialect will be the following: @@ -1899,7 +1921,9 @@ void mlir::populateSPIRVToLLVMConversionPatterns( ReturnPattern, ReturnValuePattern, // Barrier ops - ControlBarrierPattern, + ControlBarrierPattern, + ControlBarrierPattern, + ControlBarrierPattern, // Group reduction operations GroupReducePattern, diff --git a/mlir/test/Conversion/SPIRVToLLVM/barrier-ops-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/barrier-ops-to-llvm.mlir index d53afeeea15d103..a5cae67a3d5c5d2 100644 --- a/mlir/test/Conversion/SPIRVToLLVM/barrier-ops-to-llvm.mlir +++ b/mlir/test/Conversion/SPIRVToLLVM/barrier-ops-to-llvm.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -convert-spirv-to-llvm %s | FileCheck %s +// RUN: mlir-opt -convert-spirv-to-llvm -split-input-file %s | FileCheck %s //===----------------------------------------------------------------------===// // spirv.ControlBarrierOp @@ -21,3 +21,28 @@ spirv.func @control_barrier() "None" { spirv.ControlBarrier , , spirv.Return } + +// ----- + +//===----------------------------------------------------------------------===// +// spirv.INTEL.SplitBarrier +//===----------------------------------------------------------------------===// + +// CHECK-DAG: llvm.func spir_funccc @_Z33__spirv_ControlBarrierArriveINTELiii(i32, i32, i32) attributes {convergent, no_unwind, will_return} +// CHECK-DAG: llvm.func spir_funccc @_Z31__spirv_ControlBarrierWaitINTELiii(i32, i32, i32) attributes {convergent, no_unwind, will_return} + +// CHECK-LABEL: @split_barrier +spirv.func @split_barrier() "None" { + // CHECK: [[EXECUTION:%.*]] = llvm.mlir.constant(2 : i32) : i32 + // CHECK: [[MEMORY:%.*]] = llvm.mlir.constant(2 : i32) : i32 + // CHECK: [[SEMANTICS:%.*]] = llvm.mlir.constant(768 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z33__spirv_ControlBarrierArriveINTELiii([[EXECUTION]], [[MEMORY]], [[SEMANTICS]]) {convergent, no_unwind, will_return} : (i32, i32, i32) -> () + spirv.INTEL.ControlBarrierArrive , , + + // CHECK: [[EXECUTION:%.*]] = llvm.mlir.constant(2 : i32) : i32 + // CHECK: [[MEMORY:%.*]] = llvm.mlir.constant(2 : i32) : i32 + // CHECK: [[SEMANTICS:%.*]] = llvm.mlir.constant(256 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z31__spirv_ControlBarrierWaitINTELiii([[EXECUTION]], [[MEMORY]], [[SEMANTICS]]) {convergent, no_unwind, will_return} : (i32, i32, i32) -> () + spirv.INTEL.ControlBarrierWait , , + spirv.Return +}