From 5de0f06ef5ca695e89efaace70bdc8db94a59605 Mon Sep 17 00:00:00 2001 From: Han-Chung Wang Date: Fri, 22 Nov 2024 13:04:41 -0800 Subject: [PATCH] [Codegen] Implement serialization for MaterializeEncodingInfo struct. (#19260) The revision adds the support for conversion between MaterializeEncodingInfo struct and DictionaryAttr. It also implements the in(equality) operators for the struct to verify if the deserialized result match the original struct. Signed-off-by: hanhanW --- .../Codegen/Dialect/Codegen/Utils/Utils.cpp | 63 ++++++++++++++ .../Codegen/Dialect/Codegen/Utils/Utils.h | 11 +++ .../Codegen/Utils/unittests/UtilsTest.cpp | 87 +++++++++++++++++++ 3 files changed, 161 insertions(+) diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.cpp b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.cpp index 1ef10b52ad87..266153b042af 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.cpp @@ -72,6 +72,18 @@ llvm::raw_ostream &operator<<(llvm::raw_ostream &os, return os; } +bool operator==(const MaterializeEncodingInfo &lhs, + const MaterializeEncodingInfo &rhs) { + return lhs.innerDimsPos == rhs.innerDimsPos && + lhs.innerTileSizes == rhs.innerTileSizes && + lhs.outerDimsPerm == rhs.outerDimsPerm && lhs.swizzle == rhs.swizzle; +} + +bool operator!=(const MaterializeEncodingInfo &lhs, + const MaterializeEncodingInfo &rhs) { + return !(lhs == rhs); +} + //===----------------------------------------------------------------------===// // Layout Utilities. //===----------------------------------------------------------------------===// @@ -188,6 +200,57 @@ std::optional deserializeTileSwizzle(DictionaryAttr attr) { return swizzle; } +DictionaryAttr serializeEncodingInfo(MLIRContext *ctx, + const MaterializeEncodingInfo &info) { + Builder b(ctx); + SmallVector items; + items.emplace_back(b.getStringAttr("innerDimsPos"), + b.getI64ArrayAttr(info.innerDimsPos)); + items.emplace_back(b.getStringAttr("innerTileSizes"), + b.getI64ArrayAttr(info.innerTileSizes)); + items.emplace_back(b.getStringAttr("outerDimsPerm"), + b.getI64ArrayAttr(info.outerDimsPerm)); + if (info.swizzle) { + items.emplace_back(b.getStringAttr("swizzle"), + serializeTileSwizzle(ctx, info.swizzle.value())); + } + + return b.getDictionaryAttr(items); +} + +std::optional +deserializeEncodingInfo(DictionaryAttr attr) { + MaterializeEncodingInfo info; + +#define extractArrayAttrItem(name) \ + { \ + auto value = attr.getNamed(#name); \ + if (!value || !isa(value->getValue())) { \ + return std::nullopt; \ + } \ + info.name = extractFromIntegerArrayAttr(value->getValue()); \ + } + + extractArrayAttrItem(innerDimsPos); + extractArrayAttrItem(innerTileSizes); + extractArrayAttrItem(outerDimsPerm); +#undef extractArrayAttrItem + + if (attr.contains("swizzle")) { + auto dictAttr = + dyn_cast(attr.getNamed("swizzle")->getValue()); + if (!dictAttr) { + return std::nullopt; + } + info.swizzle = deserializeTileSwizzle(dictAttr); + if (!info.swizzle) { + return std::nullopt; + } + } + + return info; +} + SmallVector getExpandedTileShape(const TileSwizzle::ExpandShapeType &expandShape) { SmallVector result; diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.h b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.h index a56e15a95aab..98e4bf4ea56a 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.h +++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.h @@ -107,6 +107,11 @@ struct MaterializeEncodingInfo { std::optional swizzle; }; +bool operator==(const MaterializeEncodingInfo &lhs, + const MaterializeEncodingInfo &rhs); +bool operator!=(const MaterializeEncodingInfo &lhs, + const MaterializeEncodingInfo &rhs); + //===----------------------------------------------------------------------===// // Layout Utilities. //===----------------------------------------------------------------------===// @@ -120,6 +125,12 @@ DictionaryAttr serializeTileSwizzle(MLIRContext *ctx, const TileSwizzle &swizzle); std::optional deserializeTileSwizzle(DictionaryAttr attr); +/// Conversion between MaterializeEncodingInfo struct and DictionaryAttr. +DictionaryAttr serializeEncodingInfo(MLIRContext *ctx, + const MaterializeEncodingInfo &info); +std::optional +deserializeEncodingInfo(DictionaryAttr attr); + /// Concatenates the vectors. SmallVector getExpandedTileShape(const TileSwizzle::ExpandShapeType &expandShape); diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/unittests/UtilsTest.cpp b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/unittests/UtilsTest.cpp index 71f3725642d3..82f482761da7 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/unittests/UtilsTest.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/unittests/UtilsTest.cpp @@ -99,5 +99,92 @@ TEST(TileSwizzle, Deserialization) { EXPECT_FALSE(deserializeTileSwizzle(b.getDictionaryAttr(items)).has_value()); } +TEST(MaterializeEncodingInfo, RelationalOperator) { + MaterializeEncodingInfo info1; + info1.innerDimsPos = {0, 1}; + info1.innerTileSizes = {16, 1}; + info1.outerDimsPerm = {0, 2, 1, 3}; + + MaterializeEncodingInfo info2; + info2.innerDimsPos = {1, 0}; + info2.innerTileSizes = {16, 1}; + info2.outerDimsPerm = {0, 2, 1, 3}; + + EXPECT_EQ(info1, info1); + EXPECT_EQ(info2, info2); + EXPECT_NE(info1, info2); + + // They mismatch if one has a swizzle, but not the other. + info2 = info1; + info1.swizzle = TileSwizzle(); + EXPECT_NE(info1, info2); + + // They match because they all have an empty swizzle. + info2.swizzle = TileSwizzle(); + EXPECT_EQ(info1, info2); +} + +TEST(MaterializeEncodingInfo, Serialization) { + MaterializeEncodingInfo info; + info.innerDimsPos = {0, 1}; + info.innerTileSizes = {16, 1}; + info.outerDimsPerm = {0, 2, 1, 3}; + + MLIRContext ctx; + DictionaryAttr dictAttr = serializeEncodingInfo(&ctx, info); + + EXPECT_TRUE(dictAttr.contains("innerDimsPos")); + EXPECT_TRUE(dictAttr.contains("innerTileSizes")); + EXPECT_TRUE(dictAttr.contains("outerDimsPerm")); + EXPECT_FALSE(dictAttr.contains("swizzle")); + + EXPECT_TRUE(isa(dictAttr.getNamed("innerDimsPos")->getValue())); + EXPECT_TRUE(isa(dictAttr.getNamed("innerTileSizes")->getValue())); + EXPECT_TRUE(isa(dictAttr.getNamed("outerDimsPerm")->getValue())); + + auto extractedInnerDimsPos = extractFromIntegerArrayAttr( + dictAttr.getNamed("innerDimsPos")->getValue()); + EXPECT_EQ(extractedInnerDimsPos, info.innerDimsPos); + auto extractedInnerTileSizes = extractFromIntegerArrayAttr( + dictAttr.getNamed("innerTileSizes")->getValue()); + EXPECT_EQ(extractedInnerTileSizes, info.innerTileSizes); + auto extractedOuterDimsPerm = extractFromIntegerArrayAttr( + dictAttr.getNamed("outerDimsPerm")->getValue()); + EXPECT_EQ(extractedOuterDimsPerm, info.outerDimsPerm); + + std::optional deserializedInfo = + deserializeEncodingInfo(dictAttr); + EXPECT_THAT(deserializedInfo, Optional(info)); +} + +TEST(MaterializeEncodingInfo, Deserialization) { + MLIRContext ctx; + Builder b(&ctx); + + auto emptyDictAttr = b.getDictionaryAttr({}); + EXPECT_FALSE(deserializeEncodingInfo(emptyDictAttr).has_value()); + + SmallVector items; + items.emplace_back(b.getStringAttr("innerDimsPos"), + b.getI64ArrayAttr({0, 1})); + EXPECT_FALSE(deserializeEncodingInfo(b.getDictionaryAttr(items)).has_value()); + + items.emplace_back(b.getStringAttr("innerTileSizes"), + b.getI64ArrayAttr({16, 1})); + EXPECT_FALSE(deserializeEncodingInfo(b.getDictionaryAttr(items)).has_value()); + + items.emplace_back(b.getStringAttr("outerDimsPerm"), + b.getI64ArrayAttr({0, 2, 1, 3})); + EXPECT_TRUE(deserializeEncodingInfo(b.getDictionaryAttr(items)).has_value()); + + // If the swizzle presents, it needs to be deserializable to TileSwizzle. + items.emplace_back(b.getStringAttr("swizzle"), b.getUnitAttr()); + EXPECT_FALSE(deserializeEncodingInfo(b.getDictionaryAttr(items)).has_value()); + + TileSwizzle swizzle; + items.back().setValue(serializeTileSwizzle(&ctx, swizzle)); + EXPECT_TRUE(deserializeEncodingInfo(b.getDictionaryAttr(items)).has_value()); +} + } // namespace } // namespace mlir::iree_compiler::IREE::Codegen