diff --git a/src/analysis/lattices/abstraction.h b/src/analysis/lattices/abstraction.h new file mode 100644 index 00000000000..bc503518c9f --- /dev/null +++ b/src/analysis/lattices/abstraction.h @@ -0,0 +1,227 @@ +/* + * Copyright 2025 WebAssembly Community Group participants + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include + +#include "../lattice.h" +#include "support/utilities.h" + +#if __cplusplus >= 202002L +#include "analysis/lattices/bool.h" +#endif + +#ifndef wasm_analysis_lattices_abstraction_h +#define wasm_analysis_lattices_abstraction_h + +namespace wasm::analysis { + +// CRTP lattice composed of increasingly abstract sub-lattices. The subclass is +// responsible for providing two method templates. The first abstracts an +// element of one sub-lattice into an element of the next sub-lattice: +// +// template +// E2 abstract(const E1&) const +// +// The template method should be specialized for each sub-lattice index I, its +// element type E1, and the next element type E2. +// +// The `abstract` method is used to abstract elements of the more specific +// lattice whenever elements from different lattices are compared or joined. It +// may also be used to abstract two joined elements from the same lattice when +// those elements are unrelated and the second method returns true: +// +// template +// bool shouldAbstract(const E&. const E&) const +// +// shouldAbstract is only queried for unrelated elements. Related elements of +// the same sub-lattice are always joined as normal. +// +// `abstract` should be monotonic. Making its input more general should either +// not change its output or make its output more general. +// +// `shouldAbstract` should return true only when no upper bound of its arguments +// in their original sub-lattice is used. If such an upper bound is used in a +// comparison or join, the operation may fail to uphold the properties of a +// lattice. +template struct Abstraction { + using Element = std::variant; + + std::tuple lattices; + + Abstraction(Ls&&... lattices) : lattices({std::move(lattices)...}) {} + + Element getBottom() const noexcept { + return std::get<0>(lattices).getBottom(); + } + + LatticeComparison compare(const Element& a, const Element& b) const noexcept { + if (a.index() < b.index()) { + auto abstractedA = a; + abstractToIndex(abstractedA, b.index()); + switch (compare()[b.index()](lattices, abstractedA, b)) { + case EQUAL: + case LESS: + return LESS; + case NO_RELATION: + case GREATER: + return NO_RELATION; + } + WASM_UNREACHABLE("unexpected comparison"); + } + if (a.index() > b.index()) { + auto abstractedB = b; + abstractToIndex(abstractedB, a.index()); + switch (compare()[a.index()](lattices, a, abstractedB)) { + case EQUAL: + case GREATER: + return GREATER; + case NO_RELATION: + case LESS: + return NO_RELATION; + } + WASM_UNREACHABLE("unexpected comparison"); + } + return compare()[a.index()](lattices, a, b); + } + + bool join(Element& joinee, const Element& _joiner) const noexcept { + Element joiner = _joiner; + bool changed = false; + if (joinee.index() < joiner.index()) { + abstractToIndex(joinee, joiner.index()); + changed = true; + } else if (joinee.index() > joiner.index()) { + abstractToIndex(joiner, joinee.index()); + } + while (true) { + assert(joinee.index() == joiner.index()); + if (joiner.index() == sizeof...(Ls) - 1) { + // Cannot abstract further, so we must join no matter what. + break; + } + switch (compare()[joiner.index()](lattices, joinee, joiner)) { + case NO_RELATION: + if (shouldAbstract()[joiner.index()](self(), joinee, joiner)) { + // Try abstracting further. + joinee = abstract()[joinee.index()](self(), joinee); + joiner = abstract()[joiner.index()](self(), joiner); + changed = true; + continue; + } + break; + case EQUAL: + case LESS: + case GREATER: + break; + } + break; + } + return join()[joiner.index()](lattices, joinee, joiner) || changed; + } + +private: + const Self& self() const noexcept { return *static_cast(this); } + + // TODO: Use C++26 pack indexing. + template using L = std::tuple_element_t>; + + // Compute tables of functions that forward operations to the CRTP subtype or + // the lattices. These tables map the dynamic variant indices to compile-time + // lattice indices. + + template + static constexpr auto makeAbstractFuncs(std::index_sequence) noexcept { + using F = Element (*)(const Self&, const Element& elem); + return std::array{ + [](const Self& self, const Element& elem) -> Element { + if constexpr (I < sizeof...(Ls) - 1) { + using E1 = typename L::Element; + using E2 = typename L::Element; + return Element(std::in_place_index_t{}, + self.template abstract(std::get(elem))); + } else { + WASM_UNREACHABLE("unexpected abstraction"); + } + }...}; + } + static constexpr auto abstract() noexcept { + return makeAbstractFuncs(std::make_index_sequence{}); + } + + void abstractToIndex(Element& elem, std::size_t index) const noexcept { + while (elem.index() < index) { + elem = abstract()[elem.index()](self(), elem); + } + } + + template + static constexpr auto + makeShouldAbstractFuncs(std::index_sequence) noexcept { + using F = bool (*)(const Self&, const Element&, const Element&); + return std::array{ + [](const Self& self, const Element& a, const Element& b) -> bool { + if constexpr (I < sizeof...(Ls) - 1) { + return self.template shouldAbstract(std::get(a), + std::get(b)); + } else { + WASM_UNREACHABLE("unexpected abstraction check"); + } + }...}; + } + static constexpr auto shouldAbstract() noexcept { + return makeShouldAbstractFuncs(std::make_index_sequence{}); + } + + template + static constexpr auto makeCompareFuncs(std::index_sequence) noexcept { + using F = LatticeComparison (*)( + const std::tuple&, const Element&, const Element&); + return std::array{ + [](const std::tuple& lattices, + const Element& a, + const Element& b) -> LatticeComparison { + return std::get(lattices).compare(std::get(a), std::get(b)); + }...}; + } + static constexpr auto compare() noexcept { + return makeCompareFuncs(std::make_index_sequence{}); + } + + template + static constexpr auto makeJoinFuncs(std::index_sequence) noexcept { + using F = bool (*)(const std::tuple&, Element&, const Element&); + return std::array{[](const std::tuple& lattices, + Element& joinee, + const Element& joiner) { + return std::get(lattices).join(std::get(joinee), + std::get(joiner)); + }...}; + } + static constexpr auto join() noexcept { + return makeJoinFuncs(std::make_index_sequence{}); + } +}; + +#if __cplusplus >= 202002L +static_assert(Lattice>); +#endif + +} // namespace wasm::analysis + +#endif // wasm_analysis_lattices_abstraction_h diff --git a/test/gtest/lattices.cpp b/test/gtest/lattices.cpp index 7fff8d0c075..5a80994b2bd 100644 --- a/test/gtest/lattices.cpp +++ b/test/gtest/lattices.cpp @@ -14,6 +14,8 @@ * limitations under the License. */ +#include "analysis/lattice.h" +#include "analysis/lattices/abstraction.h" #include "analysis/lattices/array.h" #include "analysis/lattices/bool.h" #include "analysis/lattices/flat.h" @@ -723,3 +725,130 @@ TEST(StackLattice, Join) { {flat.get(0), flat.get(1)}, {flat.get(0), flat.getTop()}); } + +using OddEvenInt = analysis::Flat; +using OddEvenBool = analysis::Flat; +struct OddEvenAbstraction + : analysis::Abstraction { + OddEvenAbstraction() + : analysis::Abstraction( + OddEvenInt{}, OddEvenBool{}) {} + + template E2 abstract(const E1&) const; + + template + bool shouldAbstract(const E&, const E&) const; +}; + +template<> +OddEvenBool::Element +OddEvenAbstraction::abstract<0>(const OddEvenInt::Element& elem) const { + if (elem.isTop()) { + return OddEvenBool{}.getTop(); + } + if (elem.isBottom()) { + return OddEvenBool{}.getBottom(); + } + return OddEvenBool{}.get((*elem.getVal() & 1) == 0); +} + +template<> +bool OddEvenAbstraction::shouldAbstract<0>(const OddEvenInt::Element&, + const OddEvenInt::Element&) const { + // Since the elements are not related, they must be different integers. + // Always abstract them. + return true; +} + +TEST(AbstractionLattice, GetBottom) { + OddEvenAbstraction abstraction; + auto expected = OddEvenAbstraction::Element(OddEvenInt{}.getBottom()); + EXPECT_EQ(abstraction.getBottom(), expected); +} + +TEST(AbstractionLattice, Join) { + OddEvenAbstraction abstraction; + + auto expectJoin = [&](const char* file, + int line, + const auto& joinee, + const auto& joiner, + const auto& expected) { + testing::ScopedTrace trace(file, line, ""); + switch (abstraction.compare(joinee, joiner)) { + case analysis::NO_RELATION: + EXPECT_NE(joinee, joiner); + EXPECT_EQ(abstraction.compare(joiner, joinee), analysis::NO_RELATION); + EXPECT_EQ(abstraction.compare(joinee, expected), analysis::LESS); + EXPECT_EQ(abstraction.compare(joiner, expected), analysis::LESS); + break; + case analysis::EQUAL: + EXPECT_EQ(joinee, joiner); + EXPECT_EQ(abstraction.compare(joiner, joinee), analysis::EQUAL); + EXPECT_EQ(abstraction.compare(joinee, expected), analysis::EQUAL); + EXPECT_EQ(abstraction.compare(joiner, expected), analysis::EQUAL); + break; + case analysis::LESS: + EXPECT_EQ(joiner, expected); + EXPECT_EQ(abstraction.compare(joiner, joinee), analysis::GREATER); + EXPECT_EQ(abstraction.compare(joinee, expected), analysis::LESS); + EXPECT_EQ(abstraction.compare(joiner, expected), analysis::EQUAL); + break; + case analysis::GREATER: + EXPECT_EQ(joinee, expected); + EXPECT_EQ(abstraction.compare(joiner, joinee), analysis::LESS); + EXPECT_EQ(abstraction.compare(joinee, expected), analysis::EQUAL); + EXPECT_EQ(abstraction.compare(joiner, expected), analysis::LESS); + } + { + auto copy = joinee; + EXPECT_EQ(abstraction.join(copy, joiner), joinee != expected); + EXPECT_EQ(copy, expected); + } + { + auto copy = joiner; + EXPECT_EQ(abstraction.join(copy, joinee), joiner != expected); + EXPECT_EQ(copy, expected); + } + }; + +#define JOIN(a, b, c) expectJoin(__FILE__, __LINE__, a, b, c) + + auto bot = abstraction.getBottom(); + auto one = OddEvenAbstraction::Element(OddEvenInt{}.get(1)); + auto two = OddEvenAbstraction::Element(OddEvenInt{}.get(2)); + auto three = OddEvenAbstraction::Element(OddEvenInt{}.get(3)); + auto four = OddEvenAbstraction::Element(OddEvenInt{}.get(4)); + auto even = OddEvenAbstraction::Element(OddEvenBool{}.get(true)); + auto odd = OddEvenAbstraction::Element(OddEvenBool{}.get(false)); + auto top = OddEvenAbstraction::Element(OddEvenBool{}.getTop()); + + JOIN(bot, bot, bot); + JOIN(bot, one, one); + JOIN(bot, two, two); + JOIN(bot, even, even); + JOIN(bot, odd, odd); + JOIN(bot, top, top); + + JOIN(one, one, one); + JOIN(one, two, top); + JOIN(one, three, odd); + JOIN(one, even, top); + JOIN(one, odd, odd); + + JOIN(two, two, two); + JOIN(two, three, top); + JOIN(two, four, even); + JOIN(two, even, even); + JOIN(two, odd, top); + JOIN(two, top, top); + + JOIN(even, even, even); + JOIN(even, odd, top); + JOIN(even, top, top); + + JOIN(odd, odd, odd); + JOIN(odd, top, top); + +#undef JOIN +}