Skip to content

Commit

Permalink
Adapt to happy-eyeballs 0.6 changes (#340)
Browse files Browse the repository at this point in the history
* Adapt to happy-eyeballs 0.6 changes

* further fixes:
- when a connection failed (i.e. timeout), cancel all connection attempts
- transport the success/failure via the connected_condition instead of recursing

* pass connect_timeout to happy_eyeballs creation

* address review from @reynir: filter the current attempt from being woken up/cancelled

---------

Co-authored-by: Robur <[email protected]>
  • Loading branch information
hannesm and robur-team authored Jun 15, 2023
1 parent 08a763b commit 2cde843
Show file tree
Hide file tree
Showing 2 changed files with 136 additions and 94 deletions.
122 changes: 74 additions & 48 deletions lwt/client/dns_client_lwt.ml
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ module Transport : Dns_client.S
timeout_ns : int64 ;
(* TODO: avoid race, use a mvar instead of condition *)
mutable fd : [ `Plain of Lwt_unix.file_descr | `Tls of Tls_lwt.Unix.t ] option ;
mutable connected_condition : unit Lwt_condition.t option ;
mutable connected_condition : (unit, [ `Msg of string ]) result Lwt_condition.t option ;
mutable requests : (Cstruct.t * (Cstruct.t, [ `Msg of string ]) result Lwt_condition.t) IM.t ;
mutable he : Happy_eyeballs.t ;
mutable cancel_connecting : unit Lwt.u Happy_eyeballs.Waiter_map.t;
mutable cancel_connecting : (int * unit Lwt.u) list Happy_eyeballs.Waiter_map.t;
mutable waiters : ((Ipaddr.t * int) * Lwt_unix.file_descr, [ `Msg of string ]) result Lwt.u Happy_eyeballs.Waiter_map.t ;
timer_condition : unit Lwt_condition.t ;
}
Expand All @@ -51,65 +51,89 @@ module Transport : Dns_client.S

let clock = Mtime_clock.elapsed_ns

let he_timer_interval = Duration.of_ms 500
let he_timer_interval = Duration.of_ms 10

let close_socket fd =
Lwt.catch (fun () -> Lwt_unix.close fd) (fun _ -> Lwt.return_unit)

let try_connect ip port =
Lwt_unix.(getprotobyname "tcp" >|= fun x -> x.p_proto) >>= fun proto_number ->
let fam =
Ipaddr.(Lwt_unix.(match ip with V4 _ -> PF_INET | V6 _ -> PF_INET6))
in
let socket = Lwt_unix.socket fam Lwt_unix.SOCK_STREAM proto_number in
let open Lwt_result.Infix in
Lwt.catch (fun () ->
let addr = Lwt_unix.ADDR_INET (Ipaddr_unix.to_inet_addr ip, port) in
Lwt_result.ok (Lwt_unix.connect socket addr) >|= fun () ->
socket)
(fun e ->
Lwt_result.ok (close_socket socket) >>= fun () ->
let err =
Fmt.str "error %s connecting to nameserver %a:%d"
(Printexc.to_string e) Ipaddr.pp ip port
in
Lwt.return (Error (`Msg err)))

