Skip to content

Commit

Permalink
Automerge: [MLIR][NVVM] Declare InferIntRangeInterface for RangeableR…
Browse files Browse the repository at this point in the history
…egisterOp (#122263)
  • Loading branch information
grypp authored and github-actions[bot] committed Jan 10, 2025
2 parents 2533bb2 + 66e41a1 commit 314646d
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 2 deletions.
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Interfaces/InferIntRangeInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "llvm/IR/IntrinsicsNVPTX.h"

Expand Down
16 changes: 14 additions & 2 deletions mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.td"
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td"
include "mlir/Interfaces/InferIntRangeInterface.td"

def LLVM_PointerGeneric : LLVM_PointerInAddressSpace<0>;
def LLVM_PointerGlobal : LLVM_PointerInAddressSpace<1>;
Expand Down Expand Up @@ -134,8 +135,8 @@ class NVVM_SpecialRegisterOp<string mnemonic, list<Trait> traits = []> :
let assemblyFormat = "attr-dict `:` type($res)";
}

class NVVM_SpecialRangeableRegisterOp<string mnemonic, list<Trait> traits = []> :
NVVM_SpecialRegisterOp<mnemonic, traits> {
class NVVM_SpecialRangeableRegisterOp<string mnemonic> :
NVVM_SpecialRegisterOp<mnemonic, [DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]> {
let arguments = (ins OptionalAttr<LLVM_ConstantRangeAttr>:$range);
let assemblyFormat = "(`range` $range^)? attr-dict `:` type($res)";
let llvmBuilder = baseLlvmBuilder # setRangeRetAttrCode # baseLlvmBuilderCoda;
Expand All @@ -147,6 +148,17 @@ class NVVM_SpecialRangeableRegisterOp<string mnemonic, list<Trait> traits = []>
build($_builder, $_state, resultType, ::mlir::LLVM::ConstantRangeAttr{});
}]>
];

// Define this method for the InferIntRangeInterface.
let extraClassDefinition = [{
// Infer the result ranges based on the range attribute.
void $cppClass::inferResultRanges(
ArrayRef<::mlir::ConstantIntRanges> argRanges,
SetIntRangeFn setResultRanges) {
nvvmInferResultRanges(getOperation(), getResult(), argRanges, setResultRanges);
}
}];

}

//===----------------------------------------------------------------------===//
Expand Down
11 changes: 11 additions & 0 deletions mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1158,6 +1158,17 @@ llvm::Intrinsic::ID CpAsyncBulkTensorReduceOp::getIntrinsicID(
llvm_unreachable("Invalid Reduction Op for CpAsyncBulkTensorReduceOp");
}

/// Infer the result ranges for the NVVM SpecialRangeableRegisterOp that might
/// have ConstantRangeAttr.
static void nvvmInferResultRanges(Operation *op, Value result,
ArrayRef<::mlir::ConstantIntRanges> argRanges,
SetIntRangeFn setResultRanges) {
if (auto rangeAttr = op->getAttrOfType<LLVM::ConstantRangeAttr>("range")) {
setResultRanges(result, {rangeAttr.getLower(), rangeAttr.getUpper(),
rangeAttr.getLower(), rangeAttr.getUpper()});
}
}

//===----------------------------------------------------------------------===//
// NVVMDialect initialization, type parsing, and registration.
//===----------------------------------------------------------------------===//
Expand Down
35 changes: 35 additions & 0 deletions mlir/test/Dialect/LLVMIR/nvvm-test-range.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// RUN: mlir-opt -int-range-optimizations %s | FileCheck %s
gpu.module @module{
gpu.func @kernel_1() kernel {
%tidx = nvvm.read.ptx.sreg.tid.x range <i32, 0, 32> : i32
%tidy = nvvm.read.ptx.sreg.tid.y range <i32, 0, 128> : i32
%tidz = nvvm.read.ptx.sreg.tid.z range <i32, 0, 4> : i32
%c64 = arith.constant 64 : i32

%1 = arith.cmpi sgt, %tidx, %c64 : i32
scf.if %1 {
gpu.printf "threadidx"
}
%2 = arith.cmpi sgt, %tidy, %c64 : i32
scf.if %2 {
gpu.printf "threadidy"
}
%3 = arith.cmpi sgt, %tidz, %c64 : i32
scf.if %3 {
gpu.printf "threadidz"
}
gpu.return
}
}

// CHECK-LABEL: gpu.func @kernel_1
// CHECK: %[[false:.+]] = arith.constant false
// CHECK: %[[c64_i32:.+]] = arith.constant 64 : i32
// CHECK: %[[S0:.+]] = nvvm.read.ptx.sreg.tid.y range <i32, 0, 128> : i32
// CHECK: scf.if %[[false]] {
// CHECK: gpu.printf "threadidx"
// CHECK: %[[S1:.+]] = arith.cmpi sgt, %[[S0]], %[[c64_i32]] : i32
// CHECK: scf.if %[[S1]] {
// CHECK: gpu.printf "threadidy"
// CHECK: scf.if %[[false]] {
// CHECK: gpu.printf "threadidz"

0 comments on commit 314646d

Please sign in to comment.