Skip to content

Commit

Permalink
Add middleware: cacheall.
Browse files Browse the repository at this point in the history
  • Loading branch information
googollee committed Apr 12, 2024
1 parent ab29eaf commit 3c7a619
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 17 deletions.
44 changes: 44 additions & 0 deletions cacheall.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package espresso

import (
"fmt"
"net/http"

"github.com/googollee/go-espresso/codec"
)

func cacheAllError(ctx Context) error {
wr := &responseWriter{
ResponseWriter: ctx.ResponseWriter(),
}
code := http.StatusInternalServerError
defer func() {
perr := recover()

if wr.hasWritten || (ctx.Error() == nil && perr == nil) {
return
}

if httpCoder, ok := ctx.Error().(HTTPError); ok {
code = httpCoder.HTTPCode()
}
wr.WriteHeader(code)

if perr == nil {
perr = ctx.Error()
}

codec := codec.Module.Value(ctx)
if codec == nil {
fmt.Fprintf(wr, "%v", perr)
return
}

_ = codec.EncodeResponse(ctx, perr)
}()

ctx = ctx.WithResponseWriter(wr)
ctx.Next()

return nil
}
17 changes: 1 addition & 16 deletions router.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,28 +64,13 @@ func (g *router) register(ctx *buildtimeContext, fn HandleFunc) {

pattern := ctx.endpoint.Method + " " + path
g.mux.HandleFunc(pattern, func(w http.ResponseWriter, r *http.Request) {
wr := &responseWriter{
ResponseWriter: w,
}

ctx := &runtimeContext{
ctx: r.Context(),
endpoint: &endpoint,
request: r,
response: wr,
response: w,
}

ctx.Next()

if wr.hasWritten || ctx.err != nil {
return
}

code := http.StatusInternalServerError
if httpCoder, ok := ctx.err.(HTTPError); ok {
code = httpCoder.HTTPCode()
}
w.WriteHeader(code)
fmt.Fprint(w, ctx.err.Error())
})
}
6 changes: 5 additions & 1 deletion runtime.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,11 @@ func (c *runtimeContext) ResponseWriter() http.ResponseWriter {
}

func (c *runtimeContext) Next() {
c.endpoint.ChainFuncs[c.chainIndex](c)
index := c.chainIndex
c.chainIndex++
if err := c.endpoint.ChainFuncs[index](c); err != nil {
c.err = err
}
}

func (c *runtimeContext) Error() error {
Expand Down
1 change: 1 addition & 0 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ func New() *Espresso {
}

ret.AddModule(codec.Module.ProvideWithFunc(codec.Default))
ret.Use(cacheAllError)

return ret
}
Expand Down

0 comments on commit 3c7a619

Please sign in to comment.