diff --git a/internal/api/handlers.go b/internal/api/handlers.go index 12a284f..7869e91 100644 --- a/internal/api/handlers.go +++ b/internal/api/handlers.go @@ -34,6 +34,7 @@ func NewGoCertRouter(env *Environment) http.Handler { apiV1Router.HandleFunc("GET /accounts/{id}", GetUserAccount(env)) apiV1Router.HandleFunc("GET /accounts", GetUserAccounts(env)) apiV1Router.HandleFunc("POST /accounts", PostUserAccount(env)) + apiV1Router.HandleFunc("POST /accounts/{id}", ChangeUserAccountPassword(env)) m := metrics.NewMetricsSubsystem(env.DB) frontendHandler := newFrontendFileServer() @@ -360,6 +361,34 @@ func PostUserAccount(env *Environment) http.HandlerFunc { } } +func ChangeUserAccountPassword(env *Environment) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + var user certdb.User + if err := json.NewDecoder(r.Body).Decode(&user); err != nil { + logErrorAndWriteResponse("Invalid JSON format", http.StatusBadRequest, w) + return + } + if user.Password == "" { + logErrorAndWriteResponse("Password is required", http.StatusBadRequest, w) + return + } + ret, err := env.DB.UpdateUser(id, user.Password) + if err != nil { + if errors.Is(err, certdb.ErrIdNotFound) { + logErrorAndWriteResponse(err.Error(), http.StatusNotFound, w) + return + } + logErrorAndWriteResponse(err.Error(), http.StatusInternalServerError, w) + return + } + w.WriteHeader(http.StatusOK) + if _, err := w.Write([]byte(strconv.FormatInt(ret, 10))); err != nil { + logErrorAndWriteResponse(err.Error(), http.StatusInternalServerError, w) + } + } +} + // logErrorAndWriteResponse is a helper function that logs any error and writes it back as an http response func logErrorAndWriteResponse(msg string, status int, w http.ResponseWriter) { errMsg := fmt.Sprintf("error: %s", msg) diff --git a/internal/api/handlers_test.go b/internal/api/handlers_test.go index 4e30dbb..abfe64e 100644 --- a/internal/api/handlers_test.go +++ b/internal/api/handlers_test.go @@ -101,10 +101,12 @@ const ( ) const ( - adminUser = `{"username": "testadmin", "password": "admin"}` - validUser = `{"username": "testuser", "password": "user"}` - noPasswordUser = `{"username": "nopass", "password": ""}` - invalidUser = `{"username": "", "password": ""}` + adminUser = `{"username": "testadmin", "password": "admin"}` + validUser = `{"username": "testuser", "password": "user"}` + invalidUser = `{"username": "", "password": ""}` + noPasswordUser = `{"username": "nopass", "password": ""}` + adminUserNewPassword = `{"id": 1, "password": "newpassword"}` + userNewInvalidPassword = `{"id": 1, "password": ""}` ) func TestGoCertCertificatesHandlers(t *testing.T) { @@ -268,7 +270,7 @@ func TestGoCertCertificatesHandlers(t *testing.T) { method: "POST", path: "/api/v1/certificate_requests/2/certificate", data: validCert2, - response: "4", + response: "1", status: http.StatusCreated, }, { @@ -284,7 +286,7 @@ func TestGoCertCertificatesHandlers(t *testing.T) { method: "POST", path: "/api/v1/certificate_requests/4/certificate/reject", data: "", - response: "4", + response: "1", status: http.StatusAccepted, }, { @@ -300,7 +302,7 @@ func TestGoCertCertificatesHandlers(t *testing.T) { method: "DELETE", path: "/api/v1/certificate_requests/2/certificate", data: "", - response: "4", + response: "1", status: http.StatusAccepted, }, { @@ -369,12 +371,6 @@ func TestGoCertUsersHandlers(t *testing.T) { ts := httptest.NewTLSServer(server.NewGoCertRouter(env)) defer ts.Close() - originalFunc := server.GeneratePassword - server.GeneratePassword = func(length int) (string, error) { - return "generatedPassword", nil - } - defer func() { server.GeneratePassword = originalFunc }() - client := ts.Client() testCases := []struct { @@ -414,7 +410,7 @@ func TestGoCertUsersHandlers(t *testing.T) { method: "POST", path: "/api/v1/accounts", data: noPasswordUser, - response: "{\"id\":3,\"password\":\"generatedPassword\"}", + response: "{\"id\":3,\"password\":", status: http.StatusCreated, }, { @@ -441,6 +437,30 @@ func TestGoCertUsersHandlers(t *testing.T) { response: "error: Username is required", status: http.StatusBadRequest, }, + { + desc: "Change password success", + method: "POST", + path: "/api/v1/accounts/1", + data: adminUserNewPassword, + response: "1", + status: http.StatusOK, + }, + { + desc: "Change password failure no user", + method: "POST", + path: "/api/v1/accounts/100", + data: adminUserNewPassword, + response: "id not found", + status: http.StatusNotFound, + }, + { + desc: "Change password failure", + method: "POST", + path: "/api/v1/accounts/1", + data: userNewInvalidPassword, + response: "Password is required", + status: http.StatusBadRequest, + }, } for _, tC := range testCases { t.Run(tC.desc, func(t *testing.T) { diff --git a/internal/certdb/certdb.go b/internal/certdb/certdb.go index 0d306b3..191e812 100644 --- a/internal/certdb/certdb.go +++ b/internal/certdb/certdb.go @@ -130,11 +130,11 @@ func (db *CertificateRequestsRepository) UpdateCSR(id string, cert string) (int6 if err != nil { return 0, err } - insertId, err := result.LastInsertId() + affectedRows, err := result.RowsAffected() if err != nil { return 0, err } - return insertId, nil + return affectedRows, nil } // DeleteCSR removes a CSR from the database alongside the certificate that may have been generated for it. @@ -219,11 +219,11 @@ func (db *CertificateRequestsRepository) UpdateUser(id, password string) (int64, if err != nil { return 0, err } - insertId, err := result.LastInsertId() + affectedRows, err := result.RowsAffected() if err != nil { return 0, err } - return insertId, nil + return affectedRows, nil } // DeleteUser removes a user from the table.