diff --git a/netkat/packet_transformer.cc b/netkat/packet_transformer.cc index d57dd6d..bdbe19d 100644 --- a/netkat/packet_transformer.cc +++ b/netkat/packet_transformer.cc @@ -81,15 +81,17 @@ PacketTransformerManager::GetNodeOrDie( // probably going to cause terrible performance, and needs to be revisited. absl::btree_map PacketTransformerManager::GetMapAtValue(const DecisionNode& node, int value) { - if (node.modify_branch_by_field_match.contains(value)) - return node.modify_branch_by_field_match.at(value); + if (auto it = node.modify_branch_by_field_match.find(value); + it != node.modify_branch_by_field_match.end()) { + return it->second; + } absl::btree_map result = node.default_branch_by_field_modification; - if (result.contains(value) || IsDeny(node.default_branch)) return result; - - // Otherwise, add a mapping from `value` to the default branch, then return. - result[value] = node.default_branch; + if (result.find(value) == result.end() && !IsDeny(node.default_branch)) { + // Otherwise, add a mapping from `value` to the default branch, then return. + result[value] = node.default_branch; + } return result; } @@ -351,26 +353,21 @@ PacketTransformerHandle PacketTransformerManager::Modification( } namespace { -absl::btree_map CombineModifyBranches( - const absl::btree_map& left, +void CombineModifyBranches( + absl::btree_map& result, const absl::btree_map& right, absl::AnyInvocable combiner, PacketTransformerHandle default_value) { - absl::btree_map result; - for (const auto& [value, branch] : left) { - if (right.contains(value)) { - result[value] = combiner(branch, right.at(value)); - } else { - result[value] = combiner(branch, default_value); - } - } for (const auto& [value, branch] : right) { - if (!result.contains(value)) + auto it = result.find(value); + if (it != result.end()) { + it->second = combiner(it->second, branch); + } else { result[value] = combiner(default_value, branch); + } } - return result; } } // namespace @@ -413,16 +410,16 @@ PacketTransformerHandle PacketTransformerManager::Sequence(DecisionNode left, right_applied_to_left_modifications; for (const auto& [value, branch] : left.default_branch_by_field_modification) { - absl::btree_map right_at_value_with_sequence = - CombineModifyBranches( - {}, GetMapAtValue(right, value), - /*combiner=*/ - [this](PacketTransformerHandle left, - PacketTransformerHandle right) { - return Sequence(left, right); - }, - /*default_value=*/branch); - right_applied_to_left_modifications = CombineModifyBranches( + absl::btree_map right_at_value_with_sequence; + CombineModifyBranches( + right_at_value_with_sequence, GetMapAtValue(right, value), + /*combiner=*/ + [this](PacketTransformerHandle left, PacketTransformerHandle right) { + return Sequence(left, right); + }, + /*default_value=*/branch); + + CombineModifyBranches( right_applied_to_left_modifications, right_at_value_with_sequence, /*combiner=*/ [this](PacketTransformerHandle left, PacketTransformerHandle right) { @@ -431,74 +428,59 @@ PacketTransformerHandle PacketTransformerManager::Sequence(DecisionNode left, /*default_value=*/Deny()); } - result_node.default_branch_by_field_modification = CombineModifyBranches( + CombineModifyBranches( + result_node.default_branch_by_field_modification, right_applied_to_left_modifications, - CombineModifyBranches( - {}, right.default_branch_by_field_modification, - [this](PacketTransformerHandle left, PacketTransformerHandle right) { - return Sequence(left, right); - }, - /*default_value=*/left.default_branch), [this](PacketTransformerHandle left, PacketTransformerHandle right) { return Union(left, right); }, - /*default_value=*/Deny()); + Deny()); - // Collect every value mapped in each node. - absl::flat_hash_set all_possible_values; - all_possible_values.reserve( - left.modify_branch_by_field_match.size() + - right.modify_branch_by_field_match.size() + - left.default_branch_by_field_modification.size() + - right.default_branch_by_field_modification.size() + - right_applied_to_left_modifications.size()); - - absl::c_transform( - left.modify_branch_by_field_match, - std::inserter(all_possible_values, all_possible_values.end()), - [](auto pair) { return pair.first; }); - absl::c_transform( - right.modify_branch_by_field_match, - std::inserter(all_possible_values, all_possible_values.end()), - [](auto pair) { return pair.first; }); - absl::c_transform( - left.default_branch_by_field_modification, - std::inserter(all_possible_values, all_possible_values.end()), - [](auto pair) { return pair.first; }); - absl::c_transform( + CombineModifyBranches( + result_node.default_branch_by_field_modification, right.default_branch_by_field_modification, - std::inserter(all_possible_values, all_possible_values.end()), - [](auto pair) { return pair.first; }); - absl::c_transform( - right_applied_to_left_modifications, - std::inserter(all_possible_values, all_possible_values.end()), - [](auto pair) { return pair.first; }); + [this](PacketTransformerHandle left, PacketTransformerHandle right) { + return Sequence(left, right); + }, + left.default_branch); // For every value in mapped in each node, construct the proper new branch. - for (int value : all_possible_values) { - auto left_map_at_value = GetMapAtValue(left, value); - // An empty map is equivalent to a map with a single entry of - // , but the latter is not always canonical. However, an empty - // map won't work correctly for the merges below (an in fact, the whole - // for-loop would be skipped), so we expand it here if necessary. - if (left_map_at_value.empty()) left_map_at_value[value] = Deny(); - - for (const auto& [left_value, left_spp] : left_map_at_value) { - result_node.modify_branch_by_field_match[value] = CombineModifyBranches( - result_node.modify_branch_by_field_match[value], - CombineModifyBranches( - {}, GetMapAtValue(right, left_value), - /*combiner=*/ - [this](PacketTransformerHandle left, - PacketTransformerHandle right) { - return Sequence(left, right); - }, - /*default_value=*/left_spp), - /*combiner=*/ + for (auto const& [match_value, left_modify_map] : + left.modify_branch_by_field_match) { + auto& result_modify_map = result_node.modify_branch_by_field_match[match_value]; + for (const auto& [left_value, left_spp] : left_modify_map) { + absl::btree_map right_at_value_with_sequence; + CombineModifyBranches( + right_at_value_with_sequence, GetMapAtValue(right, left_value), + [this](PacketTransformerHandle left, PacketTransformerHandle right) { + return Sequence(left, right); + }, + left_spp); + CombineModifyBranches( + result_modify_map, right_at_value_with_sequence, [this](PacketTransformerHandle left, PacketTransformerHandle right) { return Union(left, right); }, - /*default_value=*/Deny()); + Deny()); + } + } + + for (auto const& [match_value, right_modify_map] : + right.modify_branch_by_field_match) { + auto& result_modify_map = result_node.modify_branch_by_field_match[match_value]; + auto left_map_at_value = GetMapAtValue(left, match_value); + for (const auto& [right_value, right_branch] : right_modify_map) { + if (auto it = left_map_at_value.find(right_value); + it != left_map_at_value.end()) { + auto& result_branch = result_modify_map[right_value]; + result_branch = + Union(result_branch, Sequence(it->second, right_branch)); + } else { + auto& result_branch = result_modify_map[right_value]; + result_branch = Union( + result_branch, + Sequence(left.default_branch, right_branch)); + } } } @@ -545,50 +527,26 @@ PacketTransformerHandle PacketTransformerManager::Union(DecisionNode left, DCHECK(left.field == right.field); DecisionNode result_node{ .field = left.field, - .default_branch_by_field_modification = CombineModifyBranches( - left.default_branch_by_field_modification, - right.default_branch_by_field_modification, - /*combiner=*/ - [this](PacketTransformerHandle left, PacketTransformerHandle right) { - return Union(left, right); - }, - /*default_value=*/Deny()), + .default_branch_by_field_modification = left.default_branch_by_field_modification, .default_branch = Union(left.default_branch, right.default_branch), }; - // Collect every value in mapped in each node. - absl::flat_hash_set all_possible_values; - all_possible_values.reserve( - left.modify_branch_by_field_match.size() + - right.modify_branch_by_field_match.size() + - left.default_branch_by_field_modification.size() + - right.default_branch_by_field_modification.size()); - - absl::c_transform( - left.modify_branch_by_field_match, - std::inserter(all_possible_values, all_possible_values.end()), - [](auto pair) { return pair.first; }); - absl::c_transform( - right.modify_branch_by_field_match, - std::inserter(all_possible_values, all_possible_values.end()), - [](auto pair) { return pair.first; }); - absl::c_transform( - left.default_branch_by_field_modification, - std::inserter(all_possible_values, all_possible_values.end()), - [](auto pair) { return pair.first; }); - absl::c_transform( + CombineModifyBranches( + result_node.default_branch_by_field_modification, right.default_branch_by_field_modification, - std::inserter(all_possible_values, all_possible_values.end()), - [](auto pair) { return pair.first; }); + /*combiner=*/ + [this](PacketTransformerHandle left, PacketTransformerHandle right) { + return Union(left, right); + }, + /*default_value=*/Deny()); - // For every value in mapped in each node, construct the proper new branch. - // TODO(dilo): Would like to use absl::bind_front here instead of a lambda: - // absl::bind_front( - // &PacketTransformerManager::Union, this), - for (int value : all_possible_values) { - result_node.modify_branch_by_field_match[value] = CombineModifyBranches( - GetMapAtValue(left, value), GetMapAtValue(right, value), + result_node.modify_branch_by_field_match = left.modify_branch_by_field_match; + for (auto const& [match_value, right_modify_map] : + right.modify_branch_by_field_match) { + auto& result_modify_map = + result_node.modify_branch_by_field_match[match_value]; + CombineModifyBranches( + result_modify_map, right_modify_map, /*combiner=*/ [this](PacketTransformerHandle left, PacketTransformerHandle right) { return Union(left, right); @@ -596,6 +554,7 @@ PacketTransformerHandle PacketTransformerManager::Union(DecisionNode left, /*default_value=*/Deny()); } + return NodeToTransformer(std::move(result_node)); } @@ -706,7 +665,7 @@ std::string PacketTransformerManager::ToString( while (!work_list.empty()) { PacketTransformerHandle transformer = work_list.front(); work_list.pop(); - absl::StrAppend(&result, transformer); + absl::StrAppend(&result, transformer.ToString()); if (IsAccept(transformer) || IsDeny(transformer)) continue; @@ -749,7 +708,7 @@ std::string PacketTransformerManager::ToDot( const absl::btree_map& map) { for (const auto& [new_value, branch] : map) { absl::StrAppendFormat( - &result, " %d -> %d [label=\"%s==%s; %s:=%d\"]\n", node_index, + &result, " %d -> %d [label=\"%s==%s;\\n%s:=%d\"]\n", node_index, branch.node_index_, field, old_value, field, new_value); if (IsAccept(branch) || IsDeny(branch)) continue; bool new_branch = visited.insert(branch).second; @@ -785,6 +744,7 @@ std::string PacketTransformerManager::ToDot( return result; } + absl::Status PacketTransformerManager::CheckInternalInvariants() const { // Invariant: Proper and sentinel node indices are disjoint. RET_CHECK(nodes_.size() <= SentinelNodeIndex::kMinSentinel);