From 2cde8439ac6f3f101025e1a2a1651ba882567d14 Mon Sep 17 00:00:00 2001 From: Hannes Mehnert Date: Thu, 15 Jun 2023 13:00:48 +0200 Subject: [PATCH] Adapt to happy-eyeballs 0.6 changes (#340) * Adapt to happy-eyeballs 0.6 changes * further fixes: - when a connection failed (i.e. timeout), cancel all connection attempts - transport the success/failure via the connected_condition instead of recursing * pass connect_timeout to happy_eyeballs creation * address review from @reynir: filter the current attempt from being woken up/cancelled --------- Co-authored-by: Robur --- lwt/client/dns_client_lwt.ml | 122 +++++++++++++++++------------ mirage/client/dns_client_mirage.ml | 108 ++++++++++++++----------- 2 files changed, 136 insertions(+), 94 deletions(-) diff --git a/lwt/client/dns_client_lwt.ml b/lwt/client/dns_client_lwt.ml index bf1770a5..878de838 100644 --- a/lwt/client/dns_client_lwt.ml +++ b/lwt/client/dns_client_lwt.ml @@ -24,10 +24,10 @@ module Transport : Dns_client.S timeout_ns : int64 ; (* TODO: avoid race, use a mvar instead of condition *) mutable fd : [ `Plain of Lwt_unix.file_descr | `Tls of Tls_lwt.Unix.t ] option ; - mutable connected_condition : unit Lwt_condition.t option ; + mutable connected_condition : (unit, [ `Msg of string ]) result Lwt_condition.t option ; mutable requests : (Cstruct.t * (Cstruct.t, [ `Msg of string ]) result Lwt_condition.t) IM.t ; mutable he : Happy_eyeballs.t ; - mutable cancel_connecting : unit Lwt.u Happy_eyeballs.Waiter_map.t; + mutable cancel_connecting : (int * unit Lwt.u) list Happy_eyeballs.Waiter_map.t; mutable waiters : ((Ipaddr.t * int) * Lwt_unix.file_descr, [ `Msg of string ]) result Lwt.u Happy_eyeballs.Waiter_map.t ; timer_condition : unit Lwt_condition.t ; } @@ -51,65 +51,89 @@ module Transport : Dns_client.S let clock = Mtime_clock.elapsed_ns - let he_timer_interval = Duration.of_ms 500 + let he_timer_interval = Duration.of_ms 10 let close_socket fd = Lwt.catch (fun () -> Lwt_unix.close fd) (fun _ -> Lwt.return_unit) + let try_connect ip port = + Lwt_unix.(getprotobyname "tcp" >|= fun x -> x.p_proto) >>= fun proto_number -> + let fam = + Ipaddr.(Lwt_unix.(match ip with V4 _ -> PF_INET | V6 _ -> PF_INET6)) + in + let socket = Lwt_unix.socket fam Lwt_unix.SOCK_STREAM proto_number in + let open Lwt_result.Infix in + Lwt.catch (fun () -> + let addr = Lwt_unix.ADDR_INET (Ipaddr_unix.to_inet_addr ip, port) in + Lwt_result.ok (Lwt_unix.connect socket addr) >|= fun () -> + socket) + (fun e -> + Lwt_result.ok (close_socket socket) >>= fun () -> + let err = + Fmt.str "error %s connecting to nameserver %a:%d" + (Printexc.to_string e) Ipaddr.pp ip port + in + Lwt.return (Error (`Msg err))) + let handle_one_action t = function - | Happy_eyeballs.Connect (host, id, (ip, port)) -> + | Happy_eyeballs.Connect (host, id, attempt, (ip, port)) -> let cancelled, cancel = Lwt.task () in - t.cancel_connecting <- Happy_eyeballs.Waiter_map.add id cancel t.cancel_connecting; - Lwt_unix.(getprotobyname "tcp" >|= fun x -> x.p_proto) >>= fun proto_number -> - let fam = - Ipaddr.(Lwt_unix.(match ip with V4 _ -> PF_INET | V6 _ -> PF_INET6)) - in - let socket = Lwt_unix.socket fam Lwt_unix.SOCK_STREAM proto_number in - let addr = Lwt_unix.ADDR_INET (Ipaddr_unix.to_inet_addr ip, port) in - Lwt.pick [ - Lwt.try_bind - (fun () -> Lwt_unix.connect socket addr) - Lwt.return_ok - (fun e -> - let err = - Fmt.str "error %s connecting to nameserver %a:%d" - (Printexc.to_string e) Ipaddr.pp ip port - in - Lwt.return (Error (`Msg err))); - (cancelled >|= fun () -> Error (`Msg "cancelled")); - ] >>= fun r -> - t.cancel_connecting <- Happy_eyeballs.Waiter_map.remove id t.cancel_connecting; - begin match r with - | Ok () -> + let entry = attempt, cancel in + t.cancel_connecting <- + Happy_eyeballs.Waiter_map.update id + (function None -> Some [ entry ] | Some c -> Some (entry :: c)) + t.cancel_connecting; + let conn = + try_connect ip port >>= function + | Ok fd -> + let cancel_connecting, others = + Happy_eyeballs.Waiter_map.find_and_remove id t.cancel_connecting + in + t.cancel_connecting <- cancel_connecting; + List.iter (fun (att, w) -> if att <> attempt then Lwt.wakeup_later w ()) + (Option.value ~default:[] others); let waiters, r = Happy_eyeballs.Waiter_map.find_and_remove id t.waiters in t.waiters <- waiters; begin match r with | Some waiter -> - Lwt.wakeup_later waiter (Ok ((ip, port), socket)); + Lwt.wakeup_later waiter (Ok ((ip, port), fd)); Lwt.return_unit - | None -> close_socket socket + | None -> close_socket fd end >|= fun () -> Some (Happy_eyeballs.Connected (host, id, (ip, port))) | Error `Msg err -> - close_socket socket >|= fun () -> - Some (Happy_eyeballs.Connection_failed (host, id, (ip, port), err)) - end + t.cancel_connecting <- + Happy_eyeballs.Waiter_map.update id + (function None -> None | Some c -> + match List.filter (fun (att, _) -> not (att = attempt)) c with + | [] -> None + | c -> Some c) + t.cancel_connecting; + Lwt.return (Some (Happy_eyeballs.Connection_failed (host, id, (ip, port), err))) + in + Lwt.pick [ conn ; (cancelled >|= fun () -> None); ] | Connect_failed (host, id, reason) -> + let cancel_connecting, others = + Happy_eyeballs.Waiter_map.find_and_remove id t.cancel_connecting + in + t.cancel_connecting <- cancel_connecting; + List.iter (fun (_, w) -> Lwt.wakeup_later w ()) (Option.value ~default:[] others); let waiters, r = Happy_eyeballs.Waiter_map.find_and_remove id t.waiters in t.waiters <- waiters; begin match r with | Some waiter -> + let host_or_ip v = + match Ipaddr.of_domain_name v with + | None -> Domain_name.to_string v + | Some ip -> Ipaddr.to_string ip + in let err = - Fmt.str "connection to %a failed: %s" Domain_name.pp host reason + Fmt.str "connection to %s failed: %s" (host_or_ip host) reason in Lwt.wakeup_later waiter (Error (`Msg err)) | None -> () end; Lwt.return None - | Connect_cancelled (_host, id) -> - (match Happy_eyeballs.Waiter_map.find_opt id t.cancel_connecting with - | None -> Lwt.return_none - | Some cancel -> Lwt.wakeup cancel (); Lwt.return_none) | Resolve_a _ | Resolve_aaaa _ as a -> Log.warn (fun m -> m "ignoring action %a" Happy_eyeballs.pp_action a); Lwt.return None @@ -234,7 +258,7 @@ module Transport : Dns_client.S fd = None ; connected_condition = None ; requests = IM.empty ; - he = Happy_eyeballs.create (clock ()) ; + he = Happy_eyeballs.create ~connect_timeout:timeout (clock ()) ; cancel_connecting = Happy_eyeballs.Waiter_map.empty ; waiters = Happy_eyeballs.Waiter_map.empty ; timer_condition = Lwt_condition.create () ; @@ -363,12 +387,15 @@ module Transport : Dns_client.S Lwt.async (fun () -> Lwt_list.iter_p (handle_action t) actions); waiter >>= function | Error `Msg msg -> - Lwt_condition.broadcast connected_condition (); + let err = + Error (`Msg (Fmt.str "error %s connecting to resolver %a" + msg + Fmt.(list ~sep:(any ", ") (pair ~sep:(any ":") Ipaddr.pp int)) + (to_pairs (nameserver_ips t.nameservers)))) + in + Lwt_condition.broadcast connected_condition err; t.connected_condition <- None; - Lwt.return - (Error (`Msg (Fmt.str "error %s connecting to resolver %a" - msg Fmt.(list ~sep:(any ",") (pair ~sep:(any ":") Ipaddr.pp int)) - (to_pairs (nameserver_ips t.nameservers))))) + Lwt.return err | Ok (addr, socket) -> let continue socket = t.fd <- Some socket; @@ -381,7 +408,7 @@ module Transport : Dns_client.S | Error (`Msg msg) -> Log.err (fun m -> m "error while connecting to resolver: %s" msg) | Ok () -> ()); - Lwt_condition.broadcast connected_condition (); + Lwt_condition.broadcast connected_condition (Ok ()); t.connected_condition <- None; req_all socket t in @@ -404,18 +431,17 @@ module Transport : Dns_client.S nameservers in if ns' = [] then begin - Lwt_condition.broadcast connected_condition (); + let err = Error (`Msg "no further nameservers configured") in + Lwt_condition.broadcast connected_condition err; t.connected_condition <- None; - Lwt.return (Error (`Msg "no further nameservers configured")) + Lwt.return err end else connect_to_ns_list t connected_condition ns') and connect_via_tcp_to_ns (t : t) = match t.fd, t.connected_condition with | Some _, _ -> Lwt.return (Ok ()) - | None, Some w -> - Lwt_condition.wait w >>= fun () -> - connect_via_tcp_to_ns t + | None, Some w -> Lwt_condition.wait w | None, None -> let connected_condition = Lwt_condition.create () in t.connected_condition <- Some connected_condition ; diff --git a/mirage/client/dns_client_mirage.ml b/mirage/client/dns_client_mirage.ml index 8e3286f3..de35af64 100644 --- a/mirage/client/dns_client_mirage.ml +++ b/mirage/client/dns_client_mirage.ml @@ -127,40 +127,42 @@ The format of a nameserver is: stack : stack ; mutable udp_ports : IS.t ; mutable flow : [`Plain of S.TCP.flow | `Tls of TLS.flow ] option ; - mutable connected_condition : unit Lwt_condition.t option ; + mutable connected_condition : (unit, [ `Msg of string ]) result Lwt_condition.t option ; mutable requests : (Cstruct.t * (Cstruct.t, [ `Msg of string ]) result Lwt_condition.t) IM.t ; mutable he : Happy_eyeballs.t ; - mutable cancel_connecting : unit Lwt.u Happy_eyeballs.Waiter_map.t ; + mutable cancel_connecting : (int * unit Lwt.u) list Happy_eyeballs.Waiter_map.t ; mutable waiters : ((Ipaddr.t * int) * S.TCP.flow, [ `Msg of string ]) result Lwt.u Happy_eyeballs.Waiter_map.t ; timer_condition : unit Lwt_condition.t ; } type context = t let clock = M.elapsed_ns - let he_timer_interval = Duration.of_ms 500 + let he_timer_interval = Duration.of_ms 10 + + let try_connect stack addr = + let open Lwt.Infix in + S.TCP.create_connection (S.tcp stack) addr >|= + Result.map_error + (fun err -> `Msg (Fmt.str "error connecting to nameserver %a:%u: %a" + Ipaddr.pp (fst addr) (snd addr) S.TCP.pp_error err)) let handle_one_action t = function - | Happy_eyeballs.Connect (host, id, addr) -> + | Happy_eyeballs.Connect (host, id, attempt, addr) -> let cancelled, cancel = Lwt.task () in - t.cancel_connecting <- Happy_eyeballs.Waiter_map.add id cancel t.cancel_connecting; - Lwt.pick [ - begin - S.TCP.create_connection (S.tcp t.stack) addr >>= function - | Error e -> - let err = - Fmt.str "error connecting to nameserver %a: %a" - Ipaddr.pp (fst addr) S.TCP.pp_error e - in - Lwt.return_error (`Msg err) - | Ok flow -> - Lwt.return_ok flow - end; - begin - cancelled >|= fun () -> Error (`Msg "cancelled") - end; - ] >>= fun r -> - begin match r with + let entry = attempt, cancel in + t.cancel_connecting <- + Happy_eyeballs.Waiter_map.update id + (function None -> Some [ entry ] | Some c -> Some (entry :: c)) + t.cancel_connecting; + let conn = + try_connect t.stack addr >>= function | Ok flow -> + let cancel_connecting, others = + Happy_eyeballs.Waiter_map.find_and_remove id t.cancel_connecting + in + t.cancel_connecting <- cancel_connecting; + List.iter (fun (att, u) -> if att <> attempt then Lwt.wakeup_later u ()) + (Option.value ~default:[] others); let waiters, r = Happy_eyeballs.Waiter_map.find_and_remove id t.waiters in t.waiters <- waiters; begin match r with @@ -170,10 +172,23 @@ The format of a nameserver is: | None -> S.TCP.close flow end >|= fun () -> Some (Happy_eyeballs.Connected (host, id, addr)) - | Error `Msg err -> - Lwt.return (Some (Happy_eyeballs.Connection_failed (host, id, addr, err))) - end + | Error `Msg msg -> + t.cancel_connecting <- + Happy_eyeballs.Waiter_map.update id + (function None -> None | Some c -> + match List.filter (fun (att, _) -> not (att = attempt)) c with + | [] -> None + | c -> Some c) + t.cancel_connecting; + Lwt.return (Some (Happy_eyeballs.Connection_failed (host, id, addr, msg))) + in + Lwt.pick [ conn ; (cancelled >|= fun () -> None) ] | Connect_failed (host, id, reason) -> + let cancel_connecting, others = + Happy_eyeballs.Waiter_map.find_and_remove id t.cancel_connecting + in + t.cancel_connecting <- cancel_connecting; + List.iter (fun (_, u) -> Lwt.wakeup_later u ()) (Option.value ~default:[] others); let waiters, r = Happy_eyeballs.Waiter_map.find_and_remove id t.waiters in t.waiters <- waiters; begin match r with @@ -185,10 +200,6 @@ The format of a nameserver is: | None -> () end; Lwt.return None - | Connect_cancelled (_host, id) -> - (match Happy_eyeballs.Waiter_map.find_opt id t.cancel_connecting with - | None -> Lwt.return_none - | Some cancel -> Lwt.wakeup cancel (); Lwt.return_none) | Resolve_a _ | Resolve_aaaa _ as a -> Log.warn (fun m -> m "ignoring action %a" Happy_eyeballs.pp_action a); Lwt.return None @@ -270,7 +281,7 @@ The format of a nameserver is: flow = None ; connected_condition = None ; requests = IM.empty ; - he = Happy_eyeballs.create (clock ()) ; + he = Happy_eyeballs.create ~connect_timeout:timeout (clock ()) ; cancel_connecting = Happy_eyeballs.Waiter_map.empty ; waiters = Happy_eyeballs.Waiter_map.empty ; timer_condition = Lwt_condition.create () ; @@ -394,10 +405,15 @@ The format of a nameserver is: Lwt.async (fun () -> Lwt_list.iter_p (handle_action t) actions); waiter >>= function | Error `Msg msg -> - Lwt_condition.broadcast connected_condition (); + let err = Error (`Msg (Fmt.str "error %s connecting to resolver %a" + msg + Fmt.(list ~sep:(any ", ") (pair ~sep:(any ":") Ipaddr.pp int)) + (to_pairs t.nameservers))) + in + Lwt_condition.broadcast connected_condition err; t.connected_condition <- None; Log.err (fun m -> m "error connecting to resolver %s" msg); - Lwt.return (Error (`Msg "connect failure")) + Lwt.return err | Ok (addr, flow) -> let continue flow = t.flow <- Some flow; @@ -410,7 +426,7 @@ The format of a nameserver is: | Ok () -> () else Lwt.return_unit); - Lwt_condition.broadcast connected_condition (); + Lwt_condition.broadcast connected_condition (Ok ()); t.connected_condition <- None; req_all flow t in @@ -423,8 +439,6 @@ The format of a nameserver is: | Error e -> Log.warn (fun m -> m "error establishing TLS connection to %a:%d: %a" Ipaddr.pp (fst addr) (snd addr) TLS.pp_write_error e); - Lwt_condition.broadcast connected_condition (); - t.connected_condition <- None; let ns' = List.filter (function | `Tls (_, ip, port) -> @@ -432,23 +446,25 @@ The format of a nameserver is: | _ -> true) nameservers in - if ns' = [] then - Lwt.return (Error (`Msg "no further nameservers configured")) - else + if ns' = [] then begin + let err = Error (`Msg "no further nameservers configured") in + Lwt_condition.broadcast connected_condition err; + t.connected_condition <- None; + Lwt.return err + end else connect_ns t ns' - let rec connect t = + let connect t = + let to_tcp = function + | Ok () -> Ok (`Tcp, t) + | Error `Msg msg -> Error (`Msg msg) + in match t.proto with | `Udp -> Lwt.return (Ok (`Udp, t)) | `Tcp -> match t.flow, t.connected_condition with | Some _, _ -> Lwt.return (Ok (`Tcp, t)) - | None, Some w -> - Lwt_condition.wait w >>= fun () -> - connect t - | None, None -> - connect_ns t t.nameservers >|= function - | Ok () -> Ok (`Tcp, t) - | Error `Msg msg -> Error (`Msg msg) + | None, Some w -> Lwt_condition.wait w >|= to_tcp + | None, None -> connect_ns t t.nameservers >|= to_tcp let close _f = (* ignoring this here *)