From 8e129f7134d492f2f479999d7d734932a3ac66c9 Mon Sep 17 00:00:00 2001
From: hwipl <33433250+hwipl@users.noreply.github.com>
Date: Mon, 26 Aug 2024 20:15:30 +0200
Subject: [PATCH] Add method "DumpState" to D-Bus API
Add the method "DumpState" to the D-Bus API that returns the internal
state of the OC-Daemon as a JSON string for debugging.
Signed-off-by: hwipl <33433250+hwipl@users.noreply.github.com>
---
.../com.telekom_mms.oc_daemon.Daemon.conf | 4 +
internal/client/client.go | 18 +++
internal/client/client_test.go | 35 ++++++
internal/client/cmd.go | 2 +
internal/client/cmd_test.go | 1 +
internal/daemon/daemon.go | 34 ++++++
internal/dbusapi/service.go | 41 +++++--
internal/dbusapi/service_test.go | 67 ++++++++++
internal/dnsproxy/proxy.go | 19 +++
internal/dnsproxy/proxy_test.go | 29 +++++
internal/dnsproxy/remotes.go | 12 ++
internal/dnsproxy/watches.go | 17 +++
internal/splitrt/addresses.go | 26 ++++
internal/splitrt/devices.go | 31 +++++
internal/splitrt/excludes.go | 15 +++
internal/splitrt/splitrt.go | 62 +++++++++-
internal/splitrt/splitrt_test.go | 49 ++++++++
internal/trafpol/allowaddrs.go | 46 +++++++
internal/trafpol/allowdevs.go | 9 ++
internal/trafpol/allowdevs_test.go | 21 ++++
internal/trafpol/allownames.go | 30 +++++
internal/trafpol/trafpol.go | 115 +++++++++++++-----
internal/trafpol/trafpol_test.go | 41 +++++--
internal/vpnsetup/vpnsetup.go | 33 +++++
internal/vpnsetup/vpnsetup_test.go | 50 ++++++++
pkg/client/client.go | 16 +++
pkg/client/client_test.go | 18 +++
27 files changed, 788 insertions(+), 53 deletions(-)
create mode 100644 internal/trafpol/allowaddrs.go
create mode 100644 internal/trafpol/allownames.go
diff --git a/configs/dbus/com.telekom_mms.oc_daemon.Daemon.conf b/configs/dbus/com.telekom_mms.oc_daemon.Daemon.conf
index 0398fbd2..0daa096b 100644
--- a/configs/dbus/com.telekom_mms.oc_daemon.Daemon.conf
+++ b/configs/dbus/com.telekom_mms.oc_daemon.Daemon.conf
@@ -18,6 +18,10 @@
+
+
diff --git a/internal/client/client.go b/internal/client/client.go
index 28ae158d..fad9f8a7 100644
--- a/internal/client/client.go
+++ b/internal/client/client.go
@@ -221,6 +221,24 @@ func getStatus() error {
return printStatus(status)
}
+// dumpState dumps the internal state of the daemon.
+func dumpState() error {
+ // create client
+ c, err := clientNewClient(config)
+ if err != nil {
+ return fmt.Errorf("error creating client: %w", err)
+ }
+ defer func() { _ = c.Close() }()
+
+ // get status
+ state, err := c.DumpState()
+ if err != nil {
+ return fmt.Errorf("error getting status: %w", err)
+ }
+ fmt.Println(state)
+ return nil
+}
+
// monitor subscribes to VPN status updates from the daemon and displays them.
func monitor() error {
// create client
diff --git a/internal/client/client_test.go b/internal/client/client_test.go
index 3a0a2229..c4ee5b61 100644
--- a/internal/client/client_test.go
+++ b/internal/client/client_test.go
@@ -14,6 +14,8 @@ import (
type testClient struct {
querErr error
status *vpnstatus.Status
+ dumpErr error
+ dumpSta string
authErr error
connErr error
discErr error
@@ -33,6 +35,7 @@ func (t *testClient) Subscribe() (chan *vpnstatus.Status, error) { return t.subs
func (t *testClient) Authenticate() error { return t.authErr }
func (t *testClient) Connect() error { return t.connErr }
func (t *testClient) Disconnect() error { return t.discErr }
+func (t *testClient) DumpState() (string, error) { return t.dumpSta, t.dumpErr }
func (t *testClient) Close() error { return nil }
// TestListServers tests listServers.
@@ -205,6 +208,38 @@ func TestGetStatus(t *testing.T) {
}
}
+// TestDumpState tests dumpState.
+func TestDumpState(t *testing.T) {
+ defer func() { clientNewClient = client.NewClient }()
+
+ // test with client error
+ clientNewClient = func(*client.Config) (client.Client, error) {
+ return nil, errors.New("test error")
+ }
+
+ if err := dumpState(); err == nil {
+ t.Error("client error should return error")
+ }
+
+ // test with dump state error
+ clientNewClient = func(*client.Config) (client.Client, error) {
+ return &testClient{dumpErr: errors.New("test error")}, nil
+ }
+
+ if err := dumpState(); err == nil {
+ t.Error("dump state error should return error")
+ }
+
+ // test without error
+ clientNewClient = func(*client.Config) (client.Client, error) {
+ return &testClient{dumpSta: "test state"}, nil
+ }
+
+ if err := dumpState(); err != nil {
+ t.Error("dump state should not return error")
+ }
+}
+
// TestMonitor tests monitor.
func TestMonitor(t *testing.T) {
defer func() { clientNewClient = client.NewClient }()
diff --git a/internal/client/cmd.go b/internal/client/cmd.go
index ce3da869..cea26f3f 100644
--- a/internal/client/cmd.go
+++ b/internal/client/cmd.go
@@ -243,6 +243,8 @@ func run(args []string) error {
return reconnectVPN()
case "status":
return getStatus()
+ case "dumpstate":
+ return dumpState()
case "monitor":
return monitor()
case "save":
diff --git a/internal/client/cmd_test.go b/internal/client/cmd_test.go
index 5f0b270a..c9e2eca4 100644
--- a/internal/client/cmd_test.go
+++ b/internal/client/cmd_test.go
@@ -126,6 +126,7 @@ func TestRun(t *testing.T) {
"reconnect",
"status",
"monitor",
+ "dumpstate",
} {
if err := run([]string{"test",
"-cert", "cert-file",
diff --git a/internal/daemon/daemon.go b/internal/daemon/daemon.go
index a2f3cd0e..7da02589 100644
--- a/internal/daemon/daemon.go
+++ b/internal/daemon/daemon.go
@@ -3,6 +3,7 @@ package daemon
import (
"context"
+ "encoding/json"
"fmt"
"net"
"net/netip"
@@ -473,6 +474,33 @@ func (d *Daemon) handleClientRequest(request *api.Request) {
}
}
+// dumpState returns the internal daemon state as json string.
+func (d *Daemon) dumpState() string {
+ // define state type
+ type State struct {
+ TrafficPolicing *trafpol.State
+ VPNSetup *vpnsetup.State
+ }
+
+ // collect internal state
+ state := State{}
+ if d.trafpol != nil {
+ state.TrafficPolicing = d.trafpol.GetState()
+ }
+ if d.vpnsetup != nil {
+ state.VPNSetup = d.vpnsetup.GetState()
+ }
+
+ // convert to json
+ b, err := json.Marshal(state)
+ if err != nil {
+ log.WithError(err).Error("Daemon could not convert internal state to JSON")
+ return ""
+ }
+
+ return string(b)
+}
+
// handleDBusRequest handles a D-Bus API client request.
func (d *Daemon) handleDBusRequest(request *dbusapi.Request) {
defer request.Close()
@@ -505,6 +533,12 @@ func (d *Daemon) handleDBusRequest(request *dbusapi.Request) {
// diconnect VPN
log.Info("Daemon got disconnect request from client")
d.disconnectVPN()
+
+ case dbusapi.RequestDumpState:
+ // dump state
+ state := d.dumpState()
+ log.WithField("state", state).Info("Daemon got dump state request from client")
+ request.Results = []any{state}
}
}
diff --git a/internal/dbusapi/service.go b/internal/dbusapi/service.go
index bb449fc9..fc5e8cdd 100644
--- a/internal/dbusapi/service.go
+++ b/internal/dbusapi/service.go
@@ -130,12 +130,14 @@ const (
const (
MethodConnect = Interface + ".Connect"
MethodDisconnect = Interface + ".Disconnect"
+ MethodDumpState = Interface + ".DumpState"
)
// Request Names.
const (
RequestConnect = "Connect"
RequestDisconnect = "Disconnect"
+ RequestDumpState = "DumpState"
)
// Request is a D-Bus client request.
@@ -212,6 +214,27 @@ func (d daemon) Disconnect(sender dbus.Sender) *dbus.Error {
return nil
}
+// DumpState is the "DumpState" method of the D-Bus interface.
+func (d daemon) DumpState(sender dbus.Sender) (string, *dbus.Error) {
+ log.WithField("sender", sender).Debug("Received D-Bus DumpState() call")
+ request := &Request{
+ Name: RequestDumpState,
+ wait: make(chan struct{}),
+ done: d.done,
+ }
+ select {
+ case d.requests <- request:
+ case <-d.done:
+ return "", dbus.NewError(Interface+".DumpStateAborted", []any{"DumpState aborted"})
+ }
+
+ request.Wait()
+ if request.Error != nil {
+ return "", dbus.NewError(Interface+".DumpStateAborted", []any{request.Error.Error()})
+ }
+ return request.Results[0].(string), nil
+}
+
// propertyUpdate is an update of a property.
type propertyUpdate struct {
name string
@@ -431,16 +454,18 @@ func (s *Service) Start() error {
// set names of method arguments
introMeths := introspect.Methods(meths)
for _, m := range introMeths {
- if m.Name != "Connect" {
- continue
+ if m.Name == "Connect" {
+ m.Args[0].Name = "server"
+ m.Args[1].Name = "cookie"
+ m.Args[2].Name = "host"
+ m.Args[3].Name = "connect_url"
+ m.Args[4].Name = "fingerprint"
+ m.Args[5].Name = "resolve"
}
- m.Args[0].Name = "server"
- m.Args[1].Name = "cookie"
- m.Args[2].Name = "host"
- m.Args[3].Name = "connect_url"
- m.Args[4].Name = "fingerprint"
- m.Args[5].Name = "resolve"
+ if m.Name == "DumpState" {
+ m.Args[0].Name = "state"
+ }
}
// set peer interface
peerData := introspect.Interface{
diff --git a/internal/dbusapi/service_test.go b/internal/dbusapi/service_test.go
index 41a9b5ce..32e4f852 100644
--- a/internal/dbusapi/service_test.go
+++ b/internal/dbusapi/service_test.go
@@ -167,6 +167,73 @@ func TestDaemonDisconnect(t *testing.T) {
}
}
+// TestDaemonDumpStateErrors tests DumpState of daemon, errors.
+func TestDaemonDumpStateErrors(t *testing.T) {
+ // create daemon
+ requests := make(chan *Request)
+ done := make(chan struct{})
+ daemon := daemon{
+ requests: requests,
+ done: done,
+ }
+
+ // error when handling request
+ go func() {
+ r := <-requests
+ r.Error = errors.New("test error")
+ r.Close()
+ }()
+ if _, err := daemon.DumpState(""); err == nil {
+ t.Error("should return error")
+ }
+
+ // closed daemon
+ close(done)
+ if _, err := daemon.DumpState(""); err == nil {
+ t.Error("should return error")
+ }
+}
+
+// TestDaemonDumpState tests DumpState of daemon.
+func TestDaemonDumpState(t *testing.T) {
+ // create daemon
+ requests := make(chan *Request)
+ done := make(chan struct{})
+ daemon := daemon{
+ requests: requests,
+ done: done,
+ }
+
+ // run disconnect and get results
+ want := &Request{
+ Name: RequestDumpState,
+ Results: []any{"test state"},
+ done: done,
+ }
+ got := &Request{}
+ go func() {
+ r := <-requests
+ r.Results = append(r.Results, "test state")
+ got = r
+ r.Close()
+ }()
+ state, err := daemon.DumpState("sender")
+ if err != nil {
+ t.Error(err)
+ }
+
+ // check results
+ if got.Name != want.Name ||
+ !reflect.DeepEqual(got.Parameters, want.Parameters) ||
+ !reflect.DeepEqual(got.Results, want.Results) ||
+ got.Error != want.Error ||
+ got.done != want.done ||
+ state != "test state" {
+ // not equal
+ t.Errorf("got %v, want %v", got, want)
+ }
+}
+
// testConn implements the dbusConn interface for testing.
type testConn struct {
reqNameReply dbus.RequestNameReply
diff --git a/internal/dnsproxy/proxy.go b/internal/dnsproxy/proxy.go
index be836fd8..0e72c465 100644
--- a/internal/dnsproxy/proxy.go
+++ b/internal/dnsproxy/proxy.go
@@ -9,6 +9,14 @@ import (
log "github.com/sirupsen/logrus"
)
+// State is the internal state of the DNS Proxy.
+type State struct {
+ Config *Config
+ Remotes map[string][]string
+ Watches []string
+ TempWatches []string
+}
+
// Proxy is a DNS proxy.
type Proxy struct {
config *Config
@@ -247,6 +255,17 @@ func (p *Proxy) SetWatches(watches []string) {
}
}
+// GetState returns the internal state of the DNS Proxy.
+func (p *Proxy) GetState() *State {
+ watches, tempWatches := p.watches.List()
+ return &State{
+ Config: p.config,
+ Remotes: p.remotes.List(),
+ Watches: watches,
+ TempWatches: tempWatches,
+ }
+}
+
// NewProxy returns a new Proxy that listens on address.
func NewProxy(config *Config) *Proxy {
var udp *dns.Server
diff --git a/internal/dnsproxy/proxy_test.go b/internal/dnsproxy/proxy_test.go
index 217b7ee2..5a580e80 100644
--- a/internal/dnsproxy/proxy_test.go
+++ b/internal/dnsproxy/proxy_test.go
@@ -4,6 +4,7 @@ import (
"errors"
"net"
"net/netip"
+ "reflect"
"testing"
"github.com/miekg/dns"
@@ -272,6 +273,34 @@ func TestProxySetWatches(_ *testing.T) {
p.SetWatches(watches)
}
+// TestProxyGetState tests GetState of Proxy.
+func TestProxyGetState(t *testing.T) {
+ p := NewProxy(getTestConfig())
+
+ // set remotes
+ getRemotes := func() map[string][]string {
+ return map[string][]string{".": {"192.168.1.1"}}
+ }
+ p.SetRemotes(getRemotes())
+
+ // set watches
+ p.watches.Add("example.com.")
+ p.watches.AddTempCNAME("cname.example.com.", 300)
+ p.watches.AddTempDNAME("dname.example.com.", 300)
+
+ // check state
+ want := &State{
+ Config: getTestConfig(),
+ Remotes: getRemotes(),
+ Watches: []string{"example.com."},
+ TempWatches: []string{"cname.example.com.", "dname.example.com."},
+ }
+ got := p.GetState()
+ if !reflect.DeepEqual(got, want) {
+ t.Errorf("got %v, want %v", got, want)
+ }
+}
+
// TestNewProxy tests NewProxy.
func TestNewProxy(t *testing.T) {
p := NewProxy(getTestConfig())
diff --git a/internal/dnsproxy/remotes.go b/internal/dnsproxy/remotes.go
index 40b8838e..48096433 100644
--- a/internal/dnsproxy/remotes.go
+++ b/internal/dnsproxy/remotes.go
@@ -67,6 +67,18 @@ func (r *Remotes) Get(domain string) []string {
return r.m["."]
}
+// List returns the remotes.
+func (r *Remotes) List() map[string][]string {
+ r.RLock()
+ defer r.RUnlock()
+
+ remotes := make(map[string][]string)
+ for k, v := range r.m {
+ remotes[k] = append(remotes[k], v...)
+ }
+ return remotes
+}
+
// NewRemotes returns a new Remotes.
func NewRemotes() *Remotes {
return &Remotes{
diff --git a/internal/dnsproxy/watches.go b/internal/dnsproxy/watches.go
index f5ff80f9..dd2d98e5 100644
--- a/internal/dnsproxy/watches.go
+++ b/internal/dnsproxy/watches.go
@@ -168,6 +168,23 @@ func (w *Watches) Contains(domain string) bool {
return false
}
+// List returns the list of watches and temporary watches.
+func (w *Watches) List() (watches, tempWatches []string) {
+ w.RLock()
+ defer w.RUnlock()
+
+ for k := range w.m {
+ watches = append(watches, k)
+ }
+ for k := range w.c {
+ tempWatches = append(tempWatches, k)
+ }
+ for k := range w.d {
+ tempWatches = append(tempWatches, k)
+ }
+ return
+}
+
// Close closes the watches
func (w *Watches) Close() {
close(w.done)
diff --git a/internal/splitrt/addresses.go b/internal/splitrt/addresses.go
index e009db88..e737ec3b 100644
--- a/internal/splitrt/addresses.go
+++ b/internal/splitrt/addresses.go
@@ -2,12 +2,14 @@ package splitrt
import (
"net/netip"
+ "sync"
"github.com/telekom-mms/oc-daemon/internal/addrmon"
)
// Addresses is a set of addresses.
type Addresses struct {
+ sync.Mutex
m map[int][]*addrmon.Update
}
@@ -26,6 +28,9 @@ func (a *Addresses) contains(addr *addrmon.Update) bool {
// Add adds address info in addr to addresses.
func (a *Addresses) Add(addr *addrmon.Update) {
+ a.Lock()
+ defer a.Unlock()
+
if a.contains(addr) {
return
}
@@ -34,6 +39,9 @@ func (a *Addresses) Add(addr *addrmon.Update) {
// Remove removes address info in addr from addresses.
func (a *Addresses) Remove(addr *addrmon.Update) {
+ a.Lock()
+ defer a.Unlock()
+
if !a.contains(addr) {
return
}
@@ -52,12 +60,30 @@ func (a *Addresses) Remove(addr *addrmon.Update) {
// Get returns the addresses of the device identified by index.
func (a *Addresses) Get(index int) (addrs []netip.Prefix) {
+ a.Lock()
+ defer a.Unlock()
+
for _, x := range a.m[index] {
addrs = append(addrs, x.Address)
}
return
}
+// List returns all addresses.
+func (a *Addresses) List() []*addrmon.Update {
+ a.Lock()
+ defer a.Unlock()
+
+ var addresses []*addrmon.Update
+ for _, v := range a.m {
+ for _, addr := range v {
+ address := *addr
+ addresses = append(addresses, &address)
+ }
+ }
+ return addresses
+}
+
// NewAddresses returns new Addresses.
func NewAddresses() *Addresses {
return &Addresses{
diff --git a/internal/splitrt/devices.go b/internal/splitrt/devices.go
index 64b1352f..3f5d2e61 100644
--- a/internal/splitrt/devices.go
+++ b/internal/splitrt/devices.go
@@ -1,21 +1,30 @@
package splitrt
import (
+ "sync"
+
"github.com/telekom-mms/oc-daemon/internal/devmon"
)
// Devices is a set of devices.
type Devices struct {
+ sync.Mutex
m map[int]*devmon.Update
}
// Add adds device info in dev to devices.
func (d *Devices) Add(dev *devmon.Update) {
+ d.Lock()
+ defer d.Unlock()
+
d.m[dev.Index] = dev
}
// Remove removes device info in dev from devices.
func (d *Devices) Remove(dev *devmon.Update) {
+ d.Lock()
+ defer d.Unlock()
+
delete(d.m, dev.Index)
}
@@ -35,19 +44,41 @@ func (d *Devices) getType(match bool, typ string) (indexes []int) {
// GetReal returns device indexes of all real devices.
func (d *Devices) GetReal() []int {
+ d.Lock()
+ defer d.Unlock()
+
return d.getType(true, "device")
}
// GetVirtual returns device indexes of all virtual devices.
func (d *Devices) GetVirtual() []int {
+ d.Lock()
+ defer d.Unlock()
+
return d.getType(false, "device")
}
// GetAll returns all device indexes.
func (d *Devices) GetAll() []int {
+ d.Lock()
+ defer d.Unlock()
+
return d.getType(false, "")
}
+// List returns all devices.
+func (d *Devices) List() []*devmon.Update {
+ d.Lock()
+ defer d.Unlock()
+
+ var devices []*devmon.Update
+ for _, v := range d.m {
+ device := *v
+ devices = append(devices, &device)
+ }
+ return devices
+}
+
// NewDevices returns new Devices.
func NewDevices() *Devices {
return &Devices{
diff --git a/internal/splitrt/excludes.go b/internal/splitrt/excludes.go
index 63c3ad54..d1a1f789 100644
--- a/internal/splitrt/excludes.go
+++ b/internal/splitrt/excludes.go
@@ -197,6 +197,21 @@ func (e *Excludes) Stop() {
log.Debug("SplitRouting stopped periodic cleanup of excludes")
}
+// List returns the list of static and dynamic excludes.
+func (e *Excludes) List() (static, dynamic []string) {
+ e.Lock()
+ defer e.Unlock()
+
+ for k := range e.s {
+ static = append(static, k)
+ }
+ for k := range e.d {
+ dynamic = append(dynamic, k.String())
+ }
+
+ return
+}
+
// NewExcludes returns new split excludes.
func NewExcludes() *Excludes {
return &Excludes{
diff --git a/internal/splitrt/splitrt.go b/internal/splitrt/splitrt.go
index f77820a0..bb06d2b3 100644
--- a/internal/splitrt/splitrt.go
+++ b/internal/splitrt/splitrt.go
@@ -5,6 +5,7 @@ import (
"context"
"fmt"
"net/netip"
+ "sync"
log "github.com/sirupsen/logrus"
"github.com/telekom-mms/oc-daemon/internal/addrmon"
@@ -13,6 +14,39 @@ import (
"github.com/telekom-mms/oc-daemon/pkg/vpnconfig"
)
+// State is the internal state.
+type State struct {
+ Config *Config
+ VPNConfig *vpnconfig.Config
+ Devices []*devmon.Update
+ Addresses []*addrmon.Update
+ LocalExcludes []string
+ StaticExcludes []string
+ DynamicExcludes []string
+}
+
+// locals are local excludes.
+type locals struct {
+ sync.Mutex
+ l []netip.Prefix
+}
+
+// get returns the local excludes, returned slice should not be modified.
+func (l *locals) get() []netip.Prefix {
+ l.Lock()
+ defer l.Unlock()
+
+ return l.l
+}
+
+// set sets the local excludes.
+func (l *locals) set(locals []netip.Prefix) {
+ l.Lock()
+ defer l.Unlock()
+
+ l.l = locals
+}
+
// SplitRouting is a split routing configuration.
type SplitRouting struct {
config *Config
@@ -21,7 +55,7 @@ type SplitRouting struct {
addrmon *addrmon.AddrMon
devices *Devices
addrs *Addresses
- locals []netip.Prefix
+ locals locals
excludes *Excludes
dnsreps chan *dnsproxy.Report
done chan struct{}
@@ -151,21 +185,21 @@ func (s *SplitRouting) updateLocalNetworkExcludes(ctx context.Context) {
// add new excludes
for _, e := range excludes {
- if !isIn(e, s.locals) {
+ if !isIn(e, s.locals.get()) {
s.excludes.AddStatic(ctx, e)
}
}
// remove old excludes
- for _, l := range s.locals {
+ for _, l := range s.locals.get() {
if !isIn(l, excludes) {
s.excludes.RemoveStatic(ctx, l)
}
}
// save local excludes
- s.locals = excludes
- log.WithField("locals", s.locals).Debug("SplitRouting updated exclude local networks")
+ s.locals.set(excludes)
+ log.WithField("locals", s.locals.get()).Debug("SplitRouting updated exclude local networks")
}
// handleDeviceUpdate handles a device update from the device monitor.
@@ -270,6 +304,24 @@ func (s *SplitRouting) DNSReports() chan *dnsproxy.Report {
return s.dnsreps
}
+// GetState returns the internal state.
+func (s *SplitRouting) GetState() *State {
+ var locals []string
+ for _, l := range s.locals.get() {
+ locals = append(locals, l.String())
+ }
+ static, dynamic := s.excludes.List()
+ return &State{
+ Config: s.config,
+ VPNConfig: s.vpnconfig,
+ Devices: s.devices.List(),
+ Addresses: s.addrs.List(),
+ LocalExcludes: locals,
+ StaticExcludes: static,
+ DynamicExcludes: dynamic,
+ }
+}
+
// NewSplitRouting returns a new SplitRouting.
func NewSplitRouting(config *Config, vpnconfig *vpnconfig.Config) *SplitRouting {
return &SplitRouting{
diff --git a/internal/splitrt/splitrt_test.go b/internal/splitrt/splitrt_test.go
index 54f8e433..e7832af0 100644
--- a/internal/splitrt/splitrt_test.go
+++ b/internal/splitrt/splitrt_test.go
@@ -284,6 +284,55 @@ func TestSplitRoutingDNSReports(t *testing.T) {
}
}
+// TestSplitRoutingGetState tests GetState of SplitRouting.
+func TestSplitRoutingGetState(t *testing.T) {
+ s := NewSplitRouting(NewConfig(), vpnconfig.New())
+
+ // set devices
+ dev := &devmon.Update{
+ Add: true,
+ Device: "test",
+ Type: "test",
+ Index: 1,
+ }
+ s.devices.Add(dev)
+
+ // set addresses
+ addr := &addrmon.Update{
+ Add: true,
+ Address: netip.MustParsePrefix("192.168.1.0/24"),
+ Index: 1,
+ }
+ s.addrs.Add(addr)
+
+ // set local excludes
+ locals := []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")}
+ s.locals.set(locals)
+
+ // set static excludes
+ static := netip.MustParsePrefix("10.1.0.0/24")
+ s.excludes.s[static.String()] = static
+
+ // set dynamic excludes
+ dynamic := netip.MustParseAddr("10.2.0.1")
+ s.excludes.d[dynamic] = &dynExclude{}
+
+ // get and check state
+ want := &State{
+ Config: NewConfig(),
+ VPNConfig: vpnconfig.New(),
+ Devices: []*devmon.Update{dev},
+ Addresses: []*addrmon.Update{addr},
+ LocalExcludes: []string{"10.0.0.0/24"},
+ StaticExcludes: []string{"10.1.0.0/24"},
+ DynamicExcludes: []string{"10.2.0.1"},
+ }
+ got := s.GetState()
+ if !reflect.DeepEqual(got, want) {
+ t.Errorf("got %v, want %v", got, want)
+ }
+}
+
// TestNewSplitRouting tests NewSplitRouting.
func TestNewSplitRouting(t *testing.T) {
config := NewConfig()
diff --git a/internal/trafpol/allowaddrs.go b/internal/trafpol/allowaddrs.go
new file mode 100644
index 00000000..f8cc3550
--- /dev/null
+++ b/internal/trafpol/allowaddrs.go
@@ -0,0 +1,46 @@
+package trafpol
+
+import "net/netip"
+
+// AllowAddrs are allowed addresses.
+type AllowAddrs struct {
+ m map[string]netip.Prefix
+}
+
+// Add adds prefix to the allowed addresses.
+func (a *AllowAddrs) Add(prefix netip.Prefix) bool {
+ s := prefix.String()
+ if _, ok := a.m[s]; ok {
+ // ip already in allowed addrs
+ return false
+ }
+ a.m[s] = prefix
+ return true
+}
+
+// Remove removes prefix from the allowed addresses.
+func (a *AllowAddrs) Remove(prefix netip.Prefix) bool {
+ s := prefix.String()
+ if _, ok := a.m[s]; !ok {
+ // ip not in allowed addrs
+ return false
+ }
+ delete(a.m, s)
+ return true
+}
+
+// List returns a list of all allowed addresses.
+func (a *AllowAddrs) List() []netip.Prefix {
+ var prefixes []netip.Prefix
+ for _, p := range a.m {
+ prefixes = append(prefixes, p)
+ }
+ return prefixes
+}
+
+// NewAllowAddrs returns new AllowAddrs.
+func NewAllowAddrs() *AllowAddrs {
+ return &AllowAddrs{
+ m: make(map[string]netip.Prefix),
+ }
+}
diff --git a/internal/trafpol/allowdevs.go b/internal/trafpol/allowdevs.go
index ea54fef6..a916450e 100644
--- a/internal/trafpol/allowdevs.go
+++ b/internal/trafpol/allowdevs.go
@@ -23,6 +23,15 @@ func (a *AllowDevs) Remove(ctx context.Context, device string) {
}
}
+// List returns a slice of all allowed devices.
+func (a *AllowDevs) List() []string {
+ var l []string
+ for _, v := range a.m {
+ l = append(l, v)
+ }
+ return l
+}
+
// NewAllowDevs returns new allowDevs.
func NewAllowDevs() *AllowDevs {
return &AllowDevs{
diff --git a/internal/trafpol/allowdevs_test.go b/internal/trafpol/allowdevs_test.go
index b50cc7c5..1582fd5a 100644
--- a/internal/trafpol/allowdevs_test.go
+++ b/internal/trafpol/allowdevs_test.go
@@ -3,6 +3,7 @@ package trafpol
import (
"context"
"reflect"
+ "slices"
"testing"
"github.com/telekom-mms/oc-daemon/internal/execs"
@@ -66,6 +67,26 @@ func TestAllowDevsRemove(t *testing.T) {
}
}
+// TestAllowDevsList tests List of AllowDevs.
+func TestAllowDevsList(t *testing.T) {
+ a := NewAllowDevs()
+ ctx := context.Background()
+
+ oldRunCmd := execs.RunCmd
+ execs.RunCmd = func(_ context.Context, _, _ string, _ ...string) ([]byte, []byte, error) {
+ return nil, nil, nil
+ }
+ defer func() { execs.RunCmd = oldRunCmd }()
+
+ a.Add(ctx, "test")
+
+ want := []string{"test"}
+ got := a.List()
+ if !slices.Equal(got, want) {
+ t.Errorf("got %v, want %v", got, want)
+ }
+}
+
// TestNewAllowDevs tests NewAllowDevs.
func TestNewAllowDevs(t *testing.T) {
a := NewAllowDevs()
diff --git a/internal/trafpol/allownames.go b/internal/trafpol/allownames.go
new file mode 100644
index 00000000..47bc77e4
--- /dev/null
+++ b/internal/trafpol/allownames.go
@@ -0,0 +1,30 @@
+package trafpol
+
+import "net/netip"
+
+// AllowNames are allowed DNS names.
+type AllowNames struct {
+ m map[string][]netip.Addr
+}
+
+// Add adds and updates the allowed name and its IP addresses.
+func (a *AllowNames) Add(name string, addrs []netip.Addr) {
+ a.m[name] = addrs
+}
+
+// GetAll returns all allowed names with their IP addresses.
+func (a *AllowNames) GetAll() map[string][]netip.Addr {
+ names := make(map[string][]netip.Addr)
+ for k, v := range a.m {
+ names[k] = append(names[k], v...)
+
+ }
+ return names
+}
+
+// NewAllowNames returns new AllowNames.
+func NewAllowNames() *AllowNames {
+ return &AllowNames{
+ m: make(map[string][]netip.Addr),
+ }
+}
diff --git a/internal/trafpol/trafpol.go b/internal/trafpol/trafpol.go
index a904baeb..60c8d9f1 100644
--- a/internal/trafpol/trafpol.go
+++ b/internal/trafpol/trafpol.go
@@ -12,12 +12,29 @@ import (
"github.com/telekom-mms/oc-daemon/internal/dnsmon"
)
-// trafPolAddrCmd is a TrafPol address command.
-type trafPolAddrCmd struct {
- add bool
- ip netip.Addr
- ok bool
- done chan struct{}
+// TrafPol command types.
+const (
+ trafPolCmdAddAddress uint8 = iota + 1
+ trafPolCmdRemoveAddress
+ trafPolCmdGetState
+)
+
+// State is the internal TrafPol state.
+type State struct {
+ Config *Config
+ CaptivePortal bool
+ AllowedDevices []string
+ AllowedAddresses []netip.Prefix
+ AllowedNames map[string][]netip.Addr
+}
+
+// trafPolCmd is a TrafPol command.
+type trafPolCmd struct {
+ typ uint8
+ ip netip.Addr
+ ok bool
+ state *State
+ done chan struct{}
}
// TrafPol is a traffic policing component.
@@ -32,15 +49,15 @@ type TrafPol struct {
// allowed devices, addresses, names
allowDevs *AllowDevs
- allowAddrs map[string]netip.Prefix
- allowNames map[string][]netip.Addr
+ allowAddrs *AllowAddrs
+ allowNames *AllowNames
// resolver for allowed names, channel for resolver updates
resolver *Resolver
resolvUp chan *ResolvedName
// address commands channel
- cmds chan *trafPolAddrCmd
+ cmds chan *trafPolCmd
loopDone chan struct{}
done chan struct{}
@@ -101,13 +118,13 @@ func (t *TrafPol) getAllowedHostsIPs() []netip.Prefix {
// - allowed names
// - allowed addrs
ipset := make(map[string]netip.Prefix)
- for _, n := range t.allowNames {
+ for _, n := range t.allowNames.GetAll() {
for _, ip := range n {
prefix := netip.PrefixFrom(ip, ip.BitLen())
ipset[prefix.String()] = prefix
}
}
- for _, a := range t.allowAddrs {
+ for _, a := range t.allowAddrs.List() {
ipset[a.String()] = a
}
@@ -123,33 +140,28 @@ func (t *TrafPol) getAllowedHostsIPs() []netip.Prefix {
// handleResolverUpdate handles a resolver update.
func (t *TrafPol) handleResolverUpdate(ctx context.Context, update *ResolvedName) {
// update allowed names
- t.allowNames[update.Name] = update.IPs
+ t.allowNames.Add(update.Name, update.IPs)
// set new filter rules
setAllowedIPs(ctx, t.getAllowedHostsIPs())
}
// handleAddressCommand handles an address command.
-func (t *TrafPol) handleAddressCommand(ctx context.Context, cmd *trafPolAddrCmd) {
- defer close(cmd.done)
-
+func (t *TrafPol) handleAddressCommand(ctx context.Context, cmd *trafPolCmd) {
// convert to prefix
prefix := netip.PrefixFrom(cmd.ip, cmd.ip.BitLen())
// update allowed addrs
- s := prefix.String()
- if cmd.add {
- if _, ok := t.allowAddrs[s]; ok {
+ if cmd.typ == trafPolCmdAddAddress {
+ if ok := t.allowAddrs.Add(prefix); !ok {
// ip already in allowed addrs
return
}
- t.allowAddrs[s] = prefix
} else {
- if _, ok := t.allowAddrs[s]; !ok {
+ if ok := t.allowAddrs.Remove(prefix); !ok {
// ip not in allowed addrs
return
}
- delete(t.allowAddrs, s)
}
// set new filter rules
@@ -159,6 +171,30 @@ func (t *TrafPol) handleAddressCommand(ctx context.Context, cmd *trafPolAddrCmd)
cmd.ok = true
}
+// handleGetStateCommand handles a get state command.
+func (t *TrafPol) handleGetStateCommand(cmd *trafPolCmd) {
+ // set state
+ cmd.state = &State{
+ Config: t.config,
+ CaptivePortal: t.capPortal,
+ AllowedDevices: t.allowDevs.List(),
+ AllowedAddresses: t.allowAddrs.List(),
+ AllowedNames: t.allowNames.GetAll(),
+ }
+}
+
+// handleCommand handles a command.
+func (t *TrafPol) handleCommand(ctx context.Context, cmd *trafPolCmd) {
+ defer close(cmd.done)
+
+ switch cmd.typ {
+ case trafPolCmdAddAddress, trafPolCmdRemoveAddress:
+ t.handleAddressCommand(ctx, cmd)
+ case trafPolCmdGetState:
+ t.handleGetStateCommand(cmd)
+ }
+}
+
// start starts the traffic policing component.
func (t *TrafPol) start(ctx context.Context) {
defer close(t.loopDone)
@@ -192,9 +228,9 @@ func (t *TrafPol) start(ctx context.Context) {
t.handleResolverUpdate(ctx, u)
case c := <-t.cmds:
- // Address Command
- log.WithField("command", c).Debug("TrafPol got address command")
- t.handleAddressCommand(ctx, c)
+ // Command
+ log.WithField("command", c).Debug("TrafPol got command")
+ t.handleCommand(ctx, c)
case <-t.done:
// shutdown
@@ -264,8 +300,8 @@ func (t *TrafPol) AddAllowedAddr(addr netip.Addr) (ok bool) {
log.WithField("addr", addr).
Debug("TrafPol adding IP to allowed addresses")
- c := &trafPolAddrCmd{
- add: true,
+ c := &trafPolCmd{
+ typ: trafPolCmdAddAddress,
ip: addr,
done: make(chan struct{}),
}
@@ -280,7 +316,8 @@ func (t *TrafPol) RemoveAllowedAddr(addr netip.Addr) (ok bool) {
log.WithField("addr", addr).
Debug("TrafPol removing IP from allowed addresses")
- c := &trafPolAddrCmd{
+ c := &trafPolCmd{
+ typ: trafPolCmdRemoveAddress,
ip: addr,
done: make(chan struct{}),
}
@@ -290,6 +327,20 @@ func (t *TrafPol) RemoveAllowedAddr(addr netip.Addr) (ok bool) {
return c.ok
}
+// GetState returns the internal TrafPol state.
+func (t *TrafPol) GetState() *State {
+ log.Debug("TrafPol getting internal state")
+
+ c := &trafPolCmd{
+ typ: trafPolCmdGetState,
+ done: make(chan struct{}),
+ }
+ t.cmds <- c
+ <-c.done
+
+ return c.state
+}
+
// parseAllowedHosts parses the allowed hosts and returns IP addresses and DNS names
func parseAllowedHosts(hosts []string) (addrs []netip.Prefix, names []string) {
for _, h := range hosts {
@@ -321,13 +372,13 @@ func NewTrafPol(config *Config) *TrafPol {
a, n := parseAllowedHosts(hosts)
// create allowed addrs and names
- addrs := make(map[string]netip.Prefix)
- names := make(map[string][]netip.Addr)
+ addrs := NewAllowAddrs()
+ names := NewAllowNames()
for _, addr := range a {
- addrs[addr.String()] = addr
+ addrs.Add(addr)
}
for _, name := range n {
- names[name] = []netip.Addr{}
+ names.Add(name, []netip.Addr{})
}
// create channel for resolver updates
@@ -347,7 +398,7 @@ func NewTrafPol(config *Config) *TrafPol {
resolver: NewResolver(config, n, resolvUp),
resolvUp: resolvUp,
- cmds: make(chan *trafPolAddrCmd),
+ cmds: make(chan *trafPolCmd),
loopDone: make(chan struct{}),
done: make(chan struct{}),
diff --git a/internal/trafpol/trafpol_test.go b/internal/trafpol/trafpol_test.go
index dfe6d911..108d9dc4 100644
--- a/internal/trafpol/trafpol_test.go
+++ b/internal/trafpol/trafpol_test.go
@@ -4,6 +4,7 @@ import (
"context"
"net/netip"
"reflect"
+ "slices"
"sort"
"sync"
"testing"
@@ -132,10 +133,10 @@ func TestTrafPolGetAllowedHostsIPs(t *testing.T) {
tp := NewTrafPol(c)
// add allowed names
- tp.allowNames["example.com"] = []netip.Addr{
+ tp.allowNames.Add("example.com", []netip.Addr{
netip.MustParseAddr("192.168.1.1"),
netip.MustParseAddr("2001:DB8:1::1"),
- }
+ })
// wanted IPs
want := []netip.Prefix{}
@@ -202,9 +203,9 @@ func TestTrafPolAddRemoveAllowedAddr(t *testing.T) {
t.Errorf("address not added")
}
- want := prefix.String()
- got := tp.allowAddrs[prefix.String()].String()
- if got != want {
+ want := []netip.Prefix{prefix}
+ got := tp.allowAddrs.List()
+ if !slices.Equal(got, want) {
t.Errorf("got %s, want %s", got, want)
}
@@ -218,9 +219,9 @@ func TestTrafPolAddRemoveAllowedAddr(t *testing.T) {
t.Errorf("address not removed")
}
- want = netip.Prefix{}.String()
- got = tp.allowAddrs[prefix.String()].String()
- if got != want {
+ want = []netip.Prefix{}
+ got = tp.allowAddrs.List()
+ if !slices.Equal(got, want) {
t.Errorf("got %s, want %s", got, want)
}
@@ -241,6 +242,30 @@ func TestTrafPolAddRemoveAllowedAddr(t *testing.T) {
tp.Stop()
}
+// TestTrafPolGetState tests GetState of Trafpol.
+func TestTrafPolGetState(t *testing.T) {
+ // set dummy low level function for devmon
+ oldRegisterLinkUpdates := devmon.RegisterLinkUpdates
+ devmon.RegisterLinkUpdates = func(*devmon.DevMon) (chan netlink.LinkUpdate, error) {
+ return nil, nil
+ }
+ defer func() { devmon.RegisterLinkUpdates = oldRegisterLinkUpdates }()
+
+ // start trafpol
+ tp := NewTrafPol(NewConfig())
+ if err := tp.Start(); err != nil {
+ t.Fatal(err)
+ }
+
+ // check state
+ if tp.GetState() == nil {
+ t.Errorf("got invalid state")
+ }
+
+ // stop trafpol
+ tp.Stop()
+}
+
// TestNewTrafPol tests NewTrafPol.
func TestNewTrafPol(t *testing.T) {
c := NewConfig()
diff --git a/internal/vpnsetup/vpnsetup.go b/internal/vpnsetup/vpnsetup.go
index a8a60ecc..e0b7c5c6 100644
--- a/internal/vpnsetup/vpnsetup.go
+++ b/internal/vpnsetup/vpnsetup.go
@@ -19,12 +19,20 @@ import (
const (
commandSetup uint8 = iota
commandTeardown
+ commandGetState
)
+// State is the internal state of the VPN Setup.
+type State struct {
+ SplitRouting *splitrt.State
+ DNSProxy *dnsproxy.State
+}
+
// command is a VPNSetup command.
type command struct {
cmd uint8
vpnconf *vpnconfig.Config
+ state *State
done chan struct{}
}
@@ -434,6 +442,18 @@ func (v *VPNSetup) teardown(ctx context.Context, vpnconf *vpnconfig.Config) {
v.teardownDNS(ctx, vpnconf)
}
+// getState gets the internal state.
+func (v *VPNSetup) getState(c *command) {
+ state := &State{}
+ if v.splitrt != nil {
+ state.SplitRouting = v.splitrt.GetState()
+ }
+ if v.dnsProxy != nil {
+ state.DNSProxy = v.dnsProxy.GetState()
+ }
+ c.state = state
+}
+
// handleCommand handles a command.
func (v *VPNSetup) handleCommand(ctx context.Context, c *command) {
defer close(c.done)
@@ -443,6 +463,8 @@ func (v *VPNSetup) handleCommand(ctx context.Context, c *command) {
v.setup(ctx, c.vpnconf)
case commandTeardown:
v.teardown(ctx, c.vpnconf)
+ case commandGetState:
+ v.getState(c)
}
}
@@ -519,6 +541,17 @@ func (v *VPNSetup) Teardown(vpnconfig *vpnconfig.Config) {
<-c.done
}
+// GetState returns the internal state of the VPN config.
+func (v *VPNSetup) GetState() *State {
+ c := &command{
+ cmd: commandGetState,
+ done: make(chan struct{}),
+ }
+ v.cmds <- c
+ <-c.done
+ return c.state
+}
+
// NewVPNSetup returns a new VPNSetup.
func NewVPNSetup(
dnsProxyConfig *dnsproxy.Config,
diff --git a/internal/vpnsetup/vpnsetup_test.go b/internal/vpnsetup/vpnsetup_test.go
index 2cd577bd..bd577431 100644
--- a/internal/vpnsetup/vpnsetup_test.go
+++ b/internal/vpnsetup/vpnsetup_test.go
@@ -378,6 +378,56 @@ func TestVPNSetupSetupTeardown(_ *testing.T) {
v.Stop()
}
+// TestVPNSetupGetState tests GetState of VPNSetup.
+func TestVPNSetupGetState(t *testing.T) {
+ // override functions
+ oldCmd := execs.RunCmd
+ execs.RunCmd = func(context.Context, string, string, ...string) ([]byte, []byte, error) {
+ return nil, nil, nil
+ }
+ defer func() { execs.RunCmd = oldCmd }()
+
+ oldRegisterAddrUpdates := addrmon.RegisterAddrUpdates
+ addrmon.RegisterAddrUpdates = func(*addrmon.AddrMon) (chan netlink.AddrUpdate, error) {
+ return nil, nil
+ }
+ defer func() { addrmon.RegisterAddrUpdates = oldRegisterAddrUpdates }()
+
+ oldRegisterLinkUpdates := devmon.RegisterLinkUpdates
+ devmon.RegisterLinkUpdates = func(*devmon.DevMon) (chan netlink.LinkUpdate, error) {
+ return nil, nil
+ }
+ defer func() { devmon.RegisterLinkUpdates = oldRegisterLinkUpdates }()
+
+ // start vpn setup
+ v := NewVPNSetup(dnsproxy.NewConfig(), splitrt.NewConfig())
+ v.Start()
+
+ // without vpn config
+ got := v.GetState()
+ if got == nil ||
+ got.SplitRouting != nil ||
+ got.DNSProxy == nil {
+ t.Errorf("got invalid state: %v", got)
+ }
+
+ // with vpn config
+ vpnconf := vpnconfig.New()
+ v.Setup(vpnconf)
+
+ got = v.GetState()
+ if got == nil ||
+ got.SplitRouting == nil ||
+ got.DNSProxy == nil {
+ t.Errorf("got invalid state: %v", got)
+ }
+
+ // teardown config
+ v.Teardown(vpnconf)
+
+ v.Stop()
+}
+
// TestNewVPNSetup tests NewVPNSetup.
func TestNewVPNSetup(t *testing.T) {
dnsConfig := dnsproxy.NewConfig()
diff --git a/pkg/client/client.go b/pkg/client/client.go
index 3d325287..926cdb0c 100644
--- a/pkg/client/client.go
+++ b/pkg/client/client.go
@@ -35,6 +35,8 @@ type Client interface {
Connect() error
Disconnect() error
+ DumpState() (string, error)
+
Close() error
}
@@ -576,6 +578,20 @@ func (d *DBusClient) Disconnect() error {
return disconnect(d)
}
+// dumpState sends a dump state request to the daemon.
+var dumpState = func(d *DBusClient) (string, error) {
+ // call dump state
+ state := ""
+ err := d.conn.Object(dbusapi.Interface, dbusapi.Path).
+ Call(dbusapi.MethodDumpState, 0).Store(&state)
+ return state, err
+}
+
+// DumpState returns the internal state of the OC-Daemon as string.
+func (d *DBusClient) DumpState() (string, error) {
+ return dumpState(d)
+}
+
// Close closes the DBusClient.
func (d *DBusClient) Close() error {
var err error
diff --git a/pkg/client/client_test.go b/pkg/client/client_test.go
index 16f0078e..cbcf807e 100644
--- a/pkg/client/client_test.go
+++ b/pkg/client/client_test.go
@@ -423,6 +423,24 @@ func TestDBusClientDisconnect(t *testing.T) {
}
}
+// TestDBusClientDumpState tests DumpState of DBusClient.
+func TestDBusClientDumpState(t *testing.T) {
+ // clean up after tests
+ oldDumpState := dumpState
+ defer func() { dumpState = oldDumpState }()
+
+ // create test client
+ client := &DBusClient{}
+
+ dumpState = func(_ *DBusClient) (string, error) {
+ return "test state", nil
+ }
+ state, err := client.DumpState()
+ if err != nil || state != "test state" {
+ t.Error(err, state)
+ }
+}
+
// testRWC is a reader writer closer for testing.
type testRWC struct{}