diff --git a/internal/cmdtmpl/command.go b/internal/cmdtmpl/command.go index 1547d8c..3ebdb20 100644 --- a/internal/cmdtmpl/command.go +++ b/internal/cmdtmpl/command.go @@ -326,6 +326,9 @@ table inet oc-daemon-routing { {{end -}} ` +// LoadedTemplates are the templates loaded from file. +var LoadedTemplates string + // defaultTemplate is the parsed default template for the command lists. var defaultTemplate = template.Must(template.New("Template").Parse(DefaultTemplate)) @@ -546,6 +549,25 @@ add element inet oc-daemon-routing excludes4 { {{.}} } }, } +// LoadTemplates loads the templates from file. +func LoadTemplates(file string) error { + // read file contents + f, err := os.ReadFile(file) + if err != nil { + return err + } + + // parse file contents + s := string(f) + t, err := template.New("Template").Parse(s) + if err == nil { + // save loaded templates + LoadedTemplates = s + defaultTemplate = t + } + return err +} + // LoadCommandLists loads the command lists from file. func LoadCommandLists(file string) error { // read file contents diff --git a/internal/cmdtmpl/command_test.go b/internal/cmdtmpl/command_test.go index ff0ad91..77acd38 100644 --- a/internal/cmdtmpl/command_test.go +++ b/internal/cmdtmpl/command_test.go @@ -20,6 +20,46 @@ func TestExecuteTemplateParseError(t *testing.T) { } } +// TestLoadTemplates tests LoadTemplates. +func TestLoadTemplates(t *testing.T) { + dir := t.TempDir() + + // not existing file + if err := LoadTemplates(filepath.Join(dir, "does not exist")); err == nil { + t.Errorf("not existing file should return error") + } + + // invalid template file + f := filepath.Join(dir, "command-lists.tmpl") + if err := os.WriteFile(f, []byte("{{ invalid template"), 0600); err != nil { + t.Fatal(err) + } + if err := LoadTemplates(f); err == nil { + t.Errorf("invalid template file should return error") + } + + // valid template file + oldTemplates := defaultTemplate + defer func() { + // cleanup after test + defaultTemplate = oldTemplates + LoadedTemplates = "" + }() + valid := "valid template" + if err := os.WriteFile(f, []byte(valid), 0600); err != nil { + t.Fatal(err) + } + if err := LoadTemplates(f); err != nil { + t.Errorf("valid template file returned error: %s", err) + } + if defaultTemplate == oldTemplates { + t.Error("load did not update templates") + } + if LoadedTemplates != valid { + t.Errorf("unexpected loaded templates: %s", LoadedTemplates) + } +} + // TestLoadCommandLists tests LoadCommandLists. func TestLoadCommandLists(t *testing.T) { dir := t.TempDir() diff --git a/internal/daemon/cmd.go b/internal/daemon/cmd.go index 94af20a..faf6d9c 100644 --- a/internal/daemon/cmd.go +++ b/internal/daemon/cmd.go @@ -90,6 +90,12 @@ func run(args []string) error { } // load command lists + if err := cmdtmpl.LoadTemplates(config.CommandLists.TemplatesFile); err != nil { + log.WithError(err).Debug("Daemon did not load command templates, using defaults") + } else { + log.WithField("file", config.CommandLists.TemplatesFile). + Info("Daemon loaded command templates from file") + } if err := cmdtmpl.LoadCommandLists(config.CommandLists.ListsFile); err != nil { log.WithError(err).Debug("Daemon did not load command lists, using defaults") } else { diff --git a/internal/daemon/cmd_test.go b/internal/daemon/cmd_test.go index ca94f6f..edbd448 100644 --- a/internal/daemon/cmd_test.go +++ b/internal/daemon/cmd_test.go @@ -84,6 +84,30 @@ func TestRun(t *testing.T) { t.Errorf("start should return error") } + // not existing command list templates file + tmplsFile := filepath.Join(dir, "tmpls") + tmplsConf := fmt.Sprintf(`{ + "CommandLists": { + "TemplatesFile": "%s" + } +} + `, tmplsFile) + if err := os.WriteFile(cfg, []byte(tmplsConf), 0600); err != nil { + t.Fatal(err) + } + + if err := run([]string{"test", "-verbose", "-config", cfg}); err == nil { + t.Errorf("start should return error") + } + + // empty command list templates file + if err := os.WriteFile(tmplsFile, []byte(""), 0600); err != nil { + t.Fatal(err) + } + if err := run([]string{"test", "-verbose", "-config", cfg}); err == nil { + t.Errorf("start should return error") + } + // not existing command lists file cmdListsFile := filepath.Join(dir, "cmd-lists") cmdListsConf := fmt.Sprintf(`{ diff --git a/internal/daemon/daemon.go b/internal/daemon/daemon.go index 2132fa4..5c5e530 100644 --- a/internal/daemon/daemon.go +++ b/internal/daemon/daemon.go @@ -528,18 +528,20 @@ func (d *Daemon) handleClientRequest(request *api.Request) { func (d *Daemon) dumpState() string { // define state type type State struct { - DaemonConfig *daemoncfg.Config - TrafficPolicing *trafpol.State - VPNSetup *vpnsetup.State - CommandLists map[string]*cmdtmpl.CommandList + DaemonConfig *daemoncfg.Config + TrafficPolicing *trafpol.State + VPNSetup *vpnsetup.State + CommandLists map[string]*cmdtmpl.CommandList + CommandTemplates string } // collect internal state c := d.config.Copy() c.LoginInfo.Cookie = "HIDDEN" // hide cookie state := State{ - DaemonConfig: c, - CommandLists: cmdtmpl.CommandLists, + DaemonConfig: c, + CommandLists: cmdtmpl.CommandLists, + CommandTemplates: cmdtmpl.LoadedTemplates, } if d.trafpol != nil { state.TrafficPolicing = d.trafpol.GetState() diff --git a/internal/daemoncfg/config.go b/internal/daemoncfg/config.go index 6fbcff1..7d3834c 100644 --- a/internal/daemoncfg/config.go +++ b/internal/daemoncfg/config.go @@ -529,12 +529,14 @@ func NewTrafficPolicing() *TrafficPolicing { // Command lists default values var ( - CommandListsListsFile = configDir + "/command-lists.json" + CommandListsListsFile = configDir + "/command-lists.json" + CommandListsTemplatesFile = configDir + "/command-lists.tmpl" ) // CommandLists is the command lists configuration. type CommandLists struct { - ListsFile string + ListsFile string + TemplatesFile string } // Copy returns a copy of the command lists configuration. @@ -546,7 +548,8 @@ func (c *CommandLists) Copy() *CommandLists { // Valid returns whether the command lists configuration is valid. func (c *CommandLists) Valid() bool { if c == nil || - c.ListsFile == "" { + c.ListsFile == "" || + c.TemplatesFile == "" { return false } @@ -556,7 +559,8 @@ func (c *CommandLists) Valid() bool { // NewCommandLists returns a new command lists configuration. func NewCommandLists() *CommandLists { return &CommandLists{ - ListsFile: CommandListsListsFile, + ListsFile: CommandListsListsFile, + TemplatesFile: CommandListsTemplatesFile, } }