diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 73752ca..19e995c 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -56,4 +56,3 @@ jobs: extra-trusted-public-keys = ocaml.nix-cache.com-1:/xI2h2+56rwFfKyyFVbkJSeGqSIYMC/Je+7XXqGKDIY= - name: "Run nix-build" run: nix-build ./nix/ci/test.nix --argstr ocamlVersion ${{ matrix.setup.ocamlVersion }} - diff --git a/src/ssl.ml b/src/ssl.ml index 56d1f26..53ac299 100644 --- a/src/ssl.ml +++ b/src/ssl.ml @@ -87,9 +87,6 @@ type verify_error = | Error_v_keyusage_no_certsign | Error_v_application_verification -type bigarray = - (char, Bigarray.int8_unsigned_elt, Bigarray.c_layout) Bigarray.Array1.t - external get_error_string : unit -> string = "ocaml_ssl_get_error_string" (** Kept for backwards compatibility *) @@ -211,9 +208,35 @@ type context_type = | Server_context | Both_context -external create_context : +module Modes = struct + type t = int + + (* value taken from openssl/ssl.h *) + let no_mode = 0x000 + let enable_partial_write = 0x001 (* SSL_MODE_ENABLE_PARTIAL_WRITE *) + (*let accept_moving_write_buffer = 0x002: is always set because of GC*) + let auto_retry = 0x004 (* SSL_MODE_AUTO_RETRY *) + let no_auto_chain = 0x008 (* SSL_MODE_RELEASE_BUFFERS *) + let release_buffers = 0x010 (* SSL_MODE_RELEASE_BUFFERS *) + let send_clienthello_time = 0x020 (* SSL_MODE_SEND_CLIENTHELLO_TIME *) + let send_serverhello_time = 0x040 (* SSL_MODE_SEND_SERVERHELLO_TIME *) + let send_fallback_scsv = 0x080 (* SSL_MODE_SEND_FALLBACK_SCSV *) + let async = 0x100 (* SSL_MODE_ASYNC *) + + let (lor) = (lor) + let (land) = (land) + let lnot = lnot + let subset a b = a land (lnot b) = no_mode +end + +external set_mode : context -> Modes.t -> unit = "ocaml_ssl_set_mode" +external clear_mode : context -> Modes.t -> unit = "ocaml_ssl_clear_mode" +external get_mode : context -> Modes.t = "ocaml_ssl_get_mode" + +external raw_create_context : protocol -> context_type + -> Modes.t -> context = "ocaml_ssl_create_context" @@ -454,9 +477,13 @@ external set_hostflags : external set_host : socket -> string -> unit = "ocaml_ssl_set1_host" external set_ip : socket -> string -> unit = "ocaml_ssl_set1_ip" +type bigarray = + (char, Bigarray.int8_unsigned_elt, Bigarray.c_layout) Bigarray.Array1.t + (* Here is the signature of the base communication functions that are implemented below in two versions *) module type Ssl_base = sig + val create_context : ?modes:Modes.t -> protocol -> context_type -> context val connect : socket -> unit val accept : socket -> unit val ssl_shutdown : socket -> bool @@ -471,6 +498,9 @@ end (* Provide the base implementation communication functions that release the OCaml runtime lock, allowing multiple systhreads to execute concurrently. *) module Runtime_unlock_base = struct + let create_context ?(modes = Modes.auto_retry) protocol ctype = + raw_create_context protocol ctype modes + external connect : socket -> unit = "ocaml_ssl_connect" external accept : socket -> unit = "ocaml_ssl_accept" external write : socket -> Bytes.t -> int -> int -> int = "ocaml_ssl_write" @@ -507,6 +537,10 @@ end (* Same as above, but doesn't release the lock. *) module Runtime_lock_base = struct + let create_context ?(modes = Modes.(async lor enable_partial_write)) + protocol ctype = + raw_create_context protocol ctype modes + external get_error : socket -> int -> ssl_error = "ocaml_ssl_get_error_code" [@@noalloc] @@ -559,6 +593,9 @@ module Runtime_lock_base = struct = "ocaml_ssl_write_blocking" [@@noalloc] + (** Allow SSL_write(..., n) to return r with 0 < r < n (i.e. report success + when just a single record has been written *) + let write socket buffer start length = if start < 0 then invalid_arg "Ssl.write: start negative"; if length < 0 then invalid_arg "Ssl.write: length negative"; diff --git a/src/ssl.mli b/src/ssl.mli index f8c24f1..5df3b36 100644 --- a/src/ssl.mli +++ b/src/ssl.mli @@ -89,9 +89,6 @@ type ssl_error = (** See https://www.openssl.org/docs/manmaster/man3/SSL_CTX_set_verify.html *) -type bigarray = - (char, Bigarray.int8_unsigned_elt, Bigarray.c_layout) Bigarray.Array1.t - exception Method_error (** The SSL method could not be initialized. *) @@ -311,8 +308,67 @@ type context_type = | Server_context (** Server connections. *) | Both_context (** Client and server connections. *) -val create_context : protocol -> context_type -> context -(** Create a context. *) +module Modes : sig + (** set of mode *) + type t = int + + val no_mode : t + + (** Allow SSL_write(..., n) to return r with 0 < r < n (i.e. report success + when just a single record has been written *) + val enable_partial_write : t + + (** Never bother the application with retries if the transport is blocking *) + val auto_retry : t + + (** Don't attempt to automatically build certificate chain *) + val no_auto_chain : t + + (** Save RAM by releasing read and write buffers when they're empty. (SSL3 and + TLS only.) Released buffers are freed. *) + val release_buffers : t + + (** Send the current time in the Random fields of the ClientHello and + ServerHello records for compatibility with hypothetical implementations + that require it. *) + val send_clienthello_time : t + val send_serverhello_time : t + + (** Send TLS_FALLBACK_SCSV in the ClientHello. To be set only by + applications that reconnect with a downgraded protocol version; see + draft-ietf-tls-downgrade-scsv-00 for details. DO NOT ENABLE THIS if your + application attempts a normal handshake. Only use this in explicit + fallback retries, following the guidance in + draft-ietf-tls-downgrade-scsv-00. *) + val send_fallback_scsv : t + + (** Support Asynchronous operation *) + val async : t + + (** put togther two sets of modes *) + val ( lor ) : t -> t -> t + + (** conjunction of modes *) + val ( land ) : t -> t -> t + + (** negation of modes *) + val lnot : t -> t + + (** subset on modes*) + val subset : t -> t -> bool +end + +(** Set the given modes in a context (does not clear preset modes) *) +val set_mode : context -> Modes.t -> unit + +(** Clear the given modes in a context *) +val clear_mode : context -> Modes.t -> unit + +(** Get the current mode of a context *) +val get_mode : context -> Modes.t + +val create_context : ?modes:Modes.t -> protocol -> context_type -> context +(** Create a context. Default modes is Modes.(auto_retry) *) val set_min_protocol_version : context -> protocol -> unit (** [set_min_protocol_version ctx proto] sets the minimum supported protocol @@ -571,6 +627,9 @@ val flush : socket -> unit val read : socket -> Bytes.t -> int -> int -> int (** [read sock buf off len] receives data from a connected SSL socket. *) +type bigarray = + (char, Bigarray.int8_unsigned_elt, Bigarray.c_layout) Bigarray.Array1.t + val read_into_bigarray : socket -> bigarray -> int -> int -> int (** [read_into_bigarray sock ba off len] receives data from a connected SSL socket. This function releases the runtime while the read takes place. *) @@ -614,6 +673,10 @@ val output_int : socket -> int -> unit i.e. handling of `EWOULDBLOCK`, `EGAIN`, etc. Additionally, the functions in this module don't perform a copy of application data buffers. *) module Runtime_lock : sig + val create_context : ?modes:Modes.t -> protocol -> context_type -> context + (** same as create_context above, but the default modes are + [Modes.(async lor enable_partial_write] *) + val connect : socket -> unit (** Connect an SSL socket. *) diff --git a/src/ssl_stubs.c b/src/ssl_stubs.c index 180cdaf..6a4ddf2 100644 --- a/src/ssl_stubs.c +++ b/src/ssl_stubs.c @@ -540,8 +540,28 @@ static void set_protocol(SSL_CTX *ssl_context, int protocol) { } } -CAMLprim value ocaml_ssl_create_context(value protocol, value type) { - CAMLparam2(protocol, type); +CAMLprim void ocaml_ssl_set_mode(value ctx, value modes) { + CAMLparam1(ctx); + SSL_CTX_set_mode(Ctx_val(ctx), + SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER | Int_val(modes)); + CAMLreturn0; +} + +CAMLprim void ocaml_ssl_clear_mode(value ctx, value modes) { + CAMLparam1(ctx); + SSL_CTX_clear_mode(Ctx_val(ctx), + ~SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER & Int_val(modes)); + CAMLreturn0; +} + +CAMLprim value ocaml_ssl_get_mode(value ctx, value modes) { + CAMLparam1(ctx); + long r = SSL_CTX_get_mode(Ctx_val(ctx)); + CAMLreturn(Val_int(~SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER & r)); +} + +CAMLprim value ocaml_ssl_create_context(value protocol, value type, value modes) { + CAMLparam3(protocol, type, modes); CAMLlocal1(block); SSL_CTX *ctx; const SSL_METHOD *method = get_method(Int_val(type)); @@ -558,7 +578,7 @@ CAMLprim value ocaml_ssl_create_context(value protocol, value type) { a write retry (since the GC may need to move it). In blocking mode, hide SSL_ERROR_WANT_(READ|WRITE) from us. */ SSL_CTX_set_mode(ctx, - SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER | SSL_MODE_AUTO_RETRY); + SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER | Int_val(modes)); caml_acquire_runtime_system(); block = caml_alloc_custom(&ctx_ops, sizeof(SSL_CTX *), 0, 1); diff --git a/tests/dune b/tests/dune index 26d43f5..d8b3890 100644 --- a/tests/dune +++ b/tests/dune @@ -3,6 +3,11 @@ (modules util) (libraries ssl threads str alcotest)) +(library + (name util_rlock) + (modules util_rlock) + (libraries ssl threads str alcotest)) + (test (name ssl_test) (modules ssl_test) @@ -43,3 +48,9 @@ (modules ssl_io) (libraries ssl alcotest util) (deps ca.pem ca.key server.key server.pem)) + +(test + (name ssl_rlock_io) + (modules ssl_rlock_io) + (libraries ssl alcotest util_rlock) + (deps ca.pem ca.key server.key server.pem)) diff --git a/tests/ssl_rlock_io.ml b/tests/ssl_rlock_io.ml new file mode 100644 index 0000000..0b1b07e --- /dev/null +++ b/tests/ssl_rlock_io.ml @@ -0,0 +1,126 @@ +open Alcotest + +module Ssl = struct + include Ssl + include Ssl.Runtime_lock +end + +module Util = Util_rlock + +let test_verify () = + let addr = Unix.ADDR_INET (Unix.inet_addr_of_string "127.0.0.1", 2342) in + Util.server_thread addr None |> ignore; + + let context = Ssl.create_context TLSv1_3 Client_context in + let ssl = Ssl.open_connection_with_context context addr in + let verify_result = + try + Ssl.verify ssl; + "" + with + | e -> Printexc.to_string e + in + let rec fn () = + try + Ssl.shutdown_connection ssl; + with + Ssl.(Connection_error(Error_want_write|Error_want_read| + Error_want_accept|Error_want_connect|Error_zero_return)) -> + fn () + in + fn (); + check + bool + "no verify errors" + true + (Str.search_forward + (Str.regexp_string "error:00:000000:lib(0)") + verify_result + 0 + > 0) + +let test_set_host () = + let addr = Unix.ADDR_INET (Unix.inet_addr_of_string "127.0.0.1", 2343) in + let pid = Util.server_thread addr None in + + let context = Ssl.create_context TLSv1_3 Client_context in + let domain = Unix.domain_of_sockaddr addr in + let sock = Unix.socket domain Unix.SOCK_STREAM 0 in + let ssl = Ssl.embed_socket sock context in + Ssl.set_host ssl "localhost"; + Unix.connect sock addr; + Unix.set_nonblock sock; + let rec fn () = + try + Ssl.connect ssl; + with + Ssl.(Connection_error(Error_want_write|Error_want_read| + Error_want_accept|Error_want_connect|Error_zero_return)) -> + fn () + in fn (); + + let verify_result = + try + Ssl.verify ssl; + "" + with + | e -> Printexc.to_string e + in + let rec fn () = + try + Ssl.shutdown_connection ssl; + with + Ssl.(Connection_error(Error_want_write|Error_want_read| + Error_want_accept|Error_want_connect|Error_zero_return)) -> + fn () + in + fn (); + check + bool + "no verify errors" + true + (Str.search_forward + (Str.regexp_string "error:00:000000:lib(0)") + verify_result + 0 + > 0); + Unix.kill pid Sys.sigint; + Unix.waitpid [] pid |> ignore + + +let test_read_write () = + let addr = Unix.ADDR_INET (Unix.inet_addr_of_string "127.0.0.1", 2344) in + let pid = Util.server_thread addr (Some (fun _ -> "received")) in + + let context = Ssl.create_context TLSv1_3 Client_context in + let ssl = Ssl.open_connection_with_context context addr in + Unix.set_nonblock (Ssl.file_descr_of_socket ssl); + let send_msg = "send" in + let write_buf = Bytes.create (String.length send_msg) in + let rec fn () = + try Ssl.write ssl write_buf 0 4 |> ignore; + with Ssl.(Write_error(Error_want_write|Error_want_read| + Error_want_accept|Error_want_connect|Error_zero_return)) -> + fn () + in fn (); + let read_buf = Bytes.create 8 in + let rec fn () = + try Ssl.read ssl read_buf 0 8 |> ignore; + with Ssl.(Read_error(Error_want_write|Error_want_read| + Error_want_accept|Error_want_connect|Error_zero_return)) -> + fn () + in fn (); + Ssl.shutdown_connection ssl; + check string "received message" "received" (Bytes.to_string read_buf); + Unix.kill pid Sys.sigint; + Unix.waitpid [] pid |> ignore + +let () = + run + "Ssl io functions with Ssl.Runtime_lock and non blocking socket" + [ ( "IO" + , [ test_case "Verify" `Quick test_verify + ; test_case "Set host" `Quick test_set_host + ; test_case "Read write" `Quick test_read_write + ] ) + ] diff --git a/tests/util_rlock.ml b/tests/util_rlock.ml new file mode 100644 index 0000000..e524bea --- /dev/null +++ b/tests/util_rlock.ml @@ -0,0 +1,98 @@ +module Ssl = struct + include Ssl + + let[@ocaml.alert "-deprecated"] get_error_string = get_error_string +end + +open Ssl +open Ssl.Runtime_lock + +type server_args = + { address : Unix.sockaddr + ; parser : (string -> string) option + } + +let server_rw_loop ssl parser_func = + let rw_loop = ref true in + while !rw_loop do + try + let read_buf = Bytes.create 256 in + let read_bytes = read ssl read_buf 0 256 in + if read_bytes > 0 + then ( + let input = Bytes.to_string read_buf in + let response = parser_func input in + Ssl.write_substring ssl response 0 (String.length response) |> ignore; + Ssl.close_notify ssl |> ignore; + rw_loop := false) + with + | Read_error(Error_want_read|Error_want_accept| + Error_want_connect|Error_want_write|Error_zero_return) -> + () + | Read_error _ -> rw_loop := false + done + +let server_init args = + try + (* Server initialization *) + let socket = Unix.socket Unix.PF_INET Unix.SOCK_STREAM 0 in + Unix.setsockopt socket Unix.SO_REUSEADDR true; + Unix.bind socket args.address; + let context = create_context TLSv1_3 Server_context in + use_certificate context "server.pem" "server.key"; + Ssl.set_context_alpn_select_callback context (fun client_protos -> + List.find_opt (fun opt -> opt = "http/1.1") client_protos); + (* Signal ready and listen for connection *) + Unix.listen socket 1; + Some (socket, context) + with + | exn -> + Printexc.to_string exn |> print_endline; + None + +let server_listen args = + match server_init args with + | None -> + Thread.exit () [@warning "-3"] + | Some (socket, context) -> + let _ = Unix.select [socket] [] [] (-1.0) in + let listen = Unix.accept socket in + Unix.set_nonblock (fst listen); + let ssl = embed_socket (fst listen) context in + let rec fn () = + try + accept ssl; + (* Exit right away unless we need to rw *) + (match args.parser with + | Some parser_func -> server_rw_loop ssl parser_func + | None -> + (); + shutdown ssl; + exit 0) + with + Accept_error(Error_want_read|Error_want_write + |Error_want_connect|Error_want_accept|Error_zero_return) -> + fn () + in + fn () + +let server_thread addr parser = + let args = { address = addr; parser } in + let pid = Unix.fork () in + if pid = 0 then + server_listen args + else + Unix.sleep 1; pid + +let check_ssl_no_error err = + Str.string_partial_match (Str.regexp_string "error:00000000:lib(0)") err 0 + +let[@ocaml.alert "-deprecated"] pp_protocol ppf = function + | SSLv23 -> Format.fprintf ppf "SSLv23" + | SSLv3 -> Format.fprintf ppf "SSLv3" + | TLSv1 -> Format.fprintf ppf "TLSv1" + | TLSv1_1 -> Format.fprintf ppf "TLSv1_1" + | TLSv1_2 -> Format.fprintf ppf "TLSv1_2" + | TLSv1_3 -> Format.fprintf ppf "TLSv1_3" + +let protocol_testable = Alcotest.testable pp_protocol (fun r1 r2 -> r1 == r2)