Skip to content

Commit

Permalink
adapt to new happy eyeballs (#329)
Browse files Browse the repository at this point in the history
* adapt to new happy eyeballs: handle cancellation

Co-authored-by: Reynir Björnsson <[email protected]>
  • Loading branch information
hannesm and reynir authored Dec 2, 2022
1 parent b1ab0f5 commit 757d120
Show file tree
Hide file tree
Showing 3 changed files with 168 additions and 111 deletions.
2 changes: 1 addition & 1 deletion dns-client.opam
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ depends: [
"mirage-clock" {>= "3.0.0"}
"mtime" {>= "1.2.0"}
"mirage-crypto-rng" {>= "0.8.0"}
"happy-eyeballs" {>= "0.1.0"}
"happy-eyeballs" {>= "0.4.0"}
"alcotest" {with-test}
"tls" {>= "0.15.0"}
"tls-mirage" {>= "0.15.0"}
Expand Down
92 changes: 60 additions & 32 deletions lwt/client/dns_client_lwt.ml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ module Transport : Dns_client.S
mutable connected_condition : unit Lwt_condition.t option ;
mutable requests : (Cstruct.t * (Cstruct.t, [ `Msg of string ]) result Lwt_condition.t) IM.t ;
mutable he : Happy_eyeballs.t ;
mutable cancel_connecting : unit Lwt.u Happy_eyeballs.Waiter_map.t;
mutable waiters : ((Ipaddr.t * int) * Lwt_unix.file_descr, [ `Msg of string ]) result Lwt.u Happy_eyeballs.Waiter_map.t ;
timer_condition : unit Lwt_condition.t ;
}
Expand Down Expand Up @@ -55,45 +56,71 @@ module Transport : Dns_client.S
let close_socket fd =
Lwt.catch (fun () -> Lwt_unix.close fd) (fun _ -> Lwt.return_unit)

let rec handle_action t action =
(match action with
| Happy_eyeballs.Connect (host, id, (ip, port)) ->
Lwt_unix.(getprotobyname "tcp" >|= fun x -> x.p_proto) >>= fun proto_number ->
let fam =
Ipaddr.(Lwt_unix.(match ip with V4 _ -> PF_INET | V6 _ -> PF_INET6))
in
let socket = Lwt_unix.socket fam Lwt_unix.SOCK_STREAM proto_number in
let addr = Lwt_unix.ADDR_INET (Ipaddr_unix.to_inet_addr ip, port) in
Lwt.catch (fun () ->
Lwt_unix.connect socket addr >>= fun () ->
let waiters, r = Happy_eyeballs.Waiter_map.find_and_remove id t.waiters in
t.waiters <- waiters;
begin match r with
| Some waiter -> Lwt.wakeup_later waiter (Ok ((ip, port), socket)); Lwt.return_unit
| None -> close_socket socket
end >|= fun () ->
Some (Happy_eyeballs.Connected (host, id, (ip, port))))
(fun e ->
Log.err (fun m -> m "connection to %a:%d failed: %s" Ipaddr.pp ip port
(Printexc.to_string e));
close_socket socket >|= fun () ->
Some (Happy_eyeballs.Connection_failed (host, id, (ip, port))))
| Connect_failed (_host, id) ->
let handle_one_action t = function
| Happy_eyeballs.Connect (host, id, (ip, port)) ->
let cancelled, cancel = Lwt.task () in
t.cancel_connecting <- Happy_eyeballs.Waiter_map.add id cancel t.cancel_connecting;
Lwt_unix.(getprotobyname "tcp" >|= fun x -> x.p_proto) >>= fun proto_number ->
let fam =
Ipaddr.(Lwt_unix.(match ip with V4 _ -> PF_INET | V6 _ -> PF_INET6))
in
let socket = Lwt_unix.socket fam Lwt_unix.SOCK_STREAM proto_number in
let addr = Lwt_unix.ADDR_INET (Ipaddr_unix.to_inet_addr ip, port) in
Lwt.pick [
Lwt.try_bind
(fun () -> Lwt_unix.connect socket addr)
Lwt.return_ok
(fun e ->
let err =
Fmt.str "error %s connecting to nameserver %a:%d"
(Printexc.to_string e) Ipaddr.pp ip port
in
Lwt.return (Error (`Msg err)));
(cancelled >|= fun () -> Error (`Msg "cancelled"));
] >>= fun r ->
t.cancel_connecting <- Happy_eyeballs.Waiter_map.remove id t.cancel_connecting;
begin match r with
| Ok () ->
let waiters, r = Happy_eyeballs.Waiter_map.find_and_remove id t.waiters in
t.waiters <- waiters;
begin match r with
| Some waiter ->
Lwt.wakeup_later waiter (Ok ((ip, port), socket));
Lwt.return_unit
| None -> close_socket socket
end >|= fun () ->
Some (Happy_eyeballs.Connected (host, id, (ip, port)))
| Error `Msg err ->
close_socket socket >|= fun () ->
Some (Happy_eyeballs.Connection_failed (host, id, (ip, port), err))
end
| Connect_failed (host, id, reason) ->
let waiters, r = Happy_eyeballs.Waiter_map.find_and_remove id t.waiters in
t.waiters <- waiters;
begin match r with
| Some waiter -> Lwt.wakeup_later waiter (Error (`Msg "connection failed"))
| Some waiter ->
let err =
Fmt.str "connection to %a failed: %s" Domain_name.pp host reason
in
Lwt.wakeup_later waiter (Error (`Msg err))
| None -> ()
end;
Lwt.return None
| a ->
| Connect_cancelled (_host, id) ->
(match Happy_eyeballs.Waiter_map.find_opt id t.cancel_connecting with
| None -> Lwt.return_none
| Some cancel -> Lwt.wakeup cancel (); Lwt.return_none)
| Resolve_a _ | Resolve_aaaa _ as a ->
Log.warn (fun m -> m "ignoring action %a" Happy_eyeballs.pp_action a);
Lwt.return None) >>= function
| None -> Lwt.return_unit
| Some event ->
let he, actions = Happy_eyeballs.event t.he (clock ()) event in
t.he <- he;
Lwt_list.iter_p (handle_action t) actions
Lwt.return None

let rec handle_action t action =
handle_one_action t action >>= function
| None -> Lwt.return_unit
| Some event ->
let he, actions = Happy_eyeballs.event t.he (clock ()) event in
t.he <- he;
Lwt_list.iter_p (handle_action t) actions

let handle_timer_actions t actions =
Lwt.async (fun () -> Lwt_list.iter_p (fun a -> handle_action t a) actions)
Expand Down Expand Up @@ -208,6 +235,7 @@ module Transport : Dns_client.S
connected_condition = None ;
requests = IM.empty ;
he = Happy_eyeballs.create (clock ()) ;
cancel_connecting = Happy_eyeballs.Waiter_map.empty ;
waiters = Happy_eyeballs.Waiter_map.empty ;
timer_condition = Lwt_condition.create () ;
} in
Expand Down
185 changes: 107 additions & 78 deletions mirage/client/dns_client_mirage.ml
Original file line number Diff line number Diff line change
Expand Up @@ -66,58 +66,58 @@ The format of a nameserver is:
let nameserver_of_string str =
let ( let* ) = Result.bind in
begin match String.split_on_char ':' str with
| "tls" :: rest ->
let str = String.concat ":" rest in
( match String.split_on_char '!' str with
| [ nameserver ] ->
let* ipaddr, port = Ipaddr.with_port_of_string ~default:853 nameserver in
let* authenticator = CA.authenticator () in
let tls = Tls.Config.client ~authenticator () in
Ok (`Tcp, `Tls (tls, ipaddr, port))
| nameserver :: opt_hostname :: authenticator ->
let* ipaddr, port = Ipaddr.with_port_of_string ~default:853 nameserver in
let peer_name, data =
match
let* dn = Domain_name.of_string opt_hostname in
Domain_name.host dn
with
| Ok hostname -> Some hostname, String.concat "!" authenticator
| Error _ -> None, String.concat "!" (opt_hostname :: authenticator)
in
let* authenticator =
if data = "" then
CA.authenticator ()
else
let* a = X509.Authenticator.of_string data in
Ok (a (fun () -> Some (Ptime.v (P.now_d_ps ()))))
in
let tls = Tls.Config.client ~authenticator ?peer_name () in
Ok (`Tcp, `Tls (tls, ipaddr, port))
| [] -> assert false )
| "tcp" :: nameserver ->
let str = String.concat ":" nameserver in
let* ipaddr, port = Ipaddr.with_port_of_string ~default:53 str in
Ok (`Tcp, `Plaintext (ipaddr, port))
| "udp" :: nameserver ->
let str = String.concat ":" nameserver in
let* ipaddr, port = Ipaddr.with_port_of_string ~default:53 str in
Ok (`Udp, `Plaintext (ipaddr, port))
| _ ->
Error (`Msg ("Unable to decode nameserver " ^ str))
end |> Result.map_error (function `Msg e -> `Msg (e ^ format))
| "tls" :: rest ->
let str = String.concat ":" rest in
( match String.split_on_char '!' str with
| [ nameserver ] ->
let* ipaddr, port = Ipaddr.with_port_of_string ~default:853 nameserver in
let* authenticator = CA.authenticator () in
let tls = Tls.Config.client ~authenticator () in
Ok (`Tcp, `Tls (tls, ipaddr, port))
| nameserver :: opt_hostname :: authenticator ->
let* ipaddr, port = Ipaddr.with_port_of_string ~default:853 nameserver in
let peer_name, data =
match
let* dn = Domain_name.of_string opt_hostname in
Domain_name.host dn
with
| Ok hostname -> Some hostname, String.concat "!" authenticator
| Error _ -> None, String.concat "!" (opt_hostname :: authenticator)
in
let* authenticator =
if data = "" then
CA.authenticator ()
else
let* a = X509.Authenticator.of_string data in
Ok (a (fun () -> Some (Ptime.v (P.now_d_ps ()))))
in
let tls = Tls.Config.client ~authenticator ?peer_name () in
Ok (`Tcp, `Tls (tls, ipaddr, port))
| [] -> assert false )
| "tcp" :: nameserver ->
let str = String.concat ":" nameserver in
let* ipaddr, port = Ipaddr.with_port_of_string ~default:53 str in
Ok (`Tcp, `Plaintext (ipaddr, port))
| "udp" :: nameserver ->
let str = String.concat ":" nameserver in
let* ipaddr, port = Ipaddr.with_port_of_string ~default:53 str in
Ok (`Udp, `Plaintext (ipaddr, port))
| _ ->
Error (`Msg ("Unable to decode nameserver " ^ str))
end |> Result.map_error (function `Msg e -> `Msg (e ^ format))

module Transport : Dns_client.S
with type stack = S.t
and type +'a io = 'a Lwt.t
and type io_addr = [
| `Plaintext of Ipaddr.t * int
| `Tls of Tls.Config.client * Ipaddr.t * int
] = struct
| `Plaintext of Ipaddr.t * int
| `Tls of Tls.Config.client * Ipaddr.t * int
] = struct
type stack = S.t
type io_addr = [
| `Plaintext of Ipaddr.t * int
| `Tls of Tls.Config.client * Ipaddr.t * int
]
| `Plaintext of Ipaddr.t * int
| `Tls of Tls.Config.client * Ipaddr.t * int
]
type +'a io = 'a Lwt.t
module IS = Set.Make(Int)
type t = {
Expand All @@ -130,6 +130,7 @@ The format of a nameserver is:
mutable connected_condition : unit Lwt_condition.t option ;
mutable requests : (Cstruct.t * (Cstruct.t, [ `Msg of string ]) result Lwt_condition.t) IM.t ;
mutable he : Happy_eyeballs.t ;
mutable cancel_connecting : unit Lwt.u Happy_eyeballs.Waiter_map.t ;
mutable waiters : ((Ipaddr.t * int) * S.TCP.flow, [ `Msg of string ]) result Lwt.u Happy_eyeballs.Waiter_map.t ;
timer_condition : unit Lwt_condition.t ;
}
Expand All @@ -138,40 +139,67 @@ The format of a nameserver is:
let clock = M.elapsed_ns
let he_timer_interval = Duration.of_ms 500

let handle_one_action t = function
| Happy_eyeballs.Connect (host, id, addr) ->
let cancelled, cancel = Lwt.task () in
t.cancel_connecting <- Happy_eyeballs.Waiter_map.add id cancel t.cancel_connecting;
Lwt.pick [
begin
S.TCP.create_connection (S.tcp t.stack) addr >>= function
| Error e ->
let err =
Fmt.str "error connecting to nameserver %a: %a"
Ipaddr.pp (fst addr) S.TCP.pp_error e
in
Lwt.return_error (`Msg err)
| Ok flow ->
Lwt.return_ok flow
end;
begin
cancelled >|= fun () -> Error (`Msg "cancelled")
end;
] >>= fun r ->
begin match r with
| Ok flow ->
let waiters, r = Happy_eyeballs.Waiter_map.find_and_remove id t.waiters in
t.waiters <- waiters;
begin match r with
| Some waiter ->
Lwt.wakeup_later waiter (Ok (addr, flow));
Lwt.return_unit
| None -> S.TCP.close flow
end >|= fun () ->
Some (Happy_eyeballs.Connected (host, id, addr))
| Error `Msg err ->
Lwt.return (Some (Happy_eyeballs.Connection_failed (host, id, addr, err)))
end
| Connect_failed (host, id, reason) ->
let waiters, r = Happy_eyeballs.Waiter_map.find_and_remove id t.waiters in
t.waiters <- waiters;
begin match r with
| Some waiter ->
let err =
Fmt.str "connection to %a failed: %s" Domain_name.pp host reason
in
Lwt.wakeup_later waiter (Error (`Msg err))
| None -> ()
end;
Lwt.return None
| Connect_cancelled (_host, id) ->
(match Happy_eyeballs.Waiter_map.find_opt id t.cancel_connecting with
| None -> Lwt.return_none
| Some cancel -> Lwt.wakeup cancel (); Lwt.return_none)
| Resolve_a _ | Resolve_aaaa _ as a ->
Log.warn (fun m -> m "ignoring action %a" Happy_eyeballs.pp_action a);
Lwt.return None

let rec handle_action t action =
(match action with
| Happy_eyeballs.Connect (host, id, addr) ->
begin
S.TCP.create_connection (S.tcp t.stack) addr >>= function
| Error e ->
Log.err (fun m -> m "error connecting to nameserver %a: %a"
Ipaddr.pp (fst addr) S.TCP.pp_error e) ;
Lwt.return (Some (Happy_eyeballs.Connection_failed (host, id, addr)))
| Ok flow ->
let waiters, r = Happy_eyeballs.Waiter_map.find_and_remove id t.waiters in
t.waiters <- waiters;
begin match r with
| Some waiter -> Lwt.wakeup_later waiter (Ok (addr, flow)); Lwt.return_unit
| None -> S.TCP.close flow
end >|= fun () ->
Some (Happy_eyeballs.Connected (host, id, addr))
end
| Connect_failed (_host, id) ->
let waiters, r = Happy_eyeballs.Waiter_map.find_and_remove id t.waiters in
t.waiters <- waiters;
begin match r with
| Some waiter -> Lwt.wakeup_later waiter (Error (`Msg "connection failed"))
| None -> ()
end;
Lwt.return None
| a ->
Log.warn (fun m -> m "ignoring action %a" Happy_eyeballs.pp_action a);
Lwt.return None) >>= function
| None -> Lwt.return_unit
| Some event ->
let he, actions = Happy_eyeballs.event t.he (clock ()) event in
t.he <- he;
Lwt_list.iter_p (handle_action t) actions
handle_one_action t action >>= function
| None -> Lwt.return_unit
| Some event ->
let he, actions = Happy_eyeballs.event t.he (clock ()) event in
t.he <- he;
Lwt_list.iter_p (handle_action t) actions

let handle_timer_actions t actions =
Lwt.async (fun () -> Lwt_list.iter_p (fun a -> handle_action t a) actions)
Expand Down Expand Up @@ -243,6 +271,7 @@ The format of a nameserver is:
connected_condition = None ;
requests = IM.empty ;
he = Happy_eyeballs.create (clock ()) ;
cancel_connecting = Happy_eyeballs.Waiter_map.empty ;
waiters = Happy_eyeballs.Waiter_map.empty ;
timer_condition = Lwt_condition.create () ;
} in
Expand Down

0 comments on commit 757d120

Please sign in to comment.