Skip to content

Commit

Permalink
chore: generalise database management
Browse files Browse the repository at this point in the history
Signed-off-by: guillaume <[email protected]>
  • Loading branch information
gruyaume committed Sep 9, 2024
1 parent 228ac1d commit 15e57ce
Show file tree
Hide file tree
Showing 13 changed files with 108 additions and 111 deletions.
6 changes: 3 additions & 3 deletions cmd/notary/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,8 @@ func TestMain(m *testing.M) {
if err != nil {
log.Fatalf("couldn't create temp directory")
}
writeCertErr := os.WriteFile(testfolder+"/cert_test.pem", []byte(validCert), 0644)
writeKeyErr := os.WriteFile(testfolder+"/key_test.pem", []byte(validPK), 0644)
writeCertErr := os.WriteFile(testfolder+"/cert_test.pem", []byte(validCert), 0o644)
writeKeyErr := os.WriteFile(testfolder+"/key_test.pem", []byte(validPK), 0o644)
if writeCertErr != nil || writeKeyErr != nil {
log.Fatalf("couldn't create temp testing file")
}
Expand Down Expand Up @@ -145,7 +145,7 @@ func TestNotaryFail(t *testing.T) {
{"database not connectable", []string{"-config", "config.yaml"}, invalidDBConfig, "Couldn't connect to database:"},
}
for _, tc := range cases {
writeConfigErr := os.WriteFile("config.yaml", []byte(tc.ConfigYAML), 0644)
writeConfigErr := os.WriteFile("config.yaml", []byte(tc.ConfigYAML), 0o644)
if writeConfigErr != nil {
t.Errorf("Failed writing config file")
}
Expand Down
34 changes: 17 additions & 17 deletions internal/api/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import (
"strings"
"time"

"github.com/canonical/notary/internal/certdb"
"github.com/canonical/notary/internal/db"
metrics "github.com/canonical/notary/internal/metrics"
"github.com/canonical/notary/ui"
"github.com/golang-jwt/jwt"
Expand Down Expand Up @@ -159,7 +159,7 @@ func GetCertificateRequest(env *Environment) http.HandlerFunc {
id := r.PathValue("id")
cert, err := env.DB.RetrieveCSR(id)
if err != nil {
if errors.Is(err, certdb.ErrIdNotFound) {
if errors.Is(err, db.ErrIdNotFound) {
logErrorAndWriteResponse(err.Error(), http.StatusNotFound, w)
return
}
Expand All @@ -184,7 +184,7 @@ func DeleteCertificateRequest(env *Environment) http.HandlerFunc {
id := r.PathValue("id")
insertId, err := env.DB.DeleteCSR(id)
if err != nil {
if errors.Is(err, certdb.ErrIdNotFound) {
if errors.Is(err, db.ErrIdNotFound) {
logErrorAndWriteResponse(err.Error(), http.StatusNotFound, w)
return
}
Expand All @@ -210,7 +210,7 @@ func PostCertificate(env *Environment) http.HandlerFunc {
id := r.PathValue("id")
insertId, err := env.DB.UpdateCSR(id, string(cert))
if err != nil {
if errors.Is(err, certdb.ErrIdNotFound) ||
if errors.Is(err, db.ErrIdNotFound) ||
err.Error() == "certificate does not match CSR" ||
strings.Contains(err.Error(), "cert validation failed") {
logErrorAndWriteResponse(err.Error(), http.StatusBadRequest, w)
Expand Down Expand Up @@ -238,7 +238,7 @@ func RejectCertificate(env *Environment) http.HandlerFunc {
id := r.PathValue("id")
insertId, err := env.DB.UpdateCSR(id, "rejected")
if err != nil {
if errors.Is(err, certdb.ErrIdNotFound) {
if errors.Is(err, db.ErrIdNotFound) {
logErrorAndWriteResponse(err.Error(), http.StatusNotFound, w)
return
}
Expand Down Expand Up @@ -266,7 +266,7 @@ func DeleteCertificate(env *Environment) http.HandlerFunc {
id := r.PathValue("id")
insertId, err := env.DB.UpdateCSR(id, "")
if err != nil {
if errors.Is(err, certdb.ErrIdNotFound) {
if errors.Is(err, db.ErrIdNotFound) {
logErrorAndWriteResponse(err.Error(), http.StatusBadRequest, w)
return
}
Expand Down Expand Up @@ -314,7 +314,7 @@ func GetUserAccounts(env *Environment) http.HandlerFunc {
func GetUserAccount(env *Environment) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
id := r.PathValue("id")
var userAccount certdb.User
var userAccount db.User
var err error
if id == "me" {
claims, headerErr := getClaimsFromAuthorizationHeader(r.Header.Get("Authorization"), env.JWTSecret)
Expand All @@ -326,7 +326,7 @@ func GetUserAccount(env *Environment) http.HandlerFunc {
userAccount, err = env.DB.RetrieveUser(id)
}
if err != nil {
if errors.Is(err, certdb.ErrIdNotFound) {
if errors.Is(err, db.ErrIdNotFound) {
logErrorAndWriteResponse(err.Error(), http.StatusNotFound, w)
return
}
Expand All @@ -348,7 +348,7 @@ func GetUserAccount(env *Environment) http.HandlerFunc {
// PostUserAccount creates a new User Account, and returns the id of the created row
func PostUserAccount(env *Environment) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
var user certdb.User
var user db.User
if err := json.NewDecoder(r.Body).Decode(&user); err != nil {
logErrorAndWriteResponse("Invalid JSON format", http.StatusBadRequest, w)
return
Expand All @@ -357,7 +357,7 @@ func PostUserAccount(env *Environment) http.HandlerFunc {
logErrorAndWriteResponse("Username is required", http.StatusBadRequest, w)
return
}
var shouldGeneratePassword = user.Password == ""
shouldGeneratePassword := user.Password == ""
if shouldGeneratePassword {
generatedPassword, err := generatePassword()
if err != nil {
Expand All @@ -382,7 +382,7 @@ func PostUserAccount(env *Environment) http.HandlerFunc {

permission := "0"
if len(users) == 0 {
permission = "1" //if this is the first user it will be admin
permission = "1" // if this is the first user it will be admin
}
id, err := env.DB.CreateUser(user.Username, user.Password, permission)
if err != nil {
Expand Down Expand Up @@ -415,7 +415,7 @@ func DeleteUserAccount(env *Environment) http.HandlerFunc {
id := r.PathValue("id")
user, err := env.DB.RetrieveUser(id)
if err != nil {
if !errors.Is(err, certdb.ErrIdNotFound) {
if !errors.Is(err, db.ErrIdNotFound) {
logErrorAndWriteResponse(err.Error(), http.StatusInternalServerError, w)
return
}
Expand All @@ -426,7 +426,7 @@ func DeleteUserAccount(env *Environment) http.HandlerFunc {
}
insertId, err := env.DB.DeleteUser(id)
if err != nil {
if errors.Is(err, certdb.ErrIdNotFound) {
if errors.Is(err, db.ErrIdNotFound) {
logErrorAndWriteResponse(err.Error(), http.StatusNotFound, w)
return
}
Expand Down Expand Up @@ -454,7 +454,7 @@ func ChangeUserAccountPassword(env *Environment) http.HandlerFunc {
}
id = strconv.Itoa(userAccount.ID)
}
var user certdb.User
var user db.User
if err := json.NewDecoder(r.Body).Decode(&user); err != nil {
logErrorAndWriteResponse("Invalid JSON format", http.StatusBadRequest, w)
return
Expand All @@ -473,7 +473,7 @@ func ChangeUserAccountPassword(env *Environment) http.HandlerFunc {
}
ret, err := env.DB.UpdateUser(id, user.Password)
if err != nil {
if errors.Is(err, certdb.ErrIdNotFound) {
if errors.Is(err, db.ErrIdNotFound) {
logErrorAndWriteResponse(err.Error(), http.StatusNotFound, w)
return
}
Expand All @@ -489,7 +489,7 @@ func ChangeUserAccountPassword(env *Environment) http.HandlerFunc {

func Login(env *Environment) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
var userRequest certdb.User
var userRequest db.User
if err := json.NewDecoder(r.Body).Decode(&userRequest); err != nil {
logErrorAndWriteResponse("Invalid JSON format", http.StatusBadRequest, w)
return
Expand All @@ -505,7 +505,7 @@ func Login(env *Environment) http.HandlerFunc {
userAccount, err := env.DB.RetrieveUserByUsername(userRequest.Username)
if err != nil {
status := http.StatusInternalServerError
if errors.Is(err, certdb.ErrIdNotFound) {
if errors.Is(err, db.ErrIdNotFound) {
logErrorAndWriteResponse("The username or password is incorrect. Try again.", http.StatusUnauthorized, w)
return
}
Expand Down
11 changes: 5 additions & 6 deletions internal/api/handlers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (
"testing"

server "github.com/canonical/notary/internal/api"
"github.com/canonical/notary/internal/certdb"
"github.com/canonical/notary/internal/db"
"github.com/golang-jwt/jwt"
)

Expand Down Expand Up @@ -153,7 +153,7 @@ const (
)

func TestNotaryCertificatesHandlers(t *testing.T) {
testdb, err := certdb.NewCertificateRequestsRepository(":memory:", "CertificateRequests")
testdb, err := db.NewDatabase(":memory:")
if err != nil {
log.Fatalf("couldn't create test sqlite db: %s", err)
}
Expand Down Expand Up @@ -406,11 +406,10 @@ func TestNotaryCertificatesHandlers(t *testing.T) {
}
})
}

}

func TestNotaryUsersHandlers(t *testing.T) {
testdb, err := certdb.NewCertificateRequestsRepository(":memory:", "CertificateRequests")
testdb, err := db.NewDatabase(":memory:")
if err != nil {
log.Fatalf("couldn't create test sqlite db: %s", err)
}
Expand Down Expand Up @@ -573,7 +572,7 @@ func TestNotaryUsersHandlers(t *testing.T) {
}

func TestLogin(t *testing.T) {
testdb, err := certdb.NewCertificateRequestsRepository(":memory:", "CertificateRequests")
testdb, err := db.NewDatabase(":memory:")
if err != nil {
log.Fatalf("couldn't create test sqlite db: %s", err)
}
Expand Down Expand Up @@ -687,7 +686,7 @@ func TestLogin(t *testing.T) {
}

func TestAuthorization(t *testing.T) {
testdb, err := certdb.NewCertificateRequestsRepository(":memory:", "CertificateRequests")
testdb, err := db.NewDatabase(":memory:")
if err != nil {
log.Fatalf("couldn't create test sqlite db: %s", err)
}
Expand Down
6 changes: 3 additions & 3 deletions internal/api/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@ import (
"os/exec"
"time"

"github.com/canonical/notary/internal/certdb"
"github.com/canonical/notary/internal/db"
)

type Environment struct {
DB *certdb.CertificateRequestsRepository
DB *db.Database
SendPebbleNotifications bool
JWTSecret []byte
}
Expand All @@ -42,7 +42,7 @@ func NewServer(port int, cert []byte, key []byte, dbPath string, pebbleNotificat
if err != nil {
return nil, err
}
db, err := certdb.NewCertificateRequestsRepository(dbPath, "CertificateRequests")
db, err := db.NewDatabase(dbPath)
if err != nil {
log.Fatalf("Couldn't connect to database: %s", err)
}
Expand Down
4 changes: 2 additions & 2 deletions internal/api/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,8 @@ func TestMain(m *testing.M) {
if err != nil {
log.Fatalf("couldn't create temp directory")
}
writeCertErr := os.WriteFile(testfolder+"/cert_test.pem", []byte(validCert), 0644)
writeKeyErr := os.WriteFile(testfolder+"/key_test.pem", []byte(validPK), 0644)
writeCertErr := os.WriteFile(testfolder+"/cert_test.pem", []byte(validCert), 0o644)
writeKeyErr := os.WriteFile(testfolder+"/key_test.pem", []byte(validPK), 0o644)
if writeCertErr != nil || writeKeyErr != nil {
log.Fatalf("couldn't create temp testing file")
}
Expand Down
2 changes: 1 addition & 1 deletion internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ func Validate(filePath string) (Config, error) {
if c.DBPath == "" {
return config, errors.Join(validationErr, errors.New("`db_path` is empty"))
}
dbfile, err := os.OpenFile(c.DBPath, os.O_CREATE|os.O_RDONLY, 0644)
dbfile, err := os.OpenFile(c.DBPath, os.O_CREATE|os.O_RDONLY, 0o644)
if err != nil {
return config, errors.Join(validationErr, err)
}
Expand Down
10 changes: 4 additions & 6 deletions internal/config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ func TestMain(m *testing.M) {
if err != nil {
log.Fatalf("couldn't create temp directory")
}
writeCertErr := os.WriteFile(testfolder+"/cert_test.pem", []byte(validCert), 0644)
writeKeyErr := os.WriteFile(testfolder+"/key_test.pem", []byte(validPK), 0644)
writeCertErr := os.WriteFile(testfolder+"/cert_test.pem", []byte(validCert), 0o644)
writeKeyErr := os.WriteFile(testfolder+"/key_test.pem", []byte(validPK), 0o644)
if writeCertErr != nil || writeKeyErr != nil {
log.Fatalf("couldn't create temp testing file")
}
Expand All @@ -63,7 +63,7 @@ func TestMain(m *testing.M) {
}

func TestGoodConfigSuccess(t *testing.T) {
writeConfigErr := os.WriteFile("config.yaml", []byte(validConfig), 0644)
writeConfigErr := os.WriteFile("config.yaml", []byte(validConfig), 0o644)
if writeConfigErr != nil {
t.Fatalf("Error writing config file")
}
Expand All @@ -87,7 +87,6 @@ func TestGoodConfigSuccess(t *testing.T) {
if conf.Port != 8000 {
t.Fatalf("Port was not configured correctly")
}

}

func TestBadConfigFail(t *testing.T) {
Expand All @@ -105,7 +104,7 @@ func TestBadConfigFail(t *testing.T) {
}

for _, tc := range cases {
writeConfigErr := os.WriteFile("config.yaml", []byte(tc.ConfigYAML), 0644)
writeConfigErr := os.WriteFile("config.yaml", []byte(tc.ConfigYAML), 0o644)
if writeConfigErr != nil {
t.Errorf("Failed writing config file")
}
Expand All @@ -117,6 +116,5 @@ func TestBadConfigFail(t *testing.T) {
if !strings.Contains(err.Error(), tc.ExpectedError) {
t.Errorf("Expected error not found: %s", err)
}

}
}
Loading

0 comments on commit 15e57ce

Please sign in to comment.