From 7680b8490ad742028d4bce98131b55e3c1571b6f Mon Sep 17 00:00:00 2001 From: Martin Erhart Date: Fri, 13 Dec 2024 18:33:53 +0000 Subject: [PATCH 1/5] [RTG][Elaboration] Do not internalize primitive values --- .../RTG/Transforms/ElaborationPass.cpp | 842 ++++++++++-------- 1 file changed, 466 insertions(+), 376 deletions(-) diff --git a/lib/Dialect/RTG/Transforms/ElaborationPass.cpp b/lib/Dialect/RTG/Transforms/ElaborationPass.cpp index 86c79fe7c703..02856ced2884 100644 --- a/lib/Dialect/RTG/Transforms/ElaborationPass.cpp +++ b/lib/Dialect/RTG/Transforms/ElaborationPass.cpp @@ -22,6 +22,7 @@ #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/PatternMatch.h" +#include "llvm/ADT/FoldingSet.h" #include "llvm/Support/Debug.h" #include #include @@ -85,301 +86,436 @@ static uint32_t getUniformlyInRange(std::mt19937 &rng, uint32_t a, uint32_t b) { namespace { /// The abstract base class for elaborated values. -struct ElaboratorValue { +class ElaboratorValue { public: - enum class ValueKind { Attribute, Set, Bag, Sequence, Index, Bool }; + enum class ValueKind { + Attribute = 0U, + Set, + Bag, + Sequence, + Index, + Bool, + None + }; + + union StorageTy { + StorageTy() : ptr(nullptr) {} + StorageTy(const void *ptr) : ptr(ptr) {} + StorageTy(size_t index) : index(index) {} + StorageTy(bool boolean) : boolean(boolean) {} + + const void *ptr; + size_t index; + bool boolean; + }; + + ElaboratorValue(ValueKind kind = ValueKind::None, + StorageTy storage = StorageTy()) + : kind(kind), storage(storage) {} + + // This constructor is needed for LLVM RTTI + ElaboratorValue(StorageTy storage) : ElaboratorValue() {} + + llvm::hash_code getHashValue() const { + switch (kind) { + case ValueKind::Attribute: + case ValueKind::Set: + case ValueKind::Bag: + case ValueKind::Sequence: + return llvm::hash_combine(kind, storage.ptr); + case ValueKind::Index: + return llvm::hash_combine(kind, storage.index); + case ValueKind::Bool: + return llvm::hash_combine(kind, storage.boolean); + case ValueKind::None: + return llvm::hash_value(kind); + } + llvm::llvm_unreachable_internal("all cases handled above"); + } - ElaboratorValue(ValueKind kind) : kind(kind) {} - virtual ~ElaboratorValue() {} + bool operator==(const ElaboratorValue &other) const { + if (kind != other.kind) + return false; - virtual llvm::hash_code getHashValue() const = 0; - virtual bool isEqual(const ElaboratorValue &other) const = 0; + switch (kind) { + case ValueKind::Attribute: + case ValueKind::Set: + case ValueKind::Bag: + case ValueKind::Sequence: + return storage.ptr == other.storage.ptr; + case ValueKind::Index: + return storage.index == other.storage.index; + case ValueKind::Bool: + return storage.boolean == other.storage.boolean; + case ValueKind::None: + return true; + } + llvm::llvm_unreachable_internal("all cases handled above"); + } -#ifndef NDEBUG - virtual void print(llvm::raw_ostream &os) const = 0; -#endif + operator bool() const { return kind != ValueKind::None; } ValueKind getKind() const { return kind; } + StorageTy getStorage() const { return storage; } private: - const ValueKind kind; + ValueKind kind; + StorageTy storage; }; -/// Holds any typed attribute. Wrapping around an MLIR `Attribute` allows us to -/// use this elaborator value class for any values that have a corresponding -/// MLIR attribute rather than one per kind of attribute. We only support typed -/// attributes because for materialization we need to provide the type to the -/// dialect's materializer. -class AttributeValue : public ElaboratorValue { -public: - AttributeValue(TypedAttr attr) - : ElaboratorValue(ValueKind::Attribute), attr(attr) { - assert(attr && "null attributes not allowed"); - } +} // namespace - // Implement LLVMs RTTI - static bool classof(const ElaboratorValue *val) { - return val->getKind() == ValueKind::Attribute; +namespace llvm { +/// Add support for llvm style casts. We provide a cast between To and From if +/// From is mlir::Attribute or derives from it. +template +struct CastInfo> || + std::is_base_of_v>> + : DefaultDoCastIfPossible> { + /// Arguments are taken as mlir::Attribute here and not as `From`, because + /// when casting from an intermediate type of the hierarchy to one of its + /// children, the val.getTypeID() inside T::classof will use the static + /// getTypeID of the parent instead of the non-static Type::getTypeID that + /// returns the dynamic ID. This means that T::classof would end up comparing + /// the static TypeID of the children to the static TypeID of its parent, + /// making it impossible to downcast from the parent to the child. + static inline bool isPossible(ElaboratorValue ty) { + /// Return a constant true instead of a dynamic true when casting to self or + /// up the hierarchy. + if constexpr (std::is_base_of_v) { + return true; + } else { + return To::classof(ty); + } + } + static inline To doCast(ElaboratorValue value) { + return To(value.getStorage()); } + static To castFailed() { return To(); } +}; - llvm::hash_code getHashValue() const override { - return llvm::hash_combine(attr); +template <> +struct DenseMapInfo { + static inline ElaboratorValue getEmptyKey() { return ElaboratorValue(); } + static inline ElaboratorValue getTombstoneKey() { + return ElaboratorValue(ElaboratorValue::ValueKind::None, + reinterpret_cast(~0ULL)); + } + static unsigned getHashValue(const ElaboratorValue &value) { + return value.getHashValue(); + } + static bool isEqual(const ElaboratorValue &lhs, const ElaboratorValue &rhs) { + return lhs == rhs; } +}; - bool isEqual(const ElaboratorValue &other) const override { - auto *attrValue = dyn_cast(&other); - if (!attrValue) - return false; +} // namespace llvm - return attr == attrValue->attr; - } +namespace { -#ifndef NDEBUG - void print(llvm::raw_ostream &os) const override { - os << ""; +struct SetStorage : public llvm::FoldingSetNode { + SetStorage(SetVector &&set, Type type) + : set(std::move(set)), type(type) {} + + // NOLINTNEXTLINE(readability-identifier-naming) + static void Profile(llvm::FoldingSetNodeID &ID, + const SetVector &set, Type type) { + for (auto el : set) { + ID.AddPointer(el.getStorage().ptr); + ID.AddInteger(static_cast(el.getKind())); + } + ID.AddPointer(type.getAsOpaquePointer()); } -#endif - TypedAttr getAttr() const { return attr; } + // NOLINTNEXTLINE(readability-identifier-naming) + void Profile(llvm::FoldingSetNodeID &ID) const { Profile(ID, set, type); } -private: - const TypedAttr attr; -}; + // Stores the elaborated values of the set. + SetVector set; -/// Holds an evaluated value of a `IndexType`'d value. -class IndexValue : public ElaboratorValue { -public: - IndexValue(size_t index) : ElaboratorValue(ValueKind::Index), index(index) {} + // Store the set type such that we can materialize this evaluated value + // also in the case where the set is empty. + Type type; +}; - // Implement LLVMs RTTI - static bool classof(const ElaboratorValue *val) { - return val->getKind() == ValueKind::Index; +struct BagStorage : public llvm::FoldingSetNode { + BagStorage(MapVector &&bag, Type type) + : bag(std::move(bag)), type(type) {} + + // NOLINTNEXTLINE(readability-identifier-naming) + static void Profile(llvm::FoldingSetNodeID &ID, + const MapVector &bag, + Type type) { + for (auto el : bag) { + ID.AddPointer(el.first.getStorage().ptr); + ID.AddInteger(static_cast(el.first.getKind())); + ID.AddInteger(el.second); + } + ID.AddPointer(type.getAsOpaquePointer()); } - llvm::hash_code getHashValue() const override { - return llvm::hash_value(index); - } + // NOLINTNEXTLINE(readability-identifier-naming) + void Profile(llvm::FoldingSetNodeID &ID) const { Profile(ID, bag, type); } - bool isEqual(const ElaboratorValue &other) const override { - auto *indexValue = dyn_cast(&other); - if (!indexValue) - return false; + // Stores the elaborated values of the bag. + MapVector bag; - return index == indexValue->index; - } + // Store the bag type such that we can materialize this evaluated value + // also in the case where the bag is empty. + Type type; +}; -#ifndef NDEBUG - void print(llvm::raw_ostream &os) const override { - os << ""; +struct SequenceStorage : public llvm::FoldingSetNode { + SequenceStorage(StringRef name, StringAttr familyName, + SmallVector &&args) + : name(name), familyName(familyName), args(std::move(args)) {} + + // NOLINTNEXTLINE(readability-identifier-naming) + static void Profile(llvm::FoldingSetNodeID &ID, StringRef name, + StringAttr familyName, ArrayRef args) { + ID.AddString(name); + ID.AddPointer(familyName.getAsOpaquePointer()); + for (auto el : args) { + ID.AddPointer(el.getStorage().ptr); + ID.AddInteger(static_cast(el.getKind())); + } } -#endif - size_t getIndex() const { return index; } + // NOLINTNEXTLINE(readability-identifier-naming) + void Profile(llvm::FoldingSetNodeID &ID) const { + Profile(ID, name, familyName, args); + } -private: - const size_t index; + StringRef name; + StringAttr familyName; + SmallVector args; }; -/// Holds an evaluated value of an `i1` type'd value. -class BoolValue : public ElaboratorValue { +class Internalizer { public: - BoolValue(bool value) : ElaboratorValue(ValueKind::Bool), value(value) {} - - // Implement LLVMs RTTI - static bool classof(const ElaboratorValue *val) { - return val->getKind() == ValueKind::Bool; + template + StorageTy *internalize(Args &&...args) { + llvm::FoldingSetNodeID profile; + StorageTy::Profile(profile, args...); + void *insertPos = nullptr; + if (auto *storage = + getInternSet().FindNodeOrInsertPos(profile, insertPos)) + return static_cast(storage); + auto *storagePtr = new (allocator.Allocate()) + StorageTy(std::forward(args)...); + getInternSet().InsertNode(storagePtr, insertPos); + return storagePtr; } - llvm::hash_code getHashValue() const override { - return llvm::hash_value(value); + template + llvm::FoldingSet &getInternSet() { + assert(false && "no generic internalization set"); } - bool isEqual(const ElaboratorValue &other) const override { - auto *val = dyn_cast(&other); - if (!val) - return false; + template <> + llvm::FoldingSet &getInternSet() { + return internedSets; + } - return value == val->value; + template <> + llvm::FoldingSet &getInternSet() { + return internedBags; } -#ifndef NDEBUG - void print(llvm::raw_ostream &os) const override { - os << ""; + template <> + llvm::FoldingSet &getInternSet() { + return internedSequences; } -#endif - bool getBool() const { return value; } + // BagStorage *internalize(BagStorage &&storage) { + // llvm::FoldingSetNodeID profile; + // storage.Profile(profile); + // void *insertPos = nullptr; + // if (auto *bag = internedBags.FindNodeOrInsertPos(profile, insertPos)) + // return bag; + // auto *storagePtr = new BagStorage(std::move(storage)); + // internedBags.InsertNode(storagePtr, insertPos); + // return storagePtr; + // } private: - const bool value; + llvm::BumpPtrAllocator allocator; + // A map used to intern elaborator values. We do this such that we can + // compare pointers when, e.g., computing set differences, uniquing the + // elements in a set, etc. Otherwise, we'd need to do a deep value comparison + // in those situations. + // Use a pointer as the key with custom MapInfo because of object slicing when + // inserting an object of a derived class of ElaboratorValue. + // The custom MapInfo makes sure that we do a value comparison instead of + // comparing the pointers. + llvm::FoldingSet internedSets; + llvm::FoldingSet internedBags; + llvm::FoldingSet internedSequences; }; -/// Holds an evaluated value of a `SetType`'d value. -class SetValue : public ElaboratorValue { -public: - SetValue(SetVector &&set, Type type) - : ElaboratorValue(ValueKind::Set), set(std::move(set)), type(type), - cachedHash(llvm::hash_combine( - llvm::hash_combine_range(set.begin(), set.end()), type)) {} - - // Implement LLVMs RTTI - static bool classof(const ElaboratorValue *val) { - return val->getKind() == ValueKind::Set; +/// Holds any typed attribute. Wrapping around an MLIR `Attribute` allows us to +/// use this elaborator value class for any values that have a corresponding +/// MLIR attribute rather than one per kind of attribute. We only support typed +/// attributes because for materialization we need to provide the type to the +/// dialect's materializer. +struct AttributeValue : public ElaboratorValue { + AttributeValue() = default; + AttributeValue(StorageTy storage) + : ElaboratorValue(ValueKind::Attribute, storage) {} + AttributeValue(TypedAttr attr) + : ElaboratorValue(ValueKind::Attribute, attr.getAsOpaquePointer()) { + assert(attr && "null attributes not allowed"); } - llvm::hash_code getHashValue() const override { return cachedHash; } - - bool isEqual(const ElaboratorValue &other) const override { - auto *otherSet = dyn_cast(&other); - if (!otherSet) - return false; - - if (cachedHash != otherSet->cachedHash) - return false; - - // Make sure empty sets of different types are not considered equal - return set == otherSet->set && type == otherSet->type; + // Implement LLVMs RTTI + static bool classof(const ElaboratorValue &val) { + return val.getKind() == ValueKind::Attribute; } -#ifndef NDEBUG - void print(llvm::raw_ostream &os) const override { - os << "print(os); }); - os << "} at " << this << ">"; + TypedAttr getAttr() const { + return cast(Attribute::getFromOpaquePointer(getStorage().ptr)); } -#endif - - const SetVector &getSet() const { return set; } - - Type getType() const { return type; } +}; -private: - // We currently use a sorted vector to represent sets. Note that it is sorted - // by the pointer value and thus non-deterministic. - // We probably want to do some profiling in the future to see if a DenseSet or - // other representation is better suited. - const SetVector set; +/// Holds an evaluated value of a `IndexType`'d value. +struct IndexValue : public ElaboratorValue { + IndexValue() = default; + IndexValue(StorageTy storage) : ElaboratorValue(ValueKind::Index, storage) {} + IndexValue(size_t index) : ElaboratorValue(ValueKind::Index, index) {} - // Store the set type such that we can materialize this evaluated value - // also in the case where the set is empty. - const Type type; + // Implement LLVMs RTTI + static bool classof(const ElaboratorValue &val) { + return val.getKind() == ValueKind::Index; + } - // Compute the hash only once at constructor time. - const llvm::hash_code cachedHash; + size_t getIndex() const { return getStorage().index; } }; -/// Holds an evaluated value of a `BagType`'d value. -class BagValue : public ElaboratorValue { -public: - BagValue(MapVector &&bag, Type type) - : ElaboratorValue(ValueKind::Bag), bag(std::move(bag)), type(type), - cachedHash(llvm::hash_combine( - llvm::hash_combine_range(bag.begin(), bag.end()), type)) {} +/// Holds an evaluated value of an `i1` type'd value. +struct BoolValue : public ElaboratorValue { + BoolValue() = default; + BoolValue(StorageTy storage) : ElaboratorValue(ValueKind::Bool, storage) {} + BoolValue(bool value) : ElaboratorValue(ValueKind::Bool, value) {} // Implement LLVMs RTTI - static bool classof(const ElaboratorValue *val) { - return val->getKind() == ValueKind::Bag; + static bool classof(const ElaboratorValue &val) { + return val.getKind() == ValueKind::Bool; } - llvm::hash_code getHashValue() const override { return cachedHash; } - - bool isEqual(const ElaboratorValue &other) const override { - auto *otherBag = dyn_cast(&other); - if (!otherBag) - return false; + bool getBool() const { return getStorage().boolean; } +}; - if (cachedHash != otherBag->cachedHash) - return false; +/// Holds an evaluated value of a `SetType`'d value. +struct SetValue : public ElaboratorValue { + SetValue() = default; + SetValue(StorageTy storage) : ElaboratorValue(ValueKind::Set, storage) {} + SetValue(Internalizer &internalizer, SetVector &&set, + Type type) + : ElaboratorValue(ValueKind::Set, internalizer.internalize( + std::move(set), type)) {} - return llvm::equal(bag, otherBag->bag) && type == otherBag->type; + // Implement LLVMs RTTI + static bool classof(const ElaboratorValue &val) { + return val.getKind() == ValueKind::Set; } -#ifndef NDEBUG - void print(llvm::raw_ostream &os) const override { - os << " el) { - el.first->print(os); - os << " -> " << el.second; - }); - os << "} at " << this << ">"; + const SetVector &getSet() const { + return static_cast(getStorage().ptr)->set; } -#endif - const MapVector &getBag() const { return bag; } + Type getType() const { + return static_cast(getStorage().ptr)->type; + } +}; - Type getType() const { return type; } +/// Holds an evaluated value of a `BagType`'d value. +struct BagValue : public ElaboratorValue { + BagValue() = default; + BagValue(StorageTy storage) : ElaboratorValue(ValueKind::Bag, storage) {} + BagValue(Internalizer &internalizer, + MapVector &&bag, Type type) + : ElaboratorValue(ValueKind::Bag, internalizer.internalize( + std::move(bag), type)) {} -private: - // Stores the elaborated values of the bag. - const MapVector bag; + // Implement LLVMs RTTI + static bool classof(const ElaboratorValue &val) { + return val.getKind() == ValueKind::Bag; + } - // Store the type of the bag such that we can materialize this evaluated value - // also in the case where the bag is empty. - const Type type; + const MapVector &getBag() const { + return static_cast(getStorage().ptr)->bag; + } - // Compute the hash only once at constructor time. - const llvm::hash_code cachedHash; + Type getType() const { + return static_cast(getStorage().ptr)->type; + } }; /// Holds an evaluated value of a `SequenceType`'d value. -class SequenceValue : public ElaboratorValue { -public: - SequenceValue(StringRef name, StringAttr familyName, - SmallVector &&args) - : ElaboratorValue(ValueKind::Sequence), name(name), - familyName(familyName), args(std::move(args)), - cachedHash(llvm::hash_combine( - llvm::hash_combine_range(this->args.begin(), this->args.end()), - name, familyName)) {} +struct SequenceValue : public ElaboratorValue { + SequenceValue() = default; + SequenceValue(StorageTy storage) + : ElaboratorValue(ValueKind::Sequence, storage) {} + SequenceValue(Internalizer &internalizer, StringRef name, + StringAttr familyName, SmallVector &&args) + : ElaboratorValue(ValueKind::Sequence, + internalizer.internalize( + name, familyName, std::move(args))) {} // Implement LLVMs RTTI - static bool classof(const ElaboratorValue *val) { - return val->getKind() == ValueKind::Sequence; + static bool classof(const ElaboratorValue &val) { + return val.getKind() == ValueKind::Sequence; } - llvm::hash_code getHashValue() const override { return cachedHash; } - - bool isEqual(const ElaboratorValue &other) const override { - auto *otherSeq = dyn_cast(&other); - if (!otherSeq) - return false; - - if (cachedHash != otherSeq->cachedHash) - return false; - - return name == otherSeq->name && familyName == otherSeq->familyName && - args == otherSeq->args; + StringRef getName() const { + return static_cast(getStorage().ptr)->name; } - -#ifndef NDEBUG - void print(llvm::raw_ostream &os) const override { - os << "print(os); }); - os << ") at " << this << ">"; + StringAttr getFamilyName() const { + return static_cast(getStorage().ptr)->familyName; + } + ArrayRef getArgs() const { + return static_cast(getStorage().ptr)->args; } -#endif - - StringRef getName() const { return name; } - StringAttr getFamilyName() const { return familyName; } - ArrayRef getArgs() const { return args; } - -private: - const StringRef name; - const StringAttr familyName; - const SmallVector args; - - // Compute the hash only once at constructor time. - const llvm::hash_code cachedHash; }; } // namespace #ifndef NDEBUG static llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const ElaboratorValue &value) { - value.print(os); + TypeSwitch(value) + .Case( + [&](auto val) { os << ""; }) + .Case( + [&](auto val) { os << ""; }) + .Case( + [&](auto val) { os << ""; }) + .Case([&](auto val) { + os << ""; + }) + .Case([&](auto val) { + os << " &el) { + os << el.first << " -> " << el.second; + }); + os << "} at " << val.getStorage().ptr << ">"; + }) + .Case([&](auto val) { + os << ""; + }) + .Default([](auto val) { + assert(false && "all cases must be covered above"); + return Value(); + }); return os; } #endif @@ -388,32 +524,6 @@ static llvm::raw_ostream &operator<<(llvm::raw_ostream &os, // Hash Map Helpers //===----------------------------------------------------------------------===// -// NOLINTNEXTLINE(readability-identifier-naming) -static llvm::hash_code hash_value(const ElaboratorValue &val) { - return val.getHashValue(); -} - -namespace { -struct InternMapInfo : public DenseMapInfo { - static unsigned getHashValue(const ElaboratorValue *value) { - assert(value != getTombstoneKey() && value != getEmptyKey()); - return hash_value(*value); - } - - static bool isEqual(const ElaboratorValue *lhs, const ElaboratorValue *rhs) { - if (lhs == rhs) - return true; - - auto *tk = getTombstoneKey(); - auto *ek = getEmptyKey(); - if (lhs == tk || rhs == tk || lhs == ek || rhs == ek) - return false; - - return lhs->isEqual(*rhs); - } -}; -} // namespace - //===----------------------------------------------------------------------===// // Main Elaborator Implementation //===----------------------------------------------------------------------===// @@ -427,16 +537,16 @@ class Materializer { /// Materialize IR representing the provided `ElaboratorValue` and return the /// `Value` or a null value on failure. - Value materialize(ElaboratorValue *val, Location loc, - std::queue &elabRequests, + Value materialize(ElaboratorValue val, Location loc, + std::queue &elabRequests, function_ref emitError) { auto iter = materializedValues.find(val); if (iter != materializedValues.end()) return iter->second; - LLVM_DEBUG(llvm::dbgs() << "Materializing " << *val << "\n\n"); + LLVM_DEBUG(llvm::dbgs() << "Materializing " << val << "\n\n"); - return TypeSwitch(val) + return TypeSwitch(val) .Case( [&](auto val) { return visit(val, loc, elabRequests, emitError); }) @@ -453,8 +563,8 @@ class Materializer { /// deleted until `op` is reached. An error is returned if the operation is /// before the insertion point. LogicalResult materialize(Operation *op, - DenseMap &state, - std::queue &elabRequests) { + DenseMap &state, + std::queue &elabRequests) { if (op->getNumRegions() > 0) return op->emitOpError("ops with nested regions must be elaborated away"); @@ -521,10 +631,10 @@ class Materializer { } private: - Value visit(AttributeValue *val, Location loc, - std::queue &elabRequests, + Value visit(const AttributeValue &val, Location loc, + std::queue &elabRequests, function_ref emitError) { - auto attr = val->getAttr(); + auto attr = val.getAttr(); // For index attributes (and arithmetic operations on them) we use the // index dialect. @@ -552,28 +662,28 @@ class Materializer { return res; } - Value visit(IndexValue *val, Location loc, - std::queue &elabRequests, + Value visit(const IndexValue &val, Location loc, + std::queue &elabRequests, function_ref emitError) { - Value res = builder.create(loc, val->getIndex()); + Value res = builder.create(loc, val.getIndex()); materializedValues[val] = res; return res; } - Value visit(BoolValue *val, Location loc, - std::queue &elabRequests, + Value visit(const BoolValue &val, Location loc, + std::queue &elabRequests, function_ref emitError) { - Value res = builder.create(loc, val->getBool()); + Value res = builder.create(loc, val.getBool()); materializedValues[val] = res; return res; } - Value visit(SetValue *val, Location loc, - std::queue &elabRequests, + Value visit(const SetValue &val, Location loc, + std::queue &elabRequests, function_ref emitError) { SmallVector elements; - elements.reserve(val->getSet().size()); - for (auto *el : val->getSet()) { + elements.reserve(val.getSet().size()); + for (auto el : val.getSet()) { auto materialized = materialize(el, loc, elabRequests, emitError); if (!materialized) return Value(); @@ -581,47 +691,38 @@ class Materializer { elements.push_back(materialized); } - auto res = builder.create(loc, val->getType(), elements); + auto res = builder.create(loc, val.getType(), elements); materializedValues[val] = res; return res; } - Value visit(BagValue *val, Location loc, - std::queue &elabRequests, + Value visit(const BagValue &val, Location loc, + std::queue &elabRequests, function_ref emitError) { SmallVector values, weights; - values.reserve(val->getBag().size()); - weights.reserve(val->getBag().size()); - for (auto [val, weight] : val->getBag()) { + values.reserve(val.getBag().size()); + weights.reserve(val.getBag().size()); + for (auto [val, weight] : val.getBag()) { auto materializedVal = materialize(val, loc, elabRequests, emitError); - if (!materializedVal) + auto materializedWeight = + materialize(IndexValue(weight), loc, elabRequests, emitError); + if (!materializedVal || !materializedWeight) return Value(); - auto iter = integerValues.find(weight); - Value materializedWeight; - if (iter != integerValues.end()) { - materializedWeight = iter->second; - } else { - materializedWeight = builder.create( - loc, builder.getIndexAttr(weight)); - integerValues[weight] = materializedWeight; - } - values.push_back(materializedVal); weights.push_back(materializedWeight); } - auto res = - builder.create(loc, val->getType(), values, weights); + auto res = builder.create(loc, val.getType(), values, weights); materializedValues[val] = res; return res; } - Value visit(SequenceValue *val, Location loc, - std::queue &elabRequests, + Value visit(const SequenceValue &val, Location loc, + std::queue &elabRequests, function_ref emitError) { elabRequests.push(val); - return builder.create(loc, val->getName(), ValueRange()); + return builder.create(loc, val.getName(), ValueRange()); } private: @@ -630,8 +731,7 @@ class Materializer { /// insertion point such that future materializations can also reuse previous /// materializations without running into dominance issues (or requiring /// additional checks to avoid them). - DenseMap materializedValues; - DenseMap integerValues; + DenseMap materializedValues; /// Cache the builder to continue insertions at their current insertion point /// for the reason stated above. @@ -652,21 +752,11 @@ struct ElaboratorSharedState { SymbolTable &table; std::mt19937 rng; Namespace names; - - // A map used to intern elaborator values. We do this such that we can - // compare pointers when, e.g., computing set differences, uniquing the - // elements in a set, etc. Otherwise, we'd need to do a deep value comparison - // in those situations. - // Use a pointer as the key with custom MapInfo because of object slicing when - // inserting an object of a derived class of ElaboratorValue. - // The custom MapInfo makes sure that we do a value comparison instead of - // comparing the pointers. - DenseMap, InternMapInfo> - interned; + Internalizer internalizer; /// The worklist used to keep track of the test and sequence operations to /// make sure they are processed top-down (BFS traversal). - std::queue worklist; + std::queue worklist; }; /// Interprets the IR to perform and lower the represented randomizations. @@ -679,17 +769,17 @@ class Elaborator : public RTGOpVisitor> { Elaborator(ElaboratorSharedState &sharedState, Materializer &materializer) : sharedState(sharedState), materializer(materializer) {} - /// Helper to perform internalization and keep track of interpreted value for - /// the given SSA value. - template - void internalizeResult(Value val, Args &&...args) { - // TODO: this isn't the most efficient way to internalize - auto ptr = std::make_unique(std::forward(args)...); - auto *e = ptr.get(); - auto [iter, _] = sharedState.interned.insert({e, std::move(ptr)}); - state[val] = iter->second.get(); + inline void store(Value val, const ElaboratorValue &eval) { + state[val] = eval; + } + + template + inline ValueTy get(Value val) { + return dyn_cast(state.at(val)); } + inline ElaboratorValue get(Value val) { return state.at(val); } + /// Print a nice error message for operations we don't support yet. FailureOr visitUnhandledOp(Operation *op) { return op->emitOpError("elaboration not supported"); @@ -706,14 +796,14 @@ class Elaborator : public RTGOpVisitor> { } FailureOr visitOp(SequenceClosureOp op) { - SmallVector args; + SmallVector args; for (auto arg : op.getArgs()) - args.push_back(state.at(arg)); + args.push_back(get(arg)); auto familyName = op.getSequenceAttr(); auto name = sharedState.names.newName(familyName.getValue()); - internalizeResult(op.getResult(), name, familyName, - std::move(args)); + store(op.getResult(), SequenceValue(sharedState.internalizer, name, + familyName, std::move(args))); return DeletionKind::Delete; } @@ -722,84 +812,84 @@ class Elaborator : public RTGOpVisitor> { } FailureOr visitOp(SetCreateOp op) { - SetVector set; + SetVector set; for (auto val : op.getElements()) - set.insert(state.at(val)); + set.insert(get(val)); - internalizeResult(op.getSet(), std::move(set), - op.getSet().getType()); + store(op.getSet(), SetValue(sharedState.internalizer, std::move(set), + op.getSet().getType())); return DeletionKind::Delete; } FailureOr visitOp(SetSelectRandomOp op) { - auto *set = cast(state.at(op.getSet())); + auto set = cast(get(op.getSet())); size_t selected; if (auto intAttr = op->getAttrOfType("rtg.elaboration_custom_seed")) { std::mt19937 customRng(intAttr.getInt()); - selected = getUniformlyInRange(customRng, 0, set->getSet().size() - 1); + selected = getUniformlyInRange(customRng, 0, set.getSet().size() - 1); } else { selected = - getUniformlyInRange(sharedState.rng, 0, set->getSet().size() - 1); + getUniformlyInRange(sharedState.rng, 0, set.getSet().size() - 1); } - state[op.getResult()] = set->getSet()[selected]; + store(op.getResult(), set.getSet()[selected]); return DeletionKind::Delete; } FailureOr visitOp(SetDifferenceOp op) { - auto original = cast(state.at(op.getOriginal()))->getSet(); - auto diff = cast(state.at(op.getDiff()))->getSet(); + auto original = get(op.getOriginal()).getSet(); + auto diff = get(op.getDiff()).getSet(); - SetVector result(original); + SetVector result(original); result.set_subtract(diff); - internalizeResult(op.getResult(), std::move(result), - op.getResult().getType()); + store(op.getResult(), SetValue(sharedState.internalizer, std::move(result), + op.getResult().getType())); return DeletionKind::Delete; } FailureOr visitOp(SetUnionOp op) { - SetVector result; + SetVector result; for (auto set : op.getSets()) - result.set_union(cast(state.at(set))->getSet()); + result.set_union(get(set).getSet()); - internalizeResult(op.getResult(), std::move(result), - op.getType()); + store(op.getResult(), + SetValue(sharedState.internalizer, std::move(result), op.getType())); return DeletionKind::Delete; } FailureOr visitOp(SetSizeOp op) { - auto size = cast(state.at(op.getSet()))->getSet().size(); - auto sizeAttr = IntegerAttr::get(IndexType::get(op->getContext()), size); - internalizeResult(op.getResult(), sizeAttr); + auto size = get(op.getSet()).getSet().size(); + store(op.getResult(), IndexValue(size)); return DeletionKind::Delete; } FailureOr visitOp(BagCreateOp op) { - MapVector bag; + MapVector bag; for (auto [val, multiple] : llvm::zip(op.getElements(), op.getMultiples())) { - auto *interpValue = state.at(val); + auto interpValue = get(val); // If the multiple is not stored as an AttributeValue, the elaboration // must have already failed earlier (since we don't have // unevaluated/opaque values). - auto *interpMultiple = cast(state.at(multiple)); - bag[interpValue] += interpMultiple->getIndex(); + auto interpMultiple = get(multiple); + bag[interpValue] += interpMultiple.getIndex(); } - internalizeResult(op.getBag(), std::move(bag), op.getType()); + store(op.getBag(), + BagValue(sharedState.internalizer, std::move(bag), op.getType())); return DeletionKind::Delete; } FailureOr visitOp(BagSelectRandomOp op) { - auto *bag = cast(state.at(op.getBag())); + auto bag = get(op.getBag()); - SmallVector> prefixSum; - prefixSum.reserve(bag->getBag().size()); + SmallVector> prefixSum; + prefixSum.reserve(bag.getBag().size()); uint32_t accumulator = 0; - for (auto [val, weight] : bag->getBag()) { + for (auto [val, weight] : bag.getBag()) { accumulator += weight; prefixSum.push_back({val, accumulator}); } @@ -813,20 +903,21 @@ class Elaborator : public RTGOpVisitor> { auto idx = getUniformlyInRange(customRng, 0, accumulator - 1); auto *iter = llvm::upper_bound( prefixSum, idx, - [](uint32_t a, const std::pair &b) { + [](uint32_t a, const std::pair &b) { return a < b.second; }); - state[op.getResult()] = iter->first; + + store(op.getResult(), iter->first); return DeletionKind::Delete; } FailureOr visitOp(BagDifferenceOp op) { - auto *original = cast(state.at(op.getOriginal())); - auto *diff = cast(state.at(op.getDiff())); + auto original = get(op.getOriginal()); + auto diff = get(op.getDiff()); - MapVector result; - for (const auto &el : original->getBag()) { - if (!diff->getBag().contains(el.first)) { + MapVector result; + for (const auto &el : original.getBag()) { + if (!diff.getBag().contains(el.first)) { result.insert(el); continue; } @@ -834,40 +925,39 @@ class Elaborator : public RTGOpVisitor> { if (op.getInf()) continue; - auto toDiff = diff->getBag().lookup(el.first); + auto toDiff = diff.getBag().lookup(el.first); if (el.second <= toDiff) continue; result.insert({el.first, el.second - toDiff}); } - internalizeResult(op.getResult(), std::move(result), - op.getType()); + store(op.getResult(), + BagValue(sharedState.internalizer, std::move(result), op.getType())); return DeletionKind::Delete; } FailureOr visitOp(BagUnionOp op) { - MapVector result; + MapVector result; for (auto bag : op.getBags()) { - auto *val = cast(state.at(bag)); - for (auto [el, multiple] : val->getBag()) + auto val = get(bag); + for (auto [el, multiple] : val.getBag()) result[el] += multiple; } - internalizeResult(op.getResult(), std::move(result), - op.getType()); + store(op.getResult(), + BagValue(sharedState.internalizer, std::move(result), op.getType())); return DeletionKind::Delete; } FailureOr visitOp(BagUniqueSizeOp op) { - auto size = cast(state.at(op.getBag()))->getBag().size(); - auto sizeAttr = IntegerAttr::get(IndexType::get(op->getContext()), size); - internalizeResult(op.getResult(), sizeAttr); + auto size = get(op.getBag()).getBag().size(); + store(op.getResult(), IndexValue(size)); return DeletionKind::Delete; } FailureOr visitOp(scf::IfOp op) { - bool cond = cast(state.at(op.getCondition()))->getBool(); + bool cond = get(op.getCondition()).getBool(); auto &toElaborate = cond ? op.getThenRegion() : op.getElseRegion(); if (toElaborate.empty()) return DeletionKind::Delete; @@ -889,9 +979,9 @@ class Elaborator : public RTGOpVisitor> { } FailureOr visitOp(scf::ForOp op) { - auto *lowerBound = dyn_cast(state.at(op.getLowerBound())); - auto *step = dyn_cast(state.at(op.getStep())); - auto *upperBound = dyn_cast(state.at(op.getUpperBound())); + auto lowerBound = get(op.getLowerBound()); + auto step = get(op.getStep()); + auto upperBound = get(op.getUpperBound()); if (!lowerBound || !step || !upperBound) return op->emitOpError("can only elaborate index type iterator"); @@ -906,14 +996,14 @@ class Elaborator : public RTGOpVisitor> { state[iterArg] = state.at(initArg); // This loop performs the actual 'scf.for' loop iterations. - for (size_t i = lowerBound->getIndex(); i < upperBound->getIndex(); - i += step->getIndex()) { + for (size_t i = lowerBound.getIndex(); i < upperBound.getIndex(); + i += step.getIndex()) { if (failed(elaborate(op.getBodyRegion()))) return failure(); // Prepare for the next iteration by updating the mapping of the nested // regions block arguments - internalizeResult(op.getInductionVar(), i + step->getIndex()); + store(op.getInductionVar(), IndexValue(i + step.getIndex())); for (auto [iterArg, prevIterArg] : llvm::zip(op.getRegionIterArgs(), op.getBody()->getTerminator()->getOperands())) @@ -933,15 +1023,15 @@ class Elaborator : public RTGOpVisitor> { } FailureOr visitOp(index::AddOp op) { - size_t lhs = cast(state.at(op.getLhs()))->getIndex(); - size_t rhs = cast(state.at(op.getRhs()))->getIndex(); - internalizeResult(op.getResult(), lhs + rhs); + size_t lhs = get(op.getLhs()).getIndex(); + size_t rhs = get(op.getRhs()).getIndex(); + store(op.getResult(), IndexValue(lhs + rhs)); return DeletionKind::Delete; } FailureOr visitOp(index::CmpOp op) { - size_t lhs = cast(state.at(op.getLhs()))->getIndex(); - size_t rhs = cast(state.at(op.getRhs()))->getIndex(); + size_t lhs = get(op.getLhs()).getIndex(); + size_t rhs = get(op.getRhs()).getIndex(); bool result; switch (op.getPred()) { case index::IndexCmpPredicate::EQ: @@ -965,7 +1055,7 @@ class Elaborator : public RTGOpVisitor> { default: return op->emitOpError("elaboration not supported"); } - internalizeResult(op.getResult(), result); + store(op.getResult(), BoolValue(result)); return DeletionKind::Delete; } @@ -983,11 +1073,11 @@ class Elaborator : public RTGOpVisitor> { auto intAttr = dyn_cast(attr); if (intAttr && isa(attr.getType())) - internalizeResult(op->getResult(0), intAttr.getInt()); + store(op->getResult(0), IndexValue(intAttr.getInt())); else if (intAttr && intAttr.getType().isSignlessInteger(1)) - internalizeResult(op->getResult(0), intAttr.getInt()); + store(op->getResult(0), BoolValue(intAttr.getInt())); else - internalizeResult(op->getResult(0), attr); + store(op->getResult(0), AttributeValue(attr)); return DeletionKind::Delete; } @@ -1004,7 +1094,7 @@ class Elaborator : public RTGOpVisitor> { // NOLINTNEXTLINE(misc-no-recursion) LogicalResult elaborate(Region ®ion, - ArrayRef regionArguments = {}) { + ArrayRef regionArguments = {}) { if (region.getBlocks().size() > 1) return region.getParentOp()->emitOpError( "regions with more than one block are not supported"); @@ -1028,7 +1118,7 @@ class Elaborator : public RTGOpVisitor> { llvm::interleaveComma(op.getResults(), llvm::dbgs(), [&](auto res) { if (state.contains(res)) - llvm::dbgs() << *state.at(res); + llvm::dbgs() << get(res); else llvm::dbgs() << "unknown"; }); @@ -1049,7 +1139,7 @@ class Elaborator : public RTGOpVisitor> { Materializer &materializer; // A map from SSA values to a pointer of an interned elaborator value. - DenseMap state; + DenseMap state; }; } // namespace @@ -1147,21 +1237,21 @@ LogicalResult ElaborationPass::elaborateModule(ModuleOp moduleOp, // Do top-down BFS traversal such that elaborating a sequence further down // does not fix the outcome for multiple placements. while (!state.worklist.empty()) { - auto *curr = state.worklist.front(); + auto curr = state.worklist.front(); state.worklist.pop(); - if (table.lookup(curr->getName())) + if (table.lookup(curr.getName())) continue; - auto familyOp = table.lookup(curr->getFamilyName()); + auto familyOp = table.lookup(curr.getFamilyName()); // TODO: don't clone if this is the only remaining reference to this // sequence OpBuilder builder(familyOp); auto seqOp = builder.cloneWithoutRegions(familyOp); seqOp.getBodyRegion().emplaceBlock(); - seqOp.setSymName(curr->getName()); + seqOp.setSymName(curr.getName()); table.insert(seqOp); - assert(seqOp.getSymName() == curr->getName() && + assert(seqOp.getSymName() == curr.getName() && "should not have been renamed"); LLVM_DEBUG(llvm::dbgs() @@ -1170,7 +1260,7 @@ LogicalResult ElaborationPass::elaborateModule(ModuleOp moduleOp, Materializer materializer(OpBuilder::atBlockBegin(seqOp.getBody())); Elaborator elaborator(state, materializer); - if (failed(elaborator.elaborate(familyOp.getBodyRegion(), curr->getArgs()))) + if (failed(elaborator.elaborate(familyOp.getBodyRegion(), curr.getArgs()))) return failure(); materializer.finalize(); From b2d402cd80dd809e41427640f139f3a232436c11 Mon Sep 17 00:00:00 2001 From: Martin Erhart Date: Sat, 14 Dec 2024 17:15:06 +0000 Subject: [PATCH 2/5] Don't use complicated union --- .../RTG/Transforms/ElaborationPass.cpp | 146 ++++++++---------- 1 file changed, 63 insertions(+), 83 deletions(-) diff --git a/lib/Dialect/RTG/Transforms/ElaborationPass.cpp b/lib/Dialect/RTG/Transforms/ElaborationPass.cpp index 02856ced2884..5a82ac97e28c 100644 --- a/lib/Dialect/RTG/Transforms/ElaborationPass.cpp +++ b/lib/Dialect/RTG/Transforms/ElaborationPass.cpp @@ -98,69 +98,28 @@ class ElaboratorValue { None }; - union StorageTy { - StorageTy() : ptr(nullptr) {} - StorageTy(const void *ptr) : ptr(ptr) {} - StorageTy(size_t index) : index(index) {} - StorageTy(bool boolean) : boolean(boolean) {} - - const void *ptr; - size_t index; - bool boolean; - }; - - ElaboratorValue(ValueKind kind = ValueKind::None, - StorageTy storage = StorageTy()) + ElaboratorValue(ValueKind kind = ValueKind::None, uintptr_t storage = 0) : kind(kind), storage(storage) {} // This constructor is needed for LLVM RTTI - ElaboratorValue(StorageTy storage) : ElaboratorValue() {} + ElaboratorValue(uintptr_t storage) : ElaboratorValue() {} llvm::hash_code getHashValue() const { - switch (kind) { - case ValueKind::Attribute: - case ValueKind::Set: - case ValueKind::Bag: - case ValueKind::Sequence: - return llvm::hash_combine(kind, storage.ptr); - case ValueKind::Index: - return llvm::hash_combine(kind, storage.index); - case ValueKind::Bool: - return llvm::hash_combine(kind, storage.boolean); - case ValueKind::None: - return llvm::hash_value(kind); - } - llvm::llvm_unreachable_internal("all cases handled above"); + return llvm::hash_combine(kind, storage); } bool operator==(const ElaboratorValue &other) const { - if (kind != other.kind) - return false; - - switch (kind) { - case ValueKind::Attribute: - case ValueKind::Set: - case ValueKind::Bag: - case ValueKind::Sequence: - return storage.ptr == other.storage.ptr; - case ValueKind::Index: - return storage.index == other.storage.index; - case ValueKind::Bool: - return storage.boolean == other.storage.boolean; - case ValueKind::None: - return true; - } - llvm::llvm_unreachable_internal("all cases handled above"); + return kind == other.kind && storage == other.storage; } operator bool() const { return kind != ValueKind::None; } ValueKind getKind() const { return kind; } - StorageTy getStorage() const { return storage; } + uintptr_t getRawStorage() const { return storage; } -private: +protected: ValueKind kind; - StorageTy storage; + uintptr_t storage; }; } // namespace @@ -191,7 +150,7 @@ struct CastInfo struct DenseMapInfo { static inline ElaboratorValue getEmptyKey() { return ElaboratorValue(); } static inline ElaboratorValue getTombstoneKey() { - return ElaboratorValue(ElaboratorValue::ValueKind::None, - reinterpret_cast(~0ULL)); + return ElaboratorValue(ElaboratorValue::ValueKind::None, ~uintptr_t()); } static unsigned getHashValue(const ElaboratorValue &value) { return value.getHashValue(); @@ -223,7 +181,7 @@ struct SetStorage : public llvm::FoldingSetNode { static void Profile(llvm::FoldingSetNodeID &ID, const SetVector &set, Type type) { for (auto el : set) { - ID.AddPointer(el.getStorage().ptr); + ID.AddInteger(el.getRawStorage()); ID.AddInteger(static_cast(el.getKind())); } ID.AddPointer(type.getAsOpaquePointer()); @@ -249,7 +207,7 @@ struct BagStorage : public llvm::FoldingSetNode { const MapVector &bag, Type type) { for (auto el : bag) { - ID.AddPointer(el.first.getStorage().ptr); + ID.AddInteger(el.first.getRawStorage()); ID.AddInteger(static_cast(el.first.getKind())); ID.AddInteger(el.second); } @@ -278,7 +236,7 @@ struct SequenceStorage : public llvm::FoldingSetNode { ID.AddString(name); ID.AddPointer(familyName.getAsOpaquePointer()); for (auto el : args) { - ID.AddPointer(el.getStorage().ptr); + ID.AddInteger(el.getRawStorage()); ID.AddInteger(static_cast(el.getKind())); } } @@ -361,11 +319,14 @@ class Internalizer { /// attributes because for materialization we need to provide the type to the /// dialect's materializer. struct AttributeValue : public ElaboratorValue { + static_assert(sizeof(uintptr_t) == sizeof(const void *)); + AttributeValue() = default; - AttributeValue(StorageTy storage) + AttributeValue(uintptr_t storage) : ElaboratorValue(ValueKind::Attribute, storage) {} AttributeValue(TypedAttr attr) - : ElaboratorValue(ValueKind::Attribute, attr.getAsOpaquePointer()) { + : ElaboratorValue(ValueKind::Attribute, reinterpret_cast( + attr.getAsOpaquePointer())) { assert(attr && "null attributes not allowed"); } @@ -375,46 +336,55 @@ struct AttributeValue : public ElaboratorValue { } TypedAttr getAttr() const { - return cast(Attribute::getFromOpaquePointer(getStorage().ptr)); + return cast(Attribute::getFromOpaquePointer( + reinterpret_cast(storage))); } }; /// Holds an evaluated value of a `IndexType`'d value. struct IndexValue : public ElaboratorValue { + static_assert(sizeof(uintptr_t) >= sizeof(size_t)); + IndexValue() = default; - IndexValue(StorageTy storage) : ElaboratorValue(ValueKind::Index, storage) {} - IndexValue(size_t index) : ElaboratorValue(ValueKind::Index, index) {} + IndexValue(uintptr_t storage) : ElaboratorValue(ValueKind::Index, storage) {} + // IndexValue(size_t index) : ElaboratorValue(ValueKind::Index, index) {} // Implement LLVMs RTTI static bool classof(const ElaboratorValue &val) { return val.getKind() == ValueKind::Index; } - size_t getIndex() const { return getStorage().index; } + size_t getIndex() const { return storage; } }; /// Holds an evaluated value of an `i1` type'd value. struct BoolValue : public ElaboratorValue { + static_assert(sizeof(uintptr_t) >= sizeof(bool)); + BoolValue() = default; - BoolValue(StorageTy storage) : ElaboratorValue(ValueKind::Bool, storage) {} - BoolValue(bool value) : ElaboratorValue(ValueKind::Bool, value) {} + BoolValue(uintptr_t storage) : ElaboratorValue(ValueKind::Bool, storage) {} + BoolValue(bool value) : ElaboratorValue(ValueKind::Bool, uintptr_t(value)) {} // Implement LLVMs RTTI static bool classof(const ElaboratorValue &val) { return val.getKind() == ValueKind::Bool; } - bool getBool() const { return getStorage().boolean; } + bool getBool() const { return storage; } }; /// Holds an evaluated value of a `SetType`'d value. struct SetValue : public ElaboratorValue { + static_assert(sizeof(uintptr_t) == sizeof(const void *)); + SetValue() = default; - SetValue(StorageTy storage) : ElaboratorValue(ValueKind::Set, storage) {} + SetValue(uintptr_t storage) : ElaboratorValue(ValueKind::Set, storage) {} SetValue(Internalizer &internalizer, SetVector &&set, Type type) - : ElaboratorValue(ValueKind::Set, internalizer.internalize( - std::move(set), type)) {} + : ElaboratorValue( + ValueKind::Set, + reinterpret_cast( + internalizer.internalize(std::move(set), type))) {} // Implement LLVMs RTTI static bool classof(const ElaboratorValue &val) { @@ -422,22 +392,26 @@ struct SetValue : public ElaboratorValue { } const SetVector &getSet() const { - return static_cast(getStorage().ptr)->set; + return reinterpret_cast(storage)->set; } Type getType() const { - return static_cast(getStorage().ptr)->type; + return reinterpret_cast(storage)->type; } }; /// Holds an evaluated value of a `BagType`'d value. struct BagValue : public ElaboratorValue { + static_assert(sizeof(uintptr_t) == sizeof(const void *)); + BagValue() = default; - BagValue(StorageTy storage) : ElaboratorValue(ValueKind::Bag, storage) {} + BagValue(uintptr_t storage) : ElaboratorValue(ValueKind::Bag, storage) {} BagValue(Internalizer &internalizer, MapVector &&bag, Type type) - : ElaboratorValue(ValueKind::Bag, internalizer.internalize( - std::move(bag), type)) {} + : ElaboratorValue( + ValueKind::Bag, + reinterpret_cast( + internalizer.internalize(std::move(bag), type))) {} // Implement LLVMs RTTI static bool classof(const ElaboratorValue &val) { @@ -445,24 +419,27 @@ struct BagValue : public ElaboratorValue { } const MapVector &getBag() const { - return static_cast(getStorage().ptr)->bag; + return reinterpret_cast(storage)->bag; } Type getType() const { - return static_cast(getStorage().ptr)->type; + return reinterpret_cast(storage)->type; } }; /// Holds an evaluated value of a `SequenceType`'d value. struct SequenceValue : public ElaboratorValue { + static_assert(sizeof(uintptr_t) == sizeof(const void *)); + SequenceValue() = default; - SequenceValue(StorageTy storage) + SequenceValue(uintptr_t storage) : ElaboratorValue(ValueKind::Sequence, storage) {} SequenceValue(Internalizer &internalizer, StringRef name, StringAttr familyName, SmallVector &&args) : ElaboratorValue(ValueKind::Sequence, - internalizer.internalize( - name, familyName, std::move(args))) {} + reinterpret_cast( + internalizer.internalize( + name, familyName, std::move(args)))) {} // Implement LLVMs RTTI static bool classof(const ElaboratorValue &val) { @@ -470,13 +447,13 @@ struct SequenceValue : public ElaboratorValue { } StringRef getName() const { - return static_cast(getStorage().ptr)->name; + return reinterpret_cast(storage)->name; } StringAttr getFamilyName() const { - return static_cast(getStorage().ptr)->familyName; + return reinterpret_cast(storage)->familyName; } ArrayRef getArgs() const { - return static_cast(getStorage().ptr)->args; + return reinterpret_cast(storage)->args; } }; } // namespace @@ -494,7 +471,8 @@ static llvm::raw_ostream &operator<<(llvm::raw_ostream &os, .Case([&](auto val) { os << ""; + os << "} at " << reinterpret_cast(val.getRawStorage()) + << ">"; }) .Case([&](auto val) { os << " &el) { os << el.first << " -> " << el.second; }); - os << "} at " << val.getStorage().ptr << ">"; + os << "} at " << reinterpret_cast(val.getRawStorage()) + << ">"; }) .Case([&](auto val) { os << ""; + os << ") at " << reinterpret_cast(val.getRawStorage()) + << ">"; }) .Default([](auto val) { assert(false && "all cases must be covered above"); @@ -1075,7 +1055,7 @@ class Elaborator : public RTGOpVisitor> { if (intAttr && isa(attr.getType())) store(op->getResult(0), IndexValue(intAttr.getInt())); else if (intAttr && intAttr.getType().isSignlessInteger(1)) - store(op->getResult(0), BoolValue(intAttr.getInt())); + store(op->getResult(0), BoolValue(bool(intAttr.getInt()))); else store(op->getResult(0), AttributeValue(attr)); From 30b874841c367ef1d9873b3e5321f3d869a097e5 Mon Sep 17 00:00:00 2001 From: Martin Erhart Date: Sat, 14 Dec 2024 20:43:52 +0000 Subject: [PATCH 3/5] Simpler approach --- .../RTG/Transforms/ElaborationPass.cpp | 201 ++++++++++++------ 1 file changed, 140 insertions(+), 61 deletions(-) diff --git a/lib/Dialect/RTG/Transforms/ElaborationPass.cpp b/lib/Dialect/RTG/Transforms/ElaborationPass.cpp index 5a82ac97e28c..e47f0d03b5d5 100644 --- a/lib/Dialect/RTG/Transforms/ElaborationPass.cpp +++ b/lib/Dialect/RTG/Transforms/ElaborationPass.cpp @@ -122,9 +122,33 @@ class ElaboratorValue { uintptr_t storage; }; +template +struct HashedStorage { + HashedStorage(unsigned hashcode = 0, StorageTy *storage = nullptr) + : hashcode(hashcode), storage(storage) {} + + unsigned hashcode; + StorageTy *storage; +}; + +// struct SetStorage; +// struct BagStorage; +// struct SequenceStorage; + } // namespace namespace llvm { +// llvm::hash_code hash_value(const HashedStorage &storage) { +// return storage.hashcode; +// } +// llvm::hash_code hash_value(const HashedStorage &storage) { +// return storage.hashcode; +// } +// llvm::hash_code hash_value(const HashedStorage &storage) { +// return storage.hashcode; +// } + + /// Add support for llvm style casts. We provide a cast between To and From if /// From is mlir::Attribute or derives from it. template @@ -172,23 +196,63 @@ struct DenseMapInfo { } // namespace llvm namespace { +llvm::hash_code hash_value(const ElaboratorValue &val) { + return val.getHashValue(); +} + +template +struct StorageKeyInfo { + static inline HashedStorage getEmptyKey() { + return HashedStorage(0, DenseMapInfo::getEmptyKey()); + } + static inline HashedStorage getTombstoneKey() { + return HashedStorage(0, DenseMapInfo::getTombstoneKey()); + } + + static inline unsigned getHashValue(const HashedStorage &key) { + return key.hashcode; + } + static inline unsigned getHashValue(const StorageTy &key) { + return key.hashcode; + } + + static inline bool isEqual(const HashedStorage &lhs, + const HashedStorage &rhs) { + return lhs.storage == rhs.storage; + } + static inline bool isEqual(const StorageTy &lhs, const HashedStorage &rhs) { + if (isEqual(rhs, getEmptyKey()) || isEqual(rhs, getTombstoneKey())) + return false; + // Invoke the equality function on the lookup key. + return lhs.isEqual(rhs.storage); + } +}; + +template +struct StorageInfo : public DenseMapInfo { + using Base = DenseMapInfo; + static inline unsigned getHashValue(const StorageTy *key) { + return key->hashcode; + } + + static inline bool isEqual(const StorageTy *lhs, const StorageTy *rhs) { + if (lhs == rhs) + return true; + if (lhs == Base::getEmptyKey() || lhs == Base::getTombstoneKey() || rhs == Base::getEmptyKey() || rhs == Base::getTombstoneKey()) + return false; + return lhs->isEqual(rhs); + } +}; -struct SetStorage : public llvm::FoldingSetNode { +struct SetStorage { SetStorage(SetVector &&set, Type type) - : set(std::move(set)), type(type) {} - - // NOLINTNEXTLINE(readability-identifier-naming) - static void Profile(llvm::FoldingSetNodeID &ID, - const SetVector &set, Type type) { - for (auto el : set) { - ID.AddInteger(el.getRawStorage()); - ID.AddInteger(static_cast(el.getKind())); - } - ID.AddPointer(type.getAsOpaquePointer()); + : hashcode(llvm::hash_combine(type, llvm::hash_combine_range(set.begin(), set.end()))), set(std::move(set)), type(type) {} + + bool isEqual(const SetStorage *other) const { + return set == other->set && type == other->type; } - // NOLINTNEXTLINE(readability-identifier-naming) - void Profile(llvm::FoldingSetNodeID &ID) const { Profile(ID, set, type); } + unsigned hashcode; // Stores the elaborated values of the set. SetVector set; @@ -198,24 +262,15 @@ struct SetStorage : public llvm::FoldingSetNode { Type type; }; -struct BagStorage : public llvm::FoldingSetNode { +struct BagStorage { BagStorage(MapVector &&bag, Type type) - : bag(std::move(bag)), type(type) {} - - // NOLINTNEXTLINE(readability-identifier-naming) - static void Profile(llvm::FoldingSetNodeID &ID, - const MapVector &bag, - Type type) { - for (auto el : bag) { - ID.AddInteger(el.first.getRawStorage()); - ID.AddInteger(static_cast(el.first.getKind())); - ID.AddInteger(el.second); - } - ID.AddPointer(type.getAsOpaquePointer()); + : hashcode(llvm::hash_combine(type, llvm::hash_combine_range(bag.begin(), bag.end()))), bag(std::move(bag)), type(type) {} + + bool isEqual(const BagStorage *other) const { + return llvm::equal(bag, other->bag) && type == other->type; } - // NOLINTNEXTLINE(readability-identifier-naming) - void Profile(llvm::FoldingSetNodeID &ID) const { Profile(ID, bag, type); } + unsigned hashcode; // Stores the elaborated values of the bag. MapVector bag; @@ -225,65 +280,86 @@ struct BagStorage : public llvm::FoldingSetNode { Type type; }; -struct SequenceStorage : public llvm::FoldingSetNode { +struct SequenceStorage { SequenceStorage(StringRef name, StringAttr familyName, SmallVector &&args) - : name(name), familyName(familyName), args(std::move(args)) {} - - // NOLINTNEXTLINE(readability-identifier-naming) - static void Profile(llvm::FoldingSetNodeID &ID, StringRef name, - StringAttr familyName, ArrayRef args) { - ID.AddString(name); - ID.AddPointer(familyName.getAsOpaquePointer()); - for (auto el : args) { - ID.AddInteger(el.getRawStorage()); - ID.AddInteger(static_cast(el.getKind())); - } - } + : hashcode(llvm::hash_combine(name, familyName, llvm::hash_combine_range(args.begin(), args.end()))), name(name), familyName(familyName), args(std::move(args)) {} - // NOLINTNEXTLINE(readability-identifier-naming) - void Profile(llvm::FoldingSetNodeID &ID) const { - Profile(ID, name, familyName, args); + bool isEqual(const SequenceStorage *other) const { + return name == other->name && familyName == other->familyName && args == other->args; } + unsigned hashcode; StringRef name; StringAttr familyName; SmallVector args; }; +// struct LookupKey { +// unsigned hascode; +// function_ref isEqual; +// }; + class Internalizer { public: + // template + // StorageTy *internalize(Args &&...args) { + // StorageTy storage(std::forward(args)...); + + // auto existing = getInternSet().insert_as(HashedStorage(storage.hashcode), storage); + // StorageTy *&storagePtr = existing.first->storage; + // if (existing.second) + // storagePtr = new (allocator.Allocate()) StorageTy(std::move(storage)); + // return storagePtr; + // } + + // template + // DenseSet, StorageKeyInfo> &getInternSet() { + // assert(false && "no generic internalization set"); + // } + + // template <> + // DenseSet, StorageKeyInfo> &getInternSet() { + // return internedSets; + // } + + // template <> + // DenseSet, StorageKeyInfo> &getInternSet() { + // return internedBags; + // } + + // template <> + // DenseSet, StorageKeyInfo> &getInternSet() { + // return internedSequences; + // } + template StorageTy *internalize(Args &&...args) { - llvm::FoldingSetNodeID profile; - StorageTy::Profile(profile, args...); - void *insertPos = nullptr; - if (auto *storage = - getInternSet().FindNodeOrInsertPos(profile, insertPos)) - return static_cast(storage); - auto *storagePtr = new (allocator.Allocate()) - StorageTy(std::forward(args)...); - getInternSet().InsertNode(storagePtr, insertPos); - return storagePtr; + auto *storagePtr = new (allocator.Allocate()) StorageTy(std::forward(args)...); + auto existing = getInternSet().insert(storagePtr); + if (!existing.second) + allocator.Deallocate(storagePtr); + + return *existing.first; } template - llvm::FoldingSet &getInternSet() { + DenseSet> &getInternSet() { assert(false && "no generic internalization set"); } template <> - llvm::FoldingSet &getInternSet() { + DenseSet> &getInternSet() { return internedSets; } template <> - llvm::FoldingSet &getInternSet() { + DenseSet> &getInternSet() { return internedBags; } template <> - llvm::FoldingSet &getInternSet() { + DenseSet> &getInternSet() { return internedSequences; } @@ -308,9 +384,12 @@ class Internalizer { // inserting an object of a derived class of ElaboratorValue. // The custom MapInfo makes sure that we do a value comparison instead of // comparing the pointers. - llvm::FoldingSet internedSets; - llvm::FoldingSet internedBags; - llvm::FoldingSet internedSequences; + // DenseSet, StorageKeyInfo> internedSets; + // DenseSet, StorageKeyInfo> internedBags; + // DenseSet, StorageKeyInfo> internedSequences; + DenseSet> internedSets; + DenseSet> internedBags; + DenseSet> internedSequences; }; /// Holds any typed attribute. Wrapping around an MLIR `Attribute` allows us to From 33cc1b8c10335ebc69158af99a22096a260074fe Mon Sep 17 00:00:00 2001 From: Martin Erhart Date: Sat, 14 Dec 2024 21:00:42 +0000 Subject: [PATCH 4/5] More complicated, but faster approach --- .../RTG/Transforms/ElaborationPass.cpp | 305 ++++++++---------- 1 file changed, 137 insertions(+), 168 deletions(-) diff --git a/lib/Dialect/RTG/Transforms/ElaborationPass.cpp b/lib/Dialect/RTG/Transforms/ElaborationPass.cpp index e47f0d03b5d5..7fc08aa87a63 100644 --- a/lib/Dialect/RTG/Transforms/ElaborationPass.cpp +++ b/lib/Dialect/RTG/Transforms/ElaborationPass.cpp @@ -22,7 +22,6 @@ #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/PatternMatch.h" -#include "llvm/ADT/FoldingSet.h" #include "llvm/Support/Debug.h" #include #include @@ -80,7 +79,7 @@ static uint32_t getUniformlyInRange(std::mt19937 &rng, uint32_t a, uint32_t b) { } //===----------------------------------------------------------------------===// -// Elaborator Values +// Elaborator Value Base //===----------------------------------------------------------------------===// namespace { @@ -89,13 +88,13 @@ namespace { class ElaboratorValue { public: enum class ValueKind { - Attribute = 0U, - Set, + None = 0U, + Attribute, Bag, - Sequence, - Index, Bool, - None + Index, + Sequence, + Set, }; ElaboratorValue(ValueKind kind = ValueKind::None, uintptr_t storage = 0) @@ -122,32 +121,14 @@ class ElaboratorValue { uintptr_t storage; }; -template -struct HashedStorage { - HashedStorage(unsigned hashcode = 0, StorageTy *storage = nullptr) - : hashcode(hashcode), storage(storage) {} - - unsigned hashcode; - StorageTy *storage; -}; - -// struct SetStorage; -// struct BagStorage; -// struct SequenceStorage; +// NOLINTNEXTLINE(readability-identifier-naming) +llvm::hash_code hash_value(const ElaboratorValue &val) { + return val.getHashValue(); +} } // namespace namespace llvm { -// llvm::hash_code hash_value(const HashedStorage &storage) { -// return storage.hashcode; -// } -// llvm::hash_code hash_value(const HashedStorage &storage) { -// return storage.hashcode; -// } -// llvm::hash_code hash_value(const HashedStorage &storage) { -// return storage.hashcode; -// } - /// Add support for llvm style casts. We provide a cast between To and From if /// From is mlir::Attribute or derives from it. @@ -195,18 +176,38 @@ struct DenseMapInfo { } // namespace llvm +//===----------------------------------------------------------------------===// +// Elaborator Value Storages and Internalization +//===----------------------------------------------------------------------===// + namespace { -llvm::hash_code hash_value(const ElaboratorValue &val) { - return val.getHashValue(); -} -template +/// Lightweight object to be used as the key for internalization sets. It caches +/// the hashcode of the internalized object and a pointer to it. This allows a +/// delayed allocation and construction of the actual object and thus only has +/// to happen if the object is not already in the set. +template +struct HashedStorage { + HashedStorage(unsigned hashcode = 0, StorageTy *storage = nullptr) + : hashcode(hashcode), storage(storage) {} + + unsigned hashcode; + StorageTy *storage; +}; + +/// A DenseMapInfo implementation to support 'insert_as' for the internalization +/// sets. When comparing two 'HashedStorage's we can just compare the already +/// internalized storage pointers, otherwise we have to call the costly +/// 'isEqual' method. +template struct StorageKeyInfo { static inline HashedStorage getEmptyKey() { - return HashedStorage(0, DenseMapInfo::getEmptyKey()); + return HashedStorage(0, + DenseMapInfo::getEmptyKey()); } static inline HashedStorage getTombstoneKey() { - return HashedStorage(0, DenseMapInfo::getTombstoneKey()); + return HashedStorage( + 0, DenseMapInfo::getTombstoneKey()); } static inline unsigned getHashValue(const HashedStorage &key) { @@ -217,181 +218,148 @@ struct StorageKeyInfo { } static inline bool isEqual(const HashedStorage &lhs, - const HashedStorage &rhs) { + const HashedStorage &rhs) { return lhs.storage == rhs.storage; } - static inline bool isEqual(const StorageTy &lhs, const HashedStorage &rhs) { + static inline bool isEqual(const StorageTy &lhs, + const HashedStorage &rhs) { if (isEqual(rhs, getEmptyKey()) || isEqual(rhs, getTombstoneKey())) return false; - // Invoke the equality function on the lookup key. - return lhs.isEqual(rhs.storage); - } -}; - -template -struct StorageInfo : public DenseMapInfo { - using Base = DenseMapInfo; - static inline unsigned getHashValue(const StorageTy *key) { - return key->hashcode; - } - static inline bool isEqual(const StorageTy *lhs, const StorageTy *rhs) { - if (lhs == rhs) - return true; - if (lhs == Base::getEmptyKey() || lhs == Base::getTombstoneKey() || rhs == Base::getEmptyKey() || rhs == Base::getTombstoneKey()) - return false; - return lhs->isEqual(rhs); + return lhs.isEqual(rhs.storage); } }; +/// Storage object for an '!rtg.set'. struct SetStorage { SetStorage(SetVector &&set, Type type) - : hashcode(llvm::hash_combine(type, llvm::hash_combine_range(set.begin(), set.end()))), set(std::move(set)), type(type) {} + : hashcode(llvm::hash_combine( + type, llvm::hash_combine_range(set.begin(), set.end()))), + set(std::move(set)), type(type) {} bool isEqual(const SetStorage *other) const { - return set == other->set && type == other->type; + return hashcode == other->hashcode && set == other->set && + type == other->type; } - unsigned hashcode; + // The cached hashcode to avoid repeated computations. + const unsigned hashcode; - // Stores the elaborated values of the set. - SetVector set; + // Stores the elaborated values contained in the set. + const SetVector set; // Store the set type such that we can materialize this evaluated value // also in the case where the set is empty. - Type type; + const Type type; }; +/// Storage object for an '!rtg.bag'. struct BagStorage { BagStorage(MapVector &&bag, Type type) - : hashcode(llvm::hash_combine(type, llvm::hash_combine_range(bag.begin(), bag.end()))), bag(std::move(bag)), type(type) {} + : hashcode(llvm::hash_combine( + type, llvm::hash_combine_range(bag.begin(), bag.end()))), + bag(std::move(bag)), type(type) {} bool isEqual(const BagStorage *other) const { - return llvm::equal(bag, other->bag) && type == other->type; + return hashcode == other->hashcode && llvm::equal(bag, other->bag) && + type == other->type; } - unsigned hashcode; + // The cached hashcode to avoid repeated computations. + const unsigned hashcode; - // Stores the elaborated values of the bag. - MapVector bag; + // Stores the elaborated values contained in the bag with their number of + // occurences. + const MapVector bag; // Store the bag type such that we can materialize this evaluated value // also in the case where the bag is empty. - Type type; + const Type type; }; +/// Storage object for an '!rtg.sequence'. struct SequenceStorage { SequenceStorage(StringRef name, StringAttr familyName, SmallVector &&args) - : hashcode(llvm::hash_combine(name, familyName, llvm::hash_combine_range(args.begin(), args.end()))), name(name), familyName(familyName), args(std::move(args)) {} + : hashcode(llvm::hash_combine( + name, familyName, + llvm::hash_combine_range(args.begin(), args.end()))), + name(name), familyName(familyName), args(std::move(args)) {} bool isEqual(const SequenceStorage *other) const { - return name == other->name && familyName == other->familyName && args == other->args; + return hashcode == other->hashcode && name == other->name && + familyName == other->familyName && args == other->args; } - unsigned hashcode; - StringRef name; - StringAttr familyName; - SmallVector args; -}; + // The cached hashcode to avoid repeated computations. + const unsigned hashcode; -// struct LookupKey { -// unsigned hascode; -// function_ref isEqual; -// }; + // The name of this fully substituted and elaborated sequence. + const StringRef name; + // The name of the sequence family this sequence is derived from. + const StringAttr familyName; + + // The elaborator values used during substitution of the sequence family. + const SmallVector args; +}; + +/// An 'Internalizer' object internalizes storages and takes ownership of them. +/// When the initializer object is destroyed, all owned storages are also +/// deallocated and thus must not be accessed anymore. class Internalizer { public: - // template - // StorageTy *internalize(Args &&...args) { - // StorageTy storage(std::forward(args)...); - - // auto existing = getInternSet().insert_as(HashedStorage(storage.hashcode), storage); - // StorageTy *&storagePtr = existing.first->storage; - // if (existing.second) - // storagePtr = new (allocator.Allocate()) StorageTy(std::move(storage)); - // return storagePtr; - // } - - // template - // DenseSet, StorageKeyInfo> &getInternSet() { - // assert(false && "no generic internalization set"); - // } - - // template <> - // DenseSet, StorageKeyInfo> &getInternSet() { - // return internedSets; - // } - - // template <> - // DenseSet, StorageKeyInfo> &getInternSet() { - // return internedBags; - // } - - // template <> - // DenseSet, StorageKeyInfo> &getInternSet() { - // return internedSequences; - // } - + /// Internalize a storage of type `StorageTy` constructed with arguments + /// `args`. The pointers returned by this method can be used to compare + /// objects when, e.g., computing set differences, uniquing the elements in a + /// set, etc. Otherwise, we'd need to do a deep value comparison in those + /// situations. template StorageTy *internalize(Args &&...args) { - auto *storagePtr = new (allocator.Allocate()) StorageTy(std::forward(args)...); - auto existing = getInternSet().insert(storagePtr); - if (!existing.second) - allocator.Deallocate(storagePtr); - - return *existing.first; - } + StorageTy storage(std::forward(args)...); - template - DenseSet> &getInternSet() { - assert(false && "no generic internalization set"); - } - - template <> - DenseSet> &getInternSet() { - return internedSets; - } + auto existing = getInternSet().insert_as( + HashedStorage(storage.hashcode), storage); + StorageTy *&storagePtr = existing.first->storage; + if (existing.second) + storagePtr = + new (allocator.Allocate()) StorageTy(std::move(storage)); - template <> - DenseSet> &getInternSet() { - return internedBags; + return storagePtr; } - template <> - DenseSet> &getInternSet() { - return internedSequences; - } - - // BagStorage *internalize(BagStorage &&storage) { - // llvm::FoldingSetNodeID profile; - // storage.Profile(profile); - // void *insertPos = nullptr; - // if (auto *bag = internedBags.FindNodeOrInsertPos(profile, insertPos)) - // return bag; - // auto *storagePtr = new BagStorage(std::move(storage)); - // internedBags.InsertNode(storagePtr, insertPos); - // return storagePtr; - // } - private: + template + DenseSet, StorageKeyInfo> & + getInternSet() { + if constexpr (std::is_same_v) + return internedSets; + else if constexpr (std::is_same_v) + return internedBags; + else if constexpr (std::is_same_v) + return internedSequences; + else + static_assert(!sizeof(StorageTy), + "no intern set available for this storage type."); + } + + // This allocator allocates on the heap. It automatically deallocates all + // objects it allocated once the allocator itself is destroyed. llvm::BumpPtrAllocator allocator; - // A map used to intern elaborator values. We do this such that we can - // compare pointers when, e.g., computing set differences, uniquing the - // elements in a set, etc. Otherwise, we'd need to do a deep value comparison - // in those situations. - // Use a pointer as the key with custom MapInfo because of object slicing when - // inserting an object of a derived class of ElaboratorValue. - // The custom MapInfo makes sure that we do a value comparison instead of - // comparing the pointers. - // DenseSet, StorageKeyInfo> internedSets; - // DenseSet, StorageKeyInfo> internedBags; - // DenseSet, StorageKeyInfo> internedSequences; - DenseSet> internedSets; - DenseSet> internedBags; - DenseSet> internedSequences; + + // The sets holding the internalized objects. We use one set per storage type + // such that we can have a simpler equality checking function (no need to + // compare some sort of TypeIDs). + DenseSet, StorageKeyInfo> internedSets; + DenseSet, StorageKeyInfo> internedBags; + DenseSet, StorageKeyInfo> + internedSequences; }; +//===----------------------------------------------------------------------===// +// Concrete Elaborator Values +//===----------------------------------------------------------------------===// + /// Holds any typed attribute. Wrapping around an MLIR `Attribute` allows us to /// use this elaborator value class for any values that have a corresponding /// MLIR attribute rather than one per kind of attribute. We only support typed @@ -550,8 +518,8 @@ static llvm::raw_ostream &operator<<(llvm::raw_ostream &os, .Case([&](auto val) { os << "(val.getRawStorage()) - << ">"; + os << "} at " + << reinterpret_cast(val.getRawStorage()) << ">"; }) .Case([&](auto val) { os << " &el) { os << el.first << " -> " << el.second; }); - os << "} at " << reinterpret_cast(val.getRawStorage()) - << ">"; + os << "} at " + << reinterpret_cast(val.getRawStorage()) << ">"; }) .Case([&](auto val) { os << "(val.getRawStorage()) + os << ") at " + << reinterpret_cast(val.getRawStorage()) << ">"; }) .Default([](auto val) { @@ -580,11 +549,7 @@ static llvm::raw_ostream &operator<<(llvm::raw_ostream &os, #endif //===----------------------------------------------------------------------===// -// Hash Map Helpers -//===----------------------------------------------------------------------===// - -//===----------------------------------------------------------------------===// -// Main Elaborator Implementation +// Elaborator Value Materialization //===----------------------------------------------------------------------===// namespace { @@ -799,6 +764,10 @@ class Materializer { SmallVector toDelete; }; +//===----------------------------------------------------------------------===// +// Elaboration Visitor +//===----------------------------------------------------------------------===// + /// Used to signal to the elaboration driver whether the operation should be /// removed. enum class DeletionKind { Keep, Delete }; @@ -1160,7 +1129,7 @@ class Elaborator : public RTGOpVisitor> { for (auto [arg, elabArg] : llvm::zip(region.getArguments(), regionArguments)) - state[arg] = elabArg; + store(arg, elabArg); Block *block = ®ion.front(); for (auto &op : *block) { From 47522325eb734204c0ef03ad359e1d6c8f39719b Mon Sep 17 00:00:00 2001 From: Martin Erhart Date: Fri, 17 Jan 2025 14:28:00 +0000 Subject: [PATCH 5/5] Use std::variant --- .../RTG/Transforms/ElaborationPass.cpp | 533 ++++++------------ 1 file changed, 160 insertions(+), 373 deletions(-) diff --git a/lib/Dialect/RTG/Transforms/ElaborationPass.cpp b/lib/Dialect/RTG/Transforms/ElaborationPass.cpp index 7fc08aa87a63..f1f2350613b9 100644 --- a/lib/Dialect/RTG/Transforms/ElaborationPass.cpp +++ b/lib/Dialect/RTG/Transforms/ElaborationPass.cpp @@ -22,6 +22,7 @@ #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/PatternMatch.h" +#include "llvm/ADT/DenseMapInfoVariant.h" #include "llvm/Support/Debug.h" #include #include @@ -79,99 +80,40 @@ static uint32_t getUniformlyInRange(std::mt19937 &rng, uint32_t a, uint32_t b) { } //===----------------------------------------------------------------------===// -// Elaborator Value Base +// Elaborator Value //===----------------------------------------------------------------------===// namespace { +struct BagStorage; +struct SequenceStorage; +struct SetStorage; /// The abstract base class for elaborated values. -class ElaboratorValue { -public: - enum class ValueKind { - None = 0U, - Attribute, - Bag, - Bool, - Index, - Sequence, - Set, - }; - - ElaboratorValue(ValueKind kind = ValueKind::None, uintptr_t storage = 0) - : kind(kind), storage(storage) {} - - // This constructor is needed for LLVM RTTI - ElaboratorValue(uintptr_t storage) : ElaboratorValue() {} - - llvm::hash_code getHashValue() const { - return llvm::hash_combine(kind, storage); - } - - bool operator==(const ElaboratorValue &other) const { - return kind == other.kind && storage == other.storage; - } - - operator bool() const { return kind != ValueKind::None; } - - ValueKind getKind() const { return kind; } - uintptr_t getRawStorage() const { return storage; } - -protected: - ValueKind kind; - uintptr_t storage; -}; +using ElaboratorValue = std::variant; // NOLINTNEXTLINE(readability-identifier-naming) llvm::hash_code hash_value(const ElaboratorValue &val) { - return val.getHashValue(); + return std::visit( + [&val](const auto &alternative) { + // Include index in hash to make sure same value as different + // alternatives don't collide. + return llvm::hash_combine(val.index(), alternative); + }, + val); } } // namespace namespace llvm { -/// Add support for llvm style casts. We provide a cast between To and From if -/// From is mlir::Attribute or derives from it. -template -struct CastInfo> || - std::is_base_of_v>> - : DefaultDoCastIfPossible> { - /// Arguments are taken as mlir::Attribute here and not as `From`, because - /// when casting from an intermediate type of the hierarchy to one of its - /// children, the val.getTypeID() inside T::classof will use the static - /// getTypeID of the parent instead of the non-static Type::getTypeID that - /// returns the dynamic ID. This means that T::classof would end up comparing - /// the static TypeID of the children to the static TypeID of its parent, - /// making it impossible to downcast from the parent to the child. - static inline bool isPossible(ElaboratorValue ty) { - /// Return a constant true instead of a dynamic true when casting to self or - /// up the hierarchy. - if constexpr (std::is_base_of_v) { - return true; - } else { - return To::classof(ty); - } - } - static inline To doCast(ElaboratorValue value) { - return To(value.getRawStorage()); - } - static To castFailed() { return To(); } -}; - template <> -struct DenseMapInfo { - static inline ElaboratorValue getEmptyKey() { return ElaboratorValue(); } - static inline ElaboratorValue getTombstoneKey() { - return ElaboratorValue(ElaboratorValue::ValueKind::None, ~uintptr_t()); - } - static unsigned getHashValue(const ElaboratorValue &value) { - return value.getHashValue(); - } - static bool isEqual(const ElaboratorValue &lhs, const ElaboratorValue &rhs) { - return lhs == rhs; - } +struct DenseMapInfo { + static inline unsigned getEmptyKey() { return false; } + static inline unsigned getTombstoneKey() { return true; } + static unsigned getHashValue(const bool &val) { return val * 37U; } + + static bool isEqual(const bool &lhs, const bool &rhs) { return lhs == rhs; } }; } // namespace llvm @@ -356,196 +298,56 @@ class Internalizer { internedSequences; }; -//===----------------------------------------------------------------------===// -// Concrete Elaborator Values -//===----------------------------------------------------------------------===// - -/// Holds any typed attribute. Wrapping around an MLIR `Attribute` allows us to -/// use this elaborator value class for any values that have a corresponding -/// MLIR attribute rather than one per kind of attribute. We only support typed -/// attributes because for materialization we need to provide the type to the -/// dialect's materializer. -struct AttributeValue : public ElaboratorValue { - static_assert(sizeof(uintptr_t) == sizeof(const void *)); - - AttributeValue() = default; - AttributeValue(uintptr_t storage) - : ElaboratorValue(ValueKind::Attribute, storage) {} - AttributeValue(TypedAttr attr) - : ElaboratorValue(ValueKind::Attribute, reinterpret_cast( - attr.getAsOpaquePointer())) { - assert(attr && "null attributes not allowed"); - } - - // Implement LLVMs RTTI - static bool classof(const ElaboratorValue &val) { - return val.getKind() == ValueKind::Attribute; - } - - TypedAttr getAttr() const { - return cast(Attribute::getFromOpaquePointer( - reinterpret_cast(storage))); - } -}; - -/// Holds an evaluated value of a `IndexType`'d value. -struct IndexValue : public ElaboratorValue { - static_assert(sizeof(uintptr_t) >= sizeof(size_t)); - - IndexValue() = default; - IndexValue(uintptr_t storage) : ElaboratorValue(ValueKind::Index, storage) {} - // IndexValue(size_t index) : ElaboratorValue(ValueKind::Index, index) {} - - // Implement LLVMs RTTI - static bool classof(const ElaboratorValue &val) { - return val.getKind() == ValueKind::Index; - } - - size_t getIndex() const { return storage; } -}; - -/// Holds an evaluated value of an `i1` type'd value. -struct BoolValue : public ElaboratorValue { - static_assert(sizeof(uintptr_t) >= sizeof(bool)); - - BoolValue() = default; - BoolValue(uintptr_t storage) : ElaboratorValue(ValueKind::Bool, storage) {} - BoolValue(bool value) : ElaboratorValue(ValueKind::Bool, uintptr_t(value)) {} - - // Implement LLVMs RTTI - static bool classof(const ElaboratorValue &val) { - return val.getKind() == ValueKind::Bool; - } - - bool getBool() const { return storage; } -}; +} // namespace -/// Holds an evaluated value of a `SetType`'d value. -struct SetValue : public ElaboratorValue { - static_assert(sizeof(uintptr_t) == sizeof(const void *)); - - SetValue() = default; - SetValue(uintptr_t storage) : ElaboratorValue(ValueKind::Set, storage) {} - SetValue(Internalizer &internalizer, SetVector &&set, - Type type) - : ElaboratorValue( - ValueKind::Set, - reinterpret_cast( - internalizer.internalize(std::move(set), type))) {} - - // Implement LLVMs RTTI - static bool classof(const ElaboratorValue &val) { - return val.getKind() == ValueKind::Set; - } +#ifndef NDEBUG - const SetVector &getSet() const { - return reinterpret_cast(storage)->set; - } +static llvm::raw_ostream &operator<<(llvm::raw_ostream &os, + const ElaboratorValue &value); - Type getType() const { - return reinterpret_cast(storage)->type; - } -}; +static void print(TypedAttr val, llvm::raw_ostream &os) { + os << ""; +} -/// Holds an evaluated value of a `BagType`'d value. -struct BagValue : public ElaboratorValue { - static_assert(sizeof(uintptr_t) == sizeof(const void *)); - - BagValue() = default; - BagValue(uintptr_t storage) : ElaboratorValue(ValueKind::Bag, storage) {} - BagValue(Internalizer &internalizer, - MapVector &&bag, Type type) - : ElaboratorValue( - ValueKind::Bag, - reinterpret_cast( - internalizer.internalize(std::move(bag), type))) {} - - // Implement LLVMs RTTI - static bool classof(const ElaboratorValue &val) { - return val.getKind() == ValueKind::Bag; - } +static void print(BagStorage *val, llvm::raw_ostream &os) { + os << "bag, os, + [&](const std::pair &el) { + os << el.first << " -> " << el.second; + }); + os << "} at " << val << ">"; +} - const MapVector &getBag() const { - return reinterpret_cast(storage)->bag; - } +static void print(bool val, llvm::raw_ostream &os) { + os << ""; +} - Type getType() const { - return reinterpret_cast(storage)->type; - } -}; +static void print(size_t val, llvm::raw_ostream &os) { + os << ""; +} -/// Holds an evaluated value of a `SequenceType`'d value. -struct SequenceValue : public ElaboratorValue { - static_assert(sizeof(uintptr_t) == sizeof(const void *)); - - SequenceValue() = default; - SequenceValue(uintptr_t storage) - : ElaboratorValue(ValueKind::Sequence, storage) {} - SequenceValue(Internalizer &internalizer, StringRef name, - StringAttr familyName, SmallVector &&args) - : ElaboratorValue(ValueKind::Sequence, - reinterpret_cast( - internalizer.internalize( - name, familyName, std::move(args)))) {} - - // Implement LLVMs RTTI - static bool classof(const ElaboratorValue &val) { - return val.getKind() == ValueKind::Sequence; - } +static void print(SequenceStorage *val, llvm::raw_ostream &os) { + os << "name << " derived from @" + << val->familyName.getValue() << "("; + llvm::interleaveComma(val->args, os, + [&](const ElaboratorValue &val) { os << val; }); + os << ") at " << val << ">"; +} - StringRef getName() const { - return reinterpret_cast(storage)->name; - } - StringAttr getFamilyName() const { - return reinterpret_cast(storage)->familyName; - } - ArrayRef getArgs() const { - return reinterpret_cast(storage)->args; - } -}; -} // namespace +static void print(SetStorage *val, llvm::raw_ostream &os) { + os << "set, os, + [&](const ElaboratorValue &val) { os << val; }); + os << "} at " << val << ">"; +} -#ifndef NDEBUG static llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const ElaboratorValue &value) { - TypeSwitch(value) - .Case( - [&](auto val) { os << ""; }) - .Case( - [&](auto val) { os << ""; }) - .Case( - [&](auto val) { os << ""; }) - .Case([&](auto val) { - os << "(val.getRawStorage()) << ">"; - }) - .Case([&](auto val) { - os << " &el) { - os << el.first << " -> " << el.second; - }); - os << "} at " - << reinterpret_cast(val.getRawStorage()) << ">"; - }) - .Case([&](auto val) { - os << "(val.getRawStorage()) - << ">"; - }) - .Default([](auto val) { - assert(false && "all cases must be covered above"); - return Value(); - }); + std::visit([&](auto val) { print(val, os); }, value); + return os; } + #endif //===----------------------------------------------------------------------===// @@ -562,7 +364,7 @@ class Materializer { /// Materialize IR representing the provided `ElaboratorValue` and return the /// `Value` or a null value on failure. Value materialize(ElaboratorValue val, Location loc, - std::queue &elabRequests, + std::queue &elabRequests, function_ref emitError) { auto iter = materializedValues.find(val); if (iter != materializedValues.end()) @@ -570,14 +372,9 @@ class Materializer { LLVM_DEBUG(llvm::dbgs() << "Materializing " << val << "\n\n"); - return TypeSwitch(val) - .Case( - [&](auto val) { return visit(val, loc, elabRequests, emitError); }) - .Default([](auto val) { - assert(false && "all cases must be covered above"); - return Value(); - }); + return std::visit( + [&](auto val) { return visit(val, loc, elabRequests, emitError); }, + val); } /// If `op` is not in the same region as the materializer insertion point, a @@ -588,7 +385,7 @@ class Materializer { /// before the insertion point. LogicalResult materialize(Operation *op, DenseMap &state, - std::queue &elabRequests) { + std::queue &elabRequests) { if (op->getNumRegions() > 0) return op->emitOpError("ops with nested regions must be elaborated away"); @@ -655,15 +452,13 @@ class Materializer { } private: - Value visit(const AttributeValue &val, Location loc, - std::queue &elabRequests, + Value visit(TypedAttr val, Location loc, + std::queue &elabRequests, function_ref emitError) { - auto attr = val.getAttr(); - // For index attributes (and arithmetic operations on them) we use the // index dialect. - if (auto intAttr = dyn_cast(attr); - intAttr && isa(attr.getType())) { + if (auto intAttr = dyn_cast(val); + intAttr && isa(val.getType())) { Value res = builder.create(loc, intAttr); materializedValues[val] = res; return res; @@ -671,12 +466,12 @@ class Materializer { // For any other attribute, we just call the materializer of the dialect // defining that attribute. - auto *op = attr.getDialect().materializeConstant(builder, attr, - attr.getType(), loc); + auto *op = + val.getDialect().materializeConstant(builder, val, val.getType(), loc); if (!op) { emitError() << "materializer of dialect '" - << attr.getDialect().getNamespace() - << "' unable to materialize value for attribute '" << attr + << val.getDialect().getNamespace() + << "' unable to materialize value for attribute '" << val << "'"; return Value(); } @@ -686,28 +481,28 @@ class Materializer { return res; } - Value visit(const IndexValue &val, Location loc, - std::queue &elabRequests, + Value visit(size_t val, Location loc, + std::queue &elabRequests, function_ref emitError) { - Value res = builder.create(loc, val.getIndex()); + Value res = builder.create(loc, val); materializedValues[val] = res; return res; } - Value visit(const BoolValue &val, Location loc, - std::queue &elabRequests, + Value visit(bool val, Location loc, + std::queue &elabRequests, function_ref emitError) { - Value res = builder.create(loc, val.getBool()); + Value res = builder.create(loc, val); materializedValues[val] = res; return res; } - Value visit(const SetValue &val, Location loc, - std::queue &elabRequests, + Value visit(SetStorage *val, Location loc, + std::queue &elabRequests, function_ref emitError) { SmallVector elements; - elements.reserve(val.getSet().size()); - for (auto el : val.getSet()) { + elements.reserve(val->set.size()); + for (auto el : val->set) { auto materialized = materialize(el, loc, elabRequests, emitError); if (!materialized) return Value(); @@ -715,21 +510,21 @@ class Materializer { elements.push_back(materialized); } - auto res = builder.create(loc, val.getType(), elements); + auto res = builder.create(loc, val->type, elements); materializedValues[val] = res; return res; } - Value visit(const BagValue &val, Location loc, - std::queue &elabRequests, + Value visit(BagStorage *val, Location loc, + std::queue &elabRequests, function_ref emitError) { SmallVector values, weights; - values.reserve(val.getBag().size()); - weights.reserve(val.getBag().size()); - for (auto [val, weight] : val.getBag()) { + values.reserve(val->bag.size()); + weights.reserve(val->bag.size()); + for (auto [val, weight] : val->bag) { auto materializedVal = materialize(val, loc, elabRequests, emitError); auto materializedWeight = - materialize(IndexValue(weight), loc, elabRequests, emitError); + materialize(weight, loc, elabRequests, emitError); if (!materializedVal || !materializedWeight) return Value(); @@ -737,16 +532,16 @@ class Materializer { weights.push_back(materializedWeight); } - auto res = builder.create(loc, val.getType(), values, weights); + auto res = builder.create(loc, val->type, values, weights); materializedValues[val] = res; return res; } - Value visit(const SequenceValue &val, Location loc, - std::queue &elabRequests, + Value visit(SequenceStorage *val, Location loc, + std::queue &elabRequests, function_ref emitError) { elabRequests.push(val); - return builder.create(loc, val.getName(), ValueRange()); + return builder.create(loc, val->name, ValueRange()); } private: @@ -784,7 +579,7 @@ struct ElaboratorSharedState { /// The worklist used to keep track of the test and sequence operations to /// make sure they are processed top-down (BFS traversal). - std::queue worklist; + std::queue worklist; }; /// Interprets the IR to perform and lower the represented randomizations. @@ -797,17 +592,11 @@ class Elaborator : public RTGOpVisitor> { Elaborator(ElaboratorSharedState &sharedState, Materializer &materializer) : sharedState(sharedState), materializer(materializer) {} - inline void store(Value val, const ElaboratorValue &eval) { - state[val] = eval; - } - template inline ValueTy get(Value val) { - return dyn_cast(state.at(val)); + return std::get(state.at(val)); } - inline ElaboratorValue get(Value val) { return state.at(val); } - /// Print a nice error message for operations we don't support yet. FailureOr visitUnhandledOp(Operation *op) { return op->emitOpError("elaboration not supported"); @@ -826,12 +615,13 @@ class Elaborator : public RTGOpVisitor> { FailureOr visitOp(SequenceClosureOp op) { SmallVector args; for (auto arg : op.getArgs()) - args.push_back(get(arg)); + args.push_back(state.at(arg)); auto familyName = op.getSequenceAttr(); auto name = sharedState.names.newName(familyName.getValue()); - store(op.getResult(), SequenceValue(sharedState.internalizer, name, - familyName, std::move(args))); + state[op.getResult()] = + sharedState.internalizer.internalize(name, familyName, + std::move(args)); return DeletionKind::Delete; } @@ -842,55 +632,54 @@ class Elaborator : public RTGOpVisitor> { FailureOr visitOp(SetCreateOp op) { SetVector set; for (auto val : op.getElements()) - set.insert(get(val)); + set.insert(state.at(val)); - store(op.getSet(), SetValue(sharedState.internalizer, std::move(set), - op.getSet().getType())); + state[op.getSet()] = sharedState.internalizer.internalize( + std::move(set), op.getSet().getType()); return DeletionKind::Delete; } FailureOr visitOp(SetSelectRandomOp op) { - auto set = cast(get(op.getSet())); + auto set = get(op.getSet())->set; size_t selected; if (auto intAttr = op->getAttrOfType("rtg.elaboration_custom_seed")) { std::mt19937 customRng(intAttr.getInt()); - selected = getUniformlyInRange(customRng, 0, set.getSet().size() - 1); + selected = getUniformlyInRange(customRng, 0, set.size() - 1); } else { - selected = - getUniformlyInRange(sharedState.rng, 0, set.getSet().size() - 1); + selected = getUniformlyInRange(sharedState.rng, 0, set.size() - 1); } - store(op.getResult(), set.getSet()[selected]); + state[op.getResult()] = set[selected]; return DeletionKind::Delete; } FailureOr visitOp(SetDifferenceOp op) { - auto original = get(op.getOriginal()).getSet(); - auto diff = get(op.getDiff()).getSet(); + auto original = get(op.getOriginal())->set; + auto diff = get(op.getDiff())->set; SetVector result(original); result.set_subtract(diff); - store(op.getResult(), SetValue(sharedState.internalizer, std::move(result), - op.getResult().getType())); + state[op.getResult()] = sharedState.internalizer.internalize( + std::move(result), op.getResult().getType()); return DeletionKind::Delete; } FailureOr visitOp(SetUnionOp op) { SetVector result; for (auto set : op.getSets()) - result.set_union(get(set).getSet()); + result.set_union(get(set)->set); - store(op.getResult(), - SetValue(sharedState.internalizer, std::move(result), op.getType())); + state[op.getResult()] = sharedState.internalizer.internalize( + std::move(result), op.getType()); return DeletionKind::Delete; } FailureOr visitOp(SetSizeOp op) { - auto size = get(op.getSet()).getSet().size(); - store(op.getResult(), IndexValue(size)); + auto size = get(op.getSet())->set.size(); + state[op.getResult()] = size; return DeletionKind::Delete; } @@ -898,26 +687,24 @@ class Elaborator : public RTGOpVisitor> { MapVector bag; for (auto [val, multiple] : llvm::zip(op.getElements(), op.getMultiples())) { - auto interpValue = get(val); // If the multiple is not stored as an AttributeValue, the elaboration // must have already failed earlier (since we don't have // unevaluated/opaque values). - auto interpMultiple = get(multiple); - bag[interpValue] += interpMultiple.getIndex(); + bag[state.at(val)] += get(multiple); } - store(op.getBag(), - BagValue(sharedState.internalizer, std::move(bag), op.getType())); + state[op.getBag()] = sharedState.internalizer.internalize( + std::move(bag), op.getType()); return DeletionKind::Delete; } FailureOr visitOp(BagSelectRandomOp op) { - auto bag = get(op.getBag()); + auto bag = get(op.getBag())->bag; SmallVector> prefixSum; - prefixSum.reserve(bag.getBag().size()); + prefixSum.reserve(bag.size()); uint32_t accumulator = 0; - for (auto [val, weight] : bag.getBag()) { + for (auto [val, weight] : bag) { accumulator += weight; prefixSum.push_back({val, accumulator}); } @@ -935,17 +722,17 @@ class Elaborator : public RTGOpVisitor> { return a < b.second; }); - store(op.getResult(), iter->first); + state[op.getResult()] = iter->first; return DeletionKind::Delete; } FailureOr visitOp(BagDifferenceOp op) { - auto original = get(op.getOriginal()); - auto diff = get(op.getDiff()); + auto original = get(op.getOriginal())->bag; + auto diff = get(op.getDiff())->bag; MapVector result; - for (const auto &el : original.getBag()) { - if (!diff.getBag().contains(el.first)) { + for (const auto &el : original) { + if (!diff.contains(el.first)) { result.insert(el); continue; } @@ -953,39 +740,39 @@ class Elaborator : public RTGOpVisitor> { if (op.getInf()) continue; - auto toDiff = diff.getBag().lookup(el.first); + auto toDiff = diff.lookup(el.first); if (el.second <= toDiff) continue; result.insert({el.first, el.second - toDiff}); } - store(op.getResult(), - BagValue(sharedState.internalizer, std::move(result), op.getType())); + state[op.getResult()] = sharedState.internalizer.internalize( + std::move(result), op.getType()); return DeletionKind::Delete; } FailureOr visitOp(BagUnionOp op) { MapVector result; for (auto bag : op.getBags()) { - auto val = get(bag); - for (auto [el, multiple] : val.getBag()) + auto val = get(bag)->bag; + for (auto [el, multiple] : val) result[el] += multiple; } - store(op.getResult(), - BagValue(sharedState.internalizer, std::move(result), op.getType())); + state[op.getResult()] = sharedState.internalizer.internalize( + std::move(result), op.getType()); return DeletionKind::Delete; } FailureOr visitOp(BagUniqueSizeOp op) { - auto size = get(op.getBag()).getBag().size(); - store(op.getResult(), IndexValue(size)); + auto size = get(op.getBag())->bag.size(); + state[op.getResult()] = size; return DeletionKind::Delete; } FailureOr visitOp(scf::IfOp op) { - bool cond = get(op.getCondition()).getBool(); + bool cond = get(op.getCondition()); auto &toElaborate = cond ? op.getThenRegion() : op.getElseRegion(); if (toElaborate.empty()) return DeletionKind::Delete; @@ -1007,13 +794,15 @@ class Elaborator : public RTGOpVisitor> { } FailureOr visitOp(scf::ForOp op) { - auto lowerBound = get(op.getLowerBound()); - auto step = get(op.getStep()); - auto upperBound = get(op.getUpperBound()); - - if (!lowerBound || !step || !upperBound) + if (!(std::holds_alternative(state.at(op.getLowerBound())) && + std::holds_alternative(state.at(op.getStep())) && + std::holds_alternative(state.at(op.getUpperBound())))) return op->emitOpError("can only elaborate index type iterator"); + auto lowerBound = get(op.getLowerBound()); + auto step = get(op.getStep()); + auto upperBound = get(op.getUpperBound()); + // Prepare for first iteration by assigning the nested regions block // arguments. We can just reuse this elaborator because we need access to // values elaborated in the parent region anyway and materialize everything @@ -1024,14 +813,13 @@ class Elaborator : public RTGOpVisitor> { state[iterArg] = state.at(initArg); // This loop performs the actual 'scf.for' loop iterations. - for (size_t i = lowerBound.getIndex(); i < upperBound.getIndex(); - i += step.getIndex()) { + for (size_t i = lowerBound; i < upperBound; i += step) { if (failed(elaborate(op.getBodyRegion()))) return failure(); // Prepare for the next iteration by updating the mapping of the nested // regions block arguments - store(op.getInductionVar(), IndexValue(i + step.getIndex())); + state[op.getInductionVar()] = i + step; for (auto [iterArg, prevIterArg] : llvm::zip(op.getRegionIterArgs(), op.getBody()->getTerminator()->getOperands())) @@ -1051,15 +839,15 @@ class Elaborator : public RTGOpVisitor> { } FailureOr visitOp(index::AddOp op) { - size_t lhs = get(op.getLhs()).getIndex(); - size_t rhs = get(op.getRhs()).getIndex(); - store(op.getResult(), IndexValue(lhs + rhs)); + size_t lhs = get(op.getLhs()); + size_t rhs = get(op.getRhs()); + state[op.getResult()] = lhs + rhs; return DeletionKind::Delete; } FailureOr visitOp(index::CmpOp op) { - size_t lhs = get(op.getLhs()).getIndex(); - size_t rhs = get(op.getRhs()).getIndex(); + size_t lhs = get(op.getLhs()); + size_t rhs = get(op.getRhs()); bool result; switch (op.getPred()) { case index::IndexCmpPredicate::EQ: @@ -1083,7 +871,7 @@ class Elaborator : public RTGOpVisitor> { default: return op->emitOpError("elaboration not supported"); } - store(op.getResult(), BoolValue(result)); + state[op.getResult()] = result; return DeletionKind::Delete; } @@ -1101,11 +889,11 @@ class Elaborator : public RTGOpVisitor> { auto intAttr = dyn_cast(attr); if (intAttr && isa(attr.getType())) - store(op->getResult(0), IndexValue(intAttr.getInt())); + state[op->getResult(0)] = size_t(intAttr.getInt()); else if (intAttr && intAttr.getType().isSignlessInteger(1)) - store(op->getResult(0), BoolValue(bool(intAttr.getInt()))); + state[op->getResult(0)] = bool(intAttr.getInt()); else - store(op->getResult(0), AttributeValue(attr)); + state[op->getResult(0)] = attr; return DeletionKind::Delete; } @@ -1129,7 +917,7 @@ class Elaborator : public RTGOpVisitor> { for (auto [arg, elabArg] : llvm::zip(region.getArguments(), regionArguments)) - store(arg, elabArg); + state[arg] = elabArg; Block *block = ®ion.front(); for (auto &op : *block) { @@ -1146,7 +934,7 @@ class Elaborator : public RTGOpVisitor> { llvm::interleaveComma(op.getResults(), llvm::dbgs(), [&](auto res) { if (state.contains(res)) - llvm::dbgs() << get(res); + llvm::dbgs() << state.at(res); else llvm::dbgs() << "unknown"; }); @@ -1265,22 +1053,21 @@ LogicalResult ElaborationPass::elaborateModule(ModuleOp moduleOp, // Do top-down BFS traversal such that elaborating a sequence further down // does not fix the outcome for multiple placements. while (!state.worklist.empty()) { - auto curr = state.worklist.front(); + auto *curr = state.worklist.front(); state.worklist.pop(); - if (table.lookup(curr.getName())) + if (table.lookup(curr->name)) continue; - auto familyOp = table.lookup(curr.getFamilyName()); + auto familyOp = table.lookup(curr->familyName); // TODO: don't clone if this is the only remaining reference to this // sequence OpBuilder builder(familyOp); auto seqOp = builder.cloneWithoutRegions(familyOp); seqOp.getBodyRegion().emplaceBlock(); - seqOp.setSymName(curr.getName()); + seqOp.setSymName(curr->name); table.insert(seqOp); - assert(seqOp.getSymName() == curr.getName() && - "should not have been renamed"); + assert(seqOp.getSymName() == curr->name && "should not have been renamed"); LLVM_DEBUG(llvm::dbgs() << "\n=== Elaborating sequence family @" << familyOp.getSymName() @@ -1288,7 +1075,7 @@ LogicalResult ElaborationPass::elaborateModule(ModuleOp moduleOp, Materializer materializer(OpBuilder::atBlockBegin(seqOp.getBody())); Elaborator elaborator(state, materializer); - if (failed(elaborator.elaborate(familyOp.getBodyRegion(), curr.getArgs()))) + if (failed(elaborator.elaborate(familyOp.getBodyRegion(), curr->args))) return failure(); materializer.finalize();