From e6db72b4e3364a1506ef15896b47823eaf35282f Mon Sep 17 00:00:00 2001 From: Guillaume Belanger Date: Fri, 14 Jun 2024 15:30:33 -0400 Subject: [PATCH] fix: add key missing in config validation (#24) --- cmd/gocert/main.go | 10 ++- cmd/gocert/main_test.go | 2 +- internal/api/server.go | 69 ++----------------- internal/api/server_test.go | 53 ++------------ internal/config/config.go | 73 ++++++++++++++++++++ internal/config/config_test.go | 122 +++++++++++++++++++++++++++++++++ 6 files changed, 213 insertions(+), 116 deletions(-) create mode 100644 internal/config/config.go create mode 100644 internal/config/config_test.go diff --git a/cmd/gocert/main.go b/cmd/gocert/main.go index 5c43364..115b665 100644 --- a/cmd/gocert/main.go +++ b/cmd/gocert/main.go @@ -6,17 +6,21 @@ import ( "os" server "github.com/canonical/gocert/internal/api" + "github.com/canonical/gocert/internal/config" ) func main() { log.SetOutput(os.Stderr) configFilePtr := flag.String("config", "", "The config file to be provided to the server") flag.Parse() - if *configFilePtr == "" { - log.Fatalf("Providing a valid config file is required.") + log.Fatalf("Providing a config file is required.") + } + conf, err := config.Validate(*configFilePtr) + if err != nil { + log.Fatalf("Couldn't validate config file: %s", err) } - srv, err := server.NewServer(*configFilePtr) + srv, err := server.NewServer(conf.Port, conf.Cert, conf.Key, conf.DBPath, conf.PebbleNotificationsEnabled) if err != nil { log.Fatalf("Couldn't create server: %s", err) } diff --git a/cmd/gocert/main_test.go b/cmd/gocert/main_test.go index 58d37d7..3b62a2e 100644 --- a/cmd/gocert/main_test.go +++ b/cmd/gocert/main_test.go @@ -140,7 +140,7 @@ func TestGoCertFail(t *testing.T) { ConfigYAML string ExpectedOutput string }{ - {"flags not set", []string{}, validConfig, "Providing a valid config file is required."}, + {"flags not set", []string{}, validConfig, "Providing a config file is required."}, {"config file not valid", []string{"-config", "config.yaml"}, invalidConfig, "config file validation failed:"}, {"database not connectable", []string{"-config", "config.yaml"}, invalidDBConfig, "Couldn't connect to database:"}, } diff --git a/internal/api/server.go b/internal/api/server.go index 5910d76..ee2ca8b 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -7,72 +7,17 @@ import ( "fmt" "log" "net/http" - "os" "os/exec" "time" "github.com/canonical/gocert/internal/certdb" - "gopkg.in/yaml.v3" ) -type ConfigYAML struct { - KeyPath string `yaml:"key_path"` - CertPath string `yaml:"cert_path"` - DBPath string `yaml:"db_path"` - Port int `yaml:"port"` - Pebblenotificationsenabled bool `yaml:"pebble_notifications"` -} - -type Config struct { - Key []byte - Cert []byte - DBPath string - Port int - PebbleNotificationsEnabled bool -} - type Environment struct { DB *certdb.CertificateRequestsRepository SendPebbleNotifications bool } -// validateConfigFile opens and processes the given yaml file, and catches errors in the process -func validateConfigFile(filePath string) (Config, error) { - validationErr := errors.New("config file validation failed: ") - config := Config{} - configYaml, err := os.ReadFile(filePath) - if err != nil { - return config, errors.Join(validationErr, err) - } - c := ConfigYAML{} - if err := yaml.Unmarshal(configYaml, &c); err != nil { - return config, errors.Join(validationErr, err) - } - cert, err := os.ReadFile(c.CertPath) - if err != nil { - return config, errors.Join(validationErr, err) - } - key, err := os.ReadFile(c.KeyPath) - if err != nil { - return config, errors.Join(validationErr, err) - } - dbfile, err := os.OpenFile(c.DBPath, os.O_CREATE|os.O_RDONLY, 0644) - if err != nil { - return config, errors.Join(validationErr, err) - } - err = dbfile.Close() - if err != nil { - return config, errors.Join(validationErr, err) - } - - config.Cert = cert - config.Key = key - config.DBPath = c.DBPath - config.Port = c.Port - config.PebbleNotificationsEnabled = c.Pebblenotificationsenabled - return config, nil -} - func SendPebbleNotification(key, request_id string) error { cmd := exec.Command("pebble", "notify", key, fmt.Sprintf("request_id=%s", request_id)) if err := cmd.Run(); err != nil { @@ -82,27 +27,23 @@ func SendPebbleNotification(key, request_id string) error { } // NewServer creates an environment and an http server with handlers that Go can start listening to -func NewServer(configFile string) (*http.Server, error) { - config, err := validateConfigFile(configFile) - if err != nil { - return nil, err - } - serverCerts, err := tls.X509KeyPair(config.Cert, config.Key) +func NewServer(port int, cert []byte, key []byte, dbPath string, pebbleNotificationsEnabled bool) (*http.Server, error) { + serverCerts, err := tls.X509KeyPair(cert, key) if err != nil { return nil, err } - db, err := certdb.NewCertificateRequestsRepository(config.DBPath, "CertificateRequests") + db, err := certdb.NewCertificateRequestsRepository(dbPath, "CertificateRequests") if err != nil { log.Fatalf("Couldn't connect to database: %s", err) } env := &Environment{} env.DB = db - env.SendPebbleNotifications = config.PebbleNotificationsEnabled + env.SendPebbleNotifications = pebbleNotificationsEnabled router := NewGoCertRouter(env) s := &http.Server{ - Addr: fmt.Sprintf(":%d", config.Port), + Addr: fmt.Sprintf(":%d", port), ReadTimeout: 10 * time.Second, WriteTimeout: 10 * time.Second, diff --git a/internal/api/server_test.go b/internal/api/server_test.go index 8a6caea..535ac71 100644 --- a/internal/api/server_test.go +++ b/internal/api/server_test.go @@ -85,24 +85,6 @@ Q53tuiWQeoxNOjHiWstBPELxGbW6447JyVVbNYGUk+VFU7okzA6sRTJ/5Ysda4Sf auNQc2hruhr/2plhFUYoZHPzGz7d5zUGKymhCoS8BsFVtD0WDL4srdtY/W2Us7TD D7DC34n8CH9+avz9sCRwxpjxKnYW/BeyK0c4n9uZpjI8N4sOVqy6yWBUseww -----END RSA PRIVATE KEY-----` - validConfig = `key_path: "./key_test.pem" -cert_path: "./cert_test.pem" -db_path: "./certs.db" -port: 8000` - wrongCertConfig = `key_path: "./key_test.pem" -cert_path: "./cert_test_wrong.pem" -db_path: "./certs.db" -port: 8000` - wrongKeyConfig = `key_path: "./key_test_wrong.pem" -cert_path: "./cert_test.pem" -db_path: "./certs.db" -port: 8000` - invalidYAMLConfig = `wrong: fields -every: where` - invalidFileConfig = `key_path: "./nokeyfile.pem" -cert_path: "./nocertfile.pem" -db_path: "./certs.db" -port: 8000` ) func TestMain(m *testing.M) { @@ -131,11 +113,7 @@ func TestMain(m *testing.M) { } func TestNewServerSuccess(t *testing.T) { - writeConfigErr := os.WriteFile("config.yaml", []byte(validConfig), 0644) - if writeConfigErr != nil { - log.Fatalf("Error writing config file") - } - s, err := server.NewServer("config.yaml") + s, err := server.NewServer(8000, []byte(validCert), []byte(validPK), "certs.db", false) if err != nil { t.Errorf("Error occured: %s", err) } @@ -144,30 +122,9 @@ func TestNewServerSuccess(t *testing.T) { } } -func TestNewServerFail(t *testing.T) { - testCases := []struct { - desc string - config string - }{ - { - desc: "wrong certificate", - config: wrongCertConfig, - }, - { - desc: "wrong key", - config: wrongKeyConfig, - }, - } - for _, tC := range testCases { - writeConfigErr := os.WriteFile("config.yaml", []byte(tC.config), 0644) - if writeConfigErr != nil { - log.Fatalf("Error writing config file") - } - t.Run(tC.desc, func(t *testing.T) { - _, err := server.NewServer("config.yaml") - if err == nil { - t.Errorf("Expected error") - } - }) +func TestInvalidKeyFailure(t *testing.T) { + _, err := server.NewServer(8000, []byte(validCert), []byte{}, "certs.db", false) + if err == nil { + t.Errorf("No error was thrown for invalid key") } } diff --git a/internal/config/config.go b/internal/config/config.go new file mode 100644 index 0000000..0adc8a0 --- /dev/null +++ b/internal/config/config.go @@ -0,0 +1,73 @@ +package config + +import ( + "errors" + "os" + + "gopkg.in/yaml.v3" +) + +type ConfigYAML struct { + KeyPath string `yaml:"key_path"` + CertPath string `yaml:"cert_path"` + DBPath string `yaml:"db_path"` + Port int `yaml:"port"` + Pebblenotificationsenabled bool `yaml:"pebble_notifications"` +} + +type Config struct { + Key []byte + Cert []byte + DBPath string + Port int + PebbleNotificationsEnabled bool +} + +// Validate opens and processes the given yaml file, and catches errors in the process +func Validate(filePath string) (Config, error) { + validationErr := errors.New("config file validation failed: ") + config := Config{} + configYaml, err := os.ReadFile(filePath) + if err != nil { + return config, errors.Join(validationErr, err) + } + c := ConfigYAML{} + if err := yaml.Unmarshal(configYaml, &c); err != nil { + return config, errors.Join(validationErr, err) + } + if c.CertPath == "" { + return config, errors.Join(validationErr, errors.New("`cert_path` is empty")) + } + cert, err := os.ReadFile(c.CertPath) + if err != nil { + return config, errors.Join(validationErr, err) + } + if c.KeyPath == "" { + return config, errors.Join(validationErr, errors.New("`key_path` is empty")) + } + key, err := os.ReadFile(c.KeyPath) + if err != nil { + return config, errors.Join(validationErr, err) + } + if c.DBPath == "" { + return config, errors.Join(validationErr, errors.New("`db_path` is empty")) + } + dbfile, err := os.OpenFile(c.DBPath, os.O_CREATE|os.O_RDONLY, 0644) + if err != nil { + return config, errors.Join(validationErr, err) + } + err = dbfile.Close() + if err != nil { + return config, errors.Join(validationErr, err) + } + if c.Port == 0 { + return config, errors.Join(validationErr, errors.New("`port` is empty")) + } + + config.Cert = cert + config.Key = key + config.DBPath = c.DBPath + config.Port = c.Port + config.PebbleNotificationsEnabled = c.Pebblenotificationsenabled + return config, nil +} diff --git a/internal/config/config_test.go b/internal/config/config_test.go new file mode 100644 index 0000000..0aecd23 --- /dev/null +++ b/internal/config/config_test.go @@ -0,0 +1,122 @@ +package config_test + +import ( + "log" + "os" + "strings" + "testing" + + "github.com/canonical/gocert/internal/config" +) + +const ( + validCert = `Whatever cert content` + validPK = `Whatever key content` + validConfig = `key_path: "./key_test.pem" +cert_path: "./cert_test.pem" +db_path: "./certs.db" +port: 8000` + noCertPathConfig = `key_path: "./key_test.pem" +db_path: "./certs.db" +port: 8000` + noKeyPathConfig = `cert_path: "./cert_test.pem" +db_path: "./certs.db" +port: 8000` + noDBPathConfig = `key_path: "./key_test.pem" +cert_path: "./cert_test.pem" +port: 8000` + wrongCertPathConfig = `key_path: "./key_test.pem" +cert_path: "./cert_test_wrong.pem" +db_path: "./certs.db" +port: 8000` + wrongKeyPathConfig = `key_path: "./key_test_wrong.pem" +cert_path: "./cert_test.pem" +db_path: "./certs.db" +port: 8000` + invalidYAMLConfig = `just_an=invalid +yaml.here` +) + +func TestMain(m *testing.M) { + testfolder, err := os.MkdirTemp("./", "configtest-") + if err != nil { + log.Fatalf("couldn't create temp directory") + } + writeCertErr := os.WriteFile(testfolder+"/cert_test.pem", []byte(validCert), 0644) + writeKeyErr := os.WriteFile(testfolder+"/key_test.pem", []byte(validPK), 0644) + if writeCertErr != nil || writeKeyErr != nil { + log.Fatalf("couldn't create temp testing file") + } + if err := os.Chdir(testfolder); err != nil { + log.Fatalf("couldn't enter testing directory") + } + + exitval := m.Run() + + if err := os.Chdir("../"); err != nil { + log.Fatalf("couldn't change back to parent directory") + } + if err := os.RemoveAll(testfolder); err != nil { + log.Fatalf("couldn't remove temp testing directory") + } + os.Exit(exitval) +} + +func TestGoodConfigSuccess(t *testing.T) { + writeConfigErr := os.WriteFile("config.yaml", []byte(validConfig), 0644) + if writeConfigErr != nil { + t.Fatalf("Error writing config file") + } + conf, err := config.Validate("config.yaml") + if err != nil { + t.Fatalf("Error occured: %s", err) + } + + if conf.Cert == nil { + t.Fatalf("No certificates were configured for server") + } + + if conf.Key == nil { + t.Fatalf("No key was configured for server") + } + + if conf.DBPath == "" { + t.Fatalf("No database path was configured for server") + } + + if conf.Port != 8000 { + t.Fatalf("Port was not configured correctly") + } + +} + +func TestBadConfigFail(t *testing.T) { + cases := []struct { + Name string + ConfigYAML string + ExpectedError string + }{ + {"no cert path", noCertPathConfig, "`cert_path` is empty"}, + {"no key path", noKeyPathConfig, "`key_path` is empty"}, + {"no db path", noDBPathConfig, "`db_path` is empty"}, + {"wrong cert path", wrongCertPathConfig, "no such file or directory"}, + {"wrong key path", wrongKeyPathConfig, "no such file or directory"}, + {"invalid yaml", invalidYAMLConfig, "unmarshal errors"}, + } + + for _, tc := range cases { + writeConfigErr := os.WriteFile("config.yaml", []byte(tc.ConfigYAML), 0644) + if writeConfigErr != nil { + t.Errorf("Failed writing config file") + } + _, err := config.Validate("config.yaml") + if err == nil { + t.Errorf("Expected error, got nil") + } + + if !strings.Contains(err.Error(), tc.ExpectedError) { + t.Errorf("Expected error not found: %s", err) + } + + } +}