diff --git a/netkat/evaluator.cc b/netkat/evaluator.cc index 438eddc..a42e053 100644 --- a/netkat/evaluator.cc +++ b/netkat/evaluator.cc @@ -98,6 +98,15 @@ absl::flat_hash_set Evaluate(const PolicyProto& policy, } while (last_size != result.size()); return result; } + case PolicyProto::kDifferenceOp: { + absl::flat_hash_set result = + Evaluate(policy.difference_op().left(), packet); + for (const Packet& packet : + Evaluate(policy.difference_op().right(), packet)) { + if (result.contains(packet)) result.erase(packet); + } + return result; + } case PolicyProto::POLICY_NOT_SET: // Unset policy is treated as Deny. return {}; diff --git a/netkat/evaluator_test.cc b/netkat/evaluator_test.cc index b00fd65..b341bb0 100644 --- a/netkat/evaluator_test.cc +++ b/netkat/evaluator_test.cc @@ -260,6 +260,19 @@ void UnionCombines(Packet packet, PolicyProto left, PolicyProto right) { } FUZZ_TEST(EvaluatePolicyProtoTest, UnionCombines); +void DifferenceRemoves(Packet packet, PolicyProto left, PolicyProto right) { + absl::flat_hash_set expected_packets = Evaluate(left, packet); + for (const Packet& packet : Evaluate(right, packet)) { + if (expected_packets.contains(packet)) { + expected_packets.erase(packet); + } + } + + EXPECT_THAT(Evaluate(DifferenceProto(left, right), packet), + ContainerEq(expected_packets)); +} +FUZZ_TEST(EvaluatePolicyProtoTest, DifferenceRemoves); + void SequenceSequences(Packet packet, PolicyProto left, PolicyProto right) { absl::flat_hash_set expected_packets = Evaluate(right, Evaluate(left, packet)); diff --git a/netkat/frontend.cc b/netkat/frontend.cc index dd2b50e..e51ad91 100644 --- a/netkat/frontend.cc +++ b/netkat/frontend.cc @@ -126,6 +126,16 @@ absl::Status RecursivelyCheckIsValid(const PolicyProto& policy_proto) { RecursivelyCheckIsValid(policy_proto.iterate_op().iterable())) << "PolicyProto::Iterate::policy is invalid: "; return absl::OkStatus(); + case PolicyProto::kDifferenceOp: + RETURN_IF_ERROR( + RecursivelyCheckIsValid(policy_proto.difference_op().left())) + .SetPrepend() + << "PolicyProto::DifferenceOp::left is invalid: "; + RETURN_IF_ERROR( + RecursivelyCheckIsValid(policy_proto.difference_op().right())) + .SetPrepend() + << "PolicyProto::DifferenceOp::right is invalid: "; + return absl::OkStatus(); case PolicyProto::POLICY_NOT_SET: return absl::InvalidArgumentError("Unset Policy case is invalid"); } @@ -179,6 +189,11 @@ Policy Filter(Predicate predicate) { return Policy(FilterProto(std::move(predicate).ToProto())); } +Policy Difference(Policy left, Policy right) { + return Policy( + DifferenceProto(std::move(left).ToProto(), std::move(right).ToProto())); +} + Policy Policy::Accept() { return Filter(Predicate::True()); } Policy Policy::Deny() { return Filter(Predicate::False()); } diff --git a/netkat/frontend.h b/netkat/frontend.h index 7657536..87ebcd9 100644 --- a/netkat/frontend.h +++ b/netkat/frontend.h @@ -201,6 +201,7 @@ class Policy { friend Policy Sequence(std::vector); friend Policy Union(std::vector); friend Policy Iterate(Policy); + friend Policy Difference(Policy, Policy); friend Policy Record(); // Policies that conceptually represent a program that should accept or @@ -254,7 +255,7 @@ Policy Modify(absl::string_view field, int new_value); Policy Sequence(std::vector policies); // Allows callers to Sequence policies without wrapping them in a list. Prefer -// this overload when reasonble. For example, instead of +// this overload when reasonable. For example, instead of // // Sequence({p0, p1, p2, p3}) // diff --git a/netkat/frontend_test.cc b/netkat/frontend_test.cc index 566808d..4f320f6 100644 --- a/netkat/frontend_test.cc +++ b/netkat/frontend_test.cc @@ -170,6 +170,9 @@ void ExpectFromProtoToFailWithInvalidPolicyProto(PolicyProto policy_proto) { case PolicyProto::kIterateOp: policy_proto.mutable_iterate_op()->clear_iterable(); break; + case PolicyProto::kDifferenceOp: + policy_proto.mutable_difference_op()->clear_left(); + break; // Unset policy is invalid. case PolicyProto::POLICY_NOT_SET: break; @@ -209,6 +212,14 @@ void IterateToProtoIsCorrect(Policy policy) { FUZZ_TEST(FrontEndTest, IterateToProtoIsCorrect) .WithDomains(/*policy=*/AtomicDupFreePolicyDomain()); +void DifferenceToProtoIsCorrect(Policy left, Policy right) { + EXPECT_THAT(Difference(left, right).ToProto(), + EqualsProto(DifferenceProto(left.ToProto(), right.ToProto()))); +} +FUZZ_TEST(FrontEndTest, DifferenceToProtoIsCorrect) + .WithDomains(/*policy=*/AtomicDupFreePolicyDomain(), + /*policy=*/AtomicDupFreePolicyDomain()); + TEST(FrontEndTest, SequenceWithNoElementsIsAccept) { EXPECT_THAT(Sequence().ToProto(), EqualsProto(AcceptProto())); } diff --git a/netkat/netkat.proto b/netkat/netkat.proto index ddd53cd..ccc5779 100644 --- a/netkat/netkat.proto +++ b/netkat/netkat.proto @@ -116,6 +116,7 @@ message PolicyProto { Sequence sequence_op = 4; Union union_op = 5; Iterate iterate_op = 6; + Difference difference_op = 7; } // Sets the field to the given value. @@ -148,6 +149,12 @@ message PolicyProto { PolicyProto iterable = 1; } + // Represents the difference of two policies, i.e. a - b. + message Difference { + PolicyProto left = 1; + PolicyProto right = 2; + } + // Records the packet, at the given point, into the history. Referred to as // "dup" in the literature. message Record {} diff --git a/netkat/netkat_proto_constructors.cc b/netkat/netkat_proto_constructors.cc index aba1bcf..33d9664 100644 --- a/netkat/netkat_proto_constructors.cc +++ b/netkat/netkat_proto_constructors.cc @@ -113,6 +113,13 @@ PolicyProto IterateProto(PolicyProto iterable) { return policy; } +PolicyProto DifferenceProto(PolicyProto left, PolicyProto right) { + PolicyProto policy; + *policy.mutable_difference_op()->mutable_left() = std::move(left); + *policy.mutable_difference_op()->mutable_right() = std::move(right); + return policy; +} + // -- Derived Policy constructors ---------------------------------------------- PolicyProto DenyProto() { return FilterProto(FalseProto()); } @@ -165,6 +172,10 @@ std::string AsShorthandString(PolicyProto policy) { case PolicyProto::kIterateOp: return absl::StrFormat("(%s)*", AsShorthandString(policy.iterate_op().iterable())); + case PolicyProto::kDifferenceOp: + return absl::StrFormat("(%s - %s)", + AsShorthandString(policy.difference_op().left()), + AsShorthandString(policy.difference_op().right())); case PolicyProto::POLICY_NOT_SET: return "deny"; } diff --git a/netkat/netkat_proto_constructors.h b/netkat/netkat_proto_constructors.h index 3a0332a..c3c188b 100644 --- a/netkat/netkat_proto_constructors.h +++ b/netkat/netkat_proto_constructors.h @@ -47,6 +47,7 @@ PolicyProto RecordProto(); PolicyProto SequenceProto(PolicyProto left, PolicyProto right); PolicyProto UnionProto(PolicyProto left, PolicyProto right); PolicyProto IterateProto(PolicyProto iterable); +PolicyProto DifferenceProto(PolicyProto left, PolicyProto right); // -- Derived Policy constructors ---------------------------------------------- @@ -63,6 +64,7 @@ PolicyProto AcceptProto(); // Policy Sequence -> ';' // Policy Or -> '+' // Iterate -> '*' +// Difference -> '-' // Record -> 'record' // Match -> '@field==value' // Modify -> '@field:=value' diff --git a/netkat/netkat_proto_constructors_test.cc b/netkat/netkat_proto_constructors_test.cc index b5724fd..34609bf 100644 --- a/netkat/netkat_proto_constructors_test.cc +++ b/netkat/netkat_proto_constructors_test.cc @@ -127,6 +127,15 @@ void IterateProtoReturnsIterate(PolicyProto iterable) { } FUZZ_TEST(PolicyProtoTest, IterateProtoReturnsIterate); +void DifferenceProtoReturnsDifference(PolicyProto left, PolicyProto right) { + PolicyProto expected_policy; + *expected_policy.mutable_difference_op()->mutable_left() = left; + *expected_policy.mutable_difference_op()->mutable_right() = right; + + EXPECT_THAT(DifferenceProto(left, right), EqualsProto(expected_policy)); +} +FUZZ_TEST(PolicyProtoTest, DifferenceProtoReturnsDifference); + // -- Derived Policy tests ----------------------------------------------------- TEST(PolicyProtoTest, DenyProtoFiltersOnFalse) { diff --git a/netkat/netkat_test.cc b/netkat/netkat_test.cc index 02ed4cf..8477b69 100644 --- a/netkat/netkat_test.cc +++ b/netkat/netkat_test.cc @@ -22,11 +22,11 @@ namespace netkat { namespace { // Sanity fuzz test to show that the FuzzTest library works. -void DummmyFuzzTest(PredicateProto pred, PolicyProto pol) { +void DummyFuzzTest(PredicateProto pred, PolicyProto pol) { LOG_EVERY_N_SEC(INFO, 1) << "pred = " << pred; LOG_EVERY_N_SEC(INFO, 1) << "pol = " << pol; } -FUZZ_TEST(NetkatProtoTest, DummmyFuzzTest); +FUZZ_TEST(NetkatProtoTest, DummyFuzzTest); // Ensures that the protobuf C++ compiler does not add underscores to the // generated code for sub messages and oneof fields of `PredicateProto`. @@ -104,6 +104,11 @@ TEST(NetkatProtoTest, PolicyOneOfFieldNamesDontRequireUnderscores) { LOG(INFO) << "iterate: " << iter; break; } + case PolicyProto::kDifferenceOp: { + const PolicyProto::Difference& difference_op = policy.difference_op(); + LOG(INFO) << "difference: " << difference_op; + break; + } case PolicyProto::POLICY_NOT_SET: break; } diff --git a/netkat/packet_transformer.cc b/netkat/packet_transformer.cc index cc6b241..01c7fa7 100644 --- a/netkat/packet_transformer.cc +++ b/netkat/packet_transformer.cc @@ -288,6 +288,9 @@ PacketTransformerHandle PacketTransformerManager::Compile( Compile(policy.union_op().right())); case PolicyProto::kIterateOp: return Iterate(Compile(policy.iterate_op().iterable())); + case PolicyProto::kDifferenceOp: + return Difference(Compile(policy.difference_op().left()), + Compile(policy.difference_op().right())); case PolicyProto::POLICY_NOT_SET: // By convention, uninitialized policies must be treated like the Deny // policy. @@ -623,6 +626,119 @@ PacketTransformerHandle PacketTransformerManager::Union( return Union(GetNodeOrDie(left), GetNodeOrDie(right)); } +PacketTransformerHandle PacketTransformerManager::Difference( + DecisionNode left, DecisionNode right) { + // left.field > right.field: Expand the left node, reducing to the inductive + // case. + if (left.field > right.field) { + PacketFieldHandle first_field = right.field; + return Difference( + DecisionNode{ + .field = first_field, + .default_branch = NodeToTransformer(std::move(left)), + }, + std::move(right)); + } + + // left.field < right.field: Expand the right node, reducing to the inductive + // case. + if (left.field < right.field) { + PacketFieldHandle first_field = left.field; + return Difference(std::move(left), + DecisionNode{ + .field = first_field, + .default_branch = NodeToTransformer(std::move(right)), + }); + } + + // left.field == right.field: branch on shared field. + DCHECK(left.field == right.field); + DecisionNode result_node{ + .field = left.field, + .default_branch_by_field_modification = CombineModifyBranches( + left.default_branch_by_field_modification, + right.default_branch_by_field_modification, + /*combiner=*/ + [this](PacketTransformerHandle left, PacketTransformerHandle right) { + return Difference(left, right); + }, + /*default_value=*/Deny()), + .default_branch = Difference(left.default_branch, right.default_branch), + }; + + // Collect every value in mapped in each node. + absl::flat_hash_set all_possible_values; + all_possible_values.reserve( + left.modify_branch_by_field_match.size() + + right.modify_branch_by_field_match.size() + + left.default_branch_by_field_modification.size() + + right.default_branch_by_field_modification.size()); + + absl::c_transform( + left.modify_branch_by_field_match, + std::inserter(all_possible_values, all_possible_values.end()), + [](auto pair) { return pair.first; }); + absl::c_transform( + right.modify_branch_by_field_match, + std::inserter(all_possible_values, all_possible_values.end()), + [](auto pair) { return pair.first; }); + absl::c_transform( + left.default_branch_by_field_modification, + std::inserter(all_possible_values, all_possible_values.end()), + [](auto pair) { return pair.first; }); + absl::c_transform( + right.default_branch_by_field_modification, + std::inserter(all_possible_values, all_possible_values.end()), + [](auto pair) { return pair.first; }); + + // For every value in mapped in each node, construct the proper new branch. + // TODO(dilo): Would like to use absl::bind_front here instead of a lambda: + // absl::bind_front( + // &PacketTransformerManager::Difference, this), + for (int value : all_possible_values) { + result_node.modify_branch_by_field_match[value] = CombineModifyBranches( + GetMapAtValue(left, value), GetMapAtValue(right, value), + /*combiner=*/ + [this](PacketTransformerHandle left, PacketTransformerHandle right) { + return Difference(left, right); + }, + /*default_value=*/Deny()); + } + + return NodeToTransformer(std::move(result_node)); +} + +PacketTransformerHandle PacketTransformerManager::Difference( + PacketTransformerHandle left, PacketTransformerHandle right) { + // Base cases. + if (left == right) return Deny(); + if (IsDeny(left)) return Deny(); + if (IsDeny(right)) return left; + + // If either node is accept, then expand it before merging. + if (IsAccept(left)) { + const DecisionNode& right_node = GetNodeOrDie(right); + return Difference( + DecisionNode{ + .field = right_node.field, + .default_branch = Accept(), + }, + right_node); + } + + if (IsAccept(right)) { + const DecisionNode& left_node = GetNodeOrDie(left); + return Difference(left_node, DecisionNode{ + .field = left_node.field, + .default_branch = Accept(), + }); + } + + // If neither node is accept or deny, then difference the nodes directly. + return Difference(GetNodeOrDie(left), GetNodeOrDie(right)); +} + PacketTransformerHandle PacketTransformerManager::Iterate( PacketTransformerHandle iterable) { PacketTransformerHandle previous_approximation = Accept(); diff --git a/netkat/packet_transformer.h b/netkat/packet_transformer.h index 9983e89..1571357 100644 --- a/netkat/packet_transformer.h +++ b/netkat/packet_transformer.h @@ -252,7 +252,7 @@ class PacketTransformerManager { // Returns the transformer that describes the packets produced by the `left` // transformer, but not the `right` transformer. PacketTransformerHandle Difference(PacketTransformerHandle left, - PacketTransformerHandle right) = delete; + PacketTransformerHandle right); // Returns the transformer that describes the packets produced by the `left` // transformer or the `right` transformer, but not both. @@ -385,6 +385,7 @@ class PacketTransformerManager { // copies of the nodes? PacketTransformerHandle Union(DecisionNode left, DecisionNode right); PacketTransformerHandle Sequence(DecisionNode left, DecisionNode right); + PacketTransformerHandle Difference(DecisionNode left, DecisionNode right); // Internal helper function to get a map of possible modification values to // branches for a given input value at `node`. diff --git a/netkat/packet_transformer_test.cc b/netkat/packet_transformer_test.cc index c3e6d8d..dc27b35 100644 --- a/netkat/packet_transformer_test.cc +++ b/netkat/packet_transformer_test.cc @@ -179,6 +179,13 @@ void IterateCompilesToIterate(PolicyProto iterable) { } FUZZ_TEST(PacketTransformerManagerTest, IterateCompilesToIterate); +void DifferenceCompilesToDifference(PolicyProto left, PolicyProto right) { + EXPECT_EQ( + Manager().Compile(DifferenceProto(left, right)), + Manager().Difference(Manager().Compile(left), Manager().Compile(right))); +} +FUZZ_TEST(PacketTransformerManagerTest, DifferenceCompilesToDifference); + /*--- Kleene algebra axioms and equivalences ---------------------------------*/ void UnionIsAssociative(PolicyProto a, PolicyProto b, PolicyProto c) { @@ -273,6 +280,32 @@ void IterateIsLeastFixedPoint(PolicyProto p, PolicyProto q, PolicyProto r) { } FUZZ_TEST(PacketTransformerManagerTest, IterateIsLeastFixedPoint); +void DifferenceOfPolicyAndDenyIsIdentity(PolicyProto policy) { + EXPECT_EQ(Manager().Compile(DifferenceProto(policy, DenyProto())), + Manager().Compile(policy)); +} +FUZZ_TEST(PacketTransformerManagerTest, DifferenceOfPolicyAndDenyIsIdentity); + +void DifferenceOfDenyAndPolicyIsAlwaysDeny(PolicyProto policy) { + EXPECT_EQ(Manager().Compile(DifferenceProto(DenyProto(), policy)), + Manager().Compile(DenyProto())); +} +FUZZ_TEST(PacketTransformerManagerTest, DifferenceOfDenyAndPolicyIsAlwaysDeny); + +void DifferenceOfPolicyAndSelfIsAlwaysDeny(PolicyProto policy) { + EXPECT_EQ(Manager().Compile(DifferenceProto(policy, policy)), + Manager().Deny()); +} +FUZZ_TEST(PacketTransformerManagerTest, DifferenceOfPolicyAndSelfIsAlwaysDeny); + +void DifferenceIsRightDistributiveForUnion(PolicyProto a, PolicyProto b, + PolicyProto c) { + EXPECT_EQ(Manager().Compile(DifferenceProto(UnionProto(a, b), c)), + Manager().Compile( + UnionProto(DifferenceProto(a, c), DifferenceProto(b, c)))); +} +FUZZ_TEST(PacketTransformerManagerTest, DifferenceIsRightDistributiveForUnion); + /*--- Tests with concrete protos ---------------------------------------------*/ TEST(PacketTransformerManagerTest, KatchPaperFig5) { @@ -505,6 +538,34 @@ TEST(PacketTransformerManagerTest, SimpleSequenceAndUnionRunTest2) { << Manager().ToString(sequenced_transformer4); } +TEST(PacketTransformerManagerTest, DifferenceBetweenModifyAndModifyIsCorrect) { + PacketTransformerHandle modify_f_42 = Manager().Modification("f", 42); + PacketTransformerHandle modify_g_26 = Manager().Modification("g", 26); + PacketTransformerHandle diff_transformer = + Manager().Difference(modify_f_42, modify_g_26); + + Packet packet_without_fields; + Packet packet_f42 = {{"f", 42}}; + EXPECT_THAT(Manager().Run(diff_transformer, packet_without_fields), + UnorderedElementsAre(packet_f42)); + EXPECT_THAT(Manager().Run(diff_transformer, packet_f42), + UnorderedElementsAre(packet_f42)); + + Packet packet_g26 = {{"g", 26}}; + Packet packet_f42_g26 = {{"f", 42}, {"g", 26}}; + EXPECT_THAT(Manager().Run(diff_transformer, packet_g26), + UnorderedElementsAre(packet_f42_g26)); + EXPECT_THAT(Manager().Run(diff_transformer, packet_f42_g26), IsEmpty()); + + Packet packet_f24_g26 = {{"f", 24}, {"g", 26}}; + EXPECT_THAT(Manager().Run(diff_transformer, packet_f24_g26), + UnorderedElementsAre(packet_f42_g26)); + + Packet packet_f42_g62 = {{"f", 42}, {"g", 62}}; + EXPECT_THAT(Manager().Run(diff_transformer, packet_f42_g62), + UnorderedElementsAre(packet_f42_g62)); +} + TEST(PacketTransformerManagerTest, PushThroughModifyIsCorrect) { PacketSetManager& packet_set_manager = Manager().GetPacketSetManager(); PacketSetHandle f_24 = packet_set_manager.Match("f", 24); diff --git a/netkat/table.cc b/netkat/table.cc index 9cefd49..460c2bb 100644 --- a/netkat/table.cc +++ b/netkat/table.cc @@ -50,6 +50,10 @@ absl::Status VerifyActionHasNoPredicate(const Policy& action) { stack.push_back(&policy->union_op().left()); stack.push_back(&policy->union_op().right()); break; + case PolicyProto::PolicyCase::kDifferenceOp: + stack.push_back(&policy->difference_op().left()); + stack.push_back(&policy->difference_op().right()); + break; case PolicyProto::PolicyCase::kSequenceOp: stack.push_back(&policy->sequence_op().left()); stack.push_back(&policy->sequence_op().right());