Skip to content

Commit

Permalink
Added hello-world-grade test for main and refactored it to be bit mor…
Browse files Browse the repository at this point in the history
…e testable
  • Loading branch information
fingon committed May 16, 2024
1 parent f31273d commit 42bb47c
Show file tree
Hide file tree
Showing 3 changed files with 172 additions and 36 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
coverage.out
db.json
lixie
localhost:8080
101 changes: 66 additions & 35 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,12 @@ import (
"fmt"
"io/fs"
"log"
"net"
"net/http"
"os"
"os/signal"
"strconv"
"sync"
"time"

"github.com/a-h/templ"
Expand All @@ -36,20 +40,50 @@ func setupDatabase(config data.DatabaseConfig, path string) *data.Database {
return db
}

func newMux(db *data.Database) http.Handler {
mux := http.NewServeMux()

// Configure the routes
// http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
// Other options: StatusMovedPermanently, StatusFound
// http.Redirect(w, r, "/rule/edit", http.StatusSeeOther)
// })
mux.HandleFunc("/", http.NotFound)
mainHandler := templ.Handler(MainPage(db))
mux.Handle("/{$}", mainHandler)

mux.Handle(topLevelLog.PathMatcher(), logListHandler(db))
mux.Handle(topLevelLog.Path+"/{hash}/ham", logClassifyHandler(db, true))
mux.Handle(topLevelLog.Path+"/{hash}/spam", logClassifyHandler(db, false))

mux.Handle(topLevelLogRule.PathMatcher(), logRuleListHandler(db))
mux.Handle(topLevelLogRule.Path+"/edit", logRuleEditHandler(db))
mux.Handle(topLevelLogRule.Path+"/{id}/delete", logRuleDeleteSpecificHandler(db))
mux.Handle(topLevelLogRule.Path+"/{id}/edit", logRuleEditSpecificHandler(db))

// Static content
staticFS, err := fs.Sub(embedContent, "static")
if err != nil {
log.Panic(err)
}

mux.Handle("/static/", http.StripPrefix("/static/", http.FileServer(http.FS(staticFS))))
return mux
}

// Most of this function is based on
// https://grafana.com/blog/2024/02/09/how-i-write-http-services-in-go-after-13-years/
//
// These might be also useful at some point
//
// getenv func(string) string,
// stdin io.Reader,
// stdout, stderr io.Writer,
func run(
_ context.Context,
ctx context.Context,
args []string) error {
// This would be relevant only if we handled our own context.
// However, http.ListenAndServe catches os.Interrupt so this
// is not necessary:
//
// ctx, cancel := signal.NotifyContext(ctx, os.Interrupt)
// defer cancel()
ctx, cancel := signal.NotifyContext(ctx, os.Interrupt)
defer cancel()

// CLI
flags := flag.NewFlagSet(args[0], flag.ExitOnError)
Expand All @@ -63,40 +97,37 @@ func run(
return err
}

// Static content
staticFS, err := fs.Sub(embedContent, "static")
if err != nil {
log.Panic(err)
}

config := data.DatabaseConfig{LokiServer: *lokiServer,
LokiSelector: *lokiSelector}
db := setupDatabase(config, *dbPath)

// Configure the routes
// http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
// Other options: StatusMovedPermanently, StatusFound
// http.Redirect(w, r, "/rule/edit", http.StatusSeeOther)
// })
http.HandleFunc("/", http.NotFound)
mainHandler := templ.Handler(MainPage(db))
http.Handle("/{$}", mainHandler)

http.Handle(topLevelLog.PathMatcher(), logListHandler(db))
http.Handle(topLevelLog.Path+"/{hash}/ham", logClassifyHandler(db, true))
http.Handle(topLevelLog.Path+"/{hash}/spam", logClassifyHandler(db, false))

http.Handle(topLevelLogRule.PathMatcher(), logRuleListHandler(db))
http.Handle(topLevelLogRule.Path+"/edit", logRuleEditHandler(db))
http.Handle(topLevelLogRule.Path+"/{id}/delete", logRuleDeleteSpecificHandler(db))
http.Handle(topLevelLogRule.Path+"/{id}/edit", logRuleEditSpecificHandler(db))

http.Handle("/static/", http.StripPrefix("/static/", http.FileServer(http.FS(staticFS))))
mux := newMux(db)

// Start the actual server
endpoint := fmt.Sprintf("%s:%d", *address, *port)
fmt.Printf("Listening on %s\n", endpoint)
return http.ListenAndServe(endpoint, nil)
httpServer := &http.Server{
Addr: net.JoinHostPort(*address, strconv.Itoa(*port)),
Handler: mux,
}
go func() {
log.Printf("listening on %s\n", httpServer.Addr)
if err := httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
fmt.Fprintf(os.Stderr, "error listening: %v", err)
}
}()

var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
<-ctx.Done()
shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := httpServer.Shutdown(shutdownCtx); err != nil {
fmt.Fprintf(os.Stderr, "error shutting down http server: %s\n", err)
}
}()
wg.Wait()
return nil
}

