diff --git a/internal/splitrt/excludes.go b/internal/splitrt/excludes.go index bf11eb7..7f96f57 100644 --- a/internal/splitrt/excludes.go +++ b/internal/splitrt/excludes.go @@ -1,10 +1,8 @@ package splitrt import ( - "context" "net/netip" "sync" - "time" log "github.com/sirupsen/logrus" "github.com/telekom-mms/oc-daemon/internal/daemoncfg" @@ -31,9 +29,10 @@ type Excludes struct { closed chan struct{} } -// setFilter resets the excludes in netfilter. -func (e *Excludes) setFilter(ctx context.Context) { - log.Debug("SplitRouting resetting excludes in netfilter") +// GetPrefixes returns static and dynamic split excludes as Prefixes. +func (e *Excludes) GetPrefixes() []netip.Prefix { + e.Lock() + defer e.Unlock() addresses := []netip.Prefix{} for _, v := range e.s { @@ -43,18 +42,18 @@ func (e *Excludes) setFilter(ctx context.Context) { prefix := netip.PrefixFrom(k, k.BitLen()) addresses = append(addresses, prefix) } - setExcludes(ctx, addresses) + + return addresses } // AddStatic adds a static entry to the split excludes. -func (e *Excludes) AddStatic(ctx context.Context, address netip.Prefix) { +func (e *Excludes) AddStatic(address netip.Prefix) bool { log.WithField("address", address).Debug("SplitRouting adding static exclude") e.Lock() defer e.Unlock() // make sure new prefix in address does not overlap with existing // prefixes in static excludes - removed := false for k, v := range e.s { if !v.Overlaps(address) { // no overlap @@ -63,30 +62,23 @@ func (e *Excludes) AddStatic(ctx context.Context, address netip.Prefix) { if v.Bits() <= address.Bits() { // new prefix is already in existing prefix, // do not add it - return + return false } // new prefix contains old prefix, remove old prefix delete(e.s, k) - removed = true } // add new prefix to static excludes key := address.String() e.s[key] = address - // add to netfilter - if removed { - // existing entries removed, we need to reset all excludes - e.setFilter(ctx) - return - } - // single new entry, add it - addExclude(ctx, address) + // update netfilter + return true } // AddDynamic adds a dynamic entry to the split excludes. -func (e *Excludes) AddDynamic(ctx context.Context, address netip.Prefix, ttl uint32) { +func (e *Excludes) AddDynamic(address netip.Prefix, ttl uint32) bool { log.WithFields(log.Fields{ "address": address, "ttl": ttl, @@ -94,7 +86,7 @@ func (e *Excludes) AddDynamic(ctx context.Context, address netip.Prefix, ttl uin if !address.IsSingleIP() { log.Error("SplitRouting error adding dynamic exclude with multiple IPs") - return + return false } a := address.Addr() @@ -104,7 +96,7 @@ func (e *Excludes) AddDynamic(ctx context.Context, address netip.Prefix, ttl uin // make sure new ip address is not in existing static excludes for _, v := range e.s { if v.Contains(a) { - return + return false } } @@ -113,7 +105,7 @@ func (e *Excludes) AddDynamic(ctx context.Context, address netip.Prefix, ttl uin if old != nil { old.ttl = ttl old.updated = true - return + return false } // create new entry in dynamic excludes @@ -122,21 +114,25 @@ func (e *Excludes) AddDynamic(ctx context.Context, address netip.Prefix, ttl uin updated: true, } - // add to netfilter - addExclude(ctx, address) + // update netfilter + return true } // RemoveStatic removes a static entry from the split excludes. -func (e *Excludes) RemoveStatic(ctx context.Context, address netip.Prefix) { +func (e *Excludes) RemoveStatic(address netip.Prefix) bool { e.Lock() defer e.Unlock() - delete(e.s, address.String()) - e.setFilter(ctx) + addr := address.String() + if _, ok := e.s[addr]; !ok { + return false + } + delete(e.s, addr) + return true } // cleanup cleans up the dynamic split excludes. -func (e *Excludes) cleanup(ctx context.Context) { +func (e *Excludes) cleanup() bool { e.Lock() defer e.Unlock() @@ -160,43 +156,7 @@ func (e *Excludes) cleanup(ctx context.Context) { } // if entries were changed, reset netfilter - if changed { - e.setFilter(ctx) - } -} - -// start starts periodic cleanup of the split excludes. -func (e *Excludes) start() { - defer close(e.closed) - - ctx := context.Background() - timer := time.NewTimer(excludesTimer * time.Second) - for { - select { - case <-timer.C: - e.cleanup(ctx) - timer.Reset(excludesTimer * time.Second) - - case <-e.done: - if !timer.Stop() { - <-timer.C - } - return - } - } -} - -// Start starts periodic cleanup of the split excludes. -func (e *Excludes) Start() { - log.Debug("SplitRouting starting periodic cleanup of excludes") - go e.start() -} - -// Stop stops periodic cleanup of the split excludes. -func (e *Excludes) Stop() { - close(e.done) - <-e.closed - log.Debug("SplitRouting stopped periodic cleanup of excludes") + return changed } // List returns the list of static and dynamic excludes. diff --git a/internal/splitrt/excludes_test.go b/internal/splitrt/excludes_test.go index 54cbd97..26e311d 100644 --- a/internal/splitrt/excludes_test.go +++ b/internal/splitrt/excludes_test.go @@ -1,14 +1,10 @@ package splitrt import ( - "context" - "errors" "net/netip" - "reflect" "testing" "github.com/telekom-mms/oc-daemon/internal/daemoncfg" - "github.com/telekom-mms/oc-daemon/internal/execs" ) // getTestExcludes returns excludes for testing. @@ -63,41 +59,27 @@ func getTestDynamicExcludes(t *testing.T) []netip.Prefix { // TestExcludesAddStatic tests AddStatic of Excludes. func TestExcludesAddStatic(t *testing.T) { - ctx := context.Background() e := NewExcludes(daemoncfg.NewConfig()) excludes := getTestStaticExcludes(t) - // set testing runNft function - got := []string{} - execs.RunCmd = func(_ context.Context, _ string, s string, _ ...string) ([]byte, []byte, error) { - got = append(got, s) - return nil, nil, nil - } - // test adding excludes - want := []string{ - "add element inet oc-daemon-routing excludes4 { 192.168.1.0/24 }", - "add element inet oc-daemon-routing excludes6 { 2001::/64 }", - } for _, exclude := range excludes { - e.AddStatic(ctx, exclude) - } - if !reflect.DeepEqual(got, want) { - t.Errorf("got %v, want %v", got, want) + if !e.AddStatic(exclude) { + t.Errorf("should add exclude %s", exclude) + } } // test adding excludes again, should not change nft commands for _, exclude := range excludes { - e.AddStatic(ctx, exclude) - } - if !reflect.DeepEqual(got, want) { - t.Errorf("got %v, want %v", got, want) + if e.AddStatic(exclude) { + t.Errorf("should not add exclude %s", exclude) + } } // test adding overlapping excludes e = NewExcludes(daemoncfg.NewConfig()) for _, exclude := range getTestStaticExcludesOverlap(t) { - e.AddStatic(ctx, exclude) + e.AddStatic(exclude) } for k := range e.s { if k != "192.168.1.0/24" && k != "2001:2001:2001:2000::/56" { @@ -108,219 +90,131 @@ func TestExcludesAddStatic(t *testing.T) { // TestExcludesAddDynamic tests AddDynamic of Excludes. func TestExcludesAddDynamic(t *testing.T) { - ctx := context.Background() e := NewExcludes(daemoncfg.NewConfig()) excludes := getTestDynamicExcludes(t) - // set testing runNft function - got := []string{} - execs.RunCmd = func(_ context.Context, _ string, s string, _ ...string) ([]byte, []byte, error) { - got = append(got, s) - return nil, nil, nil - } - // test adding excludes - want := []string{ - "add element inet oc-daemon-routing excludes4 { 192.168.1.1/32 }", - "add element inet oc-daemon-routing excludes6 { 2001::1/128 }", - "add element inet oc-daemon-routing excludes4 { 172.16.1.1/32 }", - } for _, exclude := range excludes { - e.AddDynamic(ctx, exclude, 300) - } - if !reflect.DeepEqual(got, want) { - t.Errorf("got %v, want %v", got, want) + if !e.AddDynamic(exclude, 300) { + t.Errorf("should add exclude %s", exclude) + } } // test adding excludes again, should not change nft commands for _, exclude := range excludes { - e.AddDynamic(ctx, exclude, 300) - } - if !reflect.DeepEqual(got, want) { - t.Errorf("got %v, want %v", got, want) + if e.AddDynamic(exclude, 300) { + t.Errorf("should not add exclude %s", exclude) + } } // test adding excludes with existing static excludes, // should only add new excludes + statics := getTestStaticExcludes(t) e = NewExcludes(daemoncfg.NewConfig()) - for _, exclude := range getTestStaticExcludes(t) { - e.AddStatic(ctx, exclude) - } - got = []string{} - want = []string{ - "add element inet oc-daemon-routing excludes4 { 172.16.1.1/32 }", + for _, exclude := range statics { + if !e.AddStatic(exclude) { + t.Errorf("should add exclude %s", exclude) + } } for _, exclude := range excludes { - e.AddDynamic(ctx, exclude, 300) - } - if !reflect.DeepEqual(got, want) { - t.Errorf("got %v, want %v", got, want) + add := true + for _, static := range statics { + if static.Overlaps(exclude) { + add = false + } + } + if add && !e.AddDynamic(exclude, 300) { + t.Errorf("should add exclude %s", exclude) + } + if !add && e.AddDynamic(exclude, 300) { + t.Errorf("should not add exclude %s", exclude) + } } // test adding invalid excludes (static as dynamic) e = NewExcludes(daemoncfg.NewConfig()) - got = []string{} - want = []string{} for _, exclude := range getTestStaticExcludes(t) { - e.AddDynamic(ctx, exclude, 300) - } - if !reflect.DeepEqual(got, want) { - t.Errorf("got %v, want %v", got, want) + if e.AddDynamic(exclude, 300) { + t.Errorf("should not add exclude %s", exclude) + } } } // TestExcludesRemoveStatic tests RemoveStatic of Excludes. func TestExcludesRemove(t *testing.T) { - ctx := context.Background() e := NewExcludes(daemoncfg.NewConfig()) excludes := getTestStaticExcludes(t) - // set testing runNft function - got := []string{} - oldRunCmd := execs.RunCmd - execs.RunCmd = func(_ context.Context, _ string, s string, _ ...string) ([]byte, []byte, error) { - got = append(got, s) - return nil, nil, nil - } - defer func() { execs.RunCmd = oldRunCmd }() - // test removing not existing excludes - want := []string{ - "flush set inet oc-daemon-routing excludes4\n" + - "flush set inet oc-daemon-routing excludes6\n", - "flush set inet oc-daemon-routing excludes4\n" + - "flush set inet oc-daemon-routing excludes6\n", - } for _, exclude := range excludes { - e.RemoveStatic(ctx, exclude) - } - if !reflect.DeepEqual(got, want) { - t.Errorf("got %v, want %v", got, want) + if e.RemoveStatic(exclude) { + t.Errorf("should not remove exclude %s", exclude) + } } // test removing static excludes - got = []string{} - want = []string{ - "add element inet oc-daemon-routing excludes4 { 192.168.1.0/24 }", - "add element inet oc-daemon-routing excludes6 { 2001::/64 }", - "flush set inet oc-daemon-routing excludes4\n" + - "flush set inet oc-daemon-routing excludes6\n" + - "add element inet oc-daemon-routing excludes6 { 2001::/64 }\n", - "flush set inet oc-daemon-routing excludes4\n" + - "flush set inet oc-daemon-routing excludes6\n", - } for _, exclude := range excludes { - e.AddStatic(ctx, exclude) - } - for _, exclude := range excludes { - e.RemoveStatic(ctx, exclude) - } - if !reflect.DeepEqual(got, want) { - t.Errorf("got %v, want %v", got, want) - } - - // test with nft error - got = []string{} - execs.RunCmd = func(_ context.Context, _ string, s string, _ ...string) ([]byte, []byte, error) { - got = append(got, s) - return nil, nil, errors.New("test error") - } - for _, exclude := range excludes { - e.AddStatic(ctx, exclude) + if !e.AddStatic(exclude) { + t.Fatalf("should add exclude %s", exclude) + } } for _, exclude := range excludes { - e.RemoveStatic(ctx, exclude) - } - if !reflect.DeepEqual(got, want) { - t.Errorf("got %v, want %v", got, want) + if !e.RemoveStatic(exclude) { + t.Errorf("should remove exclude %s", exclude) + } } // test removing with dynamic excludes - got = []string{} - want = []string{ - "add element inet oc-daemon-routing excludes4 { 192.168.1.0/24 }", - "add element inet oc-daemon-routing excludes6 { 2001::/64 }", - "add element inet oc-daemon-routing excludes4 { 172.16.1.1/32 }", - "flush set inet oc-daemon-routing excludes4\n" + - "flush set inet oc-daemon-routing excludes6\n" + - "add element inet oc-daemon-routing excludes6 { 2001::/64 }\n" + - "add element inet oc-daemon-routing excludes4 { 172.16.1.1/32 }\n", - "flush set inet oc-daemon-routing excludes4\n" + - "flush set inet oc-daemon-routing excludes6\n" + - "add element inet oc-daemon-routing excludes4 { 172.16.1.1/32 }\n", - } for _, exclude := range excludes { - e.AddStatic(ctx, exclude) + if !e.AddStatic(exclude) { + t.Fatalf("should add exclude %s", exclude) + } } for _, exclude := range getTestDynamicExcludes(t) { - e.AddDynamic(ctx, exclude, 300) + e.AddDynamic(exclude, 300) } for _, exclude := range excludes { - e.RemoveStatic(ctx, exclude) - } - if !reflect.DeepEqual(got, want) { - t.Errorf("got %v, want %v", got, want) + if !e.RemoveStatic(exclude) { + t.Errorf("should remove exclude %s", exclude) + } } } // TestExcludesCleanup tests cleanup of Excludes. func TestExcludesCleanup(t *testing.T) { - ctx := context.Background() e := NewExcludes(daemoncfg.NewConfig()) - // set testing runNft function - got := []string{} - execs.RunCmd = func(_ context.Context, _ string, s string, _ ...string) ([]byte, []byte, error) { - got = append(got, s) - return nil, nil, nil - } - // test without excludes - want := []string{} - e.cleanup(ctx) - if !reflect.DeepEqual(got, want) { - t.Errorf("got %v, want %v", got, want) + if e.cleanup() { + t.Error("should not remove excludes") } // test with dynamic excludes for _, exclude := range getTestDynamicExcludes(t) { - e.AddDynamic(ctx, exclude, excludesTimer) + if !e.AddDynamic(exclude, excludesTimer) { + t.Fatalf("should add exclude %s", exclude) + } } - got = []string{} for i := 0; i <= excludesTimer; i += excludesTimer { - want := []string{} - e.cleanup(ctx) - if !reflect.DeepEqual(got, want) { - t.Errorf("got %v, want %v", got, want) + if e.cleanup() { + t.Error("should not remove excludes") } } - want = []string{ - "flush set inet oc-daemon-routing excludes4\n" + - "flush set inet oc-daemon-routing excludes6\n", - } - e.cleanup(ctx) - if !reflect.DeepEqual(got, want) { - t.Errorf("got %v, want %v", got, want) + if !e.cleanup() { + t.Error("should remove excludes") } // test with static excludes for _, exclude := range getTestStaticExcludes(t) { - e.AddStatic(ctx, exclude) + if !e.AddStatic(exclude) { + t.Fatalf("should add exclude %s", exclude) + } } - got = []string{} - want = []string{} - e.cleanup(ctx) - if !reflect.DeepEqual(got, want) { - t.Errorf("got %v, want %v", got, want) + if e.cleanup() { + t.Error("should not remove excludes") } -} -// TestExcludesStartStop tests Start and Stop of Excludes. -func TestExcludesStartStop(_ *testing.T) { - e := NewExcludes(daemoncfg.NewConfig()) - e.Start() - e.Stop() } // TestNewExcludes tests NewExcludes. diff --git a/internal/splitrt/splitrt.go b/internal/splitrt/splitrt.go index b08a402..51c55dd 100644 --- a/internal/splitrt/splitrt.go +++ b/internal/splitrt/splitrt.go @@ -6,6 +6,7 @@ import ( "fmt" "net/netip" "sync" + "time" log "github.com/sirupsen/logrus" "github.com/telekom-mms/oc-daemon/internal/addrmon" @@ -80,14 +81,13 @@ func (s *SplitRouting) setupRouting(ctx context.Context) { rejectIPv4(ctx, s.config.VPNConfig.Device.Name) } - // add excludes - s.excludes.Start() - // add gateway to static excludes if s.config.VPNConfig.Gateway.IsValid() { gateway := netip.PrefixFrom(s.config.VPNConfig.Gateway, s.config.VPNConfig.Gateway.BitLen()) - s.excludes.AddStatic(ctx, gateway) + if s.excludes.AddStatic(gateway) { + setExcludes(ctx, s.excludes.GetPrefixes()) + } } // add static IPv4 excludes @@ -95,7 +95,9 @@ func (s *SplitRouting) setupRouting(ctx context.Context) { if e.String() == "0.0.0.0/32" { continue } - s.excludes.AddStatic(ctx, e) + if s.excludes.AddStatic(e) { + setExcludes(ctx, s.excludes.GetPrefixes()) + } } // add static IPv6 excludes @@ -104,7 +106,9 @@ func (s *SplitRouting) setupRouting(ctx context.Context) { if e.String() == "::/128" { continue } - s.excludes.AddStatic(ctx, e) + if s.excludes.AddStatic(e) { + setExcludes(ctx, s.excludes.GetPrefixes()) + } } // setup routing @@ -131,9 +135,6 @@ func (s *SplitRouting) teardownRouting(ctx context.Context) { s.config.VPNConfig.Device.Name, s.config.SplitRouting.RoutingTable) unsetRoutingRules(ctx) - - // remove excludes - s.excludes.Stop() } // excludeSettings returns whether local (virtual) networks should be excluded. @@ -184,14 +185,18 @@ func (s *SplitRouting) updateLocalNetworkExcludes(ctx context.Context) { // add new excludes for _, e := range excludes { if !isIn(e, s.locals.get()) { - s.excludes.AddStatic(ctx, e) + if s.excludes.AddStatic(e) { + setExcludes(ctx, s.excludes.GetPrefixes()) + } } } // remove old excludes for _, l := range s.locals.get() { if !isIn(l, excludes) { - s.excludes.RemoveStatic(ctx, l) + if s.excludes.RemoveStatic(l) { + setExcludes(ctx, s.excludes.GetPrefixes()) + } } } @@ -238,7 +243,10 @@ func (s *SplitRouting) handleDNSReport(ctx context.Context, r *dnsproxy.Report) defer r.Close() log.WithField("report", r).Debug("SplitRouting handling DNS report") - s.excludes.AddDynamic(ctx, netip.PrefixFrom(r.IP, r.IP.BitLen()), r.TTL) + exclude := netip.PrefixFrom(r.IP, r.IP.BitLen()) + if s.excludes.AddDynamic(exclude, r.TTL) { + addExclude(ctx, exclude) + } } // start starts split routing. @@ -249,6 +257,7 @@ func (s *SplitRouting) start(ctx context.Context) { defer s.addrmon.Stop() // main loop + timer := time.NewTimer(excludesTimer * time.Second) for { select { case u := <-s.devmon.Updates(): @@ -257,7 +266,13 @@ func (s *SplitRouting) start(ctx context.Context) { s.handleAddressUpdate(ctx, u) case r := <-s.dnsreps: s.handleDNSReport(ctx, r) + case <-timer.C: + s.excludes.cleanup() + timer.Reset(excludesTimer * time.Second) case <-s.done: + if !timer.Stop() { + <-timer.C + } return } } diff --git a/internal/splitrt/splitrt_test.go b/internal/splitrt/splitrt_test.go index 9ec11b8..2f6e04a 100644 --- a/internal/splitrt/splitrt_test.go +++ b/internal/splitrt/splitrt_test.go @@ -84,7 +84,9 @@ func TestSplitRoutingHandleAddressUpdate(t *testing.T) { // test adding want := []string{ - "add element inet oc-daemon-routing excludes4 { 192.168.1.1/32 }", + "flush set inet oc-daemon-routing excludes4\n" + + "flush set inet oc-daemon-routing excludes6\n" + + "add element inet oc-daemon-routing excludes4 { 192.168.1.1/32 }\n", } update := getTestAddrMonUpdate(t, "192.168.1.1/32") s.handleAddressUpdate(ctx, update) @@ -119,7 +121,9 @@ func TestSplitRoutingHandleAddressUpdate(t *testing.T) { // test adding want = []string{ - "add element inet oc-daemon-routing excludes4 { 192.168.1.1/32 }", + "flush set inet oc-daemon-routing excludes4\n" + + "flush set inet oc-daemon-routing excludes6\n" + + "add element inet oc-daemon-routing excludes4 { 192.168.1.1/32 }\n", } update = getTestAddrMonUpdate(t, "192.168.1.1/32") s.handleAddressUpdate(ctx, update)