diff --git a/internal/cmdtmpl/command.go b/internal/cmdtmpl/command.go index 78e2397..4584801 100644 --- a/internal/cmdtmpl/command.go +++ b/internal/cmdtmpl/command.go @@ -5,10 +5,9 @@ import ( "bytes" "context" "fmt" + "os/exec" "strings" "text/template" - - "github.com/telekom-mms/oc-daemon/internal/execs" ) // Command consists of a command line to be executed and an optional Stdin to @@ -537,9 +536,24 @@ type Cmd struct { Stdin string } +// RunCmd runs the cmd with args and sets stdin to s, returns stdout and stderr. +var RunCmd = func(ctx context.Context, cmd string, s string, arg ...string) (stdout, stderr []byte, err error) { + c := exec.CommandContext(ctx, cmd, arg...) + if s != "" { + c.Stdin = bytes.NewBufferString(s) + } + var outbuf, errbuf bytes.Buffer + c.Stdout = &outbuf + c.Stderr = &errbuf + err = c.Run() + stdout = outbuf.Bytes() + stderr = errbuf.Bytes() + return +} + // Run runs the command. func (c *Cmd) Run(ctx context.Context) (stdout, stderr []byte, err error) { - return execs.RunCmd(ctx, c.Cmd, c.Stdin, c.Args...) + return RunCmd(ctx, c.Cmd, c.Stdin, c.Args...) } // GetCmds returns a list of Cmds ready to run. diff --git a/internal/cmdtmpl/command_test.go b/internal/cmdtmpl/command_test.go index 0be351c..0969a8c 100644 --- a/internal/cmdtmpl/command_test.go +++ b/internal/cmdtmpl/command_test.go @@ -2,6 +2,7 @@ package cmdtmpl import ( "context" + "path/filepath" "testing" "text/template" @@ -57,6 +58,39 @@ func TestGetCommandList(t *testing.T) { } } +// TestRunCmd tests RunCmd. +func TestRunCmd(t *testing.T) { + ctx := context.Background() + + // test not existing + dir := t.TempDir() + if _, _, err := RunCmd(ctx, filepath.Join(dir, "does/not/exist"), ""); err == nil { + t.Errorf("running not existing command should fail: %v", err) + } + + // test existing + if _, _, err := RunCmd(ctx, "echo", "", "this", "is", "a", "test"); err != nil { + t.Errorf("running echo failed: %v", err) + } + + // test with stdin + if _, _, err := RunCmd(ctx, "echo", "this is a test"); err != nil { + t.Errorf("running echo failed: %v", err) + } + + // test stdout + stdout, stderr, err := RunCmd(ctx, "cat", "this is a test") + if err != nil || string(stdout) != "this is a test" { + t.Errorf("running echo failed: %s, %s, %v", stdout, stderr, err) + } + + // test stderr and error + stdout, stderr, err = RunCmd(ctx, "cat", "", "does/not/exist") + if err == nil || string(stderr) != "cat: does/not/exist: No such file or directory\n" { + t.Errorf("running echo failed: %s, %s, %v", stdout, stderr, err) + } +} + // TestCmdRun tests Run of Cmd. func TestCmdRun(t *testing.T) { cmd := &Cmd{ diff --git a/internal/daemon/daemon.go b/internal/daemon/daemon.go index 0ea23fd..b7da133 100644 --- a/internal/daemon/daemon.go +++ b/internal/daemon/daemon.go @@ -18,7 +18,6 @@ import ( "github.com/telekom-mms/oc-daemon/internal/daemoncfg" "github.com/telekom-mms/oc-daemon/internal/dbusapi" "github.com/telekom-mms/oc-daemon/internal/dnsproxy" - "github.com/telekom-mms/oc-daemon/internal/execs" "github.com/telekom-mms/oc-daemon/internal/ocrunner" "github.com/telekom-mms/oc-daemon/internal/profilemon" "github.com/telekom-mms/oc-daemon/internal/sleepmon" @@ -941,9 +940,6 @@ func (d *Daemon) Start() error { // create context ctx := context.Background() - // set executables - execs.SetExecutables(d.config.Executables) - // cleanup after a failed shutdown d.cleanup(ctx) diff --git a/internal/execs/execs.go b/internal/execs/execs.go deleted file mode 100644 index b13f2d8..0000000 --- a/internal/execs/execs.go +++ /dev/null @@ -1,41 +0,0 @@ -// Package execs contains external executables. -package execs - -import ( - "bytes" - "context" - "os/exec" - - "github.com/telekom-mms/oc-daemon/internal/daemoncfg" -) - -// executables. -var ( - ip = daemoncfg.ExecutablesIP - sysctl = daemoncfg.ExecutablesSysctl - nft = daemoncfg.ExecutablesNft - resolvectl = daemoncfg.ExecutablesResolvectl -) - -// RunCmd runs the cmd with args and sets stdin to s, returns stdout and stderr. -var RunCmd = func(ctx context.Context, cmd string, s string, arg ...string) (stdout, stderr []byte, err error) { - c := exec.CommandContext(ctx, cmd, arg...) - if s != "" { - c.Stdin = bytes.NewBufferString(s) - } - var outbuf, errbuf bytes.Buffer - c.Stdout = &outbuf - c.Stderr = &errbuf - err = c.Run() - stdout = outbuf.Bytes() - stderr = errbuf.Bytes() - return -} - -// SetExecutables configures all executables from config. -func SetExecutables(config *daemoncfg.Executables) { - ip = config.IP - sysctl = config.Sysctl - nft = config.Nft - resolvectl = config.Resolvectl -} diff --git a/internal/execs/execs_test.go b/internal/execs/execs_test.go deleted file mode 100644 index 07d729e..0000000 --- a/internal/execs/execs_test.go +++ /dev/null @@ -1,63 +0,0 @@ -package execs - -import ( - "context" - "path/filepath" - "testing" - - "github.com/telekom-mms/oc-daemon/internal/daemoncfg" -) - -// TestRunCmd tests RunCmd. -func TestRunCmd(t *testing.T) { - ctx := context.Background() - - // test not existing - dir := t.TempDir() - if _, _, err := RunCmd(ctx, filepath.Join(dir, "does/not/exist"), ""); err == nil { - t.Errorf("running not existing command should fail: %v", err) - } - - // test existing - if _, _, err := RunCmd(ctx, "echo", "", "this", "is", "a", "test"); err != nil { - t.Errorf("running echo failed: %v", err) - } - - // test with stdin - if _, _, err := RunCmd(ctx, "echo", "this is a test"); err != nil { - t.Errorf("running echo failed: %v", err) - } - - // test stdout - stdout, stderr, err := RunCmd(ctx, "cat", "this is a test") - if err != nil || string(stdout) != "this is a test" { - t.Errorf("running echo failed: %s, %s, %v", stdout, stderr, err) - } - - // test stderr and error - stdout, stderr, err = RunCmd(ctx, "cat", "", "does/not/exist") - if err == nil || string(stderr) != "cat: does/not/exist: No such file or directory\n" { - t.Errorf("running echo failed: %s, %s, %v", stdout, stderr, err) - } -} - -// TestSetExecutables tests SetExecutables. -func TestSetExecutables(t *testing.T) { - old := daemoncfg.NewExecutables() - defer SetExecutables(old) - - config := &daemoncfg.Executables{ - IP: "/test/ip", - Sysctl: "/test/sysctl", - Nft: "/test/nft", - Resolvectl: "/test/resolvectl", - } - SetExecutables(config) - if ip != config.IP || - sysctl != config.Sysctl || - nft != config.Nft || - resolvectl != config.Resolvectl { - // executables not set properly - t.Errorf("executables incorrect") - } -} diff --git a/internal/trafpol/filter_test.go b/internal/trafpol/filter_test.go index 58d7569..ca48d18 100644 --- a/internal/trafpol/filter_test.go +++ b/internal/trafpol/filter_test.go @@ -6,17 +6,17 @@ import ( "net/netip" "testing" + "github.com/telekom-mms/oc-daemon/internal/cmdtmpl" "github.com/telekom-mms/oc-daemon/internal/daemoncfg" - "github.com/telekom-mms/oc-daemon/internal/execs" ) // TestFilterFunctionsErrors tests filter functions, errors. func TestFilterFunctionsErrors(_ *testing.T) { - oldRunCmd := execs.RunCmd - execs.RunCmd = func(context.Context, string, string, ...string) ([]byte, []byte, error) { + oldRunCmd := cmdtmpl.RunCmd + cmdtmpl.RunCmd = func(context.Context, string, string, ...string) ([]byte, []byte, error) { return nil, nil, errors.New("test error") } - defer func() { execs.RunCmd = oldRunCmd }() + defer func() { cmdtmpl.RunCmd = oldRunCmd }() ctx := context.Background() diff --git a/internal/trafpol/trafpol_test.go b/internal/trafpol/trafpol_test.go index dd013e0..ccac133 100644 --- a/internal/trafpol/trafpol_test.go +++ b/internal/trafpol/trafpol_test.go @@ -10,10 +10,10 @@ import ( "sync" "testing" + "github.com/telekom-mms/oc-daemon/internal/cmdtmpl" "github.com/telekom-mms/oc-daemon/internal/cpd" "github.com/telekom-mms/oc-daemon/internal/daemoncfg" "github.com/telekom-mms/oc-daemon/internal/devmon" - "github.com/telekom-mms/oc-daemon/internal/execs" "github.com/vishvananda/netlink" ) @@ -56,14 +56,16 @@ func TestTrafPolHandleCPDReport(t *testing.T) { var nftMutex sync.Mutex nftCmds := []string{} - oldRunCmd := execs.RunCmd - execs.RunCmd = func(_ context.Context, cmd string, stdin string, args ...string) ([]byte, []byte, error) { + oldRunCmd := cmdtmpl.RunCmd + cmdtmpl.RunCmd = func(_ context.Context, cmd string, stdin string, + args ...string) ([]byte, []byte, error) { + nftMutex.Lock() defer nftMutex.Unlock() nftCmds = append(nftCmds, cmd+" "+strings.Join(args, " ")+" "+stdin) return nil, nil, nil } - defer func() { execs.RunCmd = oldRunCmd }() + defer func() { cmdtmpl.RunCmd = oldRunCmd }() getNftCmds := func() []string { nftMutex.Lock() @@ -303,10 +305,14 @@ func TestCleanup(t *testing.T) { "nft -f - delete table inet oc-daemon-filter", } got := []string{} - execs.RunCmd = func(_ context.Context, cmd string, _ string, args ...string) ([]byte, []byte, error) { + + oldRunCmd := cmdtmpl.RunCmd + cmdtmpl.RunCmd = func(_ context.Context, cmd string, _ string, args ...string) ([]byte, []byte, error) { got = append(got, cmd+" "+strings.Join(args, " ")) return nil, nil, nil } + defer func() { cmdtmpl.RunCmd = oldRunCmd }() + Cleanup(context.Background(), daemoncfg.NewConfig()) if !reflect.DeepEqual(got, want) { t.Errorf("got %v, want %v", got, want) diff --git a/internal/vpnsetup/vpnsetup_test.go b/internal/vpnsetup/vpnsetup_test.go index b0e9a2a..775980f 100644 --- a/internal/vpnsetup/vpnsetup_test.go +++ b/internal/vpnsetup/vpnsetup_test.go @@ -10,10 +10,10 @@ import ( "time" "github.com/telekom-mms/oc-daemon/internal/addrmon" + "github.com/telekom-mms/oc-daemon/internal/cmdtmpl" "github.com/telekom-mms/oc-daemon/internal/daemoncfg" "github.com/telekom-mms/oc-daemon/internal/devmon" "github.com/telekom-mms/oc-daemon/internal/dnsproxy" - "github.com/telekom-mms/oc-daemon/internal/execs" "github.com/vishvananda/netlink" ) @@ -95,8 +95,8 @@ func TestVPNSetupCheckDNSDomain(t *testing.T) { // TestVPNSetupEnsureDNS tests ensureDNS of VPNSetup. func TestVPNSetupEnsureDNS(t *testing.T) { // clean up after tests - oldRunCmd := execs.RunCmd - defer func() { execs.RunCmd = oldRunCmd }() + oldRunCmd := cmdtmpl.RunCmd + defer func() { cmdtmpl.RunCmd = oldRunCmd }() // test settings conf := daemoncfg.NewConfig() @@ -107,7 +107,7 @@ func TestVPNSetupEnsureDNS(t *testing.T) { ctx := context.Background() // test resolvectl error - execs.RunCmd = func(context.Context, string, string, ...string) ([]byte, []byte, error) { + cmdtmpl.RunCmd = func(context.Context, string, string, ...string) ([]byte, []byte, error) { return nil, nil, errors.New("test error") } @@ -124,7 +124,7 @@ func TestVPNSetupEnsureDNS(t *testing.T) { []byte("header\nProtocols: +DefaultRoute\nDNS Servers: other\nDNS Domain: test.example.com ~.\n"), []byte("header\nProtocols: +DefaultRoute\nDNS Servers: 127.0.0.1:4253\nDNS Domain: other\n"), } { - execs.RunCmd = func(context.Context, string, string, ...string) ([]byte, []byte, error) { + cmdtmpl.RunCmd = func(context.Context, string, string, ...string) ([]byte, []byte, error) { return invalid, nil, nil } @@ -139,7 +139,7 @@ func TestVPNSetupEnsureDNS(t *testing.T) { []byte("header\n Protocols: +DefaultRoute \nother\n " + "DNS Servers: 127.0.0.1:4253 \n DNS Domain: test.example.com ~.\nother\n"), } { - execs.RunCmd = func(context.Context, string, string, ...string) ([]byte, []byte, error) { + cmdtmpl.RunCmd = func(context.Context, string, string, ...string) ([]byte, []byte, error) { return valid, nil, nil } @@ -159,11 +159,11 @@ func TestVPNSetupStartStop(_ *testing.T) { // TestVPNSetupSetupTeardown tests Setup and Teardown of VPNSetup. func TestVPNSetupSetupTeardown(_ *testing.T) { // override functions - oldCmd := execs.RunCmd - execs.RunCmd = func(context.Context, string, string, ...string) ([]byte, []byte, error) { + oldCmd := cmdtmpl.RunCmd + cmdtmpl.RunCmd = func(context.Context, string, string, ...string) ([]byte, []byte, error) { return nil, nil, nil } - defer func() { execs.RunCmd = oldCmd }() + defer func() { cmdtmpl.RunCmd = oldCmd }() oldRegisterAddrUpdates := addrmon.RegisterAddrUpdates addrmon.RegisterAddrUpdates = func(*addrmon.AddrMon) (chan netlink.AddrUpdate, error) { @@ -205,11 +205,11 @@ func TestVPNSetupSetupTeardown(_ *testing.T) { // TestVPNSetupGetState tests GetState of VPNSetup. func TestVPNSetupGetState(t *testing.T) { // override functions - oldCmd := execs.RunCmd - execs.RunCmd = func(context.Context, string, string, ...string) ([]byte, []byte, error) { + oldCmd := cmdtmpl.RunCmd + cmdtmpl.RunCmd = func(context.Context, string, string, ...string) ([]byte, []byte, error) { return nil, nil, nil } - defer func() { execs.RunCmd = oldCmd }() + defer func() { cmdtmpl.RunCmd = oldCmd }() oldRegisterAddrUpdates := addrmon.RegisterAddrUpdates addrmon.RegisterAddrUpdates = func(*addrmon.AddrMon) (chan netlink.AddrUpdate, error) { @@ -268,7 +268,7 @@ func TestNewVPNSetup(t *testing.T) { // TestCleanup tests Cleanup. func TestCleanup(t *testing.T) { got := []string{} - execs.RunCmd = func(_ context.Context, cmd string, s string, arg ...string) ([]byte, []byte, error) { + cmdtmpl.RunCmd = func(_ context.Context, cmd string, s string, arg ...string) ([]byte, []byte, error) { if s == "" { got = append(got, cmd+" "+strings.Join(arg, " ")) return nil, nil, nil