Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dns-client: if /etc/resolv.conf modifies, update the internal list of resolvers #291

Merged
merged 5 commits into from
Feb 4, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
234 changes: 149 additions & 85 deletions lwt/client/dns_client_lwt.ml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ module Transport : Dns_client.S
type +'a io = 'a Lwt.t
type stack = unit
type t = {
nameservers : io_addr list ;
mutable nameservers : io_addr list ;
mutable resolv_conf : Digest.t option ;
timeout_ns : int64 ;
(* TODO: avoid race, use a mvar instead of condition *)
mutable fd : [ `Plain of Lwt_unix.file_descr | `Tls of Tls_lwt.Unix.t ] option ;
Expand Down Expand Up @@ -103,38 +104,97 @@ module Transport : Dns_client.S
Lwt_condition.wait t.timer_condition >>= fun () ->
loop ()

let authenticator =
let _authenticator = ref None in
fun () ->
match !_authenticator with
| Some x -> x
| None -> match Ca_certs.authenticator () with
| Ok a -> _authenticator := Some a ; a
| Error `Msg m -> invalid_arg ("failed to load trust anchors: " ^ m)

let decode_resolv_conf data =
let ( let* ) = Result.bind in
let authenticator = authenticator () in
let* ns = Dns_resolvconf.parse data in
match
List.flatten
(List.map
(fun (`Nameserver ip) ->
let tls = Tls.Config.client ~authenticator ~ip () in
[ `Tls (tls, ip, 853) ; `Plaintext (ip, 53) ])
ns)
with
| [] -> Error (`Msg "no nameservers in resolv.conf")
| ns -> Ok ns

let resolv_conf () =
let ( let* ) = Result.bind in
let* data = read_file "/etc/resolv.conf" in
let* ns =
Result.map_error
(function `Msg msg ->
Log.warn (fun m -> m "error %s decoding resolv.conf %S" msg data);
`Msg msg)
(decode_resolv_conf data)
in
Ok (ns, Digest.string data)

let default_resolver () =
let authenticator = authenticator () in
let peer_name = Dns_client.default_resolver_hostname in
let tls_config = Tls.Config.client ~authenticator ~peer_name () in
List.flatten
(List.map (fun ip -> [
`Tls (tls_config, ip, 853); `Plaintext (ip, 53)
]) Dns_client.default_resolvers)

let maybe_resolv_conf t =
let needs_update =
match read_file "/etc/resolv.conf", t.resolv_conf with
| Ok data, Some dgst ->
let dgst' = Digest.string data in
if Digest.equal dgst' dgst then
`No
else
`Data (data, dgst')
| Ok data, None ->
let digest = Digest.string data in
`Data (data, digest)
| Error _, None ->
`No
| Error `Msg msg, Some _ ->
Log.warn (fun m -> m "error reading /etc/resolv.conf: %s" msg);
`Default
in
match needs_update with
| `No -> ()
| `Default ->
t.resolv_conf <- None;
t.nameservers <- default_resolver ()
| `Data (data, dgst) ->
match decode_resolv_conf data with
| Ok ns ->
t.resolv_conf <- Some dgst;
t.nameservers <- ns
| Error `Msg msg ->
Log.warn (fun m -> m "error %s decoding resolv.conf: %S" msg data);
t.resolv_conf <- None;
t.nameservers <- default_resolver ()

let create ?nameservers ~timeout () =
let nameservers =
let nameservers, resolv_conf =
match nameservers with
| Some (`Udp, _) -> invalid_arg "UDP is not supported"
| Some (`Tcp, ns) -> ns
| Some (`Tcp, ns) -> ns, None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps we should preserve whether the nameservers were passed explicitly and then not try to read resolv.conf.

| None ->
let authenticator = match Ca_certs.authenticator () with
| Ok a -> a
| Error `Msg m -> invalid_arg ("failed to load trust anchors: " ^ m)
in
match
let ( let* ) = Result.bind in
let* data = read_file "/etc/resolv.conf" in
let* ns = Dns_resolvconf.parse data in
Ok (List.flatten
(List.map
(fun (`Nameserver ip) ->
let tls = Tls.Config.client ~authenticator ~ip () in
[ `Tls (tls, ip, 853) ; `Plaintext (ip, 53) ])
ns))
with
| Error _ | Ok [] ->
let peer_name = Dns_client.default_resolver_hostname in
let tls_config = Tls.Config.client ~authenticator ~peer_name () in
List.flatten
(List.map (fun ip -> [
`Tls (tls_config, ip, 853); `Plaintext (ip, 53)
]) Dns_client.default_resolvers)
| Ok ips -> ips
match resolv_conf () with
| Error _ -> default_resolver (), None
| Ok (ips, digest) -> ips, Some digest
in
let t = {
nameservers ;
resolv_conf ;
timeout_ns = timeout ;
fd = None ;
connected_condition = None ;
Expand Down Expand Up @@ -254,74 +314,78 @@ module Transport : Dns_client.S
Ipaddr.compare ip addr = 0 && p = port)
ns

