Skip to content

Commit

Permalink
[Codegen] Implement serialization for MaterializeEncodingInfo struct. (
Browse files Browse the repository at this point in the history
…iree-org#19260)

The revision adds the support for conversion between
MaterializeEncodingInfo struct and DictionaryAttr. It also implements
the in(equality) operators for the struct to verify if the deserialized
result match the original struct.

Signed-off-by: hanhanW <[email protected]>
  • Loading branch information
hanhanW authored Nov 22, 2024
1 parent a67b00b commit 5de0f06
Show file tree
Hide file tree
Showing 3 changed files with 161 additions and 0 deletions.
63 changes: 63 additions & 0 deletions compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,18 @@ llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
return os;
}

bool operator==(const MaterializeEncodingInfo &lhs,
const MaterializeEncodingInfo &rhs) {
return lhs.innerDimsPos == rhs.innerDimsPos &&
lhs.innerTileSizes == rhs.innerTileSizes &&
lhs.outerDimsPerm == rhs.outerDimsPerm && lhs.swizzle == rhs.swizzle;
}

bool operator!=(const MaterializeEncodingInfo &lhs,
const MaterializeEncodingInfo &rhs) {
return !(lhs == rhs);
}

//===----------------------------------------------------------------------===//
// Layout Utilities.
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -188,6 +200,57 @@ std::optional<TileSwizzle> deserializeTileSwizzle(DictionaryAttr attr) {
return swizzle;
}

DictionaryAttr serializeEncodingInfo(MLIRContext *ctx,
const MaterializeEncodingInfo &info) {
Builder b(ctx);
SmallVector<NamedAttribute> items;
items.emplace_back(b.getStringAttr("innerDimsPos"),
b.getI64ArrayAttr(info.innerDimsPos));
items.emplace_back(b.getStringAttr("innerTileSizes"),
b.getI64ArrayAttr(info.innerTileSizes));
items.emplace_back(b.getStringAttr("outerDimsPerm"),
b.getI64ArrayAttr(info.outerDimsPerm));
if (info.swizzle) {
items.emplace_back(b.getStringAttr("swizzle"),
serializeTileSwizzle(ctx, info.swizzle.value()));
}

return b.getDictionaryAttr(items);
}

