diff --git a/configs/dbus/com.telekom_mms.oc_daemon.Daemon.conf b/configs/dbus/com.telekom_mms.oc_daemon.Daemon.conf index 0398fbd2..0daa096b 100644 --- a/configs/dbus/com.telekom_mms.oc_daemon.Daemon.conf +++ b/configs/dbus/com.telekom_mms.oc_daemon.Daemon.conf @@ -18,6 +18,10 @@ + + diff --git a/internal/client/client.go b/internal/client/client.go index 28ae158d..fad9f8a7 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -221,6 +221,24 @@ func getStatus() error { return printStatus(status) } +// dumpState dumps the internal state of the daemon. +func dumpState() error { + // create client + c, err := clientNewClient(config) + if err != nil { + return fmt.Errorf("error creating client: %w", err) + } + defer func() { _ = c.Close() }() + + // get status + state, err := c.DumpState() + if err != nil { + return fmt.Errorf("error getting status: %w", err) + } + fmt.Println(state) + return nil +} + // monitor subscribes to VPN status updates from the daemon and displays them. func monitor() error { // create client diff --git a/internal/client/client_test.go b/internal/client/client_test.go index 3a0a2229..c4ee5b61 100644 --- a/internal/client/client_test.go +++ b/internal/client/client_test.go @@ -14,6 +14,8 @@ import ( type testClient struct { querErr error status *vpnstatus.Status + dumpErr error + dumpSta string authErr error connErr error discErr error @@ -33,6 +35,7 @@ func (t *testClient) Subscribe() (chan *vpnstatus.Status, error) { return t.subs func (t *testClient) Authenticate() error { return t.authErr } func (t *testClient) Connect() error { return t.connErr } func (t *testClient) Disconnect() error { return t.discErr } +func (t *testClient) DumpState() (string, error) { return t.dumpSta, t.dumpErr } func (t *testClient) Close() error { return nil } // TestListServers tests listServers. @@ -205,6 +208,38 @@ func TestGetStatus(t *testing.T) { } } +// TestDumpState tests dumpState. +func TestDumpState(t *testing.T) { + defer func() { clientNewClient = client.NewClient }() + + // test with client error + clientNewClient = func(*client.Config) (client.Client, error) { + return nil, errors.New("test error") + } + + if err := dumpState(); err == nil { + t.Error("client error should return error") + } + + // test with dump state error + clientNewClient = func(*client.Config) (client.Client, error) { + return &testClient{dumpErr: errors.New("test error")}, nil + } + + if err := dumpState(); err == nil { + t.Error("dump state error should return error") + } + + // test without error + clientNewClient = func(*client.Config) (client.Client, error) { + return &testClient{dumpSta: "test state"}, nil + } + + if err := dumpState(); err != nil { + t.Error("dump state should not return error") + } +} + // TestMonitor tests monitor. func TestMonitor(t *testing.T) { defer func() { clientNewClient = client.NewClient }() diff --git a/internal/client/cmd.go b/internal/client/cmd.go index ce3da869..cea26f3f 100644 --- a/internal/client/cmd.go +++ b/internal/client/cmd.go @@ -243,6 +243,8 @@ func run(args []string) error { return reconnectVPN() case "status": return getStatus() + case "dumpstate": + return dumpState() case "monitor": return monitor() case "save": diff --git a/internal/client/cmd_test.go b/internal/client/cmd_test.go index 5f0b270a..c9e2eca4 100644 --- a/internal/client/cmd_test.go +++ b/internal/client/cmd_test.go @@ -126,6 +126,7 @@ func TestRun(t *testing.T) { "reconnect", "status", "monitor", + "dumpstate", } { if err := run([]string{"test", "-cert", "cert-file", diff --git a/internal/daemon/daemon.go b/internal/daemon/daemon.go index a2f3cd0e..7da02589 100644 --- a/internal/daemon/daemon.go +++ b/internal/daemon/daemon.go @@ -3,6 +3,7 @@ package daemon import ( "context" + "encoding/json" "fmt" "net" "net/netip" @@ -473,6 +474,33 @@ func (d *Daemon) handleClientRequest(request *api.Request) { } } +// dumpState returns the internal daemon state as json string. +func (d *Daemon) dumpState() string { + // define state type + type State struct { + TrafficPolicing *trafpol.State + VPNSetup *vpnsetup.State + } + + // collect internal state + state := State{} + if d.trafpol != nil { + state.TrafficPolicing = d.trafpol.GetState() + } + if d.vpnsetup != nil { + state.VPNSetup = d.vpnsetup.GetState() + } + + // convert to json + b, err := json.Marshal(state) + if err != nil { + log.WithError(err).Error("Daemon could not convert internal state to JSON") + return "" + } + + return string(b) +} + // handleDBusRequest handles a D-Bus API client request. func (d *Daemon) handleDBusRequest(request *dbusapi.Request) { defer request.Close() @@ -505,6 +533,12 @@ func (d *Daemon) handleDBusRequest(request *dbusapi.Request) { // diconnect VPN log.Info("Daemon got disconnect request from client") d.disconnectVPN() + + case dbusapi.RequestDumpState: + // dump state + state := d.dumpState() + log.WithField("state", state).Info("Daemon got dump state request from client") + request.Results = []any{state} } } diff --git a/internal/dbusapi/service.go b/internal/dbusapi/service.go index bb449fc9..fc5e8cdd 100644 --- a/internal/dbusapi/service.go +++ b/internal/dbusapi/service.go @@ -130,12 +130,14 @@ const ( const ( MethodConnect = Interface + ".Connect" MethodDisconnect = Interface + ".Disconnect" + MethodDumpState = Interface + ".DumpState" ) // Request Names. const ( RequestConnect = "Connect" RequestDisconnect = "Disconnect" + RequestDumpState = "DumpState" ) // Request is a D-Bus client request. @@ -212,6 +214,27 @@ func (d daemon) Disconnect(sender dbus.Sender) *dbus.Error { return nil } +// DumpState is the "DumpState" method of the D-Bus interface. +func (d daemon) DumpState(sender dbus.Sender) (string, *dbus.Error) { + log.WithField("sender", sender).Debug("Received D-Bus DumpState() call") + request := &Request{ + Name: RequestDumpState, + wait: make(chan struct{}), + done: d.done, + } + select { + case d.requests <- request: + case <-d.done: + return "", dbus.NewError(Interface+".DumpStateAborted", []any{"DumpState aborted"}) + } + + request.Wait() + if request.Error != nil { + return "", dbus.NewError(Interface+".DumpStateAborted", []any{request.Error.Error()}) + } + return request.Results[0].(string), nil +} + // propertyUpdate is an update of a property. type propertyUpdate struct { name string @@ -431,16 +454,18 @@ func (s *Service) Start() error { // set names of method arguments introMeths := introspect.Methods(meths) for _, m := range introMeths { - if m.Name != "Connect" { - continue + if m.Name == "Connect" { + m.Args[0].Name = "server" + m.Args[1].Name = "cookie" + m.Args[2].Name = "host" + m.Args[3].Name = "connect_url" + m.Args[4].Name = "fingerprint" + m.Args[5].Name = "resolve" } - m.Args[0].Name = "server" - m.Args[1].Name = "cookie" - m.Args[2].Name = "host" - m.Args[3].Name = "connect_url" - m.Args[4].Name = "fingerprint" - m.Args[5].Name = "resolve" + if m.Name == "DumpState" { + m.Args[0].Name = "state" + } } // set peer interface peerData := introspect.Interface{ diff --git a/internal/dbusapi/service_test.go b/internal/dbusapi/service_test.go index 41a9b5ce..32e4f852 100644 --- a/internal/dbusapi/service_test.go +++ b/internal/dbusapi/service_test.go @@ -167,6 +167,73 @@ func TestDaemonDisconnect(t *testing.T) { } } +// TestDaemonDumpStateErrors tests DumpState of daemon, errors. +func TestDaemonDumpStateErrors(t *testing.T) { + // create daemon + requests := make(chan *Request) + done := make(chan struct{}) + daemon := daemon{ + requests: requests, + done: done, + } + + // error when handling request + go func() { + r := <-requests + r.Error = errors.New("test error") + r.Close() + }() + if _, err := daemon.DumpState(""); err == nil { + t.Error("should return error") + } + + // closed daemon + close(done) + if _, err := daemon.DumpState(""); err == nil { + t.Error("should return error") + } +} + +// TestDaemonDumpState tests DumpState of daemon. +func TestDaemonDumpState(t *testing.T) { + // create daemon + requests := make(chan *Request) + done := make(chan struct{}) + daemon := daemon{ + requests: requests, + done: done, + } + + // run disconnect and get results + want := &Request{ + Name: RequestDumpState, + Results: []any{"test state"}, + done: done, + } + got := &Request{} + go func() { + r := <-requests + r.Results = append(r.Results, "test state") + got = r + r.Close() + }() + state, err := daemon.DumpState("sender") + if err != nil { + t.Error(err) + } + + // check results + if got.Name != want.Name || + !reflect.DeepEqual(got.Parameters, want.Parameters) || + !reflect.DeepEqual(got.Results, want.Results) || + got.Error != want.Error || + got.done != want.done || + state != "test state" { + // not equal + t.Errorf("got %v, want %v", got, want) + } +} + // testConn implements the dbusConn interface for testing. type testConn struct { reqNameReply dbus.RequestNameReply diff --git a/internal/dnsproxy/proxy.go b/internal/dnsproxy/proxy.go index be836fd8..0e72c465 100644 --- a/internal/dnsproxy/proxy.go +++ b/internal/dnsproxy/proxy.go @@ -9,6 +9,14 @@ import ( log "github.com/sirupsen/logrus" ) +// State is the internal state of the DNS Proxy. +type State struct { + Config *Config + Remotes map[string][]string + Watches []string + TempWatches []string +} + // Proxy is a DNS proxy. type Proxy struct { config *Config @@ -247,6 +255,17 @@ func (p *Proxy) SetWatches(watches []string) { } } +// GetState returns the internal state of the DNS Proxy. +func (p *Proxy) GetState() *State { + watches, tempWatches := p.watches.List() + return &State{ + Config: p.config, + Remotes: p.remotes.List(), + Watches: watches, + TempWatches: tempWatches, + } +} + // NewProxy returns a new Proxy that listens on address. func NewProxy(config *Config) *Proxy { var udp *dns.Server diff --git a/internal/dnsproxy/proxy_test.go b/internal/dnsproxy/proxy_test.go index 217b7ee2..5a580e80 100644 --- a/internal/dnsproxy/proxy_test.go +++ b/internal/dnsproxy/proxy_test.go @@ -4,6 +4,7 @@ import ( "errors" "net" "net/netip" + "reflect" "testing" "github.com/miekg/dns" @@ -272,6 +273,34 @@ func TestProxySetWatches(_ *testing.T) { p.SetWatches(watches) } +// TestProxyGetState tests GetState of Proxy. +func TestProxyGetState(t *testing.T) { + p := NewProxy(getTestConfig()) + + // set remotes + getRemotes := func() map[string][]string { + return map[string][]string{".": {"192.168.1.1"}} + } + p.SetRemotes(getRemotes()) + + // set watches + p.watches.Add("example.com.") + p.watches.AddTempCNAME("cname.example.com.", 300) + p.watches.AddTempDNAME("dname.example.com.", 300) + + // check state + want := &State{ + Config: getTestConfig(), + Remotes: getRemotes(), + Watches: []string{"example.com."}, + TempWatches: []string{"cname.example.com.", "dname.example.com."}, + } + got := p.GetState() + if !reflect.DeepEqual(got, want) { + t.Errorf("got %v, want %v", got, want) + } +} + // TestNewProxy tests NewProxy. func TestNewProxy(t *testing.T) { p := NewProxy(getTestConfig()) diff --git a/internal/dnsproxy/remotes.go b/internal/dnsproxy/remotes.go index 40b8838e..48096433 100644 --- a/internal/dnsproxy/remotes.go +++ b/internal/dnsproxy/remotes.go @@ -67,6 +67,18 @@ func (r *Remotes) Get(domain string) []string { return r.m["."] } +// List returns the remotes. +func (r *Remotes) List() map[string][]string { + r.RLock() + defer r.RUnlock() + + remotes := make(map[string][]string) + for k, v := range r.m { + remotes[k] = append(remotes[k], v...) + } + return remotes +} + // NewRemotes returns a new Remotes. func NewRemotes() *Remotes { return &Remotes{ diff --git a/internal/dnsproxy/watches.go b/internal/dnsproxy/watches.go index f5ff80f9..dd2d98e5 100644 --- a/internal/dnsproxy/watches.go +++ b/internal/dnsproxy/watches.go @@ -168,6 +168,23 @@ func (w *Watches) Contains(domain string) bool { return false } +// List returns the list of watches and temporary watches. +func (w *Watches) List() (watches, tempWatches []string) { + w.RLock() + defer w.RUnlock() + + for k := range w.m { + watches = append(watches, k) + } + for k := range w.c { + tempWatches = append(tempWatches, k) + } + for k := range w.d { + tempWatches = append(tempWatches, k) + } + return +} + // Close closes the watches func (w *Watches) Close() { close(w.done) diff --git a/internal/splitrt/addresses.go b/internal/splitrt/addresses.go index e009db88..e737ec3b 100644 --- a/internal/splitrt/addresses.go +++ b/internal/splitrt/addresses.go @@ -2,12 +2,14 @@ package splitrt import ( "net/netip" + "sync" "github.com/telekom-mms/oc-daemon/internal/addrmon" ) // Addresses is a set of addresses. type Addresses struct { + sync.Mutex m map[int][]*addrmon.Update } @@ -26,6 +28,9 @@ func (a *Addresses) contains(addr *addrmon.Update) bool { // Add adds address info in addr to addresses. func (a *Addresses) Add(addr *addrmon.Update) { + a.Lock() + defer a.Unlock() + if a.contains(addr) { return } @@ -34,6 +39,9 @@ func (a *Addresses) Add(addr *addrmon.Update) { // Remove removes address info in addr from addresses. func (a *Addresses) Remove(addr *addrmon.Update) { + a.Lock() + defer a.Unlock() + if !a.contains(addr) { return } @@ -52,12 +60,30 @@ func (a *Addresses) Remove(addr *addrmon.Update) { // Get returns the addresses of the device identified by index. func (a *Addresses) Get(index int) (addrs []netip.Prefix) { + a.Lock() + defer a.Unlock() + for _, x := range a.m[index] { addrs = append(addrs, x.Address) } return } +// List returns all addresses. +func (a *Addresses) List() []*addrmon.Update { + a.Lock() + defer a.Unlock() + + var addresses []*addrmon.Update + for _, v := range a.m { + for _, addr := range v { + address := *addr + addresses = append(addresses, &address) + } + } + return addresses +} + // NewAddresses returns new Addresses. func NewAddresses() *Addresses { return &Addresses{ diff --git a/internal/splitrt/devices.go b/internal/splitrt/devices.go index 64b1352f..3f5d2e61 100644 --- a/internal/splitrt/devices.go +++ b/internal/splitrt/devices.go @@ -1,21 +1,30 @@ package splitrt import ( + "sync" + "github.com/telekom-mms/oc-daemon/internal/devmon" ) // Devices is a set of devices. type Devices struct { + sync.Mutex m map[int]*devmon.Update } // Add adds device info in dev to devices. func (d *Devices) Add(dev *devmon.Update) { + d.Lock() + defer d.Unlock() + d.m[dev.Index] = dev } // Remove removes device info in dev from devices. func (d *Devices) Remove(dev *devmon.Update) { + d.Lock() + defer d.Unlock() + delete(d.m, dev.Index) } @@ -35,19 +44,41 @@ func (d *Devices) getType(match bool, typ string) (indexes []int) { // GetReal returns device indexes of all real devices. func (d *Devices) GetReal() []int { + d.Lock() + defer d.Unlock() + return d.getType(true, "device") } // GetVirtual returns device indexes of all virtual devices. func (d *Devices) GetVirtual() []int { + d.Lock() + defer d.Unlock() + return d.getType(false, "device") } // GetAll returns all device indexes. func (d *Devices) GetAll() []int { + d.Lock() + defer d.Unlock() + return d.getType(false, "") } +// List returns all devices. +func (d *Devices) List() []*devmon.Update { + d.Lock() + defer d.Unlock() + + var devices []*devmon.Update + for _, v := range d.m { + device := *v + devices = append(devices, &device) + } + return devices +} + // NewDevices returns new Devices. func NewDevices() *Devices { return &Devices{ diff --git a/internal/splitrt/excludes.go b/internal/splitrt/excludes.go index 63c3ad54..d1a1f789 100644 --- a/internal/splitrt/excludes.go +++ b/internal/splitrt/excludes.go @@ -197,6 +197,21 @@ func (e *Excludes) Stop() { log.Debug("SplitRouting stopped periodic cleanup of excludes") } +// List returns the list of static and dynamic excludes. +func (e *Excludes) List() (static, dynamic []string) { + e.Lock() + defer e.Unlock() + + for k := range e.s { + static = append(static, k) + } + for k := range e.d { + dynamic = append(dynamic, k.String()) + } + + return +} + // NewExcludes returns new split excludes. func NewExcludes() *Excludes { return &Excludes{ diff --git a/internal/splitrt/splitrt.go b/internal/splitrt/splitrt.go index f77820a0..bb06d2b3 100644 --- a/internal/splitrt/splitrt.go +++ b/internal/splitrt/splitrt.go @@ -5,6 +5,7 @@ import ( "context" "fmt" "net/netip" + "sync" log "github.com/sirupsen/logrus" "github.com/telekom-mms/oc-daemon/internal/addrmon" @@ -13,6 +14,39 @@ import ( "github.com/telekom-mms/oc-daemon/pkg/vpnconfig" ) +// State is the internal state. +type State struct { + Config *Config + VPNConfig *vpnconfig.Config + Devices []*devmon.Update + Addresses []*addrmon.Update + LocalExcludes []string + StaticExcludes []string + DynamicExcludes []string +} + +// locals are local excludes. +type locals struct { + sync.Mutex + l []netip.Prefix +} + +// get returns the local excludes, returned slice should not be modified. +func (l *locals) get() []netip.Prefix { + l.Lock() + defer l.Unlock() + + return l.l +} + +// set sets the local excludes. +func (l *locals) set(locals []netip.Prefix) { + l.Lock() + defer l.Unlock() + + l.l = locals +} + // SplitRouting is a split routing configuration. type SplitRouting struct { config *Config @@ -21,7 +55,7 @@ type SplitRouting struct { addrmon *addrmon.AddrMon devices *Devices addrs *Addresses - locals []netip.Prefix + locals locals excludes *Excludes dnsreps chan *dnsproxy.Report done chan struct{} @@ -151,21 +185,21 @@ func (s *SplitRouting) updateLocalNetworkExcludes(ctx context.Context) { // add new excludes for _, e := range excludes { - if !isIn(e, s.locals) { + if !isIn(e, s.locals.get()) { s.excludes.AddStatic(ctx, e) } } // remove old excludes - for _, l := range s.locals { + for _, l := range s.locals.get() { if !isIn(l, excludes) { s.excludes.RemoveStatic(ctx, l) } } // save local excludes - s.locals = excludes - log.WithField("locals", s.locals).Debug("SplitRouting updated exclude local networks") + s.locals.set(excludes) + log.WithField("locals", s.locals.get()).Debug("SplitRouting updated exclude local networks") } // handleDeviceUpdate handles a device update from the device monitor. @@ -270,6 +304,24 @@ func (s *SplitRouting) DNSReports() chan *dnsproxy.Report { return s.dnsreps } +// GetState returns the internal state. +func (s *SplitRouting) GetState() *State { + var locals []string + for _, l := range s.locals.get() { + locals = append(locals, l.String()) + } + static, dynamic := s.excludes.List() + return &State{ + Config: s.config, + VPNConfig: s.vpnconfig, + Devices: s.devices.List(), + Addresses: s.addrs.List(), + LocalExcludes: locals, + StaticExcludes: static, + DynamicExcludes: dynamic, + } +} + // NewSplitRouting returns a new SplitRouting. func NewSplitRouting(config *Config, vpnconfig *vpnconfig.Config) *SplitRouting { return &SplitRouting{ diff --git a/internal/splitrt/splitrt_test.go b/internal/splitrt/splitrt_test.go index 54f8e433..e7832af0 100644 --- a/internal/splitrt/splitrt_test.go +++ b/internal/splitrt/splitrt_test.go @@ -284,6 +284,55 @@ func TestSplitRoutingDNSReports(t *testing.T) { } } +// TestSplitRoutingGetState tests GetState of SplitRouting. +func TestSplitRoutingGetState(t *testing.T) { + s := NewSplitRouting(NewConfig(), vpnconfig.New()) + + // set devices + dev := &devmon.Update{ + Add: true, + Device: "test", + Type: "test", + Index: 1, + } + s.devices.Add(dev) + + // set addresses + addr := &addrmon.Update{ + Add: true, + Address: netip.MustParsePrefix("192.168.1.0/24"), + Index: 1, + } + s.addrs.Add(addr) + + // set local excludes + locals := []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")} + s.locals.set(locals) + + // set static excludes + static := netip.MustParsePrefix("10.1.0.0/24") + s.excludes.s[static.String()] = static + + // set dynamic excludes + dynamic := netip.MustParseAddr("10.2.0.1") + s.excludes.d[dynamic] = &dynExclude{} + + // get and check state + want := &State{ + Config: NewConfig(), + VPNConfig: vpnconfig.New(), + Devices: []*devmon.Update{dev}, + Addresses: []*addrmon.Update{addr}, + LocalExcludes: []string{"10.0.0.0/24"}, + StaticExcludes: []string{"10.1.0.0/24"}, + DynamicExcludes: []string{"10.2.0.1"}, + } + got := s.GetState() + if !reflect.DeepEqual(got, want) { + t.Errorf("got %v, want %v", got, want) + } +} + // TestNewSplitRouting tests NewSplitRouting. func TestNewSplitRouting(t *testing.T) { config := NewConfig() diff --git a/internal/trafpol/allowaddrs.go b/internal/trafpol/allowaddrs.go new file mode 100644 index 00000000..f8cc3550 --- /dev/null +++ b/internal/trafpol/allowaddrs.go @@ -0,0 +1,46 @@ +package trafpol + +import "net/netip" + +// AllowAddrs are allowed addresses. +type AllowAddrs struct { + m map[string]netip.Prefix +} + +// Add adds prefix to the allowed addresses. +func (a *AllowAddrs) Add(prefix netip.Prefix) bool { + s := prefix.String() + if _, ok := a.m[s]; ok { + // ip already in allowed addrs + return false + } + a.m[s] = prefix + return true +} + +// Remove removes prefix from the allowed addresses. +func (a *AllowAddrs) Remove(prefix netip.Prefix) bool { + s := prefix.String() + if _, ok := a.m[s]; !ok { + // ip not in allowed addrs + return false + } + delete(a.m, s) + return true +} + +// List returns a list of all allowed addresses. +func (a *AllowAddrs) List() []netip.Prefix { + var prefixes []netip.Prefix + for _, p := range a.m { + prefixes = append(prefixes, p) + } + return prefixes +} + +// NewAllowAddrs returns new AllowAddrs. +func NewAllowAddrs() *AllowAddrs { + return &AllowAddrs{ + m: make(map[string]netip.Prefix), + } +} diff --git a/internal/trafpol/allowdevs.go b/internal/trafpol/allowdevs.go index ea54fef6..a916450e 100644 --- a/internal/trafpol/allowdevs.go +++ b/internal/trafpol/allowdevs.go @@ -23,6 +23,15 @@ func (a *AllowDevs) Remove(ctx context.Context, device string) { } } +// List returns a slice of all allowed devices. +func (a *AllowDevs) List() []string { + var l []string + for _, v := range a.m { + l = append(l, v) + } + return l +} + // NewAllowDevs returns new allowDevs. func NewAllowDevs() *AllowDevs { return &AllowDevs{ diff --git a/internal/trafpol/allowdevs_test.go b/internal/trafpol/allowdevs_test.go index b50cc7c5..1582fd5a 100644 --- a/internal/trafpol/allowdevs_test.go +++ b/internal/trafpol/allowdevs_test.go @@ -3,6 +3,7 @@ package trafpol import ( "context" "reflect" + "slices" "testing" "github.com/telekom-mms/oc-daemon/internal/execs" @@ -66,6 +67,26 @@ func TestAllowDevsRemove(t *testing.T) { } } +// TestAllowDevsList tests List of AllowDevs. +func TestAllowDevsList(t *testing.T) { + a := NewAllowDevs() + ctx := context.Background() + + oldRunCmd := execs.RunCmd + execs.RunCmd = func(_ context.Context, _, _ string, _ ...string) ([]byte, []byte, error) { + return nil, nil, nil + } + defer func() { execs.RunCmd = oldRunCmd }() + + a.Add(ctx, "test") + + want := []string{"test"} + got := a.List() + if !slices.Equal(got, want) { + t.Errorf("got %v, want %v", got, want) + } +} + // TestNewAllowDevs tests NewAllowDevs. func TestNewAllowDevs(t *testing.T) { a := NewAllowDevs() diff --git a/internal/trafpol/allownames.go b/internal/trafpol/allownames.go new file mode 100644 index 00000000..47bc77e4 --- /dev/null +++ b/internal/trafpol/allownames.go @@ -0,0 +1,30 @@ +package trafpol + +import "net/netip" + +// AllowNames are allowed DNS names. +type AllowNames struct { + m map[string][]netip.Addr +} + +// Add adds and updates the allowed name and its IP addresses. +func (a *AllowNames) Add(name string, addrs []netip.Addr) { + a.m[name] = addrs +} + +// GetAll returns all allowed names with their IP addresses. +func (a *AllowNames) GetAll() map[string][]netip.Addr { + names := make(map[string][]netip.Addr) + for k, v := range a.m { + names[k] = append(names[k], v...) + + } + return names +} + +// NewAllowNames returns new AllowNames. +func NewAllowNames() *AllowNames { + return &AllowNames{ + m: make(map[string][]netip.Addr), + } +} diff --git a/internal/trafpol/trafpol.go b/internal/trafpol/trafpol.go index a904baeb..60c8d9f1 100644 --- a/internal/trafpol/trafpol.go +++ b/internal/trafpol/trafpol.go @@ -12,12 +12,29 @@ import ( "github.com/telekom-mms/oc-daemon/internal/dnsmon" ) -// trafPolAddrCmd is a TrafPol address command. -type trafPolAddrCmd struct { - add bool - ip netip.Addr - ok bool - done chan struct{} +// TrafPol command types. +const ( + trafPolCmdAddAddress uint8 = iota + 1 + trafPolCmdRemoveAddress + trafPolCmdGetState +) + +// State is the internal TrafPol state. +type State struct { + Config *Config + CaptivePortal bool + AllowedDevices []string + AllowedAddresses []netip.Prefix + AllowedNames map[string][]netip.Addr +} + +// trafPolCmd is a TrafPol command. +type trafPolCmd struct { + typ uint8 + ip netip.Addr + ok bool + state *State + done chan struct{} } // TrafPol is a traffic policing component. @@ -32,15 +49,15 @@ type TrafPol struct { // allowed devices, addresses, names allowDevs *AllowDevs - allowAddrs map[string]netip.Prefix - allowNames map[string][]netip.Addr + allowAddrs *AllowAddrs + allowNames *AllowNames // resolver for allowed names, channel for resolver updates resolver *Resolver resolvUp chan *ResolvedName // address commands channel - cmds chan *trafPolAddrCmd + cmds chan *trafPolCmd loopDone chan struct{} done chan struct{} @@ -101,13 +118,13 @@ func (t *TrafPol) getAllowedHostsIPs() []netip.Prefix { // - allowed names // - allowed addrs ipset := make(map[string]netip.Prefix) - for _, n := range t.allowNames { + for _, n := range t.allowNames.GetAll() { for _, ip := range n { prefix := netip.PrefixFrom(ip, ip.BitLen()) ipset[prefix.String()] = prefix } } - for _, a := range t.allowAddrs { + for _, a := range t.allowAddrs.List() { ipset[a.String()] = a } @@ -123,33 +140,28 @@ func (t *TrafPol) getAllowedHostsIPs() []netip.Prefix { // handleResolverUpdate handles a resolver update. func (t *TrafPol) handleResolverUpdate(ctx context.Context, update *ResolvedName) { // update allowed names - t.allowNames[update.Name] = update.IPs + t.allowNames.Add(update.Name, update.IPs) // set new filter rules setAllowedIPs(ctx, t.getAllowedHostsIPs()) } // handleAddressCommand handles an address command. -func (t *TrafPol) handleAddressCommand(ctx context.Context, cmd *trafPolAddrCmd) { - defer close(cmd.done) - +func (t *TrafPol) handleAddressCommand(ctx context.Context, cmd *trafPolCmd) { // convert to prefix prefix := netip.PrefixFrom(cmd.ip, cmd.ip.BitLen()) // update allowed addrs - s := prefix.String() - if cmd.add { - if _, ok := t.allowAddrs[s]; ok { + if cmd.typ == trafPolCmdAddAddress { + if ok := t.allowAddrs.Add(prefix); !ok { // ip already in allowed addrs return } - t.allowAddrs[s] = prefix } else { - if _, ok := t.allowAddrs[s]; !ok { + if ok := t.allowAddrs.Remove(prefix); !ok { // ip not in allowed addrs return } - delete(t.allowAddrs, s) } // set new filter rules @@ -159,6 +171,30 @@ func (t *TrafPol) handleAddressCommand(ctx context.Context, cmd *trafPolAddrCmd) cmd.ok = true } +// handleGetStateCommand handles a get state command. +func (t *TrafPol) handleGetStateCommand(cmd *trafPolCmd) { + // set state + cmd.state = &State{ + Config: t.config, + CaptivePortal: t.capPortal, + AllowedDevices: t.allowDevs.List(), + AllowedAddresses: t.allowAddrs.List(), + AllowedNames: t.allowNames.GetAll(), + } +} + +// handleCommand handles a command. +func (t *TrafPol) handleCommand(ctx context.Context, cmd *trafPolCmd) { + defer close(cmd.done) + + switch cmd.typ { + case trafPolCmdAddAddress, trafPolCmdRemoveAddress: + t.handleAddressCommand(ctx, cmd) + case trafPolCmdGetState: + t.handleGetStateCommand(cmd) + } +} + // start starts the traffic policing component. func (t *TrafPol) start(ctx context.Context) { defer close(t.loopDone) @@ -192,9 +228,9 @@ func (t *TrafPol) start(ctx context.Context) { t.handleResolverUpdate(ctx, u) case c := <-t.cmds: - // Address Command - log.WithField("command", c).Debug("TrafPol got address command") - t.handleAddressCommand(ctx, c) + // Command + log.WithField("command", c).Debug("TrafPol got command") + t.handleCommand(ctx, c) case <-t.done: // shutdown @@ -264,8 +300,8 @@ func (t *TrafPol) AddAllowedAddr(addr netip.Addr) (ok bool) { log.WithField("addr", addr). Debug("TrafPol adding IP to allowed addresses") - c := &trafPolAddrCmd{ - add: true, + c := &trafPolCmd{ + typ: trafPolCmdAddAddress, ip: addr, done: make(chan struct{}), } @@ -280,7 +316,8 @@ func (t *TrafPol) RemoveAllowedAddr(addr netip.Addr) (ok bool) { log.WithField("addr", addr). Debug("TrafPol removing IP from allowed addresses") - c := &trafPolAddrCmd{ + c := &trafPolCmd{ + typ: trafPolCmdRemoveAddress, ip: addr, done: make(chan struct{}), } @@ -290,6 +327,20 @@ func (t *TrafPol) RemoveAllowedAddr(addr netip.Addr) (ok bool) { return c.ok } +// GetState returns the internal TrafPol state. +func (t *TrafPol) GetState() *State { + log.Debug("TrafPol getting internal state") + + c := &trafPolCmd{ + typ: trafPolCmdGetState, + done: make(chan struct{}), + } + t.cmds <- c + <-c.done + + return c.state +} + // parseAllowedHosts parses the allowed hosts and returns IP addresses and DNS names func parseAllowedHosts(hosts []string) (addrs []netip.Prefix, names []string) { for _, h := range hosts { @@ -321,13 +372,13 @@ func NewTrafPol(config *Config) *TrafPol { a, n := parseAllowedHosts(hosts) // create allowed addrs and names - addrs := make(map[string]netip.Prefix) - names := make(map[string][]netip.Addr) + addrs := NewAllowAddrs() + names := NewAllowNames() for _, addr := range a { - addrs[addr.String()] = addr + addrs.Add(addr) } for _, name := range n { - names[name] = []netip.Addr{} + names.Add(name, []netip.Addr{}) } // create channel for resolver updates @@ -347,7 +398,7 @@ func NewTrafPol(config *Config) *TrafPol { resolver: NewResolver(config, n, resolvUp), resolvUp: resolvUp, - cmds: make(chan *trafPolAddrCmd), + cmds: make(chan *trafPolCmd), loopDone: make(chan struct{}), done: make(chan struct{}), diff --git a/internal/trafpol/trafpol_test.go b/internal/trafpol/trafpol_test.go index dfe6d911..108d9dc4 100644 --- a/internal/trafpol/trafpol_test.go +++ b/internal/trafpol/trafpol_test.go @@ -4,6 +4,7 @@ import ( "context" "net/netip" "reflect" + "slices" "sort" "sync" "testing" @@ -132,10 +133,10 @@ func TestTrafPolGetAllowedHostsIPs(t *testing.T) { tp := NewTrafPol(c) // add allowed names - tp.allowNames["example.com"] = []netip.Addr{ + tp.allowNames.Add("example.com", []netip.Addr{ netip.MustParseAddr("192.168.1.1"), netip.MustParseAddr("2001:DB8:1::1"), - } + }) // wanted IPs want := []netip.Prefix{} @@ -202,9 +203,9 @@ func TestTrafPolAddRemoveAllowedAddr(t *testing.T) { t.Errorf("address not added") } - want := prefix.String() - got := tp.allowAddrs[prefix.String()].String() - if got != want { + want := []netip.Prefix{prefix} + got := tp.allowAddrs.List() + if !slices.Equal(got, want) { t.Errorf("got %s, want %s", got, want) } @@ -218,9 +219,9 @@ func TestTrafPolAddRemoveAllowedAddr(t *testing.T) { t.Errorf("address not removed") } - want = netip.Prefix{}.String() - got = tp.allowAddrs[prefix.String()].String() - if got != want { + want = []netip.Prefix{} + got = tp.allowAddrs.List() + if !slices.Equal(got, want) { t.Errorf("got %s, want %s", got, want) } @@ -241,6 +242,30 @@ func TestTrafPolAddRemoveAllowedAddr(t *testing.T) { tp.Stop() } +// TestTrafPolGetState tests GetState of Trafpol. +func TestTrafPolGetState(t *testing.T) { + // set dummy low level function for devmon + oldRegisterLinkUpdates := devmon.RegisterLinkUpdates + devmon.RegisterLinkUpdates = func(*devmon.DevMon) (chan netlink.LinkUpdate, error) { + return nil, nil + } + defer func() { devmon.RegisterLinkUpdates = oldRegisterLinkUpdates }() + + // start trafpol + tp := NewTrafPol(NewConfig()) + if err := tp.Start(); err != nil { + t.Fatal(err) + } + + // check state + if tp.GetState() == nil { + t.Errorf("got invalid state") + } + + // stop trafpol + tp.Stop() +} + // TestNewTrafPol tests NewTrafPol. func TestNewTrafPol(t *testing.T) { c := NewConfig() diff --git a/internal/vpnsetup/vpnsetup.go b/internal/vpnsetup/vpnsetup.go index a8a60ecc..e0b7c5c6 100644 --- a/internal/vpnsetup/vpnsetup.go +++ b/internal/vpnsetup/vpnsetup.go @@ -19,12 +19,20 @@ import ( const ( commandSetup uint8 = iota commandTeardown + commandGetState ) +// State is the internal state of the VPN Setup. +type State struct { + SplitRouting *splitrt.State + DNSProxy *dnsproxy.State +} + // command is a VPNSetup command. type command struct { cmd uint8 vpnconf *vpnconfig.Config + state *State done chan struct{} } @@ -434,6 +442,18 @@ func (v *VPNSetup) teardown(ctx context.Context, vpnconf *vpnconfig.Config) { v.teardownDNS(ctx, vpnconf) } +// getState gets the internal state. +func (v *VPNSetup) getState(c *command) { + state := &State{} + if v.splitrt != nil { + state.SplitRouting = v.splitrt.GetState() + } + if v.dnsProxy != nil { + state.DNSProxy = v.dnsProxy.GetState() + } + c.state = state +} + // handleCommand handles a command. func (v *VPNSetup) handleCommand(ctx context.Context, c *command) { defer close(c.done) @@ -443,6 +463,8 @@ func (v *VPNSetup) handleCommand(ctx context.Context, c *command) { v.setup(ctx, c.vpnconf) case commandTeardown: v.teardown(ctx, c.vpnconf) + case commandGetState: + v.getState(c) } } @@ -519,6 +541,17 @@ func (v *VPNSetup) Teardown(vpnconfig *vpnconfig.Config) { <-c.done } +// GetState returns the internal state of the VPN config. +func (v *VPNSetup) GetState() *State { + c := &command{ + cmd: commandGetState, + done: make(chan struct{}), + } + v.cmds <- c + <-c.done + return c.state +} + // NewVPNSetup returns a new VPNSetup. func NewVPNSetup( dnsProxyConfig *dnsproxy.Config, diff --git a/internal/vpnsetup/vpnsetup_test.go b/internal/vpnsetup/vpnsetup_test.go index 2cd577bd..bd577431 100644 --- a/internal/vpnsetup/vpnsetup_test.go +++ b/internal/vpnsetup/vpnsetup_test.go @@ -378,6 +378,56 @@ func TestVPNSetupSetupTeardown(_ *testing.T) { v.Stop() } +// TestVPNSetupGetState tests GetState of VPNSetup. +func TestVPNSetupGetState(t *testing.T) { + // override functions + oldCmd := execs.RunCmd + execs.RunCmd = func(context.Context, string, string, ...string) ([]byte, []byte, error) { + return nil, nil, nil + } + defer func() { execs.RunCmd = oldCmd }() + + oldRegisterAddrUpdates := addrmon.RegisterAddrUpdates + addrmon.RegisterAddrUpdates = func(*addrmon.AddrMon) (chan netlink.AddrUpdate, error) { + return nil, nil + } + defer func() { addrmon.RegisterAddrUpdates = oldRegisterAddrUpdates }() + + oldRegisterLinkUpdates := devmon.RegisterLinkUpdates + devmon.RegisterLinkUpdates = func(*devmon.DevMon) (chan netlink.LinkUpdate, error) { + return nil, nil + } + defer func() { devmon.RegisterLinkUpdates = oldRegisterLinkUpdates }() + + // start vpn setup + v := NewVPNSetup(dnsproxy.NewConfig(), splitrt.NewConfig()) + v.Start() + + // without vpn config + got := v.GetState() + if got == nil || + got.SplitRouting != nil || + got.DNSProxy == nil { + t.Errorf("got invalid state: %v", got) + } + + // with vpn config + vpnconf := vpnconfig.New() + v.Setup(vpnconf) + + got = v.GetState() + if got == nil || + got.SplitRouting == nil || + got.DNSProxy == nil { + t.Errorf("got invalid state: %v", got) + } + + // teardown config + v.Teardown(vpnconf) + + v.Stop() +} + // TestNewVPNSetup tests NewVPNSetup. func TestNewVPNSetup(t *testing.T) { dnsConfig := dnsproxy.NewConfig() diff --git a/pkg/client/client.go b/pkg/client/client.go index 3d325287..926cdb0c 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -35,6 +35,8 @@ type Client interface { Connect() error Disconnect() error + DumpState() (string, error) + Close() error } @@ -576,6 +578,20 @@ func (d *DBusClient) Disconnect() error { return disconnect(d) } +// dumpState sends a dump state request to the daemon. +var dumpState = func(d *DBusClient) (string, error) { + // call dump state + state := "" + err := d.conn.Object(dbusapi.Interface, dbusapi.Path). + Call(dbusapi.MethodDumpState, 0).Store(&state) + return state, err +} + +// DumpState returns the internal state of the OC-Daemon as string. +func (d *DBusClient) DumpState() (string, error) { + return dumpState(d) +} + // Close closes the DBusClient. func (d *DBusClient) Close() error { var err error diff --git a/pkg/client/client_test.go b/pkg/client/client_test.go index 16f0078e..cbcf807e 100644 --- a/pkg/client/client_test.go +++ b/pkg/client/client_test.go @@ -423,6 +423,24 @@ func TestDBusClientDisconnect(t *testing.T) { } } +// TestDBusClientDumpState tests DumpState of DBusClient. +func TestDBusClientDumpState(t *testing.T) { + // clean up after tests + oldDumpState := dumpState + defer func() { dumpState = oldDumpState }() + + // create test client + client := &DBusClient{} + + dumpState = func(_ *DBusClient) (string, error) { + return "test state", nil + } + state, err := client.DumpState() + if err != nil || state != "test state" { + t.Error(err, state) + } +} + // testRWC is a reader writer closer for testing. type testRWC struct{}