From e6c39757daccb2da23292021ed159a5509f6f8ac Mon Sep 17 00:00:00 2001 From: fruffy Date: Thu, 6 Jun 2024 09:35:20 -0400 Subject: [PATCH] Replace boost::container::flat_map with a custom flat_map implementation in P4Tools. --- backends/p4tools/BUILD.bazel | 1 - backends/p4tools/common/lib/model.cpp | 2 - backends/p4tools/common/lib/model.h | 6 +- backends/p4tools/common/lib/symbolic_env.cpp | 2 - backends/p4tools/common/options.h | 1 - .../modules/testgen/lib/execution_state.cpp | 1 - .../modules/testgen/lib/final_state.cpp | 2 - backends/p4tools/p4tools.def | 60 +++- ir/solver.h | 9 +- lib/flat_map.h | 318 ++++++++++++++++++ test/CMakeLists.txt | 1 + test/gtest/flat_map.cpp | 168 +++++++++ 12 files changed, 550 insertions(+), 21 deletions(-) create mode 100644 lib/flat_map.h create mode 100644 test/gtest/flat_map.cpp diff --git a/backends/p4tools/BUILD.bazel b/backends/p4tools/BUILD.bazel index 383e198a1e1..33cbc4b582a 100644 --- a/backends/p4tools/BUILD.bazel +++ b/backends/p4tools/BUILD.bazel @@ -181,7 +181,6 @@ cc_binary( deps = [ ":testgen_lib", "//:lib", - "@boost//:filesystem", "@boost//:multiprecision", ], ) diff --git a/backends/p4tools/common/lib/model.cpp b/backends/p4tools/common/lib/model.cpp index 382be531640..f5cacf24d31 100644 --- a/backends/p4tools/common/lib/model.cpp +++ b/backends/p4tools/common/lib/model.cpp @@ -4,8 +4,6 @@ #include #include -#include - #include "frontends/p4/optimizeExpressions.h" #include "ir/indexed_vector.h" #include "ir/irutils.h" diff --git a/backends/p4tools/common/lib/model.h b/backends/p4tools/common/lib/model.h index ffc63aef3ea..df1b325dad4 100644 --- a/backends/p4tools/common/lib/model.h +++ b/backends/p4tools/common/lib/model.h @@ -3,9 +3,6 @@ #include #include -#include - -#include #include "ir/ir.h" #include "ir/solver.h" @@ -14,7 +11,8 @@ namespace P4Tools { /// Symbolic maps map a state variable to a IR::Expression. -using SymbolicMapType = boost::container::flat_map; +using SymbolicMapType = + P4C::flat_map>; /// Represents a solution found by the solver. A model is a concretized form of a symbolic /// environment. All the expressions in a Model must be of type IR::Literal. diff --git a/backends/p4tools/common/lib/symbolic_env.cpp b/backends/p4tools/common/lib/symbolic_env.cpp index e66ca624fc6..26bf02f063c 100644 --- a/backends/p4tools/common/lib/symbolic_env.cpp +++ b/backends/p4tools/common/lib/symbolic_env.cpp @@ -3,8 +3,6 @@ #include #include -#include - #include "backends/p4tools/common/lib/model.h" #include "ir/indexed_vector.h" #include "ir/vector.h" diff --git a/backends/p4tools/common/options.h b/backends/p4tools/common/options.h index a1b145a7433..66664c7ee22 100644 --- a/backends/p4tools/common/options.h +++ b/backends/p4tools/common/options.h @@ -1,7 +1,6 @@ #ifndef BACKENDS_P4TOOLS_COMMON_OPTIONS_H_ #define BACKENDS_P4TOOLS_COMMON_OPTIONS_H_ -// Boost #include #include #include diff --git a/backends/p4tools/modules/testgen/lib/execution_state.cpp b/backends/p4tools/modules/testgen/lib/execution_state.cpp index 6fb6c5db013..925525d3bfe 100644 --- a/backends/p4tools/modules/testgen/lib/execution_state.cpp +++ b/backends/p4tools/modules/testgen/lib/execution_state.cpp @@ -11,7 +11,6 @@ #include #include -#include #include #include "backends/p4tools/common/compiler/convert_hs_index.h" diff --git a/backends/p4tools/modules/testgen/lib/final_state.cpp b/backends/p4tools/modules/testgen/lib/final_state.cpp index a4eb7436ee1..c21069db95a 100644 --- a/backends/p4tools/modules/testgen/lib/final_state.cpp +++ b/backends/p4tools/modules/testgen/lib/final_state.cpp @@ -5,8 +5,6 @@ #include #include -#include - #include "backends/p4tools/common/lib/model.h" #include "backends/p4tools/common/lib/symbolic_env.h" #include "backends/p4tools/common/lib/trace_event.h" diff --git a/backends/p4tools/p4tools.def b/backends/p4tools/p4tools.def index 56721564f2f..a8030a6725e 100644 --- a/backends/p4tools/p4tools.def +++ b/backends/p4tools/p4tools.def @@ -23,12 +23,21 @@ class StateVariable : Expression { StateVariable(ArrayIndex arr) : Expression(arr->getSourceInfo(), arr->type), ref(arr) {} /// Implements comparisons so that StateVariables can be used as map keys. + // Delegate to IR's notion of equality. bool operator==(const StateVariable &other) const override { - // Delegate to IR's notion of equality. return *ref == *other.ref; } /// Implements comparisons so that StateVariables can be used as map keys. + /// Note that we ignore the type of the variable in the comparison. + equiv { + // We use a custom compare function. + // TODO: Is there a faster way to implement this comparison? + return compare(ref, a.ref) == 0; + } + + /// Implements comparisons so that StateVariables can be used as map keys. + /// Note that we ignore the type of the variable in the comparison. bool operator<(const StateVariable &other) const { // We use a custom compare function. // TODO: Is there a faster way to implement this comparison? @@ -122,6 +131,30 @@ class StateVariable : Expression { dbprint { ref->dbprint(out); } } +#emit +namespace IR { +/// Equals for StateVariable pointers. We only compare the label. +struct StateVariableEqual { + bool operator()(const IR::StateVariable *s1, const IR::StateVariable *s2) const { + return s1->equiv(*s2); + } + bool operator()(const IR::StateVariable &s1, const IR::StateVariable &s2) const { + return s1.equiv(s2); + } +}; + +/// Less for StateVariable pointers. We only compare the label. +struct StateVariableLess { + bool operator()(const IR::StateVariable *s1, const IR::StateVariable *s2) const { + return s1->operator<(*s2); + } + bool operator()(const IR::StateVariable &s1, const IR::StateVariable &s2) const { + return s1.operator<(s2); + } +}; +} // namespace IR +#end + /// Signifies that a particular expression is tainted. /// This tainted expression must be resolved explicitly. class TaintExpression : Expression { @@ -151,6 +184,31 @@ class SymbolicVariable : Expression { dbprint { out << "|" + label +"(" << type << ")|"; } } +#emit +namespace IR { +/// Equals for SymbolicVariable pointers. We only compare the label. +struct SymbolicVariableEqual { + bool operator()(const IR::SymbolicVariable *s1, const IR::SymbolicVariable *s2) const { + return s1->label == s2->label; + } + bool operator()(const IR::SymbolicVariable &s1, const IR::SymbolicVariable &s2) const { + return s1.label == s2.label; + } +}; + +/// Less for SymbolicVariable pointers. We only compare the label. +struct SymbolicVariableLess { + bool operator()(const IR::SymbolicVariable *s1, const IR::SymbolicVariable *s2) const { + return s1->operator<(*s2); + } + bool operator()(const IR::SymbolicVariable &s1, const IR::SymbolicVariable &s2) const { + return s1.operator<(s2); + } +}; +} // namespace IR +#end + + /// This type replaces Type_Varbits and can store information about the current size class Extracted_Varbits : Type_Bits { public: diff --git a/ir/solver.h b/ir/solver.h index 459be62fa28..80e0c8d79f4 100644 --- a/ir/solver.h +++ b/ir/solver.h @@ -4,12 +4,10 @@ #include #include -#include -#include - #include "ir/ir.h" #include "lib/castable.h" #include "lib/cstring.h" +#include "lib/flat_map.h" /// Represents a constraint that can be shipped to and asserted within a solver. // TODO: This should implement AbstractRepCheckedNode. @@ -23,10 +21,7 @@ struct SymbolicVarComp { }; /// This type maps symbolic variables to their value assigned by the solver. -using SymbolicMapping = boost::container::flat_map; - -using SymbolicSet = boost::container::flat_set; +using SymbolicMapping = P4C::flat_map; /// Provides a higher-level interface for an SMT solver. class AbstractSolver : public ICastable { diff --git a/lib/flat_map.h b/lib/flat_map.h new file mode 100644 index 00000000000..ae7342e6190 --- /dev/null +++ b/lib/flat_map.h @@ -0,0 +1,318 @@ +/* +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#ifndef LIB_FLAT_MAP_H_ +#define LIB_FLAT_MAP_H_ + +#include +#include +#include + +namespace P4C { + +/// A header-only implementation of a memory-efficient flat_map. +/// TODO: Replace this map with std::flat_map once available in C++23: +/// https://en.cppreference.com/w/cpp/container/flat_map +template , + typename Container = std::vector>> +struct flat_map { + using key_type = K; + using mapped_type = V; + using value_type = typename Container::value_type; + using key_compare = Compare; + + struct value_compare { + bool operator()(const value_type &lhs, const value_type &rhs) const { + return key_compare()(lhs.first, rhs.first); + } + }; + + using allocator_type = typename Container::allocator_type; + using reference = typename Container::reference; + using const_reference = typename Container::const_reference; + using pointer = typename Container::pointer; + using const_pointer = typename Container::const_pointer; + using iterator = typename Container::iterator; + using const_iterator = typename Container::const_iterator; + using reverse_iterator = typename Container::reverse_iterator; + using const_reverse_iterator = typename Container::const_reverse_iterator; + using difference_type = typename Container::difference_type; + using size_type = typename Container::size_type; + + flat_map() = default; + + template + flat_map(It begin, It end) { + insert(begin, end); + } + + flat_map(std::initializer_list il) : flat_map(il.begin(), il.end()) {} + + iterator begin() { return data_.begin(); } + iterator end() { return data_.end(); } + const_iterator begin() const { return data_.begin(); } + const_iterator end() const { return data_.end(); } + const_iterator cbegin() const { return data_.cbegin(); } + const_iterator cend() const { return data_.cend(); } + reverse_iterator rbegin() { return data_.rbegin(); } + reverse_iterator rend() { return data_.rend(); } + const_reverse_iterator rbegin() const { return data_.rbegin(); } + const_reverse_iterator rend() const { return data_.rend(); } + const_reverse_iterator crbegin() const { return data_.crbegin(); } + const_reverse_iterator crend() const { return data_.crend(); } + + bool empty() const { return data_.empty(); } + size_type size() const { return data_.size(); } + size_type max_size() const { return data_.max_size(); } + size_type capacity() const { return data_.capacity(); } + void reserve(size_type size) { data_.reserve(size); } + void shrink_to_fit() { data_.shrink_to_fit(); } + size_type bytes_used() const { return capacity() * sizeof(value_type) + sizeof(data_); } + + mapped_type &operator[](const key_type &key) { + KeyOrValueCompare comp; + auto lower = lower_bound(key); + if (lower == end() || comp(key, *lower)) + return data_.emplace(lower, key, mapped_type())->second; + + return lower->second; + } + + mapped_type &operator[](key_type &&key) { + KeyOrValueCompare comp; + auto lower = lower_bound(key); + if (lower == end() || comp(key, *lower)) + return data_.emplace(lower, std::move(key), mapped_type())->second; + + return lower->second; + } + + mapped_type &at(const key_type &key) { return lower_bound(key)->second; } + + const mapped_type &at(const key_type &key) const { return lower_bound(key)->second; } + + std::pair insert(value_type &&value) { return emplace(std::move(value)); } + + std::pair insert(const value_type &value) { return emplace(value); } + + iterator insert(const_iterator hint, value_type &&value) { + return emplace_hint(hint, std::move(value)); + } + + iterator insert(const_iterator hint, const value_type &value) { + return emplace_hint(hint, value); + } + + template + void insert(It begin, It end) { + // If we need to increase the capacity, utilize this fact and emplace + // the stuff. + for (; begin != end && size() == capacity(); ++begin) { + emplace(*begin); + } + if (begin == end) return; + + // If we don't need to increase capacity, then we can use a more efficient + // insert method where everything is just put in the same vector + // and then merge in place. + size_type size_before = data_.size(); + try { + for (size_t i = capacity(); i > size_before && begin != end; --i, ++begin) { + data_.emplace_back(*begin); + } + } catch (...) { + // If emplace_back throws an exception, the easiest way to make sure + // that our invariants are still in place is to resize to the state + // we were in before + for (size_t i = data_.size(); i > size_before; --i) { + data_.pop_back(); + } + throw; + } + + value_compare comp; + auto mid = data_.begin() + size_before; + std::stable_sort(mid, data_.end(), comp); + std::inplace_merge(data_.begin(), mid, data_.end(), comp); + data_.erase(std::unique(data_.begin(), data_.end(), std::not_fn(comp)), data_.end()); + + // Make sure that we inserted at least one element before + // recursing. Otherwise we'd recurse too often if we were to insert the + // same element many times + if (data_.size() == size_before) { + for (; begin != end; ++begin) { + if (emplace(*begin).second) { + ++begin; + break; + } + } + } + + // Insert the remaining elements that didn't fit by calling this function recursively. + return insert(begin, end); + } + + void insert(std::initializer_list il) { insert(il.begin(), il.end()); } + + iterator erase(iterator it) { return data_.erase(it); } + + iterator erase(const_iterator it) { return erase(iterator_const_cast(it)); } + + size_type erase(const key_type &key) { + auto found = find(key); + if (found == end()) return 0; + erase(found); + return 1; + } + + iterator erase(const_iterator first, const_iterator last) { + return data_.erase(iterator_const_cast(first), iterator_const_cast(last)); + } + + void swap(flat_map &other) { data_.swap(other.data_); } + + void clear() { data_.clear(); } + + template + std::pair emplace(First &&first, Args &&...args) { + KeyOrValueCompare comp; + auto lower_bound = std::lower_bound(data_.begin(), data_.end(), first, comp); + if (lower_bound == data_.end() || comp(first, *lower_bound)) + return { + data_.emplace(lower_bound, std::forward(first), std::forward(args)...), + true}; + + return {lower_bound, false}; + } + + std::pair emplace() { return emplace(value_type()); } + + template + iterator emplace_hint(const_iterator hint, First &&first, Args &&...args) { + KeyOrValueCompare comp; + if (hint == cend() || comp(first, *hint)) { + if (hint == cbegin() || comp(*(hint - 1), first)) + return data_.emplace(iterator_const_cast(hint), std::forward(first), + std::forward(args)...); + + return emplace(std::forward(first), std::forward(args)...).first; + } else if (!comp(*hint, first)) { + return begin() + (hint - cbegin()); + } + + return emplace(std::forward(first), std::forward(args)...).first; + } + + iterator emplace_hint(const_iterator hint) { return emplace_hint(hint, value_type()); } + + key_compare key_comp() const { return key_compare(); } + value_compare value_comp() const { return value_compare(); } + + template + iterator find(const T &key) { + return binary_find(begin(), end(), key, KeyOrValueCompare()); + } + template + const_iterator find(const T &key) const { + return binary_find(begin(), end(), key, KeyOrValueCompare()); + } + template + size_type count(const T &key) const { + return std::binary_search(begin(), end(), key, KeyOrValueCompare()) ? 1 : 0; + } + template + iterator lower_bound(const T &key) { + return std::lower_bound(begin(), end(), key, KeyOrValueCompare()); + } + template + const_iterator lower_bound(const T &key) const { + return std::lower_bound(begin(), end(), key, KeyOrValueCompare()); + } + template + iterator upper_bound(const T &key) { + return std::upper_bound(begin(), end(), key, KeyOrValueCompare()); + } + template + const_iterator upper_bound(const T &key) const { + return std::upper_bound(begin(), end(), key, KeyOrValueCompare()); + } + template + std::pair equal_range(const T &key) { + return std::equal_range(begin(), end(), key, KeyOrValueCompare()); + } + template + std::pair equal_range(const T &key) const { + return std::equal_range(begin(), end(), key, KeyOrValueCompare()); + } + allocator_type get_allocator() const { return data_.get_allocator(); } + + bool operator==(const flat_map &other) const { return data_ == other.data_; } + bool operator!=(const flat_map &other) const { return !(*this == other); } + bool operator<(const flat_map &other) const { return data_ < other.data_; } + bool operator>(const flat_map &other) const { return other < *this; } + bool operator<=(const flat_map &other) const { return !(other < *this); } + bool operator>=(const flat_map &other) const { return !(*this < other); } + + private: + Container data_; + + iterator iterator_const_cast(const_iterator it) { return begin() + (it - cbegin()); } + + struct KeyOrValueCompare { + bool operator()(const key_type &lhs, const key_type &rhs) const { + return key_compare()(lhs, rhs); + } + bool operator()(const key_type &lhs, const value_type &rhs) const { + return key_compare()(lhs, rhs.first); + } + template + bool operator()(const key_type &lhs, const T &rhs) const { + return key_compare()(lhs, rhs); + } + template + bool operator()(const T &lhs, const key_type &rhs) const { + return key_compare()(lhs, rhs); + } + bool operator()(const value_type &lhs, const key_type &rhs) const { + return key_compare()(lhs.first, rhs); + } + bool operator()(const value_type &lhs, const value_type &rhs) const { + return key_compare()(lhs.first, rhs.first); + } + template + bool operator()(const value_type &lhs, const T &rhs) const { + return key_compare()(lhs.first, rhs); + } + template + bool operator()(const T &lhs, const value_type &rhs) const { + return key_compare()(lhs, rhs.first); + } + }; + + template + static It binary_find(It begin, It end, const T &value, const Comp &cmp) { + auto lower_bound = std::lower_bound(begin, end, value, cmp); + if (lower_bound == end || cmp(value, *lower_bound)) return end; + + return lower_bound; + } +}; + +template +void swap(flat_map &lhs, flat_map &rhs) { + lhs.swap(rhs); +} + +} // namespace P4C + +#endif // LIB_FLAT_MAP_H_ diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 1c3c24235ba..e3b9c0a0962 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -31,6 +31,7 @@ set (GTEST_UNITTEST_SOURCES gtest/equiv_test.cpp gtest/exception_test.cpp gtest/expr_uses_test.cpp + gtest/flat_map.cpp gtest/format_test.cpp gtest/helpers.cpp gtest/hash.cpp diff --git a/test/gtest/flat_map.cpp b/test/gtest/flat_map.cpp new file mode 100644 index 00000000000..e6ab1081c82 --- /dev/null +++ b/test/gtest/flat_map.cpp @@ -0,0 +1,168 @@ +/* +Copyright 2013-present Barefoot Networks, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include "lib/flat_map.h" + +#include + +#include "lib/map.h" + +namespace Test { + +TEST(FlatMap, MapEqual) { + P4C::flat_map a; + P4C::flat_map b; + + EXPECT_TRUE(a == b); + + a[1] = 111; + a[2] = 222; + a[3] = 333; + a[4] = 444; + + b[1] = 111; + b[2] = 222; + b[3] = 333; + b[4] = 444; + + EXPECT_TRUE(a == b); + + a.erase(2); + b.erase(2); + + EXPECT_TRUE(a == b); + + a.clear(); + b.clear(); + + EXPECT_TRUE(a == b); +} + +TEST(FlatMap, MapNotEqual) { + P4C::flat_map a; + P4C::flat_map b; + + EXPECT_TRUE(a == b); + + a[1] = 111; + a[2] = 222; + a[3] = 333; + a[4] = 444; + + b[4] = 444; + b[3] = 333; + b[2] = 222; + b[1] = 111; + + EXPECT_TRUE(a != b); + + a.clear(); + b.clear(); + + EXPECT_TRUE(a == b); + + a[1] = 111; + a[2] = 222; + + b[1] = 111; + b[2] = 222; + b[3] = 333; + + EXPECT_TRUE(a != b); + + a.clear(); + b.clear(); + + EXPECT_TRUE(a == b); + + a[1] = 111; + a[2] = 222; + a[3] = 333; + a[4] = 444; + + b[4] = 111; + b[3] = 222; + b[2] = 333; + b[1] = 444; + + EXPECT_TRUE(a != b); + + a.clear(); + b.clear(); + + EXPECT_TRUE(a == b); + + a[1] = 111; + a[2] = 222; + a[3] = 333; + a[4] = 444; + + b[1] = 111; + b[2] = 111; + b[3] = 111; + b[4] = 111; + + EXPECT_TRUE(a != b); +} + +TEST(FlatMap, InsertEmplaceErase) { + P4C::flat_map om; + std::map sm; + + auto it = om.end(); + for (auto v : {0, 1, 2, 3, 4, 5, 6, 7, 8}) { + sm.emplace(v, 2 * v); + std::pair pair{v, 2 * v}; + if (v % 2 == 0) { + if ((v / 2) % 2 == 0) { + it = om.insert(pair).first; + } else { + it = om.emplace(v, pair.second).first; + } + } else { + if ((v / 2) % 2 == 0) { + it = om.insert(std::move(pair)).first; + } else { + it = om.emplace(v, v * 2).first; + } + } + } + + EXPECT_TRUE(std::equal(om.begin(), om.end(), sm.begin(), sm.end())); + + it = std::next(om.begin(), 2); + om.erase(it); + sm.erase(std::next(sm.begin(), 2)); + + EXPECT_TRUE(om.size() == sm.size()); + EXPECT_TRUE(std::equal(om.begin(), om.end(), sm.begin(), sm.end())); +} + +TEST(FlatMap, ExistingKey) { + P4C::flat_map myMap{{1, "One"}, {2, "Two"}, {3, "Three"}}; + + EXPECT_EQ(get(myMap, 1), "One"); + EXPECT_EQ(get(myMap, 2), "Two"); + EXPECT_EQ(get(myMap, 3), "Three"); +} + +TEST(FlatMap, NonExistingKey) { + P4C::flat_map myMap{{1, "One"}, {2, "Two"}, {3, "Three"}}; + + EXPECT_EQ(get(myMap, 4), ""); +} + +} // namespace Test