Skip to content

Commit

Permalink
Graceful shutdown (#32)
Browse files Browse the repository at this point in the history
  • Loading branch information
nitely authored Dec 13, 2024
1 parent 3d1f154 commit 5af7276
Show file tree
Hide file tree
Showing 10 changed files with 230 additions and 65 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ tests/functional/tflowcontrol
tests/functional/tcancel
tests/functional/tcancelremote
tests/functional/tmisc
tests/functional/tgracefulclose
src/hyperx/client
src/hyperx/server
src/hyperx/clientserver
Expand Down
1 change: 1 addition & 0 deletions hyperx.nimble
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ task functest, "Func test":
exec "nim c -r -d:release tests/functional/tcancel.nim"
exec "nim c -r -d:release tests/functional/tcancelremote.nim"
exec "nim c -r -d:release tests/functional/tmisc.nim"
exec "nim c -r -d:release tests/functional/tgracefulclose.nim"

task funcserveinsec, "Func Serve Insecure":
exec "nim c -r -d:release tests/functional/tserverinsecure.nim"
Expand Down
3 changes: 2 additions & 1 deletion src/hyperx/client.nim
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ export
ClientContext,
HyperxConnError,
HyperxStrmError,
HyperxError
HyperxError,
isGracefulClose

var sslContext {.threadvar.}: SslContext

Expand Down
116 changes: 76 additions & 40 deletions src/hyperx/clientserver.nim
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,12 @@ type
hostname*: string
port: Port
isConnected*: bool
isGracefulShutdown: bool
headersEnc, headersDec: DynHeaders
streams: Streams
recvMsgs: QueueAsync[Frame]
streamOpenedMsgs*: QueueAsync[Stream]
currStreamId, maxPeerStrmIdSeen: StreamId
currStreamId: StreamId
peerMaxConcurrentStreams: uint32
peerWindowSize: uint32
peerWindow: int32 # can be negative
Expand All @@ -159,13 +160,13 @@ proc newClient*(
hostname: hostname,
port: port,
isConnected: false,
isGracefulShutdown: false,
headersEnc: initDynHeaders(stgHeaderTableSize.int),
headersDec: initDynHeaders(stgHeaderTableSize.int),
streams: initStreams(),
currStreamId: 1.StreamId,
currStreamId: 0.StreamId,
recvMsgs: newQueue[Frame](10),
streamOpenedMsgs: newQueue[Stream](10),
maxPeerStrmIdSeen: 0.StreamId,
peerMaxConcurrentStreams: stgInitialMaxConcurrentStreams,
peerWindow: stgInitialWindowSize.int32,
peerWindowSize: stgInitialWindowSize,
Expand Down Expand Up @@ -209,12 +210,20 @@ func openMainStream(client: ClientContext): Stream {.raises: [StreamsClosedError
doAssert frmSidMain.StreamId notin client.streams
result = client.streams.open(frmSidMain.StreamId, client.peerWindowSize.int32)

func openStream(client: ClientContext): Stream {.raises: [StreamsClosedError].} =
func openStream(client: ClientContext): Stream {.raises: [StreamsClosedError, GracefulShutdownError].} =
# XXX some error if max sid is reached
# XXX error if maxStreams is reached
result = client.streams.open(client.currStreamId, client.peerWindowSize.int32)
# client uses odd numbers, and server even numbers
client.currStreamId += 2.StreamId
doAssert client.typ == ctClient
check not client.isGracefulShutdown, newGracefulShutdownError()
var sid = client.currStreamId.uint32
sid += (if sid == 0: 1 else: 2)
result = client.streams.open(StreamId sid, client.peerWindowSize.int32)
client.currStreamId = StreamId sid

func maxPeerStreamIdSeen(client: ClientContext): StreamId {.raises: [].} =
case client.typ
of ctClient: StreamId 0
of ctServer: client.currStreamId

when defined(hyperxStats):
func echoStats*(client: ClientContext) =
Expand Down Expand Up @@ -377,7 +386,6 @@ const serverHandshakeBlob = handshakeBlob(ctServer)
proc handshakeNaked(client: ClientContext) {.async.} =
doAssert client.isConnected
debugInfo "handshake"
# we need to do this before sending any other frame
let strm = client.openMainStream()
doAssert strm.id == frmSidMain.StreamId
check not client.sock.isClosed, newConnClosedError()
Expand Down Expand Up @@ -417,10 +425,6 @@ func doTransitionRecv(s: Stream, frm: Frame) {.raises: [ConnError, StrmError].}
raise newConnError(errStreamClosed)
raise newConnError(errProtocolError)
s.state = nextState
#if oldState == strmIdle:
# # XXX do this elsewhere not here
# # XXX close streams < s.id in idle state
# discard

proc readUntilEnd(client: ClientContext, frm: Frame) {.async.} =
## Read continuation frames until ``END_HEADERS`` flag is set
Expand Down Expand Up @@ -521,7 +525,7 @@ proc recvTask(client: ClientContext) {.async.} =
# XXX close queues
client.error = newConnError(err.code)
await client.sendSilently newGoAwayFrame(
client.maxPeerStrmIdSeen.int, err.code.int
client.maxPeerStreamIdSeen.int, err.code.int
)
#client.close()
raise err
Expand Down Expand Up @@ -613,10 +617,15 @@ proc consumeMainStream(client: ClientContext, frm: Frame) {.async.} =
if not strm.pingSig.isClosed:
strm.pingSig.trigger()
of frmtGoAway:
# XXX close streams lower than Last-Stream-ID
# XXX don't allow new streams creation
# the connection is still ok for streams lower than Last-Stream-ID
discard
client.isGracefulShutdown = true
client.error ?= newConnError frm.errorCode()
# streams are never created by ctServer,
# so there are no streams to close
if client.typ == ctClient:
let sid = frm.lastStreamId()
for strm in values client.streams:
if strm.id.uint32 > sid:
client.streams.close(strm.id)
else:
doAssert frm.typ notin connFrmAllowed
raise newConnError(errProtocolError)
Expand All @@ -641,19 +650,20 @@ proc recvDispatcherNaked(client: ClientContext) {.async.} =
await consumeMainStream(client, frm)
continue
check frm.typ in frmStreamAllowed, newConnError(errProtocolError)
check frm.sid.int mod 2 != 0, newConnError(errProtocolError)
if client.typ == ctServer and
frm.sid.StreamId > client.maxPeerStrmIdSeen and
frm.sid.int mod 2 != 0:
frm.sid.StreamId > client.currStreamId:
check client.streams.len <= stgServerMaxConcurrentStreams,
newConnError(errProtocolError)
client.maxPeerStrmIdSeen = frm.sid.StreamId
# we do not store idle streams, so no need to close them
let strm = client.streams.open(frm.sid.StreamId, client.peerWindowSize.int32)
await client.streamOpenedMsgs.put strm
if client.typ == ctClient and
frm.sid.StreamId > client.maxPeerStrmIdSeen and
frm.sid.int mod 2 == 0:
client.maxPeerStrmIdSeen = frm.sid.StreamId
if client.isGracefulShutdown:
await client.send newGoAwayFrame(
client.maxPeerStreamIdSeen.int, errNoError.int
)
else:
client.currStreamId = frm.sid.StreamId
# we do not store idle streams, so no need to close them
let strm = client.streams.open(frm.sid.StreamId, client.peerWindowSize.int32)
await client.streamOpenedMsgs.put strm
if frm.typ == frmtHeaders:
headers.setLen 0
client.hpackDecode(headers, frm.payload)
Expand All @@ -667,13 +677,18 @@ proc recvDispatcherNaked(client: ClientContext) {.async.} =
check frm.windowSizeInc > 0, newConnError(errProtocolError)
if frm.typ == frmtPushPromise:
check client.typ == ctClient, newConnError(errProtocolError)
# Process headers even if the stream
# does not exist
# Process headers even if the stream does not exist
if frm.sid.StreamId notin client.streams:
if frm.typ == frmtData:
client.windowPending -= frm.payloadLen.int
check frm.typ in {frmtRstStream, frmtWindowUpdate},
newConnError errStreamClosed
client.windowProcessed += frm.payloadLen.int
if client.windowProcessed > stgWindowSize.int div 2:
client.windowUpdateSig.trigger()
if client.typ == ctServer and
frm.sid.StreamId > client.currStreamId:
doAssert client.isGracefulShutdown
else:
check frm.typ in {frmtRstStream, frmtWindowUpdate},
newConnError errStreamClosed
debugInfo "stream not found " & $frm.sid.int
continue
var stream = client.streams.get frm.sid.StreamId
Expand Down Expand Up @@ -701,7 +716,7 @@ proc recvDispatcher(client: ClientContext) {.async.} =
if client.isConnected:
client.error = newConnError(err.code)
await client.sendSilently newGoAwayFrame(
client.maxPeerStrmIdSeen.int, err.code.int
client.maxPeerStreamIdSeen.int, err.code.int
)
raise err
except StrmError:
Expand Down Expand Up @@ -1023,7 +1038,7 @@ proc recvTask(strm: ClientStream) {.async.} =
if client.isConnected:
client.error = newConnError(err.code)
await client.sendSilently newGoAwayFrame(
client.maxPeerStrmIdSeen.int, err.code.int
client.maxPeerStreamIdSeen.int, err.code.int
)
raise err
except StrmError as err:
Expand Down Expand Up @@ -1222,16 +1237,19 @@ template with*(strm: ClientStream, body: untyped): untyped =
strm.windowEnd()
await failSilently(recvFut)

proc ping(strm: ClientStream) {.async.} =
# this is done for rst pings; only one stream ping
proc ping(client: ClientContext, strm: Stream) {.async.} =
# this is done for rst and go-away pings; only one stream ping
# will ever be in progress
if strm.stream.pingSig.len > 0:
await strm.stream.pingSig.waitFor()
if strm.pingSig.len > 0:
await strm.pingSig.waitFor()
else:
let sig = strm.stream.pingSig.waitFor()
await strm.client.send newPingFrame(strm.stream.id.uint32)
let sig = strm.pingSig.waitFor()
await client.send newPingFrame(strm.id.uint32)
await sig

proc ping(strm: ClientStream) {.async.} =
await strm.client.ping(strm.stream)

proc cancel*(strm: ClientStream, code: ErrorCode) {.async.} =
## This may never return until the stream/conn is closed.
## This can be called multiple times concurrently,
Expand All @@ -1246,6 +1264,24 @@ proc cancel*(strm: ClientStream, code: ErrorCode) {.async.} =
strm.stream.error ?= newStrmError(errStreamClosed)
strm.close()

proc gracefulClose*(client: ClientContext) {.async.} =
# returning early is ok
if client.isGracefulShutdown:
return
# fail silently because it's best effort,
# setting isGracefulShutdown is the only important thing
await failSilently client.send newGoAwayFrame(
int32.high, errNoError.int
)
await failSilently client.ping client.streams.get(StreamId 0)
client.isGracefulShutdown = true
await failSilently client.send newGoAwayFrame(
client.maxPeerStreamIdSeen.int, errNoError.int
)

proc isGracefulClose*(client: ClientContext): bool {.raises: [].} =
result = client.isGracefulShutdown

when defined(hyperxTest):
proc putRecvTestData*(client: ClientContext, data: seq[byte]) {.async.} =
await client.sock.putRecvData data
Expand Down
17 changes: 14 additions & 3 deletions src/hyperx/errors.nim
Original file line number Diff line number Diff line change
Expand Up @@ -56,21 +56,27 @@ type
ConnClosedError* = object of HyperxConnError
ConnError* = object of HyperxConnError
code*: ErrorCode
GracefulShutdownError* = ConnError
StrmError* = object of HyperxStrmError
typ*: HyperxErrTyp
code*: ErrorCode
QueueError* = object of HyperxError
QueueClosedError* = object of QueueError
QueueClosedError* = object of HyperxError

func newHyperxConnError*(msg: string): ref HyperxConnError {.raises: [].} =
result = (ref HyperxConnError)(msg: msg)

func newConnClosedError*(): ref ConnClosedError {.raises: [].} =
func newConnClosedError*: ref ConnClosedError {.raises: [].} =
result = (ref ConnClosedError)(msg: "Connection Closed")

func newConnError*(errCode: ErrorCode): ref ConnError {.raises: [].} =
result = (ref ConnError)(code: errCode, msg: "Connection Error: " & $errCode)

func newConnError*(errCode: uint32): ref ConnError {.raises: [].} =
result = (ref ConnError)(
code: errCode.toErrorCode,
msg: "Connection Error: " & $errCode.toErrorCode
)

func newStrmError*(errCode: ErrorCode, typ = hxLocalErr): ref StrmError {.raises: [].} =
let msg = case typ
of hxLocalErr: "Stream Error: " & $errCode
Expand All @@ -90,3 +96,8 @@ func newErrorOrDefault*(err, default: ref StrmError): ref StrmError {.raises: []
return newError(err)
else:
return default

func newGracefulShutdownError*(): ref GracefulShutdownError {.raises: [].} =
result = (ref GracefulShutdownError)(
code: errNoError, msg: "Connection Error: " & $errNoError
)
37 changes: 21 additions & 16 deletions src/hyperx/frame.nim
Original file line number Diff line number Diff line change
Expand Up @@ -360,12 +360,20 @@ func windowSizeInc*(frm: Frame): uint {.raises: [].} =
result.clearBit 31 # clear reserved byte

func errorCode*(frm: Frame): uint32 {.raises: [].} =
doAssert frm.typ == frmtRstStream
result = 0
result += frm.s[frmHeaderSize+0].uint32 shl 24
result += frm.s[frmHeaderSize+1].uint32 shl 16
result += frm.s[frmHeaderSize+2].uint32 shl 8
result += frm.s[frmHeaderSize+3].uint32
result = 0'u32
case frm.typ
of frmtRstStream:
result += frm.s[frmHeaderSize+0].uint32 shl 24
result += frm.s[frmHeaderSize+1].uint32 shl 16
result += frm.s[frmHeaderSize+2].uint32 shl 8
result += frm.s[frmHeaderSize+3].uint32
of frmtGoAway:
result += frm.s[frmHeaderSize+4].uint32 shl 24
result += frm.s[frmHeaderSize+5].uint32 shl 16
result += frm.s[frmHeaderSize+6].uint32 shl 8
result += frm.s[frmHeaderSize+7].uint32
else:
doAssert false

func pingData*(frm: Frame): uint32 {.raises: [].} =
# note we ignore the last 4 bytes
Expand All @@ -376,16 +384,13 @@ func pingData*(frm: Frame): uint32 {.raises: [].} =
result += frm.s[frmHeaderSize+2].uint32 shl 8
result += frm.s[frmHeaderSize+3].uint32

# XXX add padding field and padding as payload
#func setPadding*(frm: Frame, n: FrmPadding) =
# doAssert frm.typ in {frmtData, frmtHeaders, frmtPushPromise}

#func add*(frm: Frame, payload: openArray[byte]) =
# frm.s.add payload
# frm.setPayloadLen FrmPayloadLen(frm.rawLen-frmHeaderSize)

#template payload*(frm: Frame): untyped =
# toOpenArray(frm.s, frmHeaderSize, frm.s.len-1)
func lastStreamId*(frm: Frame): uint32 =
doAssert frm.typ == frmtGoAway
result = 0'u32
result += frm.s[frmHeaderSize+0].uint32 shl 24
result += frm.s[frmHeaderSize+1].uint32 shl 16
result += frm.s[frmHeaderSize+2].uint32 shl 8
result += frm.s[frmHeaderSize+3].uint32

func `$`*(frm: Frame): string {.raises: [].} =
result = ""
Expand Down
4 changes: 3 additions & 1 deletion src/hyperx/server.nim
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ export
ClientContext,
HyperxConnError,
HyperxStrmError,
HyperxError
HyperxError,
gracefulClose,
isGracefulClose

var sslContext {.threadvar.}: SslContext

Expand Down
10 changes: 7 additions & 3 deletions src/hyperx/stream.nim
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import ./frame
import ./value
import ./signal
import ./errors
import ./utils

# Section 5.1
type
Expand Down Expand Up @@ -198,8 +199,12 @@ proc close*(stream: Stream) {.raises: [].} =
stream.peerWindowUpdateSig.close()
stream.pingSig.close()

type StreamsClosedError* = object of QueueClosedError

func newStreamsClosedError*(msg: string): ref StreamsClosedError {.raises: [].} =
result = (ref StreamsClosedError)(msg: msg)

type
StreamsClosedError* = object of HyperxError
Streams* = object
t: Table[StreamId, Stream]
isClosed: bool
Expand Down Expand Up @@ -232,8 +237,7 @@ func open*(
peerWindow: int32
): Stream {.raises: [StreamsClosedError].} =
doAssert sid notin s.t, $sid.int
if s.isClosed:
raise newException(StreamsClosedError, "Cannot open stream")
check not s.isClosed, newStreamsClosedError("Cannot open stream")
result = newStream(sid, peerWindow)
s.t[sid] = result

Expand Down
Loading

0 comments on commit 5af7276

Please sign in to comment.