From 027f412e5270ea59c582c15c24adc406fe059f02 Mon Sep 17 00:00:00 2001 From: hwipl <33433250+hwipl@users.noreply.github.com> Date: Tue, 3 Dec 2024 15:41:23 +0100 Subject: [PATCH] Add combined Daemon configuration Consolidate all configurations in a combined daemon configuration: - Add package daemoncfg with type Config - Move existing configurations into daemoncfg.Config: - Socket API Server - CPD - OC-Daemon - DNS Proxy - Executables - OpenConnect Runner - Split Routing - Traffic Policing - Add configurations to daemoncfg.Config for internal use: - Login Info - VPN Configuration Signed-off-by: hwipl <33433250+hwipl@users.noreply.github.com> --- internal/api/config.go | 63 -- internal/api/config_test.go | 40 -- internal/api/server.go | 5 +- internal/api/server_test.go | 16 +- internal/cpd/config.go | 62 -- internal/cpd/config_test.go | 44 -- internal/cpd/cpd.go | 5 +- internal/cpd/cpd_test.go | 30 +- internal/daemon/cmd.go | 9 +- internal/daemon/cmd_test.go | 4 +- internal/daemon/config.go | 98 --- internal/daemon/config_test.go | 228 ------- internal/daemon/daemon.go | 47 +- internal/daemon/daemon_test.go | 6 +- internal/daemoncfg/config.go | 906 ++++++++++++++++++++++++++++ internal/daemoncfg/config_test.go | 923 +++++++++++++++++++++++++++++ internal/dnsproxy/config.go | 39 -- internal/dnsproxy/config_test.go | 49 -- internal/dnsproxy/proxy.go | 7 +- internal/dnsproxy/proxy_test.go | 8 +- internal/execs/config.go | 55 -- internal/execs/config_test.go | 84 --- internal/execs/execs.go | 12 +- internal/execs/execs_test.go | 6 +- internal/ocrunner/config.go | 100 ---- internal/ocrunner/config_test.go | 82 --- internal/ocrunner/connect.go | 77 ++- internal/ocrunner/connect_test.go | 90 ++- internal/splitrt/config.go | 80 --- internal/splitrt/config_test.go | 99 ---- internal/splitrt/excludes.go | 5 +- internal/splitrt/excludes_test.go | 23 +- internal/splitrt/splitrt.go | 133 ++--- internal/splitrt/splitrt_test.go | 98 ++- internal/trafpol/config.go | 83 --- internal/trafpol/config_test.go | 36 -- internal/trafpol/resolver.go | 10 +- internal/trafpol/resolver_test.go | 8 +- internal/trafpol/trafpol.go | 19 +- internal/trafpol/trafpol_test.go | 30 +- internal/vpncscript/client_test.go | 3 +- internal/vpnsetup/vpnsetup.go | 179 +++--- internal/vpnsetup/vpnsetup_test.go | 112 ++-- tools/dnsproxy/main.go | 3 +- tools/ocrunner/main.go | 12 +- 45 files changed, 2307 insertions(+), 1721 deletions(-) delete mode 100644 internal/api/config.go delete mode 100644 internal/api/config_test.go delete mode 100644 internal/cpd/config.go delete mode 100644 internal/cpd/config_test.go delete mode 100644 internal/daemon/config.go delete mode 100644 internal/daemon/config_test.go create mode 100644 internal/daemoncfg/config.go create mode 100644 internal/daemoncfg/config_test.go delete mode 100644 internal/dnsproxy/config.go delete mode 100644 internal/dnsproxy/config_test.go delete mode 100644 internal/execs/config.go delete mode 100644 internal/execs/config_test.go delete mode 100644 internal/ocrunner/config.go delete mode 100644 internal/ocrunner/config_test.go delete mode 100644 internal/splitrt/config.go delete mode 100644 internal/splitrt/config_test.go delete mode 100644 internal/trafpol/config.go delete mode 100644 internal/trafpol/config_test.go diff --git a/internal/api/config.go b/internal/api/config.go deleted file mode 100644 index 3638aad5..00000000 --- a/internal/api/config.go +++ /dev/null @@ -1,63 +0,0 @@ -package api - -import ( - "strconv" - "time" -) - -var ( - // SocketFile is the unix socket file. - SocketFile = "/run/oc-daemon/daemon.sock" - - // SocketOwner is the owner of the socket file. - SocketOwner = "" - - // SocketGroup is the group of the socket file. - SocketGroup = "" - - // SocketPermissions are the file permissions of the socket file. - SocketPermissions = "0700" - - // RequestTimeout is the timeout for an entire request/response - // exchange initiated by a client. - RequestTimeout = 30 * time.Second -) - -// Config is a server configuration. -type Config struct { - SocketFile string - SocketOwner string - SocketGroup string - SocketPermissions string - RequestTimeout time.Duration -} - -// Valid returns whether server config is valid. -func (c *Config) Valid() bool { - if c == nil || - c.SocketFile == "" || - c.RequestTimeout < 0 { - return false - } - if c.SocketPermissions != "" { - perm, err := strconv.ParseUint(c.SocketPermissions, 8, 32) - if err != nil { - return false - } - if perm > 0777 { - return false - } - } - return true -} - -// NewConfig returns a new server configuration. -func NewConfig() *Config { - return &Config{ - SocketFile: SocketFile, - SocketOwner: SocketOwner, - SocketGroup: SocketGroup, - SocketPermissions: SocketPermissions, - RequestTimeout: RequestTimeout, - } -} diff --git a/internal/api/config_test.go b/internal/api/config_test.go deleted file mode 100644 index 8eee1da8..00000000 --- a/internal/api/config_test.go +++ /dev/null @@ -1,40 +0,0 @@ -package api - -import "testing" - -// TestConfigValid tests Valid of Config. -func TestConfigValid(t *testing.T) { - // test invalid - for _, invalid := range []*Config{ - nil, - {}, - {SocketFile: "test.sock", SocketPermissions: "invalid"}, - {SocketFile: "test.sock", SocketPermissions: "1234"}, - } { - want := false - got := invalid.Valid() - if got != want { - t.Errorf("got %t, want %t for %v", got, want, invalid) - } - } - - // test valid - for _, valid := range []*Config{ - NewConfig(), - {SocketFile: "test.sock", SocketPermissions: "777"}, - } { - want := true - got := valid.Valid() - if got != want { - t.Errorf("got %t, want %t for %v", got, want, valid) - } - } -} - -// TestNewConfig tests NewConfig. -func TestNewConfig(t *testing.T) { - sc := NewConfig() - if !sc.Valid() { - t.Errorf("config is not valid") - } -} diff --git a/internal/api/server.go b/internal/api/server.go index 4aab7694..84afbae5 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -10,6 +10,7 @@ import ( "time" log "github.com/sirupsen/logrus" + "github.com/telekom-mms/oc-daemon/internal/daemoncfg" ) const ( @@ -20,7 +21,7 @@ const ( // Server is a Daemon API server. type Server struct { - config *Config + config *daemoncfg.SocketServer listen net.Listener requests chan *Request shutdown chan struct{} @@ -239,7 +240,7 @@ func (s *Server) Requests() chan *Request { } // NewServer returns a new API server. -func NewServer(config *Config) *Server { +func NewServer(config *daemoncfg.SocketServer) *Server { return &Server{ config: config, requests: make(chan *Request), diff --git a/internal/api/server_test.go b/internal/api/server_test.go index 7900bd91..b4e8fb84 100644 --- a/internal/api/server_test.go +++ b/internal/api/server_test.go @@ -4,11 +4,13 @@ import ( "net" "reflect" "testing" + + "github.com/telekom-mms/oc-daemon/internal/daemoncfg" ) // TestServerHandleRequest tests handleRequest of Server. func TestServerHandleRequest(t *testing.T) { - config := NewConfig() + config := daemoncfg.NewSocketServer() server := NewServer(config) // connection closed @@ -77,7 +79,7 @@ func TestServerHandleRequest(t *testing.T) { // TestServerSetSocketOwner tests setSocketOwner of Server. func TestServerSetSocketOwner(_ *testing.T) { - config := NewConfig() + config := daemoncfg.NewSocketServer() server := NewServer(config) // no changes @@ -95,7 +97,7 @@ func TestServerSetSocketOwner(_ *testing.T) { // TestServerSetSocketGroup tests setSocketGroup of Server. func TestServerSetSocketGroup(_ *testing.T) { - config := NewConfig() + config := daemoncfg.NewSocketServer() server := NewServer(config) // no changes @@ -113,7 +115,7 @@ func TestServerSetSocketGroup(_ *testing.T) { // TestServerSetSocketPermissions tests setSocketPermissions of Server. func TestServerSetSocketPermissions(_ *testing.T) { - config := NewConfig() + config := daemoncfg.NewSocketServer() server := NewServer(config) // socket file does not exist @@ -131,7 +133,7 @@ func TestServerSetSocketPermissions(_ *testing.T) { // TestServerStartStop tests Start and Stop of Server. func TestServerStartStop(t *testing.T) { - config := NewConfig() + config := daemoncfg.NewSocketServer() config.SocketFile = "test.sock" server := NewServer(config) if err := server.Start(); err != nil { @@ -143,7 +145,7 @@ func TestServerStartStop(t *testing.T) { // TestServerRequests tests Requests of Server. func TestServerRequests(t *testing.T) { - config := NewConfig() + config := daemoncfg.NewSocketServer() config.SocketFile = "test.sock" server := NewServer(config) if server.Requests() != server.requests { @@ -153,7 +155,7 @@ func TestServerRequests(t *testing.T) { // TestNewServer tests NewServer. func TestNewServer(t *testing.T) { - config := NewConfig() + config := daemoncfg.NewSocketServer() server := NewServer(config) if server == nil || diff --git a/internal/cpd/config.go b/internal/cpd/config.go deleted file mode 100644 index 8ad69821..00000000 --- a/internal/cpd/config.go +++ /dev/null @@ -1,62 +0,0 @@ -package cpd - -import "time" - -var ( - // Host is the host address used for probing. - Host = "connectivity-check.ubuntu.com" - - // HTTPTimeout is the timeout for http requests in seconds. - HTTPTimeout = 5 * time.Second - - // ProbeCount is the number of probes to run. - ProbeCount = 3 - - // ProbeWait is the time between probes. - ProbeWait = time.Second - - // ProbeTimer is the probe timer in case of no detected portal - // in seconds. - ProbeTimer = 300 * time.Second - - // ProbeTimerDetected is the probe timer in case of a detected portal - // in seconds. - ProbeTimerDetected = 15 * time.Second -) - -// Config is the configuration of the captive portal detection. -type Config struct { - Host string - HTTPTimeout time.Duration - ProbeCount int - ProbeWait time.Duration - ProbeTimer time.Duration - ProbeTimerDetected time.Duration -} - -// Valid returns whether the captive portal detection configuration is valid. -func (c *Config) Valid() bool { - if c == nil || - c.Host == "" || - c.HTTPTimeout <= 0 || - c.ProbeCount <= 0 || - c.ProbeWait <= 0 || - c.ProbeTimer <= 0 || - c.ProbeTimerDetected <= 0 { - - return false - } - return true -} - -// NewConfig returns a new default configuration for captive portal detection. -func NewConfig() *Config { - return &Config{ - Host: Host, - HTTPTimeout: HTTPTimeout, - ProbeCount: ProbeCount, - ProbeWait: ProbeWait, - ProbeTimer: ProbeTimer, - ProbeTimerDetected: ProbeTimerDetected, - } -} diff --git a/internal/cpd/config_test.go b/internal/cpd/config_test.go deleted file mode 100644 index 7068ad42..00000000 --- a/internal/cpd/config_test.go +++ /dev/null @@ -1,44 +0,0 @@ -package cpd - -import ( - "testing" - "time" -) - -// TestConfigValid tests Valid of Config. -func TestConfigValid(t *testing.T) { - // test invalid - for _, invalid := range []*Config{ - nil, - {}, - } { - if invalid.Valid() { - t.Errorf("config should be invalid: %v", invalid) - } - } - - // test valid - for _, valid := range []*Config{ - NewConfig(), - { - Host: "some.host.example.com", - HTTPTimeout: 3 * time.Second, - ProbeCount: 5, - ProbeWait: 2 * time.Second, - ProbeTimer: 150 * time.Second, - ProbeTimerDetected: 10 * time.Second, - }, - } { - if !valid.Valid() { - t.Errorf("config should be valid: %v", valid) - } - } -} - -// TestNewConfig tests NewConfig. -func TestNewConfig(t *testing.T) { - c := NewConfig() - if !c.Valid() { - t.Errorf("new config should be valid") - } -} diff --git a/internal/cpd/cpd.go b/internal/cpd/cpd.go index 9cdb3494..9af8eb42 100644 --- a/internal/cpd/cpd.go +++ b/internal/cpd/cpd.go @@ -7,6 +7,7 @@ import ( "time" log "github.com/sirupsen/logrus" + "github.com/telekom-mms/oc-daemon/internal/daemoncfg" ) // Report is a captive portal detection report. @@ -17,7 +18,7 @@ type Report struct { // CPD is a captive portal detection instance. type CPD struct { - config *Config + config *daemoncfg.CPD reports chan *Report probes chan struct{} done chan struct{} @@ -211,7 +212,7 @@ func (c *CPD) Results() chan *Report { } // NewCPD returns a new CPD. -func NewCPD(config *Config) *CPD { +func NewCPD(config *daemoncfg.CPD) *CPD { return &CPD{ config: config, reports: make(chan *Report), diff --git a/internal/cpd/cpd_test.go b/internal/cpd/cpd_test.go index f6f7ebee..9fea7fd9 100644 --- a/internal/cpd/cpd_test.go +++ b/internal/cpd/cpd_test.go @@ -6,6 +6,8 @@ import ( "reflect" "testing" "time" + + "github.com/telekom-mms/oc-daemon/internal/daemoncfg" ) // TestCPDProbeCheck tests probe and check of CPD. @@ -16,7 +18,7 @@ func TestCPDProbeCheck(t *testing.T) { w.WriteHeader(http.StatusBadRequest) })) defer ts.Close() - c := NewCPD(NewConfig()) + c := NewCPD(daemoncfg.NewCPD()) c.config.Host = ts.Listener.Addr().String() c.config.ProbeWait = 0 close(c.done) @@ -29,7 +31,7 @@ func TestCPDProbeCheck(t *testing.T) { w.WriteHeader(http.StatusFound) })) defer ts.Close() - c := NewCPD(NewConfig()) + c := NewCPD(daemoncfg.NewCPD()) c.config.Host = ts.Listener.Addr().String() c.config.ProbeWait = 0 @@ -41,7 +43,7 @@ func TestCPDProbeCheck(t *testing.T) { // check with invalid server address t.Run("invalid server", func(t *testing.T) { - c := NewCPD(NewConfig()) + c := NewCPD(daemoncfg.NewCPD()) c.config.Host = "" c.config.ProbeWait = 0 @@ -58,7 +60,7 @@ func TestCPDProbeCheck(t *testing.T) { w.WriteHeader(http.StatusOK) })) defer ts.Close() - c := NewCPD(NewConfig()) + c := NewCPD(daemoncfg.NewCPD()) c.config.Host = ts.Listener.Addr().String() c.config.ProbeWait = 0 @@ -71,7 +73,7 @@ func TestCPDProbeCheck(t *testing.T) { // TestCPDHandleProbeRequest tests handleProbeRequest of CPD. func TestCPDHandleProbeRequest(t *testing.T) { - c := NewCPD(NewConfig()) + c := NewCPD(daemoncfg.NewCPD()) close(c.done) c.handleProbeRequest() @@ -88,7 +90,7 @@ func TestCPDHandleProbeRequest(t *testing.T) { // TestCPDHandleProbeReport tests handleProbeReport of CPD. func TestCPDHandleProbeReport(t *testing.T) { - c := NewCPD(NewConfig()) + c := NewCPD(daemoncfg.NewCPD()) // - send a probe request // - read report @@ -118,7 +120,7 @@ func TestCPDHandleTimer(t *testing.T) { false, true, } { - c := NewCPD(NewConfig()) + c := NewCPD(daemoncfg.NewCPD()) c.timer = time.NewTimer(0) c.detected = detected c.handleTimer() @@ -132,12 +134,12 @@ func TestCPDHandleTimer(t *testing.T) { // TestCPDStartStop tests Start and Stop of CPD. func TestCPDStartStop(t *testing.T) { // start and stop immediately - c := NewCPD(NewConfig()) + c := NewCPD(daemoncfg.NewCPD()) c.Start() c.Stop() // start and stop with timer event, probe result - conf := NewConfig() + conf := daemoncfg.NewCPD() conf.Host = "" conf.ProbeTimer = 0 conf.ProbeWait = 0 @@ -153,7 +155,7 @@ func TestCPDStartStop(t *testing.T) { // TestCPDHosts tests Hosts of CPD. func TestCPDHosts(t *testing.T) { - config := NewConfig() + config := daemoncfg.NewCPD() config.Host = "test" c := NewCPD(config) want := []string{"test"} @@ -171,7 +173,7 @@ func TestCPDProbe(t *testing.T) { w.WriteHeader(http.StatusNoContent) })) defer ts.Close() - c := NewCPD(NewConfig()) + c := NewCPD(daemoncfg.NewCPD()) c.config.Host = ts.Listener.Addr().String() c.config.ProbeWait = 0 c.Start() @@ -189,7 +191,7 @@ func TestCPDProbe(t *testing.T) { http.Redirect(w, r, "http://example.com", http.StatusFound) })) defer ts.Close() - c := NewCPD(NewConfig()) + c := NewCPD(daemoncfg.NewCPD()) c.config.Host = ts.Listener.Addr().String() c.config.ProbeWait = 0 c.Start() @@ -204,7 +206,7 @@ func TestCPDProbe(t *testing.T) { // TestCPDResults tests Results of CPD. func TestCPDResults(t *testing.T) { - c := NewCPD(NewConfig()) + c := NewCPD(daemoncfg.NewCPD()) want := c.reports got := c.Results() if got != want { @@ -214,7 +216,7 @@ func TestCPDResults(t *testing.T) { // TestNewCPD tests NewCPD. func TestNewCPD(t *testing.T) { - config := NewConfig() + config := daemoncfg.NewCPD() c := NewCPD(config) if !reflect.DeepEqual(c.config, config) { t.Errorf("got %v, want %v", c.config, config) diff --git a/internal/daemon/cmd.go b/internal/daemon/cmd.go index e96fb61d..2eedb632 100644 --- a/internal/daemon/cmd.go +++ b/internal/daemon/cmd.go @@ -8,6 +8,7 @@ import ( "path/filepath" log "github.com/sirupsen/logrus" + "github.com/telekom-mms/oc-daemon/internal/daemoncfg" ) var ( @@ -26,7 +27,7 @@ const ( var osMkdirAll = os.MkdirAll // prepareFolders prepares directories used by the daemon. -func prepareFolders(config *Config) error { +func prepareFolders(config *daemoncfg.Config) error { for _, file := range []string{ config.Config, config.SocketServer.SocketFile, @@ -56,7 +57,7 @@ func flagIsSet(flags *flag.FlagSet, name string) bool { // run is the main entry point for the daemon. func run(args []string) error { // parse command line arguments - defaults := NewConfig() + defaults := daemoncfg.NewConfig() flags := flag.NewFlagSet(args[0], flag.ContinueOnError) cfgFile := flags.String(argConfig, defaults.Config, "set config `file`") verbose := flags.Bool(argVerbose, defaults.Verbose, "enable verbose output") @@ -75,7 +76,7 @@ func run(args []string) error { log.WithField("version", Version).Info("Starting Daemon") // load config - config := NewConfig() + config := daemoncfg.NewConfig() if flagIsSet(flags, argConfig) { config.Config = *cfgFile } @@ -83,7 +84,7 @@ func run(args []string) error { log.WithError(err).Warn("Daemon could not load config, using default config") } if !config.Valid() { - config = NewConfig() + config = daemoncfg.NewConfig() log.Warn("Daemon loaded invalid config, using default config") } diff --git a/internal/daemon/cmd_test.go b/internal/daemon/cmd_test.go index 4dff827a..bfe6d967 100644 --- a/internal/daemon/cmd_test.go +++ b/internal/daemon/cmd_test.go @@ -8,13 +8,15 @@ import ( "os" "path/filepath" "testing" + + "github.com/telekom-mms/oc-daemon/internal/daemoncfg" ) // TestPrepareFolders tests prepareFolders. func TestPrepareFolders(t *testing.T) { // create temp dir and config dir := t.TempDir() - cfg := NewConfig() + cfg := daemoncfg.NewConfig() // set files: config, socket, xml-profile, pid file conf := filepath.Join(dir, "conf") diff --git a/internal/daemon/config.go b/internal/daemon/config.go deleted file mode 100644 index bdc5fc9c..00000000 --- a/internal/daemon/config.go +++ /dev/null @@ -1,98 +0,0 @@ -package daemon - -import ( - "encoding/json" - "os" - - "github.com/telekom-mms/oc-daemon/internal/api" - "github.com/telekom-mms/oc-daemon/internal/cpd" - "github.com/telekom-mms/oc-daemon/internal/dnsproxy" - "github.com/telekom-mms/oc-daemon/internal/execs" - "github.com/telekom-mms/oc-daemon/internal/ocrunner" - "github.com/telekom-mms/oc-daemon/internal/splitrt" - "github.com/telekom-mms/oc-daemon/internal/trafpol" - "github.com/telekom-mms/tnd/pkg/tnd" -) - -var ( - // configDir is the directory for the configuration. - configDir = "/var/lib/oc-daemon" - - // ConfigFile is the default config file. - ConfigFile = configDir + "/oc-daemon.json" - - // DefaultDNSServer is the default DNS server address, i.e., listen - // address of systemd-resolved. - DefaultDNSServer = "127.0.0.53:53" -) - -// Config is an OC-Daemon configuration. -type Config struct { - Config string `json:"-"` - Verbose bool - - SocketServer *api.Config - CPD *cpd.Config - DNSProxy *dnsproxy.Config - OpenConnect *ocrunner.Config - Executables *execs.Config - SplitRouting *splitrt.Config - TrafficPolicing *trafpol.Config - TND *tnd.Config -} - -// String returns the configuration as string. -func (c *Config) String() string { - b, _ := json.Marshal(c) - return string(b) -} - -// Valid returns whether config is valid. -func (c *Config) Valid() bool { - if c == nil || - !c.SocketServer.Valid() || - !c.CPD.Valid() || - !c.DNSProxy.Valid() || - !c.OpenConnect.Valid() || - !c.Executables.Valid() || - !c.SplitRouting.Valid() || - !c.TrafficPolicing.Valid() || - !c.TND.Valid() { - // invalid - return false - } - return true -} - -// Load loads the configuration from the config file. -func (c *Config) Load() error { - // read file contents - file, err := os.ReadFile(c.Config) - if err != nil { - return err - } - - // parse config - if err := json.Unmarshal(file, c); err != nil { - return err - } - - return nil -} - -// NewConfig returns a new Config. -func NewConfig() *Config { - return &Config{ - Config: ConfigFile, - Verbose: false, - - SocketServer: api.NewConfig(), - CPD: cpd.NewConfig(), - DNSProxy: dnsproxy.NewConfig(), - OpenConnect: ocrunner.NewConfig(), - Executables: execs.NewConfig(), - SplitRouting: splitrt.NewConfig(), - TrafficPolicing: trafpol.NewConfig(), - TND: tnd.NewConfig(), - } -} diff --git a/internal/daemon/config_test.go b/internal/daemon/config_test.go deleted file mode 100644 index 7a20ae9d..00000000 --- a/internal/daemon/config_test.go +++ /dev/null @@ -1,228 +0,0 @@ -package daemon - -import ( - "os" - "reflect" - "testing" - - "github.com/telekom-mms/oc-daemon/internal/api" - "github.com/telekom-mms/oc-daemon/internal/cpd" - "github.com/telekom-mms/oc-daemon/internal/dnsproxy" - "github.com/telekom-mms/oc-daemon/internal/execs" - "github.com/telekom-mms/oc-daemon/internal/ocrunner" - "github.com/telekom-mms/oc-daemon/internal/splitrt" - "github.com/telekom-mms/oc-daemon/internal/trafpol" - "github.com/telekom-mms/tnd/pkg/tnd" -) - -// TestConfigString tests String of Config. -func TestConfigString(t *testing.T) { - // test new config - c := NewConfig() - if c.String() == "" { - t.Errorf("string should not be empty: %s", c.String()) - } - - // test nil - c = nil - if c.String() != "null" { - t.Errorf("string should be null: %s", c.String()) - } -} - -// TestConfigValid tests Valid of Config. -func TestConfigValid(t *testing.T) { - // test invalid - for _, invalid := range []*Config{ - nil, - {}, - } { - want := false - got := invalid.Valid() - - if got != want { - t.Errorf("got %t, want %t for %v", got, want, invalid) - } - } - - // test valid - valid := NewConfig() - want := true - got := valid.Valid() - - if got != want { - t.Errorf("got %t, want %t for %v", got, want, valid) - } -} - -// TestConfigLoad tests Load of Config. -func TestConfigLoad(t *testing.T) { - config := NewConfig() - config.Config = "does not exist" - - // test invalid path - err := config.Load() - if err == nil { - t.Errorf("got != nil, want nil") - } - - // test empty config file - empty, err := os.CreateTemp("", "oc-daemon-config-test") - if err != nil { - t.Fatal(err) - } - defer func() { - _ = os.Remove(empty.Name()) - }() - - config = NewConfig() - config.Config = empty.Name() - err = config.Load() - if err == nil { - t.Errorf("got != nil, want nil") - } - - // test valid config file - // - complete config - // - partial config with defaults - for _, content := range []string{ - `{ - "Verbose": true, - "SocketServer": { - "SocketFile": "/run/oc-daemon/daemon.sock", - "SocketOwner": "", - "SocketGroup": "", - "SocketPermissions": "0700", - "RequestTimeout": 30000000000 - }, - "CPD": { - "Host": "connectivity-check.ubuntu.com", - "HTTPTimeout": 5000000000, - "ProbeCount": 3, - "ProbeTimer": 300000000000, - "ProbeTimerDetected": 15000000000 - }, - "DNSProxy": { - "Address": "127.0.0.1:4253", - "ListenUDP": true, - "ListenTCP": true - }, - "OpenConnect": { - "OpenConnect": "openconnect", - "XMLProfile": "/var/lib/oc-daemon/profile.xml", - "VPNCScript": "/usr/bin/oc-daemon-vpncscript", - "VPNDevice": "oc-daemon-tun0", - "PIDFile": "/run/oc-daemon/openconnect.pid", - "PIDOwner": "", - "PIDGroup": "", - "PIDPermissions": "0600" - }, - "Executables": { - "IP": "ip", - "Nft": "nft", - "Resolvectl": "resolvectl", - "Sysctl": "sysctl" - }, - "SplitRouting": { - "RoutingTable": "42111", - "RulePriority1": "2111", - "RulePriority2": "2112", - "FirewallMark": "42111" - }, - "TrafficPolicing": { - "AllowedHosts": ["connectivity-check.ubuntu.com", "detectportal.firefox.com", "www.gstatic.com", "clients3.google.com", "nmcheck.gnome.org"], - "PortalPorts": [80, 443], - "ResolveTimeout": 2000000000, - "ResolveTries": 3, - "ResolveTriesSleep": 1000000000, - "ResolveTTL": 300000000000 - }, - "TND": { - "WaitCheck": 1000000000, - "HTTPSTimeout": 5000000000, - "UntrustedTimer": 30000000000, - "TrustedTimer": 60000000000 - } -}`, - `{ - "Verbose": true -}`, - } { - - valid, err := os.CreateTemp("", "oc-daemon-config-test") - if err != nil { - t.Fatal(err) - } - defer func() { - _ = os.Remove(valid.Name()) - }() - - if _, err := valid.Write([]byte(content)); err != nil { - t.Fatal(err) - } - - config := NewConfig() - config.Config = valid.Name() - if err := config.Load(); err != nil { - t.Errorf("could not load valid config: %s", err) - } - - if !config.Valid() { - t.Errorf("config is not valid") - } - - want := &Config{ - Config: valid.Name(), - Verbose: true, - SocketServer: api.NewConfig(), - CPD: cpd.NewConfig(), - DNSProxy: dnsproxy.NewConfig(), - OpenConnect: ocrunner.NewConfig(), - Executables: execs.NewConfig(), - SplitRouting: splitrt.NewConfig(), - TrafficPolicing: trafpol.NewConfig(), - TND: tnd.NewConfig(), - } - if !reflect.DeepEqual(want.DNSProxy, config.DNSProxy) { - t.Errorf("got %v, want %v", config.DNSProxy, want.DNSProxy) - } - if !reflect.DeepEqual(want.OpenConnect, config.OpenConnect) { - t.Errorf("got %v, want %v", config.OpenConnect, want.OpenConnect) - } - if !reflect.DeepEqual(want.Executables, config.Executables) { - t.Errorf("got %v, want %v", config.Executables, want.Executables) - } - if !reflect.DeepEqual(want.SplitRouting, config.SplitRouting) { - t.Errorf("got %v, want %v", config.SplitRouting, want.SplitRouting) - } - if !reflect.DeepEqual(want.TrafficPolicing, config.TrafficPolicing) { - t.Errorf("got %v, want %v", config.TrafficPolicing, want.TrafficPolicing) - } - if !reflect.DeepEqual(want.TND, config.TND) { - t.Errorf("got %v, want %v", config.TND, want.TND) - } - if !reflect.DeepEqual(want, config) { - t.Errorf("got %v, want %v", config, want) - } - } -} - -// TestNewConfig tests NewConfig. -func TestNewConfig(t *testing.T) { - want := &Config{ - Config: "/var/lib/oc-daemon/oc-daemon.json", - Verbose: false, - SocketServer: api.NewConfig(), - CPD: cpd.NewConfig(), - DNSProxy: dnsproxy.NewConfig(), - OpenConnect: ocrunner.NewConfig(), - Executables: execs.NewConfig(), - SplitRouting: splitrt.NewConfig(), - TrafficPolicing: trafpol.NewConfig(), - TND: tnd.NewConfig(), - } - got := NewConfig() - if !reflect.DeepEqual(got, want) { - t.Errorf("got %v, want %v", got, want) - } -} diff --git a/internal/daemon/daemon.go b/internal/daemon/daemon.go index 75bc4ed0..a6337461 100644 --- a/internal/daemon/daemon.go +++ b/internal/daemon/daemon.go @@ -15,7 +15,9 @@ import ( log "github.com/sirupsen/logrus" "github.com/telekom-mms/oc-daemon/internal/api" + "github.com/telekom-mms/oc-daemon/internal/daemoncfg" "github.com/telekom-mms/oc-daemon/internal/dbusapi" + "github.com/telekom-mms/oc-daemon/internal/dnsproxy" "github.com/telekom-mms/oc-daemon/internal/execs" "github.com/telekom-mms/oc-daemon/internal/ocrunner" "github.com/telekom-mms/oc-daemon/internal/profilemon" @@ -32,7 +34,7 @@ import ( // Daemon is used to run the daemon. type Daemon struct { - config *Config + config *daemoncfg.Config server *api.Server dbus *dbusapi.Service @@ -323,13 +325,14 @@ func (d *Daemon) connectVPN(login *logininfo.LoginInfo) { d.serverIPAllowed = d.trafpol.AddAllowedAddr(d.serverIP) } - // connect using runner + // save login and connect using runner + d.config.LoginInfo = login env := []string{ "oc_daemon_token=" + d.token, "oc_daemon_socket_file=" + d.config.SocketServer.SocketFile, "oc_daemon_verbose=" + strconv.FormatBool(d.config.Verbose), } - d.runner.Connect(login, env) + d.runner.Connect(d.config.Copy(), env) } // disconnectVPN disconnects from the VPN. @@ -371,7 +374,8 @@ func (d *Daemon) updateVPNConfigUp(config *vpnconfig.Config) { // connecting, set up configuration log.Info("Daemon setting up vpn configuration") - d.vpnsetup.Setup(config) + d.config.VPNConfig = daemoncfg.GetVPNConfig(config) + d.vpnsetup.Setup(d.config.Copy()) // set traffic policing setting from Disable Always On VPN setting // in configuration @@ -380,11 +384,11 @@ func (d *Daemon) updateVPNConfigUp(config *vpnconfig.Config) { // save config d.setStatusVPNConfig(config) ip := "" - for _, addr := range []net.IP{config.IPv4.Address, config.IPv6.Address} { + for _, p := range []netip.Prefix{d.config.VPNConfig.IPv4, d.config.VPNConfig.IPv6} { // this assumes either a single IPv4 or a single IPv6 address // is configured on a vpn device - if addr != nil { - ip = addr.String() + if p.IsValid() { + ip = p.Addr().String() } } d.setStatusIP(ip) @@ -417,9 +421,13 @@ func (d *Daemon) updateVPNConfigDown() { // disconnecting, tear down configuration log.Info("Daemon tearing down vpn configuration") if d.status.VPNConfig != nil { - d.vpnsetup.Teardown(d.status.VPNConfig) + d.vpnsetup.Teardown(d.config) } + // remove login and VPN config + d.config.LoginInfo = &logininfo.LoginInfo{} + d.config.VPNConfig = &daemoncfg.VPNConfig{} + // save config d.setStatusVPNConfig(nil) d.setStatusConnectionState(vpnstatus.ConnectionStateDisconnected) @@ -477,12 +485,17 @@ func (d *Daemon) handleClientRequest(request *api.Request) { func (d *Daemon) dumpState() string { // define state type type State struct { + DaemonConfig *daemoncfg.Config TrafficPolicing *trafpol.State VPNSetup *vpnsetup.State } // collect internal state - state := State{} + c := d.config.Copy() + c.LoginInfo.Cookie = "HIDDEN" // hide cookie + state := State{ + DaemonConfig: c, + } if d.trafpol != nil { state.TrafficPolicing = d.trafpol.GetState() } @@ -638,7 +651,7 @@ func (d *Daemon) handleProfileUpdate() error { // cleanup cleans up after a failed shutdown. func (d *Daemon) cleanup(ctx context.Context) { ocrunner.CleanupConnect(d.config.OpenConnect) - vpnsetup.Cleanup(ctx, d.config.OpenConnect.VPNDevice, d.config.SplitRouting) + vpnsetup.Cleanup(ctx, d.config) trafpol.Cleanup(ctx) } @@ -767,9 +780,8 @@ func (d *Daemon) startTrafPol() error { return nil } log.Info("Daemon starting TrafPol") - c := trafpol.NewConfig() - c.AllowedHosts = append(c.AllowedHosts, d.getProfileAllowedHosts()...) - c.FirewallMark = d.config.SplitRouting.FirewallMark + c := d.config.Copy() + c.TrafficPolicing.AllowedHosts = append(c.TrafficPolicing.AllowedHosts, d.getProfileAllowedHosts()...) d.trafpol = trafpol.NewTrafPol(c) if err := d.trafpol.Start(); err != nil { return fmt.Errorf("Daemon could not start TrafPol: %w", err) @@ -777,7 +789,7 @@ func (d *Daemon) startTrafPol() error { // update trafpol status d.setStatusTrafPolState(vpnstatus.TrafPolStateActive) - d.setStatusAllowedHosts(c.AllowedHosts) + d.setStatusAllowedHosts(c.TrafficPolicing.AllowedHosts) if d.serverIP.IsValid() { // VPN connection active, allow server IP @@ -984,7 +996,7 @@ func (d *Daemon) Errors() chan error { } // NewDaemon returns a new Daemon. -func NewDaemon(config *Config) *Daemon { +func NewDaemon(config *daemoncfg.Config) *Daemon { return &Daemon{ config: config, @@ -993,10 +1005,9 @@ func NewDaemon(config *Config) *Daemon { sleepmon: sleepmon.NewSleepMon(), - vpnsetup: vpnsetup.NewVPNSetup(config.DNSProxy, - config.SplitRouting), + vpnsetup: vpnsetup.NewVPNSetup(dnsproxy.NewProxy(config.DNSProxy)), - runner: ocrunner.NewConnect(config.OpenConnect), + runner: ocrunner.NewConnect(), status: vpnstatus.New(), diff --git a/internal/daemon/daemon_test.go b/internal/daemon/daemon_test.go index aa9dc2d0..77c2311c 100644 --- a/internal/daemon/daemon_test.go +++ b/internal/daemon/daemon_test.go @@ -3,12 +3,14 @@ package daemon import ( "path/filepath" "testing" + + "github.com/telekom-mms/oc-daemon/internal/daemoncfg" ) // TestDaemonErrors tests Errors of Daemon. func TestDaemonErrors(t *testing.T) { // create daemon - c := NewConfig() + c := daemoncfg.NewConfig() c.OpenConnect.XMLProfile = filepath.Join(t.TempDir(), "does-not-exist") d := NewDaemon(c) @@ -20,7 +22,7 @@ func TestDaemonErrors(t *testing.T) { // TestNewDaemon tests NewDaemon. func TestNewDaemon(t *testing.T) { // create daemon - c := NewConfig() + c := daemoncfg.NewConfig() c.OpenConnect.XMLProfile = filepath.Join(t.TempDir(), "does-not-exist") d := NewDaemon(c) diff --git a/internal/daemoncfg/config.go b/internal/daemoncfg/config.go new file mode 100644 index 00000000..7b5faf5b --- /dev/null +++ b/internal/daemoncfg/config.go @@ -0,0 +1,906 @@ +// Package daemoncfg contains the internal daemon configuration. +package daemoncfg + +import ( + "encoding/json" + "errors" + "net/netip" + "os" + "os/exec" + "reflect" + "strconv" + "time" + + "github.com/telekom-mms/oc-daemon/pkg/logininfo" + "github.com/telekom-mms/oc-daemon/pkg/vpnconfig" + "github.com/telekom-mms/tnd/pkg/tnd" +) + +// Socket Server default values. +var ( + // SocketServerSocketFile is the unix socket file. + SocketServerSocketFile = "/run/oc-daemon/daemon.sock" + + // SocketServerSocketOwner is the owner of the socket file. + SocketServerSocketOwner = "" + + // SocketServerSocketGroup is the group of the socket file. + SocketServerSocketGroup = "" + + // SocketServerSocketPermissions are the file permissions of the socket file. + SocketServerSocketPermissions = "0700" + + // SocketServerRequestTimeout is the timeout for an entire request/response + // exchange initiated by a client. + SocketServerRequestTimeout = 30 * time.Second +) + +// SocketServer the socket server configuration. +type SocketServer struct { + SocketFile string + SocketOwner string + SocketGroup string + SocketPermissions string + RequestTimeout time.Duration +} + +// Copy returns a copy of the SocketServer configuration. +func (c *SocketServer) Copy() *SocketServer { + s := *c + return &s +} + +// Valid returns whether server config is valid. +func (c *SocketServer) Valid() bool { + if c == nil || + c.SocketFile == "" || + c.RequestTimeout < 0 { + return false + } + if c.SocketPermissions != "" { + perm, err := strconv.ParseUint(c.SocketPermissions, 8, 32) + if err != nil { + return false + } + if perm > 0777 { + return false + } + } + return true +} + +// NewSocketServer returns a new server configuration. +func NewSocketServer() *SocketServer { + return &SocketServer{ + SocketFile: SocketServerSocketFile, + SocketOwner: SocketServerSocketOwner, + SocketGroup: SocketServerSocketGroup, + SocketPermissions: SocketServerSocketPermissions, + RequestTimeout: SocketServerRequestTimeout, + } +} + +// CPD default values. +var ( + // CPDHost is the host address used for probing. + CPDHost = "connectivity-check.ubuntu.com" + + // CPDHTTPTimeout is the timeout for http requests in seconds. + CPDHTTPTimeout = 5 * time.Second + + // CPDProbeCount is the number of probes to run. + CPDProbeCount = 3 + + // CPDProbeWait is the time between probes. + CPDProbeWait = time.Second + + // CPDProbeTimer is the probe timer in case of no detected portal + // in seconds. + CPDProbeTimer = 300 * time.Second + + // CPDProbeTimerDetected is the probe timer in case of a detected portal + // in seconds. + CPDProbeTimerDetected = 15 * time.Second +) + +// CPD is the configuration of the captive portal detection. +type CPD struct { + Host string + HTTPTimeout time.Duration + ProbeCount int + ProbeWait time.Duration + ProbeTimer time.Duration + ProbeTimerDetected time.Duration +} + +// Copy returns a copy of the CPD configuration. +func (c *CPD) Copy() *CPD { + n := *c + return &n +} + +// Valid returns whether the captive portal detection configuration is valid. +func (c *CPD) Valid() bool { + if c == nil || + c.Host == "" || + c.HTTPTimeout <= 0 || + c.ProbeCount <= 0 || + c.ProbeWait <= 0 || + c.ProbeTimer <= 0 || + c.ProbeTimerDetected <= 0 { + + return false + } + return true +} + +// NewCPD returns a new default configuration for captive portal detection. +func NewCPD() *CPD { + return &CPD{ + Host: CPDHost, + HTTPTimeout: CPDHTTPTimeout, + ProbeCount: CPDProbeCount, + ProbeWait: CPDProbeWait, + ProbeTimer: CPDProbeTimer, + ProbeTimerDetected: CPDProbeTimerDetected, + } +} + +// DNSProxy default values. +var ( + // DNSProxyAddress is the default listen address of the DNS proxy. + DNSProxyAddress = "127.0.0.1:4253" + + // DNSProxyListenUDP specifies whether the DNS proxy listens on UDP. + DNSProxyListenUDP = true + + // DNSProxyListenTCP specifies whether the DNS proxy listens on TCP. + DNSProxyListenTCP = true +) + +// DNSProxy is the DNS proxy configuration. +type DNSProxy struct { + Address string + ListenUDP bool + ListenTCP bool +} + +// Copy returns a copy of the DNSProxy configuration. +func (c *DNSProxy) Copy() *DNSProxy { + d := *c + return &d +} + +// Valid returns whether the DNS proxy configuration is valid. +func (c *DNSProxy) Valid() bool { + if c == nil || + c.Address == "" || + (!c.ListenUDP && !c.ListenTCP) { + + return false + } + return true +} + +// NewDNSProxy returns a new DNS proxy configuration. +func NewDNSProxy() *DNSProxy { + return &DNSProxy{ + Address: DNSProxyAddress, + ListenUDP: DNSProxyListenUDP, + ListenTCP: DNSProxyListenTCP, + } +} + +// OpenConnect default values. +var ( + // OpenConnectOpenConnect is the default openconnect executable. + OpenConnectOpenConnect = "openconnect" + + // OpenConnectXMLProfile is the default AnyConnect Profile. + OpenConnectXMLProfile = "/var/lib/oc-daemon/profile.xml" + + // OpenConnectVPNCScript is the default vpnc-script. + OpenConnectVPNCScript = "/usr/bin/oc-daemon-vpncscript" + + // OpenConnectVPNDevice is the default vpn network device name. + OpenConnectVPNDevice = "oc-daemon-tun0" + + // OpenConnectPIDFile is the default file path of the PID file for openconnect. + OpenConnectPIDFile = "/run/oc-daemon/openconnect.pid" + + // OpenConnectPIDOwner is the default owner of the PID file. + OpenConnectPIDOwner = "" + + // OpenConnectPIDGroup is the default group of the PID file. + OpenConnectPIDGroup = "" + + // OpenConnectPIDPermissions are the default file permissions of the PID file. + OpenConnectPIDPermissions = "0600" + + // OpenConnectNoProxy specifies whether the no proxy flag is set in openconnect. + OpenConnectNoProxy = true + + // OpenConnectExtraEnv are extra environment variables used by openconnect. + OpenConnectExtraEnv = []string{} + + // OpenConnectExtraArgs are extra command line arguments used by openconnect. + OpenConnectExtraArgs = []string{} +) + +// OpenConnect is the configuration for the openconnect connection runner. +type OpenConnect struct { + OpenConnect string + + XMLProfile string + VPNCScript string + VPNDevice string + + PIDFile string + PIDOwner string + PIDGroup string + PIDPermissions string + + NoProxy bool + ExtraEnv []string + ExtraArgs []string +} + +// Copy returns a copy of the OpenConnect configuration. +func (c *OpenConnect) Copy() *OpenConnect { + openConnect := *c + openConnect.ExtraEnv = append(c.ExtraEnv[:0:0], c.ExtraEnv...) + openConnect.ExtraArgs = append(c.ExtraArgs[:0:0], c.ExtraArgs...) + + return &openConnect +} + +// Valid returns whether the openconnect configuration is valid. +func (c *OpenConnect) Valid() bool { + if c == nil || + c.OpenConnect == "" || + c.XMLProfile == "" || + c.VPNCScript == "" || + c.VPNDevice == "" || + c.PIDFile == "" || + c.PIDPermissions == "" { + + return false + } + if c.PIDPermissions != "" { + perm, err := strconv.ParseUint(c.PIDPermissions, 8, 32) + if err != nil { + return false + } + if perm > 0777 { + return false + } + } + return true +} + +// NewOpenConnect returns a new configuration for an openconnect connection runner. +func NewOpenConnect() *OpenConnect { + return &OpenConnect{ + OpenConnect: OpenConnectOpenConnect, + + XMLProfile: OpenConnectXMLProfile, + VPNCScript: OpenConnectVPNCScript, + VPNDevice: OpenConnectVPNDevice, + + PIDFile: OpenConnectPIDFile, + PIDOwner: OpenConnectPIDOwner, + PIDGroup: OpenConnectPIDGroup, + PIDPermissions: OpenConnectPIDPermissions, + + NoProxy: OpenConnectNoProxy, + ExtraEnv: append(OpenConnectExtraEnv[:0:0], OpenConnectExtraEnv...), + ExtraArgs: append(OpenConnectExtraArgs[:0:0], OpenConnectExtraArgs...), + } +} + +// Executables default values. +var ( + ExecutablesIP = "ip" + ExecutablesNft = "nft" + ExecutablesResolvectl = "resolvectl" + ExecutablesSysctl = "sysctl" +) + +// Executables is the executables configuration. +type Executables struct { + IP string + Nft string + Resolvectl string + Sysctl string +} + +// Copy returns a copy of the executables configuration. +func (c *Executables) Copy() *Executables { + e := *c + return &e +} + +// Valid returns whether config is valid. +func (c *Executables) Valid() bool { + if c == nil || + c.IP == "" || + c.Nft == "" || + c.Resolvectl == "" || + c.Sysctl == "" { + // invalid + return false + } + return true +} + +// CheckExecutables checks whether executables in config exist in the +// file system and are executable. +func (c *Executables) CheckExecutables() error { + for _, f := range []string{ + c.IP, c.Nft, c.Resolvectl, c.Sysctl, + } { + if _, err := exec.LookPath(f); err != nil { + return err + } + } + return nil +} + +// NewExecutables returns a new Executables configuration. +func NewExecutables() *Executables { + return &Executables{ + IP: ExecutablesIP, + Nft: ExecutablesNft, + Resolvectl: ExecutablesResolvectl, + Sysctl: ExecutablesSysctl, + } +} + +// SplitRouting default values. +var ( + // SplitRoutingRoutingTable is the routing table. + SplitRoutingRoutingTable = "42111" + + // SplitRoutingRulePriority1 is the first routing rule priority. It must be unique, + // higher than the local rule, lower than the main and default rules, + // lower than the second routing rule priority. + SplitRoutingRulePriority1 = "2111" + + // SplitRoutingRulePriority2 is the second routing rule priority. It must be unique, + // higher than the local rule, lower than the main and default rules, + // higher than the first routing rule priority. + SplitRoutingRulePriority2 = "2112" + + // SplitRoutingFirewallMark is the firewall mark used for split routing. + SplitRoutingFirewallMark = SplitRoutingRoutingTable +) + +// SplitRouting is the split routing configuration. +type SplitRouting struct { + RoutingTable string + RulePriority1 string + RulePriority2 string + FirewallMark string +} + +// Copy returns a copy of the SplitRouting configuration. +func (c *SplitRouting) Copy() *SplitRouting { + s := *c + return &s +} + +// Valid returns whether the split routing configuration is valid. +func (c *SplitRouting) Valid() bool { + if c == nil || + c.RoutingTable == "" || + c.RulePriority1 == "" || + c.RulePriority2 == "" || + c.FirewallMark == "" { + + return false + } + + // check routing table value: must be > 0, < 0xFFFFFFFF + rtTable, err := strconv.ParseUint(c.RoutingTable, 10, 32) + if err != nil || rtTable == 0 || rtTable >= 0xFFFFFFFF { + return false + } + + // check rule priority values: must be > 0, < 32766, prio1 < prio2 + prio1, err := strconv.ParseUint(c.RulePriority1, 10, 16) + if err != nil { + return false + } + prio2, err := strconv.ParseUint(c.RulePriority2, 10, 16) + if err != nil { + return false + } + if prio1 == 0 || prio2 == 0 || + prio1 >= 32766 || prio2 >= 32766 || + prio1 >= prio2 { + + return false + } + + // check fwmark value: must be 32 bit unsigned int + if _, err := strconv.ParseUint(c.FirewallMark, 10, 32); err != nil { + return false + } + + return true +} + +// NewSplitRouting returns a new split routing configuration. +func NewSplitRouting() *SplitRouting { + return &SplitRouting{ + RoutingTable: SplitRoutingRoutingTable, + RulePriority1: SplitRoutingRulePriority1, + RulePriority2: SplitRoutingRulePriority2, + FirewallMark: SplitRoutingFirewallMark, + } +} + +// Traffic Policing default values. +var ( + // AllowedHosts is the default list of allowed hosts, this is + // initialized with hosts for captive portal detection, e.g., + // used by browsers. + AllowedHosts = []string{ + "connectivity-check.ubuntu.com", // ubuntu + "detectportal.firefox.com", // firefox + "www.gstatic.com", // chrome + "clients3.google.com", // chromium + "nmcheck.gnome.org", // gnome + } + + // PortalPorts are the default ports that are allowed to register on a + // captive portal. + PortalPorts = []uint16{ + 80, + 443, + } + + // ResolveTimeout is the timeout for dns lookups. + ResolveTimeout = 2 * time.Second + + // ResolveTries is the number of tries for dns lookups. + ResolveTries = 3 + + // ResolveTriesSleep is the sleep time between retries. + ResolveTriesSleep = time.Second + + // ResolveTimer is the time for periodic resolve update checks, + // should be higher than tries * (timeout + sleep). + ResolveTimer = 30 * time.Second + + // ResolveTTL is the lifetime of resolved entries. + ResolveTTL = 300 * time.Second +) + +// TrafficPolicing is a TrafPol configuration. +type TrafficPolicing struct { + AllowedHosts []string + PortalPorts []uint16 + + ResolveTimeout time.Duration + ResolveTries int + ResolveTriesSleep time.Duration + ResolveTimer time.Duration + ResolveTTL time.Duration +} + +// Copy returns a copy of the TrafficPolicing configuration. +func (c *TrafficPolicing) Copy() *TrafficPolicing { + trafpol := *c + trafpol.AllowedHosts = append(c.AllowedHosts[:0:0], c.AllowedHosts...) + trafpol.PortalPorts = append(c.PortalPorts[:0:0], c.PortalPorts...) + + return &trafpol +} + +// Valid returns whether the TrafPol configuration is valid. +func (c *TrafficPolicing) Valid() bool { + if c == nil || + len(c.PortalPorts) == 0 || + c.ResolveTimeout < 0 || + c.ResolveTries < 1 || + c.ResolveTriesSleep < 0 || + c.ResolveTimer < 0 || + c.ResolveTTL < 0 { + + return false + } + return true +} + +// NewTrafficPolicing returns a new TrafPol configuration. +func NewTrafficPolicing() *TrafficPolicing { + return &TrafficPolicing{ + AllowedHosts: append(AllowedHosts[:0:0], AllowedHosts...), + PortalPorts: append(PortalPorts[:0:0], PortalPorts...), + + ResolveTimeout: ResolveTimeout, + ResolveTries: ResolveTries, + ResolveTriesSleep: ResolveTriesSleep, + ResolveTimer: ResolveTimer, + ResolveTTL: ResolveTTL, + } +} + +// VPNDevice is a VPN device configuration in VPNConfig. +type VPNDevice struct { + Name string + MTU int +} + +// Copy returns a copy of the VPN device. +func (d *VPNDevice) Copy() VPNDevice { + return VPNDevice{ + Name: d.Name, + MTU: d.MTU, + } +} + +// VPNDNS is a DNS configuration in VPNConfig. +type VPNDNS struct { + DefaultDomain string + ServersIPv4 []netip.Addr + ServersIPv6 []netip.Addr +} + +// Copy returns a copy of VPNDNS. +func (d *VPNDNS) Copy() VPNDNS { + return VPNDNS{ + DefaultDomain: d.DefaultDomain, + ServersIPv4: append(d.ServersIPv4[:0:0], d.ServersIPv4...), + ServersIPv6: append(d.ServersIPv6[:0:0], d.ServersIPv6...), + } +} + +// Remotes returns a map of DNS remotes from the DNS configuration that maps +// domain "." to the IPv4 and IPv6 DNS servers in the configuration including +// port number 53. +func (d *VPNDNS) Remotes() map[string][]string { + remotes := map[string][]string{} + for _, s := range d.ServersIPv4 { + server := s.String() + ":53" + remotes["."] = append(remotes["."], server) + } + for _, s := range d.ServersIPv6 { + server := "[" + s.String() + "]:53" + remotes["."] = append(remotes["."], server) + } + + return remotes +} + +// VPNSplit is a split routing configuration in VPNConfig. +type VPNSplit struct { + ExcludeIPv4 []netip.Prefix + ExcludeIPv6 []netip.Prefix + ExcludeDNS []string + + ExcludeVirtualSubnetsOnlyIPv4 bool +} + +// Copy returns a copy of VPN split. +func (s *VPNSplit) Copy() VPNSplit { + return VPNSplit{ + ExcludeIPv4: append(s.ExcludeIPv4[:0:0], s.ExcludeIPv4...), + ExcludeIPv6: append(s.ExcludeIPv6[:0:0], s.ExcludeIPv6...), + ExcludeDNS: append(s.ExcludeDNS[:0:0], s.ExcludeDNS...), + + ExcludeVirtualSubnetsOnlyIPv4: s.ExcludeVirtualSubnetsOnlyIPv4, + } +} + +// DNSExcludes returns a list of DNS-based split excludes from the +// split routing configuration. The list contains domain names including the +// trailing ".". +func (s *VPNSplit) DNSExcludes() []string { + excludes := make([]string, len(s.ExcludeDNS)) + for i, e := range s.ExcludeDNS { + excludes[i] = e + "." + } + + return excludes +} + +// VPNFlags are other configuration settings in VPNConfig. +type VPNFlags struct { + DisableAlwaysOnVPN bool +} + +// Copy returns a copy of VPN flags. +func (f *VPNFlags) Copy() VPNFlags { + return VPNFlags{ + DisableAlwaysOnVPN: f.DisableAlwaysOnVPN, + } +} + +// VPNConfig is a VPN configuration. +type VPNConfig struct { + Gateway netip.Addr + PID int + Timeout int + Device VPNDevice + IPv4 netip.Prefix + IPv6 netip.Prefix + DNS VPNDNS + Split VPNSplit + Flags VPNFlags +} + +// Copy returns a copy of the VPN configuration. +func (c *VPNConfig) Copy() *VPNConfig { + if c == nil { + return nil + } + return &VPNConfig{ + Gateway: c.Gateway, + PID: c.PID, + Timeout: c.Timeout, + Device: c.Device.Copy(), + IPv4: c.IPv4, + IPv6: c.IPv6, + DNS: c.DNS.Copy(), + Split: c.Split.Copy(), + Flags: c.Flags.Copy(), + } +} + +// Empty returns whether the VPN configuration is empty. +func (c *VPNConfig) Empty() bool { + empty := &VPNConfig{} + return reflect.DeepEqual(c, empty) +} + +// Valid returns whether the VPN configuration is valid. +func (c *VPNConfig) Valid() bool { + // an empty config is valid + if c.Empty() { + return true + } + + // check config entries + for _, invalid := range []bool{ + !c.Gateway.IsValid(), + c.Device.Name == "", + len(c.Device.Name) > 15, + c.Device.MTU < 68, + c.Device.MTU > 16384, + !c.IPv4.IsValid() && !c.IPv6.IsValid(), + len(c.DNS.ServersIPv4) == 0 && len(c.DNS.ServersIPv6) == 0, + } { + if invalid { + return false + } + } + + return true +} + +// GetVPNConfig converts vpnconf to VPNConfig. +func GetVPNConfig(vpnconf *vpnconfig.Config) *VPNConfig { + // convert gateway + gateway := netip.Addr{} + if g, ok := netip.AddrFromSlice(vpnconf.Gateway); ok { + gateway = g + } + + // convert ipv4 address + pre4 := netip.Prefix{} + if ipv4, ok := netip.AddrFromSlice(vpnconf.IPv4.Address.To4()); ok { + pre4len, _ := vpnconf.IPv4.Netmask.Size() + pre4 = netip.PrefixFrom(ipv4, pre4len) + } + + // convert ipv6 address + pre6 := netip.Prefix{} + if ipv6, ok := netip.AddrFromSlice(vpnconf.IPv6.Address); ok { + pre6len, _ := vpnconf.IPv6.Netmask.Size() + pre6 = netip.PrefixFrom(ipv6, pre6len) + } + + // convert ipv4 dns servers + var dns4 []netip.Addr + for _, a := range vpnconf.DNS.ServersIPv4 { + if d, ok := netip.AddrFromSlice(a.To4()); ok { + dns4 = append(dns4, d) + } + } + + // convert ipv6 dns servers + var dns6 []netip.Addr + for _, a := range vpnconf.DNS.ServersIPv6 { + if d, ok := netip.AddrFromSlice(a); ok { + dns6 = append(dns6, d) + } + } + + // convert ipv4 excludes + var excludes4 []netip.Prefix + for _, a := range vpnconf.Split.ExcludeIPv4 { + if ipv4, ok := netip.AddrFromSlice(a.IP.To4()); ok { + pre4len, _ := a.Mask.Size() + pre4 := netip.PrefixFrom(ipv4, pre4len) + excludes4 = append(excludes4, pre4) + } + } + + // convert ipv6 excludes + var excludes6 []netip.Prefix + for _, a := range vpnconf.Split.ExcludeIPv6 { + if ipv6, ok := netip.AddrFromSlice(a.IP); ok { + pre6len, _ := a.Mask.Size() + pre6 := netip.PrefixFrom(ipv6, pre6len) + excludes6 = append(excludes6, pre6) + } + } + + return &VPNConfig{ + Gateway: gateway, + PID: vpnconf.PID, + Timeout: vpnconf.Timeout, + Device: VPNDevice{ + Name: vpnconf.Device.Name, + MTU: vpnconf.Device.MTU, + }, + IPv4: pre4, + IPv6: pre6, + DNS: VPNDNS{ + DefaultDomain: vpnconf.DNS.DefaultDomain, + ServersIPv4: dns4, + ServersIPv6: dns6, + }, + Split: VPNSplit{ + ExcludeIPv4: excludes4, + ExcludeIPv6: excludes6, + ExcludeDNS: vpnconf.Split.ExcludeDNS, + + ExcludeVirtualSubnetsOnlyIPv4: vpnconf.Split.ExcludeVirtualSubnetsOnlyIPv4, + }, + Flags: VPNFlags{ + DisableAlwaysOnVPN: vpnconf.Flags.DisableAlwaysOnVPN, + }, + } +} + +// Config default values. +var ( + // configDir is the directory for the configuration. + configDir = "/var/lib/oc-daemon" + + // ConfigFile is the default config file. + ConfigFile = configDir + "/oc-daemon.json" +) + +// Config is an OC-Daemon configuration. +type Config struct { + Config string + Verbose bool + + SocketServer *SocketServer + CPD *CPD + DNSProxy *DNSProxy + OpenConnect *OpenConnect + Executables *Executables + SplitRouting *SplitRouting + TrafficPolicing *TrafficPolicing + TND *tnd.Config + + LoginInfo *logininfo.LoginInfo + VPNConfig *VPNConfig +} + +// Copy returns a copy of the configuration. +func (c *Config) Copy() *Config { + t := *c.TND + return &Config{ + Config: c.Config, + Verbose: c.Verbose, + + SocketServer: c.SocketServer.Copy(), + CPD: c.CPD.Copy(), + DNSProxy: c.DNSProxy.Copy(), + OpenConnect: c.OpenConnect.Copy(), + Executables: c.Executables.Copy(), + SplitRouting: c.SplitRouting.Copy(), + TrafficPolicing: c.TrafficPolicing.Copy(), + TND: &t, + + LoginInfo: c.LoginInfo.Copy(), + VPNConfig: c.VPNConfig.Copy(), + } +} + +// String returns the configuration as string. +func (c *Config) String() string { + b, _ := json.Marshal(c) + return string(b) +} + +// loginInfoEmpty returns whether l is empty. +func loginInfoEmpty(l *logininfo.LoginInfo) bool { + empty := &logininfo.LoginInfo{} + return reflect.DeepEqual(l, empty) +} + +// Valid returns whether config is valid. +func (c *Config) Valid() bool { + if c == nil || + !c.SocketServer.Valid() || + !c.CPD.Valid() || + !c.DNSProxy.Valid() || + !c.OpenConnect.Valid() || + !c.Executables.Valid() || + !c.SplitRouting.Valid() || + !c.TrafficPolicing.Valid() || + !c.TND.Valid() || + !loginInfoEmpty(c.LoginInfo) && !c.LoginInfo.Valid() || + !c.VPNConfig.Valid() { + // invalid + return false + } + return true +} + +// Load loads the configuration from the config file. +func (c *Config) Load() error { + // read file contents + file, err := os.ReadFile(c.Config) + if err != nil { + return err + } + + // save current values to detect not allowed config file entries + oldConfig := c.Config + c.Config = "" + oldLoginInfo := c.LoginInfo + c.LoginInfo = nil + oldVPNConfig := c.VPNConfig + c.VPNConfig = nil + + // parse config + if err := json.Unmarshal(file, c); err != nil { + return err + } + + // check not allowed entries in config file + if c.Config != "" { + return errors.New("configuration file must not include Config") + } + if c.LoginInfo != nil { + return errors.New("configuration file must not include LoginInfo") + } + if c.VPNConfig != nil { + return errors.New("configuration file must not include VPNConfig") + } + + // reset saved values + c.Config = oldConfig + c.LoginInfo = oldLoginInfo + c.VPNConfig = oldVPNConfig + + return nil +} + +// NewConfig returns a new Config. +func NewConfig() *Config { + return &Config{ + Config: ConfigFile, + Verbose: false, + + SocketServer: NewSocketServer(), + CPD: NewCPD(), + DNSProxy: NewDNSProxy(), + OpenConnect: NewOpenConnect(), + Executables: NewExecutables(), + SplitRouting: NewSplitRouting(), + TrafficPolicing: NewTrafficPolicing(), + TND: tnd.NewConfig(), + + LoginInfo: &logininfo.LoginInfo{}, + VPNConfig: &VPNConfig{}, + } +} diff --git a/internal/daemoncfg/config_test.go b/internal/daemoncfg/config_test.go new file mode 100644 index 00000000..5c436770 --- /dev/null +++ b/internal/daemoncfg/config_test.go @@ -0,0 +1,923 @@ +package daemoncfg + +import ( + "log" + "net" + "net/netip" + "os" + "path/filepath" + "reflect" + "strings" + "testing" + "time" + + "github.com/telekom-mms/oc-daemon/pkg/logininfo" + "github.com/telekom-mms/oc-daemon/pkg/vpnconfig" + "github.com/telekom-mms/tnd/pkg/tnd" +) + +// TestSocketServerValid tests Valid of SocketServer. +func TestSocketServerValid(t *testing.T) { + // test invalid + for _, invalid := range []*SocketServer{ + nil, + {}, + {SocketFile: "test.sock", SocketPermissions: "invalid"}, + {SocketFile: "test.sock", SocketPermissions: "1234"}, + } { + want := false + got := invalid.Valid() + if got != want { + t.Errorf("got %t, want %t for %v", got, want, invalid) + } + } + + // test valid + for _, valid := range []*SocketServer{ + NewSocketServer(), + {SocketFile: "test.sock", SocketPermissions: "777"}, + } { + want := true + got := valid.Valid() + if got != want { + t.Errorf("got %t, want %t for %v", got, want, valid) + } + } +} + +// TestNewSocketServer tests NewSocketServer. +func TestNewSocketServer(t *testing.T) { + sc := NewSocketServer() + if !sc.Valid() { + t.Errorf("config is not valid") + } +} + +// TestCPDValid tests Valid of CPD. +func TestCPDValid(t *testing.T) { + // test invalid + for _, invalid := range []*CPD{ + nil, + {}, + } { + if invalid.Valid() { + t.Errorf("config should be invalid: %v", invalid) + } + } + + // test valid + for _, valid := range []*CPD{ + NewCPD(), + { + Host: "some.host.example.com", + HTTPTimeout: 3 * time.Second, + ProbeCount: 5, + ProbeWait: 2 * time.Second, + ProbeTimer: 150 * time.Second, + ProbeTimerDetected: 10 * time.Second, + }, + } { + if !valid.Valid() { + t.Errorf("config should be valid: %v", valid) + } + } +} + +// TestNewCPD tests NewCPD. +func TestNewCPD(t *testing.T) { + c := NewCPD() + if !c.Valid() { + t.Errorf("new config should be valid") + } +} + +// TestDNSProxyValid tests Valid of DNSProxy. +func TestDNSProxyValid(t *testing.T) { + // test invalid + for _, invalid := range []*DNSProxy{ + nil, + {}, + } { + want := false + got := invalid.Valid() + + if got != want { + t.Errorf("got %t, want %t for %v", got, want, invalid) + } + } + + // test valid + for _, valid := range []*DNSProxy{ + NewDNSProxy(), + {Address: "127.0.0.1:4253", ListenUDP: true}, + {Address: "127.0.0.1:4253", ListenTCP: true}, + } { + want := true + got := valid.Valid() + + if got != want { + t.Errorf("got %t, want %t for %v", got, want, valid) + } + } +} + +// TestNewDNSProxy tests NewDNSProxy. +func TestNewDNSProxy(t *testing.T) { + want := &DNSProxy{ + Address: DNSProxyAddress, + ListenUDP: true, + ListenTCP: true, + } + got := NewDNSProxy() + if !reflect.DeepEqual(got, want) { + t.Errorf("got %v, want %v", got, want) + } +} + +// TestOpenConnectValid tests Valid of OpenConnect. +func TestOpenConnectValid(t *testing.T) { + // test invalid + for _, invalid := range []*OpenConnect{ + nil, + {}, + { + OpenConnect: "openconnect", + XMLProfile: "/test/profile", + VPNCScript: "/test/vpncscript", + VPNDevice: "test-device", + PIDFile: "/test/pid", + PIDPermissions: "invalid", + }, + { + OpenConnect: "openconnect", + XMLProfile: "/test/profile", + VPNCScript: "/test/vpncscript", + VPNDevice: "test-device", + PIDFile: "/test/pid", + PIDPermissions: "1234", + }, + } { + want := false + got := invalid.Valid() + + if got != want { + t.Errorf("got %t, want %t for %v", got, want, invalid) + } + } + + // test valid + for _, valid := range []*OpenConnect{ + NewOpenConnect(), + { + OpenConnect: "openconnect", + XMLProfile: "/test/profile", + VPNCScript: "/test/vpncscript", + VPNDevice: "test-device", + PIDFile: "/test/pid", + PIDPermissions: "777", + }, + } { + want := true + got := valid.Valid() + + if got != want { + t.Errorf("got %t, want %t for %v", got, want, valid) + } + } +} + +// TestNewOpenConnect tests NewOpenConnect. +func TestNewOpenConnect(t *testing.T) { + want := &OpenConnect{ + OpenConnect: OpenConnectOpenConnect, + + XMLProfile: OpenConnectXMLProfile, + VPNCScript: OpenConnectVPNCScript, + VPNDevice: OpenConnectVPNDevice, + + PIDFile: OpenConnectPIDFile, + PIDOwner: OpenConnectPIDOwner, + PIDGroup: OpenConnectPIDGroup, + PIDPermissions: OpenConnectPIDPermissions, + + NoProxy: OpenConnectNoProxy, + ExtraEnv: OpenConnectExtraEnv, + ExtraArgs: OpenConnectExtraArgs, + } + got := NewOpenConnect() + if !reflect.DeepEqual(got, want) { + t.Errorf("got %v, want %v", got, want) + } +} + +// TestExecutablesValid tests Valid of Executables. +func TestExecutablesValid(t *testing.T) { + // test invalid + for _, invalid := range []*Executables{ + nil, + {}, + } { + if invalid.Valid() { + t.Errorf("config should be invalid: %v", invalid) + } + } + + // test valid + for _, valid := range []*Executables{ + NewExecutables(), + {"/test/ip", "/test/nft", "/test/resolvectl", "/test/sysctl"}, + } { + if !valid.Valid() { + t.Errorf("config should be valid: %v", valid) + } + } +} + +// TestExecutablesCheckExecutables tests CheckExecutables of Executables. +func TestExecutablesCheckExecutables(t *testing.T) { + // create temporary dir for executables + dir, err := os.MkdirTemp("", "execs-test") + if err != nil { + t.Fatal(err) + } + defer func() { _ = os.RemoveAll(dir) }() + + // create executable file paths + ip := filepath.Join(dir, "ip") + nft := filepath.Join(dir, "nft") + resolvectl := filepath.Join(dir, "resolvectl") + sysctl := filepath.Join(dir, "sysctl") + + // create config with executables + c := &Executables{ + IP: ip, + Nft: nft, + Resolvectl: resolvectl, + Sysctl: sysctl, + } + + // test with not all files existing, create files in the process + for _, f := range []string{ + ip, nft, resolvectl, sysctl, + } { + // test + if got := c.CheckExecutables(); got == nil { + t.Errorf("got nil, want != nil") + } + + // create executable file + if err := os.WriteFile(f, []byte{}, 0777); err != nil { + t.Fatal(err) + } + } + + // test with all files existing + if got := c.CheckExecutables(); got != nil { + t.Errorf("got %v, want nil", got) + } +} + +// TestNewExecutables tests NewExecutables. +func TestNewExecutables(t *testing.T) { + want := &Executables{ExecutablesIP, ExecutablesNft, + ExecutablesResolvectl, ExecutablesSysctl} + got := NewExecutables() + if !reflect.DeepEqual(got, want) { + t.Errorf("got %v, want %v", got, want) + } +} + +// TestSplitRoutingValid tests Valid of SplitRouting. +func TestSplitRoutingValid(t *testing.T) { + // test invalid + for _, invalid := range []*SplitRouting{ + nil, + {}, + { + RoutingTable: "42111", + FirewallMark: "42111", + RulePriority1: "0", + RulePriority2: "1", + }, + { + RoutingTable: "42111", + FirewallMark: "42111", + RulePriority1: "32766", + RulePriority2: "32767", + }, + { + RoutingTable: "42111", + FirewallMark: "42111", + RulePriority1: "2111", + RulePriority2: "2111", + }, + { + RoutingTable: "42111", + FirewallMark: "42111", + RulePriority1: "2112", + RulePriority2: "2111", + }, + { + RoutingTable: "42111", + FirewallMark: "42111", + RulePriority1: "65537", + RulePriority2: "2111", + }, + { + RoutingTable: "42111", + FirewallMark: "42111", + RulePriority1: "2111", + RulePriority2: "65537", + }, + { + RoutingTable: "0", + FirewallMark: "42112", + RulePriority1: "2222", + RulePriority2: "2223", + }, + { + RoutingTable: "4294967295", + FirewallMark: "42112", + RulePriority1: "2222", + RulePriority2: "2223", + }, + { + RoutingTable: "42112", + FirewallMark: "4294967296", + RulePriority1: "2222", + RulePriority2: "2223", + }, + } { + want := false + got := invalid.Valid() + + if got != want { + t.Errorf("got %t, want %t for %v", got, want, invalid) + } + } + + // test valid + for _, valid := range []*SplitRouting{ + NewSplitRouting(), + { + RoutingTable: "42112", + FirewallMark: "42112", + RulePriority1: "2222", + RulePriority2: "2223", + }, + } { + want := true + got := valid.Valid() + + if got != want { + t.Errorf("got %t, want %t for %v", got, want, valid) + } + } +} + +// TestNewSplitRouting tests NewSplitRouting. +func TestNewSplitRouting(t *testing.T) { + c := NewSplitRouting() + if !c.Valid() { + t.Errorf("new config should be valid") + } +} + +// TestTrafficPolicingValid tests Valid of TrafficPolicing. +func TestTrafficPolicingValid(t *testing.T) { + // test invalid + for _, invalid := range []*TrafficPolicing{ + nil, + {}, + } { + want := false + got := invalid.Valid() + + if got != want { + t.Errorf("got %t, want %t for %v", got, want, invalid) + } + } + + // test valid + valid := NewTrafficPolicing() + want := true + got := valid.Valid() + + if got != want { + t.Errorf("got %t, want %t for %v", got, want, valid) + } +} + +// TestNewTrafficPolicing tests NewTrafficPolicing. +func TestNewTrafficPolicing(t *testing.T) { + c := NewTrafficPolicing() + if !c.Valid() { + t.Errorf("new config should be valid") + } +} + +// TestVPNDNSRemotes tests Remotes of VPNDNS. +func TestVPNDNSRemotes(t *testing.T) { + // test empty + c := &VPNConfig{} + if len(c.DNS.Remotes()) != 0 { + t.Errorf("got %d, want 0", len(c.DNS.Remotes())) + } + + // test ipv4 + for _, want := range [][]string{ + {"127.0.0.1:53"}, + {"127.0.0.1:53", "192.168.1.1:53"}, + {"127.0.0.1:53", "192.168.1.1:53", "10.0.0.1:53"}, + } { + c := &VPNConfig{} + for _, ip := range want { + ip = ip[:len(ip)-3] // remove port + c.DNS.ServersIPv4 = append(c.DNS.ServersIPv4, netip.MustParseAddr(ip)) + } + got := c.DNS.Remotes()["."] + if !reflect.DeepEqual(got, want) { + t.Errorf("got %v, want %v", got, want) + } + } + + // test ipv6 + for _, want := range [][]string{ + {"[::1]:53"}, + {"[::1]:53", "[2000::1]:53"}, + {"[::1]:53", "[2000::1]:53", "[2002::1]:53"}, + } { + c := &VPNConfig{} + for _, ip := range want { + ip = ip[1 : len(ip)-4] // remove port and brackets + c.DNS.ServersIPv6 = append(c.DNS.ServersIPv6, netip.MustParseAddr(ip)) + } + got := c.DNS.Remotes()["."] + log.Println(got) + if !reflect.DeepEqual(got, want) { + t.Errorf("got %v, want %v", got, want) + } + } + + // test both ipv4 and ipv6 + c = &VPNConfig{} + dns4 := "127.0.0.1" + dns6 := "::1" + c.DNS.ServersIPv4 = append(c.DNS.ServersIPv4, netip.MustParseAddr(dns4)) + c.DNS.ServersIPv6 = append(c.DNS.ServersIPv6, netip.MustParseAddr(dns6)) + + want := map[string][]string{ + ".": {dns4 + ":53", "[" + dns6 + "]:53"}, + } + got := c.DNS.Remotes() + if !reflect.DeepEqual(got, want) { + t.Errorf("got %v, want %v", got, want) + } +} + +// TestVPNSplitDNSExcludes tests DNSExcludes of VPNSplit. +func TestVPNSplitDNSExcludes(t *testing.T) { + // test empty + c := &VPNConfig{} + if len(c.Split.DNSExcludes()) != 0 { + t.Errorf("got %d, want 0", len(c.Split.DNSExcludes())) + } + + // test filled + c = &VPNConfig{} + want := []string{"example.com", "test.com"} + c.Split.ExcludeDNS = want + for i, got := range c.Split.DNSExcludes() { + want := want[i] + "." + if got != want { + t.Errorf("got %s, want %s", got, want) + } + } +} + +// TestVPNConfigCopy tests Copy of VPNConfig. +func TestVPNConfigCopy(t *testing.T) { + // test nil + if (*VPNConfig)(nil).Copy() != nil { + t.Error("copy of nil should be nil") + } + + // test with new config + want := &VPNConfig{} + got := want.Copy() + + if !reflect.DeepEqual(got, want) { + t.Errorf("got %v, want %v", got, want) + } + + // test modification after copy + c1 := &VPNConfig{} + c2 := c1.Copy() + c1.PID = 12345 + c1.Split.ExcludeIPv4 = append(c1.Split.ExcludeIPv4, netip.MustParsePrefix("192.168.1.0/24")) + + if reflect.DeepEqual(c1, c2) { + t.Error("copies should not be equal after modification") + } +} + +// getValidTestVPNConfig returns a valid VPNConfig for testing. +func getValidTestVPNConfig() *VPNConfig { + c := &VPNConfig{} + + c.Gateway = netip.MustParseAddr("192.168.0.1") + c.PID = 123456 + c.Timeout = 300 + c.Device.Name = "tun0" + c.Device.MTU = 1300 + c.IPv4 = netip.MustParsePrefix("192.168.0.123/24") + c.IPv6 = netip.MustParsePrefix("2001:42:42:42::1/64") + c.DNS.DefaultDomain = "mycompany.com" + c.DNS.ServersIPv4 = []netip.Addr{netip.MustParseAddr("192.168.0.53")} + c.DNS.ServersIPv6 = []netip.Addr{netip.MustParseAddr("2001:53:53:53::53")} + c.Split.ExcludeIPv4 = []netip.Prefix{ + netip.MustParsePrefix("0.0.0.0/32"), + netip.MustParsePrefix("10.0.0.0/24"), + } + c.Split.ExcludeIPv6 = []netip.Prefix{ + netip.MustParsePrefix("2001:2:3:4::1/128"), + netip.MustParsePrefix("2001:2:3:5::1/64"), + } + c.Split.ExcludeDNS = []string{"this.other.com", "that.other.com"} + c.Split.ExcludeVirtualSubnetsOnlyIPv4 = true + c.Flags.DisableAlwaysOnVPN = true + + return c +} + +// TestVPNConfigValid tests Valid of VPNConfig. +func TestVPNConfigValid(t *testing.T) { + // test invalid + for _, invalid := range []*VPNConfig{ + // only device name set, invalid device name + {Device: VPNDevice{Name: "this is too long for a device name"}}, + // only PID set + {PID: 123}, + } { + if invalid.Valid() { + t.Errorf("%v should not be valid", invalid) + } + } + + // test valid + for _, valid := range []*VPNConfig{ + // empty + {}, + // full valid config + getValidTestVPNConfig(), + } { + if !valid.Valid() { + t.Errorf("%v should be valid", valid) + } + } +} + +// TestGetVPNConfig tests GetVPNConfig. +func TestGetVPNConfig(t *testing.T) { + // create vpnconfig.Config + c := vpnconfig.New() + + c.Gateway = net.IPv4(192, 168, 0, 1) + c.PID = 123456 + c.Timeout = 300 + c.Device.Name = "tun0" + c.Device.MTU = 1300 + c.IPv4.Address = net.IPv4(192, 168, 0, 123) + c.IPv4.Netmask = net.IPv4Mask(255, 255, 255, 0) + c.IPv6.Address = net.ParseIP("2001:42:42:42::1") + c.IPv6.Netmask = net.CIDRMask(64, 128) + c.DNS.DefaultDomain = "mycompany.com" + c.DNS.ServersIPv4 = []net.IP{net.IPv4(192, 168, 0, 53)} + c.DNS.ServersIPv6 = []net.IP{net.ParseIP("2001:53:53:53::53")} + c.Split.ExcludeIPv4 = []*net.IPNet{ + { + IP: net.IPv4(0, 0, 0, 0), + Mask: net.IPv4Mask(255, 255, 255, 255), + }, + { + IP: net.IPv4(10, 0, 0, 0), + Mask: net.IPv4Mask(255, 255, 255, 0), + }, + } + c.Split.ExcludeIPv6 = []*net.IPNet{ + { + IP: net.ParseIP("2001:2:3:4::1"), + Mask: net.CIDRMask(128, 128), + }, + { + IP: net.ParseIP("2001:2:3:5::1"), + Mask: net.CIDRMask(64, 128), + }, + } + c.Split.ExcludeDNS = []string{"this.other.com", "that.other.com"} + c.Split.ExcludeVirtualSubnetsOnlyIPv4 = true + c.Flags.DisableAlwaysOnVPN = true + + // convert and check + got := GetVPNConfig(c) + if got.Gateway.Unmap().String() != "192.168.0.1" || + got.PID != c.PID || + got.Timeout != c.Timeout || + got.Device.Name != c.Device.Name || + got.Device.MTU != c.Device.MTU || + got.IPv4.String() != "192.168.0.123/24" || + got.IPv6.String() != "2001:42:42:42::1/64" || + got.DNS.DefaultDomain != c.DNS.DefaultDomain || + got.DNS.ServersIPv4[0].String() != "192.168.0.53" || + got.DNS.ServersIPv6[0].String() != "2001:53:53:53::53" || + got.Split.ExcludeIPv4[0].String() != "0.0.0.0/32" || + got.Split.ExcludeIPv4[1].String() != "10.0.0.0/24" || + got.Split.ExcludeIPv6[0].String() != "2001:2:3:4::1/128" || + got.Split.ExcludeIPv6[1].String() != "2001:2:3:5::1/64" || + got.Split.ExcludeDNS[0] != "this.other.com" || + got.Split.ExcludeDNS[1] != "that.other.com" || + got.Split.ExcludeVirtualSubnetsOnlyIPv4 != c.Split.ExcludeVirtualSubnetsOnlyIPv4 || + got.Flags.DisableAlwaysOnVPN != c.Flags.DisableAlwaysOnVPN { + t.Errorf("invalid conversion: %+v", got) + } +} + +// TestConfigCopy tests Copy of Config. +func TestConfigCopy(t *testing.T) { + // test with new config + want := NewConfig() + got := want.Copy() + + if !reflect.DeepEqual(got, want) { + t.Errorf("got %v, want %v", got, want) + } + + // test modification after copy + c1 := NewConfig() + c2 := c1.Copy() + c1.Verbose = !c2.Verbose + c1.LoginInfo.Cookie = "something else" + c1.VPNConfig.PID = 123456 + + if reflect.DeepEqual(c1, c2) { + t.Error("copies should not be equal after modification") + } +} + +// TestConfigString tests String of Config. +func TestConfigString(t *testing.T) { + // test new config + c := NewConfig() + if c.String() == "" { + t.Errorf("string should not be empty: %s", c.String()) + } + + // test nil + c = nil + if c.String() != "null" { + t.Errorf("string should be null: %s", c.String()) + } +} + +// TestConfigValid tests Valid of Config. +func TestConfigValid(t *testing.T) { + // test invalid + for _, invalid := range []*Config{ + nil, + {}, + } { + want := false + got := invalid.Valid() + + if got != want { + t.Errorf("got %t, want %t for %v", got, want, invalid) + } + } + + // test valid + valid := NewConfig() + want := true + got := valid.Valid() + + if got != want { + t.Errorf("got %t, want %t for %v", got, want, valid) + } +} + +// TestConfigLoad tests Load of Config. +func TestConfigLoad(t *testing.T) { + conf := NewConfig() + conf.Config = "does not exist" + + // test invalid path + err := conf.Load() + if err == nil { + t.Errorf("got != nil, want nil") + } + + // test empty config file + empty, err := os.CreateTemp("", "oc-daemon-config-test") + if err != nil { + t.Fatal(err) + } + defer func() { + _ = os.Remove(empty.Name()) + }() + + conf = NewConfig() + conf.Config = empty.Name() + err = conf.Load() + if err == nil { + t.Errorf("got != nil, want nil") + } + + // test invalid config files + // - with Config + // - with LoginInfo + // - with VPNConfig + for _, content := range []string{ + `{ + "Config": "should not be here", + "Verbose": true +}`, + `{ + "Verbose": true, + "LoginInfo": { + "Server": "192.168.1.1" + } +}`, + `{ + "Verbose": true, + "VPNConfig": { + "Gateway": "192.168.1.1" + } +}`, + } { + invalid, err := os.CreateTemp("", "oc-daemon-config-test") + if err != nil { + t.Fatal(err) + } + defer func() { + _ = os.Remove(invalid.Name()) + }() + + if _, err := invalid.Write([]byte(content)); err != nil { + t.Fatal(err) + } + + conf := NewConfig() + conf.Config = invalid.Name() + if err := conf.Load(); err == nil || + !strings.HasPrefix(err.Error(), "configuration file must not include") { + t.Errorf("should not load invalid config: %s", content) + } + } + + // test valid config file + // - complete config + // - partial config with defaults + for _, content := range []string{ + `{ + "Verbose": true, + "SocketServer": { + "SocketFile": "/run/oc-daemon/daemon.sock", + "SocketOwner": "", + "SocketGroup": "", + "SocketPermissions": "0700", + "RequestTimeout": 30000000000 + }, + "CPD": { + "Host": "connectivity-check.ubuntu.com", + "HTTPTimeout": 5000000000, + "ProbeCount": 3, + "ProbeTimer": 300000000000, + "ProbeTimerDetected": 15000000000 + }, + "DNSProxy": { + "Address": "127.0.0.1:4253", + "ListenUDP": true, + "ListenTCP": true + }, + "OpenConnect": { + "OpenConnect": "openconnect", + "XMLProfile": "/var/lib/oc-daemon/profile.xml", + "VPNCScript": "/usr/bin/oc-daemon-vpncscript", + "VPNDevice": "oc-daemon-tun0", + "PIDFile": "/run/oc-daemon/openconnect.pid", + "PIDOwner": "", + "PIDGroup": "", + "PIDPermissions": "0600" + }, + "Executables": { + "IP": "ip", + "Nft": "nft", + "Resolvectl": "resolvectl", + "Sysctl": "sysctl" + }, + "SplitRouting": { + "RoutingTable": "42111", + "RulePriority1": "2111", + "RulePriority2": "2112", + "FirewallMark": "42111" + }, + "TrafficPolicing": { + "AllowedHosts": ["connectivity-check.ubuntu.com", "detectportal.firefox.com", "www.gstatic.com", "clients3.google.com", "nmcheck.gnome.org"], + "PortalPorts": [80, 443], + "ResolveTimeout": 2000000000, + "ResolveTries": 3, + "ResolveTriesSleep": 1000000000, + "ResolveTTL": 300000000000 + }, + "TND": { + "WaitCheck": 1000000000, + "HTTPSTimeout": 5000000000, + "UntrustedTimer": 30000000000, + "TrustedTimer": 60000000000 + } +}`, + `{ + "Verbose": true +}`, + } { + + valid, err := os.CreateTemp("", "oc-daemon-config-test") + if err != nil { + t.Fatal(err) + } + defer func() { + _ = os.Remove(valid.Name()) + }() + + if _, err := valid.Write([]byte(content)); err != nil { + t.Fatal(err) + } + + conf := NewConfig() + conf.Config = valid.Name() + if err := conf.Load(); err != nil { + t.Errorf("could not load valid config: %s", err) + } + + if !conf.Valid() { + t.Errorf("config is not valid") + } + + want := &Config{ + Config: valid.Name(), + Verbose: true, + SocketServer: NewSocketServer(), + CPD: NewCPD(), + DNSProxy: NewDNSProxy(), + OpenConnect: NewOpenConnect(), + Executables: NewExecutables(), + SplitRouting: NewSplitRouting(), + TrafficPolicing: NewTrafficPolicing(), + TND: tnd.NewConfig(), + LoginInfo: &logininfo.LoginInfo{}, + VPNConfig: &VPNConfig{}, + } + if !reflect.DeepEqual(want.DNSProxy, conf.DNSProxy) { + t.Errorf("got %v, want %v", conf.DNSProxy, want.DNSProxy) + } + if !reflect.DeepEqual(want.OpenConnect, conf.OpenConnect) { + t.Errorf("got %v, want %v", conf.OpenConnect, want.OpenConnect) + } + if !reflect.DeepEqual(want.Executables, conf.Executables) { + t.Errorf("got %v, want %v", conf.Executables, want.Executables) + } + if !reflect.DeepEqual(want.SplitRouting, conf.SplitRouting) { + t.Errorf("got %v, want %v", conf.SplitRouting, want.SplitRouting) + } + if !reflect.DeepEqual(want.TrafficPolicing, conf.TrafficPolicing) { + t.Errorf("got %v, want %v", conf.TrafficPolicing, want.TrafficPolicing) + } + if !reflect.DeepEqual(want.TND, conf.TND) { + t.Errorf("got %v, want %v", conf.TND, want.TND) + } + if !reflect.DeepEqual(want, conf) { + t.Errorf("got %v, want %v", conf, want) + } + } +} + +// TestNewConfig tests NewConfig. +func TestNewConfig(t *testing.T) { + want := &Config{ + Config: "/var/lib/oc-daemon/oc-daemon.json", + Verbose: false, + SocketServer: NewSocketServer(), + CPD: NewCPD(), + DNSProxy: NewDNSProxy(), + OpenConnect: NewOpenConnect(), + Executables: NewExecutables(), + SplitRouting: NewSplitRouting(), + TrafficPolicing: NewTrafficPolicing(), + TND: tnd.NewConfig(), + LoginInfo: &logininfo.LoginInfo{}, + VPNConfig: &VPNConfig{}, + } + got := NewConfig() + if !reflect.DeepEqual(got, want) { + t.Errorf("got %v, want %v", got, want) + } +} diff --git a/internal/dnsproxy/config.go b/internal/dnsproxy/config.go deleted file mode 100644 index 879e069e..00000000 --- a/internal/dnsproxy/config.go +++ /dev/null @@ -1,39 +0,0 @@ -package dnsproxy - -var ( - // Address is the default listen address of the DNS proxy. - Address = "127.0.0.1:4253" - - // ListenUDP specifies whether the DNS proxy listens on UDP. - ListenUDP = true - - // ListenTCP specifies whether the DNS proxy listens on TCP. - ListenTCP = true -) - -// Config is a DNS proxy configuration. -type Config struct { - Address string - ListenUDP bool - ListenTCP bool -} - -// Valid returns whether the DNS proxy configuration is valid. -func (c *Config) Valid() bool { - if c == nil || - c.Address == "" || - (!c.ListenUDP && !c.ListenTCP) { - - return false - } - return true -} - -// NewConfig returns a new DNS proxy configuration. -func NewConfig() *Config { - return &Config{ - Address: Address, - ListenUDP: ListenUDP, - ListenTCP: ListenTCP, - } -} diff --git a/internal/dnsproxy/config_test.go b/internal/dnsproxy/config_test.go deleted file mode 100644 index cb3688a3..00000000 --- a/internal/dnsproxy/config_test.go +++ /dev/null @@ -1,49 +0,0 @@ -package dnsproxy - -import ( - "reflect" - "testing" -) - -// TestConfigValid tests Valid of Config. -func TestConfigValid(t *testing.T) { - // test invalid - for _, invalid := range []*Config{ - nil, - {}, - } { - want := false - got := invalid.Valid() - - if got != want { - t.Errorf("got %t, want %t for %v", got, want, invalid) - } - } - - // test valid - for _, valid := range []*Config{ - NewConfig(), - {Address: "127.0.0.1:4253", ListenUDP: true}, - {Address: "127.0.0.1:4253", ListenTCP: true}, - } { - want := true - got := valid.Valid() - - if got != want { - t.Errorf("got %t, want %t for %v", got, want, valid) - } - } -} - -// TestNewConfig tests NewConfig. -func TestNewConfig(t *testing.T) { - want := &Config{ - Address: Address, - ListenUDP: true, - ListenTCP: true, - } - got := NewConfig() - if !reflect.DeepEqual(got, want) { - t.Errorf("got %v, want %v", got, want) - } -} diff --git a/internal/dnsproxy/proxy.go b/internal/dnsproxy/proxy.go index 0e72c465..b0bf06b6 100644 --- a/internal/dnsproxy/proxy.go +++ b/internal/dnsproxy/proxy.go @@ -7,11 +7,11 @@ import ( "github.com/miekg/dns" log "github.com/sirupsen/logrus" + "github.com/telekom-mms/oc-daemon/internal/daemoncfg" ) // State is the internal state of the DNS Proxy. type State struct { - Config *Config Remotes map[string][]string Watches []string TempWatches []string @@ -19,7 +19,7 @@ type State struct { // Proxy is a DNS proxy. type Proxy struct { - config *Config + config *daemoncfg.DNSProxy udp *dns.Server tcp *dns.Server remotes *Remotes @@ -259,7 +259,6 @@ func (p *Proxy) SetWatches(watches []string) { func (p *Proxy) GetState() *State { watches, tempWatches := p.watches.List() return &State{ - Config: p.config, Remotes: p.remotes.List(), Watches: watches, TempWatches: tempWatches, @@ -267,7 +266,7 @@ func (p *Proxy) GetState() *State { } // NewProxy returns a new Proxy that listens on address. -func NewProxy(config *Config) *Proxy { +func NewProxy(config *daemoncfg.DNSProxy) *Proxy { var udp *dns.Server if config.ListenUDP { udp = &dns.Server{ diff --git a/internal/dnsproxy/proxy_test.go b/internal/dnsproxy/proxy_test.go index 5a580e80..e6b73aa9 100644 --- a/internal/dnsproxy/proxy_test.go +++ b/internal/dnsproxy/proxy_test.go @@ -9,11 +9,12 @@ import ( "github.com/miekg/dns" log "github.com/sirupsen/logrus" + "github.com/telekom-mms/oc-daemon/internal/daemoncfg" ) // getTestConfig returns a config for testing. -func getTestConfig() *Config { - return &Config{ +func getTestConfig() *daemoncfg.DNSProxy { + return &daemoncfg.DNSProxy{ Address: "127.0.0.1:4254", ListenUDP: true, ListenTCP: true, @@ -263,7 +264,7 @@ func TestProxySetRemotes(_ *testing.T) { // TestProxySetWatches tests SetWatches of Proxy. func TestProxySetWatches(_ *testing.T) { - config := &Config{ + config := &daemoncfg.DNSProxy{ Address: "127.0.0.1:4254", ListenUDP: true, ListenTCP: true, @@ -290,7 +291,6 @@ func TestProxyGetState(t *testing.T) { // check state want := &State{ - Config: getTestConfig(), Remotes: getRemotes(), Watches: []string{"example.com."}, TempWatches: []string{"cname.example.com.", "dname.example.com."}, diff --git a/internal/execs/config.go b/internal/execs/config.go deleted file mode 100644 index 8f86f71d..00000000 --- a/internal/execs/config.go +++ /dev/null @@ -1,55 +0,0 @@ -package execs - -import "os/exec" - -// default values. -var ( - IP = "ip" - Nft = "nft" - Resolvectl = "resolvectl" - Sysctl = "sysctl" -) - -// Config is executables configuration. -type Config struct { - IP string - Nft string - Resolvectl string - Sysctl string -} - -// Valid returns whether config is valid. -func (c *Config) Valid() bool { - if c == nil || - c.IP == "" || - c.Nft == "" || - c.Resolvectl == "" || - c.Sysctl == "" { - // invalid - return false - } - return true -} - -// CheckExecutables checks whether executables in config exist in the -// file system and are executable. -func (c *Config) CheckExecutables() error { - for _, f := range []string{ - c.IP, c.Nft, c.Resolvectl, c.Sysctl, - } { - if _, err := exec.LookPath(f); err != nil { - return err - } - } - return nil -} - -// NewConfig returns a new Config. -func NewConfig() *Config { - return &Config{ - IP: IP, - Nft: Nft, - Resolvectl: Resolvectl, - Sysctl: Sysctl, - } -} diff --git a/internal/execs/config_test.go b/internal/execs/config_test.go deleted file mode 100644 index e399deac..00000000 --- a/internal/execs/config_test.go +++ /dev/null @@ -1,84 +0,0 @@ -package execs - -import ( - "os" - "path/filepath" - "reflect" - "testing" -) - -// TestConfigValid tests Valid of Config. -func TestConfigValid(t *testing.T) { - // test invalid - for _, invalid := range []*Config{ - nil, - {}, - } { - if invalid.Valid() { - t.Errorf("config should be invalid: %v", invalid) - } - } - - // test valid - for _, valid := range []*Config{ - NewConfig(), - {"/test/ip", "/test/nft", "/test/resolvectl", "/test/sysctl"}, - } { - if !valid.Valid() { - t.Errorf("config should be valid: %v", valid) - } - } -} - -// TestConfigCheckExecutables tests CheckExecutables of Config. -func TestConfigCheckExecutables(t *testing.T) { - // create temporary dir for executables - dir, err := os.MkdirTemp("", "execs-test") - if err != nil { - t.Fatal(err) - } - defer func() { _ = os.RemoveAll(dir) }() - - // create executable file paths - ip := filepath.Join(dir, "ip") - nft := filepath.Join(dir, "nft") - resolvectl := filepath.Join(dir, "resolvectl") - sysctl := filepath.Join(dir, "sysctl") - - // create config with executables - c := &Config{ - IP: ip, - Nft: nft, - Resolvectl: resolvectl, - Sysctl: sysctl, - } - - // test with not all files existing, create files in the process - for _, f := range []string{ - ip, nft, resolvectl, sysctl, - } { - // test - if got := c.CheckExecutables(); got == nil { - t.Errorf("got nil, want != nil") - } - - // create executable file - if err := os.WriteFile(f, []byte{}, 0777); err != nil { - t.Fatal(err) - } - } - - // test with all files existing - if got := c.CheckExecutables(); got != nil { - t.Errorf("got %v, want nil", got) - } -} - -// TestNewConfig tests NewConfig. -func TestNewConfig(t *testing.T) { - want := &Config{IP, Nft, Resolvectl, Sysctl} - got := NewConfig() - if !reflect.DeepEqual(got, want) { - t.Errorf("got %v, want %v", got, want) - } -} diff --git a/internal/execs/execs.go b/internal/execs/execs.go index 44e1253b..1c035f8c 100644 --- a/internal/execs/execs.go +++ b/internal/execs/execs.go @@ -5,14 +5,16 @@ import ( "bytes" "context" "os/exec" + + "github.com/telekom-mms/oc-daemon/internal/daemoncfg" ) // executables. var ( - ip = IP - sysctl = Sysctl - nft = Nft - resolvectl = Resolvectl + ip = daemoncfg.ExecutablesIP + sysctl = daemoncfg.ExecutablesSysctl + nft = daemoncfg.ExecutablesNft + resolvectl = daemoncfg.ExecutablesResolvectl ) // RunCmd runs the cmd with args and sets stdin to s, returns stdout and stderr. @@ -87,7 +89,7 @@ func RunResolvectl(ctx context.Context, arg ...string) (stdout, stderr []byte, e } // SetExecutables configures all executables from config. -func SetExecutables(config *Config) { +func SetExecutables(config *daemoncfg.Executables) { ip = config.IP sysctl = config.Sysctl nft = config.Nft diff --git a/internal/execs/execs_test.go b/internal/execs/execs_test.go index 6310e1dd..99a2353f 100644 --- a/internal/execs/execs_test.go +++ b/internal/execs/execs_test.go @@ -6,6 +6,8 @@ import ( "reflect" "strings" "testing" + + "github.com/telekom-mms/oc-daemon/internal/daemoncfg" ) // TestRunCmd tests RunCmd. @@ -225,10 +227,10 @@ func TestRunResolvectl(t *testing.T) { // TestSetExecutables tests SetExecutables. func TestSetExecutables(t *testing.T) { - old := NewConfig() + old := daemoncfg.NewExecutables() defer SetExecutables(old) - config := &Config{ + config := &daemoncfg.Executables{ IP: "/test/ip", Sysctl: "/test/sysctl", Nft: "/test/nft", diff --git a/internal/ocrunner/config.go b/internal/ocrunner/config.go deleted file mode 100644 index 2b9d1832..00000000 --- a/internal/ocrunner/config.go +++ /dev/null @@ -1,100 +0,0 @@ -package ocrunner - -import "strconv" - -var ( - // OpenConnect is the default openconnect executable. - OpenConnect = "openconnect" - - // XMLProfile is the default AnyConnect Profile. - XMLProfile = "/var/lib/oc-daemon/profile.xml" - - // VPNCScript is the default vpnc-script. - VPNCScript = "/usr/bin/oc-daemon-vpncscript" - - // VPNDevice is the default vpn network device name. - VPNDevice = "oc-daemon-tun0" - - // PIDFile is the default file path of the PID file for openconnect. - PIDFile = "/run/oc-daemon/openconnect.pid" - - // PIDOwner is the default owner of the PID file. - PIDOwner = "" - - // PIDGroup is the default group of the PID file. - PIDGroup = "" - - // PIDPermissions are the default file permissions of the PID file. - PIDPermissions = "0600" - - // NoProxy specifies whether the no proxy flag is set in openconnect. - NoProxy = true - - // ExtraEnv are extra environment variables used by openconnect. - ExtraEnv = []string{} - - // ExtraArgs are extra command line arguments used by openconnect. - ExtraArgs = []string{} -) - -// Config is the configuration for an openconnect connection runner. -type Config struct { - OpenConnect string - - XMLProfile string - VPNCScript string - VPNDevice string - - PIDFile string - PIDOwner string - PIDGroup string - PIDPermissions string - - NoProxy bool - ExtraEnv []string - ExtraArgs []string -} - -// Valid returns whether the openconnect configuration is valid. -func (c *Config) Valid() bool { - if c == nil || - c.OpenConnect == "" || - c.XMLProfile == "" || - c.VPNCScript == "" || - c.VPNDevice == "" || - c.PIDFile == "" || - c.PIDPermissions == "" { - - return false - } - if c.PIDPermissions != "" { - perm, err := strconv.ParseUint(c.PIDPermissions, 8, 32) - if err != nil { - return false - } - if perm > 0777 { - return false - } - } - return true -} - -// NewConfig returns a new configuration for an openconnect connection runner. -func NewConfig() *Config { - return &Config{ - OpenConnect: OpenConnect, - - XMLProfile: XMLProfile, - VPNCScript: VPNCScript, - VPNDevice: VPNDevice, - - PIDFile: PIDFile, - PIDOwner: PIDOwner, - PIDGroup: PIDGroup, - PIDPermissions: PIDPermissions, - - NoProxy: NoProxy, - ExtraEnv: append(ExtraEnv[:0:0], ExtraEnv...), - ExtraArgs: append(ExtraArgs[:0:0], ExtraArgs...), - } -} diff --git a/internal/ocrunner/config_test.go b/internal/ocrunner/config_test.go deleted file mode 100644 index d1c4db81..00000000 --- a/internal/ocrunner/config_test.go +++ /dev/null @@ -1,82 +0,0 @@ -package ocrunner - -import ( - "reflect" - "testing" -) - -// TestConfigValid tests Valid of Config. -func TestConfigValid(t *testing.T) { - // test invalid - for _, invalid := range []*Config{ - nil, - {}, - { - OpenConnect: "openconnect", - XMLProfile: "/test/profile", - VPNCScript: "/test/vpncscript", - VPNDevice: "test-device", - PIDFile: "/test/pid", - PIDPermissions: "invalid", - }, - { - OpenConnect: "openconnect", - XMLProfile: "/test/profile", - VPNCScript: "/test/vpncscript", - VPNDevice: "test-device", - PIDFile: "/test/pid", - PIDPermissions: "1234", - }, - } { - want := false - got := invalid.Valid() - - if got != want { - t.Errorf("got %t, want %t for %v", got, want, invalid) - } - } - - // test valid - for _, valid := range []*Config{ - NewConfig(), - { - OpenConnect: "openconnect", - XMLProfile: "/test/profile", - VPNCScript: "/test/vpncscript", - VPNDevice: "test-device", - PIDFile: "/test/pid", - PIDPermissions: "777", - }, - } { - want := true - got := valid.Valid() - - if got != want { - t.Errorf("got %t, want %t for %v", got, want, valid) - } - } -} - -// TestNewConfig tests NewConfig. -func TestNewConfig(t *testing.T) { - want := &Config{ - OpenConnect: OpenConnect, - - XMLProfile: XMLProfile, - VPNCScript: VPNCScript, - VPNDevice: VPNDevice, - - PIDFile: PIDFile, - PIDOwner: PIDOwner, - PIDGroup: PIDGroup, - PIDPermissions: PIDPermissions, - - NoProxy: NoProxy, - ExtraEnv: ExtraEnv, - ExtraArgs: ExtraArgs, - } - got := NewConfig() - if !reflect.DeepEqual(got, want) { - t.Errorf("got %v, want %v", got, want) - } -} diff --git a/internal/ocrunner/connect.go b/internal/ocrunner/connect.go index b398bb7b..8023d48a 100644 --- a/internal/ocrunner/connect.go +++ b/internal/ocrunner/connect.go @@ -12,7 +12,7 @@ import ( "syscall" log "github.com/sirupsen/logrus" - "github.com/telekom-mms/oc-daemon/pkg/logininfo" + "github.com/telekom-mms/oc-daemon/internal/daemoncfg" ) // ConnectEvent is a connect runner event. @@ -24,8 +24,8 @@ type ConnectEvent struct { // PID is the process ID of the running openconnect process PID uint32 - // login info for connect - login *logininfo.LoginInfo + // config is the daemon configuration. + config *daemoncfg.Config // Env are extra environment variables set during execution env []string @@ -33,9 +33,6 @@ type ConnectEvent struct { // Connect is a openconnect connection runner. type Connect struct { - // connection runner configuration - config *Config - // openconnect command command *exec.Cmd @@ -74,13 +71,13 @@ func (c *Connect) sendEvent(event *ConnectEvent) { } // setPIDOwner sets the owner of the pid file. -func (c *Connect) setPIDOwner() { - if c.config.PIDOwner == "" { +func (c *Connect) setPIDOwner(config *daemoncfg.Config) { + if config.OpenConnect.PIDOwner == "" { // do not change owner return } - user, err := userLookup(c.config.PIDOwner) + user, err := userLookup(config.OpenConnect.PIDOwner) if err != nil { log.WithError(err).Error("OC-Runner could not get UID of pid file owner") return @@ -92,19 +89,19 @@ func (c *Connect) setPIDOwner() { return } - if err := osChown(c.config.PIDFile, uid, -1); err != nil { + if err := osChown(config.OpenConnect.PIDFile, uid, -1); err != nil { log.WithError(err).Error("OC-Runner could not change owner of pid file") } } // setPIDGroup sets the group of the pid file. -func (c *Connect) setPIDGroup() { - if c.config.PIDGroup == "" { +func (c *Connect) setPIDGroup(config *daemoncfg.Config) { + if config.OpenConnect.PIDGroup == "" { // do not change group return } - group, err := userLookupGroup(c.config.PIDGroup) + group, err := userLookupGroup(config.OpenConnect.PIDGroup) if err != nil { log.WithError(err).Error("OC-Runner could not get GID of pid file group") return @@ -116,13 +113,13 @@ func (c *Connect) setPIDGroup() { return } - if err := osChown(c.config.PIDFile, -1, gid); err != nil { + if err := osChown(config.OpenConnect.PIDFile, -1, gid); err != nil { log.WithError(err).Error("OC-Runner could not change group of pid file") } } // savePidFile saves the running command to pid file. -func (c *Connect) savePidFile() { +func (c *Connect) savePidFile(config *daemoncfg.Config) { if c.command == nil || c.command.Process == nil { return } @@ -131,22 +128,22 @@ func (c *Connect) savePidFile() { pid := fmt.Sprintf("%d\n", c.command.Process.Pid) // convert permissions - perm, err := strconv.ParseUint(c.config.PIDPermissions, 8, 32) + perm, err := strconv.ParseUint(config.OpenConnect.PIDPermissions, 8, 32) if err != nil { log.WithError(err).Error("OC-Runner could not convert permissions of pid file to uint") return } // write pid to file with permissions - err = osWriteFile(c.config.PIDFile, []byte(pid), os.FileMode(perm)) + err = osWriteFile(config.OpenConnect.PIDFile, []byte(pid), os.FileMode(perm)) if err != nil { log.WithError(err).Error("OC-Runner writing pid error") return } // set owner and group - c.setPIDOwner() - c.setPIDGroup() + c.setPIDOwner(config) + c.setPIDGroup(config) } // getPID returns the PID of the running command. @@ -171,12 +168,12 @@ func (c *Connect) handleConnect(e *ConnectEvent) { // // openconnect --cookie-on-stdin $HOST --servercert $FINGERPRINT // - serverCert := fmt.Sprintf("--servercert=%s", e.login.Fingerprint) - xmlConfig := fmt.Sprintf("--xmlconfig=%s", c.config.XMLProfile) - script := fmt.Sprintf("--script=%s", c.config.VPNCScript) - host := e.login.Host - if e.login.ConnectURL != "" { - host = e.login.ConnectURL + serverCert := fmt.Sprintf("--servercert=%s", e.config.LoginInfo.Fingerprint) + xmlConfig := fmt.Sprintf("--xmlconfig=%s", e.config.OpenConnect.XMLProfile) + script := fmt.Sprintf("--script=%s", e.config.OpenConnect.VPNCScript) + host := e.config.LoginInfo.Host + if e.config.LoginInfo.ConnectURL != "" { + host = e.config.LoginInfo.ConnectURL } parameters := []string{ xmlConfig, @@ -185,30 +182,30 @@ func (c *Connect) handleConnect(e *ConnectEvent) { host, serverCert, } - if c.config.NoProxy { + if e.config.OpenConnect.NoProxy { parameters = append(parameters, "--no-proxy") } - if e.login.Resolve != "" { - resolve := fmt.Sprintf("--resolve=%s", e.login.Resolve) + if e.config.LoginInfo.Resolve != "" { + resolve := fmt.Sprintf("--resolve=%s", e.config.LoginInfo.Resolve) parameters = append(parameters, resolve) } - if c.config.VPNDevice != "" { - device := fmt.Sprintf("--interface=%s", c.config.VPNDevice) + if e.config.OpenConnect.VPNDevice != "" { + device := fmt.Sprintf("--interface=%s", e.config.OpenConnect.VPNDevice) parameters = append(parameters, device) } - parameters = append(parameters, c.config.ExtraArgs...) - c.command = execCommand(c.config.OpenConnect, parameters...) + parameters = append(parameters, e.config.OpenConnect.ExtraArgs...) + c.command = execCommand(e.config.OpenConnect.OpenConnect, parameters...) // run command in own process group so it is not canceled by interrupt // signal sent to daemon c.command.SysProcAttr = &syscall.SysProcAttr{Setpgid: true} // run command, pass login info to stdin - b := bytes.NewBufferString(e.login.Cookie) + b := bytes.NewBufferString(e.config.LoginInfo.Cookie) c.command.Stdin = b c.command.Stdout = os.Stdout c.command.Stderr = os.Stderr - c.command.Env = append(os.Environ(), c.config.ExtraEnv...) + c.command.Env = append(os.Environ(), e.config.OpenConnect.ExtraEnv...) c.command.Env = append(c.command.Env, e.env...) if err := c.command.Start(); err != nil { @@ -219,7 +216,7 @@ func (c *Connect) handleConnect(e *ConnectEvent) { } // save pid and cmd line - c.savePidFile() + c.savePidFile(e.config) // signal connect to user c.sendEvent(&ConnectEvent{ @@ -305,10 +302,10 @@ func (c *Connect) Stop() { } // Connect connects the vpn by starting openconnect. -func (c *Connect) Connect(login *logininfo.LoginInfo, env []string) { +func (c *Connect) Connect(config *daemoncfg.Config, env []string) { e := &ConnectEvent{ Connect: true, - login: login, + config: config, env: env, } c.commands <- e @@ -326,10 +323,8 @@ func (c *Connect) Events() chan *ConnectEvent { } // NewConnect returns a new Connect. -func NewConnect(config *Config) *Connect { +func NewConnect() *Connect { return &Connect{ - config: config, - exits: make(chan struct{}), commands: make(chan *ConnectEvent), @@ -341,7 +336,7 @@ func NewConnect(config *Config) *Connect { } // CleanupConnect cleans up connect after a failed shutdown. -func CleanupConnect(config *Config) { +func CleanupConnect(config *daemoncfg.OpenConnect) { // get pid from file b, err := osReadFile(config.PIDFile) if err != nil { diff --git a/internal/ocrunner/connect_test.go b/internal/ocrunner/connect_test.go index 62e094ef..6f179e05 100644 --- a/internal/ocrunner/connect_test.go +++ b/internal/ocrunner/connect_test.go @@ -6,15 +6,15 @@ import ( "os" "os/exec" "os/user" - "reflect" "testing" + "github.com/telekom-mms/oc-daemon/internal/daemoncfg" "github.com/telekom-mms/oc-daemon/pkg/logininfo" ) // TestConnectStartStop tests Start and Stop of Connect. func TestConnectStartStop(_ *testing.T) { - c := NewConnect(NewConfig()) + c := NewConnect() c.Start() c.Stop() } @@ -29,12 +29,12 @@ func TestConnectSavePidFile(t *testing.T) { osChown = os.Chown }() - conf := NewConfig() - conf.PIDFile = t.TempDir() + "pidfile" + conf := daemoncfg.NewConfig() + conf.OpenConnect.PIDFile = t.TempDir() + "pidfile" // no process - c := NewConnect(conf) - c.savePidFile() + c := NewConnect() + c.savePidFile(conf) // with chown error userLookup = func(string) (*user.User, error) { @@ -47,13 +47,13 @@ func TestConnectSavePidFile(t *testing.T) { return errors.New("test error") } - conf.PIDFile = t.TempDir() + "pidfile" - conf.PIDOwner = "test" - conf.PIDGroup = "test" + conf.OpenConnect.PIDFile = t.TempDir() + "pidfile" + conf.OpenConnect.PIDOwner = "test" + conf.OpenConnect.PIDGroup = "test" - c = NewConnect(conf) + c = NewConnect() c.command = &exec.Cmd{Process: &os.Process{Pid: 123}} - c.savePidFile() + c.savePidFile(conf) // with invalid uid/gid userLookup = func(string) (*user.User, error) { @@ -63,9 +63,9 @@ func TestConnectSavePidFile(t *testing.T) { return &user.Group{Gid: "invalid"}, nil } - c = NewConnect(conf) + c = NewConnect() c.command = &exec.Cmd{Process: &os.Process{Pid: 123}} - c.savePidFile() + c.savePidFile(conf) // with user/group lookup error userLookup = func(string) (*user.User, error) { @@ -75,25 +75,25 @@ func TestConnectSavePidFile(t *testing.T) { return nil, errors.New("test error") } - c = NewConnect(conf) + c = NewConnect() c.command = &exec.Cmd{Process: &os.Process{Pid: 123}} - c.savePidFile() + c.savePidFile(conf) // with write error osWriteFile = func(string, []byte, fs.FileMode) error { return errors.New("test error") } - c = NewConnect(conf) + c = NewConnect() c.command = &exec.Cmd{Process: &os.Process{Pid: 123}} - c.savePidFile() + c.savePidFile(conf) // with invalid permissions - conf.PIDPermissions = "invalid" + conf.OpenConnect.PIDPermissions = "invalid" - c = NewConnect(conf) + c = NewConnect() c.command = &exec.Cmd{Process: &os.Process{Pid: 123}} - c.savePidFile() + c.savePidFile(conf) } // TestConnectConnect tests Connect of Connect. @@ -101,7 +101,9 @@ func TestConnectConnect(t *testing.T) { // clean up after tests defer func() { execCommand = exec.Command }() - login := &logininfo.LoginInfo{ + conf := daemoncfg.NewConfig() + conf.OpenConnect.PIDFile = t.TempDir() + "pidfile" + conf.LoginInfo = &logininfo.LoginInfo{ Server: "vpnserver.example.com", Cookie: "3311180634@13561856@1339425499@B315A0E29D16C6FD92EE...", Host: "10.0.0.1", @@ -109,17 +111,15 @@ func TestConnectConnect(t *testing.T) { Fingerprint: "469bb424ec8835944d30bc77c77e8fc1d8e23a42", Resolve: "vpnserver.example.com:10.0.0.1", } - conf := NewConfig() - conf.PIDFile = t.TempDir() + "pidfile" // test with exec error execCommand = func(string, ...string) *exec.Cmd { return exec.Command("") } - c := NewConnect(conf) + c := NewConnect() c.Start() - c.Connect(login, nil) + c.Connect(conf, nil) c.Stop() // test without exec error @@ -127,13 +127,13 @@ func TestConnectConnect(t *testing.T) { return exec.Command("sleep", "10") } - c = NewConnect(conf) + c = NewConnect() c.Start() - c.Connect(login, nil) + c.Connect(conf, nil) <-c.Events() // test double connect - c.Connect(login, nil) + c.Connect(conf, nil) c.Stop() } @@ -148,21 +148,21 @@ func TestConnectDisconnect(t *testing.T) { }() // without connection - c := NewConnect(NewConfig()) + c := NewConnect() c.Start() c.Disconnect() c.Stop() // with connection - conf := NewConfig() + conf := daemoncfg.NewOpenConnect() conf.PIDFile = t.TempDir() + "pidfile" execCommand = func(string, ...string) *exec.Cmd { return exec.Command("sleep", "10") } - c = NewConnect(NewConfig()) + c = NewConnect() c.Start() - c.Connect(&logininfo.LoginInfo{}, nil) + c.Connect(daemoncfg.NewConfig(), nil) <-c.Events() c.Disconnect() <-c.Events() @@ -172,14 +172,14 @@ func TestConnectDisconnect(t *testing.T) { processSignal = func(*os.Process, os.Signal) error { return errors.New("test error") } - c = NewConnect(NewConfig()) + c = NewConnect() c.command = &exec.Cmd{Process: &os.Process{}} c.handleDisconnect() } // TestConnectEvents tests Events of Connect. func TestConnectEvents(t *testing.T) { - c := NewConnect(NewConfig()) + c := NewConnect() want := c.events got := c.Events() @@ -190,14 +190,8 @@ func TestConnectEvents(t *testing.T) { // TestNewConnect tests NewConnect. func TestNewConnect(t *testing.T) { - config := NewConfig() - config.XMLProfile = "/some/profile/file" - config.VPNCScript = "/some/vpnc/script" - config.VPNDevice = "tun999" - c := NewConnect(config) - if !reflect.DeepEqual(c.config, config) { - t.Errorf("got %v, want %v", c.config, config) - } + c := NewConnect() + if c.exits == nil || c.commands == nil || c.done == nil || @@ -223,14 +217,14 @@ func TestCleanupConnect(_ *testing.T) { return nil, errors.New("test error") } - CleanupConnect(NewConfig()) + CleanupConnect(daemoncfg.NewOpenConnect()) // PID file contains garbage osReadFile = func(string) ([]byte, error) { return []byte("garbage"), nil } - CleanupConnect(NewConfig()) + CleanupConnect(daemoncfg.NewOpenConnect()) // cannot read process cmdline reads := 0 @@ -242,7 +236,7 @@ func TestCleanupConnect(_ *testing.T) { return []byte("123"), nil } - CleanupConnect(NewConfig()) + CleanupConnect(daemoncfg.NewOpenConnect()) // process cmdline does not contain openconnect (other process) reads = 0 @@ -254,7 +248,7 @@ func TestCleanupConnect(_ *testing.T) { return []byte("123"), nil } - CleanupConnect(NewConfig()) + CleanupConnect(daemoncfg.NewOpenConnect()) // cannot find process (process already terminated) reads = 0 @@ -269,7 +263,7 @@ func TestCleanupConnect(_ *testing.T) { return nil, errors.New("test error") } - CleanupConnect(NewConfig()) + CleanupConnect(daemoncfg.NewOpenConnect()) // stop process reads = 0 @@ -280,5 +274,5 @@ func TestCleanupConnect(_ *testing.T) { return nil } - CleanupConnect(NewConfig()) + CleanupConnect(daemoncfg.NewOpenConnect()) } diff --git a/internal/splitrt/config.go b/internal/splitrt/config.go deleted file mode 100644 index 7302b2a6..00000000 --- a/internal/splitrt/config.go +++ /dev/null @@ -1,80 +0,0 @@ -package splitrt - -import "strconv" - -var ( - // RoutingTable is the routing table. - RoutingTable = "42111" - - // RulePriority1 is the first routing rule priority. It must be unique, - // higher than the local rule, lower than the main and default rules, - // lower than the second routing rule priority. - RulePriority1 = "2111" - - // RulePriority2 is the second routing rule priority. It must be unique, - // higher than the local rule, lower than the main and default rules, - // higher than the first routing rule priority. - RulePriority2 = "2112" - - // FirewallMark is the firewall mark used for split routing. - FirewallMark = RoutingTable -) - -// Config is a split routing configuration. -type Config struct { - RoutingTable string - RulePriority1 string - RulePriority2 string - FirewallMark string -} - -// Valid returns whether the split routing configuration is valid. -func (c *Config) Valid() bool { - if c == nil || - c.RoutingTable == "" || - c.RulePriority1 == "" || - c.RulePriority2 == "" || - c.FirewallMark == "" { - - return false - } - - // check routing table value: must be > 0, < 0xFFFFFFFF - rtTable, err := strconv.ParseUint(c.RoutingTable, 10, 32) - if err != nil || rtTable == 0 || rtTable >= 0xFFFFFFFF { - return false - } - - // check rule priority values: must be > 0, < 32766, prio1 < prio2 - prio1, err := strconv.ParseUint(c.RulePriority1, 10, 16) - if err != nil { - return false - } - prio2, err := strconv.ParseUint(c.RulePriority2, 10, 16) - if err != nil { - return false - } - if prio1 == 0 || prio2 == 0 || - prio1 >= 32766 || prio2 >= 32766 || - prio1 >= prio2 { - - return false - } - - // check fwmark value: must be 32 bit unsigned int - if _, err := strconv.ParseUint(c.FirewallMark, 10, 32); err != nil { - return false - } - - return true -} - -// NewConfig returns a new split routing configuration. -func NewConfig() *Config { - return &Config{ - RoutingTable: RoutingTable, - RulePriority1: RulePriority1, - RulePriority2: RulePriority2, - FirewallMark: FirewallMark, - } -} diff --git a/internal/splitrt/config_test.go b/internal/splitrt/config_test.go deleted file mode 100644 index e82c93f5..00000000 --- a/internal/splitrt/config_test.go +++ /dev/null @@ -1,99 +0,0 @@ -package splitrt - -import "testing" - -// TestConfigValid tests Valid of Config. -func TestConfigValid(t *testing.T) { - // test invalid - for _, invalid := range []*Config{ - nil, - {}, - { - RoutingTable: "42111", - FirewallMark: "42111", - RulePriority1: "0", - RulePriority2: "1", - }, - { - RoutingTable: "42111", - FirewallMark: "42111", - RulePriority1: "32766", - RulePriority2: "32767", - }, - { - RoutingTable: "42111", - FirewallMark: "42111", - RulePriority1: "2111", - RulePriority2: "2111", - }, - { - RoutingTable: "42111", - FirewallMark: "42111", - RulePriority1: "2112", - RulePriority2: "2111", - }, - { - RoutingTable: "42111", - FirewallMark: "42111", - RulePriority1: "65537", - RulePriority2: "2111", - }, - { - RoutingTable: "42111", - FirewallMark: "42111", - RulePriority1: "2111", - RulePriority2: "65537", - }, - { - RoutingTable: "0", - FirewallMark: "42112", - RulePriority1: "2222", - RulePriority2: "2223", - }, - { - RoutingTable: "4294967295", - FirewallMark: "42112", - RulePriority1: "2222", - RulePriority2: "2223", - }, - { - RoutingTable: "42112", - FirewallMark: "4294967296", - RulePriority1: "2222", - RulePriority2: "2223", - }, - } { - want := false - got := invalid.Valid() - - if got != want { - t.Errorf("got %t, want %t for %v", got, want, invalid) - } - } - - // test valid - for _, valid := range []*Config{ - NewConfig(), - { - RoutingTable: "42112", - FirewallMark: "42112", - RulePriority1: "2222", - RulePriority2: "2223", - }, - } { - want := true - got := valid.Valid() - - if got != want { - t.Errorf("got %t, want %t for %v", got, want, valid) - } - } -} - -// TestNewConfig tests NewConfig. -func TestNewConfig(t *testing.T) { - c := NewConfig() - if !c.Valid() { - t.Errorf("new config should be valid") - } -} diff --git a/internal/splitrt/excludes.go b/internal/splitrt/excludes.go index d1a1f789..bf11eb7e 100644 --- a/internal/splitrt/excludes.go +++ b/internal/splitrt/excludes.go @@ -7,6 +7,7 @@ import ( "time" log "github.com/sirupsen/logrus" + "github.com/telekom-mms/oc-daemon/internal/daemoncfg" ) const ( @@ -23,6 +24,7 @@ type dynExclude struct { // Excludes contains split Excludes. type Excludes struct { sync.Mutex + conf *daemoncfg.Config s map[string]netip.Prefix d map[netip.Addr]*dynExclude done chan struct{} @@ -213,8 +215,9 @@ func (e *Excludes) List() (static, dynamic []string) { } // NewExcludes returns new split excludes. -func NewExcludes() *Excludes { +func NewExcludes(conf *daemoncfg.Config) *Excludes { return &Excludes{ + conf: conf, s: make(map[string]netip.Prefix), d: make(map[netip.Addr]*dynExclude), done: make(chan struct{}), diff --git a/internal/splitrt/excludes_test.go b/internal/splitrt/excludes_test.go index ee3a7f4a..54cbd974 100644 --- a/internal/splitrt/excludes_test.go +++ b/internal/splitrt/excludes_test.go @@ -7,6 +7,7 @@ import ( "reflect" "testing" + "github.com/telekom-mms/oc-daemon/internal/daemoncfg" "github.com/telekom-mms/oc-daemon/internal/execs" ) @@ -63,7 +64,7 @@ func getTestDynamicExcludes(t *testing.T) []netip.Prefix { // TestExcludesAddStatic tests AddStatic of Excludes. func TestExcludesAddStatic(t *testing.T) { ctx := context.Background() - e := NewExcludes() + e := NewExcludes(daemoncfg.NewConfig()) excludes := getTestStaticExcludes(t) // set testing runNft function @@ -94,7 +95,7 @@ func TestExcludesAddStatic(t *testing.T) { } // test adding overlapping excludes - e = NewExcludes() + e = NewExcludes(daemoncfg.NewConfig()) for _, exclude := range getTestStaticExcludesOverlap(t) { e.AddStatic(ctx, exclude) } @@ -108,7 +109,7 @@ func TestExcludesAddStatic(t *testing.T) { // TestExcludesAddDynamic tests AddDynamic of Excludes. func TestExcludesAddDynamic(t *testing.T) { ctx := context.Background() - e := NewExcludes() + e := NewExcludes(daemoncfg.NewConfig()) excludes := getTestDynamicExcludes(t) // set testing runNft function @@ -141,7 +142,7 @@ func TestExcludesAddDynamic(t *testing.T) { // test adding excludes with existing static excludes, // should only add new excludes - e = NewExcludes() + e = NewExcludes(daemoncfg.NewConfig()) for _, exclude := range getTestStaticExcludes(t) { e.AddStatic(ctx, exclude) } @@ -157,7 +158,7 @@ func TestExcludesAddDynamic(t *testing.T) { } // test adding invalid excludes (static as dynamic) - e = NewExcludes() + e = NewExcludes(daemoncfg.NewConfig()) got = []string{} want = []string{} for _, exclude := range getTestStaticExcludes(t) { @@ -171,7 +172,7 @@ func TestExcludesAddDynamic(t *testing.T) { // TestExcludesRemoveStatic tests RemoveStatic of Excludes. func TestExcludesRemove(t *testing.T) { ctx := context.Background() - e := NewExcludes() + e := NewExcludes(daemoncfg.NewConfig()) excludes := getTestStaticExcludes(t) // set testing runNft function @@ -265,7 +266,7 @@ func TestExcludesRemove(t *testing.T) { // TestExcludesCleanup tests cleanup of Excludes. func TestExcludesCleanup(t *testing.T) { ctx := context.Background() - e := NewExcludes() + e := NewExcludes(daemoncfg.NewConfig()) // set testing runNft function got := []string{} @@ -317,20 +318,22 @@ func TestExcludesCleanup(t *testing.T) { // TestExcludesStartStop tests Start and Stop of Excludes. func TestExcludesStartStop(_ *testing.T) { - e := NewExcludes() + e := NewExcludes(daemoncfg.NewConfig()) e.Start() e.Stop() } // TestNewExcludes tests NewExcludes. func TestNewExcludes(t *testing.T) { - e := NewExcludes() + conf := daemoncfg.NewConfig() + e := NewExcludes(conf) if e == nil || + e.conf != conf || e.s == nil || e.d == nil || e.done == nil || e.closed == nil { - t.Errorf("got nil, want != nil") + t.Errorf("invalid excludes") } } diff --git a/internal/splitrt/splitrt.go b/internal/splitrt/splitrt.go index bb06d2b3..b08a4022 100644 --- a/internal/splitrt/splitrt.go +++ b/internal/splitrt/splitrt.go @@ -9,15 +9,13 @@ import ( log "github.com/sirupsen/logrus" "github.com/telekom-mms/oc-daemon/internal/addrmon" + "github.com/telekom-mms/oc-daemon/internal/daemoncfg" "github.com/telekom-mms/oc-daemon/internal/devmon" "github.com/telekom-mms/oc-daemon/internal/dnsproxy" - "github.com/telekom-mms/oc-daemon/pkg/vpnconfig" ) // State is the internal state. type State struct { - Config *Config - VPNConfig *vpnconfig.Config Devices []*devmon.Update Addresses []*addrmon.Update LocalExcludes []string @@ -49,89 +47,89 @@ func (l *locals) set(locals []netip.Prefix) { // SplitRouting is a split routing configuration. type SplitRouting struct { - config *Config - vpnconfig *vpnconfig.Config - devmon *devmon.DevMon - addrmon *addrmon.AddrMon - devices *Devices - addrs *Addresses - locals locals - excludes *Excludes - dnsreps chan *dnsproxy.Report - done chan struct{} - closed chan struct{} + config *daemoncfg.Config + devmon *devmon.DevMon + addrmon *addrmon.AddrMon + devices *Devices + addrs *Addresses + locals locals + excludes *Excludes + dnsreps chan *dnsproxy.Report + done chan struct{} + closed chan struct{} } // setupRouting sets up routing using config. func (s *SplitRouting) setupRouting(ctx context.Context) { // prepare netfilter and excludes - setRoutingRules(ctx, s.config.FirewallMark) - - // convert to netip - pre4 := netip.Prefix{} - if ipv4, ok := netip.AddrFromSlice(s.vpnconfig.IPv4.Address.To4()); ok { - one4, _ := s.vpnconfig.IPv4.Netmask.Size() - pre4 = netip.PrefixFrom(ipv4, one4) - } - pre6 := netip.Prefix{} - if ipv6, ok := netip.AddrFromSlice(s.vpnconfig.IPv6.Address); ok { - one6, _ := s.vpnconfig.IPv6.Netmask.Size() - pre6 = netip.PrefixFrom(ipv6, one6) - } + setRoutingRules(ctx, s.config.SplitRouting.FirewallMark) // filter non-local traffic to vpn addresses - addLocalAddressesIPv4(ctx, s.vpnconfig.Device.Name, []netip.Prefix{pre4}) - addLocalAddressesIPv6(ctx, s.vpnconfig.Device.Name, []netip.Prefix{pre6}) + addLocalAddressesIPv4(ctx, + s.config.VPNConfig.Device.Name, + []netip.Prefix{s.config.VPNConfig.IPv4}) + addLocalAddressesIPv6(ctx, + s.config.VPNConfig.Device.Name, + []netip.Prefix{s.config.VPNConfig.IPv6}) // reject unsupported ip versions on vpn - if !pre6.IsValid() { - rejectIPv6(ctx, s.vpnconfig.Device.Name) + if !s.config.VPNConfig.IPv6.IsValid() { + rejectIPv6(ctx, s.config.VPNConfig.Device.Name) } - if !pre4.IsValid() { - rejectIPv4(ctx, s.vpnconfig.Device.Name) + if !s.config.VPNConfig.IPv4.IsValid() { + rejectIPv4(ctx, s.config.VPNConfig.Device.Name) } // add excludes s.excludes.Start() // add gateway to static excludes - if s.vpnconfig.Gateway != nil { - g := netip.MustParseAddr(s.vpnconfig.Gateway.String()) - gateway := netip.PrefixFrom(g, g.BitLen()) + if s.config.VPNConfig.Gateway.IsValid() { + gateway := netip.PrefixFrom(s.config.VPNConfig.Gateway, + s.config.VPNConfig.Gateway.BitLen()) s.excludes.AddStatic(ctx, gateway) } // add static IPv4 excludes - for _, e := range s.vpnconfig.Split.ExcludeIPv4 { + for _, e := range s.config.VPNConfig.Split.ExcludeIPv4 { if e.String() == "0.0.0.0/32" { continue } - p := netip.MustParsePrefix(e.String()) - s.excludes.AddStatic(ctx, p) + s.excludes.AddStatic(ctx, e) } // add static IPv6 excludes - for _, e := range s.vpnconfig.Split.ExcludeIPv6 { + for _, e := range s.config.VPNConfig.Split.ExcludeIPv6 { // TODO: does ::/128 exist? if e.String() == "::/128" { continue } - p := netip.MustParsePrefix(e.String()) - s.excludes.AddStatic(ctx, p) + s.excludes.AddStatic(ctx, e) } // setup routing - addDefaultRouteIPv4(ctx, s.vpnconfig.Device.Name, s.config.RoutingTable, - s.config.RulePriority1, s.config.FirewallMark, s.config.RulePriority2) - addDefaultRouteIPv6(ctx, s.vpnconfig.Device.Name, s.config.RoutingTable, - s.config.RulePriority1, s.config.FirewallMark, s.config.RulePriority2) - + addDefaultRouteIPv4(ctx, + s.config.VPNConfig.Device.Name, + s.config.SplitRouting.RoutingTable, + s.config.SplitRouting.RulePriority1, + s.config.SplitRouting.FirewallMark, + s.config.SplitRouting.RulePriority2) + addDefaultRouteIPv6(ctx, + s.config.VPNConfig.Device.Name, + s.config.SplitRouting.RoutingTable, + s.config.SplitRouting.RulePriority1, + s.config.SplitRouting.FirewallMark, + s.config.SplitRouting.RulePriority2) } // teardownRouting tears down the routing configuration. func (s *SplitRouting) teardownRouting(ctx context.Context) { - deleteDefaultRouteIPv4(ctx, s.vpnconfig.Device.Name, s.config.RoutingTable) - deleteDefaultRouteIPv6(ctx, s.vpnconfig.Device.Name, s.config.RoutingTable) + deleteDefaultRouteIPv4(ctx, + s.config.VPNConfig.Device.Name, + s.config.SplitRouting.RoutingTable) + deleteDefaultRouteIPv6(ctx, + s.config.VPNConfig.Device.Name, + s.config.SplitRouting.RoutingTable) unsetRoutingRules(ctx) // remove excludes @@ -140,12 +138,12 @@ func (s *SplitRouting) teardownRouting(ctx context.Context) { // excludeSettings returns whether local (virtual) networks should be excluded. func (s *SplitRouting) excludeLocalNetworks() (exclude bool, virtual bool) { - for _, e := range s.vpnconfig.Split.ExcludeIPv4 { + for _, e := range s.config.VPNConfig.Split.ExcludeIPv4 { if e.String() == "0.0.0.0/32" { exclude = true } } - if s.vpnconfig.Split.ExcludeVirtualSubnetsOnlyIPv4 { + if s.config.VPNConfig.Split.ExcludeVirtualSubnetsOnlyIPv4 { virtual = true } return @@ -211,7 +209,7 @@ func (s *SplitRouting) handleDeviceUpdate(ctx context.Context, u *devmon.Update) // skip loopback devices return } - if u.Device == s.vpnconfig.Device.Name { + if u.Device == s.config.VPNConfig.Device.Name { // skip vpn tunnel device, so we do not use it for // split excludes return @@ -312,8 +310,6 @@ func (s *SplitRouting) GetState() *State { } static, dynamic := s.excludes.List() return &State{ - Config: s.config, - VPNConfig: s.vpnconfig, Devices: s.devices.List(), Addresses: s.addrs.List(), LocalExcludes: locals, @@ -323,24 +319,25 @@ func (s *SplitRouting) GetState() *State { } // NewSplitRouting returns a new SplitRouting. -func NewSplitRouting(config *Config, vpnconfig *vpnconfig.Config) *SplitRouting { +func NewSplitRouting(config *daemoncfg.Config) *SplitRouting { return &SplitRouting{ - config: config, - vpnconfig: vpnconfig, - devmon: devmon.NewDevMon(), - addrmon: addrmon.NewAddrMon(), - devices: NewDevices(), - addrs: NewAddresses(), - excludes: NewExcludes(), - dnsreps: make(chan *dnsproxy.Report), - done: make(chan struct{}), - closed: make(chan struct{}), + config: config, + devmon: devmon.NewDevMon(), + addrmon: addrmon.NewAddrMon(), + devices: NewDevices(), + addrs: NewAddresses(), + excludes: NewExcludes(config), + dnsreps: make(chan *dnsproxy.Report), + done: make(chan struct{}), + closed: make(chan struct{}), } } // Cleanup cleans up old configuration after a failed shutdown. -func Cleanup(ctx context.Context, config *Config) { - cleanupRouting(ctx, config.RoutingTable, config.RulePriority1, - config.RulePriority2) +func Cleanup(ctx context.Context, config *daemoncfg.Config) { + cleanupRouting(ctx, + config.SplitRouting.RoutingTable, + config.SplitRouting.RulePriority1, + config.SplitRouting.RulePriority2) cleanupRoutingRules(ctx) } diff --git a/internal/splitrt/splitrt_test.go b/internal/splitrt/splitrt_test.go index e7832af0..9ec11b8a 100644 --- a/internal/splitrt/splitrt_test.go +++ b/internal/splitrt/splitrt_test.go @@ -3,24 +3,23 @@ package splitrt import ( "context" "errors" - "net" "net/netip" "reflect" "strings" "testing" "github.com/telekom-mms/oc-daemon/internal/addrmon" + "github.com/telekom-mms/oc-daemon/internal/daemoncfg" "github.com/telekom-mms/oc-daemon/internal/devmon" "github.com/telekom-mms/oc-daemon/internal/dnsproxy" "github.com/telekom-mms/oc-daemon/internal/execs" - "github.com/telekom-mms/oc-daemon/pkg/vpnconfig" "github.com/vishvananda/netlink" ) // TestSplitRoutingHandleDeviceUpdate tests handleDeviceUpdate of SplitRouting. func TestSplitRoutingHandleDeviceUpdate(t *testing.T) { ctx := context.Background() - s := NewSplitRouting(NewConfig(), vpnconfig.New()) + s := NewSplitRouting(daemoncfg.NewConfig()) want := []string{"nothing else"} got := []string{"nothing else"} @@ -56,7 +55,7 @@ func TestSplitRoutingHandleDeviceUpdate(t *testing.T) { // test adding vpn device update = getTestDevMonUpdate() - update.Device = s.vpnconfig.Device.Name + update.Device = s.config.VPNConfig.Device.Name s.handleDeviceUpdate(ctx, update) if !reflect.DeepEqual(got, want) { t.Errorf("got %v, want %v", got, want) @@ -68,14 +67,11 @@ func TestSplitRoutingHandleAddressUpdate(t *testing.T) { ctx := context.Background() // test with exclude - vpnconf := vpnconfig.New() - vpnconf.Split.ExcludeIPv4 = []*net.IPNet{ - { - IP: net.IPv4(0, 0, 0, 0), - Mask: net.CIDRMask(32, 32), - }, - } - s := NewSplitRouting(NewConfig(), vpnconf) + conf := daemoncfg.NewConfig() + conf.VPNConfig.Split.ExcludeIPv4 = []netip.Prefix{ + netip.MustParsePrefix("0.0.0.0/32"), + } + s := NewSplitRouting(conf) s.devices.Add(getTestDevMonUpdate()) got := []string{} @@ -109,15 +105,12 @@ func TestSplitRoutingHandleAddressUpdate(t *testing.T) { } // test with exclude and virtual - vpnconf = vpnconfig.New() - vpnconf.Split.ExcludeIPv4 = []*net.IPNet{ - { - IP: net.IPv4(0, 0, 0, 0), - Mask: net.CIDRMask(32, 32), - }, - } - vpnconf.Split.ExcludeVirtualSubnetsOnlyIPv4 = true - s = NewSplitRouting(NewConfig(), vpnconf) + conf = daemoncfg.NewConfig() + conf.VPNConfig.Split.ExcludeIPv4 = []netip.Prefix{ + netip.MustParsePrefix("0.0.0.0/32"), + } + conf.VPNConfig.Split.ExcludeVirtualSubnetsOnlyIPv4 = true + s = NewSplitRouting(conf) devUp := getTestDevMonUpdate() devUp.Type = "virtual" s.devices.Add(devUp) @@ -156,7 +149,7 @@ func TestSplitRoutingHandleAddressUpdate(t *testing.T) { // TestSplitRoutingHandleDNSReport tests handleDNSReport of SplitRouting. func TestSplitRoutingHandleDNSReport(t *testing.T) { ctx := context.Background() - s := NewSplitRouting(NewConfig(), vpnconfig.New()) + s := NewSplitRouting(daemoncfg.NewConfig()) got := []string{} oldRunCmd := execs.RunCmd @@ -207,52 +200,39 @@ func TestSplitRoutingStartStop(t *testing.T) { defer func() { devmon.RegisterLinkUpdates = oldRegisterLinkUpdates }() // test with new configs - s := NewSplitRouting(NewConfig(), vpnconfig.New()) + s := NewSplitRouting(daemoncfg.NewConfig()) if err := s.Start(); err != nil { t.Error(err) } s.Stop() // test with excludes - vpnconf := vpnconfig.New() - vpnconf.Split.ExcludeIPv4 = []*net.IPNet{ - { - IP: net.IPv4(0, 0, 0, 0), - Mask: net.CIDRMask(32, 32), - }, - { - IP: net.IPv4(192, 168, 1, 1), - Mask: net.CIDRMask(32, 32), - }, - } - vpnconf.Split.ExcludeIPv6 = []*net.IPNet{ - { - IP: net.ParseIP("::"), - Mask: net.CIDRMask(128, 128), - }, - { - IP: net.ParseIP("2000::1"), - Mask: net.CIDRMask(128, 128), - }, - } - s = NewSplitRouting(NewConfig(), vpnconf) + conf := daemoncfg.NewConfig() + conf.VPNConfig.Split.ExcludeIPv4 = []netip.Prefix{ + netip.MustParsePrefix("0.0.0.0/32"), + netip.MustParsePrefix("192.168.1.1/32"), + } + conf.VPNConfig.Split.ExcludeIPv6 = []netip.Prefix{ + netip.MustParsePrefix("::/128"), + netip.MustParsePrefix("2000::1/128"), + } + s = NewSplitRouting(conf) if err := s.Start(); err != nil { t.Error(err) } s.Stop() // test with vpn address - vpnconf = vpnconfig.New() - vpnconf.IPv4.Address = net.IPv4(192, 168, 1, 1) - vpnconf.IPv4.Netmask = net.CIDRMask(24, 32) - s = NewSplitRouting(NewConfig(), vpnconf) + conf = daemoncfg.NewConfig() + conf.VPNConfig.IPv4 = netip.MustParsePrefix("192.168.1.1/24") + s = NewSplitRouting(daemoncfg.NewConfig()) if err := s.Start(); err != nil { t.Error(err) } s.Stop() // test with events - s = NewSplitRouting(NewConfig(), vpnconfig.New()) + s = NewSplitRouting(daemoncfg.NewConfig()) if err := s.Start(); err != nil { t.Error(err) } @@ -267,7 +247,7 @@ func TestSplitRoutingStartStop(t *testing.T) { execs.RunCmd = func(context.Context, string, string, ...string) ([]byte, []byte, error) { return nil, nil, errors.New("test error") } - s = NewSplitRouting(NewConfig(), vpnconfig.New()) + s = NewSplitRouting(daemoncfg.NewConfig()) if err := s.Start(); err != nil { t.Error(err) } @@ -276,7 +256,7 @@ func TestSplitRoutingStartStop(t *testing.T) { // TestSplitRoutingDNSReports tests DNSReports of SplitRouting. func TestSplitRoutingDNSReports(t *testing.T) { - s := NewSplitRouting(NewConfig(), vpnconfig.New()) + s := NewSplitRouting(daemoncfg.NewConfig()) want := s.dnsreps got := s.DNSReports() if got != want { @@ -286,7 +266,7 @@ func TestSplitRoutingDNSReports(t *testing.T) { // TestSplitRoutingGetState tests GetState of SplitRouting. func TestSplitRoutingGetState(t *testing.T) { - s := NewSplitRouting(NewConfig(), vpnconfig.New()) + s := NewSplitRouting(daemoncfg.NewConfig()) // set devices dev := &devmon.Update{ @@ -319,8 +299,6 @@ func TestSplitRoutingGetState(t *testing.T) { // get and check state want := &State{ - Config: NewConfig(), - VPNConfig: vpnconfig.New(), Devices: []*devmon.Update{dev}, Addresses: []*addrmon.Update{addr}, LocalExcludes: []string{"10.0.0.0/24"}, @@ -335,15 +313,11 @@ func TestSplitRoutingGetState(t *testing.T) { // TestNewSplitRouting tests NewSplitRouting. func TestNewSplitRouting(t *testing.T) { - config := NewConfig() - vpnconf := vpnconfig.New() - s := NewSplitRouting(config, vpnconf) + config := daemoncfg.NewConfig() + s := NewSplitRouting(config) if s.config != config { t.Errorf("got %p, want %p", s.config, config) } - if s.vpnconfig != vpnconf { - t.Errorf("got %p, want %p", s.vpnconfig, vpnconf) - } if s.devmon == nil || s.addrmon == nil || s.devices == nil || @@ -372,7 +346,7 @@ func TestCleanup(t *testing.T) { } defer func() { execs.RunCmd = oldRunCmd }() - Cleanup(context.Background(), NewConfig()) + Cleanup(context.Background(), daemoncfg.NewConfig()) want := []string{ "ip -4 rule delete pref 2111", "ip -4 rule delete pref 2112", diff --git a/internal/trafpol/config.go b/internal/trafpol/config.go deleted file mode 100644 index 62d2213e..00000000 --- a/internal/trafpol/config.go +++ /dev/null @@ -1,83 +0,0 @@ -package trafpol - -import ( - "time" -) - -var ( - // AllowedHosts is the default list of allowed hosts, this is - // initialized with hosts for captive portal detection, e.g., - // used by browsers. - AllowedHosts = []string{ - "connectivity-check.ubuntu.com", // ubuntu - "detectportal.firefox.com", // firefox - "www.gstatic.com", // chrome - "clients3.google.com", // chromium - "nmcheck.gnome.org", // gnome - } - - // PortalPorts are the default ports that are allowed to register on a - // captive portal. - PortalPorts = []uint16{ - 80, - 443, - } - - // ResolveTimeout is the timeout for dns lookups. - ResolveTimeout = 2 * time.Second - - // ResolveTries is the number of tries for dns lookups. - ResolveTries = 3 - - // ResolveTriesSleep is the sleep time between retries. - ResolveTriesSleep = time.Second - - // ResolveTimer is the time for periodic resolve update checks, - // should be higher than tries * (timeout + sleep). - ResolveTimer = 30 * time.Second - - // ResolveTTL is the lifetime of resolved entries. - ResolveTTL = 300 * time.Second -) - -// Config is a TrafPol configuration. -type Config struct { - AllowedHosts []string - PortalPorts []uint16 - FirewallMark string `json:"-"` - - ResolveTimeout time.Duration - ResolveTries int - ResolveTriesSleep time.Duration - ResolveTimer time.Duration - ResolveTTL time.Duration -} - -// Valid returns whether the TrafPol configuration is valid. -func (c *Config) Valid() bool { - if c == nil || - len(c.PortalPorts) == 0 || - c.ResolveTimeout < 0 || - c.ResolveTries < 1 || - c.ResolveTriesSleep < 0 || - c.ResolveTimer < 0 || - c.ResolveTTL < 0 { - - return false - } - return true -} - -// NewConfig returns a new TrafPol configuration. -func NewConfig() *Config { - return &Config{ - AllowedHosts: append(AllowedHosts[:0:0], AllowedHosts...), - PortalPorts: append(PortalPorts[:0:0], PortalPorts...), - - ResolveTimeout: ResolveTimeout, - ResolveTries: ResolveTries, - ResolveTriesSleep: ResolveTriesSleep, - ResolveTimer: ResolveTimer, - ResolveTTL: ResolveTTL, - } -} diff --git a/internal/trafpol/config_test.go b/internal/trafpol/config_test.go deleted file mode 100644 index db9306dd..00000000 --- a/internal/trafpol/config_test.go +++ /dev/null @@ -1,36 +0,0 @@ -package trafpol - -import "testing" - -// TestConfigValid tests Valid of Config. -func TestConfigValid(t *testing.T) { - // test invalid - for _, invalid := range []*Config{ - nil, - {}, - } { - want := false - got := invalid.Valid() - - if got != want { - t.Errorf("got %t, want %t for %v", got, want, invalid) - } - } - - // test valid - valid := NewConfig() - want := true - got := valid.Valid() - - if got != want { - t.Errorf("got %t, want %t for %v", got, want, valid) - } -} - -// TestNewConfig tests NewConfig. -func TestNewConfig(t *testing.T) { - c := NewConfig() - if !c.Valid() { - t.Errorf("new config should be valid") - } -} diff --git a/internal/trafpol/resolver.go b/internal/trafpol/resolver.go index cabf58e1..f5d99693 100644 --- a/internal/trafpol/resolver.go +++ b/internal/trafpol/resolver.go @@ -8,6 +8,8 @@ import ( "sort" "sync" "time" + + "github.com/telekom-mms/oc-daemon/internal/daemoncfg" ) // ResolvedName is a resolved DNS name. @@ -18,7 +20,7 @@ type ResolvedName struct { } // sleepResolveTry is used to sleep before resolve (re)tries, can be canceled. -func (r *ResolvedName) sleepResolveTry(ctx context.Context, config *Config) { +func (r *ResolvedName) sleepResolveTry(ctx context.Context, config *daemoncfg.TrafficPolicing) { timer := time.NewTimer(config.ResolveTriesSleep) select { case <-timer.C: @@ -31,7 +33,7 @@ func (r *ResolvedName) sleepResolveTry(ctx context.Context, config *Config) { } // resolve resolves the DNS name to its IP addresses. -func (r *ResolvedName) resolve(ctx context.Context, config *Config, updates chan *ResolvedName) { +func (r *ResolvedName) resolve(ctx context.Context, config *daemoncfg.TrafficPolicing, updates chan *ResolvedName) { // try to resolve ip addresses of host resolver := &net.Resolver{} tries := 0 @@ -104,7 +106,7 @@ func (r *ResolvedName) resolve(ctx context.Context, config *Config, updates chan // Resolver is a DNS resolver that resolves names to their IP addresses. type Resolver struct { - config *Config + config *daemoncfg.TrafficPolicing names map[string]*ResolvedName updates chan *ResolvedName cmds chan struct{} @@ -233,7 +235,7 @@ func (r *Resolver) Resolve() { } // NewResolver returns a new Resolver. -func NewResolver(config *Config, names []string, updates chan *ResolvedName) *Resolver { +func NewResolver(config *daemoncfg.TrafficPolicing, names []string, updates chan *ResolvedName) *Resolver { n := make(map[string]*ResolvedName) for _, name := range names { n[name] = &ResolvedName{Name: name} diff --git a/internal/trafpol/resolver_test.go b/internal/trafpol/resolver_test.go index ee0ce6df..9aca0301 100644 --- a/internal/trafpol/resolver_test.go +++ b/internal/trafpol/resolver_test.go @@ -4,11 +4,13 @@ import ( "net/netip" "testing" "time" + + "github.com/telekom-mms/oc-daemon/internal/daemoncfg" ) // TestResolverStartStop tests Start and Stop of Resolver. func TestResolverStartStop(_ *testing.T) { - config := NewConfig() + config := daemoncfg.NewTrafficPolicing() names := []string{} updates := make(chan *ResolvedName) r := NewResolver(config, names, updates) @@ -18,7 +20,7 @@ func TestResolverStartStop(_ *testing.T) { // TestResolverResolve tests Resolve of Resolver. func TestResolverResolve(_ *testing.T) { - config := NewConfig() + config := daemoncfg.NewTrafficPolicing() config.ResolveTriesSleep = 0 config.ResolveTimer = 0 names := []string{"does not exist...", "example.com"} @@ -58,7 +60,7 @@ func TestResolverResolve(_ *testing.T) { // TestNewResolver tests NewResolver. func TestNewResolver(t *testing.T) { - config := NewConfig() + config := daemoncfg.NewTrafficPolicing() names := []string{"test.example.com"} updates := make(chan *ResolvedName) r := NewResolver(config, names, updates) diff --git a/internal/trafpol/trafpol.go b/internal/trafpol/trafpol.go index 60c8d9f1..2f9a707c 100644 --- a/internal/trafpol/trafpol.go +++ b/internal/trafpol/trafpol.go @@ -8,6 +8,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/telekom-mms/oc-daemon/internal/cpd" + "github.com/telekom-mms/oc-daemon/internal/daemoncfg" "github.com/telekom-mms/oc-daemon/internal/devmon" "github.com/telekom-mms/oc-daemon/internal/dnsmon" ) @@ -21,7 +22,6 @@ const ( // State is the internal TrafPol state. type State struct { - Config *Config CaptivePortal bool AllowedDevices []string AllowedAddresses []netip.Prefix @@ -39,7 +39,7 @@ type trafPolCmd struct { // TrafPol is a traffic policing component. type TrafPol struct { - config *Config + config *daemoncfg.Config devmon *devmon.DevMon dnsmon *dnsmon.DNSMon cpd *cpd.CPD @@ -96,7 +96,7 @@ func (t *TrafPol) handleCPDReport(ctx context.Context, report *cpd.Report) { t.resolver.Resolve() // remove ports from allowed ports - removePortalPorts(ctx, t.config.PortalPorts) + removePortalPorts(ctx, t.config.TrafficPolicing.PortalPorts) t.capPortal = false log.WithField("capPortal", t.capPortal).Info("TrafPol changed CPD status") } @@ -105,7 +105,7 @@ func (t *TrafPol) handleCPDReport(ctx context.Context, report *cpd.Report) { // add ports to allowed ports if !t.capPortal { - addPortalPorts(ctx, t.config.PortalPorts) + addPortalPorts(ctx, t.config.TrafficPolicing.PortalPorts) t.capPortal = true log.WithField("capPortal", t.capPortal).Info("TrafPol changed CPD status") } @@ -175,7 +175,6 @@ func (t *TrafPol) handleAddressCommand(ctx context.Context, cmd *trafPolCmd) { func (t *TrafPol) handleGetStateCommand(cmd *trafPolCmd) { // set state cmd.state = &State{ - Config: t.config, CaptivePortal: t.capPortal, AllowedDevices: t.allowDevs.List(), AllowedAddresses: t.allowAddrs.List(), @@ -247,7 +246,7 @@ func (t *TrafPol) Start() error { ctx := context.Background() // set firewall config - setFilterRules(ctx, t.config.FirewallMark) + setFilterRules(ctx, t.config.SplitRouting.FirewallMark) // set filter rules setAllowedIPs(ctx, t.getAllowedHostsIPs()) @@ -363,12 +362,12 @@ func parseAllowedHosts(hosts []string) (addrs []netip.Prefix, names []string) { } // NewTrafPol returns a new traffic policing component. -func NewTrafPol(config *Config) *TrafPol { +func NewTrafPol(config *daemoncfg.Config) *TrafPol { // create cpd - c := cpd.NewCPD(cpd.NewConfig()) + c := cpd.NewCPD(daemoncfg.NewCPD()) // get allowed addrs and names - hosts := append(config.AllowedHosts, c.Hosts()...) + hosts := append(config.TrafficPolicing.AllowedHosts, c.Hosts()...) a, n := parseAllowedHosts(hosts) // create allowed addrs and names @@ -395,7 +394,7 @@ func NewTrafPol(config *Config) *TrafPol { allowAddrs: addrs, allowNames: names, - resolver: NewResolver(config, n, resolvUp), + resolver: NewResolver(config.TrafficPolicing, n, resolvUp), resolvUp: resolvUp, cmds: make(chan *trafPolCmd), diff --git a/internal/trafpol/trafpol_test.go b/internal/trafpol/trafpol_test.go index 108d9dc4..7a7cbd87 100644 --- a/internal/trafpol/trafpol_test.go +++ b/internal/trafpol/trafpol_test.go @@ -10,6 +10,7 @@ import ( "testing" "github.com/telekom-mms/oc-daemon/internal/cpd" + "github.com/telekom-mms/oc-daemon/internal/daemoncfg" "github.com/telekom-mms/oc-daemon/internal/devmon" "github.com/telekom-mms/oc-daemon/internal/execs" "github.com/vishvananda/netlink" @@ -17,7 +18,7 @@ import ( // TestTrafPolHandleDeviceUpdate tests handleDeviceUpdate of TrafPol. func TestTrafPolHandleDeviceUpdate(_ *testing.T) { - tp := NewTrafPol(NewConfig()) + tp := NewTrafPol(daemoncfg.NewConfig()) ctx := context.Background() // test adding @@ -33,7 +34,7 @@ func TestTrafPolHandleDeviceUpdate(_ *testing.T) { // TestTrafPolHandleDNSUpdate tests handleDNSUpdate of TrafPol. func TestTrafPolHandleDNSUpdate(_ *testing.T) { - tp := NewTrafPol(NewConfig()) + tp := NewTrafPol(daemoncfg.NewConfig()) tp.resolver.Start() defer tp.resolver.Stop() @@ -45,7 +46,7 @@ func TestTrafPolHandleDNSUpdate(_ *testing.T) { // TestTrafPolHandleCPDReport tests handleCPDReport of TrafPol. func TestTrafPolHandleCPDReport(t *testing.T) { - tp := NewTrafPol(NewConfig()) + tp := NewTrafPol(daemoncfg.NewConfig()) ctx := context.Background() tp.resolver.Start() @@ -113,7 +114,7 @@ func TestTrafPolStartEvents(t *testing.T) { } defer func() { devmon.RegisterLinkUpdates = oldRegisterLinkUpdates }() - tp := NewTrafPol(NewConfig()) + tp := NewTrafPol(daemoncfg.NewConfig()) if err := tp.Start(); err != nil { t.Fatal(err) } @@ -127,9 +128,9 @@ func TestTrafPolStartEvents(t *testing.T) { // TestTrafPolGetAllowedHostsIPs tests getAllowedHostsIPs of TrafPol. func TestTrafPolGetAllowedHostsIPs(t *testing.T) { // create trafpol with allowed addresses - c := NewConfig() - c.AllowedHosts = append(c.AllowedHosts, "192.168.2.0/24") - c.AllowedHosts = append(c.AllowedHosts, "2001:DB8:2::/64") + c := daemoncfg.NewConfig() + c.TrafficPolicing.AllowedHosts = append(c.TrafficPolicing.AllowedHosts, + "192.168.2.0/24", "2001:DB8:2::/64") tp := NewTrafPol(c) // add allowed names @@ -176,7 +177,7 @@ func TestTrafPolStartStop(t *testing.T) { } defer func() { devmon.RegisterLinkUpdates = oldRegisterLinkUpdates }() - tp := NewTrafPol(NewConfig()) + tp := NewTrafPol(daemoncfg.NewConfig()) if err := tp.Start(); err != nil { t.Fatal(err) } @@ -192,7 +193,7 @@ func TestTrafPolAddRemoveAllowedAddr(t *testing.T) { } defer func() { devmon.RegisterLinkUpdates = oldRegisterLinkUpdates }() - tp := NewTrafPol(NewConfig()) + tp := NewTrafPol(daemoncfg.NewConfig()) if err := tp.Start(); err != nil { t.Fatal(err) } @@ -252,7 +253,7 @@ func TestTrafPolGetState(t *testing.T) { defer func() { devmon.RegisterLinkUpdates = oldRegisterLinkUpdates }() // start trafpol - tp := NewTrafPol(NewConfig()) + tp := NewTrafPol(daemoncfg.NewConfig()) if err := tp.Start(); err != nil { t.Fatal(err) } @@ -268,11 +269,10 @@ func TestTrafPolGetState(t *testing.T) { // TestNewTrafPol tests NewTrafPol. func TestNewTrafPol(t *testing.T) { - c := NewConfig() - c.AllowedHosts = append(c.AllowedHosts, "192.168.1.1") - c.AllowedHosts = append(c.AllowedHosts, "192.168.2.0/24") - c.AllowedHosts = append(c.AllowedHosts, "2001:DB8:1::1") - c.AllowedHosts = append(c.AllowedHosts, "2001:DB8:2::/64 ") + c := daemoncfg.NewConfig() + c.TrafficPolicing.AllowedHosts = append(c.TrafficPolicing.AllowedHosts, + "192.168.1.1", "192.168.2.0/24", + "2001:DB8:1::1", "2001:DB8:2::/64") tp := NewTrafPol(c) if tp == nil || tp.devmon == nil || diff --git a/internal/vpncscript/client_test.go b/internal/vpncscript/client_test.go index af0c21b9..eed65ffb 100644 --- a/internal/vpncscript/client_test.go +++ b/internal/vpncscript/client_test.go @@ -7,13 +7,14 @@ import ( "github.com/telekom-mms/oc-daemon/internal/api" "github.com/telekom-mms/oc-daemon/internal/daemon" + "github.com/telekom-mms/oc-daemon/internal/daemoncfg" "github.com/telekom-mms/oc-daemon/pkg/vpnconfig" ) // TestRunClient tests runClient. func TestRunClient(t *testing.T) { sockfile := filepath.Join(t.TempDir(), "sockfile") - config := api.NewConfig() + config := daemoncfg.NewSocketServer() config.SocketFile = sockfile // without errors diff --git a/internal/vpnsetup/vpnsetup.go b/internal/vpnsetup/vpnsetup.go index e0b7c5c6..3588a1cd 100644 --- a/internal/vpnsetup/vpnsetup.go +++ b/internal/vpnsetup/vpnsetup.go @@ -9,10 +9,10 @@ import ( "time" log "github.com/sirupsen/logrus" + "github.com/telekom-mms/oc-daemon/internal/daemoncfg" "github.com/telekom-mms/oc-daemon/internal/dnsproxy" "github.com/telekom-mms/oc-daemon/internal/execs" "github.com/telekom-mms/oc-daemon/internal/splitrt" - "github.com/telekom-mms/oc-daemon/pkg/vpnconfig" ) // command types. @@ -30,20 +30,17 @@ type State struct { // command is a VPNSetup command. type command struct { - cmd uint8 - vpnconf *vpnconfig.Config - state *State - done chan struct{} + cmd uint8 + conf *daemoncfg.Config + state *State + done chan struct{} } // VPNSetup sets up the configuration of the vpn tunnel that belongs to the // current VPN connection. type VPNSetup struct { - splitrt *splitrt.SplitRouting - splitrtConf *splitrt.Config - - dnsProxy *dnsproxy.Proxy - dnsProxyConf *dnsproxy.Config + splitrt *splitrt.SplitRouting + dnsProxy *dnsproxy.Proxy ensureDone chan struct{} ensureClosed chan struct{} @@ -54,12 +51,14 @@ type VPNSetup struct { } // setupVPNDevice sets up the vpn device with config. -func setupVPNDevice(ctx context.Context, c *vpnconfig.Config) { +func setupVPNDevice(ctx context.Context, c *daemoncfg.Config) { // set mtu on device - mtu := strconv.Itoa(c.Device.MTU) - if stdout, stderr, err := execs.RunIPLink(ctx, "set", c.Device.Name, "mtu", mtu); err != nil { + mtu := strconv.Itoa(c.VPNConfig.Device.MTU) + if stdout, stderr, err := execs.RunIPLink( + ctx, "set", c.VPNConfig.Device.Name, "mtu", mtu); err != nil { + log.WithError(err).WithFields(log.Fields{ - "device": c.Device.Name, + "device": c.VPNConfig.Device.Name, "mtu": mtu, "stdout": string(stdout), "stderr": string(stderr), @@ -68,9 +67,11 @@ func setupVPNDevice(ctx context.Context, c *vpnconfig.Config) { } // set device up - if stdout, stderr, err := execs.RunIPLink(ctx, "set", c.Device.Name, "up"); err != nil { + if stdout, stderr, err := execs.RunIPLink( + ctx, "set", c.VPNConfig.Device.Name, "up"); err != nil { + log.WithError(err).WithFields(log.Fields{ - "device": c.Device.Name, + "device": c.VPNConfig.Device.Name, "stdout": string(stdout), "stderr": string(stderr), }).Error("Daemon could not set device up") @@ -79,9 +80,11 @@ func setupVPNDevice(ctx context.Context, c *vpnconfig.Config) { // set ipv4 and ipv6 addresses on device setupIP := func(a netip.Prefix) { - dev := c.Device.Name + dev := c.VPNConfig.Device.Name addr := a.String() - if stdout, stderr, err := execs.RunIPAddress(ctx, "add", addr, "dev", dev); err != nil { + if stdout, stderr, err := execs.RunIPAddress( + ctx, "add", addr, "dev", dev); err != nil { + log.WithError(err).WithFields(log.Fields{ "device": dev, "ip": addr, @@ -93,26 +96,22 @@ func setupVPNDevice(ctx context.Context, c *vpnconfig.Config) { } - if ipv4, ok := netip.AddrFromSlice(c.IPv4.Address.To4()); ok { - one4, _ := c.IPv4.Netmask.Size() - pre4 := netip.PrefixFrom(ipv4, one4) - - setupIP(pre4) + if c.VPNConfig.IPv4.IsValid() { + setupIP(c.VPNConfig.IPv4) } - if ipv6, ok := netip.AddrFromSlice(c.IPv6.Address); ok { - one6, _ := c.IPv6.Netmask.Size() - pre6 := netip.PrefixFrom(ipv6, one6) - - setupIP(pre6) + if c.VPNConfig.IPv6.IsValid() { + setupIP(c.VPNConfig.IPv6) } } // teardownVPNDevice tears down the configured vpn device. -func teardownVPNDevice(ctx context.Context, c *vpnconfig.Config) { +func teardownVPNDevice(ctx context.Context, c *daemoncfg.Config) { // set device down - if stdout, stderr, err := execs.RunIPLink(ctx, "set", c.Device.Name, "down"); err != nil { + if stdout, stderr, err := execs.RunIPLink( + ctx, "set", c.VPNConfig.Device.Name, "down"); err != nil { + log.WithError(err).WithFields(log.Fields{ - "device": c.Device.Name, + "device": c.VPNConfig.Device.Name, "stdout": string(stdout), "stderr": string(stderr), }).Error("Daemon could not set device down") @@ -122,11 +121,11 @@ func teardownVPNDevice(ctx context.Context, c *vpnconfig.Config) { } // setupRouting sets up routing using config. -func (v *VPNSetup) setupRouting(vpnconf *vpnconfig.Config) { +func (v *VPNSetup) setupRouting(config *daemoncfg.Config) { if v.splitrt != nil { return } - v.splitrt = splitrt.NewSplitRouting(v.splitrtConf, vpnconf) + v.splitrt = splitrt.NewSplitRouting(config) if err := v.splitrt.Start(); err != nil { log.WithError(err).Error("VPNSetup error setting split routing") } @@ -142,12 +141,14 @@ func (v *VPNSetup) teardownRouting() { } // setupDNSServer sets the DNS server. -func (v *VPNSetup) setupDNSServer(ctx context.Context, config *vpnconfig.Config) { - device := config.Device.Name - if stdout, stderr, err := execs.RunResolvectl(ctx, "dns", device, v.dnsProxyConf.Address); err != nil { +func (v *VPNSetup) setupDNSServer(ctx context.Context, config *daemoncfg.Config) { + device := config.VPNConfig.Device.Name + if stdout, stderr, err := execs.RunResolvectl( + ctx, "dns", device, config.DNSProxy.Address); err != nil { + log.WithError(err).WithFields(log.Fields{ "device": device, - "server": v.dnsProxyConf.Address, + "server": config.DNSProxy.Address, "stdout": string(stdout), "stderr": string(stderr), }).Error("VPNSetup error setting dns server") @@ -155,12 +156,14 @@ func (v *VPNSetup) setupDNSServer(ctx context.Context, config *vpnconfig.Config) } // setupDNSDomains sets the DNS domains. -func (v *VPNSetup) setupDNSDomains(ctx context.Context, config *vpnconfig.Config) { - device := config.Device.Name - if stdout, stderr, err := execs.RunResolvectl(ctx, "domain", device, config.DNS.DefaultDomain, "~."); err != nil { +func (v *VPNSetup) setupDNSDomains(ctx context.Context, config *daemoncfg.Config) { + device := config.VPNConfig.Device.Name + if stdout, stderr, err := execs.RunResolvectl( + ctx, "domain", device, config.VPNConfig.DNS.DefaultDomain, "~."); err != nil { + log.WithError(err).WithFields(log.Fields{ "device": device, - "domain": config.DNS.DefaultDomain, + "domain": config.VPNConfig.DNS.DefaultDomain, "stdout": string(stdout), "stderr": string(stderr), }).Error("VPNSetup error setting dns domains") @@ -168,9 +171,11 @@ func (v *VPNSetup) setupDNSDomains(ctx context.Context, config *vpnconfig.Config } // setupDNSDefaultRoute sets the DNS default route. -func (v *VPNSetup) setupDNSDefaultRoute(ctx context.Context, config *vpnconfig.Config) { - device := config.Device.Name - if stdout, stderr, err := execs.RunResolvectl(ctx, "default-route", device, "yes"); err != nil { +func (v *VPNSetup) setupDNSDefaultRoute(ctx context.Context, config *daemoncfg.Config) { + device := config.VPNConfig.Device.Name + if stdout, stderr, err := execs.RunResolvectl( + ctx, "default-route", device, "yes"); err != nil { + log.WithError(err).WithFields(log.Fields{ "device": device, "stdout": string(stdout), @@ -180,15 +185,15 @@ func (v *VPNSetup) setupDNSDefaultRoute(ctx context.Context, config *vpnconfig.C } // setupDNS sets up DNS using config. -func (v *VPNSetup) setupDNS(ctx context.Context, config *vpnconfig.Config) { +func (v *VPNSetup) setupDNS(ctx context.Context, config *daemoncfg.Config) { // configure dns proxy // set remotes - remotes := config.DNS.Remotes() + remotes := config.VPNConfig.DNS.Remotes() v.dnsProxy.SetRemotes(remotes) // set watches - excludes := config.Split.DNSExcludes() + excludes := config.VPNConfig.Split.DNSExcludes() log.WithField("excludes", excludes).Debug("Daemon setting DNS Split Excludes") v.dnsProxy.SetWatches(excludes) @@ -222,7 +227,7 @@ func (v *VPNSetup) setupDNS(ctx context.Context, config *vpnconfig.Config) { } // teardownDNS tears down the DNS configuration. -func (v *VPNSetup) teardownDNS(ctx context.Context, vpnconf *vpnconfig.Config) { +func (v *VPNSetup) teardownDNS(ctx context.Context, config *daemoncfg.Config) { // update dns proxy configuration // reset remotes @@ -235,9 +240,11 @@ func (v *VPNSetup) teardownDNS(ctx context.Context, vpnconf *vpnconfig.Config) { // update dns configuration of host // undo device dns configuration - if stdout, stderr, err := execs.RunResolvectl(ctx, "revert", vpnconf.Device.Name); err != nil { + if stdout, stderr, err := execs.RunResolvectl( + ctx, "revert", config.VPNConfig.Device.Name); err != nil { + log.WithError(err).WithFields(log.Fields{ - "device": vpnconf.Device.Name, + "device": config.VPNConfig.Device.Name, "stdout": string(stdout), "stderr": string(stderr), }).Error("VPNSetup error reverting dns configuration") @@ -274,9 +281,9 @@ func (v *VPNSetup) checkDNSProtocols(protocols []string) bool { } // checkDNSServers checks the configured DNS servers. -func (v *VPNSetup) checkDNSServers(servers []string) bool { +func (v *VPNSetup) checkDNSServers(conf *daemoncfg.Config, servers []string) bool { // check dns server ip - if len(servers) != 1 || servers[0] != v.dnsProxyConf.Address { + if len(servers) != 1 || servers[0] != conf.DNSProxy.Address { // server not correct return false } @@ -285,9 +292,9 @@ func (v *VPNSetup) checkDNSServers(servers []string) bool { } // checkDNSDomain checks the configured DNS domains. -func (v *VPNSetup) checkDNSDomain(config *vpnconfig.Config, domains []string) bool { +func (v *VPNSetup) checkDNSDomain(config *daemoncfg.Config, domains []string) bool { // get domains in config - inConfig := strings.Fields(config.DNS.DefaultDomain) + inConfig := strings.Fields(config.VPNConfig.DNS.DefaultDomain) inConfig = append(inConfig, "~.") // check domains in config @@ -309,11 +316,11 @@ func (v *VPNSetup) checkDNSDomain(config *vpnconfig.Config, domains []string) bo } // ensureDNS ensures the DNS config. -func (v *VPNSetup) ensureDNS(ctx context.Context, config *vpnconfig.Config) bool { +func (v *VPNSetup) ensureDNS(ctx context.Context, config *daemoncfg.Config) bool { log.Debug("VPNSetup checking DNS settings") // get dns settings - device := config.Device.Name + device := config.VPNConfig.Device.Name stdout, stderr, err := execs.RunResolvectl(ctx, "status", device, "--no-pager") if err != nil { log.WithError(err).WithFields(log.Fields{ @@ -343,7 +350,7 @@ func (v *VPNSetup) ensureDNS(ctx context.Context, config *vpnconfig.Config) bool protOK = v.checkDNSProtocols(f) case "DNS Servers": - srvOK = v.checkDNSServers(f) + srvOK = v.checkDNSServers(config, f) case "DNS Domain": domOK = v.checkDNSDomain(config, f) @@ -381,7 +388,7 @@ func (v *VPNSetup) ensureDNS(ctx context.Context, config *vpnconfig.Config) bool } // ensureConfig ensured that the VPN config is and stays active. -func (v *VPNSetup) ensureConfig(ctx context.Context, vpnconf *vpnconfig.Config) { +func (v *VPNSetup) ensureConfig(ctx context.Context, conf *daemoncfg.Config) { defer close(v.ensureClosed) timerInvalid := time.Second @@ -393,7 +400,7 @@ func (v *VPNSetup) ensureConfig(ctx context.Context, vpnconf *vpnconfig.Config) log.Debug("VPNSetup checking VPN configuration") // ensure DNS settings - if ok := v.ensureDNS(ctx, vpnconf); !ok { + if ok := v.ensureDNS(ctx, conf); !ok { timer = timerInvalid break } @@ -408,10 +415,10 @@ func (v *VPNSetup) ensureConfig(ctx context.Context, vpnconf *vpnconfig.Config) } // startEnsure starts ensuring the VPN config. -func (v *VPNSetup) startEnsure(ctx context.Context, vpnconf *vpnconfig.Config) { +func (v *VPNSetup) startEnsure(ctx context.Context, conf *daemoncfg.Config) { v.ensureDone = make(chan struct{}) v.ensureClosed = make(chan struct{}) - go v.ensureConfig(ctx, vpnconf) + go v.ensureConfig(ctx, conf) } // stopEnsure stops ensuring the VPN config. @@ -421,25 +428,25 @@ func (v *VPNSetup) stopEnsure() { } // setup sets up the vpn configuration. -func (v *VPNSetup) setup(ctx context.Context, vpnconf *vpnconfig.Config) { +func (v *VPNSetup) setup(ctx context.Context, conf *daemoncfg.Config) { // setup device, routing, dns - setupVPNDevice(ctx, vpnconf) - v.setupRouting(vpnconf) - v.setupDNS(ctx, vpnconf) + setupVPNDevice(ctx, conf) + v.setupRouting(conf) + v.setupDNS(ctx, conf) // ensure VPN config - v.startEnsure(ctx, vpnconf) + v.startEnsure(ctx, conf) } // teardown tears down the vpn configuration. -func (v *VPNSetup) teardown(ctx context.Context, vpnconf *vpnconfig.Config) { +func (v *VPNSetup) teardown(ctx context.Context, conf *daemoncfg.Config) { // stop ensuring VPN config v.stopEnsure() // tear down device, routing, dns - teardownVPNDevice(ctx, vpnconf) + teardownVPNDevice(ctx, conf) v.teardownRouting() - v.teardownDNS(ctx, vpnconf) + v.teardownDNS(ctx, conf) } // getState gets the internal state. @@ -460,9 +467,9 @@ func (v *VPNSetup) handleCommand(ctx context.Context, c *command) { switch c.cmd { case commandSetup: - v.setup(ctx, c.vpnconf) + v.setup(ctx, c.conf) case commandTeardown: - v.teardown(ctx, c.vpnconf) + v.teardown(ctx, c.conf) case commandGetState: v.getState(c) } @@ -520,22 +527,22 @@ func (v *VPNSetup) Stop() { } // Setup sets the VPN config up. -func (v *VPNSetup) Setup(vpnconfig *vpnconfig.Config) { +func (v *VPNSetup) Setup(conf *daemoncfg.Config) { c := &command{ - cmd: commandSetup, - vpnconf: vpnconfig, - done: make(chan struct{}), + cmd: commandSetup, + conf: conf, + done: make(chan struct{}), } v.cmds <- c <-c.done } // Teardown tears the VPN config down. -func (v *VPNSetup) Teardown(vpnconfig *vpnconfig.Config) { +func (v *VPNSetup) Teardown(conf *daemoncfg.Config) { c := &command{ - cmd: commandTeardown, - vpnconf: vpnconfig, - done: make(chan struct{}), + cmd: commandTeardown, + conf: conf, + done: make(chan struct{}), } v.cmds <- c <-c.done @@ -553,14 +560,9 @@ func (v *VPNSetup) GetState() *State { } // NewVPNSetup returns a new VPNSetup. -func NewVPNSetup( - dnsProxyConfig *dnsproxy.Config, - splitrtConfig *splitrt.Config, -) *VPNSetup { +func NewVPNSetup(dnsProxy *dnsproxy.Proxy) *VPNSetup { return &VPNSetup{ - dnsProxy: dnsproxy.NewProxy(dnsProxyConfig), - dnsProxyConf: dnsProxyConfig, - splitrtConf: splitrtConfig, + dnsProxy: dnsProxy, cmds: make(chan *command), done: make(chan struct{}), @@ -569,8 +571,9 @@ func NewVPNSetup( } // Cleanup cleans up the configuration after a failed shutdown. -func Cleanup(ctx context.Context, vpnDevice string, splitrtConfig *splitrt.Config) { +func Cleanup(ctx context.Context, config *daemoncfg.Config) { // dns, device, split routing + vpnDevice := config.OpenConnect.VPNDevice if _, _, err := execs.RunResolvectl(ctx, "revert", vpnDevice); err == nil { log.WithField("device", vpnDevice). Warn("VPNSetup cleaned up dns config") @@ -579,5 +582,5 @@ func Cleanup(ctx context.Context, vpnDevice string, splitrtConfig *splitrt.Confi log.WithField("device", vpnDevice). Warn("VPNSetup cleaned up vpn device") } - splitrt.Cleanup(ctx, splitrtConfig) + splitrt.Cleanup(ctx, config) } diff --git a/internal/vpnsetup/vpnsetup_test.go b/internal/vpnsetup/vpnsetup_test.go index bd577431..67b9525e 100644 --- a/internal/vpnsetup/vpnsetup_test.go +++ b/internal/vpnsetup/vpnsetup_test.go @@ -3,7 +3,6 @@ package vpnsetup import ( "context" "errors" - "net" "net/netip" "reflect" "strings" @@ -11,11 +10,10 @@ import ( "time" "github.com/telekom-mms/oc-daemon/internal/addrmon" + "github.com/telekom-mms/oc-daemon/internal/daemoncfg" "github.com/telekom-mms/oc-daemon/internal/devmon" "github.com/telekom-mms/oc-daemon/internal/dnsproxy" "github.com/telekom-mms/oc-daemon/internal/execs" - "github.com/telekom-mms/oc-daemon/internal/splitrt" - "github.com/telekom-mms/oc-daemon/pkg/vpnconfig" "github.com/vishvananda/netlink" ) @@ -25,13 +23,12 @@ func TestSetupVPNDevice(t *testing.T) { oldRunCmd := execs.RunCmd defer func() { execs.RunCmd = oldRunCmd }() - c := vpnconfig.New() - c.Device.Name = "tun0" - c.Device.MTU = 1300 - c.IPv4.Address = net.IPv4(192, 168, 0, 123) - c.IPv4.Netmask = net.IPv4Mask(255, 255, 255, 0) - c.IPv6.Address = net.ParseIP("2001::1") - c.IPv6.Netmask = net.CIDRMask(64, 128) + c := daemoncfg.NewConfig() + c.DNSProxy.Address = "127.0.0.1:4253" + c.VPNConfig.Device.Name = "tun0" + c.VPNConfig.Device.MTU = 1300 + c.VPNConfig.IPv4 = netip.MustParsePrefix("192.168.0.123/24") + c.VPNConfig.IPv6 = netip.MustParsePrefix("2001::1/64") // overwrite RunCmd want := []string{ @@ -86,8 +83,9 @@ func TestTeardownVPNDevice(t *testing.T) { oldRunCmd := execs.RunCmd defer func() { execs.RunCmd = oldRunCmd }() - c := vpnconfig.New() - c.Device.Name = "tun0" + c := daemoncfg.NewConfig() + c.DNSProxy.Address = "127.0.0.1:4253" + c.VPNConfig.Device.Name = "tun0" // overwrite RunCmd want := []string{ @@ -118,16 +116,16 @@ func TestVPNSetupSetupDNS(t *testing.T) { oldRunCmd := execs.RunCmd defer func() { execs.RunCmd = oldRunCmd }() - c := vpnconfig.New() - c.Device.Name = "tun0" - c.DNS.DefaultDomain = "mycompany.com" + c := daemoncfg.NewConfig() + c.VPNConfig.Device.Name = "tun0" + c.VPNConfig.DNS.DefaultDomain = "mycompany.com" got := []string{} execs.RunCmd = func(_ context.Context, _ string, _ string, arg ...string) ([]byte, []byte, error) { got = append(got, strings.Join(arg, " ")) return nil, nil, nil } - v := NewVPNSetup(dnsproxy.NewConfig(), splitrt.NewConfig()) + v := NewVPNSetup(dnsproxy.NewProxy(daemoncfg.NewDNSProxy())) v.setupDNS(context.Background(), c) want := []string{ @@ -160,8 +158,8 @@ func TestVPNSetupTeardownDNS(t *testing.T) { oldRunCmd := execs.RunCmd defer func() { execs.RunCmd = oldRunCmd }() - c := vpnconfig.New() - c.Device.Name = "tun0" + c := daemoncfg.NewConfig() + c.VPNConfig.Device.Name = "tun0" got := []string{} execs.RunCmd = func(_ context.Context, _ string, _ string, arg ...string) ([]byte, []byte, error) { @@ -169,7 +167,7 @@ func TestVPNSetupTeardownDNS(t *testing.T) { return nil, nil, nil } - v := NewVPNSetup(dnsproxy.NewConfig(), splitrt.NewConfig()) + v := NewVPNSetup(dnsproxy.NewProxy(daemoncfg.NewDNSProxy())) v.teardownDNS(context.Background(), c) want := []string{ @@ -196,7 +194,7 @@ func TestVPNSetupTeardownDNS(t *testing.T) { // TestVPNSetupCheckDNSProtocols tests checkDNSProtocols of VPNSetup. func TestVPNSetupCheckDNSProtocols(t *testing.T) { - v := NewVPNSetup(dnsproxy.NewConfig(), splitrt.NewConfig()) + v := NewVPNSetup(dnsproxy.NewProxy(daemoncfg.NewDNSProxy())) // test invalid for _, invalid := range [][]string{ @@ -221,48 +219,49 @@ func TestVPNSetupCheckDNSProtocols(t *testing.T) { // TestVPNSetupCheckDNSServers tests checkDNSServers of VPNSetup. func TestVPNSetupCheckDNSServers(t *testing.T) { - v := NewVPNSetup(dnsproxy.NewConfig(), splitrt.NewConfig()) + conf := daemoncfg.NewConfig() + v := NewVPNSetup(dnsproxy.NewProxy(conf.DNSProxy)) // test invalid for _, invalid := range [][]string{ {}, {"x", "y", "z"}, } { - if ok := v.checkDNSServers(invalid); ok { + if ok := v.checkDNSServers(conf, invalid); ok { t.Errorf("dns check should fail with %v", invalid) } } // test valid - proxy := []string{v.dnsProxyConf.Address} - if ok := v.checkDNSServers(proxy); !ok { + proxy := []string{conf.DNSProxy.Address} + if ok := v.checkDNSServers(conf, proxy); !ok { t.Errorf("dns check should not fail with %v", proxy) } } // TestVPNSetupCheckDNSDomain tests checkDNSDomain of VPNSetup. func TestVPNSetupCheckDNSDomain(t *testing.T) { - v := NewVPNSetup(dnsproxy.NewConfig(), splitrt.NewConfig()) - vpnconf := vpnconfig.New() - vpnconf.DNS.DefaultDomain = "test.example.com" + conf := daemoncfg.NewConfig() + conf.VPNConfig.DNS.DefaultDomain = "test.example.com" + v := NewVPNSetup(dnsproxy.NewProxy(conf.DNSProxy)) // test invalid for _, invalid := range [][]string{ {}, {"x", "y", "z"}, - {vpnconf.DNS.DefaultDomain}, + {conf.VPNConfig.DNS.DefaultDomain}, } { - if ok := v.checkDNSDomain(vpnconf, invalid); ok { + if ok := v.checkDNSDomain(conf, invalid); ok { t.Errorf("dns check should fail with %v", invalid) } } // test valid for _, valid := range [][]string{ - {"~.", vpnconf.DNS.DefaultDomain}, - {vpnconf.DNS.DefaultDomain, "~."}, + {"~.", conf.VPNConfig.DNS.DefaultDomain}, + {conf.VPNConfig.DNS.DefaultDomain, "~."}, } { - if ok := v.checkDNSDomain(vpnconf, valid); !ok { + if ok := v.checkDNSDomain(conf, valid); !ok { t.Errorf("dns check should not fail with %v", valid) } } @@ -274,20 +273,20 @@ func TestVPNSetupEnsureDNS(t *testing.T) { oldRunCmd := execs.RunCmd defer func() { execs.RunCmd = oldRunCmd }() - v := NewVPNSetup(dnsproxy.NewConfig(), splitrt.NewConfig()) - ctx := context.Background() - vpnconf := vpnconfig.New() - // test settings - v.dnsProxyConf.Address = "127.0.0.1:4253" - vpnconf.DNS.DefaultDomain = "test.example.com" + conf := daemoncfg.NewConfig() + conf.DNSProxy.Address = "127.0.0.1:4253" + conf.VPNConfig.DNS.DefaultDomain = "test.example.com" + + v := NewVPNSetup(dnsproxy.NewProxy(conf.DNSProxy)) + ctx := context.Background() // test resolvectl error execs.RunCmd = func(context.Context, string, string, ...string) ([]byte, []byte, error) { return nil, nil, errors.New("test error") } - if ok := v.ensureDNS(ctx, vpnconf); ok { + if ok := v.ensureDNS(ctx, conf); ok { t.Errorf("ensure dns should fail with resolvectl error") } @@ -304,7 +303,7 @@ func TestVPNSetupEnsureDNS(t *testing.T) { return invalid, nil, nil } - if ok := v.ensureDNS(ctx, vpnconf); ok { + if ok := v.ensureDNS(ctx, conf); ok { t.Errorf("ensure dns should fail with %v", invalid) } } @@ -319,7 +318,7 @@ func TestVPNSetupEnsureDNS(t *testing.T) { return valid, nil, nil } - if ok := v.ensureDNS(ctx, vpnconf); !ok { + if ok := v.ensureDNS(ctx, conf); !ok { t.Errorf("ensure dns should not fail with %v", valid) } } @@ -327,7 +326,7 @@ func TestVPNSetupEnsureDNS(t *testing.T) { // TestVPNSetupStartStop tests Start and Stop of VPNSetup. func TestVPNSetupStartStop(_ *testing.T) { - v := NewVPNSetup(dnsproxy.NewConfig(), splitrt.NewConfig()) + v := NewVPNSetup(dnsproxy.NewProxy(daemoncfg.NewDNSProxy())) v.Start() v.Stop() } @@ -354,12 +353,12 @@ func TestVPNSetupSetupTeardown(_ *testing.T) { defer func() { devmon.RegisterLinkUpdates = oldRegisterLinkUpdates }() // start vpn setup, prepare config - v := NewVPNSetup(dnsproxy.NewConfig(), splitrt.NewConfig()) + conf := daemoncfg.NewConfig() + v := NewVPNSetup(dnsproxy.NewProxy(conf.DNSProxy)) v.Start() - vpnconf := vpnconfig.New() // setup config - v.Setup(vpnconf) + v.Setup(conf) // send dns report while config is active report := dnsproxy.NewReport("example.com", netip.Addr{}, 300) @@ -369,7 +368,7 @@ func TestVPNSetupSetupTeardown(_ *testing.T) { time.Sleep(time.Second * 2) // teardown config - v.Teardown(vpnconf) + v.Teardown(conf) // send dns report while config is not active v.dnsProxy.Reports() <- dnsproxy.NewReport("example.com", netip.Addr{}, 300) @@ -400,7 +399,8 @@ func TestVPNSetupGetState(t *testing.T) { defer func() { devmon.RegisterLinkUpdates = oldRegisterLinkUpdates }() // start vpn setup - v := NewVPNSetup(dnsproxy.NewConfig(), splitrt.NewConfig()) + conf := daemoncfg.NewConfig() + v := NewVPNSetup(dnsproxy.NewProxy(conf.DNSProxy)) v.Start() // without vpn config @@ -412,8 +412,7 @@ func TestVPNSetupGetState(t *testing.T) { } // with vpn config - vpnconf := vpnconfig.New() - v.Setup(vpnconf) + v.Setup(conf) got = v.GetState() if got == nil || @@ -423,20 +422,17 @@ func TestVPNSetupGetState(t *testing.T) { } // teardown config - v.Teardown(vpnconf) + v.Teardown(conf) v.Stop() } // TestNewVPNSetup tests NewVPNSetup. func TestNewVPNSetup(t *testing.T) { - dnsConfig := dnsproxy.NewConfig() - splitrtConfig := splitrt.NewConfig() - - v := NewVPNSetup(dnsConfig, splitrtConfig) + dnsProxy := dnsproxy.NewProxy(daemoncfg.NewDNSProxy()) + v := NewVPNSetup(dnsProxy) if v == nil || - v.dnsProxyConf != dnsConfig || - v.splitrtConf != splitrtConfig || + v.dnsProxy != dnsProxy || v.cmds == nil || v.done == nil || v.closed == nil { @@ -455,7 +451,9 @@ func TestCleanup(t *testing.T) { got = append(got, cmd+" "+strings.Join(arg, " ")+" "+s) return nil, nil, nil } - Cleanup(context.Background(), "tun0", splitrt.NewConfig()) + cfg := daemoncfg.NewConfig() + cfg.OpenConnect.VPNDevice = "tun0" + Cleanup(context.Background(), cfg) want := []string{ "resolvectl revert tun0", "ip link delete tun0", diff --git a/tools/dnsproxy/main.go b/tools/dnsproxy/main.go index 2c59ae99..18e7bc0c 100644 --- a/tools/dnsproxy/main.go +++ b/tools/dnsproxy/main.go @@ -8,6 +8,7 @@ import ( "strings" log "github.com/sirupsen/logrus" + "github.com/telekom-mms/oc-daemon/internal/daemoncfg" "github.com/telekom-mms/oc-daemon/internal/dnsproxy" ) @@ -71,7 +72,7 @@ func parseCommandLine() { func main() { log.SetLevel(log.DebugLevel) parseCommandLine() - c := dnsproxy.NewConfig() + c := daemoncfg.NewDNSProxy() c.Address = address c.ListenUDP = true c.ListenTCP = true diff --git a/tools/ocrunner/main.go b/tools/ocrunner/main.go index 25d939ae..8c19d269 100644 --- a/tools/ocrunner/main.go +++ b/tools/ocrunner/main.go @@ -8,6 +8,7 @@ import ( "time" log "github.com/sirupsen/logrus" + "github.com/telekom-mms/oc-daemon/internal/daemoncfg" "github.com/telekom-mms/oc-daemon/internal/ocrunner" "github.com/telekom-mms/oc-daemon/pkg/client" ) @@ -57,10 +58,7 @@ func main() { } // connect client - ocrConf := ocrunner.NewConfig() - ocrConf.XMLProfile = *profile - ocrConf.VPNCScript = *script - c := ocrunner.NewConnect(ocrConf) + c := ocrunner.NewConnect() done := make(chan struct{}) go c.Start() go func() { @@ -73,7 +71,11 @@ func main() { done <- struct{}{} }() if *connect { - c.Connect(a.GetLogin(), []string{}) + dconf := daemoncfg.NewConfig() + dconf.OpenConnect.XMLProfile = *profile + dconf.OpenConnect.VPNCScript = *script + dconf.LoginInfo = a.GetLogin() + c.Connect(dconf, []string{}) } // disconnect client