func main() {
Expand Down
105 changes: 105 additions & 0 deletions main_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
/*
* Author: Markus Stenberg <[email protected]>
*
* Copyright (c) 2024 Markus Stenberg
*
* Created: Thu May 16 07:24:25 2024 mstenber
* Last modified: Thu May 16 08:43:58 2024 mstenber
* Edit time: 25 min
*
*/

package main

import (
"context"
"errors"
"fmt"
"log"
"net/http"
"os"
"strconv"
"testing"
"time"

"gotest.tools/v3/assert"
)

func retrieveURL(ctx context.Context, url string) (*http.Response, error) {
client := http.Client{}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, fmt.Errorf("failed to create http request: %w", err)
}
return client.Do(req)
}

func waitForURL(ctx context.Context, url string) error {
for {
resp, err := retrieveURL(ctx, url)
if err != nil {
if errors.Is(err, context.DeadlineExceeded) {
return err
}
fmt.Printf("Error making request: %v\n", err)
continue
}
resp.Body.Close()
if resp.StatusCode == http.StatusOK {
return nil
}
select {
case <-ctx.Done():
return ctx.Err()
default:
time.Sleep(100 * time.Millisecond)
}
}
}

func TestMain(t *testing.T) {
ctx := context.Background()
ctx, cancel := context.WithCancel(ctx)
t.Cleanup(cancel)

f, err := os.CreateTemp("", "lixie-test-db-*.json")
if err != nil {
log.Fatal(err)
}
// TODO: Produce test data?
f.Close()
defer os.Remove(f.Name())
port := 18080

// TODO: Produce some sort of Loki fake (or other way to
// ingest precanned input?)
go func() {
err := run(ctx, []string{"lixie", "-port", strconv.Itoa(port), "-db", f.Name(), "-loki-server", "http://localhost:3100"})
if err != nil {
log.Panic(err)
}
}()

ctx2, cancel2 := context.WithTimeout(ctx, 1*time.Second)
t.Cleanup(cancel2)
baseURL := fmt.Sprintf("http://localhost:%d", port)
err = waitForURL(ctx2, baseURL)
if err != nil {
log.Panic(err)
}

t.Parallel()
for _, tli := range topLevelInfos {
t.Run(tli.Title, func(t *testing.T) {
ctx3, cancel3 := context.WithTimeout(ctx, 1*time.Second)
t.Cleanup(cancel3)

resp, err := retrieveURL(ctx3, fmt.Sprintf("%s%s", baseURL, tli.Path))
if err != nil {
log.Panic(err)
}
resp.Body.Close()
assert.Equal(t, resp.StatusCode, http.StatusOK)
})
}
}

0 comments on commit 42bb47c

Please sign in to comment.