Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
212 changes: 86 additions & 126 deletions netkat/packet_transformer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,15 +81,17 @@ PacketTransformerManager::GetNodeOrDie(
// probably going to cause terrible performance, and needs to be revisited.
absl::btree_map<int, PacketTransformerHandle>
PacketTransformerManager::GetMapAtValue(const DecisionNode& node, int value) {
if (node.modify_branch_by_field_match.contains(value))
return node.modify_branch_by_field_match.at(value);
if (auto it = node.modify_branch_by_field_match.find(value);
it != node.modify_branch_by_field_match.end()) {
return it->second;
}

absl::btree_map<int, PacketTransformerHandle> result =
node.default_branch_by_field_modification;
if (result.contains(value) || IsDeny(node.default_branch)) return result;

// Otherwise, add a mapping from `value` to the default branch, then return.
result[value] = node.default_branch;
if (result.find(value) == result.end() && !IsDeny(node.default_branch)) {
// Otherwise, add a mapping from `value` to the default branch, then return.
result[value] = node.default_branch;
}
return result;
}

Expand Down Expand Up @@ -351,26 +353,21 @@ PacketTransformerHandle PacketTransformerManager::Modification(
}

namespace {
absl::btree_map<int, PacketTransformerHandle> CombineModifyBranches(
const absl::btree_map<int, PacketTransformerHandle>& left,
void CombineModifyBranches(
absl::btree_map<int, PacketTransformerHandle>& result,
const absl::btree_map<int, PacketTransformerHandle>& right,
absl::AnyInvocable<PacketTransformerHandle(PacketTransformerHandle,
PacketTransformerHandle)>
combiner,
PacketTransformerHandle default_value) {
absl::btree_map<int, PacketTransformerHandle> result;
for (const auto& [value, branch] : left) {
if (right.contains(value)) {
result[value] = combiner(branch, right.at(value));
} else {
result[value] = combiner(branch, default_value);
}
}
for (const auto& [value, branch] : right) {
if (!result.contains(value))
auto it = result.find(value);
if (it != result.end()) {
it->second = combiner(it->second, branch);
} else {
result[value] = combiner(default_value, branch);
}
}
return result;
}

} // namespace
Expand Down Expand Up @@ -413,16 +410,16 @@ PacketTransformerHandle PacketTransformerManager::Sequence(DecisionNode left,
right_applied_to_left_modifications;
for (const auto& [value, branch] :
left.default_branch_by_field_modification) {
absl::btree_map<int, PacketTransformerHandle> right_at_value_with_sequence =
CombineModifyBranches(
{}, GetMapAtValue(right, value),
/*combiner=*/
[this](PacketTransformerHandle left,
PacketTransformerHandle right) {
return Sequence(left, right);
},
/*default_value=*/branch);
right_applied_to_left_modifications = CombineModifyBranches(
absl::btree_map<int, PacketTransformerHandle> right_at_value_with_sequence;
CombineModifyBranches(
right_at_value_with_sequence, GetMapAtValue(right, value),
/*combiner=*/
[this](PacketTransformerHandle left, PacketTransformerHandle right) {
return Sequence(left, right);
},
/*default_value=*/branch);

CombineModifyBranches(
right_applied_to_left_modifications, right_at_value_with_sequence,
/*combiner=*/
[this](PacketTransformerHandle left, PacketTransformerHandle right) {
Expand All @@ -431,74 +428,59 @@ PacketTransformerHandle PacketTransformerManager::Sequence(DecisionNode left,
/*default_value=*/Deny());
}

result_node.default_branch_by_field_modification = CombineModifyBranches(
CombineModifyBranches(
result_node.default_branch_by_field_modification,
right_applied_to_left_modifications,
CombineModifyBranches(
{}, right.default_branch_by_field_modification,
[this](PacketTransformerHandle left, PacketTransformerHandle right) {
return Sequence(left, right);
},
/*default_value=*/left.default_branch),
[this](PacketTransformerHandle left, PacketTransformerHandle right) {
return Union(left, right);
},
/*default_value=*/Deny());
Deny());

// Collect every value mapped in each node.
absl::flat_hash_set<int> all_possible_values;
all_possible_values.reserve(
left.modify_branch_by_field_match.size() +
right.modify_branch_by_field_match.size() +
left.default_branch_by_field_modification.size() +
right.default_branch_by_field_modification.size() +
right_applied_to_left_modifications.size());

absl::c_transform(
left.modify_branch_by_field_match,
std::inserter(all_possible_values, all_possible_values.end()),
[](auto pair) { return pair.first; });
absl::c_transform(
right.modify_branch_by_field_match,
std::inserter(all_possible_values, all_possible_values.end()),
[](auto pair) { return pair.first; });
absl::c_transform(
left.default_branch_by_field_modification,
std::inserter(all_possible_values, all_possible_values.end()),
[](auto pair) { return pair.first; });
absl::c_transform(
CombineModifyBranches(
result_node.default_branch_by_field_modification,
right.default_branch_by_field_modification,
std::inserter(all_possible_values, all_possible_values.end()),
[](auto pair) { return pair.first; });
absl::c_transform(
right_applied_to_left_modifications,
std::inserter(all_possible_values, all_possible_values.end()),
[](auto pair) { return pair.first; });
[this](PacketTransformerHandle left, PacketTransformerHandle right) {
return Sequence(left, right);
},
left.default_branch);

// For every value in mapped in each node, construct the proper new branch.
for (int value : all_possible_values) {
auto left_map_at_value = GetMapAtValue(left, value);
// An empty map is equivalent to a map with a single entry of
// <value, Deny>, but the latter is not always canonical. However, an empty
// map won't work correctly for the merges below (an in fact, the whole
// for-loop would be skipped), so we expand it here if necessary.
if (left_map_at_value.empty()) left_map_at_value[value] = Deny();

for (const auto& [left_value, left_spp] : left_map_at_value) {
result_node.modify_branch_by_field_match[value] = CombineModifyBranches(
result_node.modify_branch_by_field_match[value],
CombineModifyBranches(
{}, GetMapAtValue(right, left_value),
/*combiner=*/
[this](PacketTransformerHandle left,
PacketTransformerHandle right) {
return Sequence(left, right);
},
/*default_value=*/left_spp),
/*combiner=*/
for (auto const& [match_value, left_modify_map] :
left.modify_branch_by_field_match) {
auto& result_modify_map = result_node.modify_branch_by_field_match[match_value];
for (const auto& [left_value, left_spp] : left_modify_map) {
absl::btree_map<int, PacketTransformerHandle> right_at_value_with_sequence;
CombineModifyBranches(
right_at_value_with_sequence, GetMapAtValue(right, left_value),
[this](PacketTransformerHandle left, PacketTransformerHandle right) {
return Sequence(left, right);
},
left_spp);
CombineModifyBranches(
result_modify_map, right_at_value_with_sequence,
[this](PacketTransformerHandle left, PacketTransformerHandle right) {
return Union(left, right);
},
/*default_value=*/Deny());
Deny());
}
}

for (auto const& [match_value, right_modify_map] :
right.modify_branch_by_field_match) {
auto& result_modify_map = result_node.modify_branch_by_field_match[match_value];
auto left_map_at_value = GetMapAtValue(left, match_value);
for (const auto& [right_value, right_branch] : right_modify_map) {
if (auto it = left_map_at_value.find(right_value);
it != left_map_at_value.end()) {
auto& result_branch = result_modify_map[right_value];
result_branch =
Union(result_branch, Sequence(it->second, right_branch));
} else {
auto& result_branch = result_modify_map[right_value];
result_branch = Union(
result_branch,
Sequence(left.default_branch, right_branch));
}
}
}

Expand Down Expand Up @@ -545,57 +527,34 @@ PacketTransformerHandle PacketTransformerManager::Union(DecisionNode left,
DCHECK(left.field == right.field);
DecisionNode result_node{
.field = left.field,
.default_branch_by_field_modification = CombineModifyBranches(
left.default_branch_by_field_modification,
right.default_branch_by_field_modification,
/*combiner=*/
[this](PacketTransformerHandle left, PacketTransformerHandle right) {
return Union(left, right);
},
/*default_value=*/Deny()),
.default_branch_by_field_modification = left.default_branch_by_field_modification,
.default_branch = Union(left.default_branch, right.default_branch),
};

// Collect every value in mapped in each node.
absl::flat_hash_set<int> all_possible_values;
all_possible_values.reserve(
left.modify_branch_by_field_match.size() +
right.modify_branch_by_field_match.size() +
left.default_branch_by_field_modification.size() +
right.default_branch_by_field_modification.size());

absl::c_transform(
left.modify_branch_by_field_match,
std::inserter(all_possible_values, all_possible_values.end()),
[](auto pair) { return pair.first; });
absl::c_transform(
right.modify_branch_by_field_match,
std::inserter(all_possible_values, all_possible_values.end()),
[](auto pair) { return pair.first; });
absl::c_transform(
left.default_branch_by_field_modification,
std::inserter(all_possible_values, all_possible_values.end()),
[](auto pair) { return pair.first; });
absl::c_transform(
CombineModifyBranches(
result_node.default_branch_by_field_modification,
right.default_branch_by_field_modification,
std::inserter(all_possible_values, all_possible_values.end()),
[](auto pair) { return pair.first; });
/*combiner=*/
[this](PacketTransformerHandle left, PacketTransformerHandle right) {
return Union(left, right);
},
/*default_value=*/Deny());

// For every value in mapped in each node, construct the proper new branch.
// TODO(dilo): Would like to use absl::bind_front here instead of a lambda:
// absl::bind_front<PacketTransformerHandle(PacketTransformerHandle,
// PacketTransformerHandle)>(
// &PacketTransformerManager::Union, this),
for (int value : all_possible_values) {
result_node.modify_branch_by_field_match[value] = CombineModifyBranches(
GetMapAtValue(left, value), GetMapAtValue(right, value),
result_node.modify_branch_by_field_match = left.modify_branch_by_field_match;
for (auto const& [match_value, right_modify_map] :
right.modify_branch_by_field_match) {
auto& result_modify_map =
result_node.modify_branch_by_field_match[match_value];
CombineModifyBranches(
result_modify_map, right_modify_map,
/*combiner=*/
[this](PacketTransformerHandle left, PacketTransformerHandle right) {
return Union(left, right);
},
/*default_value=*/Deny());
}


return NodeToTransformer(std::move(result_node));
}

Expand Down Expand Up @@ -706,7 +665,7 @@ std::string PacketTransformerManager::ToString(
while (!work_list.empty()) {
PacketTransformerHandle transformer = work_list.front();
work_list.pop();
absl::StrAppend(&result, transformer);
absl::StrAppend(&result, transformer.ToString());

if (IsAccept(transformer) || IsDeny(transformer)) continue;

Expand Down Expand Up @@ -749,7 +708,7 @@ std::string PacketTransformerManager::ToDot(
const absl::btree_map<int, PacketTransformerHandle>& map) {
for (const auto& [new_value, branch] : map) {
absl::StrAppendFormat(
&result, " %d -> %d [label=\"%s==%s; %s:=%d\"]\n", node_index,
&result, " %d -> %d [label=\"%s==%s;\\n%s:=%d\"]\n", node_index,
branch.node_index_, field, old_value, field, new_value);
if (IsAccept(branch) || IsDeny(branch)) continue;
bool new_branch = visited.insert(branch).second;
Expand Down Expand Up @@ -785,6 +744,7 @@ std::string PacketTransformerManager::ToDot(
return result;
}


absl::Status PacketTransformerManager::CheckInternalInvariants() const {
// Invariant: Proper and sentinel node indices are disjoint.
RET_CHECK(nodes_.size() <= SentinelNodeIndex::kMinSentinel);
Expand Down
Loading