diff --git a/lwt/client/dns_client_lwt.ml b/lwt/client/dns_client_lwt.ml index b36a967da..c0745b503 100644 --- a/lwt/client/dns_client_lwt.ml +++ b/lwt/client/dns_client_lwt.ml @@ -13,8 +13,14 @@ module Transport : Dns_client.S type io_addr = [ `Plaintext of Ipaddr.t * int | `Tls of Tls.Config.client * Ipaddr.t * int ] type +'a io = 'a Lwt.t type stack = unit + type nameservers = + | Static of io_addr list + | Resolv_conf of { + mutable nameservers : io_addr list; + mutable digest : Digest.t option + } type t = { - nameservers : io_addr list ; + nameservers : nameservers; timeout_ns : int64 ; (* TODO: avoid race, use a mvar instead of condition *) mutable fd : [ `Plain of Lwt_unix.file_descr | `Tls of Tls_lwt.Unix.t ] option ; @@ -26,6 +32,10 @@ module Transport : Dns_client.S } type context = t + let nameserver_ips = function + | Static nameservers -> nameservers + | Resolv_conf { nameservers; _ } -> nameservers + let read_file file = try let fh = open_in file in @@ -103,35 +113,96 @@ module Transport : Dns_client.S Lwt_condition.wait t.timer_condition >>= fun () -> loop () + let authenticator = + let authenticator_ref = ref None in + fun () -> + match !authenticator_ref with + | Some x -> x + | None -> match Ca_certs.authenticator () with + | Ok a -> authenticator_ref := Some a ; a + | Error `Msg m -> invalid_arg ("failed to load trust anchors: " ^ m) + + let decode_resolv_conf data = + let ( let* ) = Result.bind in + let authenticator = authenticator () in + let* ns = Dns_resolvconf.parse data in + match + List.flatten + (List.map + (fun (`Nameserver ip) -> + let tls = Tls.Config.client ~authenticator ~ip () in + [ `Tls (tls, ip, 853) ; `Plaintext (ip, 53) ]) + ns) + with + | [] -> Error (`Msg "no nameservers in resolv.conf") + | ns -> Ok ns + + let resolv_conf () = + let ( let* ) = Result.bind in + let* data = read_file "/etc/resolv.conf" in + let* ns = + Result.map_error + (function `Msg msg -> + Log.warn (fun m -> m "error %s decoding resolv.conf %S" msg data); + `Msg msg) + (decode_resolv_conf data) + in + Ok (ns, Digest.string data) + + let default_resolver () = + let authenticator = authenticator () in + let peer_name = Dns_client.default_resolver_hostname in + let tls_config = Tls.Config.client ~authenticator ~peer_name () in + List.flatten + (List.map (fun ip -> [ + `Tls (tls_config, ip, 853); `Plaintext (ip, 53) + ]) Dns_client.default_resolvers) + + let maybe_resolv_conf t = + match t.nameservers with + | Static _ -> () + | Resolv_conf resolv_conf -> + let needs_update = + match read_file "/etc/resolv.conf", resolv_conf.digest with + | Ok data, Some dgst -> + let dgst' = Digest.string data in + if Digest.equal dgst' dgst then + `No + else + `Data (data, dgst') + | Ok data, None -> + let digest = Digest.string data in + `Data (data, digest) + | Error _, None -> + `No + | Error `Msg msg, Some _ -> + Log.warn (fun m -> m "error reading /etc/resolv.conf: %s" msg); + `Default + in + match needs_update with + | `No -> () + | `Default -> + resolv_conf.digest <- None; + resolv_conf.nameservers <- default_resolver () + | `Data (data, dgst) -> + match decode_resolv_conf data with + | Ok ns -> + resolv_conf.digest <- Some dgst; + resolv_conf.nameservers <- ns + | Error `Msg msg -> + Log.warn (fun m -> m "error %s decoding resolv.conf: %S" msg data); + resolv_conf.digest <- None; + resolv_conf.nameservers <- default_resolver () + let create ?nameservers ~timeout () = let nameservers = match nameservers with | Some (`Udp, _) -> invalid_arg "UDP is not supported" - | Some (`Tcp, ns) -> ns + | Some (`Tcp, ns) -> Static ns | None -> - let authenticator = match Ca_certs.authenticator () with - | Ok a -> a - | Error `Msg m -> invalid_arg ("failed to load trust anchors: " ^ m) - in - match - let ( let* ) = Result.bind in - let* data = read_file "/etc/resolv.conf" in - let* ns = Dns_resolvconf.parse data in - Ok (List.flatten - (List.map - (fun (`Nameserver ip) -> - let tls = Tls.Config.client ~authenticator ~ip () in - [ `Tls (tls, ip, 853) ; `Plaintext (ip, 53) ]) - ns)) - with - | Error _ | Ok [] -> - let peer_name = Dns_client.default_resolver_hostname in - let tls_config = Tls.Config.client ~authenticator ~peer_name () in - List.flatten - (List.map (fun ip -> [ - `Tls (tls_config, ip, 853); `Plaintext (ip, 53) - ]) Dns_client.default_resolvers) - | Ok ips -> ips + match resolv_conf () with + | Error _ -> Resolv_conf { nameservers = default_resolver (); digest = None } + | Ok (ips, digest) -> Resolv_conf { nameservers = ips; digest = Some digest } in let t = { nameservers ; @@ -146,7 +217,8 @@ module Transport : Dns_client.S Lwt.async (fun () -> he_timer t); t - let nameservers { nameservers ; _ } = `Tcp, nameservers + let nameservers { nameservers; _ } = `Tcp, nameserver_ips nameservers + let rng = Mirage_crypto_rng.generate ?g:None let with_timeout timeout f = @@ -254,74 +326,78 @@ module Transport : Dns_client.S Ipaddr.compare ip addr = 0 && p = port) ns - let rec connect_via_tcp_to_ns (t : t) nameservers = + let rec connect_to_ns_list (t : t) connected_condition nameservers = + let waiter, notify = Lwt.task () in + let waiters, id = Happy_eyeballs.Waiter_map.register notify t.waiters in + t.waiters <- waiters; + let ns = to_pairs nameservers in + let he, actions = Happy_eyeballs.connect_ip t.he (clock ()) ~id ns in + t.he <- he; + Lwt_condition.signal t.timer_condition (); + Lwt.async (fun () -> Lwt_list.iter_p (handle_action t) actions); + waiter >>= function + | Error `Msg msg -> + Lwt_condition.broadcast connected_condition (); + t.connected_condition <- None; + Lwt.return + (Error (`Msg (Fmt.str "error %s connecting to resolver %a" + msg Fmt.(list ~sep:(any ",") (pair ~sep:(any ":") Ipaddr.pp int)) + (to_pairs (nameserver_ips t.nameservers))))) + | Ok (addr, socket) -> + let continue socket = + t.fd <- Some socket; + Lwt.async (fun () -> + read_loop t socket >>= fun () -> + if IM.is_empty t.requests then + Lwt.return_unit + else + connect_via_tcp_to_ns t >|= function + | Error (`Msg msg) -> + Log.err (fun m -> m "error while connecting to resolver: %s" msg) + | Ok () -> ()); + Lwt_condition.broadcast connected_condition (); + t.connected_condition <- None; + req_all socket t + in + let config = find_ns (nameserver_ips t.nameservers) addr in + match config with + | `Plaintext _ -> continue (`Plain socket) + | `Tls (tls_cfg, _, _) -> + Lwt.catch (fun () -> + Tls_lwt.Unix.client_of_fd tls_cfg socket >>= fun f -> + continue (`Tls f)) + (fun e -> + Log.warn (fun m -> m "TLS handshake with %a:%d failed: %s" + Ipaddr.pp (fst addr) (snd addr) (Printexc.to_string e)); + let ns' = + List.filter + (function + | `Tls (_, ip, port) -> + not (Ipaddr.compare ip (fst addr) = 0 && port = snd addr) + | _ -> true) + nameservers + in + if ns' = [] then begin + Lwt_condition.broadcast connected_condition (); + t.connected_condition <- None; + Lwt.return (Error (`Msg "no further nameservers configured")) + end else + connect_to_ns_list t connected_condition ns') + + and connect_via_tcp_to_ns (t : t) = match t.fd, t.connected_condition with | Some _, _ -> Lwt.return (Ok ()) | None, Some w -> Lwt_condition.wait w >>= fun () -> - connect_via_tcp_to_ns t nameservers + connect_via_tcp_to_ns t | None, None -> let connected_condition = Lwt_condition.create () in t.connected_condition <- Some connected_condition ; - let waiter, notify = Lwt.task () in - let waiters, id = Happy_eyeballs.Waiter_map.register notify t.waiters in - t.waiters <- waiters; - let ns = to_pairs nameservers in - let he, actions = Happy_eyeballs.connect_ip t.he (clock ()) ~id ns in - t.he <- he; - Lwt_condition.signal t.timer_condition (); - Lwt.async (fun () -> Lwt_list.iter_p (handle_action t) actions); - waiter >>= function - | Error `Msg msg -> - Lwt_condition.broadcast connected_condition (); - t.connected_condition <- None; - Lwt.return - (Error (`Msg (Fmt.str "error %s connecting to resolver %a" - msg Fmt.(list ~sep:(any ",") (pair ~sep:(any ":") Ipaddr.pp int)) - (to_pairs t.nameservers)))) - | Ok (addr, socket) -> - let continue socket = - t.fd <- Some socket; - Lwt.async (fun () -> - read_loop t socket >>= fun () -> - if IM.is_empty t.requests then - Lwt.return_unit - else - connect_via_tcp_to_ns t t.nameservers >|= function - | Error (`Msg msg) -> - Log.err (fun m -> m "error while connecting to resolver: %s" msg) - | Ok () -> ()); - Lwt_condition.broadcast connected_condition (); - t.connected_condition <- None; - req_all socket t - in - let config = find_ns t.nameservers addr in - match config with - | `Plaintext _ -> continue (`Plain socket) - | `Tls (tls_cfg, _, _) -> - Lwt.catch (fun () -> - Tls_lwt.Unix.client_of_fd tls_cfg socket >>= fun f -> - continue (`Tls f)) - (fun e -> - Log.warn (fun m -> m "TLS handshake with %a:%d failed: %s" - Ipaddr.pp (fst addr) (snd addr) (Printexc.to_string e)); - Lwt_condition.broadcast connected_condition (); - t.connected_condition <- None; - let ns' = - List.filter - (function - | `Tls (_, ip, port) -> - not (Ipaddr.compare ip (fst addr) = 0 && port = snd addr) - | _ -> true) - nameservers - in - if ns' = [] then - Lwt.return (Error (`Msg "no further nameservers configured")) - else - connect_via_tcp_to_ns t ns') + maybe_resolv_conf t; + connect_to_ns_list t connected_condition (nameserver_ips t.nameservers) let connect t = - connect_via_tcp_to_ns t t.nameservers >|= function + connect_via_tcp_to_ns t >|= function | Ok () -> Ok t | Error `Msg msg -> Error (`Msg msg) end diff --git a/unix/client/dns_client_unix.ml b/unix/client/dns_client_unix.ml index d8511130c..c6e0c5980 100644 --- a/unix/client/dns_client_unix.ml +++ b/unix/client/dns_client_unix.ml @@ -10,9 +10,15 @@ module Transport : Dns_client.S = struct type io_addr = Ipaddr.t * int type stack = unit + type nameservers = + | Static of io_addr list + | Resolv_conf of { + mutable nameservers : io_addr list; + mutable digest : Digest.t option + } type t = { protocol : Dns.proto ; - nameservers : io_addr list ; + nameservers : nameservers ; timeout_ns : int64 ; } type context = { @@ -34,25 +40,60 @@ module Transport : Dns_client.S Error (`Msg ("Error reading file: " ^ file)) with _ -> Error (`Msg ("Error opening file " ^ file)) + let decode_resolv_conf data = + match Dns_resolvconf.parse data with + | Ok [] -> Error (`Msg "empty nameservers from resolv.conf") + | Ok ips -> Ok ips + | Error _ as e -> e + + let default_resolvers () = + List.map (fun ip -> ip, 53) Dns_client.default_resolvers + + let maybe_resolv_conf t = + match t.nameservers with + | Static _ -> () + | Resolv_conf resolv_conf -> + let decode_update data dgst = + match decode_resolv_conf data with + | Ok ips -> + resolv_conf.digest <- Some dgst; + resolv_conf.nameservers <- List.map (function `Nameserver ip -> (ip, 53)) ips + | Error _ -> + resolv_conf.digest <- None; + resolv_conf.nameservers <- default_resolvers () + in + match read_file "/etc/resolv.conf", resolv_conf.digest with + | Ok data, Some d -> + let digest = Digest.string data in + if Digest.equal digest d then () else decode_update data digest + | Ok data, None -> decode_update data (Digest.string data) + | Error _, None -> () + | Error _, Some _ -> + resolv_conf.digest <- None; + resolv_conf.nameservers <- default_resolvers () + let create ?nameservers ~timeout () = let protocol, nameservers = match nameservers with - | Some ns -> ns + | Some (proto, ns) -> (proto, Static ns) | None -> - let ips = + let ips, digest = match - Result.bind - (read_file "/etc/resolv.conf") - (fun data -> Dns_resolvconf.parse data) + let ( let* ) = Result.bind in + let* data = read_file "/etc/resolv.conf" in + let* ips = decode_resolv_conf data in + Ok (ips, Digest.string data) with - | Error _ | Ok [] -> List.map (fun ip -> ip, 53) Dns_client.default_resolvers - | Ok ips -> List.map (function `Nameserver ip -> (ip, 53)) ips + | Error _ -> default_resolvers (), None + | Ok (ips, digest) -> + List.map (function `Nameserver ip -> (ip, 53)) ips, Some digest in - `Tcp, ips + (`Tcp, Resolv_conf { nameservers = ips; digest }) in { protocol ; nameservers ; timeout_ns = timeout } - let nameservers { protocol ; nameservers ; _ } = protocol, nameservers + let nameservers { protocol ; nameservers = Static nameservers | Resolv_conf { nameservers; _ } ; _ } = + protocol, nameservers let clock = Mtime_clock.elapsed_ns let rng = Mirage_crypto_rng.generate ?g:None @@ -74,6 +115,7 @@ module Transport : Dns_client.S (* there is no connect timeouts, just a request timeout (unix: receive timeout) *) let connect t = + maybe_resolv_conf t; match nameservers t with | _, [] -> Error (`Msg "empty nameserver list") | proto, (server, port) :: _ ->