Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions netkat/evaluator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,15 @@ absl::flat_hash_set<Packet> Evaluate(const PolicyProto& policy,
} while (last_size != result.size());
return result;
}
case PolicyProto::kDifferenceOp: {
absl::flat_hash_set<Packet> result =
Evaluate(policy.difference_op().left(), packet);
for (const Packet& packet :
Evaluate(policy.difference_op().right(), packet)) {
if (result.contains(packet)) result.erase(packet);
}
return result;
}
case PolicyProto::POLICY_NOT_SET:
// Unset policy is treated as Deny.
return {};
Expand Down
13 changes: 13 additions & 0 deletions netkat/evaluator_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,19 @@ void UnionCombines(Packet packet, PolicyProto left, PolicyProto right) {
}
FUZZ_TEST(EvaluatePolicyProtoTest, UnionCombines);

void DifferenceRemoves(Packet packet, PolicyProto left, PolicyProto right) {
absl::flat_hash_set<Packet> expected_packets = Evaluate(left, packet);
for (const Packet& packet : Evaluate(right, packet)) {
if (expected_packets.contains(packet)) {
expected_packets.erase(packet);
}
}

EXPECT_THAT(Evaluate(DifferenceProto(left, right), packet),
ContainerEq(expected_packets));
}
FUZZ_TEST(EvaluatePolicyProtoTest, DifferenceRemoves);

void SequenceSequences(Packet packet, PolicyProto left, PolicyProto right) {
absl::flat_hash_set<Packet> expected_packets =
Evaluate(right, Evaluate(left, packet));
Expand Down
15 changes: 15 additions & 0 deletions netkat/frontend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,16 @@ absl::Status RecursivelyCheckIsValid(const PolicyProto& policy_proto) {
RecursivelyCheckIsValid(policy_proto.iterate_op().iterable()))
<< "PolicyProto::Iterate::policy is invalid: ";
return absl::OkStatus();
case PolicyProto::kDifferenceOp:
RETURN_IF_ERROR(
RecursivelyCheckIsValid(policy_proto.difference_op().left()))
.SetPrepend()
<< "PolicyProto::DifferenceOp::left is invalid: ";
RETURN_IF_ERROR(
RecursivelyCheckIsValid(policy_proto.difference_op().right()))
.SetPrepend()
<< "PolicyProto::DifferenceOp::right is invalid: ";
return absl::OkStatus();
case PolicyProto::POLICY_NOT_SET:
return absl::InvalidArgumentError("Unset Policy case is invalid");
}
Expand Down Expand Up @@ -179,6 +189,11 @@ Policy Filter(Predicate predicate) {
return Policy(FilterProto(std::move(predicate).ToProto()));
}

Policy Difference(Policy left, Policy right) {
return Policy(
DifferenceProto(std::move(left).ToProto(), std::move(right).ToProto()));
}

Policy Policy::Accept() { return Filter(Predicate::True()); }

