diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 9006cf6e50cd8..13ea7874e1ac9 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -115,7 +115,7 @@ jobs: displayName: 'Install dependencies (i386 Linux)' condition: and(eq(variables['Agent.OS'], 'Linux'), eq(variables['CPU'], 'i386')) - - bash: brew install boehmgc make sfml + - bash: brew install boehmgc make sfml openssl displayName: 'Install dependencies (OSX)' condition: eq(variables['Agent.OS'], 'Darwin') diff --git a/lib/pure/net.nim b/lib/pure/net.nim index 343cdc9b1b77d..23f75c4a5ee67 100644 --- a/lib/pure/net.nim +++ b/lib/pure/net.nim @@ -120,7 +120,6 @@ when defineSsl: SslContext* = ref object context*: SslCtx referencedData: HashSet[int] - extraInternal: SslContextExtraInternal SslAcceptResult* = enum AcceptNoClient = 0, AcceptNoHandshake, AcceptSuccess @@ -128,13 +127,9 @@ when defineSsl: SslHandshakeType* = enum handshakeAsClient, handshakeAsServer - SslClientGetPskFunc* = proc(hint: string): tuple[identity: string, psk: string] + SslClientGetPskFunc* = proc(hint: string): tuple[identity: string, psk: string]{.nimcall.} - SslServerGetPskFunc* = proc(identity: string): string - - SslContextExtraInternal = ref object of RootRef - serverGetPskFunc: SslServerGetPskFunc - clientGetPskFunc: SslClientGetPskFunc + SslServerGetPskFunc* = proc(identity: string): string{.nimcall.} else: type @@ -611,6 +606,9 @@ when defineSsl: when not defined(openssl10) and not defined(libressl): let sslVersion = getOpenSSLVersion() if sslVersion >= 0x010101000 and not sslVersion == 0x020000000: + # XXX always false! + # XXX however, setting non-1.3 ciphers with ciphersuites will + # XXX cause an error, cipherList needs to be split into 1.3 and non-1.3 # In OpenSSL >= 1.1.1, TLSv1.3 cipher suites can only be configured via # this API. if newCTX.SSL_CTX_set_ciphersuites(cipherList) != 1: @@ -658,11 +656,7 @@ when defineSsl: if not found: raise newException(IOError, "No SSL/TLS CA certificates found.") - result = SslContext(context: newCTX, referencedData: initHashSet[int](), - extraInternal: new(SslContextExtraInternal)) - - proc getExtraInternal(ctx: SslContext): SslContextExtraInternal = - return ctx.extraInternal + result = SslContext(context: newCTX, referencedData: initHashSet[int]()) proc destroyContext*(ctx: SslContext) = ## Free memory referenced by SslContext. @@ -678,56 +672,51 @@ when defineSsl: ## Sets the identity hint passed to server. ## ## Only used in PSK ciphersuites. - if ctx.context.SSL_CTX_use_psk_identity_hint(hint) <= 0: + if ctx.context.SSL_CTX_use_psk_identity_hint(hint.cstring) <= 0: raiseSSLError() - proc clientGetPskFunc*(ctx: SslContext): SslClientGetPskFunc = - return ctx.getExtraInternal().clientGetPskFunc - - proc pskClientCallback(ssl: SslPtr; hint: cstring; identity: cstring; - max_identity_len: cuint; psk: ptr cuchar; - max_psk_len: cuint): cuint {.cdecl.} = - let ctx = SslContext(context: ssl.SSL_get_SSL_CTX) - let hintString = if hint == nil: "" else: $hint - let (identityString, pskString) = (ctx.clientGetPskFunc)(hintString) - if pskString.len.cuint > max_psk_len: - return 0 - if identityString.len.cuint >= max_identity_len: - return 0 - copyMem(identity, identityString.cstring, identityString.len + 1) # with the last zero byte - copyMem(psk, pskString.cstring, pskString.len) - - return pskString.len.cuint - - proc `clientGetPskFunc=`*(ctx: SslContext, fun: SslClientGetPskFunc) = - ## Sets function that returns the client identity and the PSK based on identity - ## hint from the server. - ## - ## Only used in PSK ciphersuites. - ctx.getExtraInternal().clientGetPskFunc = fun - ctx.context.SSL_CTX_set_psk_client_callback( - if fun == nil: nil else: pskClientCallback) - - proc serverGetPskFunc*(ctx: SslContext): SslServerGetPskFunc = - return ctx.getExtraInternal().serverGetPskFunc + template genpskServerCallback(pskfunc: SslServerGetPskFunc): auto = + proc pskServerCallback(ssl: SslCtx; identity: cstring; psk: ptr cuchar; + max_psk_len: cint): cuint {.cdecl.} = + let pskString = pskfunc($identity) + if pskString.len.cint > max_psk_len: + return 0 + copyMem(psk, pskString.cstring, pskString.len) + return pskString.len.cuint - proc pskServerCallback(ssl: SslCtx; identity: cstring; psk: ptr cuchar; - max_psk_len: cint): cuint {.cdecl.} = - let ctx = SslContext(context: ssl.SSL_get_SSL_CTX) - let pskString = (ctx.serverGetPskFunc)($identity) - if pskString.len.cint > max_psk_len: - return 0 - copyMem(psk, pskString.cstring, pskString.len) + pskServerCallback - return pskString.len.cuint - - proc `serverGetPskFunc=`*(ctx: SslContext, fun: SslServerGetPskFunc) = + proc `serverGetPskFunc=`*(ctx: SslContext, fun: static SslServerGetPskFunc) = ## Sets function that returns PSK based on the client identity. + ## Call with nil to remove the callback ## ## Only used in PSK ciphersuites. - ctx.getExtraInternal().serverGetPskFunc = fun - ctx.context.SSL_CTX_set_psk_server_callback(if fun == nil: nil - else: pskServerCallback) + ctx.context.SSL_CTX_set_psk_server_callback( + when fun.isNil: nil else: genpskServerCallback(fun)) + + template genpskClientCallback(pskfunc: SslClientGetPskFunc): auto = + proc pskClientCallback(ssl: SslPtr; hint: cstring; identity: cstring; + max_identity_len: cuint; psk: ptr cuchar; + max_psk_len: cuint): cuint {.cdecl.} = + let (identityString, pskString) = pskfunc($hint) + if pskString.len.cuint > max_psk_len: + return 0 + if identityString.len.cuint >= max_identity_len: + return 0 + copyMem(identity, identityString.cstring, identityString.len + 1) # with the last zero byte + copyMem(psk, pskString.cstring, pskString.len) + return pskString.len.cuint + + pskClientCallback + + proc `clientGetPskFunc=`*(ctx: SslContext, fun: static SslClientGetPskFunc) = + ## Sets function that returns the client identity and the PSK based on identity + ## hint from the server. + ## Call with nil to remove the callback. + ## + ## Only used in PSK ciphersuites. + ctx.context.SSL_CTX_set_psk_client_callback( + when fun.isNil: nil else: genpskClientCallback(fun)) proc getPskIdentity*(socket: Socket): string = ## Gets the PSK identity provided by the client. @@ -754,7 +743,6 @@ when defineSsl: socket.sslNoShutdown = false if socket.sslHandle == nil: raiseSSLError() - if SSL_set_fd(socket.sslHandle, socket.fd) != 1: raiseSSLError() diff --git a/tests/stdlib/tnetpsk.nim b/tests/stdlib/tnetpsk.nim new file mode 100644 index 0000000000000..67b5396f318e0 --- /dev/null +++ b/tests/stdlib/tnetpsk.nim @@ -0,0 +1,77 @@ +discard """ + joinable:false + batchable:false + matrix: "--threads:on -d:ssl" + targets: "c cpp" + timeout:5 +""" +import std/net +from std/openssl import SSL_CTX_ctrl + +when defined(osx): + {.passl:"-Wl,-rpath,/usr/local/opt/openssl/lib".} + +# using channels_builtin +var serverChannel: Channel[Port] + +proc clientFunc(identityHint: string): tuple[identity: string, psk: string] = + doAssert identityHint == "bartholomew" + return ("aethelfridda", "aethelfridda-loves-" & identityHint) + +proc client(p: Port){.thread.}= + let context = newContext(cipherList = "PSK-AES256-CBC-SHA") + defer: context.destroyContext() + + # turn off tls1_3 to force connection over psk + doAssert context.context.SSL_CTX_ctrl(124, 0x0303, nil) > 0 # SSL_CTX_set_max_proto_version(TLS1_2) + context.clientGetPskFunc = clientFunc + + let sock = newSocket() + defer: sock.close() + + sock.connect("localhost", p) + context.wrapConnectedSocket(sock, handshakeAsClient) + + sock.send("hello from aethelfridda\r\l") + doAssert sock.recvLine() == "goodbye from bartholomew" + +proc server(){.thread.}= + let context = newContext(cipherList="PSK-AES256-CBC-SHA") + context.pskIdentityHint = "bartholomew" + context.serverGetPskFunc = proc(identity: string): string = identity & "-loves-bartholomew" + context.sessionIdContext= "anything" + + let sock = newSocket() + defer: + sock.close() + context.destroyContext() + sock.bindAddr(Port(0)) + let (_, port) = sock.getLocalAddr() + serverChannel.send(port) + sock.listen() + var client = new(Socket) + sock.accept(client) + sock.setSockOpt(OptReuseAddr, true) + context.wrapConnectedSocket(client, handshakeAsServer) + doAssert client.getPskIdentity() == "aethelfridda" + doAssert recvLine(client) == "hello from aethelfridda" + client.send("goodbye from bartholomew\r\l") + +proc main()= + var + srv:Thread[void] + cli:Thread[Port] + serverChannel.open() + defer: serverChannel.close() + + createThread(srv,server) + + # wait for server to bind a port + let port = serverChannel.recv() + + createThread(cli, client, port) + + joinThread(srv) + joinThread(cli) + +main()