Skip to content

Commit

Permalink
Calculate storage bytes through interface method for encoding types. (i…
Browse files Browse the repository at this point in the history
…ree-org#19413)

The revision moves the implementation of storage bytes calculation to
`EncodingAttr::calculateStorageSizeInBytes`. If the encoding attribute
implements the interface, the implementation has higher priority.
Because it knows all the details, including whether packing the data
back-to-back or not.

The change is not NFC because it also fixes a bug for dynamic cases. The
`dynamicDims` value range is not a mixed value range. It is only for
dynamic cases. To make the logic correct, we need to use
`getDynamicDimIndex()` to get the corresponding dimension index before
the update.

The revision duplicates two methods from Util to Encoding dialect
because we do not want the dependency (i.e., encoding -> util):
  - getTypeBitWidth
  - getRoundedElementByteWidth

The function argument of `calculateStorageSizeInBytes` method is changed
because of the needs.

---------

Signed-off-by: hanhanW <[email protected]>
  • Loading branch information
hanhanW authored Dec 13, 2024
1 parent 900ef1d commit c618134
Show file tree
Hide file tree
Showing 5 changed files with 147 additions and 50 deletions.
107 changes: 104 additions & 3 deletions compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/TypeSwitch.h"
#include "mlir/Dialect/Affine/Utils.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/Attributes.h"
Expand Down Expand Up @@ -41,7 +42,7 @@ EncodingAttr EncodingAttr::get(MLIRContext *ctx, int64_t operandIndex,
bcastMapAttr, roundDimsToAttr, layoutsAttr);
}

