Skip to content

Commit

Permalink
Implement (in)equality operator and serialization for TileSwizzle. (i…
Browse files Browse the repository at this point in the history
…ree-org#19257)

The revision implements the conversion between the TileSwizzle struct
and DictionaryAttr. The utilities are tested in unittests.

To verify if the deserialization result has the same content, it also
implements the (in)equality operators (i.e., `==` and `!=`) for the
struct.

---------

Signed-off-by: hanhanW <[email protected]>
  • Loading branch information
hanhanW authored Nov 22, 2024
1 parent 3201efb commit a67b00b
Show file tree
Hide file tree
Showing 6 changed files with 304 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ iree_compiler_cc_library(
deps = [
"//compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR:IREECodegenDialect",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:DialectUtils",
"@llvm-project//mlir:IR",
],
)
133 changes: 133 additions & 0 deletions compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,34 @@

#include "iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.h"
#include "llvm/ADT/STLExtras.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"

namespace mlir::iree_compiler::IREE::Codegen {

//===----------------------------------------------------------------------===//
// Layout Structs.
//===----------------------------------------------------------------------===//

bool operator==(TileSwizzle::Dim lhs, TileSwizzle::Dim rhs) {
return lhs.kind == rhs.kind && lhs.size == rhs.size;
}

bool operator!=(TileSwizzle::Dim lhs, TileSwizzle::Dim rhs) {
return !(lhs == rhs);
}

bool operator==(const TileSwizzle &lhs, const TileSwizzle &rhs) {
return lhs.expandShape == rhs.expandShape &&
lhs.permutation == rhs.permutation;
}

bool operator!=(const TileSwizzle &lhs, const TileSwizzle &rhs) {
return !(lhs == rhs);
}

llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
TileSwizzle::Dim::Kind kind) {
switch (kind) {
Expand Down Expand Up @@ -55,6 +76,118 @@ llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
// Layout Utilities.
//===----------------------------------------------------------------------===//

std::string convertSwizzleKindToString(TileSwizzle::Dim::Kind kind) {
switch (kind) {
case TileSwizzle::Dim::Kind::Internal:
return "Internal";
case TileSwizzle::Dim::Kind::CrossThread:
return "CrossThread";
case TileSwizzle::Dim::Kind::CrossIntrinsic:
return "CrossIntrinsic";
default:
assert(false && "unhandled enum type");
}
return "";
}

std::optional<TileSwizzle::Dim::Kind>
convertStringToSwizzleKind(StringRef str) {
if (str == "Internal") {
return TileSwizzle::Dim::Kind::Internal;
}
if (str == "CrossThread") {
return TileSwizzle::Dim::Kind::CrossThread;
}
if (str == "CrossIntrinsic") {
return TileSwizzle::Dim::Kind::CrossIntrinsic;
}
return std::nullopt;
}

static ArrayAttr swizzleDimToArrayAttr(MLIRContext *ctx, TileSwizzle::Dim dim) {
Builder b(ctx);
return b.getArrayAttr({b.getStringAttr(convertSwizzleKindToString(dim.kind)),
b.getI16IntegerAttr(dim.size)});
}

static std::optional<TileSwizzle::Dim> arrayAttrToSwizzleDim(Attribute attr) {
auto arrayAttr = dyn_cast<ArrayAttr>(attr);
if (!arrayAttr) {
return std::nullopt;
}
ArrayRef<Attribute> attrs = arrayAttr.getValue();
if (attrs.size() != 2) {
return std::nullopt;
}
auto kindAttr = dyn_cast<StringAttr>(attrs[0]);
auto sizeAttr = dyn_cast<IntegerAttr>(attrs[1]);
if (!kindAttr || !sizeAttr) {
return std::nullopt;
}
std::optional<TileSwizzle::Dim::Kind> maybeKind =
convertStringToSwizzleKind(kindAttr.getValue());
if (!maybeKind) {
return std::nullopt;
}
return TileSwizzle::Dim(maybeKind.value(), sizeAttr.getInt());
}

DictionaryAttr serializeTileSwizzle(MLIRContext *ctx,
const TileSwizzle &swizzle) {
Builder b(ctx);
SmallVector<NamedAttribute> items;

SmallVector<Attribute> expandShape;
for (auto expandConfig : swizzle.expandShape) {
Attribute expandAttr = b.getArrayAttr(
llvm::map_to_vector(expandConfig, [&](TileSwizzle::Dim dim) {
return cast<Attribute>(swizzleDimToArrayAttr(ctx, dim));
}));
expandShape.push_back(expandAttr);
}

items.emplace_back(b.getStringAttr("expandShape"),
b.getArrayAttr(expandShape));
items.emplace_back(b.getStringAttr("permutation"),
b.getI64ArrayAttr(swizzle.permutation));

return b.getDictionaryAttr(items);
}

std::optional<TileSwizzle> deserializeTileSwizzle(DictionaryAttr attr) {
TileSwizzle swizzle;

auto expandShapeAttr = attr.getNamed("expandShape");
if (!expandShapeAttr) {
return std::nullopt;
}
auto expandShapeArrayAttr = dyn_cast<ArrayAttr>(expandShapeAttr->getValue());
if (!expandShapeArrayAttr) {
return std::nullopt;
}

for (auto expandConfig : expandShapeArrayAttr.getAsRange<ArrayAttr>()) {
TileSwizzle::ExpandShapeDimVectorType vec;
for (auto dimAttr : expandConfig.getAsRange<ArrayAttr>()) {
auto maybeDim = arrayAttrToSwizzleDim(dimAttr);
if (!maybeDim) {
return std::nullopt;
}
vec.push_back(maybeDim.value());
}
swizzle.expandShape.push_back(vec);
}

auto permAttr = attr.getNamed("permutation");
if (!permAttr || !isa<ArrayAttr>(permAttr->getValue())) {
return std::nullopt;
}
swizzle.permutation =
extractFromIntegerArrayAttr<int64_t>(permAttr->getValue());

return swizzle;
}

SmallVector<int64_t>
getExpandedTileShape(const TileSwizzle::ExpandShapeType &expandShape) {
SmallVector<int64_t> result;
Expand Down
16 changes: 16 additions & 0 deletions compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
#include "llvm-c/TargetMachine.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/Support/LLVM.h"

namespace mlir::iree_compiler::IREE::Codegen {
Expand Down Expand Up @@ -79,6 +81,11 @@ struct TileSwizzle {
llvm::SmallVector<int64_t> permutation;
};

bool operator==(TileSwizzle::Dim lhs, TileSwizzle::Dim rhs);
bool operator!=(TileSwizzle::Dim lhs, TileSwizzle::Dim rhs);
bool operator==(const TileSwizzle &lhs, const TileSwizzle &rhs);
bool operator!=(const TileSwizzle &lhs, const TileSwizzle &rhs);

llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
TileSwizzle::Dim::Kind kind);

Expand All @@ -104,6 +111,15 @@ struct MaterializeEncodingInfo {
// Layout Utilities.
//===----------------------------------------------------------------------===//

/// Conversion between TileSwizzle::Dim::Kind and string.
std::string convertSwizzleKindToString(TileSwizzle::Dim::Kind kind);
std::optional<TileSwizzle::Dim::Kind> convertStringToSwizzleKind(StringRef str);

/// Conversion between TileSwizzle struct and DictionaryAttr.
DictionaryAttr serializeTileSwizzle(MLIRContext *ctx,
const TileSwizzle &swizzle);
std::optional<TileSwizzle> deserializeTileSwizzle(DictionaryAttr attr);

/// Concatenates the vectors.
SmallVector<int64_t>
getExpandedTileShape(const TileSwizzle::ExpandShapeType &expandShape);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Copyright 2024 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

load("//build_tools/bazel:build_defs.oss.bzl", "iree_compiler_cc_test")

package(
features = ["layering_check"],
licenses = ["notice"], # Apache 2.0
)

iree_compiler_cc_test(
name = "UtilsTest",
testonly = True,
srcs = ["UtilsTest.cpp"],
deps = [
"//compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils",
"//compiler/src/iree/testing:gtest_main",
"@com_google_googletest//:gtest",
"@llvm-project//mlir:DialectUtils",
"@llvm-project//mlir:IR",
],
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
################################################################################
# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
# compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/unittests/BUILD.bazel#
# #
# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
# CMake-only content. #
# #
# To disable autogeneration for this file entirely, delete this header. #
################################################################################

iree_add_all_subdirs()

iree_cc_test(
NAME
UtilsTest
SRCS
"UtilsTest.cpp"
DEPS
MLIRIR
gmock
gtest
iree::compiler::Codegen::Dialect::Codegen::Utils
iree::testing::gtest_main
)

### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include <gmock/gmock.h>
#include <gtest/gtest.h>

#include "iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/Attributes.h"

namespace mlir::iree_compiler::IREE::Codegen {
namespace {

using testing::Optional;
using Kind = TileSwizzle::Dim::Kind;

TEST(TileSwizzle, RelationalOperator) {
TileSwizzle swizzle1;
swizzle1.permutation = {1, 2, 0};
TileSwizzle swizzle2;
EXPECT_NE(swizzle1, swizzle2);
swizzle2.permutation = swizzle1.permutation;
EXPECT_EQ(swizzle1, swizzle2);
swizzle1.expandShape.push_back({TileSwizzle::Dim(Kind::CrossThread, 16)});
swizzle2.expandShape.push_back({TileSwizzle::Dim(Kind::Internal, 16)});
EXPECT_NE(swizzle1, swizzle2);
swizzle2.expandShape[0][0].kind = Kind::CrossThread;
EXPECT_EQ(swizzle1, swizzle2);
}

TEST(TileSwizzle, DimKindToString) {
EXPECT_EQ(convertSwizzleKindToString(Kind::Internal), "Internal");
EXPECT_EQ(convertSwizzleKindToString(Kind::CrossIntrinsic), "CrossIntrinsic");
EXPECT_EQ(convertSwizzleKindToString(Kind::CrossThread), "CrossThread");
}

TEST(TileSwizzle, StringToDimKind) {
std::optional<TileSwizzle::Dim::Kind> maybeKind;
maybeKind = convertStringToSwizzleKind("Internal");
EXPECT_THAT(maybeKind, Optional(TileSwizzle::Dim::Kind::Internal));
maybeKind = convertStringToSwizzleKind("CrossIntrinsic");
EXPECT_THAT(maybeKind, Optional(TileSwizzle::Dim::Kind::CrossIntrinsic));
maybeKind = convertStringToSwizzleKind("CrossThread");
EXPECT_THAT(maybeKind, Optional(TileSwizzle::Dim::Kind::CrossThread));
maybeKind = convertStringToSwizzleKind("deadbeef");
EXPECT_FALSE(maybeKind.has_value());
}

TEST(TileSwizzle, Serialization) {
TileSwizzle swizzle;
swizzle.expandShape.push_back({TileSwizzle::Dim(Kind::CrossThread, 16)});
swizzle.expandShape.push_back({TileSwizzle::Dim(Kind::CrossIntrinsic, 4),
TileSwizzle::Dim(Kind::Internal, 4)});
swizzle.permutation = {1, 2, 0};

MLIRContext ctx;
DictionaryAttr dictAttr = serializeTileSwizzle(&ctx, swizzle);

EXPECT_TRUE(dictAttr.contains("expandShape"));
EXPECT_TRUE(dictAttr.contains("permutation"));

// Verify if the sizes match. The check of values is done by the comparison
// between deserialzation result and the original struct.
auto expandShapeArrayAttr =
dyn_cast<ArrayAttr>(dictAttr.getNamed("expandShape")->getValue());
EXPECT_EQ(expandShapeArrayAttr.size(), swizzle.expandShape.size());
for (auto [expectedShape, actualShape] : llvm::zip_equal(
swizzle.expandShape, expandShapeArrayAttr.getAsRange<ArrayAttr>())) {
EXPECT_EQ(expectedShape.size(), actualShape.size());
}

SmallVector<int64_t> extractedPerm = extractFromIntegerArrayAttr<int64_t>(
dictAttr.getNamed("permutation")->getValue());
EXPECT_EQ(extractedPerm, swizzle.permutation);

std::optional<TileSwizzle> deserializedSwizzle =
deserializeTileSwizzle(dictAttr);
EXPECT_THAT(deserializedSwizzle, Optional(swizzle));
}

TEST(TileSwizzle, Deserialization) {
MLIRContext ctx;
Builder b(&ctx);

auto emptyDictAttr = b.getDictionaryAttr({});
EXPECT_FALSE(deserializeTileSwizzle(emptyDictAttr).has_value());

SmallVector<NamedAttribute> items;
items.emplace_back(b.getStringAttr("expandShape"), b.getArrayAttr({}));
EXPECT_FALSE(deserializeTileSwizzle(b.getDictionaryAttr(items)).has_value());

items.emplace_back(b.getStringAttr("permutation"), b.getArrayAttr({}));
EXPECT_TRUE(deserializeTileSwizzle(b.getDictionaryAttr(items)).has_value());

items.back().setValue(b.getUnitAttr());
EXPECT_FALSE(deserializeTileSwizzle(b.getDictionaryAttr(items)).has_value());
}

} // namespace
} // namespace mlir::iree_compiler::IREE::Codegen

0 comments on commit a67b00b

Please sign in to comment.