diff --git a/.gitignore b/.gitignore index 4b46d81..475d1d2 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,3 @@ +coverage.out db.json lixie -localhost:8080 diff --git a/main.go b/main.go index 5546532..4cf27cf 100644 --- a/main.go +++ b/main.go @@ -14,8 +14,12 @@ import ( "fmt" "io/fs" "log" + "net" "net/http" "os" + "os/signal" + "strconv" + "sync" "time" "github.com/a-h/templ" @@ -36,20 +40,50 @@ func setupDatabase(config data.DatabaseConfig, path string) *data.Database { return db } +func newMux(db *data.Database) http.Handler { + mux := http.NewServeMux() + + // Configure the routes + // http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + // Other options: StatusMovedPermanently, StatusFound + // http.Redirect(w, r, "/rule/edit", http.StatusSeeOther) + // }) + mux.HandleFunc("/", http.NotFound) + mainHandler := templ.Handler(MainPage(db)) + mux.Handle("/{$}", mainHandler) + + mux.Handle(topLevelLog.PathMatcher(), logListHandler(db)) + mux.Handle(topLevelLog.Path+"/{hash}/ham", logClassifyHandler(db, true)) + mux.Handle(topLevelLog.Path+"/{hash}/spam", logClassifyHandler(db, false)) + + mux.Handle(topLevelLogRule.PathMatcher(), logRuleListHandler(db)) + mux.Handle(topLevelLogRule.Path+"/edit", logRuleEditHandler(db)) + mux.Handle(topLevelLogRule.Path+"/{id}/delete", logRuleDeleteSpecificHandler(db)) + mux.Handle(topLevelLogRule.Path+"/{id}/edit", logRuleEditSpecificHandler(db)) + + // Static content + staticFS, err := fs.Sub(embedContent, "static") + if err != nil { + log.Panic(err) + } + + mux.Handle("/static/", http.StripPrefix("/static/", http.FileServer(http.FS(staticFS)))) + return mux +} + +// Most of this function is based on +// https://grafana.com/blog/2024/02/09/how-i-write-http-services-in-go-after-13-years/ +// // These might be also useful at some point // // getenv func(string) string, // stdin io.Reader, // stdout, stderr io.Writer, func run( - _ context.Context, + ctx context.Context, args []string) error { - // This would be relevant only if we handled our own context. - // However, http.ListenAndServe catches os.Interrupt so this - // is not necessary: - // - // ctx, cancel := signal.NotifyContext(ctx, os.Interrupt) - // defer cancel() + ctx, cancel := signal.NotifyContext(ctx, os.Interrupt) + defer cancel() // CLI flags := flag.NewFlagSet(args[0], flag.ExitOnError) @@ -63,40 +97,37 @@ func run( return err } - // Static content - staticFS, err := fs.Sub(embedContent, "static") - if err != nil { - log.Panic(err) - } - config := data.DatabaseConfig{LokiServer: *lokiServer, LokiSelector: *lokiSelector} db := setupDatabase(config, *dbPath) - // Configure the routes - // http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { - // Other options: StatusMovedPermanently, StatusFound - // http.Redirect(w, r, "/rule/edit", http.StatusSeeOther) - // }) - http.HandleFunc("/", http.NotFound) - mainHandler := templ.Handler(MainPage(db)) - http.Handle("/{$}", mainHandler) - - http.Handle(topLevelLog.PathMatcher(), logListHandler(db)) - http.Handle(topLevelLog.Path+"/{hash}/ham", logClassifyHandler(db, true)) - http.Handle(topLevelLog.Path+"/{hash}/spam", logClassifyHandler(db, false)) - - http.Handle(topLevelLogRule.PathMatcher(), logRuleListHandler(db)) - http.Handle(topLevelLogRule.Path+"/edit", logRuleEditHandler(db)) - http.Handle(topLevelLogRule.Path+"/{id}/delete", logRuleDeleteSpecificHandler(db)) - http.Handle(topLevelLogRule.Path+"/{id}/edit", logRuleEditSpecificHandler(db)) - - http.Handle("/static/", http.StripPrefix("/static/", http.FileServer(http.FS(staticFS)))) + mux := newMux(db) // Start the actual server - endpoint := fmt.Sprintf("%s:%d", *address, *port) - fmt.Printf("Listening on %s\n", endpoint) - return http.ListenAndServe(endpoint, nil) + httpServer := &http.Server{ + Addr: net.JoinHostPort(*address, strconv.Itoa(*port)), + Handler: mux, + } + go func() { + log.Printf("listening on %s\n", httpServer.Addr) + if err := httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed { + fmt.Fprintf(os.Stderr, "error listening: %v", err) + } + }() + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + <-ctx.Done() + shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + if err := httpServer.Shutdown(shutdownCtx); err != nil { + fmt.Fprintf(os.Stderr, "error shutting down http server: %s\n", err) + } + }() + wg.Wait() + return nil } func main() { diff --git a/main_test.go b/main_test.go new file mode 100644 index 0000000..588c54a --- /dev/null +++ b/main_test.go @@ -0,0 +1,105 @@ +/* + * Author: Markus Stenberg + * + * Copyright (c) 2024 Markus Stenberg + * + * Created: Thu May 16 07:24:25 2024 mstenber + * Last modified: Thu May 16 08:43:58 2024 mstenber + * Edit time: 25 min + * + */ + +package main + +import ( + "context" + "errors" + "fmt" + "log" + "net/http" + "os" + "strconv" + "testing" + "time" + + "gotest.tools/v3/assert" +) + +func retrieveURL(ctx context.Context, url string) (*http.Response, error) { + client := http.Client{} + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return nil, fmt.Errorf("failed to create http request: %w", err) + } + return client.Do(req) +} + +func waitForURL(ctx context.Context, url string) error { + for { + resp, err := retrieveURL(ctx, url) + if err != nil { + if errors.Is(err, context.DeadlineExceeded) { + return err + } + fmt.Printf("Error making request: %v\n", err) + continue + } + resp.Body.Close() + if resp.StatusCode == http.StatusOK { + return nil + } + select { + case <-ctx.Done(): + return ctx.Err() + default: + time.Sleep(100 * time.Millisecond) + } + } +} + +func TestMain(t *testing.T) { + ctx := context.Background() + ctx, cancel := context.WithCancel(ctx) + t.Cleanup(cancel) + + f, err := os.CreateTemp("", "lixie-test-db-*.json") + if err != nil { + log.Fatal(err) + } + // TODO: Produce test data? + f.Close() + defer os.Remove(f.Name()) + port := 18080 + + // TODO: Produce some sort of Loki fake (or other way to + // ingest precanned input?) + go func() { + err := run(ctx, []string{"lixie", "-port", strconv.Itoa(port), "-db", f.Name(), "-loki-server", "http://localhost:3100"}) + if err != nil { + log.Panic(err) + } + }() + + ctx2, cancel2 := context.WithTimeout(ctx, 1*time.Second) + t.Cleanup(cancel2) + baseURL := fmt.Sprintf("http://localhost:%d", port) + err = waitForURL(ctx2, baseURL) + if err != nil { + log.Panic(err) + } + + t.Parallel() + for _, tli := range topLevelInfos { + t.Run(tli.Title, func(t *testing.T) { + ctx3, cancel3 := context.WithTimeout(ctx, 1*time.Second) + t.Cleanup(cancel3) + + resp, err := retrieveURL(ctx3, fmt.Sprintf("%s%s", baseURL, tli.Path)) + if err != nil { + log.Panic(err) + } + resp.Body.Close() + assert.Equal(t, resp.StatusCode, http.StatusOK) + }) + } +}