Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: stun and dtls close #22

Merged
merged 4 commits into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions tests/testdtls.nim
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,40 @@ suite "DTLS":
await allFutures(dtls1.stop(), dtls2.stop(), dtls3.stop())
await allFutures(stun1.stop(), stun2.stop(), stun3.stop())
await allFutures(udp1.close(), udp2.close(), udp3.close())

asyncTest "Two DTLS nodes connecting to each other, closing the created connections then re-connect the nodes":
# Related to https://github.com/vacp2p/nim-webrtc/pull/22
let
localAddr1 = initTAddress("127.0.0.1:4444")
localAddr2 = initTAddress("127.0.0.1:5555")
udp1 = UdpTransport.new(localAddr1)
udp2 = UdpTransport.new(localAddr2)
stun1 = Stun.new(udp1)
stun2 = Stun.new(udp2)
dtls1 = Dtls.new(stun1)
dtls2 = Dtls.new(stun2)
var
serverConnFut = dtls1.accept()
clientConn = await dtls2.connect(localAddr1)
serverConn = await serverConnFut

await serverConn.write(@[1'u8, 2, 3, 4])
await clientConn.write(@[5'u8, 6, 7, 8])
check (await serverConn.read()) == @[5'u8, 6, 7, 8]
check (await clientConn.read()) == @[1'u8, 2, 3, 4]
await allFutures(serverConn.close(), clientConn.close())
check serverConn.isClosed() and clientConn.isClosed()

serverConnFut = dtls1.accept()
clientConn = await dtls2.connect(localAddr1)
serverConn = await serverConnFut

await serverConn.write(@[5'u8, 6, 7, 8])
await clientConn.write(@[1'u8, 2, 3, 4])
check (await serverConn.read()) == @[1'u8, 2, 3, 4]
check (await clientConn.read()) == @[5'u8, 6, 7, 8]

await allFutures(serverConn.close(), clientConn.close())
await allFutures(dtls1.stop(), dtls2.stop())
await allFutures(stun1.stop(), stun2.stop())
await allFutures(udp1.close(), udp2.close())
21 changes: 14 additions & 7 deletions webrtc/dtls/dtls_connection.nim
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ logScope:
const DtlsConnTracker* = "webrtc.dtls.conn"

type
DtlsConnOnClose* = proc() {.raises: [], gcsafe.}

MbedTLSCtx = object
ssl: mbedtls_ssl_context
config: mbedtls_ssl_config
Expand All @@ -34,7 +36,7 @@ type
DtlsConn* = ref object
# DtlsConn is a Dtls connection receiving and sending data using
# the underlying Stun Connection
conn*: StunConn # The wrapper protocol Stun Connection
conn: StunConn # The wrapper protocol Stun Connection
raddr: TransportAddress # Remote address
dataRecv: seq[byte] # data received which will be read by SCTP
dataToSend: seq[byte]
Expand All @@ -43,7 +45,7 @@ type

# Close connection management
closed: bool
closeEvent: AsyncEvent
onClose: seq[DtlsConnOnClose]

# Local and Remote certificate, needed by wrapped protocol DataChannel
# and by libp2p
Expand All @@ -53,6 +55,9 @@ type
# Mbed-TLS contexts
ctx: MbedTLSCtx

proc isClosed*(self: DtlsConn): bool =
return self.closed

proc getRemoteCertificateCallback(
ctx: pointer, pcert: ptr mbedtls_x509_crt, state: cint, pflags: ptr uint32
): cint {.cdecl.} =
Expand Down Expand Up @@ -100,7 +105,6 @@ proc new*(T: type DtlsConn, conn: StunConn): T =
var self = T(conn: conn)
self.raddr = conn.raddr
self.closed = false
self.closeEvent = newAsyncEvent()
return self

proc dtlsConnInit(self: DtlsConn) =
Expand Down Expand Up @@ -160,10 +164,10 @@ proc connectInit*(self: DtlsConn, ctr_drbg: mbedtls_ctr_drbg_context) =
except MbedTLSError as exc:
raise newException(WebRtcError, "DTLS - Connect initialization: " & exc.msg, exc)

proc join*(self: DtlsConn) {.async: (raises: [CancelledError]).} =
## Wait for the Dtls Connection to be closed
proc addOnClose*(self: DtlsConn, onCloseProc: DtlsConnOnClose) =
## Adds a proc to be called when DtlsConn is closed
##
await self.closeEvent.wait()
self.onClose.add(onCloseProc)

proc dtlsHandshake*(
self: DtlsConn, isServer: bool
Expand Down Expand Up @@ -217,7 +221,10 @@ proc close*(self: DtlsConn) {.async: (raises: [CancelledError, WebRtcError]).} =
await self.conn.write(self.dataToSend)
self.dataToSend = @[]
untrackCounter(DtlsConnTracker)
self.closeEvent.fire()
await self.conn.close()
for onCloseProc in self.onClose:
onCloseProc()
self.onClose = @[]

proc write*(
self: DtlsConn, msg: seq[byte]
Expand Down
31 changes: 10 additions & 21 deletions webrtc/dtls/dtls_transport.nim
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,8 @@ logScope:
const DtlsTransportTracker* = "webrtc.dtls.transport"

type
DtlsConnAndCleanup = object
connection: DtlsConn
cleanup: Future[void].Raising([])

Dtls* = ref object of RootObj
connections: Table[TransportAddress, DtlsConnAndCleanup]
connections: Table[TransportAddress, DtlsConn]
transport: Stun
laddr: TransportAddress
started: bool
Expand All @@ -46,7 +42,7 @@ type

proc new*(T: type Dtls, transport: Stun): T =
var self = T(
connections: initTable[TransportAddress, DtlsConnAndCleanup](),
connections: initTable[TransportAddress, DtlsConn](),
transport: transport,
laddr: transport.laddr,
started: true,
Expand All @@ -72,10 +68,8 @@ proc stop*(self: Dtls) {.async: (raises: [CancelledError]).} =

self.started = false
let
allCloses = toSeq(self.connections.values()).mapIt(it.connection.close())
allCleanup = toSeq(self.connections.values()).mapIt(it.cleanup)
allCloses = toSeq(self.connections.values()).mapIt(it.close())
await noCancel allFutures(allCloses)
await noCancel allFutures(allCleanup)
untrackCounter(DtlsTransportTracker)

proc localCertificate*(self: Dtls): seq[byte] =
Expand All @@ -85,14 +79,11 @@ proc localCertificate*(self: Dtls): seq[byte] =
proc localAddress*(self: Dtls): TransportAddress =
self.laddr

proc cleanupDtlsConn(self: Dtls, conn: DtlsConn) {.async: (raises: []).} =
# Waiting for a connection to be closed to remove it from the table
try:
await conn.join()
except CancelledError as exc:
discard

self.connections.del(conn.remoteAddress())
proc addConnToTable(self: Dtls, conn: DtlsConn) =
proc cleanup() =
self.connections.del(conn.remoteAddress())
self.connections[conn.remoteAddress()] = conn
conn.addOnClose(cleanup)

proc accept*(
self: Dtls
Expand All @@ -114,8 +105,7 @@ proc accept*(
self.ctr_drbg, self.serverPrivKey, self.serverCert, self.localCert
)
await res.dtlsHandshake(true)
self.connections[raddr] =
DtlsConnAndCleanup(connection: res, cleanup: self.cleanupDtlsConn(res))
self.addConnToTable(res)
break
except WebRtcError as exc:
trace "Handshake fails, try accept another connection", raddr, error = exc.msg
Expand All @@ -136,8 +126,7 @@ proc connect*(

try:
await res.dtlsHandshake(false)
self.connections[raddr] =
DtlsConnAndCleanup(connection: res, cleanup: self.cleanupDtlsConn(res))
self.addConnToTable(res)
except WebRtcError as exc:
trace "Handshake fails", raddr, error = exc.msg
self.connections.del(raddr)
Expand Down
16 changes: 10 additions & 6 deletions webrtc/stun/stun_connection.nim
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ type
StunUsernameProvider* = proc(): string {.raises: [], gcsafe.}
StunUsernameChecker* = proc(username: seq[byte]): bool {.raises: [], gcsafe.}
StunPasswordProvider* = proc(username: seq[byte]): seq[byte] {.raises: [], gcsafe.}
StunConnOnClose* = proc() {.raises: [], gcsafe.}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It can be only OnClose, it is clear from the context it is related to a Stun conn.


StunConn* = ref object
udp*: UdpTransport # The wrapper protocol: UDP Transport
Expand All @@ -37,8 +38,10 @@ type
stunMsgs*: AsyncQueue[seq[byte]] # stun messages received and to be
# processed by the stun message handler
handlesFut*: Future[void] # Stun Message handler
closeEvent: AsyncEvent

# Close connection management
closed*: bool
onClose: seq[StunConnOnClose]

# Is ice-controlling and iceTiebreaker, not fully implemented yet.
iceControlling: bool
Expand Down Expand Up @@ -201,7 +204,6 @@ proc new*(
laddr: udp.laddr,
raddr: raddr,
closed: false,
closeEvent: newAsyncEvent(),
dataRecv: newAsyncQueue[seq[byte]](StunMaxQueuingMessages),
stunMsgs: newAsyncQueue[seq[byte]](StunMaxQueuingMessages),
iceControlling: iceControlling,
Expand All @@ -215,10 +217,10 @@ proc new*(
trackCounter(StunConnectionTracker)
return self

proc join*(self: StunConn) {.async: (raises: [CancelledError]).} =
## Wait for the Stun Connection to be closed
proc addOnClose*(self: StunConn, onCloseProc: StunConnOnClose) =
## Adds a proc to be called when StunConn is closed
##
await self.closeEvent.wait()
self.onClose.add(onCloseProc)

proc close*(self: StunConn) {.async: (raises: []).} =
## Close a Stun Connection
Expand All @@ -227,7 +229,9 @@ proc close*(self: StunConn) {.async: (raises: []).} =
debug "Try to close an already closed StunConn"
return
await self.handlesFut.cancelAndWait()
self.closeEvent.fire()
for onCloseProc in self.onClose:
onCloseProc()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can pass the conn as param.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know if it's a good idea, the conn is useless at this point (because it's closed)

self.onClose = @[]
self.closed = true
untrackCounter(StunConnectionTracker)

Expand Down
19 changes: 8 additions & 11 deletions webrtc/stun/stun_transport.nim
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ type

rng: ref HmacDrbgContext

proc addConnToTable(self: Stun, conn: StunConn) =
proc cleanup() =
self.connections.del(conn.raddr)
self.connections[conn.raddr] = conn
conn.addOnClose(cleanup)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then cleanup doesn't need to be a closure and can be passed when the conn is created. addConnToTable won't be necessary.


proc accept*(self: Stun): Future[StunConn] {.async: (raises: [CancelledError]).} =
## Accept a Stun Connection
##
Expand All @@ -53,27 +59,18 @@ proc connect*(
do:
let res = StunConn.new(self.udp, raddr, false, self.usernameProvider,
self.usernameChecker, self.passwordProvider, self.rng)
self.connections[raddr] = res
self.addConnToTable(res)
return res

proc cleanupStunConn(self: Stun, conn: StunConn) {.async: (raises: []).} =
# Waiting for a connection to be closed to remove it from the table
try:
await conn.join()
self.connections.del(conn.raddr)
except CancelledError as exc:
warn "Error cleaning up Stun Connection", error=exc.msg

proc stunReadLoop(self: Stun) {.async: (raises: [CancelledError]).} =
while true:
let (buf, raddr) = await self.udp.read()
var stunConn: StunConn
if not self.connections.hasKey(raddr):
stunConn = StunConn.new(self.udp, raddr, true, self.usernameProvider,
self.usernameChecker, self.passwordProvider, self.rng)
self.connections[raddr] = stunConn
self.addConnToTable(stunConn)
await self.pendingConn.addLast(stunConn)
asyncSpawn self.cleanupStunConn(stunConn)
else:
try:
stunConn = self.connections[raddr]
Expand Down
Loading