diff --git a/ir/CMakeLists.txt b/ir/CMakeLists.txt index 242f6d104c..01c72f0375 100644 --- a/ir/CMakeLists.txt +++ b/ir/CMakeLists.txt @@ -30,6 +30,7 @@ set (IR_SRCS node.cpp pass_manager.cpp pass_utils.cpp + splitter.cpp type.cpp visitor.cpp write_context.cpp @@ -56,6 +57,7 @@ set (IR_HDRS nodemap.h pass_manager.h pass_utils.h + splitter.h vector.h visitor.h ) diff --git a/ir/splitter.cpp b/ir/splitter.cpp new file mode 100644 index 0000000000..49987003d3 --- /dev/null +++ b/ir/splitter.cpp @@ -0,0 +1,280 @@ +/* +Copyright 2025-present Altera Corporation. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include "splitter.h" + +#include +#include + +#include "frontends/common/resolveReferences/referenceMap.h" +#include "frontends/common/resolveReferences/resolveReferences.h" +#include "frontends/p4/typeMap.h" +#include "ir/ir-traversal.h" +#include "ir/visitor.h" + +namespace P4 { + +struct StatementSplitter : Inspector, ResolutionContext { + StatementSplitter( + std::function predicate, + P4::NameGenerator &nameGen, P4::TypeMap *typeMap, + absl::flat_hash_set &neededDecls) + : predicate(predicate), nameGen(nameGen), typeMap(typeMap), neededDecls(neededDecls) {} + + bool preorder(const IR::LoopStatement *) override { + BUG("Loops not supported in statement splitter, must be unrolled before"); + } + + bool preorder(const IR::Statement *stmt) override { + handleStmt(stmt); + return false; + } + + bool preorder(const IR::BlockStatement *bs) override { + if (handleStmt(bs)) { + // split on the bs itself + return false; + } + + for (size_t i = 0, sz = bs->components.size(); i < sz; i++) { + visit(bs->components[i], "vector"); + if (result.after) { + const auto [before, after, _] = result; // copy + auto *copy = bs->clone(); + copy->components.erase(copy->components.begin() + i, copy->components.end()); + if (before) { + copy->components.push_back(before); + } + result.before = filterDeclarations(copy); + copy = bs->clone(); + copy->components.erase(copy->components.begin(), copy->components.begin() + i); + collectNeededDeclarations(copy); + copy->components.replace(copy->components.begin(), after); + result.after = copy; + return false; // stop on first split point + } + } + return false; + } + + bool preorder(const IR::IfStatement *ifs) override { + if (handleStmt(ifs)) { + return false; // split on the if itself + } + + auto [results, anySplit] = splitBranches({ifs->ifTrue, ifs->ifFalse}); + if (!anySplit) { + return false; + } + + IR::ID condName{nameGen.newName("cond"), nullptr}; + const auto &si = ifs->srcInfo; + const auto *decl = new IR::Declaration_Variable(si, condName, IR::Type::Boolean::get()); + result.hoistedDeclarations.push_back(decl); + + const auto *condPE = new IR::PathExpression(si, new IR::Path(si, condName)); + const auto *asgn = new IR::AssignmentStatement(si, condPE, ifs->condition); + + auto *beforeIf = ifs->clone(); + beforeIf->condition = condPE->clone(); + beforeIf->ifTrue = results[0].before; + beforeIf->ifFalse = results[1].before; + result.before = new IR::BlockStatement(si, {asgn, beforeIf}); + + auto *afterIf = beforeIf->clone(); + afterIf->ifTrue = results[0].after; + afterIf->ifFalse = results[1].after; + result.after = afterIf; + + for (auto **trueBranch : {&beforeIf->ifTrue, &afterIf->ifTrue}) { + if (*trueBranch == nullptr) { + *trueBranch = new IR::BlockStatement(ifs->ifTrue->srcInfo); + } + } + return false; + } + + bool preorder(const IR::SwitchStatement *sw) override { + if (handleStmt(sw)) { + return false; // split on the switch itself + } + + std::vector branches; + for (const auto *case_ : sw->cases) { + branches.push_back(case_->statement); + } + auto [results, anySplit] = splitBranches(branches); + + if (!anySplit) { + return false; + } + + IR::ID selName{nameGen.newName("selector"), nullptr}; + const auto &si = sw->srcInfo; + const auto *selType = typeMap ? typeMap->getType(sw->expression) : nullptr; + selType = selType ? selType : sw->expression->type; + BUG_CHECK(selType && !selType->is(), + "Cannot split switch statement with unknown selector type %1%", sw->expression); + const auto *decl = new IR::Declaration_Variable(si, selName, selType); + result.hoistedDeclarations.push_back(decl); + + const auto *selPE = new IR::PathExpression(si, new IR::Path(si, selName)); + const auto *asgn = new IR::AssignmentStatement(si, selPE, sw->expression); + + // ensure we don't accidentally create fallthrough + for (size_t i = 0; i < branches.size(); ++i) { + for (const auto **val : {&results[i].before, &results[i].after}) { + if (!*val && branches[i]) { + *val = new IR::BlockStatement(branches[i]->srcInfo); + } + } + } + + auto *beforeSw = sw->clone(); + beforeSw->expression = selPE; + for (size_t i = 0; i < branches.size(); ++i) { + setCase(beforeSw, i, results[i].before); + } + result.before = new IR::BlockStatement(si, {asgn, beforeSw}); + + auto *afterSw = beforeSw->clone(); + for (size_t i = 0; i < branches.size(); ++i) { + setCase(afterSw, i, results[i].after); + } + result.after = afterSw; + return false; + } + + void end_apply(const IR::Node *root) override { + if (!result.before) { + result.before = root->checkedTo(); + } + } + + SplitResult result; + + private: + bool handleStmt(const IR::Statement *stmt) { + BUG_CHECK(result.before == nullptr && result.after == nullptr, + "More than one leaf statement found: %1% and %2%", + result.before ? result.before : result.after, stmt); + if (predicate(stmt, getChildContext())) { + result.after = stmt; + collectNeededDeclarations(stmt); + return true; + } + return false; + } + + void setCase(IR::SwitchStatement *sw, size_t i, const IR::Statement *value) { + // note that we can't go all the way to statement as it can be nullptr + modify(sw, &IR::SwitchStatement::cases, IR::Traversal::Index(i), + [value](IR::SwitchCase *case_) { + case_->statement = value; + return case_; + }); + } + + void takeHoisted(std::vector &decls) { + result.hoistedDeclarations.insert(result.hoistedDeclarations.end(), decls.begin(), + decls.end()); + decls.clear(); + } + + std::pair>, bool> splitBranches( + std::vector branches) { + std::vector> res; + bool anySplit = false; + res.reserve(branches.size()); + + for (const auto *branch : branches) { + if (!branch) { + res.emplace_back(); + continue; + } + visit(branch, "branch"); + anySplit = anySplit || result.after; + if (!result) { + result.before = branch; + } + res.emplace_back(std::move(result)); + result.clear(); + } + for (auto &[_, __, hoisted] : res) { + takeHoisted(hoisted); + } + return {res, anySplit}; + } + + void collectNeededDeclarations(const IR::Node *after) { + struct CollectNeededDecls : Inspector, ResolutionContext { + explicit CollectNeededDecls(absl::flat_hash_set &needed) + : needed(needed) {} + + void postorder(const IR::PathExpression *pe) override { + // using lower-level resolution to avoid emitting errors for things not found + if (!resolve(pe->path->name, ResolutionType::Any).empty()) { + needed.insert(pe->path->name); + } + } + + absl::flat_hash_set &needed; + }; + + after->apply(CollectNeededDecls(neededDecls), getChildContext()); + } + + template + const T *filterDeclarations(const T *node) { + struct FilterDecls : Transform { + FilterDecls(absl::flat_hash_set &needed, + std::vector &hoisted) + : needed(needed), hoisted(hoisted) {} + + const IR::Node *preorder(IR::Declaration_Variable *decl) override { + if (needed.contains(decl->name)) { + hoisted.push_back(decl); + return nullptr; + } + return decl; + } + + absl::flat_hash_set &needed; + std::vector &hoisted; + }; + + FilterDecls filter(neededDecls, result.hoistedDeclarations); + return node->apply(filter)->template checkedTo(); + } + + std::function predicate; + P4::NameGenerator &nameGen; + P4::TypeMap *typeMap; + absl::flat_hash_set &neededDecls; +}; + +SplitResult splitStatementBefore( + const IR::Statement *stat, + std::function predicate, + P4::NameGenerator &nameGen, P4::TypeMap *typeMap) { + absl::flat_hash_set neededDecls; + StatementSplitter split(predicate, nameGen, typeMap, neededDecls); + // no incoming context, declaration resolution will work only within the splitter + stat->apply(split); + return split.result; +} + +} // namespace P4 diff --git a/ir/splitter.h b/ir/splitter.h new file mode 100644 index 0000000000..be185f9b7b --- /dev/null +++ b/ir/splitter.h @@ -0,0 +1,127 @@ +/* +Copyright 2025-present Altera Corporation. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#ifndef IR_SPLITTER_H_ +#define IR_SPLITTER_H_ + +#include "ir/ir.h" +#include "ir/visitor.h" + +namespace P4 { + +class NameGenerator; +class TypeMap; + +template +struct SplitResult { + const Node *before = nullptr; + const Node *after = nullptr; + std::vector hoistedDeclarations; + + /// @brief Returns true if any splitting occured and false otherwise. + explicit operator bool() const { return after; } + void clear() { + before = after = nullptr; + hoistedDeclarations.clear(); + } +}; + +/// @brief Split @p stat so that on every control-flow path all the statements up to the first one +/// matching @p predicate are in SplitResult::before, while the rest (starting from the matching +/// statemet) is in SplitResult::after. +/// +/// @pre @p stat must not contain P4::IR::LoopStatement (loops must be unrolled before). +/// @pre All variable declarations have unique names. +/// @pre No non-standard control flow blocks exist in the IR of @p stat (only if and switch). +/// +/// @note Fresh variables are introduced to save all if conditions and switch selectors. This is to +/// ensure the right branch is triggered even if code is inserted between the split points. +/// Furthemore, declarations that would become invisible in the "after section" are hoisted. No +/// other provisions are made to isolate side effects that may be inserted between the split code +/// fragments. +/// +/// @note No inspection is done for the called object of IR::MA::MethodCallStatement (except that @p +/// predicate is applied to them as for any other statement). Therefore, called functions/actions +/// are not recursively split. +/// +/// @code{.p4} +/// a = a + 4 +/// if (b > 5) { +/// bit<3> v = c + 2; +/// t1.apply(); +/// d = 8 + v; +/// } else { +/// if (e == 4) { +/// d = 1; +/// t2.apply(); +/// c = 2 +/// } +/// } +/// @endcode +/// +/// If we split this according to calls to table.apply, we get: +/// +/// Before: +/// @code{.p4} +/// a = a + 4 +/// cond_1 = b > 5; +/// if (cond_1) { +/// v = c + 2; +/// } else { +/// cond_2 = e == 4; +/// if (cond_2) { +/// d = 1; +/// } +/// } +/// @endcode +/// +/// After: +/// @code{.p4} +/// if (cond_1) { +/// t1.apply(); +/// d = 8 + v; +/// } else { +/// if (cond_2) { +/// t2.apply(); +/// c = 2 +/// } +/// } +/// @endcode +/// +/// Hoisted declarations: +/// @code{.p4} +/// bool cond_1; +/// bool cond_2; +/// bit<3> v; +/// @endcode +/// +/// +/// @param stat The statement to split. +/// @param predicate Predicate over statements. Returns true for split point. +/// @param nameGen A name generator valid for the P4 program containing @p stat. +/// @param typeMap A P4::TypeMap valid for the program, or nullptr. In case nullptr is passed, +/// types must be encoded in the IR, or no P4::IR::SwitchStatement can be split. +/// @return The after node is empty (`nullptr`) if no splitting occurred. +/// SplitResult::hoistedDeclarations will contain declarations that have to be made available for +/// the two statements to work. +SplitResult splitStatementBefore( + const IR::Statement *stat, + std::function predicate, + P4::NameGenerator &nameGen, P4::TypeMap *typeMap = nullptr); + +} // namespace P4 + +#endif // IR_SPLITTER_H_ diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index c504b4e57b..92facc3843 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -38,6 +38,7 @@ set (GTEST_UNITTEST_SOURCES gtest/hvec_map.cpp gtest/hvec_set.cpp gtest/indexed_vector.cpp + gtest/ir-splitter.cpp gtest/ir-traversal.cpp gtest/json_test.cpp gtest/map.cpp diff --git a/test/gtest/ir-splitter.cpp b/test/gtest/ir-splitter.cpp new file mode 100644 index 0000000000..0d8436660f --- /dev/null +++ b/test/gtest/ir-splitter.cpp @@ -0,0 +1,412 @@ +/* +Copyright 2025-present Altera Corporation. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include + +#include "frontends/common/parseInput.h" +#include "frontends/common/resolveReferences/referenceMap.h" +#include "frontends/p4/typeChecking/typeChecker.h" +#include "ir/ir.h" +#include "ir/splitter.h" + +using namespace P4::literals; + +namespace P4::Test { + +struct SplitterTest : public ::testing::Test { + SplitResult splitBefore( + const IR::Statement *stat, + std::function predicate) { + stat->apply(nameGen); + return splitStatementBefore(stat, predicate, nameGen, &typeMap); + } + + const IR::Statement *parse(std::string_view code, std::string_view decs = "") { + const auto program = absl::StrCat( + "extern void fn(); extern void f1(); extern void f2(); extern void f3(); ", + "extern void f4(); extern void f5(); extern void f6(); extern void bar(); ", + "control c() { bit<4> a; bit<4> b; bit<4> c; bit<4> d; bool bvar; ", decs, "apply {", + code, "} }"); + const auto *prog = P4::parseP4String(program, CompilerOptions::FrontendVersion::P4_16); + CHECK_NULL(prog); + P4::TypeInference ti(&typeMap, false, false, false); + prog = prog->apply(ti); + CHECK_NULL(prog); + const IR::BlockStatement *bs = nullptr; + P4::forAllMatching(prog, [&](const auto *cntr) { bs = cntr->body; }); + CHECK_NULL(bs); + return bs; + } + + const IR::PathExpression *pe(std::string_view name) { + return new IR::PathExpression(P4::cstring(name)); + } + + /// a = b + const IR::AssignmentStatement *asgn(std::string_view lhs, std::string_view rhs) { + return new IR::AssignmentStatement(pe(lhs), pe(rhs)); + } + + /// a = expr + const IR::AssignmentStatement *asgn(std::string_view lhs, const IR::Expression *expr) { + return new IR::AssignmentStatement(pe(lhs), expr); + } + + /// a == b + const IR::Equ *eq(std::string_view lhs, std::string_view rhs) { + return new IR::Equ(pe(lhs), pe(rhs)); + } + + const IR::MethodCallStatement *call(std::string_view fn) { + return new IR::MethodCallStatement(new IR::MethodCallExpression(pe(fn))); + } + + const IR::BlockStatement *blk(IR::IndexedVector stmts) { + return new IR::BlockStatement(std::move(stmts)); + } + + const IR::IfStatement *ifs(const IR::Expression *cond, const IR::Statement *tr, + const IR::Statement *fls = nullptr) { + return new IR::IfStatement(cond, tr, fls); + } + + MinimalNameGenerator nameGen; + TypeMap typeMap; +}; + +#define EXPECT_EQUIV(a, b) \ + EXPECT_TRUE(a->equiv(*b)) << "Actual:" << Log::indent << Log::endl \ + << a << Log::unindent << "\nExpected: " << Log::indent << Log::endl \ + << b << Log::unindent; + +template +bool predIs(const IR::Statement *stmt, const P4::Visitor::Context *) { + return stmt->is(); +} + +TEST_F(SplitterTest, SplitBsEmpty) { + const auto *bs = blk({}); + auto [before, after, decls] = splitBefore(bs, &predIs); + ASSERT_TRUE(before); + const auto *bbs = before->to(); + ASSERT_TRUE(bbs) << before; + EXPECT_EQ(bbs->components.size(), 0); + EXPECT_FALSE(after); + ASSERT_EQ(decls.size(), 0) << decls; +} + +TEST_F(SplitterTest, SplitBsSimple1) { + const auto *bs = blk({asgn("a", "b")}); + auto [before, after, decls] = splitBefore(bs, &predIs); + ASSERT_TRUE(before); + + const auto *bbs = before->to(); + ASSERT_TRUE(bbs) << before; + EXPECT_EQ(bbs->components.size(), 0) << bbs; + + const auto *abs = after->to(); + ASSERT_TRUE(abs) << after; + ASSERT_EQ(abs->components.size(), 1) << abs; + EXPECT_TRUE(abs->components.front()->is()) << abs; + + ASSERT_EQ(decls.size(), 0) << decls; +} + +TEST_F(SplitterTest, SplitBsSimple2) { + const auto *bs = blk({call("fn"), asgn("a", "b")}); + auto [before, after, decls] = splitBefore(bs, &predIs); + ASSERT_TRUE(before); + + const auto *bbs = before->to(); + ASSERT_TRUE(bbs) << before; + ASSERT_EQ(bbs->components.size(), 1) << before; + EXPECT_TRUE(bbs->components.front()->is()) << bbs; + + const auto *abs = after->to(); + ASSERT_TRUE(abs) << after; + ASSERT_EQ(abs->components.size(), 1) << abs; + EXPECT_TRUE(abs->components.front()->is()) << abs; + + ASSERT_EQ(decls.size(), 0) << decls; +} + +TEST_F(SplitterTest, SplitBsIfSingleBranch) { + const auto *bs = ifs(eq("a", "b"), blk({call("fn"), asgn("a", "b")})); + auto [before, after, decls] = splitBefore(bs, &predIs); + ASSERT_TRUE(before); + + const auto *bbs = before->to(); + ASSERT_TRUE(bbs) << before; + ASSERT_EQ(bbs->components.size(), 2) << before; + EXPECT_TRUE(bbs->components.at(0)->is()) << bbs; + const auto *bifs = bbs->components.at(1)->to(); + ASSERT_TRUE(bifs) << bbs; + ASSERT_TRUE(bifs->ifTrue) << bifs; + EXPECT_FALSE(bifs->ifFalse) << bifs; + + EXPECT_EQUIV(before, blk({asgn("cond", eq("a", "b")), ifs(pe("cond"), blk({call("fn")}))})); + + const auto *aifs = after->to(); + ASSERT_TRUE(aifs) << after; + ASSERT_TRUE(aifs->ifTrue) << aifs; + EXPECT_FALSE(aifs->ifFalse) << aifs; + + EXPECT_EQUIV(after, ifs(pe("cond"), blk({asgn("a", "b")}))); + + ASSERT_EQ(decls.size(), 1) << decls; + EXPECT_EQUIV(decls.front(), new IR::Declaration_Variable(IR::ID{"cond"_cs, nullptr}, + IR::Type::Boolean::get())); +} + +TEST_F(SplitterTest, SplitBsIfTwoBranches) { + const auto *bs = ifs(eq("a", "b"), blk({call("fn"), asgn("a", "b")}), blk({asgn("c", "d")})); + auto [before, after, decls] = splitBefore(bs, &predIs); + ASSERT_TRUE(before); + + EXPECT_EQUIV(before, + blk({asgn("cond", eq("a", "b")), ifs(pe("cond"), blk({call("fn")}), blk({}))})); + + EXPECT_EQUIV(after, ifs(pe("cond"), blk({asgn("a", "b")}), blk({asgn("c", "d")}))); + + ASSERT_EQ(decls.size(), 1) << decls; + EXPECT_EQUIV(decls.front(), new IR::Declaration_Variable(IR::ID{"cond"_cs, nullptr}, + IR::Type::Boolean::get())); +} + +TEST_F(SplitterTest, SplitBsIfNested) { + const auto *bs = parse(R"( +if (a == b) { + fn(); + if (bvar) { + f2(); + a = b; + c = b; + f3(); + } else { + f2(); + f4(); + } + f5(); +} else { + c = d; + bar(); +})"); + auto [before, after, decls] = splitBefore(bs, &predIs); + ASSERT_TRUE(before); + + EXPECT_EQUIV(before, parse(R"( +{ + cond_0 = a == b; + if (cond_0) { + fn(); + { + cond = bvar; + if (cond) { + f2(); + // cut + } else { + f2(); + f4(); + } + } + } else { + // cut + } +})", + "bool cond_0; bool cond;")); + + EXPECT_EQUIV(after, parse(R"( +if (cond_0) { + if (cond) { + // cut + a = b; + c = b; + f3(); + } + f5(); +} else { + // cut + c = d; + bar(); +})", + "bool cond_0; bool cond;")); + + ASSERT_EQ(decls.size(), 2) << decls; + EXPECT_EQUIV(decls.at(0), new IR::Declaration_Variable(IR::ID{"cond"_cs, nullptr}, + IR::Type::Boolean::get())); + EXPECT_EQUIV(decls.at(1), new IR::Declaration_Variable(IR::ID{"cond_0"_cs, nullptr}, + IR::Type::Boolean::get())); +} + +TEST_F(SplitterTest, IfSwitchNested) { + const auto *bs = parse(R"( +if (a == b) { + fn(); + switch (a) { + 0: + 1: { f2(); } // test we don't create fallthrough in "after" + 2: { f3(); a = b; f4(); } + 3: {} + // test non-block in IF + default: { if (a > 5) a = c; else f6(); } + } + f5(); +} else { + bar(); +})"); + auto [before, after, decls] = splitBefore(bs, &predIs); + ASSERT_TRUE(before); + + EXPECT_EQUIV(before, parse(R"( +{ + cond_0 = a == b; + if (cond_0) { + fn(); + { + selector = a; + switch (selector) { + 0: + 1: { f2(); } + 2: { f3(); } + 3: {} + default: { + { + cond = a > 5; + if (cond) {} + else f6(); + } + } + } + } + } else { + bar(); + } +})", + "bool cond_0; bool cond; bit<4> selector; ")); + + EXPECT_EQUIV(after, parse(R"( +if (cond_0) { + switch (selector) { + 0: + 1: {} + 2: { a = b; f4(); } + 3: {} + default: { if (cond) a = c; } + } + f5(); +} +)", + "bool cond; bool cond_0; bit<4> selector; ")); + + ASSERT_EQ(decls.size(), 3) << decls; + EXPECT_EQUIV(decls.at(0), new IR::Declaration_Variable(IR::ID{"cond"_cs, nullptr}, + IR::Type::Boolean::get())); + EXPECT_EQUIV(decls.at(1), new IR::Declaration_Variable(IR::ID{"selector"_cs, nullptr}, + IR::Type::Bits::get(4))); + EXPECT_EQUIV(decls.at(2), new IR::Declaration_Variable(IR::ID{"cond_0"_cs, nullptr}, + IR::Type::Boolean::get())); +} + +TEST_F(SplitterTest, HoistVarIf) { + const auto *bs = parse(R"( +if (a == b) { + bit<4> x; + bit<4> y; + y = c + 4; + x = y + 2; + // split + f1(); + x = c > a ? x + c : x + a; +})"); + auto [before, after, decls] = splitBefore(bs, &predIs); + ASSERT_TRUE(before); + + EXPECT_EQUIV(before, parse(R"( +{ + cond = a == b; + if (cond) { + bit<4> y; + y = c + 4; + x = y + 2; + // split + } +})", + "bool cond; bit<4> x; ")); + + EXPECT_EQUIV(after, parse(R"( +if (cond) { + f1(); + x = c > a ? x + c : x + a; +})", + "bool cond; bit<4> x; ")); + + ASSERT_EQ(decls.size(), 2) << decls; + EXPECT_EQUIV(decls.at(0), new IR::Declaration_Variable(IR::ID{"cond"_cs, nullptr}, + IR::Type::Boolean::get())); + EXPECT_EQUIV(decls.at(1), + new IR::Declaration_Variable(IR::ID{"x"_cs, nullptr}, IR::Type::Bits::get(4))); +} + +TEST_F(SplitterTest, HoistVarSwitch) { + const auto *bs = parse(R"( +switch (a) { + 0: { + bit<4> x; + bit<4> y; + y = c + 4; + x = y + 2; + // split + f1(); + x = c > a ? x + c : x + a; + } + 1: { a = b; } +})"); + auto [before, after, decls] = splitBefore(bs, &predIs); + ASSERT_TRUE(before); + + EXPECT_EQUIV(before, parse(R"( +{ + selector = a; + switch (selector) { + 0: { + bit<4> y; + y = c + 4; + x = y + 2; + // split + } + 1: { a = b; } + } +})", + "bit<4> selector; bit<4> x; ")); + + EXPECT_EQUIV(after, parse(R"( +switch (selector) { + 0: { + f1(); + x = c > a ? x + c : x + a; + } + 1: {} +})", + "bit<4> selector; bit<4> x; ")); + + ASSERT_EQ(decls.size(), 2) << decls; + EXPECT_EQUIV(decls.at(0), new IR::Declaration_Variable(IR::ID{"selector"_cs, nullptr}, + IR::Type::Bits::get(4))); + EXPECT_EQUIV(decls.at(1), + new IR::Declaration_Variable(IR::ID{"x"_cs, nullptr}, IR::Type::Bits::get(4))); +} + +} // namespace P4::Test