AffineMap EncodingAttr::getMapForOperandIndex() {
AffineMap EncodingAttr::getMapForOperandIndex() const {
auto index = getOperandIndex().getValue().getZExtValue();
switch (index) {
case MATMUL_LHS:
Expand All @@ -59,7 +60,8 @@ AffineMap EncodingAttr::getMapForOperandIndex() {
}
}

std::optional<unsigned> EncodingAttr::mapDimToOperandIndex(int64_t dimPos) {
std::optional<unsigned>
EncodingAttr::mapDimToOperandIndex(int64_t dimPos) const {
return getMapForOperandIndex().getResultPosition(
getAffineDimExpr(dimPos, getContext()));
}
Expand Down Expand Up @@ -91,7 +93,7 @@ MatmulNarrowDim getMatmulNarrowDim(linalg::LinalgOp linalgOp,
return (narrowM && (!narrowN || mSize <= nSize)) ? narrowM : narrowN;
}

ArrayRef<int64_t> EncodingAttr::getRoundDimsToArray() {
ArrayRef<int64_t> EncodingAttr::getRoundDimsToArray() const {
auto roundDimsTo = getRoundDimsTo();
if (!roundDimsTo) {
return {};
Expand All @@ -111,6 +113,105 @@ EncodingAttr EncodingAttr::clone(AffineMap bcastMap) {
AffineMapAttr::get(bcastMap), getRoundDimsTo(), getLayouts());
}

/// Returns the bit-width of the scalar type. If the type is complex, it returns
/// the type of individual elements * 2 (1 for real and 1 for complex).
static unsigned getTypeBitWidth(Type type) {
if (auto complexType = dyn_cast<ComplexType>(type)) {
return 2 * complexType.getElementType().getIntOrFloatBitWidth();
}
return type.getIntOrFloatBitWidth();
}

/// Returns the number of bytes an element of the given type occupies in memory.
/// This is in the default dense conversion to machine words where sizes must be
/// powers of two aligned to bytes.
///
/// Examples:
/// getRoundedElementByteWidth(i1) = 1
/// getRoundedElementByteWidth(i23) = 4
/// getRoundedElementByteWidth(i32) = 4
/// getRoundedElementByteWidth(bf16) = 2
/// getRoundedElementByteWidth(i33) = 8
/// getRoundedElementByteWidth(complex<f32>) = 8
static int32_t getRoundedElementByteWidth(Type type) {
unsigned bitsUnaligned = getTypeBitWidth(type);
assert(bitsUnaligned > 0 && "0-width types unsupported");
// Round up to 8-bit aligned bytes.
unsigned byteAligned = (bitsUnaligned + 8 - 1) / 8;
// Round up to the next power of two (unless already a power of two).
return llvm::PowerOf2Ceil(byteAligned);
}

Value EncodingAttr::calculateStorageSizeInBytes(Location loc,
OpBuilder &builder,
RankedTensorType type,
ValueRange dynamicDims) const {
SmallVector<int64_t> paddedShape(type.getShape());
SmallVector<Value> paddedDynamicDims(dynamicDims.begin(), dynamicDims.end());
ArrayRef<int64_t> roundDimsTo = getRoundDimsToArray();
FailureOr<linalg::ContractionDimensions> cDims =
getEncodingContractionDims(*this);
auto pad = [&](int dim, int value) {
std::optional<unsigned> maybeMappedDim = mapDimToOperandIndex(dim);
if (!maybeMappedDim) {
return;
}
unsigned mappedDim = maybeMappedDim.value();
if (type.isDynamicDim(mappedDim)) {
mappedDim = type.getDynamicDimIndex(mappedDim);
auto alignment = builder.create<arith::ConstantIndexOp>(loc, value);
paddedDynamicDims[mappedDim] = builder.create<arith::CeilDivUIOp>(
loc, paddedDynamicDims[mappedDim], alignment);
paddedDynamicDims[mappedDim] = builder.create<arith::MulIOp>(
loc, paddedDynamicDims[mappedDim], alignment);
} else {
paddedShape[mappedDim] = llvm::alignTo(paddedShape[mappedDim], value);
}
};
for (auto m : cDims->m) {
pad(m, roundDimsTo[0]);
}
for (auto n : cDims->n) {
pad(n, roundDimsTo[1]);
}
for (auto k : cDims->k) {
pad(k, roundDimsTo[2]);
}

constexpr int64_t kNumBitsInByte = 8;
unsigned elementBits = getTypeBitWidth(type.getElementType());
int64_t numBytesPerElem = 1;
if (elementBits > kNumBitsInByte) {
numBytesPerElem *= getRoundedElementByteWidth(type.getElementType());
}

int64_t staticCount = numBytesPerElem;
for (unsigned i = 0, e = type.getRank(); i < e; ++i) {
if (!type.isDynamicDim(i)) {
staticCount *= paddedShape[i];
}
}

Value result =
builder.create<arith::ConstantIndexOp>(loc, staticCount).getResult();
for (auto dim : paddedDynamicDims) {
result = builder.create<arith::MulIOp>(loc, result, dim);
}

// Always pack the elements back-to-back for subtypes.
if (elementBits < kNumBitsInByte) {
if (kNumBitsInByte % elementBits) {
assert(false && "unsupported subtype");
return Value();
}
Value divisor = builder.create<arith::ConstantIndexOp>(
loc, kNumBitsInByte / elementBits);
result = builder.create<arith::CeilDivUIOp>(loc, result, divisor);
}

return result;
}

MatmulNarrowDim getMatmulNarrowDim(EncodingAttr encoding) {
if (encoding.getOpType().getValue() != EncodingOpType::matmul) {
return {};
Expand Down
15 changes: 10 additions & 5 deletions compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#define IREE_DIALECT_ENCODING_ATTRS

include "iree/compiler/Dialect/Encoding/IR/EncodingBase.td"
include "iree/compiler/Dialect/Encoding/IR/EncodingInterfaces.td"
include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/EnumAttr.td"

Expand Down Expand Up @@ -41,10 +42,14 @@ def EncodingOpTypeAttr:
IREEEncoding_EnumAttr<EncodingOpType, "optype">;

def EncodingAttr :
IREEEncoding_Attr<"Encoding"> {
IREEEncoding_Attr<"Encoding", [
DeclareAttrInterfaceMethods<IREEEncoding_EncodingLayoutAttrInterface, [
"calculateStorageSizeInBytes",
]>
]> {
let mnemonic = "encoding";
let summary = [{information to decide how to data-tile a tensor}];
let description = [{
let description = [{
This attribute describes the change in the layout for
a given tensor to execute subsequent operations on
the tiled layout. The encoding serves as a way to
Expand Down Expand Up @@ -93,15 +98,15 @@ def EncodingAttr :
/// operand_index. The dimensions of the returned map are those of the
/// data-tiled op's iteration space, and the results of the map are in
/// the domain of the encoded tensor type.
AffineMap getMapForOperandIndex();
AffineMap getMapForOperandIndex() const;

/// Given the dim position of the encoding `user_indexing_maps`, returns the
/// matching index of the given encoding's tensor, using getMapForOperandIndex
/// bcast_map and user_indexing_map.
std::optional<unsigned> mapDimToOperandIndex(int64_t dimPos);
std::optional<unsigned> mapDimToOperandIndex(int64_t dimPos) const;

/// Returns an integer array with values in `round_dims_to`.
ArrayRef<int64_t> getRoundDimsToArray();
ArrayRef<int64_t> getRoundDimsToArray() const;

/// Returns a vector with values in `element_types`.
SmallVector<Type> getElementTypesArray();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,10 @@ def IREEEncoding_EncodingLayoutAttrInterface :
Returns the storage size (in bytes) for the tensor types with an
optional encoding.
}],
/*retTy=*/"::mlir::OpFoldResult",
/*retTy=*/"::mlir::Value",
/*methodName=*/"calculateStorageSizeInBytes",
/*args=*/(ins
"::mlir::Location":$loc,
"::mlir::OpBuilder &":$builder,
"RankedTensorType":$type,
"ValueRange":$dynamicDims
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,24 @@ util.func public @sizeof_lhs_encoding_dynamic(%arg0: index, %arg1: index) -> ind

// -----

#map = affine_map<(d0, d1, d2) -> (d0, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
#encoding = #iree_encoding.encoding<operand_index = 0, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#map, #map1, #map2], round_dims_to = array<i64: 4, 8, 16>>
util.func public @sizeof_lhs_encoding_partially_dynamic(%arg0: index) -> index {
%0 = stream.tensor.sizeof tensor<10x?xf32, #encoding>{%arg0} : index
util.return %0 : index
}
// CHECK-LABEL: @sizeof_lhs_encoding_partially_dynamic
// CHECK-DAG: %[[C48:.+]] = arith.constant 48 : index
// CHECK-DAG: %[[C16:.+]] = arith.constant 16 : index
// CHECK: %[[CEIL_DIV_D1:.+]] = arith.ceildivui %arg0, %[[C16]]
// CHECK: %[[PAD_D1:.+]] = arith.muli %[[CEIL_DIV_D1]], %[[C16]]
// CHECK: %[[T0:.+]] = arith.muli %[[PAD_D1]], %[[C48]]
// CHECK: return %[[T0]]

// -----

#map = affine_map<(d0, d1, d2) -> (d0, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
Expand Down
54 changes: 13 additions & 41 deletions compiler/src/iree/compiler/Utils/ElementPackingUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
#include "iree/compiler/Utils/ElementPackingUtils.h"

#include "iree/compiler/Dialect/Encoding/IR/EncodingOps.h"
#include "iree/compiler/Dialect/Encoding/IR/EncodingTypes.h"
#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/MathExtras.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
Expand Down Expand Up @@ -62,6 +64,14 @@ Value calculateStorageElementCountInBytes(Location loc,
RankedTensorType shapedType,
ValueRange dynamicDims,
OpBuilder &builder) {
Attribute encoding = shapedType.getEncoding();
if (auto encodingLayoutAttr =
dyn_cast_or_null<IREE::Encoding::EncodingLayoutAttrInterface>(
encoding)) {
return encodingLayoutAttr.calculateStorageSizeInBytes(
loc, builder, shapedType, dynamicDims);
}

Type alignedElementType =
legalizeStorageElementType(shapedType.getElementType());
unsigned elementBits = IREE::Util::getTypeBitWidth(alignedElementType);
Expand All @@ -72,52 +82,14 @@ Value calculateStorageElementCountInBytes(Location loc,
staticCount *= IREE::Util::getRoundedElementByteWidth(alignedElementType);
}

// TODO: Do we use makeComposedFoldedAffineApply here, so the index
// computation an be much simpler.
SmallVector<int64_t> paddedShape(shapedType.getShape());
SmallVector<Value> paddedDynamicDims(dynamicDims.begin(), dynamicDims.end());
auto encoding = IREE::Encoding::getEncodingAttr(shapedType);
if (encoding && !encoding.getRoundDimsToArray().empty()) {
auto roundDimsTo = encoding.getRoundDimsToArray();
FailureOr<linalg::ContractionDimensions> cDims =
IREE::Encoding::getEncodingContractionDims(encoding);
auto pad = [&](int dim, int value) {
std::optional<unsigned> maybeMappedDim =
encoding.mapDimToOperandIndex(dim);
if (!maybeMappedDim) {
return;
}
unsigned mappedDim = maybeMappedDim.value();
if (shapedType.isDynamicDim(mappedDim)) {
auto alignment = builder.create<arith::ConstantIndexOp>(loc, value);
paddedDynamicDims[mappedDim] = builder.create<arith::CeilDivUIOp>(
loc, paddedDynamicDims[mappedDim], alignment);
paddedDynamicDims[mappedDim] = builder.create<arith::MulIOp>(
loc, paddedDynamicDims[mappedDim], alignment);
} else {
paddedShape[mappedDim] = llvm::alignTo(paddedShape[mappedDim], value);
}
};
for (auto m : cDims->m) {
pad(m, roundDimsTo[0]);
}
for (auto n : cDims->n) {
pad(n, roundDimsTo[1]);
}
for (auto k : cDims->k) {
pad(k, roundDimsTo[2]);
}
}

for (unsigned i = 0; i < shapedType.getRank(); ++i) {
if (!shapedType.isDynamicDim(i))
staticCount *= paddedShape[i];
staticCount *= shapedType.getDimSize(i);
}

// Scale by dynamic dims, if present.
auto value =
builder.create<arith::ConstantIndexOp>(loc, staticCount).getResult();
for (auto dim : paddedDynamicDims) {
for (auto dim : dynamicDims) {
value = builder.createOrFold<arith::MulIOp>(loc, value, dim);
}
// Sub-byte packing requires putting multiple elements in the same byte.
Expand All @@ -127,7 +99,7 @@ Value calculateStorageElementCountInBytes(Location loc,
// TODO(antiagainst): We may want to emit runtime check to make sure this is
// divisible.
auto divisor = builder.create<arith::ConstantIndexOp>(loc, byteElements);
if (!clEnableI1Support && paddedDynamicDims.empty() &&
if (!clEnableI1Support && dynamicDims.empty() &&
(staticCount * elementBits) % 8 != 0) {
return nullptr;
}
Expand Down

0 comments on commit c618134

Please sign in to comment.