diff --git a/internal/api/handlers.go b/internal/api/handlers.go index 5b1fc5a..5764198 100644 --- a/internal/api/handlers.go +++ b/internal/api/handlers.go @@ -27,18 +27,24 @@ func NewGoCertRouter(env *Environment) http.Handler { apiV1Router.HandleFunc("POST /certificate_requests/{id}/certificate/reject", RejectCertificate(env)) apiV1Router.HandleFunc("DELETE /certificate_requests/{id}/certificate", DeleteCertificate(env)) - metricsHandler := metrics.NewPrometheusMetricsHandler() + m := metrics.NewMetricsSubsystem(env.DB) frontendHandler := newFrontendFileServer() router := http.NewServeMux() router.HandleFunc("/status", HealthCheck) - router.Handle("/metrics", metricsHandler) + router.Handle("/metrics", m.Handler) router.Handle("/api/v1/", http.StripPrefix("/api/v1", apiV1Router)) router.Handle("/", frontendHandler) - return logging(router) + ctx := middlewareContext{metrics: m} + middleware := createMiddlewareStack( + metricsMiddleware(&ctx), + loggingMiddleware(&ctx), + ) + return middleware(router) } +// newFrontendFileServer uses the embedded ui output files as the base for a file server func newFrontendFileServer() http.Handler { frontendFS, err := fs.Sub(ui.FrontendFS, "out") if err != nil { @@ -213,14 +219,6 @@ func DeleteCertificate(env *Environment) http.HandlerFunc { } } -// The logging middleware captures any http request coming through, and logs it -func logging(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - next.ServeHTTP(w, r) - log.Println(r.Method, r.URL.Path) - }) -} - // logErrorAndWriteResponse is a helper function that logs any error and writes it back as an http response func logErrorAndWriteResponse(msg string, status int, w http.ResponseWriter) { errMsg := fmt.Sprintf("error: %s", msg) diff --git a/internal/api/middleware.go b/internal/api/middleware.go new file mode 100644 index 0000000..8e8e3d1 --- /dev/null +++ b/internal/api/middleware.go @@ -0,0 +1,80 @@ +package server + +import ( + "log" + "net/http" + + "github.com/canonical/gocert/internal/metrics" + "github.com/prometheus/client_golang/prometheus/promhttp" +) + +type middleware func(http.Handler) http.Handler + +// The middlewareContext type helps middleware receive and pass along information through the middleware chain. +type middlewareContext struct { + responseStatusCode int + metrics *metrics.PrometheusMetrics +} + +// The responseWriterCloner struct wraps the http.ResponseWriter struct, and extracts the status +// code of the response writer for the middleware to read +type responseWriterCloner struct { + http.ResponseWriter + statusCode int +} + +// newResponseWriter returns a new ResponseWriterCloner struct +// it returns http.StatusOK by default because the http.ResponseWriter defaults to that header +// if the WriteHeader() function is never called. +func newResponseWriter(w http.ResponseWriter) *responseWriterCloner { + return &responseWriterCloner{w, http.StatusOK} +} + +// WriteHeader overrides the ResponseWriter method to duplicate the status code into the wrapper struct +func (rwc *responseWriterCloner) WriteHeader(code int) { + rwc.statusCode = code + rwc.ResponseWriter.WriteHeader(code) +} + +// createMiddlewareStack chains the given middleware functions to wrap the api. +// Each middleware functions calls next.ServeHTTP in order to resume the chain of execution. +// The order the middleware functions are given to createMiddlewareStack matters. +// Any code before next.ServeHTTP is called is executed in the given middleware's order. +// Any code after next.ServeHTTP is called is executed in the given middleware's reverse order. +func createMiddlewareStack(middleware ...middleware) middleware { + return func(next http.Handler) http.Handler { + for i := len(middleware) - 1; i >= 0; i-- { + mw := middleware[i] + next = mw(next) + } + return next + } +} + +// The Metrics middleware captures any request relevant to a metric and records it for prometheus. +func metricsMiddleware(ctx *middlewareContext) middleware { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + base := promhttp.InstrumentHandlerCounter( + &ctx.metrics.RequestsTotal, + promhttp.InstrumentHandlerDuration( + &ctx.metrics.RequestsDuration, + next, + ), + ) + base.ServeHTTP(w, r) + }) + } +} + +// The Logging middleware captures any http request coming through and the response status code, and logs it. +func loggingMiddleware(ctx *middlewareContext) middleware { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + clonedWwriter := newResponseWriter(w) + next.ServeHTTP(w, r) + log.Println(r.Method, r.URL.Path, clonedWwriter.statusCode, http.StatusText(clonedWwriter.statusCode)) + ctx.responseStatusCode = clonedWwriter.statusCode + }) + } +} diff --git a/internal/certdb/validation.go b/internal/certdb/validation.go index 9a487a4..ae191ae 100644 --- a/internal/certdb/validation.go +++ b/internal/certdb/validation.go @@ -22,6 +22,7 @@ func ValidateCertificateRequest(csr string) error { if err != nil { return err } + // TODO: We should validate the actual certificate request parameters here too. (Has the required fields etc) return nil } @@ -40,6 +41,7 @@ func ValidateCertificate(cert string) error { if err != nil { return err } + // TODO: We should validate the actual certificate parameters here too. (Has the required fields etc) return nil } diff --git a/internal/metrics/metrics.go b/internal/metrics/metrics.go index b55d107..a790d50 100644 --- a/internal/metrics/metrics.go +++ b/internal/metrics/metrics.go @@ -1,17 +1,224 @@ package metrics import ( + "crypto/x509" + "encoding/pem" + "errors" + "log" "net/http" + "time" + "github.com/canonical/gocert/internal/certdb" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/collectors" "github.com/prometheus/client_golang/prometheus/promhttp" ) -// Returns an HTTP handler for Prometheus metrics. -func NewPrometheusMetricsHandler() http.Handler { - reg := prometheus.NewRegistry() - reg.MustRegister(collectors.NewGoCollector(), collectors.NewProcessCollector(collectors.ProcessCollectorOpts{})) - prometheusHandler := promhttp.HandlerFor(reg, promhttp.HandlerOpts{}) - return prometheusHandler +type PrometheusMetrics struct { + http.Handler + registry *prometheus.Registry + CertificateRequests prometheus.Gauge + OutstandingCertificateRequests prometheus.Gauge + Certificates prometheus.Gauge + CertificatesExpiringIn1Day prometheus.Gauge + CertificatesExpiringIn7Days prometheus.Gauge + CertificatesExpiringIn30Days prometheus.Gauge + CertificatesExpiringIn90Days prometheus.Gauge + ExpiredCertificates prometheus.Gauge + + RequestsTotal prometheus.CounterVec + RequestsDuration prometheus.HistogramVec +} + +// NewMetricsSubsystem returns the metrics endpoint HTTP handler and the Prometheus metrics collectors for the server and middleware. +func NewMetricsSubsystem(db *certdb.CertificateRequestsRepository) *PrometheusMetrics { + metricsBackend := newPrometheusMetrics() + metricsBackend.Handler = promhttp.HandlerFor(metricsBackend.registry, promhttp.HandlerOpts{}) + ticker := time.NewTicker(120 * time.Second) + go func() { + for ; ; <-ticker.C { + csrs, err := db.RetrieveAll() + if err != nil { + log.Println(errors.Join(errors.New("error generating metrics repository: "), err)) + panic(1) + } + metricsBackend.GenerateMetrics(csrs) + } + }() + return metricsBackend +} + +// newPrometheusMetrics reads the status of the database, calculates all of the values of the metrics, +// registers these metrics to the prometheus registry, and returns the registry and the metrics. +// The registry and metrics can be modified from this struct from anywhere in the codebase. +func newPrometheusMetrics() *PrometheusMetrics { + m := &PrometheusMetrics{ + registry: prometheus.NewRegistry(), + CertificateRequests: certificateRequestsMetric(), + OutstandingCertificateRequests: outstandingCertificateRequestsMetric(), + Certificates: certificatesMetric(), + ExpiredCertificates: expiredCertificatesMetric(), + CertificatesExpiringIn1Day: certificatesExpiringIn1DayMetric(), + CertificatesExpiringIn7Days: certificatesExpiringIn7DaysMetric(), + CertificatesExpiringIn30Days: certificatesExpiringIn30DaysMetric(), + CertificatesExpiringIn90Days: certificatesExpiringIn90DaysMetric(), + + RequestsTotal: requestsTotalMetric(), + RequestsDuration: requestDurationMetric(), + } + m.registry.MustRegister(m.CertificateRequests) + m.registry.MustRegister(m.OutstandingCertificateRequests) + m.registry.MustRegister(m.Certificates) + m.registry.MustRegister(m.ExpiredCertificates) + m.registry.MustRegister(m.CertificatesExpiringIn1Day) + m.registry.MustRegister(m.CertificatesExpiringIn7Days) + m.registry.MustRegister(m.CertificatesExpiringIn30Days) + m.registry.MustRegister(m.CertificatesExpiringIn90Days) + + m.registry.MustRegister(m.RequestsTotal) + m.registry.MustRegister(m.RequestsDuration) + + m.registry.MustRegister(collectors.NewGoCollector()) + m.registry.MustRegister(collectors.NewProcessCollector(collectors.ProcessCollectorOpts{})) + return m +} + +// GenerateMetrics receives the live list of csrs to calculate the most recent values for the metrics +// defined for prometheus +func (pm *PrometheusMetrics) GenerateMetrics(csrs []certdb.CertificateRequest) { + var csrCount float64 = float64(len(csrs)) + var outstandingCSRCount float64 + var certCount float64 + var expiredCertCount float64 + var expiringIn1DayCertCount float64 + var expiringIn7DaysCertCount float64 + var expiringIn30DaysCertCount float64 + var expiringIn90DaysCertCount float64 + for _, entry := range csrs { + if entry.Certificate == "" { + outstandingCSRCount += 1 + continue + } + if entry.Certificate == "rejected" { + continue + } + certCount += 1 + expiryDate := certificateExpiryDate(entry.Certificate) + daysRemaining := time.Until(expiryDate).Hours() / 24 + if daysRemaining < 0 { + expiredCertCount += 1 + } else { + if daysRemaining < 1 { + expiringIn1DayCertCount += 1 + } + if daysRemaining < 7 { + expiringIn7DaysCertCount += 1 + } + if daysRemaining < 30 { + expiringIn30DaysCertCount += 1 + } + if daysRemaining < 90 { + expiringIn90DaysCertCount += 1 + } + } + } + pm.CertificateRequests.Set(csrCount) + pm.OutstandingCertificateRequests.Set(outstandingCSRCount) + pm.Certificates.Set(certCount) + pm.ExpiredCertificates.Set(expiredCertCount) + pm.CertificatesExpiringIn1Day.Set(expiringIn1DayCertCount) + pm.CertificatesExpiringIn7Days.Set(expiringIn7DaysCertCount) + pm.CertificatesExpiringIn30Days.Set(expiringIn30DaysCertCount) + pm.CertificatesExpiringIn90Days.Set(expiringIn90DaysCertCount) +} + +func certificateRequestsMetric() prometheus.Gauge { + metric := prometheus.NewGauge(prometheus.GaugeOpts{ + Name: "certificate_requests", + Help: "Total number of certificate requests", + }) + return metric +} + +func outstandingCertificateRequestsMetric() prometheus.Gauge { + metric := prometheus.NewGauge(prometheus.GaugeOpts{ + Name: "outstanding_certificate_requests", + Help: "Number of outstanding certificate requests", + }) + return metric +} + +func certificatesMetric() prometheus.Gauge { + metric := prometheus.NewGauge(prometheus.GaugeOpts{ + Name: "certificates", + Help: "Total number of certificates provided to certificate requests", + }) + return metric +} + +func expiredCertificatesMetric() prometheus.Gauge { + metric := prometheus.NewGauge(prometheus.GaugeOpts{ + Name: "certificates_expired", + Help: "Number of expired certificates", + }) + return metric +} + +func certificatesExpiringIn1DayMetric() prometheus.Gauge { + metric := prometheus.NewGauge(prometheus.GaugeOpts{ + Name: "certificates_expiring_in_1_day", + Help: "Number of certificates expiring in less than 1 day", + }) + return metric +} + +func certificatesExpiringIn7DaysMetric() prometheus.Gauge { + metric := prometheus.NewGauge(prometheus.GaugeOpts{ + Name: "certificates_expiring_in_7_days", + Help: "Number of certificates expiring in less than 7 days", + }) + return metric +} +func certificatesExpiringIn30DaysMetric() prometheus.Gauge { + metric := prometheus.NewGauge(prometheus.GaugeOpts{ + Name: "certificates_expiring_in_30_days", + Help: "Number of certificates expiring in less than 30 days", + }) + return metric +} + +func certificatesExpiringIn90DaysMetric() prometheus.Gauge { + metric := prometheus.NewGauge(prometheus.GaugeOpts{ + Name: "certificates_expiring_in_90_days", + Help: "Number of certificates expiring in less than 90 days", + }) + return metric +} + +func requestsTotalMetric() prometheus.CounterVec { + metric := prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: "http_requests_total", + Help: "Tracks the number of HTTP requests.", + }, []string{"method", "code"}, + ) + return *metric +} + +func requestDurationMetric() prometheus.HistogramVec { + metric := prometheus.NewHistogramVec( + prometheus.HistogramOpts{ + Name: "http_request_duration_seconds", + Help: "Tracks the latencies for HTTP requests.", + Buckets: prometheus.ExponentialBuckets(0.1, 1.5, 5), + }, []string{"method", "code"}, + ) + return *metric +} + +func certificateExpiryDate(certString string) time.Time { + certBlock, _ := pem.Decode([]byte(certString)) + cert, _ := x509.ParseCertificate(certBlock.Bytes) + // TODO: cert.NotAfter can exist in a wrong cert. We should catch that at the db level validation + return cert.NotAfter } diff --git a/internal/metrics/metrics_test.go b/internal/metrics/metrics_test.go index 5885eb7..fbfc162 100644 --- a/internal/metrics/metrics_test.go +++ b/internal/metrics/metrics_test.go @@ -1,17 +1,31 @@ package metrics_test import ( + "bytes" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "fmt" + "math/big" "net/http" "net/http/httptest" + "os" "strings" "testing" + "time" + "github.com/canonical/gocert/internal/certdb" metrics "github.com/canonical/gocert/internal/metrics" ) // TestPrometheusHandler tests that the Prometheus metrics handler responds correctly to an HTTP request. func TestPrometheusHandler(t *testing.T) { - handler := metrics.NewPrometheusMetricsHandler() + db, err := certdb.NewCertificateRequestsRepository(":memory:", "CertificateReq") + if err != nil { + t.Fatal(err) + } + m := metrics.NewMetricsSubsystem(db) request, err := http.NewRequest("GET", "/", nil) if err != nil { @@ -19,7 +33,7 @@ func TestPrometheusHandler(t *testing.T) { } recorder := httptest.NewRecorder() - handler.ServeHTTP(recorder, request) + m.Handler.ServeHTTP(recorder, request) if status := recorder.Code; status != http.StatusOK { t.Errorf("handler returned wrong status code: got %v want %v", status, http.StatusOK) @@ -30,4 +44,121 @@ func TestPrometheusHandler(t *testing.T) { if !strings.Contains(recorder.Body.String(), "go_goroutines") { t.Errorf("handler returned an empty body") } + err = db.Close() + if err != nil { + t.Fatal(err) + } +} + +// Generates a CSR and Certificate with the given days remaining +func generateCertPair(daysRemaining int) (string, string) { + NotAfterTime := time.Now().AddDate(0, 0, daysRemaining) + key, _ := rsa.GenerateKey(rand.Reader, 2048) + + csrTemplate := x509.CertificateRequest{} + certTemplate := x509.Certificate{ + SerialNumber: big.NewInt(1), + NotAfter: NotAfterTime, + } + + csrBytes, _ := x509.CreateCertificateRequest(rand.Reader, &csrTemplate, key) + certBytes, _ := x509.CreateCertificate(rand.Reader, &certTemplate, &certTemplate, &key.PublicKey, key) + + var buff bytes.Buffer + pem.Encode(&buff, &pem.Block{ //nolint:errcheck + Type: "CERTIFICATE REQUEST", + Bytes: csrBytes, + }) + csr := buff.String() + buff.Reset() + pem.Encode(&buff, &pem.Block{ //nolint:errcheck + Type: "CERTIFICATE", + Bytes: certBytes, + }) + cert := buff.String() + return csr, cert +} + +func initializeTestDB(t *testing.T, db *certdb.CertificateRequestsRepository) { + for i, v := range []int{5, 10, 32} { + csr, cert := generateCertPair(v) + _, err := db.Create(csr) + if err != nil { + t.Fatalf("couldn't create test csr:%s", err) + } + _, err = db.Update(fmt.Sprint(i+1), cert) + if err != nil { + t.Fatalf("couldn't create test cert:%s", err) + } + } +} + +// TestMetrics tests some of the metrics that we currently collect. +func TestMetrics(t *testing.T) { + f, err := os.CreateTemp("./","*.db") + fmt.Print(f.Name()) + if err != nil { + t.Fatal("couldn't create temp db file: "+ err.Error()) + } + defer f.Close() + defer os.Remove(f.Name()) + db, err := certdb.NewCertificateRequestsRepository(f.Name(), "CertificateReq") + if err != nil { + t.Fatal(err) + } + initializeTestDB(t, db) + m := metrics.NewMetricsSubsystem(db) + csrs, _ := db.RetrieveAll() + m.GenerateMetrics(csrs) + + request, _ := http.NewRequest("GET", "/", nil) + recorder := httptest.NewRecorder() + m.Handler.ServeHTTP(recorder, request) + + if status := recorder.Code; status != http.StatusOK { + t.Errorf("handler returned wrong status code: got %v want %v", status, http.StatusOK) + } + if recorder.Body.String() == "" { + t.Errorf("handler returned an empty body") + } + for _, line := range strings.Split(recorder.Body.String(), "\n") { + if strings.Contains(line, "outstanding_certificate_requests ") && !strings.HasPrefix(line, "#") { + if !strings.HasSuffix(line, "0") { + t.Errorf("outstanding_certificate_requests expected to receive 0") + } + } else if strings.Contains(line, "certificate_requests ") && !strings.HasPrefix(line, "#") { + if !strings.HasSuffix(line, "3") { + t.Errorf("certificate_requests expected to receive 3") + } + } else if strings.Contains(line, "certificates ") && !strings.HasPrefix(line, "#") { + if !strings.HasSuffix(line, "3") { + t.Errorf("certificates expected to receive 3") + } + } else if strings.Contains(line, "certificates_expired ") && !strings.HasPrefix(line, "#") { + if !strings.HasSuffix(line, "0") { + t.Errorf("certificates_expired expected to receive 0") + } + } else if strings.Contains(line, "certificates_expiring_in_1_day ") && !strings.HasPrefix(line, "#") { + if !strings.HasSuffix(line, "0") { + t.Errorf("certificates_expiring_in_1_day expected to receive 0") + } + } else if strings.Contains(line, "certificates_expiring_in_7_days ") && !strings.HasPrefix(line, "#") { + if !strings.HasSuffix(line, "1") { + t.Errorf("certificates_expiring_in_7_days expected to receive 1") + } + } else if strings.Contains(line, "certificates_expiring_in_30_days ") && !strings.HasPrefix(line, "#") { + if !strings.HasSuffix(line, "2") { + t.Errorf("certificates_expiring_in_30_days expected to receive 2") + } + } else if strings.Contains(line, "certificates_expiring_in_90_days ") && !strings.HasPrefix(line, "#") { + if !strings.HasSuffix(line, "3") { + t.Errorf("certificates_expiring_in_90_days expected to receive 3") + } + } + } + + err = db.Close() + if err != nil { + t.Fatal(err) + } }