From 3248d61fd6e20a6cb1d877ce75e30ba2a226faa6 Mon Sep 17 00:00:00 2001 From: Tom Pantelis Date: Tue, 5 Mar 2024 13:55:45 -0500 Subject: [PATCH] Adjust chain rule implementation in fake packetfilter The purpose of ClearChain is to clear/flush all rules for a table chain. The fake implementation was modified to do this, previously it deleted the entire chain, ie it was equivalent to DeleteChain. To facilitate this, the tableChains map, which just tracked the existence of chains, was removed. The chainRules map can also be used for this purpose. In addition, inserting/appending a rule now fails if the chain doesn't exist to align with the behavior of the real iptables implementation. This required adjustments to the globalnet unit tests. Signed-off-by: Tom Pantelis --- .../cluster_egressip_controller_test.go | 1 + .../controllers/controllers_suite_test.go | 18 +++++ .../controllers/gateway_monitor_test.go | 17 +++-- .../global_egressip_controller_test.go | 1 + .../global_ingressip_controller_test.go | 1 + .../controllers/node_controller_test.go | 1 + .../controllers/service_controller_test.go | 1 + .../service_export_controller_test.go | 1 + pkg/packetfilter/fake/packetfilter.go | 70 +++++++++++-------- 9 files changed, 74 insertions(+), 37 deletions(-) diff --git a/pkg/globalnet/controllers/cluster_egressip_controller_test.go b/pkg/globalnet/controllers/cluster_egressip_controller_test.go index 3655e7627..cc16f3caf 100644 --- a/pkg/globalnet/controllers/cluster_egressip_controller_test.go +++ b/pkg/globalnet/controllers/cluster_egressip_controller_test.go @@ -361,6 +361,7 @@ func newClusterGlobalEgressIPControllerTestDriver() *clusterGlobalEgressIPContro BeforeEach(func() { t.testDriverBase = newTestDriverBase() + t.testDriverBase.initChains() var err error diff --git a/pkg/globalnet/controllers/controllers_suite_test.go b/pkg/globalnet/controllers/controllers_suite_test.go index e1350bcbf..8d02985fe 100644 --- a/pkg/globalnet/controllers/controllers_suite_test.go +++ b/pkg/globalnet/controllers/controllers_suite_test.go @@ -142,6 +142,24 @@ func (t *testDriverBase) afterEach() { t.controller.Stop() } +func (t *testDriverBase) initChains() { + for _, chain := range []string{ + constants.SmGlobalnetIngressChain, + constants.SmGlobalnetEgressChain, + constants.SmGlobalnetEgressChainForPods, + constants.SmGlobalnetEgressChainForHeadlessSvcPods, + constants.SmGlobalnetEgressChainForHeadlessSvcEPs, + constants.SmGlobalnetEgressChainForNamespace, + constants.SmGlobalnetEgressChainForCluster, + routeAgent.SmPostRoutingChain, + constants.SmGlobalnetMarkChain, + } { + Expect(t.pFilter.CreateChainIfNotExists(packetfilter.TableTypeNAT, &packetfilter.Chain{ + Name: chain, + })).To(Succeed()) + } +} + func (t *testDriverBase) verifyIPsReservedInPool(ips ...string) { if t.pool == nil { return diff --git a/pkg/globalnet/controllers/gateway_monitor_test.go b/pkg/globalnet/controllers/gateway_monitor_test.go index e8e9876d1..cdcd5a990 100644 --- a/pkg/globalnet/controllers/gateway_monitor_test.go +++ b/pkg/globalnet/controllers/gateway_monitor_test.go @@ -88,7 +88,7 @@ var _ = Describe("Endpoint monitoring", func() { It("should stop and restart the controllers", func() { t.leaderElection.AwaitLeaseReleased() - t.awaitNoGlobalnetChains() + t.awaitGlobalnetChainsCleared() t.ensureControllersStopped() By("Recreating the Endpoint") @@ -167,7 +167,7 @@ var _ = Describe("Endpoint monitoring", func() { Expect(t.endpoints.Delete(context.TODO(), endpoint.Name, metav1.DeleteOptions{})).To(Succeed()) - t.awaitNoGlobalnetChains() + t.awaitGlobalnetChainsCleared() }) }) }) @@ -304,14 +304,19 @@ func (t *gatewayMonitorTestDriver) ensureControllersStopped() { func (t *gatewayMonitorTestDriver) awaitGlobalnetChains() { t.pFilter.AwaitChain(packetfilter.TableTypeNAT, constants.SmGlobalnetIngressChain) t.pFilter.AwaitChain(packetfilter.TableTypeNAT, constants.SmGlobalnetEgressChain) + t.pFilter.AwaitChain(packetfilter.TableTypeNAT, constants.SmGlobalnetEgressChainForPods) + t.pFilter.AwaitChain(packetfilter.TableTypeNAT, constants.SmGlobalnetEgressChainForHeadlessSvcPods) + t.pFilter.AwaitChain(packetfilter.TableTypeNAT, constants.SmGlobalnetEgressChainForHeadlessSvcEPs) + t.pFilter.AwaitChain(packetfilter.TableTypeNAT, constants.SmGlobalnetEgressChainForNamespace) + t.pFilter.AwaitChain(packetfilter.TableTypeNAT, constants.SmGlobalnetEgressChainForCluster) t.pFilter.AwaitChain(packetfilter.TableTypeNAT, routeAgent.SmPostRoutingChain) t.pFilter.AwaitChain(packetfilter.TableTypeNAT, constants.SmGlobalnetMarkChain) } -func (t *gatewayMonitorTestDriver) awaitNoGlobalnetChains() { - t.pFilter.AwaitNoChain(packetfilter.TableTypeNAT, constants.SmGlobalnetIngressChain) - t.pFilter.AwaitNoChain(packetfilter.TableTypeNAT, constants.SmGlobalnetEgressChain) - t.pFilter.AwaitNoChain(packetfilter.TableTypeNAT, constants.SmGlobalnetMarkChain) +func (t *gatewayMonitorTestDriver) awaitGlobalnetChainsCleared() { + t.pFilter.AwaitNoRules(packetfilter.TableTypeNAT, constants.SmGlobalnetIngressChain) + t.pFilter.AwaitNoRules(packetfilter.TableTypeNAT, constants.SmGlobalnetEgressChain) + t.pFilter.AwaitNoRules(packetfilter.TableTypeNAT, constants.SmGlobalnetMarkChain) } func newEndpointSpec(clusterID, hostname, subnet string) *submarinerv1.EndpointSpec { diff --git a/pkg/globalnet/controllers/global_egressip_controller_test.go b/pkg/globalnet/controllers/global_egressip_controller_test.go index 26db84d67..9c6b05db2 100644 --- a/pkg/globalnet/controllers/global_egressip_controller_test.go +++ b/pkg/globalnet/controllers/global_egressip_controller_test.go @@ -593,6 +593,7 @@ func newGlobalEgressIPControllerTestDriver() *globalEgressIPControllerTestDriver BeforeEach(func() { t.testDriverBase = newTestDriverBase() + t.testDriverBase.initChains() var err error diff --git a/pkg/globalnet/controllers/global_ingressip_controller_test.go b/pkg/globalnet/controllers/global_ingressip_controller_test.go index f094e7647..be0610dd9 100644 --- a/pkg/globalnet/controllers/global_ingressip_controller_test.go +++ b/pkg/globalnet/controllers/global_ingressip_controller_test.go @@ -505,6 +505,7 @@ func newGlobalIngressIPControllerDriver() *globalIngressIPControllerTestDriver { BeforeEach(func() { t.testDriverBase = newTestDriverBase() + t.testDriverBase.initChains() var err error diff --git a/pkg/globalnet/controllers/node_controller_test.go b/pkg/globalnet/controllers/node_controller_test.go index 081401e78..df60e42aa 100644 --- a/pkg/globalnet/controllers/node_controller_test.go +++ b/pkg/globalnet/controllers/node_controller_test.go @@ -174,6 +174,7 @@ func newNodeControllerTestDriver() *nodeControllerTestDriver { BeforeEach(func() { t.testDriverBase = newTestDriverBase() + t.testDriverBase.initChains() var err error diff --git a/pkg/globalnet/controllers/service_controller_test.go b/pkg/globalnet/controllers/service_controller_test.go index a80e5c675..8f7392a07 100644 --- a/pkg/globalnet/controllers/service_controller_test.go +++ b/pkg/globalnet/controllers/service_controller_test.go @@ -176,6 +176,7 @@ func newServiceControllerTestDriver() *serviceControllerTestDriver { BeforeEach(func() { t.testDriverBase = newTestDriverBase() + t.testDriverBase.initChains() }) JustBeforeEach(func() { diff --git a/pkg/globalnet/controllers/service_export_controller_test.go b/pkg/globalnet/controllers/service_export_controller_test.go index 4a6dc0322..47cbfdf2e 100644 --- a/pkg/globalnet/controllers/service_export_controller_test.go +++ b/pkg/globalnet/controllers/service_export_controller_test.go @@ -448,6 +448,7 @@ func newServiceExportControllerTestDriver() *serviceExportControllerTestDriver { BeforeEach(func() { t.testDriverBase = newTestDriverBase() + t.testDriverBase.initChains() }) JustBeforeEach(func() { diff --git a/pkg/packetfilter/fake/packetfilter.go b/pkg/packetfilter/fake/packetfilter.go index 39ea08aac..b94c0b243 100644 --- a/pkg/packetfilter/fake/packetfilter.go +++ b/pkg/packetfilter/fake/packetfilter.go @@ -21,6 +21,7 @@ package fake import ( "encoding/json" "fmt" + "strings" "sync" . "github.com/onsi/gomega" @@ -32,7 +33,6 @@ import ( type PacketFilter struct { mutex sync.Mutex chainRules map[string]set.Set[string] - tableChains map[uint32]set.Set[string] failOnAppendRuleMatchers []interface{} failOnDeleteRuleMatchers []interface{} @@ -45,9 +45,8 @@ type PacketFilter struct { func New() *PacketFilter { pf := &PacketFilter{ - chainRules: map[string]set.Set[string]{}, - tableChains: map[uint32]set.Set[string]{}, - sets: map[string]set.Set[string]{}, + chainRules: map[string]set.Set[string]{}, + sets: map[string]set.Set[string]{}, } packetfilter.SetNewDriverFn(func() (packetfilter.Driver, error) { @@ -89,11 +88,13 @@ func (i *PacketFilter) ClearChain(table packetfilter.TableType, chain string) er i.mutex.Lock() defer i.mutex.Unlock() - chainSet := i.tableChains[uint32(table)] - if chainSet != nil { - chainSet.Delete(chain) + ruleSet := i.chainRules[chainKey(uint32(table), chain)] + if ruleSet == nil { + return fmt.Errorf("chain %q for table %q does not exist", chain, table) } + ruleSet.Clear() + return nil } @@ -165,7 +166,7 @@ func (i *PacketFilter) delete(table packetfilter.TableType, chain, rulespec stri return err } - ruleSet := i.chainRules[fmt.Sprintf("%v/%s", table, chain)] + ruleSet := i.chainRules[chainKey(uint32(table), chain)] if ruleSet != nil { ruleSet.Delete(rulespec) } @@ -177,23 +178,26 @@ func (i *PacketFilter) deleteChain(table uint32, chain string) { i.mutex.Lock() defer i.mutex.Unlock() - chainSet := i.tableChains[table] - if chainSet != nil { - chainSet.Delete(chain) - } + delete(i.chainRules, chainKey(table, chain)) } func (i *PacketFilter) addChainsFor(table uint32, chains ...string) { i.mutex.Lock() defer i.mutex.Unlock() - chainSet := i.tableChains[table] - if chainSet == nil { - chainSet = set.New[string]() - i.tableChains[table] = chainSet + for _, chain := range chains { + key := chainKey(table, chain) + + ruleSet := i.chainRules[key] + if ruleSet == nil { + ruleSet = set.New[string]() + i.chainRules[key] = ruleSet + } } +} - chainSet.Insert(chains...) +func chainKey(table uint32, chain string) string { + return fmt.Sprintf("%v/%s", table, chain) } func (i *PacketFilter) addRule(table packetfilter.TableType, chain, rulespec string) error { @@ -205,10 +209,9 @@ func (i *PacketFilter) addRule(table packetfilter.TableType, chain, rulespec str return err } - ruleSet := i.chainRules[fmt.Sprintf("%v/%s", table, chain)] + ruleSet := i.chainRules[chainKey(uint32(table), chain)] if ruleSet == nil { - ruleSet = set.New[string]() - i.chainRules[fmt.Sprintf("%v/%s", table, chain)] = ruleSet + return fmt.Errorf("chain %q for table %q does not exist", chain, table) } ruleSet.Insert(rulespec) @@ -220,7 +223,7 @@ func (i *PacketFilter) listRules(table packetfilter.TableType, chain string) []s i.mutex.Lock() defer i.mutex.Unlock() - rules := i.chainRules[fmt.Sprintf("%v/%s", table, chain)] + rules := i.chainRules[chainKey(uint32(table), chain)] if rules != nil { return rules.UnsortedList() } @@ -232,24 +235,23 @@ func (i *PacketFilter) listChains(table packetfilter.TableType) []string { i.mutex.Lock() defer i.mutex.Unlock() - chains := i.tableChains[uint32(table)] - if chains != nil { - return chains.UnsortedList() + var chains []string + tableKey := chainKey(uint32(table), "") + + for k := range i.chainRules { + if strings.HasPrefix(k, tableKey) { + chains = append(chains, k[len(tableKey):]) + } } - return []string{} + return chains } func (i *PacketFilter) chainExists(table uint32, chain string) (bool, error) { i.mutex.Lock() defer i.mutex.Unlock() - chainSet := i.tableChains[table] - if chainSet != nil { - return chainSet.Has(chain), nil - } - - return false, nil + return i.chainRules[chainKey(table, chain)] != nil, nil } func matchRuleForError(matchers *[]interface{}, rulespec string) error { @@ -290,6 +292,12 @@ func (i *PacketFilter) AwaitNoRule(table packetfilter.TableType, chain string, s }, 5).ShouldNot(ContainElement(stringOrMatcher), "Rules for IP table %v, chain %q", table, chain) } +func (i *PacketFilter) AwaitNoRules(table packetfilter.TableType, chain string) { + Eventually(func() []string { + return i.listRules(table, chain) + }, 5).Should(BeEmpty()) +} + func (i *PacketFilter) AddFailOnAppendRuleMatcher(stringOrMatcher interface{}) { i.mutex.Lock() defer i.mutex.Unlock()