Skip to content

Commit

Permalink
feat: add SetHandlers when fast fail for no valid host and invalid rP…
Browse files Browse the repository at this point in the history
…ath (#1057)
  • Loading branch information
kingcanfish authored Mar 10, 2024
1 parent 81f0c83 commit b7cbc9d
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 1 deletion.
57 changes: 56 additions & 1 deletion pkg/app/server/hertz_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,10 @@ func formatAsDate(t time.Time) string {
}

// copied from router
var default400Body = []byte("400 bad request")
var (
default400Body = []byte("400 bad request")
requiredHostBody = []byte("missing required Host header")
)

func TestServer_Use(t *testing.T) {
router := New()
Expand Down Expand Up @@ -284,6 +287,13 @@ func TestNotAbsolutePath(t *testing.T) {

func TestNotAbsolutePathWithRawPath(t *testing.T) {
engine := New(WithHostPorts("127.0.0.1:9991"), WithUseRawPath(true))
const (
MiddlewareKey = "middleware_key"
MiddlewareValue = "middleware_value"
)
engine.Use(func(c context.Context, ctx *app.RequestContext) {
ctx.Response.Header.Set(MiddlewareKey, MiddlewareValue)
})
engine.POST("/", func(c context.Context, ctx *app.RequestContext) {
})
engine.POST("/a", func(c context.Context, ctx *app.RequestContext) {
Expand All @@ -301,6 +311,8 @@ func TestNotAbsolutePathWithRawPath(t *testing.T) {
engine.ServeHTTP(context.Background(), ctx)
assert.DeepEqual(t, consts.StatusBadRequest, ctx.Response.StatusCode())
assert.DeepEqual(t, default400Body, ctx.Response.Body())
gh := ctx.Response.Header.Get(MiddlewareKey)
assert.DeepEqual(t, MiddlewareValue, gh)

s = "POST a?a=b HTTP/1.1\r\nHost: a.b.c\r\nContent-Length: 5\r\nContent-Type: foo/bar\r\n\r\nabcdef4343"
zr = mock.NewZeroCopyReader(s)
Expand All @@ -312,6 +324,49 @@ func TestNotAbsolutePathWithRawPath(t *testing.T) {
engine.ServeHTTP(context.Background(), ctx)
assert.DeepEqual(t, consts.StatusBadRequest, ctx.Response.StatusCode())
assert.DeepEqual(t, default400Body, ctx.Response.Body())
gh = ctx.Response.Header.Get(MiddlewareKey)
assert.DeepEqual(t, MiddlewareValue, gh)
}

func TestNotValidHost(t *testing.T) {
engine := New(WithHostPorts("127.0.0.1:9992"))
const (
MiddlewareKey = "middleware_key"
MiddlewareValue = "middleware_value"
)
engine.Use(func(c context.Context, ctx *app.RequestContext) {
ctx.Response.Header.Set(MiddlewareKey, MiddlewareValue)
})
engine.POST("/", func(c context.Context, ctx *app.RequestContext) {
})
engine.POST("/a", func(c context.Context, ctx *app.RequestContext) {
})

s := "POST ?a=b HTTP/1.1\r\nHost: \r\nContent-Length: 5\r\nContent-Type: foo/bar\r\n\r\nabcdef4343"
zr := mock.NewZeroCopyReader(s)

ctx := app.NewContext(0)
if err := req.Read(&ctx.Request, zr); err != nil {
t.Fatalf("unexpected error: %s", err)
}
engine.ServeHTTP(context.Background(), ctx)
assert.DeepEqual(t, consts.StatusBadRequest, ctx.Response.StatusCode())
assert.DeepEqual(t, requiredHostBody, ctx.Response.Body())
gh := ctx.Response.Header.Get(MiddlewareKey)
assert.DeepEqual(t, MiddlewareValue, gh)

s = "POST a?a=b HTTP/1.1\r\nContent-Length: 5\r\nContent-Type: foo/bar\r\n\r\nabcdef4343"
zr = mock.NewZeroCopyReader(s)

ctx = app.NewContext(0)
if err := req.Read(&ctx.Request, zr); err != nil {
t.Fatalf("unexpected error: %s", err)
}
engine.ServeHTTP(context.Background(), ctx)
assert.DeepEqual(t, consts.StatusBadRequest, ctx.Response.StatusCode())
assert.DeepEqual(t, requiredHostBody, ctx.Response.Body())
gh = ctx.Response.Header.Get(MiddlewareKey)
assert.DeepEqual(t, MiddlewareValue, gh)
}

func TestWithBasePath(t *testing.T) {
Expand Down
2 changes: 2 additions & 0 deletions pkg/route/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -726,6 +726,7 @@ func (engine *Engine) ServeHTTP(c context.Context, ctx *app.RequestContext) {

// align with https://datatracker.ietf.org/doc/html/rfc2616#section-5.2
if len(ctx.Request.Host()) == 0 && ctx.Request.Header.IsHTTP11() && bytesconv.B2s(ctx.Request.Method()) != consts.MethodConnect {
ctx.SetHandlers(engine.Handlers)
serveError(c, ctx, consts.StatusBadRequest, requiredHostBody)
return
}
Expand All @@ -743,6 +744,7 @@ func (engine *Engine) ServeHTTP(c context.Context, ctx *app.RequestContext) {

// Follow RFC7230#section-5.3
if rPath == "" || rPath[0] != '/' {
ctx.SetHandlers(engine.Handlers)
serveError(c, ctx, consts.StatusBadRequest, default400Body)
return
}
Expand Down

0 comments on commit b7cbc9d

Please sign in to comment.