Skip to content

Commit

Permalink
dns-client: if /etc/resolv.conf modifies, update the internal list of…
Browse files Browse the repository at this point in the history
… resolvers (#291)

* dns-client: if /etc/resolv.conf modifies, update the internal list of resolvers

Only affects lwt and unix.
Previously, when you started an application, /etc/resolv.conf was read once at
startup. When you switch to a different network, and /etc/resolv.conf is updated
by DHCP, this is not reflected in the dns-client.

Now, the digest of /etc/resolv.conf is stored internally, and if it changes, it
is parsed again. The lwt implementation needed to be modified slightly since
connect_to_ns_via_tcp carried a nameserver list, and a different task may have
been woken up. Now, there are two functions and it should play out nicely.

Co-authored-by: Reynir Björnsson <[email protected]>
  • Loading branch information
hannesm and reynir authored Feb 4, 2022
1 parent 5b37ae1 commit 7763da2
Show file tree
Hide file tree
Showing 2 changed files with 213 additions and 95 deletions.
246 changes: 161 additions & 85 deletions lwt/client/dns_client_lwt.ml
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,14 @@ module Transport : Dns_client.S
type io_addr = [ `Plaintext of Ipaddr.t * int | `Tls of Tls.Config.client * Ipaddr.t * int ]
type +'a io = 'a Lwt.t
type stack = unit
type nameservers =
| Static of io_addr list
| Resolv_conf of {
mutable nameservers : io_addr list;
mutable digest : Digest.t option
}
type t = {
nameservers : io_addr list ;
nameservers : nameservers;
timeout_ns : int64 ;
(* TODO: avoid race, use a mvar instead of condition *)
mutable fd : [ `Plain of Lwt_unix.file_descr | `Tls of Tls_lwt.Unix.t ] option ;
Expand All @@ -26,6 +32,10 @@ module Transport : Dns_client.S
}
type context = t

let nameserver_ips = function
| Static nameservers -> nameservers
| Resolv_conf { nameservers; _ } -> nameservers

let read_file file =
try
let fh = open_in file in
Expand Down Expand Up @@ -103,35 +113,96 @@ module Transport : Dns_client.S
Lwt_condition.wait t.timer_condition >>= fun () ->
loop ()

let authenticator =
let authenticator_ref = ref None in
fun () ->
match !authenticator_ref with
| Some x -> x
| None -> match Ca_certs.authenticator () with
| Ok a -> authenticator_ref := Some a ; a
| Error `Msg m -> invalid_arg ("failed to load trust anchors: " ^ m)

let decode_resolv_conf data =
let ( let* ) = Result.bind in
let authenticator = authenticator () in
let* ns = Dns_resolvconf.parse data in
match
List.flatten
(List.map
(fun (`Nameserver ip) ->
let tls = Tls.Config.client ~authenticator ~ip () in
[ `Tls (tls, ip, 853) ; `Plaintext (ip, 53) ])
ns)
with
| [] -> Error (`Msg "no nameservers in resolv.conf")
| ns -> Ok ns

let resolv_conf () =
let ( let* ) = Result.bind in
let* data = read_file "/etc/resolv.conf" in
let* ns =
Result.map_error
(function `Msg msg ->
Log.warn (fun m -> m "error %s decoding resolv.conf %S" msg data);
`Msg msg)
(decode_resolv_conf data)
in
Ok (ns, Digest.string data)

let default_resolver () =
let authenticator = authenticator () in
let peer_name = Dns_client.default_resolver_hostname in
let tls_config = Tls.Config.client ~authenticator ~peer_name () in
List.flatten
(List.map (fun ip -> [
`Tls (tls_config, ip, 853); `Plaintext (ip, 53)
]) Dns_client.default_resolvers)

let maybe_resolv_conf t =
match t.nameservers with
| Static _ -> ()
| Resolv_conf resolv_conf ->
let needs_update =
match read_file "/etc/resolv.conf", resolv_conf.digest with
| Ok data, Some dgst ->
let dgst' = Digest.string data in
if Digest.equal dgst' dgst then
`No
else
`Data (data, dgst')
| Ok data, None ->
let digest = Digest.string data in
`Data (data, digest)
| Error _, None ->
`No
| Error `Msg msg, Some _ ->
Log.warn (fun m -> m "error reading /etc/resolv.conf: %s" msg);
`Default
in
match needs_update with
| `No -> ()
| `Default ->
resolv_conf.digest <- None;
resolv_conf.nameservers <- default_resolver ()
| `Data (data, dgst) ->
match decode_resolv_conf data with
| Ok ns ->
resolv_conf.digest <- Some dgst;
resolv_conf.nameservers <- ns
| Error `Msg msg ->
Log.warn (fun m -> m "error %s decoding resolv.conf: %S" msg data);
resolv_conf.digest <- None;
resolv_conf.nameservers <- default_resolver ()

let create ?nameservers ~timeout () =
let nameservers =
match nameservers with
| Some (`Udp, _) -> invalid_arg "UDP is not supported"
| Some (`Tcp, ns) -> ns
| Some (`Tcp, ns) -> Static ns
| None ->
let authenticator = match Ca_certs.authenticator () with
| Ok a -> a
| Error `Msg m -> invalid_arg ("failed to load trust anchors: " ^ m)
in
match
let ( let* ) = Result.bind in
let* data = read_file "/etc/resolv.conf" in
let* ns = Dns_resolvconf.parse data in
Ok (List.flatten
(List.map
(fun (`Nameserver ip) ->
let tls = Tls.Config.client ~authenticator ~ip () in
[ `Tls (tls, ip, 853) ; `Plaintext (ip, 53) ])
ns))
with
| Error _ | Ok [] ->
let peer_name = Dns_client.default_resolver_hostname in
let tls_config = Tls.Config.client ~authenticator ~peer_name () in
List.flatten
(List.map (fun ip -> [
`Tls (tls_config, ip, 853); `Plaintext (ip, 53)
]) Dns_client.default_resolvers)
| Ok ips -> ips
match resolv_conf () with
| Error _ -> Resolv_conf { nameservers = default_resolver (); digest = None }
| Ok (ips, digest) -> Resolv_conf { nameservers = ips; digest = Some digest }
in
let t = {
nameservers ;
Expand All @@ -146,7 +217,8 @@ module Transport : Dns_client.S
Lwt.async (fun () -> he_timer t);
t

let nameservers { nameservers ; _ } = `Tcp, nameservers
let nameservers { nameservers; _ } = `Tcp, nameserver_ips nameservers

let rng = Mirage_crypto_rng.generate ?g:None

let with_timeout timeout f =
Expand Down Expand Up @@ -254,74 +326,78 @@ module Transport : Dns_client.S
Ipaddr.compare ip addr = 0 && p = port)
ns