let handle_one_action t = function
| Happy_eyeballs.Connect (host, id, (ip, port)) ->
| Happy_eyeballs.Connect (host, id, attempt, (ip, port)) ->
let cancelled, cancel = Lwt.task () in
t.cancel_connecting <- Happy_eyeballs.Waiter_map.add id cancel t.cancel_connecting;
Lwt_unix.(getprotobyname "tcp" >|= fun x -> x.p_proto) >>= fun proto_number ->
let fam =
Ipaddr.(Lwt_unix.(match ip with V4 _ -> PF_INET | V6 _ -> PF_INET6))
in
let socket = Lwt_unix.socket fam Lwt_unix.SOCK_STREAM proto_number in
let addr = Lwt_unix.ADDR_INET (Ipaddr_unix.to_inet_addr ip, port) in
Lwt.pick [
Lwt.try_bind
(fun () -> Lwt_unix.connect socket addr)
Lwt.return_ok
(fun e ->
let err =
Fmt.str "error %s connecting to nameserver %a:%d"
(Printexc.to_string e) Ipaddr.pp ip port
in
Lwt.return (Error (`Msg err)));
(cancelled >|= fun () -> Error (`Msg "cancelled"));
] >>= fun r ->
t.cancel_connecting <- Happy_eyeballs.Waiter_map.remove id t.cancel_connecting;
begin match r with
| Ok () ->
let entry = attempt, cancel in
t.cancel_connecting <-
Happy_eyeballs.Waiter_map.update id
(function None -> Some [ entry ] | Some c -> Some (entry :: c))
t.cancel_connecting;
let conn =
try_connect ip port >>= function
| Ok fd ->
let cancel_connecting, others =
Happy_eyeballs.Waiter_map.find_and_remove id t.cancel_connecting
in
t.cancel_connecting <- cancel_connecting;
List.iter (fun (att, w) -> if att <> attempt then Lwt.wakeup_later w ())
(Option.value ~default:[] others);
let waiters, r = Happy_eyeballs.Waiter_map.find_and_remove id t.waiters in
t.waiters <- waiters;
begin match r with
| Some waiter ->
Lwt.wakeup_later waiter (Ok ((ip, port), socket));
Lwt.wakeup_later waiter (Ok ((ip, port), fd));
Lwt.return_unit
| None -> close_socket socket
| None -> close_socket fd
end >|= fun () ->
Some (Happy_eyeballs.Connected (host, id, (ip, port)))
| Error `Msg err ->
close_socket socket >|= fun () ->
Some (Happy_eyeballs.Connection_failed (host, id, (ip, port), err))
end
t.cancel_connecting <-
Happy_eyeballs.Waiter_map.update id
(function None -> None | Some c ->
match List.filter (fun (att, _) -> not (att = attempt)) c with
| [] -> None
| c -> Some c)
t.cancel_connecting;
Lwt.return (Some (Happy_eyeballs.Connection_failed (host, id, (ip, port), err)))
in
Lwt.pick [ conn ; (cancelled >|= fun () -> None); ]
| Connect_failed (host, id, reason) ->
let cancel_connecting, others =
Happy_eyeballs.Waiter_map.find_and_remove id t.cancel_connecting
in
t.cancel_connecting <- cancel_connecting;
List.iter (fun (_, w) -> Lwt.wakeup_later w ()) (Option.value ~default:[] others);
let waiters, r = Happy_eyeballs.Waiter_map.find_and_remove id t.waiters in
t.waiters <- waiters;
begin match r with
| Some waiter ->
let host_or_ip v =
match Ipaddr.of_domain_name v with
| None -> Domain_name.to_string v
| Some ip -> Ipaddr.to_string ip
in
let err =
Fmt.str "connection to %a failed: %s" Domain_name.pp host reason
Fmt.str "connection to %s failed: %s" (host_or_ip host) reason
in
Lwt.wakeup_later waiter (Error (`Msg err))
| None -> ()
end;
Lwt.return None
| Connect_cancelled (_host, id) ->
(match Happy_eyeballs.Waiter_map.find_opt id t.cancel_connecting with
| None -> Lwt.return_none
| Some cancel -> Lwt.wakeup cancel (); Lwt.return_none)
| Resolve_a _ | Resolve_aaaa _ as a ->
Log.warn (fun m -> m "ignoring action %a" Happy_eyeballs.pp_action a);
Lwt.return None
Expand Down Expand Up @@ -234,7 +258,7 @@ module Transport : Dns_client.S
fd = None ;
connected_condition = None ;
requests = IM.empty ;
he = Happy_eyeballs.create (clock ()) ;
he = Happy_eyeballs.create ~connect_timeout:timeout (clock ()) ;
cancel_connecting = Happy_eyeballs.Waiter_map.empty ;
waiters = Happy_eyeballs.Waiter_map.empty ;
timer_condition = Lwt_condition.create () ;
Expand Down Expand Up @@ -363,12 +387,15 @@ module Transport : Dns_client.S
Lwt.async (fun () -> Lwt_list.iter_p (handle_action t) actions);
waiter >>= function
| Error `Msg msg ->
Lwt_condition.broadcast connected_condition ();
let err =
Error (`Msg (Fmt.str "error %s connecting to resolver %a"
msg
Fmt.(list ~sep:(any ", ") (pair ~sep:(any ":") Ipaddr.pp int))
(to_pairs (nameserver_ips t.nameservers))))
in
Lwt_condition.broadcast connected_condition err;
t.connected_condition <- None;
Lwt.return
(Error (`Msg (Fmt.str "error %s connecting to resolver %a"
msg Fmt.(list ~sep:(any ",") (pair ~sep:(any ":") Ipaddr.pp int))
(to_pairs (nameserver_ips t.nameservers)))))
Lwt.return err
| Ok (addr, socket) ->
let continue socket =
t.fd <- Some socket;
Expand All @@ -381,7 +408,7 @@ module Transport : Dns_client.S
| Error (`Msg msg) ->
Log.err (fun m -> m "error while connecting to resolver: %s" msg)
| Ok () -> ());
Lwt_condition.broadcast connected_condition ();
Lwt_condition.broadcast connected_condition (Ok ());
t.connected_condition <- None;
req_all socket t
in
Expand All @@ -404,18 +431,17 @@ module Transport : Dns_client.S
nameservers
in
if ns' = [] then begin
Lwt_condition.broadcast connected_condition ();
let err = Error (`Msg "no further nameservers configured") in
Lwt_condition.broadcast connected_condition err;
t.connected_condition <- None;
Lwt.return (Error (`Msg "no further nameservers configured"))
Lwt.return err
end else
connect_to_ns_list t connected_condition ns')

