diff --git a/internal/api/handlers.go b/internal/api/handlers.go index 9d94940..d66b402 100644 --- a/internal/api/handlers.go +++ b/internal/api/handlers.go @@ -27,16 +27,16 @@ func NewGoCertRouter(env *Environment) http.Handler { apiV1Router.HandleFunc("POST /certificate_requests/{id}/certificate/reject", RejectCertificate(env)) apiV1Router.HandleFunc("DELETE /certificate_requests/{id}/certificate", DeleteCertificate(env)) - metricsHandler := metrics.NewPrometheusMetricsHandler() + m := metrics.NewMetricsSubsystem(env.DB) frontendHandler := newFrontendFileServer() router := http.NewServeMux() router.HandleFunc("/status", HealthCheck) - router.Handle("/metrics", metricsHandler) + router.Handle("/metrics", m.Handler) router.Handle("/api/v1/", http.StripPrefix("/api/v1", apiV1Router)) router.Handle("/", frontendHandler) - ctx := Context{} + ctx := MiddlewareContext{metrics: m} middleware := createMiddlewareStack( Metrics(&ctx), Logging(&ctx), @@ -44,21 +44,6 @@ func NewGoCertRouter(env *Environment) http.Handler { return middleware(router) } -// createMiddlewareStack chains given middleware for the server. -// Each middleware functions calls next.ServeHTTP in order to resume the chain of execution. -// The order these functions are given to createMiddlewareStack matters. -// The functions will run the code before next.ServeHTTP in order. -// The functions will run the code after next.ServeHTTP in reverse order. -func createMiddlewareStack(middleware ...Middleware) Middleware { - return func(next http.Handler) http.Handler { - for i := len(middleware) - 1; i >= 0; i-- { - mw := middleware[i] - next = mw(next) - } - return next - } -} - // newFrontendFileServer uses the embedded ui output files as the base for a file server func newFrontendFileServer() http.Handler { frontendFS, err := fs.Sub(ui.FrontendFS, "out") diff --git a/internal/api/middleware.go b/internal/api/middleware.go index 4ec6b69..e5c9bae 100644 --- a/internal/api/middleware.go +++ b/internal/api/middleware.go @@ -3,13 +3,17 @@ package server import ( "log" "net/http" + "strings" + + "github.com/canonical/gocert/internal/metrics" ) type Middleware func(http.Handler) http.Handler -// The Context type helps middleware pass along information through the chain. -type Context struct { +// The MiddlewareContext type helps middleware pass along information through the chain. +type MiddlewareContext struct { responseStatusCode int + metrics *metrics.PrometheusMetrics } // The ResponseWriterCloner struct implements the http.ResponseWriter class, and copies the status @@ -30,20 +34,41 @@ func (rwc *ResponseWriterCloner) WriteHeader(code int) { rwc.ResponseWriter.WriteHeader(code) } +// createMiddlewareStack chains given middleware for the server. +// Each middleware functions calls next.ServeHTTP in order to resume the chain of execution. +// The order these functions are given to createMiddlewareStack matters. +// The functions will run the code before next.ServeHTTP in order. +// The functions will run the code after next.ServeHTTP in reverse order. +func createMiddlewareStack(middleware ...Middleware) Middleware { + return func(next http.Handler) http.Handler { + for i := len(middleware) - 1; i >= 0; i-- { + mw := middleware[i] + next = mw(next) + } + return next + } +} + // The Metrics middleware captures any request relevant to a metric and records it for prometheus. -func Metrics(ctx *Context) Middleware { +func Metrics(ctx *MiddlewareContext) Middleware { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { next.ServeHTTP(w, r) - if ctx.responseStatusCode != 200 { + if ctx.responseStatusCode/100 != 2 { return } + if r.Method == "POST" && r.URL.Path == "/api/v1/certificate_requests" { + ctx.metrics.CertificateRequests.Inc() + } + if r.Method == "DELETE" && strings.HasPrefix(r.URL.Path, "/api/v1/certificate_requests") { + ctx.metrics.CertificateRequests.Dec() + } }) } } // The logging middleware captures any http request coming through, and logs it. -func Logging(ctx *Context) Middleware { +func Logging(ctx *MiddlewareContext) Middleware { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { clonedWwriter := NewResponseWriter(w) diff --git a/internal/certdb/validation.go b/internal/certdb/validation.go index 9a487a4..ae191ae 100644 --- a/internal/certdb/validation.go +++ b/internal/certdb/validation.go @@ -22,6 +22,7 @@ func ValidateCertificateRequest(csr string) error { if err != nil { return err } + // TODO: We should validate the actual certificate request parameters here too. (Has the required fields etc) return nil } @@ -40,6 +41,7 @@ func ValidateCertificate(cert string) error { if err != nil { return err } + // TODO: We should validate the actual certificate parameters here too. (Has the required fields etc) return nil } diff --git a/internal/metrics/metrics.go b/internal/metrics/metrics.go index b55d107..f8cccb4 100644 --- a/internal/metrics/metrics.go +++ b/internal/metrics/metrics.go @@ -1,17 +1,201 @@ package metrics import ( + "crypto/x509" + "encoding/pem" + "errors" + "log" "net/http" + "time" + "github.com/canonical/gocert/internal/certdb" "github.com/prometheus/client_golang/prometheus" - "github.com/prometheus/client_golang/prometheus/collectors" "github.com/prometheus/client_golang/prometheus/promhttp" ) +type PrometheusMetrics struct { + http.Handler + registry *prometheus.Registry + CertificateRequests prometheus.Gauge + OutstandingCertificateRequests prometheus.Gauge + Certificates prometheus.Gauge + CertificatesExpiringIn1Day prometheus.Gauge + CertificatesExpiringIn7Days prometheus.Gauge + CertificatesExpiringIn30Days prometheus.Gauge + CertificatesExpiringIn90Days prometheus.Gauge + ExpiredCertificates prometheus.Gauge +} + // Returns an HTTP handler for Prometheus metrics. -func NewPrometheusMetricsHandler() http.Handler { - reg := prometheus.NewRegistry() - reg.MustRegister(collectors.NewGoCollector(), collectors.NewProcessCollector(collectors.ProcessCollectorOpts{})) - prometheusHandler := promhttp.HandlerFor(reg, promhttp.HandlerOpts{}) - return prometheusHandler +func NewMetricsSubsystem(db *certdb.CertificateRequestsRepository) *PrometheusMetrics { + metricsBackend, err := newPrometheusMetrics(db) + if err != nil { + log.Println(errors.Join(errors.New("error generating metrics repository: "), err)) + } + metricsBackend.Handler = promhttp.HandlerFor(metricsBackend.registry, promhttp.HandlerOpts{}) + return metricsBackend +} + +// Returns the metrics that prometheus needs to collect +func newPrometheusMetrics(db *certdb.CertificateRequestsRepository) (*PrometheusMetrics, error) { + csrs, err := db.RetrieveAll() + if err != nil { + return nil, errors.Join(errors.New("could not retrieve certs for metrics: "), err) + } + m := &PrometheusMetrics{ + registry: prometheus.NewRegistry(), + CertificateRequests: certificateRequestsMetric(), + OutstandingCertificateRequests: outstandingCertificateRequestsMetric(), + Certificates: certificatesMetric(), + ExpiredCertificates: expiredCertificatesMetric(), + CertificatesExpiringIn1Day: certificatesExpiringIn1DayMetric(), + CertificatesExpiringIn7Days: certificatesExpiringIn7DaysMetric(), + CertificatesExpiringIn30Days: certificatesExpiringIn30DaysMetric(), + CertificatesExpiringIn90Days: certificatesExpiringIn90DaysMetric(), + } + m.registry.MustRegister(m.CertificateRequests) + m.registry.MustRegister(m.OutstandingCertificateRequests) + m.registry.MustRegister(m.Certificates) + m.registry.MustRegister(m.ExpiredCertificates) + m.registry.MustRegister(m.CertificatesExpiringIn1Day) + m.registry.MustRegister(m.CertificatesExpiringIn7Days) + m.registry.MustRegister(m.CertificatesExpiringIn30Days) + m.registry.MustRegister(m.CertificatesExpiringIn90Days) + + m.generateMetrics(csrs) + // m.registry.MustRegister(collectors.NewGoCollector()) + // m.registry.MustRegister(collectors.NewProcessCollector(collectors.ProcessCollectorOpts{})) + return m, nil +} + +func (pm *PrometheusMetrics) generateMetrics(csrs []certdb.CertificateRequest) { + // TODO: This can run every 24 hours also to make sure we update the expiring in X day metrics. + var csrCount int = len(csrs) + var outstandingCSRCount int + var certCount int + var expiredCertCount int + var expiringIn1DayCertCount int + var expiringIn7DaysCertCount int + var expiringIn30DaysCertCount int + var expiringIn90DaysCertCount int + for _, entry := range csrs { + if entry.Certificate == "" { + outstandingCSRCount += 1 + } + if entry.Certificate != "" && entry.Certificate != "rejected" { + certCount += 1 + } + expiryDate := certificateExpiryDate(entry.Certificate) + daysRemaining := time.Until(expiryDate).Hours() / 24 + if daysRemaining < 0 { + expiredCertCount += 1 + } else { + if daysRemaining < 1 { + expiringIn1DayCertCount += 1 + } + if daysRemaining < 7 { + expiringIn7DaysCertCount += 1 + } + if daysRemaining < 30 { + expiringIn30DaysCertCount += 1 + } + if daysRemaining < 90 { + expiringIn90DaysCertCount += 1 + } + } + } + pm.CertificateRequests.Set(float64(csrCount)) + pm.OutstandingCertificateRequests.Set(float64(outstandingCSRCount)) + pm.Certificates.Set(float64(certCount)) + pm.ExpiredCertificates.Set(float64(expiredCertCount)) + pm.CertificatesExpiringIn1Day.Set(float64(expiringIn1DayCertCount)) + pm.CertificatesExpiringIn7Days.Set(float64(expiringIn7DaysCertCount)) + pm.CertificatesExpiringIn30Days.Set(float64(expiringIn30DaysCertCount)) + pm.CertificatesExpiringIn90Days.Set(float64(expiringIn90DaysCertCount)) +} + +func certificateRequestsMetric() prometheus.Gauge { + metric := prometheus.NewGauge(prometheus.GaugeOpts{ + Namespace: "TODO", + Subsystem: "TODO", + Name: "certificate_requests", + Help: "Total number of certificate requests", + }) + return metric +} + +func outstandingCertificateRequestsMetric() prometheus.Gauge { + metric := prometheus.NewGauge(prometheus.GaugeOpts{ + Namespace: "TODO", + Subsystem: "TODO", + Name: "outstanding_certificate_requests", + Help: "Number of outstanding certificate requests", + }) + return metric +} + +func certificatesMetric() prometheus.Gauge { + metric := prometheus.NewGauge(prometheus.GaugeOpts{ + Namespace: "TODO", + Subsystem: "TODO", + Name: "certificates", + Help: "Total number of certificates provided to certificate requests", + }) + return metric +} + +func expiredCertificatesMetric() prometheus.Gauge { + metric := prometheus.NewGauge(prometheus.GaugeOpts{ + Namespace: "TODO", + Subsystem: "TODO", + Name: "certificates_expired", + Help: "Number of expired certificates", + }) + return metric +} + +func certificatesExpiringIn1DayMetric() prometheus.Gauge { + metric := prometheus.NewGauge(prometheus.GaugeOpts{ + Namespace: "TODO", + Subsystem: "TODO", + Name: "certificates_expiring_in_1_day", + Help: "Number of certificates expiring in less than 1 day", + }) + return metric +} + +func certificatesExpiringIn7DaysMetric() prometheus.Gauge { + metric := prometheus.NewGauge(prometheus.GaugeOpts{ + Namespace: "TODO", + Subsystem: "TODO", + Name: "certificates_expiring_in_7_days", + Help: "Number of certificates expiring in less than 7 days", + }) + return metric +} +func certificatesExpiringIn30DaysMetric() prometheus.Gauge { + metric := prometheus.NewGauge(prometheus.GaugeOpts{ + Namespace: "TODO", + Subsystem: "TODO", + Name: "certificates_expiring_in_30_days", + Help: "Number of certificates expiring in less than 30 days", + }) + return metric +} + +func certificatesExpiringIn90DaysMetric() prometheus.Gauge { + metric := prometheus.NewGauge(prometheus.GaugeOpts{ + Namespace: "TODO", + Subsystem: "TODO", + Name: "certificates_expiring_in_90_days", + Help: "Number of certificates expiring in less than 90 days", + }) + return metric +} + +func certificateExpiryDate(certString string) time.Time { + certBlock, _ := pem.Decode([]byte(certString)) + cert, _ := x509.ParseCertificate(certBlock.Bytes) + // TODO: cert.NotAfter can exist in a wrong cert. We should catch that at the db level validation + return cert.NotAfter } diff --git a/internal/metrics/metrics_test.go b/internal/metrics/metrics_test.go index 5885eb7..b1c4275 100644 --- a/internal/metrics/metrics_test.go +++ b/internal/metrics/metrics_test.go @@ -1,17 +1,23 @@ package metrics_test import ( + "log" "net/http" "net/http/httptest" "strings" "testing" + "github.com/canonical/gocert/internal/certdb" metrics "github.com/canonical/gocert/internal/metrics" ) // TestPrometheusHandler tests that the Prometheus metrics handler responds correctly to an HTTP request. func TestPrometheusHandler(t *testing.T) { - handler := metrics.NewPrometheusMetricsHandler() + db, err := certdb.NewCertificateRequestsRepository(":memory:", "CertificateReq") + if err != nil { + log.Fatalln(err) + } + m := metrics.NewMetricsSubsystem(db) request, err := http.NewRequest("GET", "/", nil) if err != nil { @@ -19,7 +25,7 @@ func TestPrometheusHandler(t *testing.T) { } recorder := httptest.NewRecorder() - handler.ServeHTTP(recorder, request) + m.Handler.ServeHTTP(recorder, request) if status := recorder.Code; status != http.StatusOK { t.Errorf("handler returned wrong status code: got %v want %v", status, http.StatusOK) @@ -30,4 +36,8 @@ func TestPrometheusHandler(t *testing.T) { if !strings.Contains(recorder.Body.String(), "go_goroutines") { t.Errorf("handler returned an empty body") } + err = db.Close() + if err != nil { + log.Fatalln(err) + } }