Skip to content

Commit

Permalink
Add integer range inference to hal.buffer_view.dim and rank ops. (#18943
Browse files Browse the repository at this point in the history
)

This matches that default range behavior of runtime dimensions we get
from frontends.

---------

Signed-off-by: Stella Laurenzo <[email protected]>
  • Loading branch information
stellaraccident authored Oct 30, 2024
1 parent 49ffdac commit d1dd3e3
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 2 deletions.
45 changes: 45 additions & 0 deletions compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,27 @@

namespace mlir::iree_compiler::IREE::HAL {

namespace {

// We aribtrarily say that unbounded dimensions in a torch program cannot
// exceed 53bits, making the maximum safe dimension 9007199254740991. The
// astute reader will note that this is also the maximum safe value in
// JavaScript, which also "happens" to be the largest mantissa value in a
// 64bit double. We need a maximum and in the absence of a better choice,
// with this one we are at least in good company. This limit is also used
// in the frontends.
static constexpr uint64_t MAX_DIM_VALUE = (static_cast<uint64_t>(1) << 53) - 1;

// Similarly we use a very conservative maximum rank value for specifying
// ranges of runtime rank resolution functions. Various frameworks have hard
// and practical limits ranging from 32 (numpy) to hundreds. At the time of
// writing, PyTorch throws weird errors if trying to print a tensor with a rank
// greater than 992. We really just want a smallish integer value to bound
// arithmetic, so we use an arbitrary maximum.
static constexpr uint64_t MAX_RANK_VALUE = 4096;

} // namespace

//===----------------------------------------------------------------------===//
// custom<DescriptorType>($descriptor_type)
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1024,6 +1045,30 @@ void BufferViewBufferOp::getAsmResultNames(
setNameFn(getResult(), "buffer");
}

//===----------------------------------------------------------------------===//
// hal.buffer_view.dim
//===----------------------------------------------------------------------===//

void BufferViewDimOp::inferResultRangesFromOptional(
ArrayRef<IntegerValueRange> argRanges, SetIntLatticeFn setResultRange) {
const unsigned indexTypeNumBits = 64;
setResultRange(getResult(), IntegerValueRange(ConstantIntRanges::fromUnsigned(
APInt::getZero(indexTypeNumBits),
APInt(indexTypeNumBits, MAX_DIM_VALUE))));
}

//===----------------------------------------------------------------------===//
// hal.buffer_view.dim
//===----------------------------------------------------------------------===//

void BufferViewRankOp::inferResultRangesFromOptional(
ArrayRef<IntegerValueRange> argRanges, SetIntLatticeFn setResultRange) {
const unsigned indexTypeNumBits = 64;
setResultRange(getResult(), IntegerValueRange(ConstantIntRanges::fromUnsigned(
APInt::getZero(indexTypeNumBits),
APInt(indexTypeNumBits, MAX_RANK_VALUE))));
}

//===----------------------------------------------------------------------===//
// hal.channel.create
//===----------------------------------------------------------------------===//
Expand Down
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Interfaces/InferIntRangeInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Interfaces/ViewLikeInterface.h"

Expand Down
11 changes: 9 additions & 2 deletions compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ include "mlir/IR/OpAsmInterface.td"
include "mlir/IR/SymbolInterfaces.td"
include "mlir/Interfaces/CallInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/InferIntRangeInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/ViewLikeInterface.td"

Expand Down Expand Up @@ -1010,7 +1011,10 @@ def HAL_BufferViewEncodingTypeOp : HAL_PureOp<"buffer_view.encoding_type"> {
}];
}

def HAL_BufferViewRankOp : HAL_PureOp<"buffer_view.rank"> {
def HAL_BufferViewRankOp : HAL_PureOp<"buffer_view.rank", [
DeclareOpInterfaceMethods<InferIntRangeInterface,
["inferResultRangesFromOptional"]>,
]> {
let summary = [{buffer view rank query}];
let description = [{
Returns the rank of the buffer view.
Expand All @@ -1030,7 +1034,10 @@ def HAL_BufferViewRankOp : HAL_PureOp<"buffer_view.rank"> {
}];
}

def HAL_BufferViewDimOp : HAL_PureOp<"buffer_view.dim"> {
def HAL_BufferViewDimOp : HAL_PureOp<"buffer_view.dim", [
DeclareOpInterfaceMethods<InferIntRangeInterface,
["inferResultRangesFromOptional"]>,
]> {
let summary = [{buffer view dimension value query}];
let description = [{
Returns the value of the given dimension.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -493,3 +493,33 @@ util.func @util_align_zero(%arg0 : i64) -> i64 {
%rem16 = arith.remui %0, %c16 : i64
util.return %rem16 : i64
}

// -----

util.func @hal_buffer_view_dim_min_max(%bv : !hal.buffer_view) -> (i1, i1, i1) {
%zero = arith.constant 0 : index
%max = arith.constant 9007199254740991 : index
%0 = hal.buffer_view.dim<%bv : !hal.buffer_view>[0] : index
%1 = arith.cmpi slt, %0, %zero : index
%2 = arith.cmpi uge, %0, %zero : index
%3 = arith.cmpi ugt, %0, %max : index
// CHECK-DAG: %[[FALSE:.*]] = arith.constant false
// CHECK-DAG: %[[TRUE:.*]] = arith.constant true
// CHECK: util.return %[[FALSE]], %[[TRUE]], %[[FALSE]]
util.return %1, %2, %3 : i1, i1, i1
}

// -----

util.func @hal_buffer_view_rank_min_max(%bv : !hal.buffer_view) -> (i1, i1, i1) {
%zero = arith.constant 0 : index
%max = arith.constant 4096 : index
%0 = hal.buffer_view.rank<%bv : !hal.buffer_view> : index
%1 = arith.cmpi slt, %0, %zero : index
%2 = arith.cmpi uge, %0, %zero : index
%3 = arith.cmpi ugt, %0, %max : index
// CHECK-DAG: %[[FALSE:.*]] = arith.constant false
// CHECK-DAG: %[[TRUE:.*]] = arith.constant true
// CHECK: util.return %[[FALSE]], %[[TRUE]], %[[FALSE]]
util.return %1, %2, %3 : i1, i1, i1
}

0 comments on commit d1dd3e3

Please sign in to comment.