diff --git a/tests/runalltests.nim b/tests/runalltests.nim index 764ff3c..3a49806 100644 --- a/tests/runalltests.nim +++ b/tests/runalltests.nim @@ -10,3 +10,4 @@ {.used.} import teststun +import testdtls diff --git a/tests/testdtls.nim b/tests/testdtls.nim new file mode 100644 index 0000000..871289a --- /dev/null +++ b/tests/testdtls.nim @@ -0,0 +1,83 @@ +# Nim-WebRTC +# Copyright (c) 2024 Status Research & Development GmbH +# Licensed under either of +# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE)) +# * MIT license ([LICENSE-MIT](LICENSE-MIT)) +# at your option. +# This file may not be copied, modified, or distributed except according to +# those terms. + +{.used.} + +import chronos +import ../webrtc/udp_transport +import ../webrtc/stun/stun_transport +import ../webrtc/dtls/dtls_transport +import ../webrtc/dtls/dtls_connection +import ./asyncunit + +suite "DTLS": + teardown: + checkLeaks() + + asyncTest "Two DTLS nodes connecting to each other, then sending/receiving data": + let + localAddr1 = initTAddress("127.0.0.1:4444") + localAddr2 = initTAddress("127.0.0.1:5555") + udp1 = UdpTransport.new(localAddr1) + udp2 = UdpTransport.new(localAddr2) + stun1 = Stun.new(udp1) + stun2 = Stun.new(udp2) + dtls1 = Dtls.new(stun1) + dtls2 = Dtls.new(stun2) + conn1Fut = dtls1.accept() + conn2 = await dtls2.connect(localAddr1) + conn1 = await conn1Fut + + await conn1.write(@[1'u8, 2, 3, 4]) + let seq1 = await conn2.read() + check seq1 == @[1'u8, 2, 3, 4] + + await conn2.write(@[5'u8, 6, 7, 8]) + let seq2 = await conn1.read() + check seq2 == @[5'u8, 6, 7, 8] + await allFutures(conn1.close(), conn2.close()) + await allFutures(dtls1.stop(), dtls2.stop()) + await allFutures(stun1.stop(), stun2.stop()) + await allFutures(udp1.close(), udp2.close()) + + asyncTest "Two DTLS nodes connecting to the same DTLS server, sending/receiving data": + let + localAddr1 = initTAddress("127.0.0.1:4444") + localAddr2 = initTAddress("127.0.0.1:5555") + localAddr3 = initTAddress("127.0.0.1:6666") + udp1 = UdpTransport.new(localAddr1) + udp2 = UdpTransport.new(localAddr2) + udp3 = UdpTransport.new(localAddr3) + stun1 = Stun.new(udp1) + stun2 = Stun.new(udp2) + stun3 = Stun.new(udp3) + dtls1 = Dtls.new(stun1) + dtls2 = Dtls.new(stun2) + dtls3 = Dtls.new(stun3) + servConn1Fut = dtls1.accept() + servConn2Fut = dtls1.accept() + clientConn1 = await dtls2.connect(localAddr1) + clientConn2 = await dtls3.connect(localAddr1) + servConn1 = await servConn1Fut + servConn2 = await servConn2Fut + + await servConn1.write(@[1'u8, 2, 3, 4]) + await servConn2.write(@[5'u8, 6, 7, 8]) + await clientConn1.write(@[9'u8, 10, 11, 12]) + await clientConn2.write(@[13'u8, 14, 15, 16]) + check: + (await clientConn1.read()) == @[1'u8, 2, 3, 4] + (await clientConn2.read()) == @[5'u8, 6, 7, 8] + (await servConn1.read()) == @[9'u8, 10, 11, 12] + (await servConn2.read()) == @[13'u8, 14, 15, 16] + await allFutures(servConn1.close(), servConn2.close()) + await allFutures(clientConn1.close(), clientConn2.close()) + await allFutures(dtls1.stop(), dtls2.stop(), dtls3.stop()) + await allFutures(stun1.stop(), stun2.stop(), stun3.stop()) + await allFutures(udp1.close(), udp2.close(), udp3.close()) diff --git a/tests/teststun.nim b/tests/teststun.nim index e2506be..01f8ba9 100644 --- a/tests/teststun.nim +++ b/tests/teststun.nim @@ -55,7 +55,7 @@ suite "Stun message encoding/decoding": decoded == msg messageIntegrity.attributeType == AttrMessageIntegrity.uint16 fingerprint.attributeType == AttrFingerprint.uint16 - conn.close() + await conn.close() await udp.close() asyncTest "Get BindingResponse from BindingRequest + encode & decode": @@ -82,7 +82,7 @@ suite "Stun message encoding/decoding": bindingResponse == decoded messageIntegrity.attributeType == AttrMessageIntegrity.uint16 fingerprint.attributeType == AttrFingerprint.uint16 - conn.close() + await conn.close() await udp.close() suite "Stun checkForError": @@ -114,7 +114,7 @@ suite "Stun checkForError": check: errorMissUsername.getAttribute(ErrorCode).get().getErrorCode() == ECBadRequest - conn.close() + await conn.close() await udp.close() asyncTest "checkForError: UsernameChecker returns false": @@ -136,5 +136,5 @@ suite "Stun checkForError": check: error.getAttribute(ErrorCode).get().getErrorCode() == ECUnauthorized - conn.close() + await conn.close() await udp.close() diff --git a/webrtc.nimble b/webrtc.nimble index ed35db0..da9e66f 100644 --- a/webrtc.nimble +++ b/webrtc.nimble @@ -17,12 +17,15 @@ let lang = getEnv("NIMLANG", "c") # Which backend (c/cpp/js) let flags = getEnv("NIMFLAGS", "") # Extra flags for the compiler let verbose = getEnv("V", "") notin ["", "0"] -let cfg = +var cfg = " --styleCheck:usages --styleCheck:error" & (if verbose: "" else: " --verbosity:0 --hints:off") & " --skipParentCfg --skipUserCfg -f" & " --threads:on --opt:speed" +when defined(windows): + cfg = cfg & " --clib:ws2_32" + import hashes proc runTest(filename: string) = diff --git a/webrtc/dtls/dtls_connection.nim b/webrtc/dtls/dtls_connection.nim new file mode 100644 index 0000000..ca342db --- /dev/null +++ b/webrtc/dtls/dtls_connection.nim @@ -0,0 +1,280 @@ +# Nim-WebRTC +# Copyright (c) 2024 Status Research & Development GmbH +# Licensed under either of +# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE)) +# * MIT license ([LICENSE-MIT](LICENSE-MIT)) +# at your option. +# This file may not be copied, modified, or distributed except according to +# those terms. + +import chronos, chronicles +import + mbedtls/[ + ssl, ssl_cookie, ssl_cache, pk, md, ctr_drbg, rsa, x509, x509_crt, bignum, error, + net_sockets, timing, + ] +import ../errors, ../stun/[stun_connection], ./dtls_utils + +logScope: + topics = "webrtc dtls_conn" + +const DtlsConnTracker* = "webrtc.dtls.conn" + +type + MbedTLSCtx = object + ssl: mbedtls_ssl_context + config: mbedtls_ssl_config + cookie: mbedtls_ssl_cookie_ctx + cache: mbedtls_ssl_cache_context + timer: mbedtls_timing_delay_context + pkey: mbedtls_pk_context + srvcert: mbedtls_x509_crt + ctr_drbg: mbedtls_ctr_drbg_context + + DtlsConn* = ref object + # DtlsConn is a Dtls connection receiving and sending data using + # the underlying Stun Connection + conn*: StunConn # The wrapper protocol Stun Connection + raddr: TransportAddress # Remote address + dataRecv: seq[byte] # data received which will be read by SCTP + dataToSend: seq[byte] + # This sequence is set by synchronous Mbed-TLS `dtlsSend` callbacks + # and sent, if set, once the synchronous functions ends + + # Close connection management + closed: bool + closeEvent: AsyncEvent + + # Local and Remote certificate, needed by wrapped protocol DataChannel + # and by libp2p + localCert: seq[byte] + remoteCert: seq[byte] + + # Mbed-TLS contexts + ctx: MbedTLSCtx + +proc getRemoteCertificateCallback( + ctx: pointer, pcert: ptr mbedtls_x509_crt, state: cint, pflags: ptr uint32 +): cint {.cdecl.} = + # getRemoteCertificateCallback is the procedure called by mbedtls when + # receiving the remote certificate. It's usually used to verify the validity + # of the certificate, we don't do it. We use this procedure to store the remot + # certificate as it's mandatory to have it for the Prologue of the Noise + # protocol, aswell as the localCertificate. + var self = cast[DtlsConn](ctx) + let cert = pcert[] + + self.remoteCert = newSeq[byte](cert.raw.len) + copyMem(addr self.remoteCert[0], cert.raw.p, cert.raw.len) + return 0 + +proc dtlsSend(ctx: pointer, buf: ptr byte, len: uint): cint {.cdecl.} = + # dtlsSend is the procedure called by mbedtls when data needs to be sent. + # As the StunConn's write proc is asynchronous and dtlsSend cannot be async, + # we store the message to be sent and it after the end of the function + # (see write or dtlsHanshake for example). + var self = cast[DtlsConn](ctx) + self.dataToSend = newSeq[byte](len) + if len > 0: + copyMem(addr self.dataToSend[0], buf, len) + trace "dtls send", len + result = len.cint + +proc dtlsRecv(ctx: pointer, buf: ptr byte, len: uint): cint {.cdecl.} = + # dtlsRecv is the procedure called by mbedtls when data needs to be received. + # As we cannot asynchronously await for data to be received, we use a data received + # queue. If this queue is empty, we return `MBEDTLS_ERR_SSL_WANT_READ` for us to await + # when the mbedtls proc resumed (see read or dtlsHandshake for example) + let self = cast[DtlsConn](ctx) + if self.dataRecv.len() == 0: + return MBEDTLS_ERR_SSL_WANT_READ + + copyMem(buf, addr self.dataRecv[0], self.dataRecv.len()) + result = self.dataRecv.len().cint + self.dataRecv = @[] + trace "dtls receive", len, result + +proc new*(T: type DtlsConn, conn: StunConn): T = + ## Initialize a Dtls Connection + ## + var self = T(conn: conn) + self.raddr = conn.raddr + self.closed = false + self.closeEvent = newAsyncEvent() + return self + +proc dtlsConnInit(self: DtlsConn) = + mb_ssl_init(self.ctx.ssl) + mb_ssl_config_init(self.ctx.config) + mb_ssl_conf_rng(self.ctx.config, mbedtls_ctr_drbg_random, self.ctx.ctr_drbg) + mb_ssl_conf_read_timeout(self.ctx.config, 10000) # in milliseconds + mb_ssl_conf_ca_chain(self.ctx.config, self.ctx.srvcert.next, nil) + mb_ssl_set_timer_cb(self.ctx.ssl, self.ctx.timer) + mb_ssl_set_verify(self.ctx.ssl, getRemoteCertificateCallback, self) + mb_ssl_set_bio(self.ctx.ssl, cast[pointer](self), dtlsSend, dtlsRecv, nil) + +proc acceptInit*( + self: DtlsConn, + ctr_drbg: mbedtls_ctr_drbg_context, + pkey: mbedtls_pk_context, + srvcert: mbedtls_x509_crt, + localCert: seq[byte], +) = + try: + self.ctx.ctr_drbg = ctr_drbg + self.ctx.pkey = pkey + self.ctx.srvcert = srvcert + self.localCert = localCert + + self.dtlsConnInit() + mb_ssl_cookie_init(self.ctx.cookie) + mb_ssl_cache_init(self.ctx.cache) + mb_ssl_config_defaults( + self.ctx.config, MBEDTLS_SSL_IS_SERVER, MBEDTLS_SSL_TRANSPORT_DATAGRAM, + MBEDTLS_SSL_PRESET_DEFAULT, + ) + mb_ssl_conf_own_cert(self.ctx.config, self.ctx.srvcert, self.ctx.pkey) + mb_ssl_cookie_setup(self.ctx.cookie, mbedtls_ctr_drbg_random, self.ctx.ctr_drbg) + mb_ssl_conf_dtls_cookies(self.ctx.config, addr self.ctx.cookie) + mb_ssl_setup(self.ctx.ssl, self.ctx.config) + mb_ssl_session_reset(self.ctx.ssl) + mb_ssl_conf_authmode(self.ctx.config, MBEDTLS_SSL_VERIFY_OPTIONAL) + except MbedTLSError as exc: + raise newException(WebRtcError, "DTLS - Accept initialization: " & exc.msg, exc) + +proc connectInit*(self: DtlsConn, ctr_drbg: mbedtls_ctr_drbg_context) = + try: + self.ctx.ctr_drbg = ctr_drbg + self.ctx.pkey = self.ctx.ctr_drbg.generateKey() + self.ctx.srvcert = self.ctx.ctr_drbg.generateCertificate(self.ctx.pkey) + self.localCert = newSeq[byte](self.ctx.srvcert.raw.len) + copyMem(addr self.localCert[0], self.ctx.srvcert.raw.p, self.ctx.srvcert.raw.len) + + self.dtlsConnInit() + mb_ssl_config_defaults( + self.ctx.config, MBEDTLS_SSL_IS_CLIENT, MBEDTLS_SSL_TRANSPORT_DATAGRAM, + MBEDTLS_SSL_PRESET_DEFAULT, + ) + mb_ssl_setup(self.ctx.ssl, self.ctx.config) + mb_ssl_conf_authmode(self.ctx.config, MBEDTLS_SSL_VERIFY_OPTIONAL) + except MbedTLSError as exc: + raise newException(WebRtcError, "DTLS - Connect initialization: " & exc.msg, exc) + +proc join*(self: DtlsConn) {.async: (raises: [CancelledError]).} = + ## Wait for the Dtls Connection to be closed + ## + await self.closeEvent.wait() + +proc dtlsHandshake*( + self: DtlsConn, isServer: bool +) {.async: (raises: [CancelledError, WebRtcError]).} = + var shouldRead = isServer + try: + while self.ctx.ssl.private_state != MBEDTLS_SSL_HANDSHAKE_OVER: + if shouldRead: + if isServer: + case self.raddr.family + of AddressFamily.IPv4: + mb_ssl_set_client_transport_id(self.ctx.ssl, self.raddr.address_v4) + of AddressFamily.IPv6: + mb_ssl_set_client_transport_id(self.ctx.ssl, self.raddr.address_v6) + else: + raiseAssert("Remote address must be IPv4 or IPv6") + let (data, _) = await self.conn.read() + self.dataRecv = data + self.dataToSend = @[] + let res = mb_ssl_handshake_step(self.ctx.ssl) + if self.dataToSend.len() > 0: + await self.conn.write(self.dataToSend) + self.dataToSend = @[] + shouldRead = false + if res == MBEDTLS_ERR_SSL_WANT_WRITE: + continue + elif res == MBEDTLS_ERR_SSL_WANT_READ: + shouldRead = true + continue + elif res == MBEDTLS_ERR_SSL_HELLO_VERIFY_REQUIRED: + mb_ssl_session_reset(self.ctx.ssl) + shouldRead = isServer + continue + elif res != 0: + raise newException(WebRtcError, "DTLS - " & $(res.mbedtls_high_level_strerr())) + except MbedTLSError as exc: + trace "Dtls handshake error", errorMsg = exc.msg + raise newException(WebRtcError, "DTLS - Handshake error", exc) + trackCounter(DtlsConnTracker) + +proc close*(self: DtlsConn) {.async: (raises: [CancelledError, WebRtcError]).} = + ## Close a Dtls Connection + ## + if self.closed: + debug "Try to close an already closed DtlsConn" + return + self.closed = true + self.dataToSend = @[] + let x = mbedtls_ssl_close_notify(addr self.ctx.ssl) + if self.dataToSend.len() > 0: + await self.conn.write(self.dataToSend) + self.dataToSend = @[] + untrackCounter(DtlsConnTracker) + self.closeEvent.fire() + +proc write*(self: DtlsConn, msg: seq[byte]) {.async.} = + ## Write a message using mbedtls_ssl_write + ## + # Mbed-TLS will wrap the message properly and call `dtlsSend` callback. + # `dtlsSend` will store the message to be sent on the higher Stun connection. + if self.closed: + debug "Try to write on an already closed DtlsConn" + return + var buf = msg + try: + self.dataToSend = @[] + let write = mb_ssl_write(self.ctx.ssl, buf) + if self.dataToSend.len() > 0: + await self.conn.write(self.dataToSend) + self.dataToSend = @[] + trace "Dtls write", msgLen = msg.len(), actuallyWrote = write + except MbedTLSError as exc: + trace "Dtls write error", errorMsg = exc.msg + raise exc + +proc read*(self: DtlsConn): Future[seq[byte]] {.async.} = + ## Read the next received message by StunConn. + ## Uncypher it using mbedtls_ssl_read. + ## + # First we read the StunConn using the asynchronous `StunConn.read` procedure. + # When we received data, we stored it in `DtlsConn.dataRecv` and call `dtlsRecv` + # callback using mbedtls in order to decypher it. + if self.closed: + debug "Try to read on an already closed DtlsConn" + return + var res = newSeq[byte](8192) + while true: + let (data, _) = await self.conn.read() + self.dataRecv = data + let length = + mbedtls_ssl_read(addr self.ctx.ssl, cast[ptr byte](addr res[0]), res.len().uint) + if length == MBEDTLS_ERR_SSL_WANT_READ: + continue + if length < 0: + raise newException( + WebRtcError, "DTLS - " & $(length.cint.mbedtls_high_level_strerr()) + ) + res.setLen(length) + return res + +proc remoteCertificate*(conn: DtlsConn): seq[byte] = + ## Get the remote certificate + ## + conn.remoteCert + +proc localCertificate*(conn: DtlsConn): seq[byte] = + ## Get the local certificate + ## + conn.localCert + +proc remoteAddress*(conn: DtlsConn): TransportAddress = + ## Get the remote address + ## + conn.raddr diff --git a/webrtc/dtls/dtls_transport.nim b/webrtc/dtls/dtls_transport.nim new file mode 100644 index 0000000..273234b --- /dev/null +++ b/webrtc/dtls/dtls_transport.nim @@ -0,0 +1,146 @@ +# Nim-WebRTC +# Copyright (c) 2024 Status Research & Development GmbH +# Licensed under either of +# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE)) +# * MIT license ([LICENSE-MIT](LICENSE-MIT)) +# at your option. +# This file may not be copied, modified, or distributed except according to +# those terms. + +import deques, tables, sequtils +import + chronos, + chronicles, + mbedtls/[ + ssl, ssl_cookie, ssl_cache, pk, md, entropy, ctr_drbg, rsa, x509, x509_crt, bignum, + error, net_sockets, timing, + ] +import + ./[dtls_utils, dtls_connection], ../errors, ../stun/[stun_connection, stun_transport] + +logScope: + topics = "webrtc dtls" + +# Implementation of a DTLS client and a DTLS Server by using the Mbed-TLS library. +# Multiple things here are unintuitive partly because of the callbacks +# used by Mbed-TLS which cannot be async. + +const DtlsTransportTracker* = "webrtc.dtls.transport" + +type + DtlsConnAndCleanup = object + connection: DtlsConn + cleanup: Future[void].Raising([]) + + Dtls* = ref object of RootObj + connections: Table[TransportAddress, DtlsConnAndCleanup] + transport: Stun + laddr: TransportAddress + started: bool + ctr_drbg: mbedtls_ctr_drbg_context + entropy: mbedtls_entropy_context + + serverPrivKey: mbedtls_pk_context + serverCert: mbedtls_x509_crt + localCert: seq[byte] + +proc new*(T: type Dtls, transport: Stun): T = + var self = T( + connections: initTable[TransportAddress, DtlsConnAndCleanup](), + transport: transport, + laddr: transport.laddr, + started: true, + ) + + mb_ctr_drbg_init(self.ctr_drbg) + mb_entropy_init(self.entropy) + mb_ctr_drbg_seed(self.ctr_drbg, mbedtls_entropy_func, self.entropy, nil, 0) + + self.serverPrivKey = self.ctr_drbg.generateKey() + self.serverCert = self.ctr_drbg.generateCertificate(self.serverPrivKey) + self.localCert = newSeq[byte](self.serverCert.raw.len) + copyMem(addr self.localCert[0], self.serverCert.raw.p, self.serverCert.raw.len) + trackCounter(DtlsTransportTracker) + return self + +proc stop*(self: Dtls) {.async: (raises: [CancelledError]).} = + ## Stop the Dtls transport. Stop every opened connections. + ## + if not self.started: + warn "Already stopped" + return + + self.started = false + let + allCloses = toSeq(self.connections.values()).mapIt(it.connection.close()) + allCleanup = toSeq(self.connections.values()).mapIt(it.cleanup) + await noCancel allFutures(allCloses) + await noCancel allFutures(allCleanup) + untrackCounter(DtlsTransportTracker) + +proc localCertificate*(self: Dtls): seq[byte] = + ## Local certificate getter + self.localCert + +proc localAddress*(self: Dtls): TransportAddress = + self.laddr + +proc cleanupDtlsConn(self: Dtls, conn: DtlsConn) {.async: (raises: []).} = + # Waiting for a connection to be closed to remove it from the table + try: + await conn.join() + except CancelledError as exc: + discard + + self.connections.del(conn.remoteAddress()) + +proc accept*( + self: Dtls +): Future[DtlsConn] {.async: (raises: [CancelledError, WebRtcError]).} = + ## Accept a Dtls Connection + ## + if not self.started: + raise newException(WebRtcError, "DTLS - Dtls transport not started") + var res: DtlsConn + + while true: + let + stunConn = await self.transport.accept() + raddr = stunConn.raddr + if raddr.family == AddressFamily.IPv4 or raddr.family == AddressFamily.IPv6: + try: + res = DtlsConn.new(stunConn) + res.acceptInit( + self.ctr_drbg, self.serverPrivKey, self.serverCert, self.localCert + ) + await res.dtlsHandshake(true) + self.connections[raddr] = + DtlsConnAndCleanup(connection: res, cleanup: self.cleanupDtlsConn(res)) + break + except WebRtcError as exc: + trace "Handshake fails, try accept another connection", raddr, error = exc.msg + self.connections.del(raddr) + return res + +proc connect*( + self: Dtls, raddr: TransportAddress +): Future[DtlsConn] {.async: (raises: [CancelledError, WebRtcError]).} = + ## Connect to a remote address, creating a Dtls Connection + ## + if not self.started: + raise newException(WebRtcError, "DTLS - Dtls transport not started") + if raddr.family != AddressFamily.IPv4 and raddr.family != AddressFamily.IPv6: + raise newException(WebRtcError, "DTLS - Can only connect to IP address") + var res = DtlsConn.new(await self.transport.connect(raddr)) + res.connectInit(self.ctr_drbg) + + try: + await res.dtlsHandshake(false) + self.connections[raddr] = + DtlsConnAndCleanup(connection: res, cleanup: self.cleanupDtlsConn(res)) + except WebRtcError as exc: + trace "Handshake fails", raddr, error = exc.msg + self.connections.del(raddr) + raise exc + + return res diff --git a/webrtc/dtls/dtls_utils.nim b/webrtc/dtls/dtls_utils.nim new file mode 100644 index 0000000..734eae9 --- /dev/null +++ b/webrtc/dtls/dtls_utils.nim @@ -0,0 +1,81 @@ +# Nim-WebRTC +# Copyright (c) 2024 Status Research & Development GmbH +# Licensed under either of +# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE)) +# * MIT license ([LICENSE-MIT](LICENSE-MIT)) +# at your option. +# This file may not be copied, modified, or distributed except according to +# those terms. + +import std/times +import ../errors + +import mbedtls/[pk, rsa, ctr_drbg, x509_crt, bignum, md, error] + +# This sequence is used for debugging. +const mb_ssl_states* = + @[ + "MBEDTLS_SSL_HELLO_REQUEST", "MBEDTLS_SSL_CLIENT_HELLO", "MBEDTLS_SSL_SERVER_HELLO", + "MBEDTLS_SSL_SERVER_CERTIFICATE", "MBEDTLS_SSL_SERVER_KEY_EXCHANGE", + "MBEDTLS_SSL_CERTIFICATE_REQUEST", "MBEDTLS_SSL_SERVER_HELLO_DONE", + "MBEDTLS_SSL_CLIENT_CERTIFICATE", "MBEDTLS_SSL_CLIENT_KEY_EXCHANGE", + "MBEDTLS_SSL_CERTIFICATE_VERIFY", "MBEDTLS_SSL_CLIENT_CHANGE_CIPHER_SPEC", + "MBEDTLS_SSL_CLIENT_FINISHED", "MBEDTLS_SSL_SERVER_CHANGE_CIPHER_SPEC", + "MBEDTLS_SSL_SERVER_FINISHED", "MBEDTLS_SSL_FLUSH_BUFFERS", + "MBEDTLS_SSL_HANDSHAKE_WRAPUP", "MBEDTLS_SSL_NEW_SESSION_TICKET", + "MBEDTLS_SSL_SERVER_HELLO_VERIFY_REQUEST_SENT", "MBEDTLS_SSL_HELLO_RETRY_REQUEST", + "MBEDTLS_SSL_ENCRYPTED_EXTENSIONS", "MBEDTLS_SSL_END_OF_EARLY_DATA", + "MBEDTLS_SSL_CLIENT_CERTIFICATE_VERIFY", + "MBEDTLS_SSL_CLIENT_CCS_AFTER_SERVER_FINISHED", + "MBEDTLS_SSL_CLIENT_CCS_BEFORE_2ND_CLIENT_HELLO", + "MBEDTLS_SSL_SERVER_CCS_AFTER_SERVER_HELLO", + "MBEDTLS_SSL_CLIENT_CCS_AFTER_CLIENT_HELLO", + "MBEDTLS_SSL_SERVER_CCS_AFTER_HELLO_RETRY_REQUEST", "MBEDTLS_SSL_HANDSHAKE_OVER", + "MBEDTLS_SSL_TLS1_3_NEW_SESSION_TICKET", + "MBEDTLS_SSL_TLS1_3_NEW_SESSION_TICKET_FLUSH", + ] + +template generateKey*(random: mbedtls_ctr_drbg_context): mbedtls_pk_context = + var res: mbedtls_pk_context + mb_pk_init(res) + discard mbedtls_pk_setup(addr res, mbedtls_pk_info_from_type(MBEDTLS_PK_RSA)) + mb_rsa_gen_key(mb_pk_rsa(res), mbedtls_ctr_drbg_random, random, 2048, 65537) + let x = mb_pk_rsa(res) + res + +template generateCertificate*( + random: mbedtls_ctr_drbg_context, issuer_key: mbedtls_pk_context +): mbedtls_x509_crt = + let + name = "C=FR,O=Status,CN=webrtc" + time_format = + try: + initTimeFormat("YYYYMMddHHmmss") + except TimeFormatParseError as exc: + raise newException(WebRtcError, "DTLS - " & exc.msg, exc) + time_from = times.now().format(time_format) + time_to = (times.now() + times.years(1)).format(time_format) + + var write_cert: mbedtls_x509write_cert + var serial_mpi: mbedtls_mpi + mb_x509write_crt_init(write_cert) + mb_x509write_crt_set_md_alg(write_cert, MBEDTLS_MD_SHA256) + mb_x509write_crt_set_subject_key(write_cert, issuer_key) + mb_x509write_crt_set_issuer_key(write_cert, issuer_key) + mb_x509write_crt_set_subject_name(write_cert, name) + mb_x509write_crt_set_issuer_name(write_cert, name) + mb_x509write_crt_set_validity(write_cert, time_from, time_to) + mb_x509write_crt_set_basic_constraints(write_cert, 0, -1) + mb_x509write_crt_set_subject_key_identifier(write_cert) + mb_x509write_crt_set_authority_key_identifier(write_cert) + mb_mpi_init(serial_mpi) + let serial_hex = mb_mpi_read_string(serial_mpi, 16) + mb_x509write_crt_set_serial(write_cert, serial_mpi) + let buf = + try: + mb_x509write_crt_pem(write_cert, 2048, mbedtls_ctr_drbg_random, random) + except MbedTLSError as exc: + raise newException(WebRtcError, "DTLS - " & exc.msg, exc) + var res: mbedtls_x509_crt + mb_x509_crt_parse(res, buf) + res diff --git a/webrtc/stun/stun_connection.nim b/webrtc/stun/stun_connection.nim index ef355ca..2f119da 100644 --- a/webrtc/stun/stun_connection.nim +++ b/webrtc/stun/stun_connection.nim @@ -220,14 +220,14 @@ proc join*(self: StunConn) {.async: (raises: [CancelledError]).} = ## await self.closeEvent.wait() -proc close*(self: StunConn) = +proc close*(self: StunConn) {.async: (raises: []).} = ## Close a Stun Connection ## if self.closed: debug "Try to close an already closed StunConn" return + await self.handlesFut.cancelAndWait() self.closeEvent.fire() - self.handlesFut.cancelSoon() self.closed = true untrackCounter(StunConnectionTracker) diff --git a/webrtc/stun/stun_transport.nim b/webrtc/stun/stun_transport.nim index b61b17f..0fbb1ad 100644 --- a/webrtc/stun/stun_transport.nim +++ b/webrtc/stun/stun_transport.nim @@ -7,7 +7,7 @@ # This file may not be copied, modified, or distributed except according to # those terms. -import tables +import tables, sequtils import chronos, chronicles, bearssl import stun_connection, stun_message, ../udp_transport @@ -22,8 +22,9 @@ type Stun* = ref object connections: Table[TransportAddress, StunConn] pendingConn: AsyncQueue[StunConn] - readingLoop: Future[void] + readingLoop: Future[void].Raising([CancelledError]) udp: UdpTransport + laddr*: TransportAddress usernameProvider: StunUsernameProvider usernameChecker: StunUsernameChecker @@ -84,12 +85,14 @@ proc stunReadLoop(self: Stun) {.async: (raises: [CancelledError]).} = else: await stunConn.dataRecv.addLast(buf) -proc stop(self: Stun) = +proc stop*(self: Stun) {.async: (raises: []).} = ## Stop the Stun transport and close all the connections ## - for conn in self.connections.values(): - conn.close() - self.readingLoop.cancelSoon() + try: + await allFutures(toSeq(self.connections.values()).mapIt(it.close())) + except CancelledError as exc: + discard + await self.readingLoop.cancelAndWait() untrackCounter(StunTransportTracker) proc defaultUsernameProvider(): string = "" @@ -108,12 +111,13 @@ proc new*( ## var self = T( udp: udp, + laddr: udp.laddr, usernameProvider: usernameProvider, usernameChecker: usernameChecker, passwordProvider: passwordProvider, rng: rng ) - self.readingLoop = stunReadLoop() + self.readingLoop = self.stunReadLoop() self.pendingConn = newAsyncQueue[StunConn](StunMaxPendingConnections) trackCounter(StunTransportTracker) return self