let rec connect_via_tcp_to_ns (t : t) nameservers =
let rec connect_to_ns_list (t : t) connected_condition nameservers =
let waiter, notify = Lwt.task () in
let waiters, id = Happy_eyeballs.Waiter_map.register notify t.waiters in
t.waiters <- waiters;
let ns = to_pairs nameservers in
let he, actions = Happy_eyeballs.connect_ip t.he (clock ()) ~id ns in
t.he <- he;
Lwt_condition.signal t.timer_condition ();
Lwt.async (fun () -> Lwt_list.iter_p (handle_action t) actions);
waiter >>= function
| Error `Msg msg ->
Lwt_condition.broadcast connected_condition ();
t.connected_condition <- None;
Lwt.return
(Error (`Msg (Fmt.str "error %s connecting to resolver %a"
msg Fmt.(list ~sep:(any ",") (pair ~sep:(any ":") Ipaddr.pp int))
(to_pairs t.nameservers))))
| Ok (addr, socket) ->
let continue socket =
t.fd <- Some socket;
Lwt.async (fun () ->
read_loop t socket >>= fun () ->
if IM.is_empty t.requests then
Lwt.return_unit
else
connect_via_tcp_to_ns t >|= function
| Error (`Msg msg) ->
Log.err (fun m -> m "error while connecting to resolver: %s" msg)
| Ok () -> ());
Lwt_condition.broadcast connected_condition ();
t.connected_condition <- None;
req_all socket t
in
let config = find_ns t.nameservers addr in
match config with
| `Plaintext _ -> continue (`Plain socket)
| `Tls (tls_cfg, _, _) ->
Lwt.catch (fun () ->
Tls_lwt.Unix.client_of_fd tls_cfg socket >>= fun f ->
continue (`Tls f))
(fun e ->
Log.warn (fun m -> m "TLS handshake with %a:%d failed: %s"
Ipaddr.pp (fst addr) (snd addr) (Printexc.to_string e));
let ns' =
List.filter
(function
| `Tls (_, ip, port) ->
not (Ipaddr.compare ip (fst addr) = 0 && port = snd addr)
| _ -> true)
nameservers
in
if ns' = [] then begin
Lwt_condition.broadcast connected_condition ();
t.connected_condition <- None;
Lwt.return (Error (`Msg "no further nameservers configured"))
end else
connect_to_ns_list t connected_condition ns')

and connect_via_tcp_to_ns (t : t) =
match t.fd, t.connected_condition with
| Some _, _ -> Lwt.return (Ok ())
| None, Some w ->
Lwt_condition.wait w >>= fun () ->
connect_via_tcp_to_ns t nameservers
connect_via_tcp_to_ns t
| None, None ->
let connected_condition = Lwt_condition.create () in
t.connected_condition <- Some connected_condition ;
let waiter, notify = Lwt.task () in
let waiters, id = Happy_eyeballs.Waiter_map.register notify t.waiters in
t.waiters <- waiters;
let ns = to_pairs nameservers in
let he, actions = Happy_eyeballs.connect_ip t.he (clock ()) ~id ns in
t.he <- he;
Lwt_condition.signal t.timer_condition ();
Lwt.async (fun () -> Lwt_list.iter_p (handle_action t) actions);
waiter >>= function
| Error `Msg msg ->
Lwt_condition.broadcast connected_condition ();
t.connected_condition <- None;
Lwt.return
(Error (`Msg (Fmt.str "error %s connecting to resolver %a"
msg Fmt.(list ~sep:(any ",") (pair ~sep:(any ":") Ipaddr.pp int))
(to_pairs t.nameservers))))
| Ok (addr, socket) ->
let continue socket =
t.fd <- Some socket;
Lwt.async (fun () ->
read_loop t socket >>= fun () ->
if IM.is_empty t.requests then
Lwt.return_unit
else
connect_via_tcp_to_ns t t.nameservers >|= function
| Error (`Msg msg) ->
Log.err (fun m -> m "error while connecting to resolver: %s" msg)
| Ok () -> ());
Lwt_condition.broadcast connected_condition ();
t.connected_condition <- None;
req_all socket t
in
let config = find_ns t.nameservers addr in
match config with
| `Plaintext _ -> continue (`Plain socket)
| `Tls (tls_cfg, _, _) ->
Lwt.catch (fun () ->
Tls_lwt.Unix.client_of_fd tls_cfg socket >>= fun f ->
continue (`Tls f))
(fun e ->
Log.warn (fun m -> m "TLS handshake with %a:%d failed: %s"
Ipaddr.pp (fst addr) (snd addr) (Printexc.to_string e));
Lwt_condition.broadcast connected_condition ();
t.connected_condition <- None;
let ns' =
List.filter
(function
| `Tls (_, ip, port) ->
not (Ipaddr.compare ip (fst addr) = 0 && port = snd addr)
| _ -> true)
nameservers
in
if ns' = [] then
Lwt.return (Error (`Msg "no further nameservers configured"))
else
connect_via_tcp_to_ns t ns')
maybe_resolv_conf t;
connect_to_ns_list t connected_condition t.nameservers

let connect t =
connect_via_tcp_to_ns t t.nameservers >|= function
connect_via_tcp_to_ns t >|= function
| Ok () -> Ok t
| Error `Msg msg -> Error (`Msg msg)
end
Expand Down
55 changes: 44 additions & 11 deletions unix/client/dns_client_unix.ml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ module Transport : Dns_client.S
type stack = unit
type t = {
protocol : Dns.proto ;
nameservers : io_addr list ;
mutable resolv_conf : Digest.t option ;
mutable nameservers : io_addr list ;
timeout_ns : int64 ;
}
type context = {
Expand All @@ -34,23 +35,54 @@ module Transport : Dns_client.S
Error (`Msg ("Error reading file: " ^ file))
with _ -> Error (`Msg ("Error opening file " ^ file))

let decode_resolv_conf data =
match Dns_resolvconf.parse data with
| Ok [] -> Error (`Msg "empty nameservers from resolv.conf")
| Ok ips -> Ok ips
| Error _ as e -> e

let default_resolvers () =
List.map (fun ip -> ip, 53) Dns_client.default_resolvers

let maybe_resolv_conf t =
let decode_update data dgst =
match decode_resolv_conf data with
| Ok ips ->
t.resolv_conf <- Some dgst;
t.nameservers <- List.map (function `Nameserver ip -> (ip, 53)) ips
| Error _ ->
t.resolv_conf <- None;
t.nameservers <- default_resolvers ()
in
match read_file "/etc/resolv.conf", t.resolv_conf with
| Error _, None -> ()
| Error _, Some _ ->
t.resolv_conf <- None;
t.nameservers <- default_resolvers ()
| Ok data, None -> decode_update data (Digest.string data)
| Ok data, Some d ->
let digest = Digest.string data in
if Digest.equal digest d then () else decode_update data digest

let create ?nameservers ~timeout () =
let protocol, nameservers =
let (protocol, nameservers), resolv_conf =
match nameservers with
| Some ns -> ns
| Some ns -> ns, None
| None ->
let ips =
let ips, digest =
match
Result.bind
(read_file "/etc/resolv.conf")
(fun data -> Dns_resolvconf.parse data)
let ( let* ) = Result.bind in
let* data = read_file "/etc/resolv.conf" in
let* ips = decode_resolv_conf data in
Ok (ips, Digest.string data)
with
| Error _ | Ok [] -> List.map (fun ip -> ip, 53) Dns_client.default_resolvers
| Ok ips -> List.map (function `Nameserver ip -> (ip, 53)) ips
| Error _ -> default_resolvers (), None
| Ok (ips, digest) ->
List.map (function `Nameserver ip -> (ip, 53)) ips, Some digest
in
`Tcp, ips
(`Tcp, ips), digest
in
{ protocol ; nameservers ; timeout_ns = timeout }
{ protocol ; resolv_conf ; nameservers ; timeout_ns = timeout }

let nameservers { protocol ; nameservers ; _ } = protocol, nameservers
let clock = Mtime_clock.elapsed_ns
Expand All @@ -74,6 +106,7 @@ module Transport : Dns_client.S

(* there is no connect timeouts, just a request timeout (unix: receive timeout) *)
let connect t =
maybe_resolv_conf t;
match nameservers t with
| _, [] -> Error (`Msg "empty nameserver list")
| proto, (server, port) :: _ ->
Expand Down