Policy Policy::Deny() { return Filter(Predicate::False()); }
Expand Down
3 changes: 2 additions & 1 deletion netkat/frontend.h
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ class Policy {
friend Policy Sequence(std::vector<Policy>);
friend Policy Union(std::vector<Policy>);
friend Policy Iterate(Policy);
friend Policy Difference(Policy, Policy);
friend Policy Record();

// Policies that conceptually represent a program that should accept or
Expand Down Expand Up @@ -254,7 +255,7 @@ Policy Modify(absl::string_view field, int new_value);
Policy Sequence(std::vector<Policy> policies);

// Allows callers to Sequence policies without wrapping them in a list. Prefer
// this overload when reasonble. For example, instead of
// this overload when reasonable. For example, instead of
//
// Sequence({p0, p1, p2, p3})
//
Expand Down
11 changes: 11 additions & 0 deletions netkat/frontend_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,9 @@ void ExpectFromProtoToFailWithInvalidPolicyProto(PolicyProto policy_proto) {
case PolicyProto::kIterateOp:
policy_proto.mutable_iterate_op()->clear_iterable();
break;
case PolicyProto::kDifferenceOp:
policy_proto.mutable_difference_op()->clear_left();
break;
// Unset policy is invalid.
case PolicyProto::POLICY_NOT_SET:
break;
Expand Down Expand Up @@ -209,6 +212,14 @@ void IterateToProtoIsCorrect(Policy policy) {
FUZZ_TEST(FrontEndTest, IterateToProtoIsCorrect)
.WithDomains(/*policy=*/AtomicDupFreePolicyDomain());

void DifferenceToProtoIsCorrect(Policy left, Policy right) {
EXPECT_THAT(Difference(left, right).ToProto(),
EqualsProto(DifferenceProto(left.ToProto(), right.ToProto())));
}
FUZZ_TEST(FrontEndTest, DifferenceToProtoIsCorrect)
.WithDomains(/*policy=*/AtomicDupFreePolicyDomain(),
/*policy=*/AtomicDupFreePolicyDomain());

TEST(FrontEndTest, SequenceWithNoElementsIsAccept) {
EXPECT_THAT(Sequence().ToProto(), EqualsProto(AcceptProto()));
}
Expand Down
7 changes: 7 additions & 0 deletions netkat/netkat.proto
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ message PolicyProto {
Sequence sequence_op = 4;
Union union_op = 5;
Iterate iterate_op = 6;
Difference difference_op = 7;
}

// Sets the field to the given value.
Expand Down Expand Up @@ -148,6 +149,12 @@ message PolicyProto {
PolicyProto iterable = 1;
}

// Represents the difference of two policies, i.e. a - b.
message Difference {
PolicyProto left = 1;
PolicyProto right = 2;
}

// Records the packet, at the given point, into the history. Referred to as
// "dup" in the literature.
message Record {}
Expand Down
11 changes: 11 additions & 0 deletions netkat/netkat_proto_constructors.cc
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,13 @@ PolicyProto IterateProto(PolicyProto iterable) {
return policy;
}

PolicyProto DifferenceProto(PolicyProto left, PolicyProto right) {
PolicyProto policy;
*policy.mutable_difference_op()->mutable_left() = std::move(left);
*policy.mutable_difference_op()->mutable_right() = std::move(right);
return policy;
}

// -- Derived Policy constructors ----------------------------------------------

PolicyProto DenyProto() { return FilterProto(FalseProto()); }
Expand Down Expand Up @@ -165,6 +172,10 @@ std::string AsShorthandString(PolicyProto policy) {
case PolicyProto::kIterateOp:
return absl::StrFormat("(%s)*",
AsShorthandString(policy.iterate_op().iterable()));
case PolicyProto::kDifferenceOp:
return absl::StrFormat("(%s - %s)",
AsShorthandString(policy.difference_op().left()),
AsShorthandString(policy.difference_op().right()));
case PolicyProto::POLICY_NOT_SET:
return "deny";
}
Expand Down
2 changes: 2 additions & 0 deletions netkat/netkat_proto_constructors.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ PolicyProto RecordProto();
PolicyProto SequenceProto(PolicyProto left, PolicyProto right);
PolicyProto UnionProto(PolicyProto left, PolicyProto right);
PolicyProto IterateProto(PolicyProto iterable);
PolicyProto DifferenceProto(PolicyProto left, PolicyProto right);

// -- Derived Policy constructors ----------------------------------------------

Expand All @@ -63,6 +64,7 @@ PolicyProto AcceptProto();
// Policy Sequence -> ';'
// Policy Or -> '+'
// Iterate -> '*'
// Difference -> '-'
// Record -> 'record'
// Match -> '@field==value'
// Modify -> '@field:=value'
Expand Down
9 changes: 9 additions & 0 deletions netkat/netkat_proto_constructors_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,15 @@ void IterateProtoReturnsIterate(PolicyProto iterable) {
}
FUZZ_TEST(PolicyProtoTest, IterateProtoReturnsIterate);

void DifferenceProtoReturnsDifference(PolicyProto left, PolicyProto right) {
PolicyProto expected_policy;
*expected_policy.mutable_difference_op()->mutable_left() = left;
*expected_policy.mutable_difference_op()->mutable_right() = right;

EXPECT_THAT(DifferenceProto(left, right), EqualsProto(expected_policy));
}
FUZZ_TEST(PolicyProtoTest, DifferenceProtoReturnsDifference);

// -- Derived Policy tests -----------------------------------------------------

TEST(PolicyProtoTest, DenyProtoFiltersOnFalse) {
Expand Down
9 changes: 7 additions & 2 deletions netkat/netkat_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@ namespace netkat {
namespace {

// Sanity fuzz test to show that the FuzzTest library works.
void DummmyFuzzTest(PredicateProto pred, PolicyProto pol) {
void DummyFuzzTest(PredicateProto pred, PolicyProto pol) {
LOG_EVERY_N_SEC(INFO, 1) << "pred = " << pred;
LOG_EVERY_N_SEC(INFO, 1) << "pol = " << pol;
}
FUZZ_TEST(NetkatProtoTest, DummmyFuzzTest);
FUZZ_TEST(NetkatProtoTest, DummyFuzzTest);

// Ensures that the protobuf C++ compiler does not add underscores to the
// generated code for sub messages and oneof fields of `PredicateProto`.
Expand Down Expand Up @@ -104,6 +104,11 @@ TEST(NetkatProtoTest, PolicyOneOfFieldNamesDontRequireUnderscores) {
LOG(INFO) << "iterate: " << iter;
break;
}
case PolicyProto::kDifferenceOp: {
const PolicyProto::Difference& difference_op = policy.difference_op();
LOG(INFO) << "difference: " << difference_op;
break;
}
case PolicyProto::POLICY_NOT_SET:
break;
}
Expand Down
116 changes: 116 additions & 0 deletions netkat/packet_transformer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,9 @@ PacketTransformerHandle PacketTransformerManager::Compile(
Compile(policy.union_op().right()));
case PolicyProto::kIterateOp:
return Iterate(Compile(policy.iterate_op().iterable()));
case PolicyProto::kDifferenceOp:
return Difference(Compile(policy.difference_op().left()),
Compile(policy.difference_op().right()));
case PolicyProto::POLICY_NOT_SET:
// By convention, uninitialized policies must be treated like the Deny
// policy.
Expand Down Expand Up @@ -623,6 +626,119 @@ PacketTransformerHandle PacketTransformerManager::Union(
return Union(GetNodeOrDie(left), GetNodeOrDie(right));
}

PacketTransformerHandle PacketTransformerManager::Difference(
DecisionNode left, DecisionNode right) {
// left.field > right.field: Expand the left node, reducing to the inductive
// case.
if (left.field > right.field) {
PacketFieldHandle first_field = right.field;
return Difference(
DecisionNode{
.field = first_field,
.default_branch = NodeToTransformer(std::move(left)),
},
std::move(right));
}

// left.field < right.field: Expand the right node, reducing to the inductive
// case.
if (left.field < right.field) {
PacketFieldHandle first_field = left.field;
return Difference(std::move(left),
DecisionNode{
.field = first_field,
.default_branch = NodeToTransformer(std::move(right)),
});
}

// left.field == right.field: branch on shared field.
DCHECK(left.field == right.field);
DecisionNode result_node{
.field = left.field,
.default_branch_by_field_modification = CombineModifyBranches(
left.default_branch_by_field_modification,
right.default_branch_by_field_modification,
/*combiner=*/
[this](PacketTransformerHandle left, PacketTransformerHandle right) {
return Difference(left, right);
},
/*default_value=*/Deny()),
.default_branch = Difference(left.default_branch, right.default_branch),
};

// Collect every value in mapped in each node.
absl::flat_hash_set<int> all_possible_values;
all_possible_values.reserve(
left.modify_branch_by_field_match.size() +
right.modify_branch_by_field_match.size() +
left.default_branch_by_field_modification.size() +
right.default_branch_by_field_modification.size());

absl::c_transform(
left.modify_branch_by_field_match,
std::inserter(all_possible_values, all_possible_values.end()),
[](auto pair) { return pair.first; });
absl::c_transform(
right.modify_branch_by_field_match,
std::inserter(all_possible_values, all_possible_values.end()),
[](auto pair) { return pair.first; });
absl::c_transform(
left.default_branch_by_field_modification,
std::inserter(all_possible_values, all_possible_values.end()),
[](auto pair) { return pair.first; });
absl::c_transform(
right.default_branch_by_field_modification,
std::inserter(all_possible_values, all_possible_values.end()),
[](auto pair) { return pair.first; });

// For every value in mapped in each node, construct the proper new branch.
// TODO(dilo): Would like to use absl::bind_front here instead of a lambda:
// absl::bind_front<PacketTransformerHandle(PacketTransformerHandle,
// PacketTransformerHandle)>(
// &PacketTransformerManager::Difference, this),
for (int value : all_possible_values) {
result_node.modify_branch_by_field_match[value] = CombineModifyBranches(
GetMapAtValue(left, value), GetMapAtValue(right, value),
/*combiner=*/
[this](PacketTransformerHandle left, PacketTransformerHandle right) {
return Difference(left, right);
},
/*default_value=*/Deny());
}

return NodeToTransformer(std::move(result_node));
}

PacketTransformerHandle PacketTransformerManager::Difference(
PacketTransformerHandle left, PacketTransformerHandle right) {
// Base cases.
if (left == right) return Deny();
if (IsDeny(left)) return Deny();
if (IsDeny(right)) return left;

// If either node is accept, then expand it before merging.
if (IsAccept(left)) {
const DecisionNode& right_node = GetNodeOrDie(right);
return Difference(
DecisionNode{
.field = right_node.field,
.default_branch = Accept(),
},
right_node);
}

if (IsAccept(right)) {
const DecisionNode& left_node = GetNodeOrDie(left);
return Difference(left_node, DecisionNode{
.field = left_node.field,
.default_branch = Accept(),
});
}

// If neither node is accept or deny, then difference the nodes directly.
return Difference(GetNodeOrDie(left), GetNodeOrDie(right));
}

PacketTransformerHandle PacketTransformerManager::Iterate(
PacketTransformerHandle iterable) {
PacketTransformerHandle previous_approximation = Accept();
Expand Down
3 changes: 2 additions & 1 deletion netkat/packet_transformer.h
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ class PacketTransformerManager {
// Returns the transformer that describes the packets produced by the `left`
// transformer, but not the `right` transformer.
PacketTransformerHandle Difference(PacketTransformerHandle left,
PacketTransformerHandle right) = delete;
PacketTransformerHandle right);

// Returns the transformer that describes the packets produced by the `left`
// transformer or the `right` transformer, but not both.
Expand Down Expand Up @@ -385,6 +385,7 @@ class PacketTransformerManager {
// copies of the nodes?
PacketTransformerHandle Union(DecisionNode left, DecisionNode right);
PacketTransformerHandle Sequence(DecisionNode left, DecisionNode right);
PacketTransformerHandle Difference(DecisionNode left, DecisionNode right);

// Internal helper function to get a map of possible modification values to
// branches for a given input value at `node`.
Expand Down
Loading