let rec connect_via_tcp_to_ns (t : t) nameservers =
let rec connect_to_ns_list (t : t) connected_condition nameservers =
let waiter, notify = Lwt.task () in
let waiters, id = Happy_eyeballs.Waiter_map.register notify t.waiters in
t.waiters <- waiters;
let ns = to_pairs nameservers in
let he, actions = Happy_eyeballs.connect_ip t.he (clock ()) ~id ns in
t.he <- he;
Lwt_condition.signal t.timer_condition ();
Lwt.async (fun () -> Lwt_list.iter_p (handle_action t) actions);
waiter >>= function
| Error `Msg msg ->
Lwt_condition.broadcast connected_condition ();
t.connected_condition <- None;
Lwt.return
(Error (`Msg (Fmt.str "error %s connecting to resolver %a"
msg Fmt.(list ~sep:(any ",") (pair ~sep:(any ":") Ipaddr.pp int))
(to_pairs (nameserver_ips t.nameservers)))))
| Ok (addr, socket) ->
let continue socket =
t.fd <- Some socket;
Lwt.async (fun () ->
read_loop t socket >>= fun () ->
if IM.is_empty t.requests then
Lwt.return_unit
else
connect_via_tcp_to_ns t >|= function
| Error (`Msg msg) ->
Log.err (fun m -> m "error while connecting to resolver: %s" msg)
| Ok () -> ());
Lwt_condition.broadcast connected_condition ();
t.connected_condition <- None;
req_all socket t
in
let config = find_ns (nameserver_ips t.nameservers) addr in
match config with
| `Plaintext _ -> continue (`Plain socket)
| `Tls (tls_cfg, _, _) ->
Lwt.catch (fun () ->
Tls_lwt.Unix.client_of_fd tls_cfg socket >>= fun f ->
continue (`Tls f))
(fun e ->
Log.warn (fun m -> m "TLS handshake with %a:%d failed: %s"
Ipaddr.pp (fst addr) (snd addr) (Printexc.to_string e));
let ns' =
List.filter
(function
| `Tls (_, ip, port) ->
not (Ipaddr.compare ip (fst addr) = 0 && port = snd addr)
| _ -> true)
nameservers
in
if ns' = [] then begin
Lwt_condition.broadcast connected_condition ();
t.connected_condition <- None;
Lwt.return (Error (`Msg "no further nameservers configured"))
end else
connect_to_ns_list t connected_condition ns')

and connect_via_tcp_to_ns (t : t) =
match t.fd, t.connected_condition with
| Some _, _ -> Lwt.return (Ok ())
| None, Some w ->
Lwt_condition.wait w >>= fun () ->
connect_via_tcp_to_ns t nameservers
connect_via_tcp_to_ns t
| None, None ->
let connected_condition = Lwt_condition.create () in
t.connected_condition <- Some connected_condition ;
let waiter, notify = Lwt.task () in
let waiters, id = Happy_eyeballs.Waiter_map.register notify t.waiters in
t.waiters <- waiters;
let ns = to_pairs nameservers in
let he, actions = Happy_eyeballs.connect_ip t.he (clock ()) ~id ns in
t.he <- he;
Lwt_condition.signal t.timer_condition ();
Lwt.async (fun () -> Lwt_list.iter_p (handle_action t) actions);
waiter >>= function
| Error `Msg msg ->
Lwt_condition.broadcast connected_condition ();
t.connected_condition <- None;
Lwt.return
(Error (`Msg (Fmt.str "error %s connecting to resolver %a"
msg Fmt.(list ~sep:(any ",") (pair ~sep:(any ":") Ipaddr.pp int))
(to_pairs t.nameservers))))
| Ok (addr, socket) ->
let continue socket =
t.fd <- Some socket;
Lwt.async (fun () ->
read_loop t socket >>= fun () ->
if IM.is_empty t.requests then
Lwt.return_unit
else
connect_via_tcp_to_ns t t.nameservers >|= function
| Error (`Msg msg) ->
Log.err (fun m -> m "error while connecting to resolver: %s" msg)
| Ok () -> ());
Lwt_condition.broadcast connected_condition ();
t.connected_condition <- None;
req_all socket t
in
let config = find_ns t.nameservers addr in
match config with
| `Plaintext _ -> continue (`Plain socket)
| `Tls (tls_cfg, _, _) ->
Lwt.catch (fun () ->
Tls_lwt.Unix.client_of_fd tls_cfg socket >>= fun f ->
continue (`Tls f))
(fun e ->
Log.warn (fun m -> m "TLS handshake with %a:%d failed: %s"
Ipaddr.pp (fst addr) (snd addr) (Printexc.to_string e));
Lwt_condition.broadcast connected_condition ();
t.connected_condition <- None;
let ns' =
List.filter
(function
| `Tls (_, ip, port) ->
not (Ipaddr.compare ip (fst addr) = 0 && port = snd addr)
| _ -> true)
nameservers
in
if ns' = [] then
Lwt.return (Error (`Msg "no further nameservers configured"))
else
connect_via_tcp_to_ns t ns')
maybe_resolv_conf t;
connect_to_ns_list t connected_condition (nameserver_ips t.nameservers)

let connect t =
connect_via_tcp_to_ns t t.nameservers >|= function
connect_via_tcp_to_ns t >|= function
| Ok () -> Ok t
| Error `Msg msg -> Error (`Msg msg)
end
Expand Down
62 changes: 52 additions & 10 deletions unix/client/dns_client_unix.ml
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,15 @@ module Transport : Dns_client.S
= struct
type io_addr = Ipaddr.t * int
type stack = unit
type nameservers =
| Static of io_addr list
| Resolv_conf of {
mutable nameservers : io_addr list;
mutable digest : Digest.t option
}
type t = {
protocol : Dns.proto ;
nameservers : io_addr list ;
nameservers : nameservers ;
timeout_ns : int64 ;
}
type context = {
Expand All @@ -34,25 +40,60 @@ module Transport : Dns_client.S
Error (`Msg ("Error reading file: " ^ file))
with _ -> Error (`Msg ("Error opening file " ^ file))

let decode_resolv_conf data =
match Dns_resolvconf.parse data with
| Ok [] -> Error (`Msg "empty nameservers from resolv.conf")
| Ok ips -> Ok ips
| Error _ as e -> e

let default_resolvers () =
List.map (fun ip -> ip, 53) Dns_client.default_resolvers

let maybe_resolv_conf t =
match t.nameservers with
| Static _ -> ()
| Resolv_conf resolv_conf ->
let decode_update data dgst =
match decode_resolv_conf data with
| Ok ips ->
resolv_conf.digest <- Some dgst;
resolv_conf.nameservers <- List.map (function `Nameserver ip -> (ip, 53)) ips
| Error _ ->
resolv_conf.digest <- None;
resolv_conf.nameservers <- default_resolvers ()
in
match read_file "/etc/resolv.conf", resolv_conf.digest with
| Ok data, Some d ->
let digest = Digest.string data in
if Digest.equal digest d then () else decode_update data digest
| Ok data, None -> decode_update data (Digest.string data)
| Error _, None -> ()
| Error _, Some _ ->
resolv_conf.digest <- None;
resolv_conf.nameservers <- default_resolvers ()

let create ?nameservers ~timeout () =
let protocol, nameservers =
match nameservers with
| Some ns -> ns
| Some (proto, ns) -> (proto, Static ns)
| None ->
let ips =
let ips, digest =
match
Result.bind
(read_file "/etc/resolv.conf")
(fun data -> Dns_resolvconf.parse data)
let ( let* ) = Result.bind in
let* data = read_file "/etc/resolv.conf" in
let* ips = decode_resolv_conf data in
Ok (ips, Digest.string data)
with
| Error _ | Ok [] -> List.map (fun ip -> ip, 53) Dns_client.default_resolvers
| Ok ips -> List.map (function `Nameserver ip -> (ip, 53)) ips
| Error _ -> default_resolvers (), None
| Ok (ips, digest) ->
List.map (function `Nameserver ip -> (ip, 53)) ips, Some digest
in
`Tcp, ips
(`Tcp, Resolv_conf { nameservers = ips; digest })
in
{ protocol ; nameservers ; timeout_ns = timeout }

let nameservers { protocol ; nameservers ; _ } = protocol, nameservers
let nameservers { protocol ; nameservers = Static nameservers | Resolv_conf { nameservers; _ } ; _ } =
protocol, nameservers
let clock = Mtime_clock.elapsed_ns
let rng = Mirage_crypto_rng.generate ?g:None

Expand All @@ -74,6 +115,7 @@ module Transport : Dns_client.S

(* there is no connect timeouts, just a request timeout (unix: receive timeout) *)
let connect t =
maybe_resolv_conf t;
match nameservers t with
| _, [] -> Error (`Msg "empty nameserver list")
| proto, (server, port) :: _ ->
Expand Down

0 comments on commit 7763da2

Please sign in to comment.