From 3fabf04f5c623b759d11e4990ffa0d2df9becf55 Mon Sep 17 00:00:00 2001 From: Hannes Mehnert Date: Wed, 14 Apr 2021 00:18:18 +0200 Subject: [PATCH] Remove custom error monad, use Result and Rresult instead --- lib/control.ml | 101 ----------------------------- lib/dune | 2 +- lib/engine.ml | 130 ++++++++++++++++++------------------- lib/engine.mli | 11 ++-- lib/handshake_client.ml | 102 ++++++++++++++--------------- lib/handshake_client.mli | 6 +- lib/handshake_client13.ml | 44 +++++++------ lib/handshake_common.ml | 120 +++++++++++++++++----------------- lib/handshake_server.ml | 132 +++++++++++++++++++------------------- lib/handshake_server.mli | 6 +- lib/handshake_server13.ml | 82 +++++++++++------------ lib/reader.ml | 25 ++++---- lib/reader.mli | 40 ++++++------ lib/state.ml | 4 -- lwt/tls_lwt.ml | 13 ++-- lwt/tls_lwt.mli | 2 +- mirage/tls_mirage.ml | 4 +- tests/feedback.ml | 6 +- tls.opam | 1 + 19 files changed, 366 insertions(+), 465 deletions(-) delete mode 100644 lib/control.ml diff --git a/lib/control.ml b/lib/control.ml deleted file mode 100644 index 6363a84e..00000000 --- a/lib/control.ml +++ /dev/null @@ -1,101 +0,0 @@ -(* - * Monad core - *) - -(* Generic monad core; we could maybe import it from somewhere else. *) -module type Monad = sig - type 'a t - val return : 'a -> 'a t - val bind : 'a t -> ('a -> 'b t) -> 'b t -end - -(* A larger monadic api over the core. *) -module type Monad_ext = sig - type 'a t - val return : 'a -> 'a t - val bind : 'a t -> ('a -> 'b t) -> 'b t - val (>>=) : 'a t -> ('a -> 'b t) -> 'b t - val (>|=) : 'a t -> ('a -> 'b) -> 'b t - val map : ('a -> 'b) -> 'a t -> 'b t - val sequence : 'a t list -> 'a list t - val sequence_ : 'a t list -> unit t - val mapM : ('a -> 'b t) -> 'a list -> 'b list t - val mapM_ : ('a -> 'b t) -> 'a list -> unit t - val foldM : ('a -> 'b -> 'a t) -> 'a -> 'b list -> 'a t -end - -module Monad_ext_make ( M : Monad ) : - Monad_ext with type 'a t = 'a M.t = -struct - type 'a t = 'a M.t - let return = M.return - let bind = M.bind - let (>>=) = M.bind - let map f a = a >>= fun x -> return (f x) - let (>|=) a f = map f a - let rec sequence = function - | [] -> return [] - | m::ms -> m >>= fun m' -> sequence ms >>= fun ms' -> return (m'::ms') - let rec sequence_ = function - | [] -> return () - | m::ms -> m >>= fun _ -> sequence_ ms - let rec mapM f = function - | [] -> return [] - | x::xs -> f x >>= fun x' -> mapM f xs >>= fun xs' -> return (x'::xs') - let rec mapM_ f = function - | [] -> return () - | x::xs -> f x >>= fun _ -> mapM_ f xs - let rec foldM f z = function - | [] -> return z - | x::xs -> f z x >>= fun z' -> foldM f z' xs -end - - -(* - * Concrete monads. - *) - -module Option = Monad_ext_make ( struct - type 'a t = 'a option - let return a = Some a - let bind a f = match a with - | None -> None - | Some x -> f x -end ) - -module type Or_error = sig - type err - type 'a t - val fail : err -> 'a t - val is_success : 'a t -> bool - val is_error : 'a t -> bool - include Monad_ext with type 'a t := 'a t - val guard : bool -> err -> unit t - val or_else : 'a t -> 'a -> 'a - val or_else_f : 'a t -> ('b -> 'a) -> 'b -> 'a -end - -module Or_error_make (M : sig type err end) : - Or_error with type err = M.err and type 'a t = ('a, M.err) result = -struct - type err = M.err - type 'a t = ('a, err) result - let fail e = Error e - let is_success = function - | Ok _ -> true - | Error _ -> false - let is_error = function - | Ok _ -> false - | Error _ -> true - include ( - Monad_ext_make ( struct - type nonrec 'a t = 'a t - let return a = Ok a - let bind a f = match a with - | Ok x -> f x - | Error e -> Error e - end ) : Monad_ext with type 'a t := 'a t) - let guard pred err = if pred then return () else fail err - let or_else m a = match m with Ok x -> x | _ -> a - let or_else_f m f b = match m with Ok x -> x | _ -> f b -end diff --git a/lib/dune b/lib/dune index cda630fb..1b93c602 100644 --- a/lib/dune +++ b/lib/dune @@ -1,5 +1,5 @@ (library (name tls) (public_name tls) - (libraries cstruct cstruct-sexp logs hkdf mirage-crypto mirage-crypto-rng mirage-crypto-pk x509 sexplib domain-name fmt mirage-crypto-ec) + (libraries cstruct cstruct-sexp logs hkdf mirage-crypto mirage-crypto-rng mirage-crypto-pk x509 sexplib domain-name fmt mirage-crypto-ec rresult) (preprocess (pps ppx_sexp_conv ppx_cstruct))) diff --git a/lib/engine.ml b/lib/engine.ml index e0b74f7a..7abbdc7d 100644 --- a/lib/engine.ml +++ b/lib/engine.ml @@ -1,6 +1,9 @@ open Core open State +open Rresult.R.Infix + +let guard p e = if p then Ok () else Error e type state = State.state @@ -84,15 +87,11 @@ let string_of_failure = function "authentication failure: " ^ s | f -> Sexplib.Sexp.to_string_hum (sexp_of_failure f) -type ret = [ - - | `Ok of [ `Ok of state | `Eof | `Alert of Packet.alert_type ] - * [ `Response of Cstruct.t option ] - * [ `Data of Cstruct.t option ] - - | `Fail of failure * [ `Response of Cstruct.t ] -] - +type ret = + ([ `Ok of state | `Eof | `Alert of Packet.alert_type ] + * [ `Response of Cstruct.t option ] + * [ `Data of Cstruct.t option ], + failure * [ `Response of Cstruct.t ]) result let (<+>) = Cstruct.append @@ -195,7 +194,7 @@ let verify_mac sequence mac mac_k ty ver decrypted = let ver = pair_of_tls_version ver in let hdr = Crypto.pseudo_header sequence ty ver (Cstruct.len body) in Crypto.mac mac mac_k hdr body in - guard (Cstruct.equal cmac mmac) (`Fatal `MACMismatch) >|= fun () -> + guard (Cstruct.equal cmac mmac) (`Fatal `MACMismatch) >>| fun () -> body @@ -212,7 +211,7 @@ let decrypt ?(trial = false) (version : tls_version) (st : crypto_state) ty buf (* defense against http://lasecwww.epfl.ch/memo/memo_ssl.shtml 1) in https://www.openssl.org/~bodo/tls-cbc.txt *) let mask_decrypt_failure seq mac mac_k = - compute_mac seq mac mac_k buf >>= fun _ -> fail (`Fatal `MACMismatch) + compute_mac seq mac mac_k buf >>= fun _ -> Error (`Fatal `MACMismatch) in let dec ctx = @@ -224,19 +223,19 @@ let decrypt ?(trial = false) (version : tls_version) (st : crypto_state) ty buf | None -> mask_decrypt_failure seq c.hmac c.hmac_secret | Some (dec, iv') -> - compute_mac seq c.hmac c.hmac_secret dec >|= fun msg -> + compute_mac seq c.hmac c.hmac_secret dec >>| fun msg -> (msg, iv') in ( match c.iv_mode with | Iv iv -> - dec iv buf >|= fun (msg, iv') -> + dec iv buf >>| fun (msg, iv') -> CBC { c with iv_mode = Iv iv' }, msg | Random_iv -> if Cstruct.len buf < Crypto.cbc_block c.cipher then - fail (`Fatal `MACUnderflow) + Error (`Fatal `MACUnderflow) else let iv, buf = Cstruct.split buf (Crypto.cbc_block c.cipher) in - dec iv buf >|= fun (msg, _) -> + dec iv buf >>| fun (msg, _) -> (CBC c, msg) ) | AEAD c -> @@ -249,12 +248,12 @@ let decrypt ?(trial = false) (version : tls_version) (st : crypto_state) ty buf and nonce = Crypto.aead_nonce c.nonce seq in (match Crypto.decrypt_aead ~adata ~cipher:c.cipher ~key:c.cipher_secret ~nonce buf with - | None -> fail (`Fatal `MACMismatch) - | Some x -> return (AEAD c, x)) + | None -> Error (`Fatal `MACMismatch) + | Some x -> Ok (AEAD c, x)) | _ -> let explicit_nonce_len = 8 in if Cstruct.len buf < explicit_nonce_len then - fail (`Fatal `MACUnderflow) + Error (`Fatal `MACUnderflow) else let explicit_nonce, buf = Cstruct.split buf explicit_nonce_len in let adata = @@ -263,8 +262,8 @@ let decrypt ?(trial = false) (version : tls_version) (st : crypto_state) ty buf and nonce = c.nonce <+> explicit_nonce in match Crypto.decrypt_aead ~cipher:c.cipher ~key:c.cipher_secret ~nonce ~adata buf with - | None -> fail (`Fatal `MACMismatch) - | Some x -> return (AEAD c, x) + | None -> Error (`Fatal `MACMismatch) + | Some x -> Ok (AEAD c, x) in match st, version with | None, _ when ty = Packet.APPLICATION_DATA -> @@ -276,22 +275,22 @@ let decrypt ?(trial = false) (version : tls_version) (st : crypto_state) ty buf the APP_DATA above cannot be decrypted or used, so we drop it. *) Ok (None, Cstruct.empty, Packet.APPLICATION_DATA) - | None, _ -> return (st, buf, ty) + | None, _ -> Ok (st, buf, ty) | Some ctx, `TLS_1_3 -> (match ty with - | Packet.CHANGE_CIPHER_SPEC -> return (st, buf, ty) + | Packet.CHANGE_CIPHER_SPEC -> Ok (st, buf, ty) | Packet.APPLICATION_DATA -> (match ctx.cipher_st with | AEAD c -> let nonce = Crypto.aead_nonce c.nonce ctx.sequence in let unpad x = let rec eat = function - | -1 -> fail (`Fatal `MissingContentType) + | -1 -> Error (`Fatal `MissingContentType) | idx -> match Cstruct.get_uint8 x idx with | 0 -> eat (pred idx) | n -> match Packet.int_to_content_type n with - | Some ct -> return (Cstruct.sub x 0 idx, ct) - | None -> fail (`Fatal `MACUnderflow) (* TODO better error? *) + | Some ct -> Ok (Cstruct.sub x 0 idx, ct) + | None -> Error (`Fatal `MACUnderflow) (* TODO better error? *) in eat (pred (Cstruct.len x)) in @@ -301,38 +300,38 @@ let decrypt ?(trial = false) (version : tls_version) (st : crypto_state) ty buf if trial then Ok (Some ctx, Cstruct.empty, Packet.APPLICATION_DATA) else - fail (`Fatal `MACMismatch) + Error (`Fatal `MACMismatch) | Some x -> - unpad x >|= fun (data, ty) -> + unpad x >>| fun (data, ty) -> (Some { ctx with sequence = Int64.succ ctx.sequence }, data, ty)) - | _ -> fail (`Fatal `InvalidMessage)) - | _ -> fail (`Fatal `InvalidMessage)) + | _ -> Error (`Fatal `InvalidMessage)) + | _ -> Error (`Fatal `InvalidMessage)) | Some ctx, _ -> dec ctx >>= fun (st', msg) -> let ctx' = { cipher_st = st' ; sequence = Int64.succ ctx.sequence } in - return (Some ctx', msg, ty) + Ok (Some ctx', msg, ty) (* party time *) -let rec separate_records : Cstruct.t -> ((tls_hdr * Cstruct.t) list * Cstruct.t) eff +let rec separate_records : Cstruct.t -> ((tls_hdr * Cstruct.t) list * Cstruct.t, failure) result = fun buf -> let open Reader in match parse_record buf with - | Ok (`Fragment b) -> return ([], b) + | Ok (`Fragment b) -> Ok ([], b) | Ok (`Record (packet, fragment)) -> - separate_records fragment >|= fun (tl, frag) -> + separate_records fragment >>| fun (tl, frag) -> (packet :: tl, frag) | Error (Overflow x) -> Tracing.cs ~tag:"buf-in" buf ; - fail (`Fatal (`RecordOverflow x)) + Error (`Fatal (`RecordOverflow x)) | Error (UnknownVersion v) -> Tracing.cs ~tag:"buf-in" buf ; - fail (`Fatal (`UnknownRecordVersion v)) + Error (`Fatal (`UnknownRecordVersion v)) | Error (UnknownContent c) -> Tracing.cs ~tag:"buf-in" buf ; - fail (`Fatal (`UnknownContentType c)) + Error (`Fatal (`UnknownContentType c)) | Error e -> Tracing.cs ~tag:"buf-in" buf ; - fail (`Fatal (`ReaderError e)) + Error (`Fatal (`ReaderError e)) let encrypt_records encryptor version records = @@ -368,8 +367,8 @@ module Alert = struct | CLOSE_NOTIFY -> `Eof | _ -> `Alert a_type in Tracing.sexpf ~tag:"alert-out" ~f:sexp_of_tls_alert (Packet.WARNING, Packet.CLOSE_NOTIFY) ; - return (err, [`Record close_notify]) - | Error re -> fail (`Fatal (`ReaderError re)) + Ok (err, [`Record close_notify]) + | Error re -> Error (`Fatal (`ReaderError re)) end let hs_can_handle_appdata s = @@ -396,9 +395,9 @@ let early_data s = let rec separate_handshakes buf = match Reader.parse_handshake_frame buf with - | None, rest -> return ([], rest) + | None, rest -> Ok ([], rest) | Some hs, rest -> - separate_handshakes rest >|= fun (rt, frag) -> + separate_handshakes rest >>| fun (rt, frag) -> (hs :: rt, frag) let handle_change_cipher_spec = function @@ -412,8 +411,8 @@ let handle_change_cipher_spec = function | Client13 (AwaitServerHello13 _) | Server13 AwaitClientHelloHRR13 | Server13 (AwaitClientCertificate13 _) - | Server13 (AwaitClientFinished13 _) -> (fun s _ -> return (s, [])) - | _ -> (fun _ _ -> fail (`Fatal `UnexpectedCCS)) + | Server13 (AwaitClientFinished13 _) -> (fun s _ -> Ok (s, [])) + | _ -> (fun _ _ -> Error (`Fatal `UnexpectedCCS)) and handle_handshake = function | Client cs -> Handshake_client.handle_handshake cs @@ -433,32 +432,33 @@ let handle_packet hs buf = function *) | Packet.ALERT -> - Alert.handle buf >|= fun (err, out) -> + Alert.handle buf >>| fun (err, out) -> (hs, out, None, err) | Packet.APPLICATION_DATA -> if hs_can_handle_appdata hs || (early_data hs && Cstruct.len hs.hs_fragment = 0) then (Tracing.cs ~tag:"application-data-in" buf; - return (hs, [], non_empty buf, `No_err)) + Ok (hs, [], non_empty buf, `No_err)) else - fail (`Fatal `CannotHandleApplicationDataYet) + Error (`Fatal `CannotHandleApplicationDataYet) | Packet.CHANGE_CIPHER_SPEC -> handle_change_cipher_spec hs.machina hs buf - >|= fun (hs, items) -> (hs, items, None, `No_err) + >>| fun (hs, items) -> (hs, items, None, `No_err) | Packet.HANDSHAKE -> separate_handshakes (hs.hs_fragment <+> buf) >>= fun (hss, hs_fragment) -> let hs = { hs with hs_fragment } in - foldM (fun (hs, items) raw -> - handle_handshake hs.machina hs raw - >|= fun (hs', items') -> (hs', items @ items')) - (hs, []) hss - >|= fun (hs, items) -> + List.fold_left (fun acc raw -> + acc >>= fun (hs, items) -> + handle_handshake hs.machina hs raw + >>| fun (hs', items') -> (hs', items @ items')) + (Ok (hs, [])) hss + >>| fun (hs, items) -> (hs, items, None, `No_err) - | Packet.HEARTBEAT -> fail (`Fatal `NoHeartbeat) + | Packet.HEARTBEAT -> Error (`Fatal `NoHeartbeat) let decrement_early_data hs ty buf = let bytes left cipher = @@ -472,7 +472,7 @@ let decrement_early_data hs ty buf = | _ -> `AES_128_GCM_SHA256 (* TODO assert and ensure that all early_data states have a cipher *) in - bytes hs.early_data_left cipher >|= fun early_data_left -> + bytes hs.early_data_left cipher >>| fun early_data_left -> { hs with early_data_left } else Ok hs @@ -484,9 +484,9 @@ let handle_raw_record state (hdr, buf as record : raw_record) = let hs = state.handshake in let version = hs.protocol_version in ( match hs.machina, version with - | Client (AwaitServerHello _), _ -> return () - | Server AwaitClientHello , _ -> return () - | Server13 AwaitClientHelloHRR13, _ -> return () + | Client (AwaitServerHello _), _ -> Ok () + | Server AwaitClientHello , _ -> Ok () + | Server13 AwaitClientHelloHRR13, _ -> Ok () | _ , `TLS_1_3 -> guard (hdr.version = `TLS_1_2) (`Fatal (`BadRecordVersion hdr.version)) | _ , v -> guard (version_eq hdr.version v) (`Fatal (`BadRecordVersion hdr.version)) ) >>= fun () -> @@ -500,7 +500,7 @@ let handle_raw_record state (hdr, buf as record : raw_record) = decrement_early_data hs ty buf >>= fun handshake -> Tracing.sexpf ~tag:"frame-in" ~f:sexp_of_record (ty, dec) ; handle_packet handshake dec ty - >|= fun (handshake, items, data, err) -> + >>| fun (handshake, items, data, err) -> let (encryptor, decryptor, encs) = List.fold_left (fun (enc, dec, es) -> function | `Change_enc enc' -> (Some enc', dec, es) @@ -534,19 +534,19 @@ let handle_tls state buf = Tracing.cs ~tag:"wire-in" buf ; let rec handle_records st = function - | [] -> return (st, [], None, `No_err) + | [] -> Ok (st, [], None, `No_err) | r::rs -> handle_raw_record st r >>= function | (st, raw_rs, data, `No_err) -> - handle_records st rs >|= fun (st', raw_rs', data', err') -> + handle_records st rs >>| fun (st', raw_rs', data', err') -> (st', raw_rs @ raw_rs', maybe_app data data', err') - | res -> return res + | res -> Ok res in match separate_records (state.fragment <+> buf) >>= fun (in_records, fragment) -> handle_records state in_records - >|= fun (state', out_records, data, err) -> + >>| fun (state', out_records, data, err) -> let version = state'.handshake.protocol_version in let resp = match out_records with | [] -> None @@ -569,7 +569,7 @@ let handle_tls state buf = (* Tracing.sexpf ~tag:"state-out" ~f:sexp_of_state state ; *) `Ok state in - `Ok (res, `Response resp, `Data data) + Ok (res, `Response resp, `Data data) | Error x -> let version = state.handshake.protocol_version in let alert = alert_of_failure x in @@ -578,7 +578,7 @@ let handle_tls state buf = let resp = assemble_records version enc in Tracing.sexpf ~tag:"fail-alert-out" ~f:sexp_of_tls_alert (Packet.FATAL, alert) ; Tracing.sexpf ~tag:"failure" ~f:sexp_of_failure x ; - `Fail (x, `Response resp) + Error (x, `Response resp) let send_records (st : state) records = let version = st.handshake.protocol_version in @@ -642,7 +642,7 @@ let reneg ?authenticator ?acceptable_cas ?cert st = | _ -> None let key_update ?(request = true) state = - Handshake_common.output_key_update ~request state >|= fun (state', out) -> + Handshake_common.output_key_update ~request state >>| fun (state', out) -> let _, outbuf = send_records state [out] in state', outbuf diff --git a/lib/engine.mli b/lib/engine.mli index e1559ac0..3c9581b5 100644 --- a/lib/engine.mli +++ b/lib/engine.mli @@ -151,12 +151,11 @@ val sexp_of_failure : failure -> Sexplib.Sexp.t {!state}, an end of file ([`Eof]), or an incoming ([`Alert]). Possibly some [`Response] to the other endpoint is needed, and potentially some [`Data] for the application was received. *) -type ret = [ - | `Ok of [ `Ok of state | `Eof | `Alert of Packet.alert_type ] - * [ `Response of Cstruct.t option ] - * [ `Data of Cstruct.t option ] - | `Fail of failure * [ `Response of Cstruct.t ] -] +type ret = + ([ `Ok of state | `Eof | `Alert of Packet.alert_type ] + * [ `Response of Cstruct.t option ] + * [ `Data of Cstruct.t option ], + failure * [ `Response of Cstruct.t ]) result (** [handle_tls state buffer] is [ret], depending on incoming [state] and [buffer], the result is the appropriate {!ret} *) diff --git a/lib/handshake_client.ml b/lib/handshake_client.ml index 544efcfa..3799867f 100644 --- a/lib/handshake_client.ml +++ b/lib/handshake_client.ml @@ -5,6 +5,8 @@ open State open Handshake_common open Config +open Rresult.R.Infix + let (<+>) = Cstruct.append let state_version state = match state.protocol_version with @@ -79,9 +81,9 @@ let common_server_hello_validation config reneg (sh : server_hello) (ch : client let validate_reneg data = match reneg, data with | Some (cvd, svd), Some x -> guard (Cstruct.equal (cvd <+> svd) x) (`Fatal `InvalidRenegotiation) - | Some _, None -> fail (`Fatal `NoSecureRenegotiation) + | Some _, None -> Error (`Fatal `NoSecureRenegotiation) | None, Some x -> guard (Cstruct.len x = 0) (`Fatal `InvalidRenegotiation) - | None, None -> return () + | None, None -> Ok () in guard (List.mem sh.ciphersuite config.ciphers) (`Error (`NoConfiguredCiphersuite [sh.ciphersuite])) >>= fun () -> @@ -89,7 +91,7 @@ let common_server_hello_validation config reneg (sh : server_hello) (ch : client server_exts_subset_of_client sh.extensions ch.extensions) (`Fatal `InvalidServerHello) >>= fun () -> (match get_alpn_protocol sh with - | None -> return () + | None -> Ok () | Some x -> guard (List.mem x config.alpn_protocols) (`Fatal `InvalidServerHello)) >>= fun () -> validate_reneg (get_secure_renegotiation sh.extensions) @@ -142,7 +144,7 @@ let answer_server_hello state (ch : client_hello) sh secrets raw log = guard (not (Cstruct.equal Packet.downgrade12 piece)) (`Fatal `Downgrade12) >>= fun () -> guard (not (Cstruct.equal Packet.downgrade11 piece)) (`Fatal `Downgrade11) else - return ()) >>= fun () -> + Ok ()) >>= fun () -> let epoch_matches (epoch : epoch_data) = epoch.ciphersuite = sh.ciphersuite && @@ -191,7 +193,7 @@ let answer_server_hello_renegotiate state session (ch : client_hello) sh raw log let validate_keyusage certificate kex = let usage = Ciphersuite.required_usage kex in match certificate with - | None -> fail (`Fatal `NoCertificateReceived) + | None -> Error (`Fatal `NoCertificateReceived) | Some cert -> guard (supports_key_usage ~not_present:true usage cert) (`Fatal `InvalidCertificateUsage) >>= fun () -> @@ -209,9 +211,9 @@ let answer_certificate_RSA state (session : session_data) cs raw log = { session with common_session_data } in ( match session.client_version with - | `TLS_1_3 -> return `TLS_1_2 - | #tls_before_13 as v -> return v - | x -> fail (`Fatal (`NoVersions [ x ])) (* TODO: get rid of this... *) + | `TLS_1_3 -> Ok `TLS_1_2 + | #tls_before_13 as v -> Ok v + | x -> Error (`Fatal (`NoVersions [ x ])) (* TODO: get rid of this... *) ) >>= fun version -> let ver = Writer.assemble_protocol_version version in let premaster = ver <+> Mirage_crypto_rng.generate 46 in @@ -223,13 +225,13 @@ let answer_certificate_RSA state (session : session_data) cs raw log = AwaitCertificateRequestOrServerHelloDone (session, kex, premaster, log @ [raw]) in - return ({ state with machina = Client machina }, []) - | _ -> fail (`Fatal `NotRSACertificate) + Ok ({ state with machina = Client machina }, []) + | _ -> Error (`Fatal `NotRSACertificate) let answer_certificate_DHE state (session : session_data) cs raw log = let cfg = state.config in validate_chain cfg.authenticator cs (host_name_opt cfg.peer_name) >>= fun (peer_certificate, received_certificates, peer_certificate_chain, trust_anchor) -> - validate_keyusage peer_certificate `FFDHE >|= fun () -> + validate_keyusage peer_certificate `FFDHE >>| fun () -> let session = let common_session_data = { session.common_session_data with received_certificates ; peer_certificate ; peer_certificate_chain ; trust_anchor } in { session with common_session_data } @@ -238,20 +240,20 @@ let answer_certificate_DHE state (session : session_data) cs raw log = ({ state with machina = Client machina }, []) let answer_server_key_exchange_DHE state (session : session_data) kex raw log = - let to_fatal r = match r with Ok cs -> return cs | Error er -> fail (`Fatal (`ReaderError er)) in + let to_fatal r = match r with Ok cs -> Ok cs | Error er -> Error (`Fatal (`ReaderError er)) in (if Ciphersuite.ecdhe session.ciphersuite then - to_fatal (Reader.parse_ec_parameters kex) >|= fun (g, share, raw, left) -> + to_fatal (Reader.parse_ec_parameters kex) >>| fun (g, share, raw, left) -> (`Ec g, share, raw, left) else let unpack_dh dh_params = match Crypto.dh_params_unpack dh_params with - | Ok data -> return data - | Error (`Msg m) -> fail (`Fatal (`ReaderError (Unknown m))) + | Ok data -> Ok data + | Error (`Msg m) -> Error (`Fatal (`ReaderError (Reader.Unknown m))) in to_fatal (Reader.parse_dh_parameters kex) >>= fun (dh_params, raw_dh_params, leftover) -> unpack_dh dh_params >>= fun (group, shared) -> guard (Mirage_crypto_pk.Dh.modulus_size group >= Config.min_dh_size) - (`Fatal `InsufficientDH) >|= fun () -> + (`Fatal `InsufficientDH) >>| fun () -> (`Finite_field group, shared, raw_dh_params, leftover) ) >>= fun (group, shared, raw_dh_params, leftover) -> @@ -264,34 +266,34 @@ let answer_server_key_exchange_DHE state (session : session_data) kex raw log = | `Finite_field g -> let secret, client_share = Mirage_crypto_pk.Dh.gen_key g in begin match Mirage_crypto_pk.Dh.shared secret shared with - | None -> fail (`Fatal `InvalidDH) - | Some pms -> return (pms, Writer.assemble_client_dh_key_exchange client_share) + | None -> Error (`Fatal `InvalidDH) + | Some pms -> Ok (pms, Writer.assemble_client_dh_key_exchange client_share) end | `Ec `P256 -> let secret, client_share = P256.Dh.gen_key ~rng in begin match P256.Dh.key_exchange secret shared with - | Error e -> fail (`Fatal (`BadECDH e)) - | Ok pms -> return (pms, Writer.assemble_client_ec_key_exchange client_share) + | Error e -> Error (`Fatal (`BadECDH e)) + | Ok pms -> Ok (pms, Writer.assemble_client_ec_key_exchange client_share) end | `Ec `P384 -> let secret, client_share = P384.Dh.gen_key ~rng in begin match P384.Dh.key_exchange secret shared with - | Error e -> fail (`Fatal (`BadECDH e)) - | Ok pms -> return (pms, Writer.assemble_client_ec_key_exchange client_share) + | Error e -> Error (`Fatal (`BadECDH e)) + | Ok pms -> Ok (pms, Writer.assemble_client_ec_key_exchange client_share) end | `Ec `P521 -> let secret, client_share = P521.Dh.gen_key ~rng in begin match P521.Dh.key_exchange secret shared with - | Error e -> fail (`Fatal (`BadECDH e)) - | Ok pms -> return (pms, Writer.assemble_client_ec_key_exchange client_share) + | Error e -> Error (`Fatal (`BadECDH e)) + | Ok pms -> Ok (pms, Writer.assemble_client_ec_key_exchange client_share) end | `Ec `X25519 -> let secret, client_share = X25519.gen_key ~rng in begin match X25519.key_exchange secret shared with - | Error _ -> fail (`Fatal `InvalidDH) - | Ok pms -> return (pms, Writer.assemble_client_ec_key_exchange client_share) + | Error _ -> Error (`Fatal `InvalidDH) + | Ok pms -> Ok (pms, Writer.assemble_client_ec_key_exchange client_share) end - ) >|= fun (pms, kex) -> + ) >>| fun (pms, kex) -> let machina = AwaitCertificateRequestOrServerHelloDone (session, kex, pms, log @ [raw]) @@ -303,14 +305,14 @@ let answer_certificate_request state (session : session_data) cr kex pms raw log ( match state.protocol_version with | `TLS_1_0 | `TLS_1_1 -> ( match Reader.parse_certificate_request cr with - | Ok (types, cas) -> return (types, None, cas) - | Error re -> fail (`Fatal (`ReaderError re)) ) + | Ok (types, cas) -> Ok (types, None, cas) + | Error re -> Error (`Fatal (`ReaderError re)) ) | `TLS_1_2 -> ( match Reader.parse_certificate_request_1_2 cr with - | Ok (types, sigalgs, cas) -> return (types, Some sigalgs, cas) - | Error re -> fail (`Fatal (`ReaderError re)) ) - | v -> fail (`Fatal (`BadRecordVersion (v :> tls_any_version))) (* never happens *) - ) >|= fun (_types, sigalgs, _cas) -> + | Ok (types, sigalgs, cas) -> Ok (types, Some sigalgs, cas) + | Error re -> Error (`Fatal (`ReaderError re)) ) + | v -> Error (`Fatal (`BadRecordVersion (v :> tls_any_version))) (* never happens *) + ) >>| fun (_types, sigalgs, _cas) -> (* TODO: respect _types and _cas, multiple client certificates *) let own_certificate, own_private_key = match cfg.own_certificates with @@ -342,7 +344,7 @@ let answer_server_hello_done state (session : session_data) sigalgs kex premaste let data = Cstruct.concat to_sign in let ver = state.protocol_version and my_sigalgs = state.config.signature_algorithms in - signature ver data sigalgs my_sigalgs p >|= fun (signature) -> + signature ver data sigalgs my_sigalgs p >>| fun (signature) -> let cert_verify = CertificateVerify signature in let ccert_verify = Writer.assemble_handshake cert_verify in ([ cert ; kex ; cert_verify ], @@ -351,10 +353,10 @@ let answer_server_hello_done state (session : session_data) sigalgs kex premaste | true, None -> let cert = Certificate (Writer.assemble_certificates []) in let ccert = Writer.assemble_handshake cert in - return ([cert ; kex], [ccert ; ckex], log @ [ raw ; ccert ; ckex ], None) + Ok ([cert ; kex], [ccert ; ckex], log @ [ raw ; ccert ; ckex ], None) | false, _ -> - return ([kex], [ckex], log @ [ raw ; ckex ], None) ) - >|= fun (msgs, raw_msgs, raws, cert_verify) -> + Ok ([kex], [ckex], log @ [ raw ; ckex ], None) ) + >>| fun (msgs, raw_msgs, raws, cert_verify) -> let to_fin = raws @ option [] (fun x -> [x]) cert_verify in @@ -397,7 +399,7 @@ let answer_server_finished state (session : session_data) client_verify fin log Handshake_crypto.finished (state_version state) session.ciphersuite session.common_session_data.master_secret "server finished" log in guard (Cstruct.equal computed fin) (`Fatal `BadFinished) >>= fun () -> - guard (Cstruct.len state.hs_fragment = 0) (`Fatal `HandshakeFragmentsNotEmpty) >|= fun () -> + guard (Cstruct.len state.hs_fragment = 0) (`Fatal `HandshakeFragmentsNotEmpty) >>| fun () -> let machina = Established and session = { session with renegotiation = (client_verify, computed) } in ({ state with machina = Client machina ; session = `TLS session :: state.session }, []) @@ -408,7 +410,7 @@ let answer_server_finished_resume state (session : session_data) fin raw log = (checksum "client finished" (log @ [raw]), checksum "server finished" log) in guard (Cstruct.equal server fin) (`Fatal `BadFinished) >>= fun () -> - guard (Cstruct.len state.hs_fragment = 0) (`Fatal `HandshakeFragmentsNotEmpty) >|= fun () -> + guard (Cstruct.len state.hs_fragment = 0) (`Fatal `HandshakeFragmentsNotEmpty) >>| fun () -> let machina = Established and session = { session with renegotiation = (client, server) } in @@ -431,30 +433,30 @@ let answer_hello_request state = match state.config.use_reneg, state.session with | true , `TLS x :: _ -> let ext = `SecureRenegotiation (fst x.renegotiation) in - return (produce_client_hello x state.config [ext]) - | true , _ -> fail (`Fatal `InvalidSession) (* I'm pretty sure this can be an assert false *) + Ok (produce_client_hello x state.config [ext]) + | true , _ -> Error (`Fatal `InvalidSession) (* I'm pretty sure this can be an assert false *) | false, _ -> let no_reneg = Writer.assemble_alert ~level:Packet.WARNING Packet.NO_RENEGOTIATION in Tracing.sexpf ~tag:"alert-out" ~f:sexp_of_tls_alert (Packet.WARNING, Packet.NO_RENEGOTIATION) ; - return (state, [`Record (Packet.ALERT, no_reneg)]) + Ok (state, [`Record (Packet.ALERT, no_reneg)]) let handle_change_cipher_spec cs state packet = match Reader.parse_change_cipher_spec packet, cs with | Ok (), AwaitServerChangeCipherSpec (session, server_ctx, client_verify, log) -> - guard (Cstruct.len state.hs_fragment = 0) (`Fatal `HandshakeFragmentsNotEmpty) >|= fun () -> + guard (Cstruct.len state.hs_fragment = 0) (`Fatal `HandshakeFragmentsNotEmpty) >>| fun () -> let machina = AwaitServerFinished (session, client_verify, log) in Tracing.cs ~tag:"change-cipher-spec-in" packet ; ({ state with machina = Client machina }, [`Change_dec server_ctx]) | Ok (), AwaitServerChangeCipherSpecResume (session, client_ctx, server_ctx, log) -> - guard (Cstruct.len state.hs_fragment = 0) (`Fatal `HandshakeFragmentsNotEmpty) >|= fun () -> + guard (Cstruct.len state.hs_fragment = 0) (`Fatal `HandshakeFragmentsNotEmpty) >>| fun () -> let ccs = change_cipher_spec in let machina = AwaitServerFinishedResume (session, log) in Tracing.cs ~tag:"change-cipher-spec-in" packet ; Tracing.cs ~tag:"change-cipher-spec-out" packet ; ({ state with machina = Client machina }, [`Record ccs ; `Change_enc client_ctx; `Change_dec server_ctx]) - | Error re, _ -> fail (`Fatal (`ReaderError re)) - | _ -> fail (`Fatal `UnexpectedCCS) + | Error re, _ -> Error (`Fatal (`ReaderError re)) + | _ -> Error (`Fatal `UnexpectedCCS) let handle_handshake cs hs buf = let open Reader in @@ -471,11 +473,11 @@ let handle_handshake cs hs buf = | AwaitCertificate_RSA (session, log), Certificate cs -> (match Reader.parse_certificates cs with | Ok cs -> answer_certificate_RSA hs session cs buf log - | Error re -> fail (`Fatal (`ReaderError re))) + | Error re -> Error (`Fatal (`ReaderError re))) | AwaitCertificate_DHE (session, log), Certificate cs -> (match Reader.parse_certificates cs with | Ok cs -> answer_certificate_DHE hs session cs buf log - | Error re -> fail (`Fatal (`ReaderError re))) + | Error re -> Error (`Fatal (`ReaderError re))) | AwaitServerKeyExchange_DHE (session, log), ServerKeyExchange kex -> answer_server_key_exchange_DHE hs session kex buf log | AwaitCertificateRequestOrServerHelloDone (session, kex, pms, log), CertificateRequest cr -> @@ -490,5 +492,5 @@ let handle_handshake cs hs buf = answer_server_finished_resume hs session fin buf log | Established, HelloRequest -> answer_hello_request hs - | _, hs -> fail (`Fatal (`UnexpectedHandshake hs)) ) - | Error re -> fail (`Fatal (`ReaderError re)) + | _, hs -> Error (`Fatal (`UnexpectedHandshake hs)) ) + | Error re -> Error (`Fatal (`ReaderError re)) diff --git a/lib/handshake_client.mli b/lib/handshake_client.mli index 72f36622..f10a0c61 100644 --- a/lib/handshake_client.mli +++ b/lib/handshake_client.mli @@ -2,6 +2,6 @@ open Core open State val default_client_hello : Config.config -> (client_hello * tls_version * (group * dh_secret) list) -val handle_change_cipher_spec : client_handshake_state -> handshake_state -> Cstruct.t -> handshake_return eff -val handle_handshake : client_handshake_state -> handshake_state -> Cstruct.t -> handshake_return eff -val answer_hello_request : handshake_state -> handshake_return eff +val handle_change_cipher_spec : client_handshake_state -> handshake_state -> Cstruct.t -> (handshake_return, failure) result +val handle_handshake : client_handshake_state -> handshake_state -> Cstruct.t -> (handshake_return, failure) result +val answer_hello_request : handshake_state -> (handshake_return, failure) result diff --git a/lib/handshake_client13.ml b/lib/handshake_client13.ml index ef2ee5ea..24eeeb80 100644 --- a/lib/handshake_client13.ml +++ b/lib/handshake_client13.ml @@ -5,10 +5,12 @@ open Core open Handshake_common open Config +open Rresult.R.Infix + let answer_server_hello state ch (sh : server_hello) secrets raw log = (* assume SH valid, version 1.3, extensions are subset *) match Ciphersuite.ciphersuite_to_ciphersuite13 sh.ciphersuite with - | None -> fail (`Fatal `InvalidServerHello) + | None -> Error (`Fatal `InvalidServerHello) | Some cipher -> guard (List.mem cipher (ciphers13 state.config)) (`Fatal `InvalidServerHello) >>= fun () -> guard (Cstruct.len state.hs_fragment = 0) (`Fatal `HandshakeFragmentsNotEmpty) >>= fun () -> @@ -16,10 +18,10 @@ let answer_server_hello state ch (sh : server_hello) secrets raw log = (* TODO: PSK *) (* TODO: early_secret elsewhere *) match map_find ~f:(function `KeyShare ks -> Some ks | _ -> None) sh.extensions with - | None -> fail (`Fatal `InvalidServerHello) + | None -> Error (`Fatal `InvalidServerHello) | Some (g, share) -> match List.find_opt (fun (g', _) -> g = g') secrets with - | None -> fail (`Fatal `InvalidServerHello) + | None -> Error (`Fatal `InvalidServerHello) | Some (_, secret) -> Handshake_crypto13.dh_shared secret share >>= fun shared -> let hlen = Mirage_crypto.Hash.digest_size (Ciphersuite.hash13 cipher) in @@ -27,9 +29,9 @@ let answer_server_hello state ch (sh : server_hello) secrets raw log = map_find ~f:(function `PreSharedKey idx -> Some idx | _ -> None) sh.extensions, state.config.Config.cached_ticket with - | None, _ | _, None -> return (Cstruct.create hlen, false) + | None, _ | _, None -> Ok (Cstruct.create hlen, false) | Some idx, Some (psk, _epoch) -> - guard (idx = 0) (`Fatal `InvalidServerHello) >|= fun () -> + guard (idx = 0) (`Fatal `InvalidServerHello) >>| fun () -> psk.secret, true) >>= fun (psk, resumed) -> let early_secret = Handshake_crypto13.(derive (empty cipher) psk) in let hs_secret = Handshake_crypto13.derive early_secret shared in @@ -82,7 +84,7 @@ let answer_hello_retry_request state (ch : client_hello) hrr _secrets raw log = let st = AwaitServerHello13 (new_ch, [secret], Cstruct.concat [ ch0_hdr ; ch0_data ; raw ; new_ch_raw ]) in Tracing.sexpf ~tag:"handshake-out" ~f:sexp_of_tls_handshake (ClientHello new_ch); - return ({ state with machina = Client13 st ; protocol_version = `TLS_1_3 }, [`Record (Packet.HANDSHAKE, new_ch_raw)]) + Ok ({ state with machina = Client13 st ; protocol_version = `TLS_1_3 }, [`Record (Packet.HANDSHAKE, new_ch_raw)]) let answer_encrypted_extensions state (session : session_data13) server_hs_secret client_hs_secret ee raw log = (* TODO we now know: - hostname - early_data (preserve this in session!!) *) @@ -98,7 +100,7 @@ let answer_encrypted_extensions state (session : session_data13) server_hs_secre else AwaitServerCertificateRequestOrCertificate13 (session, server_hs_secret, client_hs_secret, log <+> raw) in - return ({ state with machina = Client13 st }, []) + Ok ({ state with machina = Client13 st }, []) let answer_certificate state (session : session_data13) server_hs_secret client_hs_secret certs raw log = (* certificates are (cs, ext) list - ext being statusrequest or signed_cert_timestamp *) @@ -113,7 +115,7 @@ let answer_certificate state (session : session_data13) server_hs_secret client_ { session with common_session_data13 } in let st = AwaitServerCertificateVerify13 (session, server_hs_secret, client_hs_secret, log <+> raw) in - return ({ state with machina = Client13 st }, []) + Ok ({ state with machina = Client13 st }, []) let answer_certificate_verify (state : handshake_state) (session : session_data13) server_hs_secret client_hs_secret cv raw log = let tbs = Mirage_crypto.Hash.digest (Ciphersuite.hash13 session.ciphersuite13) log in @@ -122,7 +124,7 @@ let answer_certificate_verify (state : handshake_state) (session : session_data1 state.config.signature_algorithms cv tbs session.common_session_data13.peer_certificate >>= fun () -> let st = AwaitServerFinished13 (session, server_hs_secret, client_hs_secret, log <+> raw) in - return ({ state with machina = Client13 st }, []) + Ok ({ state with machina = Client13 st }, []) let answer_certificate_request (state : handshake_state) (session : session_data13) server_hs_secret client_hs_secret _extensions raw log = (* TODO respect extensions (sigalg, CA, OIDfilter)! *) @@ -131,7 +133,7 @@ let answer_certificate_request (state : handshake_state) (session : session_data { session with common_session_data13 } in let st = AwaitServerCertificate13 (session, server_hs_secret, client_hs_secret, log <+> raw) in - return ({ state with machina = Client13 st }, []) + Ok ({ state with machina = Client13 st }, []) let answer_finished state (session : session_data13) server_hs_secret client_hs_secret fin raw log = let hash = Ciphersuite.hash13 session.ciphersuite13 in @@ -157,17 +159,17 @@ let answer_finished state (session : session_data13) server_hs_secret client_hs_ let log = log <+> cert_raw in (match own_private_key with | None -> - return ([cert_raw], log) + Ok ([cert_raw], log) | Some priv -> (* TODO use sig_algs instead of None below as requested in server's certificaterequest *) let tbs = Mirage_crypto.Hash.digest hash log in signature `TLS_1_3 ~context_string:"TLS 1.3, client CertificateVerify" - tbs None state.config.Config.signature_algorithms priv >|= fun signed -> + tbs None state.config.Config.signature_algorithms priv >>| fun signed -> let cv = CertificateVerify signed in let cv_raw = Writer.assemble_handshake cv in ([ cert_raw ; cv_raw ], log <+> cv_raw)) else - return ([], log)) >|= fun (c_cv, log) -> + Ok ([], log)) >>| fun (c_cv, log) -> let myfin = Handshake_crypto13.finished hash client_hs_secret log in let mfin = Writer.assemble_handshake (Finished myfin) in @@ -203,7 +205,7 @@ let answer_session_ticket state st = let psk = { identifier = st.ticket ; obfuscation = st.age_add ; secret ; lifetime = st.lifetime ; early_data ; issued_at } in cache.ticket_granted psk epoch | _ -> ()); - return (state, []) + Ok (state, []) let handle_key_update state req = match state.session with @@ -227,7 +229,7 @@ let handle_key_update state req = let session = `TLS13 session' :: state.session in let state' = { state with machina = Server13 Established13 ; session } in Ok (state', `Change_dec server_ctx :: out) - | _ -> fail (`Fatal `InvalidSession) + | _ -> Error (`Fatal `InvalidSession) let handle_handshake cs hs buf = let open Reader in @@ -242,8 +244,8 @@ let handle_handshake cs hs buf = | AwaitServerCertificateRequestOrCertificate13 (sd, es, ss, log), CertificateRequest cr -> (match parse_certificate_request_1_3 cr with | Ok (None, exts) -> answer_certificate_request hs sd es ss exts buf log - | Ok (Some _, _) -> fail (`Fatal `InvalidMessage) (* during handshake, context must be empty! *) - | Error re -> fail (`Fatal (`ReaderError re))) + | Ok (Some _, _) -> Error (`Fatal `InvalidMessage) (* during handshake, context must be empty! *) + | Error re -> Error (`Fatal (`ReaderError re))) | AwaitServerCertificateRequestOrCertificate13 (sd, es, ss, log), Certificate cs | AwaitServerCertificate13 (sd, es, ss, log), Certificate cs -> (match parse_certificates_1_3 cs with @@ -251,15 +253,15 @@ let handle_handshake cs hs buf = (* during handshake, context must be empty! and we'll not get any new certificate from server *) guard (Cstruct.len con = 0) (`Fatal `InvalidMessage) >>= fun () -> answer_certificate hs sd es ss cs buf log - | Error re -> fail (`Fatal (`ReaderError re))) + | Error re -> Error (`Fatal (`ReaderError re))) | AwaitServerCertificateVerify13 (sd, es, ss, log), CertificateVerify cv -> answer_certificate_verify hs sd es ss cv buf log | AwaitServerFinished13 (sd, es, ss, log), Finished fin -> answer_finished hs sd es ss fin buf log | Established13, SessionTicket se -> answer_session_ticket hs se | Established13, CertificateRequest _ -> - fail (`Fatal (`UnexpectedHandshake handshake)) (* TODO send out C, CV, F *) + Error (`Fatal (`UnexpectedHandshake handshake)) (* TODO send out C, CV, F *) | Established13, KeyUpdate req -> handle_key_update hs req - | _, hs -> fail (`Fatal (`UnexpectedHandshake hs))) - | Error re -> fail (`Fatal (`ReaderError re)) + | _, hs -> Error (`Fatal (`UnexpectedHandshake hs))) + | Error re -> Error (`Fatal (`ReaderError re)) diff --git a/lib/handshake_common.ml b/lib/handshake_common.ml index 110cdcb9..8febdc6d 100644 --- a/lib/handshake_common.ml +++ b/lib/handshake_common.ml @@ -4,6 +4,10 @@ open State open Mirage_crypto +open Rresult.R.Infix + +let guard p e = if p then Ok () else Error e + let src = Logs.Src.create "handshake" ~doc:"TLS handshake" module Log = (val Logs.src_log src : Logs.LOG) @@ -55,10 +59,10 @@ let rec find_matching host certs = let agreed_cert certs ?f ?signature_algorithms hostname = let match_host ?default host certs = match find_matching host certs with - | Some x -> return x + | Some x -> Ok x | None -> match default with - | Some c -> return c - | None -> fail (`Error (`NoMatchingCertificateFound (Domain_name.to_string host))) + | Some c -> Ok c + | None -> Error (`Error (`NoMatchingCertificateFound (Domain_name.to_string host))) in let filter = function | ([], _) -> false (* cannot happen, TODO: adapt types to avoid this case *) @@ -73,28 +77,28 @@ let agreed_cert certs ?f ?signature_algorithms hostname = | Some s -> List.exists (pk_matches_sa (snd c)) s in match certs, hostname with - | `None, _ -> fail (`Error `NoCertificateConfigured) + | `None, _ -> Error (`Error `NoCertificateConfigured) | `Single c, _ -> - if filter c && filter_sigalg c then return c else fail (`Error `CouldntSelectCertificate) + if filter c && filter_sigalg c then Ok c else Error (`Error `CouldntSelectCertificate) | `Multiple_default (c, _), None -> - if filter c && filter_sigalg c then return c else fail (`Error `CouldntSelectCertificate) + if filter c && filter_sigalg c then Ok c else Error (`Error `CouldntSelectCertificate) | `Multiple_default (c, cs), Some h -> let default = if filter c && filter_sigalg c then Some c else None in begin match default, List.filter (fun c -> filter c && filter_sigalg c) cs with | Some d, cs -> match_host ~default:d h cs | None, c :: cs -> match_host ~default:c h (c::cs) - | None, [] -> fail (`Error `CouldntSelectCertificate) + | None, [] -> Error (`Error `CouldntSelectCertificate) end | `Multiple cs, None -> begin match List.filter (fun c -> filter c && filter_sigalg c) cs with - | cert :: _ -> return cert - | _ -> fail (`Error `CouldntSelectCertificate) + | cert :: _ -> Ok cert + | _ -> Error (`Error `CouldntSelectCertificate) end | `Multiple cs, Some h -> match List.filter (fun c -> filter c && filter_sigalg c) cs with - | [ cert ] -> return cert + | [ cert ] -> Ok cert | c :: cs -> match_host ~default:c h (c :: cs) - | [] -> fail (`Error `CouldntSelectCertificate) + | [] -> Error (`Error `CouldntSelectCertificate) let get_secure_renegotiation exts = map_find @@ -106,15 +110,15 @@ let get_alpn_protocols (ch : client_hello) = let alpn_protocol config ch = match config.Config.alpn_protocols, get_alpn_protocols ch with - | _, None | [], _ -> return None + | _, None | [], _ -> Ok None | configured, Some client -> match first_match client configured with - | Some proto -> return (Some proto) + | Some proto -> Ok (Some proto) | None -> (* RFC7301 Section 3.2: In the event that the server supports no protocols that the client advertises, then the server SHALL respond with a fatal "no_application_protocol" alert. *) - fail (`Fatal `NoApplicationProtocol) + Error (`Fatal `NoApplicationProtocol) let get_alpn_protocol (sh : server_hello) = map_find ~f:(function `ALPN protocol -> Some protocol | _ -> None) sh.extensions @@ -396,29 +400,29 @@ let signature version ?context_string data client_sig_algs signature_algorithms begin match private_key with | `RSA key -> let data = Hash.MD5.digest data <+> Hash.SHA1.digest data in - return (Mirage_crypto_pk.Rsa.PKCS1.sig_encode ~key data) + Ok (Mirage_crypto_pk.Rsa.PKCS1.sig_encode ~key data) | `P256 key -> let data = Hash.SHA1.digest data in - return (ecdsa_sig_to_cstruct (Mirage_crypto_ec.P256.Dsa.sign ~key data)) + Ok (ecdsa_sig_to_cstruct (Mirage_crypto_ec.P256.Dsa.sign ~key data)) | `P384 key -> let data = Hash.SHA1.digest data in - return (ecdsa_sig_to_cstruct (Mirage_crypto_ec.P384.Dsa.sign ~key data)) + Ok (ecdsa_sig_to_cstruct (Mirage_crypto_ec.P384.Dsa.sign ~key data)) | `P521 key -> let data = Hash.SHA1.digest data in - return (ecdsa_sig_to_cstruct (Mirage_crypto_ec.P521.Dsa.sign ~key data)) - | `ED25519 key -> return (Mirage_crypto_ec.Ed25519.sign ~key data) - | _ -> fail (`Error (`NoConfiguredSignatureAlgorithm [])) - end >|= fun signed -> + Ok (ecdsa_sig_to_cstruct (Mirage_crypto_ec.P521.Dsa.sign ~key data)) + | `ED25519 key -> Ok (Mirage_crypto_ec.Ed25519.sign ~key data) + | _ -> Error (`Error (`NoConfiguredSignatureAlgorithm [])) + end >>| fun signed -> Writer.assemble_digitally_signed signed | `TLS_1_2 -> let sig_alg ec = match client_sig_algs with - | None -> return (if ec then `ECDSA_SECP256R1_SHA1 else `RSA_PKCS1_SHA1) + | None -> Ok (if ec then `ECDSA_SECP256R1_SHA1 else `RSA_PKCS1_SHA1) | Some client_algos -> let f = if ec then (fun sa -> not (rsa_sigalg sa)) else rsa_sigalg in match first_match client_algos (List.filter f signature_algorithms) with - | None -> fail (`Error (`NoConfiguredSignatureAlgorithm client_algos)) - | Some sig_alg -> return sig_alg + | None -> Error (`Error (`NoConfiguredSignatureAlgorithm client_algos)) + | Some sig_alg -> Ok sig_alg in ( match private_key with | `RSA key -> @@ -431,29 +435,29 @@ let signature version ?context_string data client_sig_algs signature_algorithms | `PSS, hash_alg -> let module H = (val (Hash.module_of hash_alg)) in let module PSS = Mirage_crypto_pk.Rsa.PSS(H) in - return (sig_alg, PSS.sign ~key (`Message data)) + Ok (sig_alg, PSS.sign ~key (`Message data)) | `PKCS1, hash_alg -> let hash = Hash.digest hash_alg data in let cs = X509.Certificate.encode_pkcs1_digest_info (hash_alg, hash) in - return (sig_alg, Mirage_crypto_pk.Rsa.PKCS1.sig_encode ~key cs) - | _ -> fail (`Error (`NoConfiguredSignatureAlgorithm [])) + Ok (sig_alg, Mirage_crypto_pk.Rsa.PKCS1.sig_encode ~key cs) + | _ -> Error (`Error (`NoConfiguredSignatureAlgorithm [])) end | `P256 key -> - sig_alg true >|= fun sig_alg -> + sig_alg true >>| fun sig_alg -> let hash = Hash.digest (hash_of_signature_algorithm sig_alg) data in sig_alg, ecdsa_sig_to_cstruct (Mirage_crypto_ec.P256.Dsa.sign ~key hash) | `P384 key -> - sig_alg true >|= fun sig_alg -> + sig_alg true >>| fun sig_alg -> let hash = Hash.digest (hash_of_signature_algorithm sig_alg) data in sig_alg, ecdsa_sig_to_cstruct (Mirage_crypto_ec.P384.Dsa.sign ~key hash) | `P521 key -> - sig_alg true >|= fun sig_alg -> + sig_alg true >>| fun sig_alg -> let hash = Hash.digest (hash_of_signature_algorithm sig_alg) data in sig_alg, ecdsa_sig_to_cstruct (Mirage_crypto_ec.P521.Dsa.sign ~key hash) | `ED25519 key -> - sig_alg true >|= fun sig_alg -> + sig_alg true >>| fun sig_alg -> sig_alg, Mirage_crypto_ec.Ed25519.sign ~key data - | _ -> fail (`Error (`NoConfiguredSignatureAlgorithm [])) ) >|= fun (sig_alg, signature) -> + | _ -> Error (`Error (`NoConfiguredSignatureAlgorithm [])) ) >>| fun (sig_alg, signature) -> Writer.assemble_digitally_signed_1_2 sig_alg signature | `TLS_1_3 -> let to_sign = @@ -461,48 +465,48 @@ let signature version ?context_string data client_sig_algs signature_algorithms prefix <+> data in (match client_sig_algs with - | None -> fail (`Error (`NoConfiguredSignatureAlgorithm [])) + | None -> Error (`Error (`NoConfiguredSignatureAlgorithm [])) (* 8446 4.2.3 "client MUST send signatureAlgorithms" *) | Some client_algos -> let sa = List.filter tls13_sigalg signature_algorithms in let sa = List.filter (pk_matches_sa private_key) sa in match first_match client_algos sa with - | None -> fail (`Error (`NoConfiguredSignatureAlgorithm client_algos)) - | Some sig_alg -> return sig_alg) >>= fun sig_alg -> + | None -> Error (`Error (`NoConfiguredSignatureAlgorithm client_algos)) + | Some sig_alg -> Ok sig_alg) >>= fun sig_alg -> let hash_alg = hash_of_signature_algorithm sig_alg in (match signature_scheme_of_signature_algorithm sig_alg, private_key with | `PSS, `RSA key -> let module H = (val (Hash.module_of hash_alg)) in let module PSS = Mirage_crypto_pk.Rsa.PSS(H) in - return (PSS.sign ~key (`Message to_sign)) + Ok (PSS.sign ~key (`Message to_sign)) | `ECDSA, `P256 key -> let hash = Hash.digest hash_alg to_sign in - return (ecdsa_sig_to_cstruct (Mirage_crypto_ec.P256.Dsa.sign ~key hash)) + Ok (ecdsa_sig_to_cstruct (Mirage_crypto_ec.P256.Dsa.sign ~key hash)) | `ECDSA, `P384 key -> let hash = Hash.digest hash_alg to_sign in - return (ecdsa_sig_to_cstruct (Mirage_crypto_ec.P384.Dsa.sign ~key hash)) + Ok (ecdsa_sig_to_cstruct (Mirage_crypto_ec.P384.Dsa.sign ~key hash)) | `ECDSA, `P521 key -> let hash = Hash.digest hash_alg to_sign in - return (ecdsa_sig_to_cstruct (Mirage_crypto_ec.P521.Dsa.sign ~key hash)) + Ok (ecdsa_sig_to_cstruct (Mirage_crypto_ec.P521.Dsa.sign ~key hash)) | `EdDSA, `ED25519 key -> - return (Mirage_crypto_ec.Ed25519.sign ~key to_sign) - | _ -> fail (`Error (`NoConfiguredSignatureAlgorithm []))) >|= fun signature -> + Ok (Mirage_crypto_ec.Ed25519.sign ~key to_sign) + | _ -> Error (`Error (`NoConfiguredSignatureAlgorithm []))) >>| fun signature -> Writer.assemble_digitally_signed_1_2 sig_alg signature end with Mirage_crypto_pk.Rsa.Insufficient_key -> - fail (`Fatal `KeyTooSmall) + Error (`Fatal `KeyTooSmall) let peer_key = function - | None -> fail (`Fatal `NoCertificateReceived) - | Some cert -> return (X509.Certificate.public_key cert) + | None -> Error (`Fatal `NoCertificateReceived) + | Some cert -> Ok (X509.Certificate.public_key cert) let verify_digitally_signed version ?context_string sig_algs data signature_data certificate = peer_key certificate >>= fun pubkey -> let decode_pkcs1_signature key raw_signature = match Mirage_crypto_pk.Rsa.PKCS1.sig_decode ~key raw_signature with - | Some signature -> return signature - | None -> fail (`Fatal `SignatureVerificationFailed) + | Some signature -> Ok signature + | None -> Error (`Fatal `SignatureVerificationFailed) in match version with @@ -537,7 +541,7 @@ let verify_digitally_signed version ?context_string sig_algs data signature_data (`Fatal `SignatureVerificationFailed) | _ -> Error (`Fatal `UnsupportedSignatureScheme) end - | Error re -> fail (`Fatal (`ReaderError re)) ) + | Error re -> Error (`Fatal (`ReaderError re)) ) | `TLS_1_2 -> ( match Reader.parse_digitally_signed_1_2 data with | Ok (sig_alg, signature) -> @@ -555,7 +559,7 @@ let verify_digitally_signed version ?context_string sig_algs data signature_data | Ok (hash_algo', target) when hash_algo = hash_algo' -> let cs = Hash.digest hash_algo data in guard (Cstruct.equal target cs) (`Fatal `SignatureVerificationFailed) - | _ -> fail (`Fatal `HashAlgorithmMismatch) + | _ -> Error (`Fatal `HashAlgorithmMismatch) in decode_pkcs1_signature key signature >>= fun raw -> compare_hashes raw signature_data @@ -578,9 +582,9 @@ let verify_digitally_signed version ?context_string sig_algs data signature_data let msg = signature_data in guard (Mirage_crypto_ec.Ed25519.verify ~key signature ~msg) (`Fatal `SignatureVerificationFailed) - | _ -> fail (`Fatal `UnsupportedSignatureScheme) + | _ -> Error (`Fatal `UnsupportedSignatureScheme) end - | Error re -> fail (`Fatal (`ReaderError re)) ) + | Error re -> Error (`Fatal (`ReaderError re)) ) | `TLS_1_3 -> ( match Reader.parse_digitally_signed_1_2 data with | Ok (sig_alg, signature) -> @@ -615,15 +619,15 @@ let verify_digitally_signed version ?context_string sig_algs data signature_data guard (Mirage_crypto_ec.Ed25519.verify ~key signature ~msg:data) (`Fatal `SignatureVerificationFailed) | _ -> - fail (`Fatal `UnsupportedSignatureScheme) + Error (`Fatal `UnsupportedSignatureScheme) end - | Error re -> fail (`Fatal (`ReaderError re))) + | Error re -> Error (`Fatal (`ReaderError re))) let validate_chain authenticator certificates hostname = let authenticate authenticator host certificates = match authenticator ~host certificates with - | Error err -> fail (`Error (`AuthenticationFailure err)) - | Ok anchor -> return anchor + | Error err -> Error (`Error (`AuthenticationFailure err)) + | Ok anchor -> Ok anchor and key_size min cs = let check c = @@ -638,7 +642,7 @@ let validate_chain authenticator certificates hostname = let f cs = match X509.Certificate.decode_der cs with Ok c -> Some c | _ -> None in filter_map ~f certs in - guard (List.length certs = List.length certificates) (`Fatal `BadCertificateChain) >|= fun () -> + guard (List.length certs = List.length certificates) (`Fatal `BadCertificateChain) >>| fun () -> certificates in @@ -651,10 +655,10 @@ let validate_chain authenticator certificates hostname = | [] -> None in match authenticator with - | None -> return (server, certs, [], None) + | None -> Ok (server, certs, [], None) | Some authenticator -> authenticate authenticator hostname certs >>= fun anchor -> - key_size Config.min_rsa_key_size certs >|= fun () -> + key_size Config.min_rsa_key_size certs >>| fun () -> Utils.option (server, certs, [], None) (fun (chain, anchor) -> (server, certs, chain, Some anchor)) @@ -678,7 +682,7 @@ let output_key_update ~request state = in Ok ({ session with server_app_secret }, server_ctx) | _ -> Error (`Fatal `InvalidSession) - end >|= fun (session', encryptor) -> + end >>| fun (session', encryptor) -> let handshake = { hs with session = `TLS13 session' :: hs.session } in let ku = let p = diff --git a/lib/handshake_server.ml b/lib/handshake_server.ml index af14066a..4ea6da03 100644 --- a/lib/handshake_server.ml +++ b/lib/handshake_server.ml @@ -5,6 +5,8 @@ open State open Handshake_common open Config +open Rresult.R.Infix + let (<+>) = Cstruct.append let state_version state = match state.protocol_version with @@ -16,9 +18,9 @@ let hello_request state = let hr = HelloRequest in Tracing.sexpf ~tag:"handshake-out" ~f:sexp_of_tls_handshake hr ; let state = { state with machina = Server AwaitClientHelloRenegotiate } in - return (state, [`Record (Packet.HANDSHAKE, Writer.assemble_handshake hr)]) + Ok (state, [`Record (Packet.HANDSHAKE, Writer.assemble_handshake hr)]) else - fail (`Fatal `InvalidSession) + Error (`Fatal `InvalidSession) let answer_client_finished state (session : session_data) client_fin raw log = @@ -32,7 +34,7 @@ let answer_client_finished state (session : session_data) client_fin raw log = let fin = Finished server in let fin_raw = Writer.assemble_handshake fin in (* we really do not want to have any leftover handshake fragments *) - guard (Cstruct.len state.hs_fragment = 0) (`Fatal `HandshakeFragmentsNotEmpty) >|= fun () -> + guard (Cstruct.len state.hs_fragment = 0) (`Fatal `HandshakeFragmentsNotEmpty) >>| fun () -> let session = { session with renegotiation = (client, server) } and machina = Server Established in @@ -46,7 +48,7 @@ let answer_client_finished_resume state (session : session_data) server_verify c in guard (Cstruct.equal client_verify client_fin) (`Fatal `BadFinished) >>= fun () -> (* we really do not want to have any leftover handshake fragments *) - guard (Cstruct.len state.hs_fragment = 0) (`Fatal `HandshakeFragmentsNotEmpty) >|= fun () -> + guard (Cstruct.len state.hs_fragment = 0) (`Fatal `HandshakeFragmentsNotEmpty) >>| fun () -> let session = { session with renegotiation = (client_verify, server_verify) } and machina = Server Established in @@ -74,11 +76,11 @@ let establish_master_secret state (session : session_data) premastersecret raw l let private_key (session : session_data) = match session.common_session_data.own_private_key with - | Some priv -> return priv - | None -> fail (`Fatal `InvalidSession) (* TODO: assert false / ensure via typing in config *) + | Some priv -> Ok priv + | None -> Error (`Fatal `InvalidSession) (* TODO: assert false / ensure via typing in config *) let validate_certs certs authenticator (session : session_data) = - validate_chain authenticator certs None >|= fun (peer_certificate, received_certificates, peer_certificate_chain, trust_anchor) -> + validate_chain authenticator certs None >>| fun (peer_certificate, received_certificates, peer_certificate_chain, trust_anchor) -> let common_session_data = { session.common_session_data with received_certificates ; @@ -89,18 +91,18 @@ let validate_certs certs authenticator (session : session_data) = { session with common_session_data } let answer_client_certificate_RSA state (session : session_data) certs raw log = - validate_certs certs state.config.authenticator session >|= fun session -> + validate_certs certs state.config.authenticator session >>| fun session -> let machina = AwaitClientKeyExchange_RSA (session, log @ [raw]) in ({ state with machina = Server machina }, []) let answer_client_certificate_DHE state (session : session_data) dh_sent certs raw log = - validate_certs certs state.config.authenticator session >|= fun session -> + validate_certs certs state.config.authenticator session >>| fun session -> let machina = AwaitClientKeyExchange_DHE (session, dh_sent, log @ [raw]) in ({ state with machina = Server machina }, []) let answer_client_certificate_verify state (session : session_data) sctx cctx verify raw log = let sigdata = Cstruct.concat log in - verify_digitally_signed state.protocol_version state.config.signature_algorithms verify sigdata session.common_session_data.peer_certificate >|= fun () -> + verify_digitally_signed state.protocol_version state.config.signature_algorithms verify sigdata session.common_session_data.peer_certificate >>| fun () -> let machina = AwaitClientChangeCipherSpec (session, sctx, cctx, log @ [raw]) in ({ state with machina = Server machina }, []) @@ -128,43 +130,43 @@ let answer_client_key_exchange_RSA state (session : session_data) kex raw log = | None -> validate_premastersecret other | Some k -> validate_premastersecret k in - return (establish_master_secret state session pms raw log) - | _ -> fail (`Fatal `NotRSACertificate) + Ok (establish_master_secret state session pms raw log) + | _ -> Error (`Fatal `NotRSACertificate) let answer_client_key_exchange_DHE state session secret kex raw log = - let to_fatal r = match r with Ok cs -> return cs | Error er -> fail (`Fatal (`ReaderError er)) in + let to_fatal r = match r with Ok cs -> Ok cs | Error er -> Error (`Fatal (`ReaderError er)) in (let open Mirage_crypto_ec in match secret with | `P256 priv -> to_fatal (Reader.parse_client_ec_key_exchange kex) >>= fun share -> begin match P256.Dh.key_exchange priv share with - | Error e -> fail (`Fatal (`BadECDH e)) - | Ok shared -> return shared + | Error e -> Error (`Fatal (`BadECDH e)) + | Ok shared -> Ok shared end | `P384 priv -> to_fatal (Reader.parse_client_ec_key_exchange kex) >>= fun share -> begin match P384.Dh.key_exchange priv share with - | Error e -> fail (`Fatal (`BadECDH e)) - | Ok shared -> return shared + | Error e -> Error (`Fatal (`BadECDH e)) + | Ok shared -> Ok shared end | `P521 priv -> to_fatal (Reader.parse_client_ec_key_exchange kex) >>= fun share -> begin match P521.Dh.key_exchange priv share with - | Error e -> fail (`Fatal (`BadECDH e)) - | Ok shared -> return shared + | Error e -> Error (`Fatal (`BadECDH e)) + | Ok shared -> Ok shared end | `X25519 priv -> to_fatal (Reader.parse_client_ec_key_exchange kex) >>= fun share -> begin match X25519.key_exchange priv share with - | Error e -> fail (`Fatal (`BadECDH e)) - | Ok shared -> return shared + | Error e -> Error (`Fatal (`BadECDH e)) + | Ok shared -> Ok shared end | `Finite_field secret -> to_fatal (Reader.parse_client_dh_key_exchange kex) >>= fun share -> begin match Mirage_crypto_pk.Dh.shared secret share with - | None -> fail (`Fatal `InvalidDH) - | Some shared -> return shared - end) >|= fun pms -> + | None -> Error (`Fatal `InvalidDH) + | Some shared -> Ok shared + end) >>| fun pms -> establish_master_secret state session pms raw log let sig_algs (client_hello : client_hello) = @@ -278,21 +280,21 @@ let answer_client_hello_common state reneg ch raw = let signature_algorithms = sig_algs ch in (agreed_cert ~f ?signature_algorithms config.own_certificates host >>= function | (c::cs, priv) -> let cciphers = agreed_cipher c (ecc_group <> None) cciphers in - return (cciphers, c::cs, Some priv) - | ([], _) -> fail (`Fatal `InvalidSession) (* TODO: assert false / remove by types in config *) + Ok (cciphers, c::cs, Some priv) + | ([], _) -> Error (`Fatal `InvalidSession) (* TODO: assert false / remove by types in config *) ) >>= fun (cciphers, chain, priv) -> ( match first_match cciphers config.ciphers with - | Some x -> return x + | Some x -> Ok x | None -> match first_match cciphers Config.Ciphers.supported with - | Some _ -> fail (`Error (`NoConfiguredCiphersuite cciphers)) - | None -> fail (`Fatal (`InvalidClientHello (`NoSupportedCiphersuite ch.ciphersuites))) ) >>= fun cipher -> + | Some _ -> Error (`Error (`NoConfiguredCiphersuite cciphers)) + | None -> Error (`Fatal (`InvalidClientHello (`NoSupportedCiphersuite ch.ciphersuites))) ) >>= fun cipher -> let extended_ms = List.mem `ExtendedMasterSecret ch.extensions in Tracing.sexpf ~tag:"cipher" ~f:Ciphersuite.sexp_of_ciphersuite cipher ; - alpn_protocol config ch >|= fun alpn_protocol -> + alpn_protocol config ch >>| fun alpn_protocol -> let own_name = match host with None -> None | Some h -> Some (Domain_name.to_string h) in let group = @@ -334,7 +336,7 @@ let answer_client_hello_common state reneg ch raw = and cert_request version config (session : session_data) = let open Writer in match config.authenticator with - | None -> return ([], session) + | None -> Ok ([], session) | Some _ -> let cas = List.map X509.Distinguished_name.encode_der config.acceptable_cas @@ -343,16 +345,16 @@ let answer_client_hello_common state reneg ch raw = in (match version with | `TLS_1_0 | `TLS_1_1 -> - return (assemble_certificate_request certs cas) + Ok (assemble_certificate_request certs cas) | `TLS_1_2 -> - return (assemble_certificate_request_1_2 certs config.signature_algorithms cas) + Ok (assemble_certificate_request_1_2 certs config.signature_algorithms cas) | `TLS_1_3 -> (* TLS 1.3 handshakes are diverted in answer_client_hello, this will never be executed. for renegotiation, it is checked that the protocol version did not change from the previous epoch (in answer_client_hello_reneg, process_client_hello the guard (version = oldversion)) *) - fail (`Fatal (`BadRecordVersion (version :> tls_any_version)))) >|= fun data -> + Error (`Fatal (`BadRecordVersion (version :> tls_any_version)))) >>| fun data -> let certreq = CertificateRequest data in Tracing.sexpf ~tag:"handshake-out" ~f:sexp_of_tls_handshake certreq ; let common_session_data = { session.common_session_data with client_auth = true } in @@ -369,27 +371,27 @@ let answer_client_hello_common state reneg ch raw = let (secret, msg) = Mirage_crypto_pk.Dh.gen_key g in let dh_param = Crypto.dh_params_pack g msg in let dh_params = Writer.assemble_dh_parameters dh_param in - return (`Finite_field secret, dh_params) + Ok (`Finite_field secret, dh_params) | `P256 -> let secret, shared = P256.Dh.gen_key ~rng in let params = Writer.assemble_ec_parameters `P256 shared in - return (`P256 secret, params) + Ok (`P256 secret, params) | `P384 -> let secret, shared = P384.Dh.gen_key ~rng in let params = Writer.assemble_ec_parameters `P384 shared in - return (`P384 secret, params) + Ok (`P384 secret, params) | `P521 -> let secret, shared = P521.Dh.gen_key ~rng in let params = Writer.assemble_ec_parameters `P521 shared in - return (`P521 secret, params) + Ok (`P521 secret, params) | `X25519 -> let secret, shared = X25519.gen_key ~rng in let params = Writer.assemble_ec_parameters `X25519 shared in - return (`X25519 secret, params) + Ok (`X25519 secret, params) ) >>= fun (secret, written) -> let data = session.common_session_data.client_random <+> session.common_session_data.server_random <+> written in private_key session >>= fun priv -> - signature version data sig_algs config.signature_algorithms priv >|= fun sgn -> + signature version data sig_algs config.signature_algorithms priv >>| fun sgn -> let kex = ServerKeyExchange (written <+> sgn) in let hs = Writer.assemble_handshake kex in Tracing.sexpf ~tag:"handshake-out" ~f:sexp_of_tls_handshake kex ; @@ -414,7 +416,7 @@ let answer_client_hello_common state reneg ch raw = AwaitClientKeyExchange_DHE (session, dh, log) in Tracing.sexpf ~tag:"handshake-out" ~f:sexp_of_tls_handshake ServerHelloDone ; - return (outs, machina) + Ok (outs, machina) | `RSA -> let outs = sh :: certificates @ cert_req @ [ hello_done ] in let log = raw :: outs in @@ -425,7 +427,7 @@ let answer_client_hello_common state reneg ch raw = AwaitClientKeyExchange_RSA (session, log) in Tracing.sexpf ~tag:"handshake-out" ~f:sexp_of_tls_handshake ServerHelloDone ; - return (outs, machina) ) >|= fun (out_recs, machina) -> + Ok (outs, machina) ) >>| fun (out_recs, machina) -> ({ state with machina = Server machina }, [`Record (Packet.HANDSHAKE, Cstruct.concat out_recs)]) @@ -451,18 +453,18 @@ let agreed_version supported (client_hello : client_hello) = | Some v -> Some v) None client_versions with - | Some x -> return x + | Some x -> Ok x | None -> match supported_versions with - | [] -> fail (`Fatal (`NoVersions raw_client_versions)) - | _ -> fail (`Error (`NoConfiguredVersions supported_versions)) + | [] -> Error (`Fatal (`NoVersions raw_client_versions)) + | _ -> Error (`Error (`NoConfiguredVersions supported_versions)) let answer_client_hello state (ch : client_hello) raw = let ensure_reneg ciphers their_data = let reneg_cs = List.mem Packet.TLS_EMPTY_RENEGOTIATION_INFO_SCSV ciphers in match reneg_cs, their_data with | _, Some x -> guard (Cstruct.len x = 0) (`Fatal `InvalidRenegotiation) - | true, _ -> return () - | _ -> fail (`Fatal `NoSecureRenegotiation) + | true, _ -> Ok () + | _ -> Error (`Fatal `NoSecureRenegotiation) and resume (ch : client_hello) state = let epoch_matches (epoch : Core.epoch_data) version ciphers extensions = @@ -491,7 +493,7 @@ let answer_client_hello state (ch : client_hello) raw = let version = state_version state in let sh, session = server_hello state.config ch session version None in (* we really do not want to have any leftover handshake fragments *) - guard (Cstruct.len state.hs_fragment = 0) (`Fatal `HandshakeFragmentsNotEmpty) >|= fun () -> + guard (Cstruct.len state.hs_fragment = 0) (`Fatal `HandshakeFragmentsNotEmpty) >>| fun () -> let client_ctx, server_ctx = Handshake_crypto.initialise_crypto_ctx version session in @@ -516,8 +518,8 @@ let answer_client_hello state (ch : client_hello) raw = let process_client_hello config ch version = let cciphers = ch.ciphersuites in (match client_hello_valid version ch with - | Ok () -> return () - | Error e -> fail (`Fatal (`InvalidClientHello e))) >>= fun () -> + | Ok () -> Ok () + | Error e -> Error (`Fatal (`InvalidClientHello e))) >>= fun () -> guard (not (List.mem Packet.TLS_FALLBACK_SCSV cciphers) || version = max_protocol_version config.protocol_versions) (`Fatal `InappropriateFallback) >>= fun () -> @@ -542,17 +544,17 @@ let answer_client_hello_reneg state (ch : client_hello) raw = let ensure_reneg our_data their_data = match our_data, their_data with | (cvd, _), Some x -> guard (Cstruct.equal cvd x) (`Fatal `InvalidRenegotiation) - | _ -> fail (`Fatal `NoSecureRenegotiation) + | _ -> Error (`Fatal `NoSecureRenegotiation) in let process_client_hello config oldversion ours ch = (match client_hello_valid oldversion ch with - | Ok () -> return () - | Error x -> fail (`Fatal (`InvalidClientHello x))) >>= fun () -> + | Ok () -> Ok () + | Error x -> Error (`Fatal (`InvalidClientHello x))) >>= fun () -> agreed_version config.protocol_versions ch >>= fun version -> guard (version = oldversion) (`Fatal (`InvalidRenegotiationVersion version)) >>= fun () -> let theirs = get_secure_renegotiation ch.extensions in - ensure_reneg ours theirs >|= fun () -> + ensure_reneg ours theirs >>| fun () -> version in @@ -565,8 +567,8 @@ let answer_client_hello_reneg state (ch : client_hello) raw = | false, _ -> let no_reneg = Writer.assemble_alert ~level:Packet.WARNING Packet.NO_RENEGOTIATION in Tracing.sexpf ~tag:"alert-out" ~f:sexp_of_tls_alert (Packet.WARNING, Packet.NO_RENEGOTIATION) ; - return (state, [`Record (Packet.ALERT, no_reneg)]) - | true , _ -> fail (`Fatal `InvalidSession) (* I'm pretty sure this can be an assert false *) + Ok (state, [`Record (Packet.ALERT, no_reneg)]) + | true , _ -> Error (`Fatal `InvalidSession) (* I'm pretty sure this can be an assert false *) let handle_change_cipher_spec ss state packet = match Reader.parse_change_cipher_spec packet, ss with @@ -579,18 +581,18 @@ let handle_change_cipher_spec ss state packet = Tracing.cs ~tag:"change-cipher-spec-in" packet ; Tracing.cs ~tag:"change-cipher-spec-out" packet ; - return ({ state with machina = Server machina }, + Ok ({ state with machina = Server machina }, [`Record ccs; `Change_enc server_ctx; `Change_dec client_ctx]) | Ok (), AwaitClientChangeCipherSpecResume (session, client_ctx, server_verify, log) -> - guard (Cstruct.len state.hs_fragment = 0) (`Fatal `HandshakeFragmentsNotEmpty) >|= fun () -> + guard (Cstruct.len state.hs_fragment = 0) (`Fatal `HandshakeFragmentsNotEmpty) >>| fun () -> let machina = AwaitClientFinishedResume (session, server_verify, log) in Tracing.cs ~tag:"change-cipher-spec-in" packet ; ({ state with machina = Server machina }, [`Change_dec client_ctx]) - | Error er, _ -> fail (`Fatal (`ReaderError er)) - | _ -> fail (`Fatal `UnexpectedCCS) + | Error er, _ -> Error (`Fatal (`ReaderError er)) + | _ -> Error (`Fatal `UnexpectedCCS) let handle_handshake ss hs buf = match Reader.parse_handshake buf with @@ -602,15 +604,15 @@ let handle_handshake ss hs buf = | AwaitClientCertificate_RSA (session, log), Certificate cs -> (match Reader.parse_certificates cs with | Ok cs -> answer_client_certificate_RSA hs session cs buf log - | Error re -> fail (`Fatal (`ReaderError re))) + | Error re -> Error (`Fatal (`ReaderError re))) | AwaitClientCertificate_DHE (session, dh_sent, log), Certificate cs -> (match Reader.parse_certificates cs with | Ok cs -> answer_client_certificate_DHE hs session dh_sent cs buf log - | Error re -> fail (`Fatal (`ReaderError re))) + | Error re -> Error (`Fatal (`ReaderError re))) | AwaitClientKeyExchange_RSA (session, log), ClientKeyExchange cs -> (match Reader.parse_client_dh_key_exchange cs with | Ok kex -> answer_client_key_exchange_RSA hs session kex buf log - | Error re -> fail (`Fatal (`ReaderError re))) + | Error re -> Error (`Fatal (`ReaderError re))) | AwaitClientKeyExchange_DHE (session, dh_sent, log), ClientKeyExchange kex -> answer_client_key_exchange_DHE hs session dh_sent kex buf log | AwaitClientCertificateVerify (session, sctx, cctx, log), CertificateVerify ver -> @@ -623,5 +625,5 @@ let handle_handshake ss hs buf = answer_client_hello_reneg hs ch buf | AwaitClientHelloRenegotiate, ClientHello ch -> (* hello-request send, renegotiation *) answer_client_hello_reneg hs ch buf - | _, hs -> fail (`Fatal (`UnexpectedHandshake hs)) ) - | Error re -> fail (`Fatal (`ReaderError re)) + | _, hs -> Error (`Fatal (`UnexpectedHandshake hs)) ) + | Error re -> Error (`Fatal (`ReaderError re)) diff --git a/lib/handshake_server.mli b/lib/handshake_server.mli index ba63ba74..424cb953 100644 --- a/lib/handshake_server.mli +++ b/lib/handshake_server.mli @@ -1,6 +1,6 @@ open State -val hello_request : handshake_state -> handshake_return eff +val hello_request : handshake_state -> (handshake_return, failure) result -val handle_change_cipher_spec : server_handshake_state -> handshake_state -> Cstruct.t -> handshake_return eff -val handle_handshake : server_handshake_state -> handshake_state -> Cstruct.t -> handshake_return eff +val handle_change_cipher_spec : server_handshake_state -> handshake_state -> Cstruct.t -> (handshake_return, failure) result +val handle_handshake : server_handshake_state -> handshake_state -> Cstruct.t -> (handshake_return, failure) result diff --git a/lib/handshake_server13.ml b/lib/handshake_server13.ml index 080fa178..b8b225e3 100644 --- a/lib/handshake_server13.ml +++ b/lib/handshake_server13.ml @@ -6,14 +6,16 @@ open Handshake_common open Handshake_crypto13 +open Rresult.R.Infix + let answer_client_hello ~hrr state ch raw = (match client_hello_valid `TLS_1_3 ch with - | Error e -> fail (`Fatal (`InvalidClientHello e)) - | Ok () -> return () ) >>= fun () -> + | Error e -> Error (`Fatal (`InvalidClientHello e)) + | Ok () -> Ok () ) >>= fun () -> (if hrr && List.mem `EarlyDataIndication ch.extensions then - fail (`Fatal (`InvalidClientHello `Has0rttAfterHRR)) + Error (`Fatal (`InvalidClientHello `Has0rttAfterHRR)) else - return ()) >>= fun () -> + Ok ()) >>= fun () -> Tracing.sexpf ~tag:"version" ~f:sexp_of_tls_version `TLS_1_3 ; let ciphers = @@ -21,18 +23,18 @@ let answer_client_hello ~hrr state ch raw = in ( match map_find ~f:(function `SupportedGroups gs -> Some gs | _ -> None) ch.extensions with - | None -> fail (`Fatal (`InvalidClientHello `NoSupportedGroupExtension)) - | Some gs -> return (filter_map ~f:Core.named_group_to_group gs )) >>= fun groups -> + | None -> Error (`Fatal (`InvalidClientHello `NoSupportedGroupExtension)) + | Some gs -> Ok (filter_map ~f:Core.named_group_to_group gs )) >>= fun groups -> ( match map_find ~f:(function `KeyShare ks -> Some ks | _ -> None) ch.extensions with - | None -> fail (`Fatal (`InvalidClientHello `NoKeyShareExtension)) + | None -> Error (`Fatal (`InvalidClientHello `NoKeyShareExtension)) | Some ks -> - let f acc (g, ks) = - match Core.named_group_to_group g with - | None -> Ok acc - | Some g -> Ok ((g, ks) :: acc) - in - foldM f [] ks ) >>= fun keyshares -> + List.fold_left (fun acc (g, ks) -> + acc >>| fun acc -> + match Core.named_group_to_group g with + | None -> acc + | Some g -> ((g, ks) :: acc)) + (Ok []) ks ) >>= fun keyshares -> let base_server_hello ?epoch cipher extensions = let ciphersuite = (cipher :> Ciphersuite.ciphersuite) in @@ -63,15 +65,15 @@ let answer_client_hello ~hrr state ch raw = first_match keyshare_groups config.Config.groups, first_match ciphers (Config.ciphers13 config) with - | _, None -> fail (`Error (`NoConfiguredCiphersuite ciphers)) + | _, None -> Error (`Error (`NoConfiguredCiphersuite ciphers)) | None, Some cipher -> if hrr then (* avoid loops CH -> HRR -> CH -> HRR -> ... *) - fail (`Fatal `NoSupportedGroup) + Error (`Fatal `NoSupportedGroup) else (* no keyshare, looks whether there's a supported group ++ send back HRR *) begin match first_match groups config.Config.groups with - | None -> fail (`Fatal `NoSupportedGroup) + | None -> Error (`Fatal `NoSupportedGroup) | Some group -> let cookie = Mirage_crypto.Hash.digest (Ciphersuite.hash13 cipher) raw in let hrr = { retry_version = `TLS_1_3 ; ciphersuite = cipher ; sessionid = ch.sessionid ; selected_group = group ; extensions = [ `Cookie cookie ] } in @@ -81,7 +83,7 @@ let answer_client_hello ~hrr state ch raw = (* but the client wouldn't know until it received the HRR *) let early_data_left = if List.mem `EarlyDataIndication ch.extensions then config.Config.zero_rtt else 0l in let machina = Server13 AwaitClientHelloHRR13 in - return ({ state with early_data_left ; machina }, + Ok ({ state with early_data_left ; machina }, `Record (Packet.HANDSHAKE, hrr_raw) :: (match ch.sessionid with | None -> [] @@ -92,21 +94,21 @@ let answer_client_hello ~hrr state ch raw = Log.debug (fun m -> m "group %a" Sexplib.Sexp.pp_hum (Core.sexp_of_group group)) ; match List.mem group groups, keyshare group with - | false, _ | _, None -> fail (`Fatal `NoSupportedGroup) (* TODO: better error type? *) + | false, _ | _, None -> Error (`Fatal `NoSupportedGroup) (* TODO: better error type? *) | _, Some keyshare -> (* DHE - full handshake *) (if hrr then match map_find ~f:(function `Cookie c -> Some c | _ -> None) ch.extensions with - | None -> fail (`Fatal (`InvalidClientHello `NoCookie)) + | None -> Error (`Fatal (`InvalidClientHello `NoCookie)) | Some c -> (* log is: 254 00 00 length c :: HRR *) let hash_hdr = Writer.assemble_message_hash (Cstruct.len c) in let hrr = { retry_version = `TLS_1_3 ; ciphersuite = cipher ; sessionid = ch.sessionid ; selected_group = group ; extensions = [ `Cookie c ]} in let hs_buf = Writer.assemble_handshake (HelloRetryRequest hrr) in - return (Cstruct.concat [ hash_hdr ; c ; hs_buf ]) + Ok (Cstruct.concat [ hash_hdr ; c ; hs_buf ]) else - return Cstruct.empty) >>= fun log -> + Ok Cstruct.empty) >>= fun log -> let hostname = hostname ch in let hlen = Mirage_crypto.Hash.digest_size (Ciphersuite.hash13 cipher) in @@ -137,7 +139,7 @@ let answer_client_hello ~hrr state ch raw = | (idx, ((id, obf_age), binder))::_ -> (* need to verify binder, do the obf_age computations + checking, figure out whether the id is in our psk cache, and use the resumption secret as input - and return the idx *) + and Ok the idx *) let psk, old_epoch = match cache.Config.lookup id with | None -> assert false (* see above *) @@ -210,14 +212,14 @@ let answer_client_hello ~hrr state ch raw = let server_hs_secret, server_ctx, client_hs_secret, client_ctx = hs_ctx hs_secret log in ( match map_find ~f:(function `SignatureAlgorithms sa -> Some sa | _ -> None) ch.extensions with - | None -> fail (`Fatal (`InvalidClientHello `NoSignatureAlgorithmsExtension)) - | Some sa -> return sa ) >>= fun sigalgs -> + | None -> Error (`Fatal (`InvalidClientHello `NoSignatureAlgorithmsExtension)) + | Some sa -> Ok sa ) >>= fun sigalgs -> (* TODO respect certificate_signature_algs if present *) let f = supports_key_usage ~not_present:true `Digital_signature in (agreed_cert ~f ~signature_algorithms:sigalgs config.Config.own_certificates hostname >>= function - | (c::cs, priv) -> return (c::cs, priv) - | _ -> fail (`Fatal `InvalidSession)) >>= fun (chain, priv) -> + | (c::cs, priv) -> Ok (c::cs, priv) + | _ -> Error (`Fatal `InvalidSession)) >>= fun (chain, priv) -> alpn_protocol config ch >>= fun alpn_protocol -> let session = let own_name = match hostname with None -> None | Some x -> Some (Domain_name.to_string x) in @@ -242,7 +244,7 @@ let answer_client_hello ~hrr state ch raw = begin if session.resumed then - return ([], log, session) + Ok ([], log, session) else let out, log, session = match config.Config.authenticator with | None -> [], log, session @@ -270,7 +272,7 @@ let answer_client_hello ~hrr state ch raw = let tbs = Mirage_crypto.Hash.digest (Ciphersuite.hash13 cipher) log in signature `TLS_1_3 ~context_string:"TLS 1.3, server CertificateVerify" - tbs (Some sigalgs) config.Config.signature_algorithms priv >|= fun signed -> + tbs (Some sigalgs) config.Config.signature_algorithms priv >>| fun signed -> let cv = CertificateVerify signed in let cv_raw = Writer.assemble_handshake cv in Tracing.sexpf ~tag:"handshake-out" ~f:sexp_of_tls_handshake cv ; @@ -293,7 +295,7 @@ let answer_client_hello ~hrr state ch raw = in let session' = { session' with server_app_secret ; client_app_secret } in - guard (Cstruct.len state.hs_fragment = 0) (`Fatal `HandshakeFragmentsNotEmpty) >|= fun () -> + guard (Cstruct.len state.hs_fragment = 0) (`Fatal `HandshakeFragmentsNotEmpty) >>| fun () -> (* send sessionticket early *) (* TODO track the nonce across handshakes / newsessionticket messages (i.e. after post-handshake auth) - needs to be unique! *) @@ -348,8 +350,8 @@ let answer_client_hello ~hrr state ch raw = let answer_client_certificate state cert (sd : session_data13) client_fini dec_ctx st raw log = match Reader.parse_certificates_1_3 cert, state.config.Config.authenticator with - | Error re, _ -> fail (`Fatal (`ReaderError re)) - | Ok (_, []), None -> fail (`Fatal `InvalidSession) (* TODO this cannot happen *) + | Error re, _ -> Error (`Fatal (`ReaderError re)) + | Ok (_, []), None -> Error (`Fatal `InvalidSession) (* TODO this cannot happen *) | Ok (_ctx, []), Some auth -> begin match auth ~host:None [] with | Ok anchor -> @@ -361,13 +363,13 @@ let answer_client_certificate state cert (sd : session_data13) client_fini dec_c let sd = { sd with common_session_data13 } in let st = AwaitClientFinished13 (client_fini, dec_ctx, st, log <+> raw) in Ok ({ state with machina = Server13 st ; session = `TLS13 sd :: state.session }, []) - | Error e -> fail (`Error (`AuthenticationFailure e)) + | Error e -> Error (`Error (`AuthenticationFailure e)) end | Ok (_ctx, cert_exts), auth -> (* TODO what to do with ctx? send through authenticator? *) (* TODO what to do with extensions? *) let certs = List.map fst cert_exts in - validate_chain auth certs None >|= fun (peer_certificate, received_certificates, peer_certificate_chain, trust_anchor) -> + validate_chain auth certs None >>| fun (peer_certificate, received_certificates, peer_certificate_chain, trust_anchor) -> let sd' = let common_session_data13 = { sd.common_session_data13 with received_certificates ; @@ -385,7 +387,7 @@ let answer_client_certificate_verify state cv (sd : session_data13) client_fini verify_digitally_signed `TLS_1_3 ~context_string:"TLS 1.3, client CertificateVerify" state.config.Config.signature_algorithms cv tbs - sd.common_session_data13.peer_certificate >|= fun () -> + sd.common_session_data13.peer_certificate >>| fun () -> let st = AwaitClientFinished13 (client_fini, dec_ctx, st, log <+> raw) in ({ state with machina = Server13 st ; session = `TLS13 sd :: state.session }, []) @@ -395,7 +397,7 @@ let answer_client_finished state fin client_fini dec_ctx st raw log = let hash = Ciphersuite.hash13 session.ciphersuite13 in let data = finished hash client_fini log in guard (Cstruct.equal data fin) (`Fatal `BadFinished) >>= fun () -> - guard (Cstruct.len state.hs_fragment = 0) (`Fatal `HandshakeFragmentsNotEmpty) >|= fun () -> + guard (Cstruct.len state.hs_fragment = 0) (`Fatal `HandshakeFragmentsNotEmpty) >>| fun () -> let session' = match st, state.config.Config.ticket_cache with | None, _ | _, None -> session | Some st, Some cache -> @@ -410,7 +412,7 @@ let answer_client_finished state fin client_fini dec_ctx st raw log = in let state' = { state with machina = Server13 Established13 ; session = `TLS13 session' :: rest } in (state', [ `Change_dec dec_ctx ]) - | _ -> fail (`Fatal `InvalidSession) + | _ -> Error (`Fatal `InvalidSession) let handle_end_of_early_data state cf hs_ctx cc st buf log = let machina = AwaitClientFinished13 (cf, cc, st, log <+> buf) in @@ -419,7 +421,7 @@ let handle_end_of_early_data state cf hs_ctx cc st buf log = let session = `TLS13 { s1 with state = `Established } :: state.session in Ok ({ state with machina = Server13 machina ; session }, [ `Change_dec hs_ctx ]) | _ -> - fail (`Fatal `InvalidSession) + Error (`Fatal `InvalidSession) let handle_key_update state req = match state.session with @@ -443,7 +445,7 @@ let handle_key_update state req = let session = `TLS13 session' :: state.session in let state' = { state with machina = Server13 Established13 ; session } in Ok (state', `Change_dec client_ctx :: out) - | _ -> fail (`Fatal `InvalidSession) + | _ -> Error (`Fatal `InvalidSession) let handle_handshake cs hs buf = let open Reader in @@ -463,5 +465,5 @@ let handle_handshake cs hs buf = handle_end_of_early_data hs cf hs_c cc st buf log | Established13, KeyUpdate req -> handle_key_update hs req - | _, hs -> fail (`Fatal (`UnexpectedHandshake hs)) ) - | Error re -> fail (`Fatal (`ReaderError re)) + | _, hs -> Error (`Fatal (`UnexpectedHandshake hs)) ) + | Error re -> Error (`Fatal (`ReaderError re)) diff --git a/lib/reader.ml b/lib/reader.ml index fefadc3b..7ed0a4d9 100644 --- a/lib/reader.ml +++ b/lib/reader.ml @@ -14,9 +14,6 @@ type error = | UnknownContent of int [@@deriving sexp] -include Control.Or_error_make (struct type err = error end) -type nonrec 'a result = ('a, error) result - exception Reader_error of error let raise_unknown msg = raise (Reader_error (Unknown msg)) @@ -24,9 +21,9 @@ and raise_wrong_length msg = raise (Reader_error (WrongLength msg)) and raise_trailing_bytes msg = raise (Reader_error (TrailingBytes msg)) let catch f x = - try return (f x) with - | Reader_error err -> fail err - | Invalid_argument _ -> fail Underflow + try Ok (f x) with + | Reader_error err -> Error err + | Invalid_argument _ -> Error Underflow let parse_version_int buf = let major = get_uint8 buf 0 in @@ -58,7 +55,7 @@ let parse_any_version = catch parse_any_version_exn let parse_record buf = if len buf < 5 then - return (`Fragment buf) + Ok (`Fragment buf) else let typ = get_uint8 buf 0 and version = parse_version_int (shift buf 1) @@ -68,18 +65,18 @@ let parse_record buf = (* 2 ^ 14 + 2048 for TLSCiphertext 2 ^ 14 + 1024 for TLSCompressed 2 ^ 14 for TLSPlaintext *) - fail (Overflow x) - | x when 5 + x > len buf -> return (`Fragment buf) + Error (Overflow x) + | x when 5 + x > len buf -> Ok (`Fragment buf) | x -> match tls_any_version_of_pair version, int_to_content_type typ with - | None, _ -> fail (UnknownVersion version) - | _, None -> fail (UnknownContent typ) + | None, _ -> Error (UnknownVersion version) + | _, None -> Error (UnknownContent typ) | Some version, Some content_type -> let payload, rest = split ~start:5 buf x in - return (`Record (({ content_type ; version }, payload), rest)) + Ok (`Record (({ content_type ; version }, payload), rest)) let validate_alert (lvl, typ) = let open Packet in @@ -132,8 +129,8 @@ let parse_alert = catch @@ fun buf -> let parse_change_cipher_spec buf = match len buf, get_uint8 buf 0 with - | 1, 1 -> return () - | _ -> fail (Unknown "bad change cipher spec message") + | 1, 1 -> Ok () + | _ -> Error (Unknown "bad change cipher spec message") let rec parse_count_list parsef buf acc = function | 0 -> (List.rev acc, buf) diff --git a/lib/reader.mli b/lib/reader.mli index 828bbeb0..2cb811de 100644 --- a/lib/reader.mli +++ b/lib/reader.mli @@ -11,33 +11,31 @@ type error = val error_of_sexp : Sexplib.Sexp.t -> error val sexp_of_error : error -> Sexplib.Sexp.t -type nonrec 'a result = ('a, error) result - -val parse_version : Cstruct.t -> Core.tls_version result -val parse_any_version : Cstruct.t -> Core.tls_any_version result +val parse_version : Cstruct.t -> (Core.tls_version, error) result +val parse_any_version : Cstruct.t -> (Core.tls_any_version, error) result val parse_record : Cstruct.t -> - [ `Record of (Core.tls_hdr * Cstruct.t) * Cstruct.t - | `Fragment of Cstruct.t - ] result + ([ `Record of (Core.tls_hdr * Cstruct.t) * Cstruct.t + | `Fragment of Cstruct.t + ], error) result val parse_handshake_frame : Cstruct.t -> (Cstruct.t option * Cstruct.t) -val parse_handshake : Cstruct.t -> Core.tls_handshake result +val parse_handshake : Cstruct.t -> (Core.tls_handshake, error) result -val parse_alert : Cstruct.t -> Core.tls_alert result +val parse_alert : Cstruct.t -> (Core.tls_alert, error) result -val parse_change_cipher_spec : Cstruct.t -> unit result +val parse_change_cipher_spec : Cstruct.t -> (unit, error) result -val parse_certificate_request : Cstruct.t -> (Packet.client_certificate_type list * Cstruct.t list) result -val parse_certificate_request_1_2 : Cstruct.t -> (Packet.client_certificate_type list * Core.signature_algorithm list * Cstruct.t list) result -val parse_certificate_request_1_3 : Cstruct.t -> (Cstruct.t option * Core.certificate_request_extension list) result +val parse_certificate_request : Cstruct.t -> (Packet.client_certificate_type list * Cstruct.t list, error) result +val parse_certificate_request_1_2 : Cstruct.t -> (Packet.client_certificate_type list * Core.signature_algorithm list * Cstruct.t list, error) result +val parse_certificate_request_1_3 : Cstruct.t -> (Cstruct.t option * Core.certificate_request_extension list, error) result -val parse_certificates : Cstruct.t -> Cstruct.t list result -val parse_certificates_1_3 : Cstruct.t -> (Cstruct.t * (Cstruct.t * 'a list) list) result +val parse_certificates : Cstruct.t -> (Cstruct.t list, error) result +val parse_certificates_1_3 : Cstruct.t -> (Cstruct.t * (Cstruct.t * 'a list) list, error) result -val parse_client_dh_key_exchange : Cstruct.t -> Cstruct.t result -val parse_client_ec_key_exchange : Cstruct.t -> Cstruct.t result +val parse_client_dh_key_exchange : Cstruct.t -> (Cstruct.t, error) result +val parse_client_ec_key_exchange : Cstruct.t -> (Cstruct.t, error) result -val parse_dh_parameters : Cstruct.t -> (Core.dh_parameters * Cstruct.t * Cstruct.t) result -val parse_ec_parameters : Cstruct.t -> ([ `X25519 | `P256 | `P384 | `P521 ] * Cstruct.t * Cstruct.t * Cstruct.t) result -val parse_digitally_signed : Cstruct.t -> Cstruct.t result -val parse_digitally_signed_1_2 : Cstruct.t -> (Core.signature_algorithm * Cstruct.t) result +val parse_dh_parameters : Cstruct.t -> (Core.dh_parameters * Cstruct.t * Cstruct.t, error) result +val parse_ec_parameters : Cstruct.t -> ([ `X25519 | `P256 | `P384 | `P521 ] * Cstruct.t * Cstruct.t * Cstruct.t, error) result +val parse_digitally_signed : Cstruct.t -> (Cstruct.t, error) result +val parse_digitally_signed_1_2 : Cstruct.t -> (Core.signature_algorithm * Cstruct.t, error) result diff --git a/lib/state.ml b/lib/state.ml index 5d3809a7..f885e044 100644 --- a/lib/state.ml +++ b/lib/state.ml @@ -304,10 +304,6 @@ type failure = [ | `Fatal of fatal ] [@@deriving sexp] -(* Monadic control-flow core. *) -include Control.Or_error_make (struct type err = failure end) -type 'a eff = 'a t - let common_data_to_epoch common is_server peer_name = let own_random, peer_random = if is_server then diff --git a/lwt/tls_lwt.ml b/lwt/tls_lwt.ml index dae93e5e..9f162598 100644 --- a/lwt/tls_lwt.ml +++ b/lwt/tls_lwt.ml @@ -62,7 +62,7 @@ module Unix = struct let handle tls buf = match Tls.Engine.handle_tls tls buf with - | `Ok (state', `Response resp, `Data data) -> + | Ok (state', `Response resp, `Data data) -> let state' = match state' with | `Ok tls -> `Active tls | `Eof -> `Eof @@ -72,7 +72,7 @@ module Unix = struct safely (resp |> when_some (write_t t)) >|= fun () -> `Ok data - | `Fail (alert, `Response resp) -> + | Error (alert, `Response resp) -> t.state <- `Error (Tls_failure alert) ; write_t t resp >>= fun () -> read_react t in @@ -128,11 +128,10 @@ module Unix = struct * *) let rec drain_handshake t = let push_linger t mcs = - let open Tls.Utils.Cs in match (mcs, t.linger) with | (None, _) -> () | (scs, None) -> t.linger <- scs - | (Some cs, Some l) -> t.linger <- Some (l <+> cs) + | (Some cs, Some l) -> t.linger <- Some (Cstruct.append l cs) in match t.state with | `Active tls when not (Tls.Engine.handshake_in_progress tls) -> @@ -225,9 +224,9 @@ module Unix = struct match t.state with | `Active tls -> ( match Tls.Engine.epoch tls with | `InitialEpoch -> assert false (* can never occur! *) - | `Epoch data -> `Ok data ) - | `Eof -> `Error - | `Error _ -> `Error + | `Epoch data -> Ok data ) + | `Eof -> Error () + | `Error _ -> Error () end diff --git a/lwt/tls_lwt.mli b/lwt/tls_lwt.mli index a1a1567d..fe988d59 100644 --- a/lwt/tls_lwt.mli +++ b/lwt/tls_lwt.mli @@ -83,7 +83,7 @@ module Unix : sig (** [epoch t] returns [epoch], which contains information of the active session. *) - val epoch : t -> [ `Ok of Tls.Core.epoch_data | `Error ] + val epoch : t -> (Tls.Core.epoch_data, unit) result end (** {1 High-level API} *) diff --git a/mirage/tls_mirage.ml b/mirage/tls_mirage.ml index cafc79ce..fab1d4b6 100644 --- a/mirage/tls_mirage.ml +++ b/mirage/tls_mirage.ml @@ -57,7 +57,7 @@ module Make (F : Mirage_flow.S) = struct let handle tls buf = match Tls.Engine.handle_tls tls buf with - | `Ok (res, `Response resp, `Data data) -> + | Ok (res, `Response resp, `Data data) -> flow.state <- ( match res with | `Ok tls -> `Active tls | `Eof -> `Eof @@ -69,7 +69,7 @@ module Make (F : Mirage_flow.S) = struct | `Ok _ -> return_unit | _ -> FLOW.close flow.flow ) >>= fun () -> return @@ `Ok data - | `Fail (fail, `Response resp) -> + | Error (fail, `Response resp) -> let reason = tls_fail fail in flow.state <- reason ; FLOW.(write flow.flow resp >>= fun _ -> close flow.flow) >>= fun () -> return reason diff --git a/tests/feedback.ml b/tests/feedback.ml index d553c5b0..380f27a3 100644 --- a/tests/feedback.ml +++ b/tests/feedback.ml @@ -18,12 +18,12 @@ module Flow = struct | `S st -> (st, "server") | `C st -> (st, "client") in match Tls.Engine.handle_tls st msg with - | `Ok (`Ok st', `Response (Some ans), `Data appdata) -> + | Ok (`Ok st', `Response (Some ans), `Data appdata) -> (rewrap_st (state, st'), ans, appdata) - | `Fail (a, _) -> + | Error (a, _) -> failwith @@ Printf.sprintf "[%s] %s error: %s" tag descr (Sexplib.Sexp.to_string_hum (Tls.Engine.sexp_of_failure a)) - | `Ok _ -> failwith "decoded alert" + | Ok _ -> failwith "decoded alert" end let loop_chatter ~certificate ~loops ~size = diff --git a/tls.opam b/tls.opam index 87668730..d0b83289 100644 --- a/tls.opam +++ b/tls.opam @@ -28,6 +28,7 @@ depends: [ "x509" {>= "0.12.0"} "domain-name" {>= "0.3.0"} "fmt" + "rresult" "cstruct-unix" {with-test & >= "3.0.0"} "ounit2" {with-test & >= "2.2.0"} "lwt" {>= "3.0.0"}