Skip to content

Commit

Permalink
Adjust chain rule implementation in fake packetfilter
Browse files Browse the repository at this point in the history
The purpose of ClearChain is to clear/flush all rules for a table chain.
The fake implementation was modified to do this, previously it deleted
the entire chain, ie it was equivalent to DeleteChain. To facilitate
this, the tableChains map, which just tracked the existence of chains,
was removed. The chainRules map can also be used for this purpose.

In addition, inserting/appending a rule now fails if the chain doesn't
exist to align with the behavior of the real iptables implementation.
This required adjustments to the globalnet unit tests.

Signed-off-by: Tom Pantelis <[email protected]>
  • Loading branch information
tpantelis committed Mar 8, 2024
1 parent d405b10 commit 3248d61
Show file tree
Hide file tree
Showing 9 changed files with 74 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,7 @@ func newClusterGlobalEgressIPControllerTestDriver() *clusterGlobalEgressIPContro

BeforeEach(func() {
t.testDriverBase = newTestDriverBase()
t.testDriverBase.initChains()

var err error

Expand Down
18 changes: 18 additions & 0 deletions pkg/globalnet/controllers/controllers_suite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,24 @@ func (t *testDriverBase) afterEach() {
t.controller.Stop()
}

func (t *testDriverBase) initChains() {
for _, chain := range []string{
constants.SmGlobalnetIngressChain,
constants.SmGlobalnetEgressChain,
constants.SmGlobalnetEgressChainForPods,
constants.SmGlobalnetEgressChainForHeadlessSvcPods,
constants.SmGlobalnetEgressChainForHeadlessSvcEPs,
constants.SmGlobalnetEgressChainForNamespace,
constants.SmGlobalnetEgressChainForCluster,
routeAgent.SmPostRoutingChain,
constants.SmGlobalnetMarkChain,
} {
Expect(t.pFilter.CreateChainIfNotExists(packetfilter.TableTypeNAT, &packetfilter.Chain{
Name: chain,
})).To(Succeed())
}
}

func (t *testDriverBase) verifyIPsReservedInPool(ips ...string) {
if t.pool == nil {
return
Expand Down
17 changes: 11 additions & 6 deletions pkg/globalnet/controllers/gateway_monitor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ var _ = Describe("Endpoint monitoring", func() {

It("should stop and restart the controllers", func() {
t.leaderElection.AwaitLeaseReleased()
t.awaitNoGlobalnetChains()
t.awaitGlobalnetChainsCleared()
t.ensureControllersStopped()

By("Recreating the Endpoint")
Expand Down Expand Up @@ -167,7 +167,7 @@ var _ = Describe("Endpoint monitoring", func() {

Expect(t.endpoints.Delete(context.TODO(), endpoint.Name, metav1.DeleteOptions{})).To(Succeed())

t.awaitNoGlobalnetChains()
t.awaitGlobalnetChainsCleared()
})
})
})
Expand Down Expand Up @@ -304,14 +304,19 @@ func (t *gatewayMonitorTestDriver) ensureControllersStopped() {
func (t *gatewayMonitorTestDriver) awaitGlobalnetChains() {
t.pFilter.AwaitChain(packetfilter.TableTypeNAT, constants.SmGlobalnetIngressChain)
t.pFilter.AwaitChain(packetfilter.TableTypeNAT, constants.SmGlobalnetEgressChain)
t.pFilter.AwaitChain(packetfilter.TableTypeNAT, constants.SmGlobalnetEgressChainForPods)
t.pFilter.AwaitChain(packetfilter.TableTypeNAT, constants.SmGlobalnetEgressChainForHeadlessSvcPods)
t.pFilter.AwaitChain(packetfilter.TableTypeNAT, constants.SmGlobalnetEgressChainForHeadlessSvcEPs)
t.pFilter.AwaitChain(packetfilter.TableTypeNAT, constants.SmGlobalnetEgressChainForNamespace)
t.pFilter.AwaitChain(packetfilter.TableTypeNAT, constants.SmGlobalnetEgressChainForCluster)
t.pFilter.AwaitChain(packetfilter.TableTypeNAT, routeAgent.SmPostRoutingChain)
t.pFilter.AwaitChain(packetfilter.TableTypeNAT, constants.SmGlobalnetMarkChain)
}

func (t *gatewayMonitorTestDriver) awaitNoGlobalnetChains() {
t.pFilter.AwaitNoChain(packetfilter.TableTypeNAT, constants.SmGlobalnetIngressChain)
t.pFilter.AwaitNoChain(packetfilter.TableTypeNAT, constants.SmGlobalnetEgressChain)
t.pFilter.AwaitNoChain(packetfilter.TableTypeNAT, constants.SmGlobalnetMarkChain)
func (t *gatewayMonitorTestDriver) awaitGlobalnetChainsCleared() {
t.pFilter.AwaitNoRules(packetfilter.TableTypeNAT, constants.SmGlobalnetIngressChain)
t.pFilter.AwaitNoRules(packetfilter.TableTypeNAT, constants.SmGlobalnetEgressChain)
t.pFilter.AwaitNoRules(packetfilter.TableTypeNAT, constants.SmGlobalnetMarkChain)
}

func newEndpointSpec(clusterID, hostname, subnet string) *submarinerv1.EndpointSpec {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -593,6 +593,7 @@ func newGlobalEgressIPControllerTestDriver() *globalEgressIPControllerTestDriver

BeforeEach(func() {
t.testDriverBase = newTestDriverBase()
t.testDriverBase.initChains()

var err error

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -505,6 +505,7 @@ func newGlobalIngressIPControllerDriver() *globalIngressIPControllerTestDriver {

BeforeEach(func() {
t.testDriverBase = newTestDriverBase()
t.testDriverBase.initChains()

var err error

Expand Down
1 change: 1 addition & 0 deletions pkg/globalnet/controllers/node_controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ func newNodeControllerTestDriver() *nodeControllerTestDriver {

BeforeEach(func() {
t.testDriverBase = newTestDriverBase()
t.testDriverBase.initChains()

var err error

Expand Down
1 change: 1 addition & 0 deletions pkg/globalnet/controllers/service_controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ func newServiceControllerTestDriver() *serviceControllerTestDriver {

BeforeEach(func() {
t.testDriverBase = newTestDriverBase()
t.testDriverBase.initChains()
})

JustBeforeEach(func() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,7 @@ func newServiceExportControllerTestDriver() *serviceExportControllerTestDriver {

BeforeEach(func() {
t.testDriverBase = newTestDriverBase()
t.testDriverBase.initChains()
})

JustBeforeEach(func() {
Expand Down
70 changes: 39 additions & 31 deletions pkg/packetfilter/fake/packetfilter.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ package fake
import (
"encoding/json"
"fmt"
"strings"
"sync"

. "github.com/onsi/gomega"
Expand All @@ -32,7 +33,6 @@ import (
type PacketFilter struct {
mutex sync.Mutex
chainRules map[string]set.Set[string]
tableChains map[uint32]set.Set[string]
failOnAppendRuleMatchers []interface{}
failOnDeleteRuleMatchers []interface{}

Expand All @@ -45,9 +45,8 @@ type PacketFilter struct {

func New() *PacketFilter {
pf := &PacketFilter{
chainRules: map[string]set.Set[string]{},
tableChains: map[uint32]set.Set[string]{},
sets: map[string]set.Set[string]{},
chainRules: map[string]set.Set[string]{},
sets: map[string]set.Set[string]{},
}

packetfilter.SetNewDriverFn(func() (packetfilter.Driver, error) {
Expand Down Expand Up @@ -89,11 +88,13 @@ func (i *PacketFilter) ClearChain(table packetfilter.TableType, chain string) er
i.mutex.Lock()
defer i.mutex.Unlock()

chainSet := i.tableChains[uint32(table)]
if chainSet != nil {
chainSet.Delete(chain)
ruleSet := i.chainRules[chainKey(uint32(table), chain)]
if ruleSet == nil {
return fmt.Errorf("chain %q for table %q does not exist", chain, table)
}

ruleSet.Clear()

return nil
}

Expand Down Expand Up @@ -165,7 +166,7 @@ func (i *PacketFilter) delete(table packetfilter.TableType, chain, rulespec stri
return err
}

ruleSet := i.chainRules[fmt.Sprintf("%v/%s", table, chain)]
ruleSet := i.chainRules[chainKey(uint32(table), chain)]
if ruleSet != nil {
ruleSet.Delete(rulespec)
}
Expand All @@ -177,23 +178,26 @@ func (i *PacketFilter) deleteChain(table uint32, chain string) {
i.mutex.Lock()
defer i.mutex.Unlock()

chainSet := i.tableChains[table]
if chainSet != nil {
chainSet.Delete(chain)
}
delete(i.chainRules, chainKey(table, chain))
}

func (i *PacketFilter) addChainsFor(table uint32, chains ...string) {
i.mutex.Lock()
defer i.mutex.Unlock()

chainSet := i.tableChains[table]
if chainSet == nil {
chainSet = set.New[string]()
i.tableChains[table] = chainSet
for _, chain := range chains {
key := chainKey(table, chain)

ruleSet := i.chainRules[key]
if ruleSet == nil {
ruleSet = set.New[string]()
i.chainRules[key] = ruleSet
}
}
}

chainSet.Insert(chains...)
func chainKey(table uint32, chain string) string {
return fmt.Sprintf("%v/%s", table, chain)
}

func (i *PacketFilter) addRule(table packetfilter.TableType, chain, rulespec string) error {
Expand All @@ -205,10 +209,9 @@ func (i *PacketFilter) addRule(table packetfilter.TableType, chain, rulespec str
return err
}

ruleSet := i.chainRules[fmt.Sprintf("%v/%s", table, chain)]
ruleSet := i.chainRules[chainKey(uint32(table), chain)]
if ruleSet == nil {
ruleSet = set.New[string]()
i.chainRules[fmt.Sprintf("%v/%s", table, chain)] = ruleSet
return fmt.Errorf("chain %q for table %q does not exist", chain, table)
}

ruleSet.Insert(rulespec)
Expand All @@ -220,7 +223,7 @@ func (i *PacketFilter) listRules(table packetfilter.TableType, chain string) []s
i.mutex.Lock()
defer i.mutex.Unlock()

rules := i.chainRules[fmt.Sprintf("%v/%s", table, chain)]
rules := i.chainRules[chainKey(uint32(table), chain)]
if rules != nil {
return rules.UnsortedList()
}
Expand All @@ -232,24 +235,23 @@ func (i *PacketFilter) listChains(table packetfilter.TableType) []string {
i.mutex.Lock()
defer i.mutex.Unlock()

chains := i.tableChains[uint32(table)]
if chains != nil {
return chains.UnsortedList()
var chains []string
tableKey := chainKey(uint32(table), "")

for k := range i.chainRules {
if strings.HasPrefix(k, tableKey) {
chains = append(chains, k[len(tableKey):])
}
}

return []string{}
return chains
}

func (i *PacketFilter) chainExists(table uint32, chain string) (bool, error) {
i.mutex.Lock()
defer i.mutex.Unlock()

chainSet := i.tableChains[table]
if chainSet != nil {
return chainSet.Has(chain), nil
}

return false, nil
return i.chainRules[chainKey(table, chain)] != nil, nil
}

func matchRuleForError(matchers *[]interface{}, rulespec string) error {
Expand Down Expand Up @@ -290,6 +292,12 @@ func (i *PacketFilter) AwaitNoRule(table packetfilter.TableType, chain string, s
}, 5).ShouldNot(ContainElement(stringOrMatcher), "Rules for IP table %v, chain %q", table, chain)
}

func (i *PacketFilter) AwaitNoRules(table packetfilter.TableType, chain string) {
Eventually(func() []string {
return i.listRules(table, chain)
}, 5).Should(BeEmpty())
}

func (i *PacketFilter) AddFailOnAppendRuleMatcher(stringOrMatcher interface{}) {
i.mutex.Lock()
defer i.mutex.Unlock()
Expand Down

0 comments on commit 3248d61

Please sign in to comment.