Skip to content

Commit

Permalink
[Encoding] Introduce "layouts" field to EncodingAttr.
Browse files Browse the repository at this point in the history
The revision introduces an optional "layouts" field to EncodingAttr. It
is an array of attributes that describes the potential layouts on the
device. It is an array because a device could have several executable
targets. Note that it can be any attribute with encoding attribute
interface implementation. The expectation of the field is to bridge the
logics between host codes and device codes. If an attribute does not
implement the interface, it could be discarded anytime.

It is a step towards #17924

Signed-off-by: hanhanW <[email protected]>
  • Loading branch information
hanhanW committed Nov 20, 2024
1 parent 73c8b00 commit c69bc93
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ EncodingAttr EncodingAttr::get(MLIRContext *ctx, int64_t operandIndex,
EncodingOpType opType, ArrayRef<Type> elemTypes,
ArrayRef<AffineMap> maps,
std::optional<AffineMap> bcastMap,
ArrayRef<int64_t> roundDimsTo) {
ArrayRef<int64_t> roundDimsTo,
ArrayRef<Attribute> layouts) {
Builder b(ctx);
auto opTypeAttr = EncodingOpTypeAttr::get(ctx, opType);
auto roundDimsToAttr = roundDimsTo.empty()
Expand All @@ -34,9 +35,10 @@ EncodingAttr EncodingAttr::get(MLIRContext *ctx, int64_t operandIndex,
auto bcastMapAttr = bcastMap.has_value()
? AffineMapAttr::get(bcastMap.value())
: AffineMapAttr();
auto layoutsAttr = layouts.empty() ? ArrayAttr() : b.getArrayAttr(layouts);
return get(ctx, b.getIndexAttr(operandIndex), opTypeAttr,
b.getTypeArrayAttr(elemTypes), b.getAffineMapArrayAttr(maps),
bcastMapAttr, roundDimsToAttr);
bcastMapAttr, roundDimsToAttr, layoutsAttr);
}

AffineMap EncodingAttr::getMapForOperandIndex() {
Expand Down Expand Up @@ -106,7 +108,7 @@ SmallVector<Type> EncodingAttr::getElementTypesArray() {
EncodingAttr EncodingAttr::clone(AffineMap bcastMap) {
return get(bcastMap.getContext(), getOperandIndex(), getOpType(),
getElementTypes(), getUserIndexingMaps(),
AffineMapAttr::get(bcastMap), getRoundDimsTo());
AffineMapAttr::get(bcastMap), getRoundDimsTo(), getLayouts());
}

MatmulNarrowDim getMatmulNarrowDim(EncodingAttr encoding) {
Expand Down
18 changes: 15 additions & 3 deletions compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,19 @@ def EncodingAttr :
AttrParameter<"ArrayAttr", "element types of the user's operands">:$element_types,
OptionalParameter<"ArrayAttr", "Indexing maps of the operation using this tensor">:$user_indexing_maps,
OptionalParameter<"AffineMapAttr", "Indexing map that represents the broadcasting dims in the producer">:$bcast_map,
// TODO(hanchung): The round_dims_to parameter can be revisited. We explicitly map them to M,N,K dimension for now.
OptionalParameter<"DenseArrayAttr", "Values for padding M,N,K dimensions">:$round_dims_to
// TODO(hanchung): Deprecate the round_dims_to field when we plumb the layouts
// field through the whole stack. See https://github.com/iree-org/iree/issues/17924
// for details. Note that today we abuse the attribute to carry narrow
// matrix information. The end goal is deprecating the field and add a
// "iteration_space_size" field to describe the shape. It is useful to
// handle narrow matrix cases.
OptionalParameter<"DenseArrayAttr", "Values for padding M,N,K dimensions">:$round_dims_to,
OptionalParameter<"ArrayAttr", "An array of attributes that describes the "
"potential layouts on the device. It is an array because a device could "
"have several executable targets. Note that it can be any attribute that "
"implements EncodingLayoutAttrInterface. The expectation of the field "
"is to bridge the logics between host codes and device codes. If an "
"attribute does not implement the interface, it could be discarded anytime.">:$layouts
);

let builders = [
Expand All @@ -73,7 +84,8 @@ def EncodingAttr :
"ArrayRef<Type>":$elemTypes,
CArg<"ArrayRef<AffineMap>", "{}">:$maps,
CArg<"std::optional<AffineMap>", "{}">:$bcastMap,
CArg<"ArrayRef<int64_t>", "{}">:$roundDimsTo)>
CArg<"ArrayRef<int64_t>", "{}">:$roundDimsTo,
CArg<"ArrayRef<Attribute>", "{}">:$layouts)>
];

let extraClassDeclaration = [{
Expand Down
12 changes: 12 additions & 0 deletions compiler/src/iree/compiler/Dialect/Encoding/IR/test/roundtrip.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -157,3 +157,15 @@ func.func @set_encoding_ops_with_indexing_maps(%arg0: tensor<?x?xf32>) -> tensor
// CHECK: func.func @set_encoding_ops_with_indexing_maps(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]:
// CHECK: iree_encoding.set_encoding %[[ARG0]] : tensor<?x?xf32> -> tensor<?x?xf32, #[[ENCODING]]>

// -----

#encoding = #iree_encoding.encoding<operand_index = 0 : i64, op_type = matmul, element_types = [f32, f32, f32], layouts = [{}]>
func.func @set_encoding_ops_with_layouts(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32, #encoding> {
%0 = iree_encoding.set_encoding %arg0 : tensor<?x?xf32> -> tensor<?x?xf32, #encoding>
return %0 : tensor<?x?xf32, #encoding>
}
// CHECK-DAG: #[[ENCODING:.+]] = #iree_encoding.encoding<operand_index = 0 : i64, op_type = matmul, element_types = [f32, f32, f32], layouts = [{}]>
// CHECK: func.func @set_encoding_ops_with_layouts(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]:
// CHECK: iree_encoding.set_encoding %[[ARG0]] : tensor<?x?xf32> -> tensor<?x?xf32, #[[ENCODING]]>

0 comments on commit c69bc93

Please sign in to comment.