Skip to content

Commit

Permalink
Added test case for #27
Browse files Browse the repository at this point in the history
  • Loading branch information
dgrr committed Jan 15, 2022
1 parent 30d1302 commit 1eb5ca1
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 6 deletions.
16 changes: 15 additions & 1 deletion serverConn.go
Original file line number Diff line number Diff line change
Expand Up @@ -264,14 +264,16 @@ func (sc *serverConn) handleStreams() {
streamPool.Put(strm)

if sc.debug {
sc.logger.Printf("Stream destroyed %d\n", strmID)
sc.logger.Printf("Stream destroyed %d. Open streams: %d\n", strmID, openStreams)
}
}

loop:
for {
select {
case <-sc.maxRequestTimer.C:
reqTimerArmed = false

deleteUntil := 0
for _, strm := range strms {
// the request is due if the startedAt time + maxRequestTime is in the past
Expand Down Expand Up @@ -308,6 +310,10 @@ loop:
when := strm.startedAt.Add(sc.maxRequestTime).Sub(time.Now())
// if the time is negative or zero it triggers imm
sc.maxRequestTimer.Reset(when)

if sc.debug {
sc.logger.Printf("Next request will timeout in %f seconds\n", when.Seconds())
}
}
}
case fr, ok := <-sc.reader:
Expand Down Expand Up @@ -380,9 +386,17 @@ loop:

sc.createStream(sc.c, fr.Type(), strm)

if sc.debug {
sc.logger.Printf("Stream %d created. Open streams: %d\n", strm.ID(), openStreams)
}

if !reqTimerArmed && sc.maxRequestTime > 0 {
reqTimerArmed = true
sc.maxRequestTimer.Reset(sc.maxRequestTime)

if sc.debug {
sc.logger.Printf("Next request will timeout in %f seconds\n", sc.maxRequestTime.Seconds())
}
}
}

Expand Down
85 changes: 80 additions & 5 deletions server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func getConn(s *Server) (*Conn, net.Listener, error) {
return nc, ln, nc.doHandshake()
}

func makeHeaders(id uint32, enc *HPACK, endStream bool, hs map[string]string) *FrameHeader {
func makeHeaders(id uint32, enc *HPACK, endHeaders, endStream bool, hs map[string]string) *FrameHeader {
fr := AcquireFrameHeader()

fr.SetStream(id)
Expand All @@ -56,7 +56,7 @@ func makeHeaders(id uint32, enc *HPACK, endStream bool, hs map[string]string) *F

h.SetPadding(false)
h.SetEndStream(endStream)
h.SetEndHeaders(true)
h.SetEndHeaders(endHeaders)

return fr
}
Expand Down Expand Up @@ -89,21 +89,21 @@ func testIssue52(t *testing.T) {

msg := []byte("Hello world, how are you doing?")

h1 := makeHeaders(3, c.enc, false, map[string]string{
h1 := makeHeaders(3, c.enc, true, false, map[string]string{
string(StringAuthority): "localhost",
string(StringMethod): "POST",
string(StringPath): "/hello/world",
string(StringScheme): "https",
"Content-Length": strconv.Itoa(len(msg)),
})
h2 := makeHeaders(9, c.enc, false, map[string]string{
h2 := makeHeaders(9, c.enc, true, false, map[string]string{
string(StringAuthority): "localhost",
string(StringMethod): "POST",
string(StringPath): "/hello/world",
string(StringScheme): "https",
"Content-Length": strconv.Itoa(len(msg)),
})
h3 := makeHeaders(7, c.enc, true, map[string]string{
h3 := makeHeaders(7, c.enc, true, true, map[string]string{
string(StringAuthority): "localhost",
string(StringMethod): "GET",
string(StringPath): "/hello/world",
Expand Down Expand Up @@ -153,3 +153,78 @@ func testIssue52(t *testing.T) {
t.Fatalf("expected EOF, got %s", err)
}
}

func TestIssue27(t *testing.T) {
s := &Server{
s: &fasthttp.Server{
Handler: func(ctx *fasthttp.RequestCtx) {
io.WriteString(ctx, "Hello world")
},
ReadTimeout: time.Second * 1,
},
cnf: ServerConfig{
Debug: false,
},
}

c, ln, err := getConn(s)
if err != nil {
t.Fatal(err)
}
defer c.Close()
defer ln.Close()

msg := []byte("Hello world, how are you doing?")

h1 := makeHeaders(3, c.enc, true, false, map[string]string{
string(StringAuthority): "localhost",
string(StringMethod): "POST",
string(StringPath): "/hello/world",
string(StringScheme): "https",
"Content-Length": strconv.Itoa(len(msg)),
})
h2 := makeHeaders(5, c.enc, true, false, map[string]string{
string(StringAuthority): "localhost",
string(StringMethod): "POST",
string(StringPath): "/hello/world",
string(StringScheme): "https",
"Content-Length": strconv.Itoa(len(msg)),
})
h3 := makeHeaders(7, c.enc, false, false, map[string]string{
string(StringAuthority): "localhost",
string(StringMethod): "GET",
string(StringPath): "/hello/world",
string(StringScheme): "https",
"Content-Length": strconv.Itoa(len(msg)),
})

c.writeFrame(h1)
c.writeFrame(h2)

time.Sleep(time.Second)
c.writeFrame(h3)

id := uint32(3)

for i := 0; i < 3; i++ {
fr, err := c.readNext()
if err != nil {
t.Fatal(err)
}

if fr.Stream() != id {
t.Fatalf("Expecting update on stream %d, got %d", id, fr.Stream())
}

if fr.Type() != FrameResetStream {
t.Fatalf("Expecting Reset, got %s", fr.Type())
}

rst := fr.Body().(*RstStream)
if rst.Code() != StreamCanceled {
t.Fatalf("Expecting StreamCanceled, got %s", rst.Code())
}

id += 2
}
}

0 comments on commit 1eb5ca1

Please sign in to comment.