std::optional<MaterializeEncodingInfo>
deserializeEncodingInfo(DictionaryAttr attr) {
MaterializeEncodingInfo info;

#define extractArrayAttrItem(name) \
{ \
auto value = attr.getNamed(#name); \
if (!value || !isa<ArrayAttr>(value->getValue())) { \
return std::nullopt; \
} \
info.name = extractFromIntegerArrayAttr<int64_t>(value->getValue()); \
}

extractArrayAttrItem(innerDimsPos);
extractArrayAttrItem(innerTileSizes);
extractArrayAttrItem(outerDimsPerm);
#undef extractArrayAttrItem

if (attr.contains("swizzle")) {
auto dictAttr =
dyn_cast<DictionaryAttr>(attr.getNamed("swizzle")->getValue());
if (!dictAttr) {
return std::nullopt;
}
info.swizzle = deserializeTileSwizzle(dictAttr);
if (!info.swizzle) {
return std::nullopt;
}
}

return info;
}

SmallVector<int64_t>
getExpandedTileShape(const TileSwizzle::ExpandShapeType &expandShape) {
SmallVector<int64_t> result;
Expand Down
11 changes: 11 additions & 0 deletions compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,11 @@ struct MaterializeEncodingInfo {
std::optional<TileSwizzle> swizzle;
};

bool operator==(const MaterializeEncodingInfo &lhs,
const MaterializeEncodingInfo &rhs);
bool operator!=(const MaterializeEncodingInfo &lhs,
const MaterializeEncodingInfo &rhs);

//===----------------------------------------------------------------------===//
// Layout Utilities.
//===----------------------------------------------------------------------===//
Expand All @@ -120,6 +125,12 @@ DictionaryAttr serializeTileSwizzle(MLIRContext *ctx,
const TileSwizzle &swizzle);
std::optional<TileSwizzle> deserializeTileSwizzle(DictionaryAttr attr);

/// Conversion between MaterializeEncodingInfo struct and DictionaryAttr.
DictionaryAttr serializeEncodingInfo(MLIRContext *ctx,
const MaterializeEncodingInfo &info);
std::optional<MaterializeEncodingInfo>
deserializeEncodingInfo(DictionaryAttr attr);

/// Concatenates the vectors.
SmallVector<int64_t>
getExpandedTileShape(const TileSwizzle::ExpandShapeType &expandShape);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,5 +99,92 @@ TEST(TileSwizzle, Deserialization) {
EXPECT_FALSE(deserializeTileSwizzle(b.getDictionaryAttr(items)).has_value());
}

TEST(MaterializeEncodingInfo, RelationalOperator) {
MaterializeEncodingInfo info1;
info1.innerDimsPos = {0, 1};
info1.innerTileSizes = {16, 1};
info1.outerDimsPerm = {0, 2, 1, 3};

MaterializeEncodingInfo info2;
info2.innerDimsPos = {1, 0};
info2.innerTileSizes = {16, 1};
info2.outerDimsPerm = {0, 2, 1, 3};

EXPECT_EQ(info1, info1);
EXPECT_EQ(info2, info2);
EXPECT_NE(info1, info2);

// They mismatch if one has a swizzle, but not the other.
info2 = info1;
info1.swizzle = TileSwizzle();
EXPECT_NE(info1, info2);

// They match because they all have an empty swizzle.
info2.swizzle = TileSwizzle();
EXPECT_EQ(info1, info2);
}

TEST(MaterializeEncodingInfo, Serialization) {
MaterializeEncodingInfo info;
info.innerDimsPos = {0, 1};
info.innerTileSizes = {16, 1};
info.outerDimsPerm = {0, 2, 1, 3};

MLIRContext ctx;
DictionaryAttr dictAttr = serializeEncodingInfo(&ctx, info);

EXPECT_TRUE(dictAttr.contains("innerDimsPos"));
EXPECT_TRUE(dictAttr.contains("innerTileSizes"));
EXPECT_TRUE(dictAttr.contains("outerDimsPerm"));
EXPECT_FALSE(dictAttr.contains("swizzle"));

EXPECT_TRUE(isa<ArrayAttr>(dictAttr.getNamed("innerDimsPos")->getValue()));
EXPECT_TRUE(isa<ArrayAttr>(dictAttr.getNamed("innerTileSizes")->getValue()));
EXPECT_TRUE(isa<ArrayAttr>(dictAttr.getNamed("outerDimsPerm")->getValue()));

auto extractedInnerDimsPos = extractFromIntegerArrayAttr<int64_t>(
dictAttr.getNamed("innerDimsPos")->getValue());
EXPECT_EQ(extractedInnerDimsPos, info.innerDimsPos);
auto extractedInnerTileSizes = extractFromIntegerArrayAttr<int64_t>(
dictAttr.getNamed("innerTileSizes")->getValue());
EXPECT_EQ(extractedInnerTileSizes, info.innerTileSizes);
auto extractedOuterDimsPerm = extractFromIntegerArrayAttr<int64_t>(
dictAttr.getNamed("outerDimsPerm")->getValue());
EXPECT_EQ(extractedOuterDimsPerm, info.outerDimsPerm);

std::optional<MaterializeEncodingInfo> deserializedInfo =
deserializeEncodingInfo(dictAttr);
EXPECT_THAT(deserializedInfo, Optional(info));
}

TEST(MaterializeEncodingInfo, Deserialization) {
MLIRContext ctx;
Builder b(&ctx);

auto emptyDictAttr = b.getDictionaryAttr({});
EXPECT_FALSE(deserializeEncodingInfo(emptyDictAttr).has_value());

SmallVector<NamedAttribute> items;
items.emplace_back(b.getStringAttr("innerDimsPos"),
b.getI64ArrayAttr({0, 1}));
EXPECT_FALSE(deserializeEncodingInfo(b.getDictionaryAttr(items)).has_value());

items.emplace_back(b.getStringAttr("innerTileSizes"),
b.getI64ArrayAttr({16, 1}));
EXPECT_FALSE(deserializeEncodingInfo(b.getDictionaryAttr(items)).has_value());

items.emplace_back(b.getStringAttr("outerDimsPerm"),
b.getI64ArrayAttr({0, 2, 1, 3}));
EXPECT_TRUE(deserializeEncodingInfo(b.getDictionaryAttr(items)).has_value());

// If the swizzle presents, it needs to be deserializable to TileSwizzle.
items.emplace_back(b.getStringAttr("swizzle"), b.getUnitAttr());
EXPECT_FALSE(deserializeEncodingInfo(b.getDictionaryAttr(items)).has_value());

TileSwizzle swizzle;
items.back().setValue(serializeTileSwizzle(&ctx, swizzle));
EXPECT_TRUE(deserializeEncodingInfo(b.getDictionaryAttr(items)).has_value());
}

} // namespace
} // namespace mlir::iree_compiler::IREE::Codegen

0 comments on commit 5de0f06

Please sign in to comment.