From 66e41a1a20f2190a800669028a0e80bd86e735ce Mon Sep 17 00:00:00 2001 From: Guray Ozen Date: Fri, 10 Jan 2025 10:32:25 +0100 Subject: [PATCH] [MLIR][NVVM] Declare InferIntRangeInterface for RangeableRegisterOp (#122263) --- .../include/mlir/Dialect/LLVMIR/NVVMDialect.h | 1 + mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 16 +++++++-- mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 11 ++++++ mlir/test/Dialect/LLVMIR/nvvm-test-range.mlir | 35 +++++++++++++++++++ 4 files changed, 61 insertions(+), 2 deletions(-) create mode 100644 mlir/test/Dialect/LLVMIR/nvvm-test-range.mlir diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h index 4fd00ff929bd..50d1a39126ea 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h @@ -19,6 +19,7 @@ #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" +#include "mlir/Interfaces/InferIntRangeInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "llvm/IR/IntrinsicsNVPTX.h" diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index a2d2102b59de..0b9097e9bbca 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -18,6 +18,7 @@ include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.td" include "mlir/Dialect/LLVMIR/LLVMOpBase.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td" +include "mlir/Interfaces/InferIntRangeInterface.td" def LLVM_PointerGeneric : LLVM_PointerInAddressSpace<0>; def LLVM_PointerGlobal : LLVM_PointerInAddressSpace<1>; @@ -134,8 +135,8 @@ class NVVM_SpecialRegisterOp traits = []> : let assemblyFormat = "attr-dict `:` type($res)"; } -class NVVM_SpecialRangeableRegisterOp traits = []> : - NVVM_SpecialRegisterOp { +class NVVM_SpecialRangeableRegisterOp : + NVVM_SpecialRegisterOp]> { let arguments = (ins OptionalAttr:$range); let assemblyFormat = "(`range` $range^)? attr-dict `:` type($res)"; let llvmBuilder = baseLlvmBuilder # setRangeRetAttrCode # baseLlvmBuilderCoda; @@ -147,6 +148,17 @@ class NVVM_SpecialRangeableRegisterOp traits = []> build($_builder, $_state, resultType, ::mlir::LLVM::ConstantRangeAttr{}); }]> ]; + + // Define this method for the InferIntRangeInterface. + let extraClassDefinition = [{ + // Infer the result ranges based on the range attribute. + void $cppClass::inferResultRanges( + ArrayRef<::mlir::ConstantIntRanges> argRanges, + SetIntRangeFn setResultRanges) { + nvvmInferResultRanges(getOperation(), getResult(), argRanges, setResultRanges); + } + }]; + } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index 8b09c0f386d6..838159d67654 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -1158,6 +1158,17 @@ llvm::Intrinsic::ID CpAsyncBulkTensorReduceOp::getIntrinsicID( llvm_unreachable("Invalid Reduction Op for CpAsyncBulkTensorReduceOp"); } +/// Infer the result ranges for the NVVM SpecialRangeableRegisterOp that might +/// have ConstantRangeAttr. +static void nvvmInferResultRanges(Operation *op, Value result, + ArrayRef<::mlir::ConstantIntRanges> argRanges, + SetIntRangeFn setResultRanges) { + if (auto rangeAttr = op->getAttrOfType("range")) { + setResultRanges(result, {rangeAttr.getLower(), rangeAttr.getUpper(), + rangeAttr.getLower(), rangeAttr.getUpper()}); + } +} + //===----------------------------------------------------------------------===// // NVVMDialect initialization, type parsing, and registration. //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/LLVMIR/nvvm-test-range.mlir b/mlir/test/Dialect/LLVMIR/nvvm-test-range.mlir new file mode 100644 index 000000000000..fae40dc7806b --- /dev/null +++ b/mlir/test/Dialect/LLVMIR/nvvm-test-range.mlir @@ -0,0 +1,35 @@ +// RUN: mlir-opt -int-range-optimizations %s | FileCheck %s +gpu.module @module{ + gpu.func @kernel_1() kernel { + %tidx = nvvm.read.ptx.sreg.tid.x range : i32 + %tidy = nvvm.read.ptx.sreg.tid.y range : i32 + %tidz = nvvm.read.ptx.sreg.tid.z range : i32 + %c64 = arith.constant 64 : i32 + + %1 = arith.cmpi sgt, %tidx, %c64 : i32 + scf.if %1 { + gpu.printf "threadidx" + } + %2 = arith.cmpi sgt, %tidy, %c64 : i32 + scf.if %2 { + gpu.printf "threadidy" + } + %3 = arith.cmpi sgt, %tidz, %c64 : i32 + scf.if %3 { + gpu.printf "threadidz" + } + gpu.return + } +} + +// CHECK-LABEL: gpu.func @kernel_1 +// CHECK: %[[false:.+]] = arith.constant false +// CHECK: %[[c64_i32:.+]] = arith.constant 64 : i32 +// CHECK: %[[S0:.+]] = nvvm.read.ptx.sreg.tid.y range : i32 +// CHECK: scf.if %[[false]] { +// CHECK: gpu.printf "threadidx" +// CHECK: %[[S1:.+]] = arith.cmpi sgt, %[[S0]], %[[c64_i32]] : i32 +// CHECK: scf.if %[[S1]] { +// CHECK: gpu.printf "threadidy" +// CHECK: scf.if %[[false]] { +// CHECK: gpu.printf "threadidz"