diff --git a/cacheall.go b/cacheall.go index 86844f1..4857338 100644 --- a/cacheall.go +++ b/cacheall.go @@ -13,28 +13,24 @@ func cacheAllError(ctx Context) error { } code := http.StatusInternalServerError defer func() { - perr := recover() + err := checkError(ctx, recover()) - if wr.hasWritten || (ctx.Error() == nil && perr == nil) { + if wr.hasWritten || err == nil { return } - if httpCoder, ok := ctx.Error().(HTTPError); ok { + if httpCoder, ok := err.(HTTPError); ok { code = httpCoder.HTTPCode() } wr.WriteHeader(code) - if perr == nil { - perr = ctx.Error() - } - codec := codec.Module.Value(ctx) if codec == nil { - fmt.Fprintf(wr, "%v", perr) + fmt.Fprintf(wr, "%v", err) return } - _ = codec.EncodeResponse(ctx, perr) + _ = codec.EncodeResponse(ctx, err) }() ctx = ctx.WithResponseWriter(wr) @@ -42,3 +38,22 @@ func cacheAllError(ctx Context) error { return nil } + +func checkError(ctx Context, perr any) error { + if perr != nil { + return Error(http.StatusInternalServerError, fmt.Errorf("%v", perr)) + } + + for _, err := range []error{ctx.Err(), ctx.Error()} { + if err == nil { + continue + } + + if _, ok := err.(HTTPError); ok { + return err + } + return Error(http.StatusInternalServerError, err) + } + + return nil +} diff --git a/cacheall_test.go b/cacheall_test.go new file mode 100644 index 0000000..ce8a66c --- /dev/null +++ b/cacheall_test.go @@ -0,0 +1,125 @@ +package espresso_test + +import ( + "errors" + "fmt" + "io" + "net/http" + "net/http/httptest" + "sync/atomic" + "testing" + + "github.com/googollee/go-espresso" + "github.com/googollee/go-espresso/codec" + "github.com/googollee/go-espresso/module" +) + +func TestCacheAllMiddleware(t *testing.T) { + tests := []struct { + name string + providers []module.Provider + middlewares []espresso.HandleFunc + wantCode int + wantBody string + }{ + { + name: "MiddlewareError", + middlewares: []espresso.HandleFunc{func(ctx espresso.Context) error { + return errors.New("error") + }}, + wantCode: http.StatusInternalServerError, + wantBody: "error", + }, + { + name: "MiddlewareHTTPError", + middlewares: []espresso.HandleFunc{func(ctx espresso.Context) error { + return espresso.Error(http.StatusGatewayTimeout, errors.New("gateway timeout")) + }}, + wantCode: http.StatusGatewayTimeout, + wantBody: "gateway timeout", + }, + { + name: "MiddlewarePanic", + middlewares: []espresso.HandleFunc{func(ctx espresso.Context) error { + panic("panic") + }}, + wantCode: http.StatusInternalServerError, + wantBody: "panic", + }, + { + name: "MiddlewareErrorWithCodec", + providers: []module.Provider{codec.Provider}, + middlewares: []espresso.HandleFunc{func(ctx espresso.Context) error { + return errors.New("error") + }}, + wantCode: http.StatusInternalServerError, + wantBody: "{\"message\":\"error\"}\n", + }, + { + name: "MiddlewareHTTPErrorWithCodec", + providers: []module.Provider{codec.Provider}, + middlewares: []espresso.HandleFunc{func(ctx espresso.Context) error { + return espresso.Error(http.StatusGatewayTimeout, errors.New("gateway timeout")) + }}, + wantCode: http.StatusGatewayTimeout, + wantBody: "{\"message\":\"gateway timeout\"}\n", + }, + { + name: "MiddlewarePanicWithCodec", + providers: []module.Provider{codec.Provider}, + middlewares: []espresso.HandleFunc{func(ctx espresso.Context) error { + panic("panic") + }}, + wantCode: http.StatusInternalServerError, + wantBody: "{\"message\":\"panic\"}\n", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + espo := espresso.New() + espo.Use(tc.middlewares...) + espo.AddModule(tc.providers...) + + var called int32 + espo.HandleFunc(func(ctx espresso.Context) error { + atomic.AddInt32(&called, 1) + + if err := ctx.Endpoint(http.MethodGet, "/").End(); err != nil { + return err + } + + fmt.Fprint(ctx.ResponseWriter(), "ok") + + return nil + }) + + called = 0 + svr := httptest.NewServer(espo) + defer svr.Close() + + resp, err := http.Get(svr.URL) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + if got := atomic.LoadInt32(&called); got != 0 { + t.Fatalf("handle func is called") + } + + if got, want := resp.StatusCode, tc.wantCode; got != want { + t.Fatalf("resp.Status = %d, want: %d", got, want) + } + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + + if got, want := string(respBody), tc.wantBody; got != want { + t.Errorf("resp.Body = %q, want: %q", got, want) + } + }) + } +} diff --git a/examples_test.go b/examples_test.go index 975bdcc..cfc7ea2 100644 --- a/examples_test.go +++ b/examples_test.go @@ -28,6 +28,7 @@ func ExampleEspresso() { } espo := espresso.New() + espo.AddModule(codec.Provider) espo.HandleFunc(func(ctx espresso.Context) error { var id int diff --git a/server.go b/server.go index 9149f99..fe97c48 100644 --- a/server.go +++ b/server.go @@ -3,7 +3,6 @@ package espresso import ( "net/http" - "github.com/googollee/go-espresso/codec" "github.com/googollee/go-espresso/module" ) @@ -22,7 +21,6 @@ func New() *Espresso { mux: ret.mux, } - ret.AddModule(codec.Module.ProvideWithFunc(codec.Default)) ret.Use(cacheAllError) return ret diff --git a/test.sh b/test.sh index 112a654..4be3059 100755 --- a/test.sh +++ b/test.sh @@ -1,3 +1,3 @@ #!/bin/sh -GODEBUG=httpmuxgo121=0 go test -v -race ./... +GODEBUG=httpmuxgo121=0 go test -v -race -cover ./...