and connect_via_tcp_to_ns (t : t) =
match t.fd, t.connected_condition with
| Some _, _ -> Lwt.return (Ok ())
| None, Some w ->
Lwt_condition.wait w >>= fun () ->
connect_via_tcp_to_ns t
| None, Some w -> Lwt_condition.wait w
| None, None ->
let connected_condition = Lwt_condition.create () in
t.connected_condition <- Some connected_condition ;
Expand Down
108 changes: 62 additions & 46 deletions mirage/client/dns_client_mirage.ml
Original file line number Diff line number Diff line change
Expand Up @@ -127,40 +127,42 @@ The format of a nameserver is:
stack : stack ;
mutable udp_ports : IS.t ;
mutable flow : [`Plain of S.TCP.flow | `Tls of TLS.flow ] option ;
mutable connected_condition : unit Lwt_condition.t option ;
mutable connected_condition : (unit, [ `Msg of string ]) result Lwt_condition.t option ;
mutable requests : (Cstruct.t * (Cstruct.t, [ `Msg of string ]) result Lwt_condition.t) IM.t ;
mutable he : Happy_eyeballs.t ;
mutable cancel_connecting : unit Lwt.u Happy_eyeballs.Waiter_map.t ;
mutable cancel_connecting : (int * unit Lwt.u) list Happy_eyeballs.Waiter_map.t ;
mutable waiters : ((Ipaddr.t * int) * S.TCP.flow, [ `Msg of string ]) result Lwt.u Happy_eyeballs.Waiter_map.t ;
timer_condition : unit Lwt_condition.t ;
}
type context = t

let clock = M.elapsed_ns
let he_timer_interval = Duration.of_ms 500
let he_timer_interval = Duration.of_ms 10

let try_connect stack addr =
let open Lwt.Infix in
S.TCP.create_connection (S.tcp stack) addr >|=
Result.map_error
(fun err -> `Msg (Fmt.str "error connecting to nameserver %a:%u: %a"
Ipaddr.pp (fst addr) (snd addr) S.TCP.pp_error err))

let handle_one_action t = function
| Happy_eyeballs.Connect (host, id, addr) ->
| Happy_eyeballs.Connect (host, id, attempt, addr) ->
let cancelled, cancel = Lwt.task () in
t.cancel_connecting <- Happy_eyeballs.Waiter_map.add id cancel t.cancel_connecting;
Lwt.pick [
begin
S.TCP.create_connection (S.tcp t.stack) addr >>= function
| Error e ->
let err =
Fmt.str "error connecting to nameserver %a: %a"
Ipaddr.pp (fst addr) S.TCP.pp_error e
in
Lwt.return_error (`Msg err)
| Ok flow ->
Lwt.return_ok flow
end;
begin
cancelled >|= fun () -> Error (`Msg "cancelled")
end;
] >>= fun r ->
begin match r with
let entry = attempt, cancel in
t.cancel_connecting <-
Happy_eyeballs.Waiter_map.update id
(function None -> Some [ entry ] | Some c -> Some (entry :: c))
t.cancel_connecting;
let conn =
try_connect t.stack addr >>= function
| Ok flow ->
let cancel_connecting, others =
Happy_eyeballs.Waiter_map.find_and_remove id t.cancel_connecting
in
t.cancel_connecting <- cancel_connecting;
List.iter (fun (att, u) -> if att <> attempt then Lwt.wakeup_later u ())
(Option.value ~default:[] others);
let waiters, r = Happy_eyeballs.Waiter_map.find_and_remove id t.waiters in
t.waiters <- waiters;
begin match r with
Expand All @@ -170,10 +172,23 @@ The format of a nameserver is:
| None -> S.TCP.close flow
end >|= fun () ->
Some (Happy_eyeballs.Connected (host, id, addr))
| Error `Msg err ->
Lwt.return (Some (Happy_eyeballs.Connection_failed (host, id, addr, err)))
end
| Error `Msg msg ->
t.cancel_connecting <-
Happy_eyeballs.Waiter_map.update id
(function None -> None | Some c ->
match List.filter (fun (att, _) -> not (att = attempt)) c with
| [] -> None
| c -> Some c)
t.cancel_connecting;
Lwt.return (Some (Happy_eyeballs.Connection_failed (host, id, addr, msg)))
in
Lwt.pick [ conn ; (cancelled >|= fun () -> None) ]
| Connect_failed (host, id, reason) ->
let cancel_connecting, others =
Happy_eyeballs.Waiter_map.find_and_remove id t.cancel_connecting
in
t.cancel_connecting <- cancel_connecting;
List.iter (fun (_, u) -> Lwt.wakeup_later u ()) (Option.value ~default:[] others);
let waiters, r = Happy_eyeballs.Waiter_map.find_and_remove id t.waiters in
t.waiters <- waiters;
begin match r with
Expand All @@ -185,10 +200,6 @@ The format of a nameserver is:
| None -> ()
end;
Lwt.return None
| Connect_cancelled (_host, id) ->
(match Happy_eyeballs.Waiter_map.find_opt id t.cancel_connecting with
| None -> Lwt.return_none
| Some cancel -> Lwt.wakeup cancel (); Lwt.return_none)
| Resolve_a _ | Resolve_aaaa _ as a ->
Log.warn (fun m -> m "ignoring action %a" Happy_eyeballs.pp_action a);
Lwt.return None
Expand Down Expand Up @@ -270,7 +281,7 @@ The format of a nameserver is:
flow = None ;
connected_condition = None ;
requests = IM.empty ;
he = Happy_eyeballs.create (clock ()) ;
he = Happy_eyeballs.create ~connect_timeout:timeout (clock ()) ;
cancel_connecting = Happy_eyeballs.Waiter_map.empty ;
waiters = Happy_eyeballs.Waiter_map.empty ;
timer_condition = Lwt_condition.create () ;
Expand Down Expand Up @@ -394,10 +405,15 @@ The format of a nameserver is:
Lwt.async (fun () -> Lwt_list.iter_p (handle_action t) actions);
waiter >>= function
| Error `Msg msg ->
Lwt_condition.broadcast connected_condition ();
let err = Error (`Msg (Fmt.str "error %s connecting to resolver %a"
msg
Fmt.(list ~sep:(any ", ") (pair ~sep:(any ":") Ipaddr.pp int))
(to_pairs t.nameservers)))
in
Lwt_condition.broadcast connected_condition err;
t.connected_condition <- None;
Log.err (fun m -> m "error connecting to resolver %s" msg);
Lwt.return (Error (`Msg "connect failure"))
Lwt.return err
| Ok (addr, flow) ->
let continue flow =
t.flow <- Some flow;
Expand All @@ -410,7 +426,7 @@ The format of a nameserver is:
| Ok () -> ()
else
Lwt.return_unit);
Lwt_condition.broadcast connected_condition ();
Lwt_condition.broadcast connected_condition (Ok ());
t.connected_condition <- None;
req_all flow t
in
Expand All @@ -423,32 +439,32 @@ The format of a nameserver is:
| Error e ->
Log.warn (fun m -> m "error establishing TLS connection to %a:%d: %a"
Ipaddr.pp (fst addr) (snd addr) TLS.pp_write_error e);
Lwt_condition.broadcast connected_condition ();
t.connected_condition <- None;
let ns' =
List.filter (function
| `Tls (_, ip, port) ->
not (Ipaddr.compare ip (fst addr) = 0 && port = snd addr)
| _ -> true)
nameservers
in
if ns' = [] then
Lwt.return (Error (`Msg "no further nameservers configured"))
else
if ns' = [] then begin
let err = Error (`Msg "no further nameservers configured") in
Lwt_condition.broadcast connected_condition err;
t.connected_condition <- None;
Lwt.return err
end else
connect_ns t ns'

let rec connect t =
let connect t =
let to_tcp = function
| Ok () -> Ok (`Tcp, t)
| Error `Msg msg -> Error (`Msg msg)
in
match t.proto with
| `Udp -> Lwt.return (Ok (`Udp, t))
| `Tcp -> match t.flow, t.connected_condition with
| Some _, _ -> Lwt.return (Ok (`Tcp, t))
| None, Some w ->
Lwt_condition.wait w >>= fun () ->
connect t
| None, None ->
connect_ns t t.nameservers >|= function
| Ok () -> Ok (`Tcp, t)
| Error `Msg msg -> Error (`Msg msg)
| None, Some w -> Lwt_condition.wait w >|= to_tcp
| None, None -> connect_ns t t.nameservers >|= to_tcp

let close _f =
(* ignoring this here *)
Expand Down

0 comments on commit 2cde843

Please sign in to comment.