Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: stun and dtls close #22

Merged
merged 4 commits into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions tests/testdtls.nim
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,38 @@ suite "DTLS":
await allFutures(dtls1.stop(), dtls2.stop(), dtls3.stop())
await allFutures(stun1.stop(), stun2.stop(), stun3.stop())
await allFutures(udp1.close(), udp2.close(), udp3.close())

asyncTest "Two DTLS nodes connecting to each other, disconnecting and reconnecting":
let
localAddr1 = initTAddress("127.0.0.1:4444")
localAddr2 = initTAddress("127.0.0.1:5555")
udp1 = UdpTransport.new(localAddr1)
udp2 = UdpTransport.new(localAddr2)
stun1 = Stun.new(udp1)
stun2 = Stun.new(udp2)
dtls1 = Dtls.new(stun1)
dtls2 = Dtls.new(stun2)
var
conn1Fut = dtls1.accept()
conn2 = await dtls2.connect(localAddr1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe better to use names like client and server?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

conn1 = await conn1Fut

await conn1.write(@[1'u8, 2, 3, 4])
await conn2.write(@[5'u8, 6, 7, 8])
check (await conn1.read()) == @[5'u8, 6, 7, 8]
check (await conn2.read()) == @[1'u8, 2, 3, 4]
await allFutures(conn1.close(), conn2.close())

conn1Fut = dtls1.accept()
conn2 = await dtls2.connect(localAddr1)
conn1 = await conn1Fut

await conn1.write(@[5'u8, 6, 7, 8])
await conn2.write(@[1'u8, 2, 3, 4])
check (await conn1.read()) == @[1'u8, 2, 3, 4]
check (await conn2.read()) == @[5'u8, 6, 7, 8]

await allFutures(conn1.close(), conn2.close())
await allFutures(dtls1.stop(), dtls2.stop())
await allFutures(stun1.stop(), stun2.stop())
await allFutures(udp1.close(), udp2.close())
7 changes: 7 additions & 0 deletions webrtc/dtls/dtls_connection.nim
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ logScope:
const DtlsConnTracker* = "webrtc.dtls.conn"

type
DtlsConnCleanup* = proc() {.raises: [], gcsafe.}

MbedTLSCtx = object
ssl: mbedtls_ssl_context
config: mbedtls_ssl_config
Expand All @@ -44,6 +46,7 @@ type
# Close connection management
closed: bool
closeEvent: AsyncEvent
cleanup*: DtlsConnCleanup

# Local and Remote certificate, needed by wrapped protocol DataChannel
# and by libp2p
Expand Down Expand Up @@ -217,6 +220,10 @@ proc close*(self: DtlsConn) {.async: (raises: [CancelledError, WebRtcError]).} =
await self.conn.write(self.dataToSend)
self.dataToSend = @[]
untrackCounter(DtlsConnTracker)
await self.conn.close()
if not self.cleanup.isNil():
self.cleanup()
self.cleanup = nil
self.closeEvent.fire()

proc write*(
Expand Down
31 changes: 10 additions & 21 deletions webrtc/dtls/dtls_transport.nim
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,8 @@ logScope:
const DtlsTransportTracker* = "webrtc.dtls.transport"

type
DtlsConnAndCleanup = object
connection: DtlsConn
cleanup: Future[void].Raising([])

Dtls* = ref object of RootObj
connections: Table[TransportAddress, DtlsConnAndCleanup]
connections: Table[TransportAddress, DtlsConn]
transport: Stun
laddr: TransportAddress
started: bool
Expand All @@ -46,7 +42,7 @@ type

proc new*(T: type Dtls, transport: Stun): T =
var self = T(
connections: initTable[TransportAddress, DtlsConnAndCleanup](),
connections: initTable[TransportAddress, DtlsConn](),
transport: transport,
laddr: transport.laddr,
started: true,
Expand All @@ -72,10 +68,8 @@ proc stop*(self: Dtls) {.async: (raises: [CancelledError]).} =

self.started = false
let
allCloses = toSeq(self.connections.values()).mapIt(it.connection.close())
allCleanup = toSeq(self.connections.values()).mapIt(it.cleanup)
allCloses = toSeq(self.connections.values()).mapIt(it.close())
await noCancel allFutures(allCloses)
await noCancel allFutures(allCleanup)
untrackCounter(DtlsTransportTracker)

proc localCertificate*(self: Dtls): seq[byte] =
Expand All @@ -85,14 +79,11 @@ proc localCertificate*(self: Dtls): seq[byte] =
proc localAddress*(self: Dtls): TransportAddress =
self.laddr

proc cleanupDtlsConn(self: Dtls, conn: DtlsConn) {.async: (raises: []).} =
# Waiting for a connection to be closed to remove it from the table
try:
await conn.join()
except CancelledError as exc:
discard

self.connections.del(conn.remoteAddress())
proc addConnToTable(self: Dtls, conn: DtlsConn) =
proc cleanup() =
self.connections.del(conn.remoteAddress())
self.connections[conn.remoteAddress()] = conn
conn.cleanup = cleanup

proc accept*(
self: Dtls
Expand All @@ -114,8 +105,7 @@ proc accept*(
self.ctr_drbg, self.serverPrivKey, self.serverCert, self.localCert
)
await res.dtlsHandshake(true)
self.connections[raddr] =
DtlsConnAndCleanup(connection: res, cleanup: self.cleanupDtlsConn(res))
self.addConnToTable(res)
break
except WebRtcError as exc:
trace "Handshake fails, try accept another connection", raddr, error = exc.msg
Expand All @@ -136,8 +126,7 @@ proc connect*(

try:
await res.dtlsHandshake(false)
self.connections[raddr] =
DtlsConnAndCleanup(connection: res, cleanup: self.cleanupDtlsConn(res))
self.addConnToTable(res)
except WebRtcError as exc:
trace "Handshake fails", raddr, error = exc.msg
self.connections.del(raddr)
Expand Down
7 changes: 7 additions & 0 deletions webrtc/stun/stun_connection.nim
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ type
StunUsernameProvider* = proc(): string {.raises: [], gcsafe.}
StunUsernameChecker* = proc(username: seq[byte]): bool {.raises: [], gcsafe.}
StunPasswordProvider* = proc(username: seq[byte]): seq[byte] {.raises: [], gcsafe.}
StunConnCleanup* = proc() {.raises: [], gcsafe.}

StunConn* = ref object
udp*: UdpTransport # The wrapper protocol: UDP Transport
Expand All @@ -37,8 +38,11 @@ type
stunMsgs*: AsyncQueue[seq[byte]] # stun messages received and to be
# processed by the stun message handler
handlesFut*: Future[void] # Stun Message handler

# Close connection management
closeEvent: AsyncEvent
closed*: bool
cleanup*: StunConnCleanup

# Is ice-controlling and iceTiebreaker, not fully implemented yet.
iceControlling: bool
Expand Down Expand Up @@ -227,6 +231,9 @@ proc close*(self: StunConn) {.async: (raises: []).} =
debug "Try to close an already closed StunConn"
return
await self.handlesFut.cancelAndWait()
if not self.cleanup.isNil():
self.cleanup()
self.cleanup = nil
self.closeEvent.fire()
self.closed = true
untrackCounter(StunConnectionTracker)
Expand Down
19 changes: 8 additions & 11 deletions webrtc/stun/stun_transport.nim
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ type

rng: ref HmacDrbgContext

proc addConnToTable(self: Stun, conn: StunConn) =
proc cleanup() =
self.connections.del(conn.raddr)
self.connections[conn.raddr] = conn
conn.cleanup = cleanup

proc accept*(self: Stun): Future[StunConn] {.async: (raises: [CancelledError]).} =
## Accept a Stun Connection
##
Expand All @@ -53,27 +59,18 @@ proc connect*(
do:
let res = StunConn.new(self.udp, raddr, false, self.usernameProvider,
self.usernameChecker, self.passwordProvider, self.rng)
self.connections[raddr] = res
self.addConnToTable(res)
return res

proc cleanupStunConn(self: Stun, conn: StunConn) {.async: (raises: []).} =
# Waiting for a connection to be closed to remove it from the table
try:
await conn.join()
self.connections.del(conn.raddr)
except CancelledError as exc:
warn "Error cleaning up Stun Connection", error=exc.msg

proc stunReadLoop(self: Stun) {.async: (raises: [CancelledError]).} =
while true:
let (buf, raddr) = await self.udp.read()
var stunConn: StunConn
if not self.connections.hasKey(raddr):
stunConn = StunConn.new(self.udp, raddr, true, self.usernameProvider,
self.usernameChecker, self.passwordProvider, self.rng)
self.connections[raddr] = stunConn
self.addConnToTable(stunConn)
await self.pendingConn.addLast(stunConn)
asyncSpawn self.cleanupStunConn(stunConn)
else:
try:
stunConn = self.connections[raddr]
Expand Down
Loading