diff --git a/tests/testdtls.nim b/tests/testdtls.nim index 871289a..2ddfc70 100644 --- a/tests/testdtls.nim +++ b/tests/testdtls.nim @@ -81,3 +81,40 @@ suite "DTLS": await allFutures(dtls1.stop(), dtls2.stop(), dtls3.stop()) await allFutures(stun1.stop(), stun2.stop(), stun3.stop()) await allFutures(udp1.close(), udp2.close(), udp3.close()) + + asyncTest "Two DTLS nodes connecting to each other, closing the created connections then re-connect the nodes": + # Related to https://github.com/vacp2p/nim-webrtc/pull/22 + let + localAddr1 = initTAddress("127.0.0.1:4444") + localAddr2 = initTAddress("127.0.0.1:5555") + udp1 = UdpTransport.new(localAddr1) + udp2 = UdpTransport.new(localAddr2) + stun1 = Stun.new(udp1) + stun2 = Stun.new(udp2) + dtls1 = Dtls.new(stun1) + dtls2 = Dtls.new(stun2) + var + serverConnFut = dtls1.accept() + clientConn = await dtls2.connect(localAddr1) + serverConn = await serverConnFut + + await serverConn.write(@[1'u8, 2, 3, 4]) + await clientConn.write(@[5'u8, 6, 7, 8]) + check (await serverConn.read()) == @[5'u8, 6, 7, 8] + check (await clientConn.read()) == @[1'u8, 2, 3, 4] + await allFutures(serverConn.close(), clientConn.close()) + check serverConn.isClosed() and clientConn.isClosed() + + serverConnFut = dtls1.accept() + clientConn = await dtls2.connect(localAddr1) + serverConn = await serverConnFut + + await serverConn.write(@[5'u8, 6, 7, 8]) + await clientConn.write(@[1'u8, 2, 3, 4]) + check (await serverConn.read()) == @[1'u8, 2, 3, 4] + check (await clientConn.read()) == @[5'u8, 6, 7, 8] + + await allFutures(serverConn.close(), clientConn.close()) + await allFutures(dtls1.stop(), dtls2.stop()) + await allFutures(stun1.stop(), stun2.stop()) + await allFutures(udp1.close(), udp2.close()) diff --git a/webrtc/dtls/dtls_connection.nim b/webrtc/dtls/dtls_connection.nim index cce1ca5..b75757f 100644 --- a/webrtc/dtls/dtls_connection.nim +++ b/webrtc/dtls/dtls_connection.nim @@ -21,6 +21,8 @@ logScope: const DtlsConnTracker* = "webrtc.dtls.conn" type + DtlsConnOnClose* = proc() {.raises: [], gcsafe.} + MbedTLSCtx = object ssl: mbedtls_ssl_context config: mbedtls_ssl_config @@ -34,7 +36,7 @@ type DtlsConn* = ref object # DtlsConn is a Dtls connection receiving and sending data using # the underlying Stun Connection - conn*: StunConn # The wrapper protocol Stun Connection + conn: StunConn # The wrapper protocol Stun Connection raddr: TransportAddress # Remote address dataRecv: seq[byte] # data received which will be read by SCTP dataToSend: seq[byte] @@ -43,7 +45,7 @@ type # Close connection management closed: bool - closeEvent: AsyncEvent + onClose: seq[DtlsConnOnClose] # Local and Remote certificate, needed by wrapped protocol DataChannel # and by libp2p @@ -53,6 +55,9 @@ type # Mbed-TLS contexts ctx: MbedTLSCtx +proc isClosed*(self: DtlsConn): bool = + return self.closed + proc getRemoteCertificateCallback( ctx: pointer, pcert: ptr mbedtls_x509_crt, state: cint, pflags: ptr uint32 ): cint {.cdecl.} = @@ -100,7 +105,6 @@ proc new*(T: type DtlsConn, conn: StunConn): T = var self = T(conn: conn) self.raddr = conn.raddr self.closed = false - self.closeEvent = newAsyncEvent() return self proc dtlsConnInit(self: DtlsConn) = @@ -160,10 +164,10 @@ proc connectInit*(self: DtlsConn, ctr_drbg: mbedtls_ctr_drbg_context) = except MbedTLSError as exc: raise newException(WebRtcError, "DTLS - Connect initialization: " & exc.msg, exc) -proc join*(self: DtlsConn) {.async: (raises: [CancelledError]).} = - ## Wait for the Dtls Connection to be closed +proc addOnClose*(self: DtlsConn, onCloseProc: DtlsConnOnClose) = + ## Adds a proc to be called when DtlsConn is closed ## - await self.closeEvent.wait() + self.onClose.add(onCloseProc) proc dtlsHandshake*( self: DtlsConn, isServer: bool @@ -217,7 +221,10 @@ proc close*(self: DtlsConn) {.async: (raises: [CancelledError, WebRtcError]).} = await self.conn.write(self.dataToSend) self.dataToSend = @[] untrackCounter(DtlsConnTracker) - self.closeEvent.fire() + await self.conn.close() + for onCloseProc in self.onClose: + onCloseProc() + self.onClose = @[] proc write*( self: DtlsConn, msg: seq[byte] diff --git a/webrtc/dtls/dtls_transport.nim b/webrtc/dtls/dtls_transport.nim index 273234b..9efd6d5 100644 --- a/webrtc/dtls/dtls_transport.nim +++ b/webrtc/dtls/dtls_transport.nim @@ -28,12 +28,8 @@ logScope: const DtlsTransportTracker* = "webrtc.dtls.transport" type - DtlsConnAndCleanup = object - connection: DtlsConn - cleanup: Future[void].Raising([]) - Dtls* = ref object of RootObj - connections: Table[TransportAddress, DtlsConnAndCleanup] + connections: Table[TransportAddress, DtlsConn] transport: Stun laddr: TransportAddress started: bool @@ -46,7 +42,7 @@ type proc new*(T: type Dtls, transport: Stun): T = var self = T( - connections: initTable[TransportAddress, DtlsConnAndCleanup](), + connections: initTable[TransportAddress, DtlsConn](), transport: transport, laddr: transport.laddr, started: true, @@ -72,10 +68,8 @@ proc stop*(self: Dtls) {.async: (raises: [CancelledError]).} = self.started = false let - allCloses = toSeq(self.connections.values()).mapIt(it.connection.close()) - allCleanup = toSeq(self.connections.values()).mapIt(it.cleanup) + allCloses = toSeq(self.connections.values()).mapIt(it.close()) await noCancel allFutures(allCloses) - await noCancel allFutures(allCleanup) untrackCounter(DtlsTransportTracker) proc localCertificate*(self: Dtls): seq[byte] = @@ -85,14 +79,11 @@ proc localCertificate*(self: Dtls): seq[byte] = proc localAddress*(self: Dtls): TransportAddress = self.laddr -proc cleanupDtlsConn(self: Dtls, conn: DtlsConn) {.async: (raises: []).} = - # Waiting for a connection to be closed to remove it from the table - try: - await conn.join() - except CancelledError as exc: - discard - - self.connections.del(conn.remoteAddress()) +proc addConnToTable(self: Dtls, conn: DtlsConn) = + proc cleanup() = + self.connections.del(conn.remoteAddress()) + self.connections[conn.remoteAddress()] = conn + conn.addOnClose(cleanup) proc accept*( self: Dtls @@ -114,8 +105,7 @@ proc accept*( self.ctr_drbg, self.serverPrivKey, self.serverCert, self.localCert ) await res.dtlsHandshake(true) - self.connections[raddr] = - DtlsConnAndCleanup(connection: res, cleanup: self.cleanupDtlsConn(res)) + self.addConnToTable(res) break except WebRtcError as exc: trace "Handshake fails, try accept another connection", raddr, error = exc.msg @@ -136,8 +126,7 @@ proc connect*( try: await res.dtlsHandshake(false) - self.connections[raddr] = - DtlsConnAndCleanup(connection: res, cleanup: self.cleanupDtlsConn(res)) + self.addConnToTable(res) except WebRtcError as exc: trace "Handshake fails", raddr, error = exc.msg self.connections.del(raddr) diff --git a/webrtc/stun/stun_connection.nim b/webrtc/stun/stun_connection.nim index 2f119da..9d661a2 100644 --- a/webrtc/stun/stun_connection.nim +++ b/webrtc/stun/stun_connection.nim @@ -28,6 +28,7 @@ type StunUsernameProvider* = proc(): string {.raises: [], gcsafe.} StunUsernameChecker* = proc(username: seq[byte]): bool {.raises: [], gcsafe.} StunPasswordProvider* = proc(username: seq[byte]): seq[byte] {.raises: [], gcsafe.} + StunConnOnClose* = proc() {.raises: [], gcsafe.} StunConn* = ref object udp*: UdpTransport # The wrapper protocol: UDP Transport @@ -37,8 +38,10 @@ type stunMsgs*: AsyncQueue[seq[byte]] # stun messages received and to be # processed by the stun message handler handlesFut*: Future[void] # Stun Message handler - closeEvent: AsyncEvent + + # Close connection management closed*: bool + onClose: seq[StunConnOnClose] # Is ice-controlling and iceTiebreaker, not fully implemented yet. iceControlling: bool @@ -201,7 +204,6 @@ proc new*( laddr: udp.laddr, raddr: raddr, closed: false, - closeEvent: newAsyncEvent(), dataRecv: newAsyncQueue[seq[byte]](StunMaxQueuingMessages), stunMsgs: newAsyncQueue[seq[byte]](StunMaxQueuingMessages), iceControlling: iceControlling, @@ -215,10 +217,10 @@ proc new*( trackCounter(StunConnectionTracker) return self -proc join*(self: StunConn) {.async: (raises: [CancelledError]).} = - ## Wait for the Stun Connection to be closed +proc addOnClose*(self: StunConn, onCloseProc: StunConnOnClose) = + ## Adds a proc to be called when StunConn is closed ## - await self.closeEvent.wait() + self.onClose.add(onCloseProc) proc close*(self: StunConn) {.async: (raises: []).} = ## Close a Stun Connection @@ -227,7 +229,9 @@ proc close*(self: StunConn) {.async: (raises: []).} = debug "Try to close an already closed StunConn" return await self.handlesFut.cancelAndWait() - self.closeEvent.fire() + for onCloseProc in self.onClose: + onCloseProc() + self.onClose = @[] self.closed = true untrackCounter(StunConnectionTracker) diff --git a/webrtc/stun/stun_transport.nim b/webrtc/stun/stun_transport.nim index 0fbb1ad..becc593 100644 --- a/webrtc/stun/stun_transport.nim +++ b/webrtc/stun/stun_transport.nim @@ -32,6 +32,12 @@ type rng: ref HmacDrbgContext +proc addConnToTable(self: Stun, conn: StunConn) = + proc cleanup() = + self.connections.del(conn.raddr) + self.connections[conn.raddr] = conn + conn.addOnClose(cleanup) + proc accept*(self: Stun): Future[StunConn] {.async: (raises: [CancelledError]).} = ## Accept a Stun Connection ## @@ -53,17 +59,9 @@ proc connect*( do: let res = StunConn.new(self.udp, raddr, false, self.usernameProvider, self.usernameChecker, self.passwordProvider, self.rng) - self.connections[raddr] = res + self.addConnToTable(res) return res -proc cleanupStunConn(self: Stun, conn: StunConn) {.async: (raises: []).} = - # Waiting for a connection to be closed to remove it from the table - try: - await conn.join() - self.connections.del(conn.raddr) - except CancelledError as exc: - warn "Error cleaning up Stun Connection", error=exc.msg - proc stunReadLoop(self: Stun) {.async: (raises: [CancelledError]).} = while true: let (buf, raddr) = await self.udp.read() @@ -71,9 +69,8 @@ proc stunReadLoop(self: Stun) {.async: (raises: [CancelledError]).} = if not self.connections.hasKey(raddr): stunConn = StunConn.new(self.udp, raddr, true, self.usernameProvider, self.usernameChecker, self.passwordProvider, self.rng) - self.connections[raddr] = stunConn + self.addConnToTable(stunConn) await self.pendingConn.addLast(stunConn) - asyncSpawn self.cleanupStunConn(stunConn) else: try: stunConn = self.connections[raddr]