Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dns-client: if /etc/resolv.conf modifies, update the internal list of resolvers #291

Merged
merged 5 commits into from
Feb 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
246 changes: 161 additions & 85 deletions lwt/client/dns_client_lwt.ml
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,14 @@ module Transport : Dns_client.S
type io_addr = [ `Plaintext of Ipaddr.t * int | `Tls of Tls.Config.client * Ipaddr.t * int ]
type +'a io = 'a Lwt.t
type stack = unit
type nameservers =
| Static of io_addr list
| Resolv_conf of {
mutable nameservers : io_addr list;
mutable digest : Digest.t option
}
type t = {
nameservers : io_addr list ;
nameservers : nameservers;
timeout_ns : int64 ;
(* TODO: avoid race, use a mvar instead of condition *)
mutable fd : [ `Plain of Lwt_unix.file_descr | `Tls of Tls_lwt.Unix.t ] option ;
Expand All @@ -26,6 +32,10 @@ module Transport : Dns_client.S
}
type context = t

let nameserver_ips = function
| Static nameservers -> nameservers
| Resolv_conf { nameservers; _ } -> nameservers

let read_file file =
try
let fh = open_in file in
Expand Down Expand Up @@ -103,35 +113,96 @@ module Transport : Dns_client.S
Lwt_condition.wait t.timer_condition >>= fun () ->
loop ()

let authenticator =
let authenticator_ref = ref None in
fun () ->
match !authenticator_ref with
| Some x -> x
| None -> match Ca_certs.authenticator () with
| Ok a -> authenticator_ref := Some a ; a
| Error `Msg m -> invalid_arg ("failed to load trust anchors: " ^ m)

let decode_resolv_conf data =
let ( let* ) = Result.bind in
let authenticator = authenticator () in
let* ns = Dns_resolvconf.parse data in
match
List.flatten
(List.map
(fun (`Nameserver ip) ->
let tls = Tls.Config.client ~authenticator ~ip () in
[ `Tls (tls, ip, 853) ; `Plaintext (ip, 53) ])
ns)
with
| [] -> Error (`Msg "no nameservers in resolv.conf")
| ns -> Ok ns

let resolv_conf () =
let ( let* ) = Result.bind in
let* data = read_file "/etc/resolv.conf" in
let* ns =
Result.map_error
(function `Msg msg ->
Log.warn (fun m -> m "error %s decoding resolv.conf %S" msg data);
`Msg msg)
(decode_resolv_conf data)
in
Ok (ns, Digest.string data)

let default_resolver () =
let authenticator = authenticator () in
let peer_name = Dns_client.default_resolver_hostname in
let tls_config = Tls.Config.client ~authenticator ~peer_name () in
List.flatten
(List.map (fun ip -> [
`Tls (tls_config, ip, 853); `Plaintext (ip, 53)
]) Dns_client.default_resolvers)

let maybe_resolv_conf t =
match t.nameservers with
| Static _ -> ()
| Resolv_conf resolv_conf ->
let needs_update =
match read_file "/etc/resolv.conf", resolv_conf.digest with
| Ok data, Some dgst ->
let dgst' = Digest.string data in
if Digest.equal dgst' dgst then
`No
else
`Data (data, dgst')
| Ok data, None ->
let digest = Digest.string data in
`Data (data, digest)
| Error _, None ->
`No
| Error `Msg msg, Some _ ->
Log.warn (fun m -> m "error reading /etc/resolv.conf: %s" msg);
`Default
in
match needs_update with
| `No -> ()
| `Default ->
resolv_conf.digest <- None;
resolv_conf.nameservers <- default_resolver ()
| `Data (data, dgst) ->
match decode_resolv_conf data with
| Ok ns ->
resolv_conf.digest <- Some dgst;
resolv_conf.nameservers <- ns
| Error `Msg msg ->
Log.warn (fun m -> m "error %s decoding resolv.conf: %S" msg data);
resolv_conf.digest <- None;
resolv_conf.nameservers <- default_resolver ()

let create ?nameservers ~timeout () =
let nameservers =
match nameservers with
| Some (`Udp, _) -> invalid_arg "UDP is not supported"
| Some (`Tcp, ns) -> ns
| Some (`Tcp, ns) -> Static ns
| None ->
let authenticator = match Ca_certs.authenticator () with
| Ok a -> a
| Error `Msg m -> invalid_arg ("failed to load trust anchors: " ^ m)
in
match
let ( let* ) = Result.bind in
let* data = read_file "/etc/resolv.conf" in
let* ns = Dns_resolvconf.parse data in
Ok (List.flatten
(List.map
(fun (`Nameserver ip) ->
let tls = Tls.Config.client ~authenticator ~ip () in
[ `Tls (tls, ip, 853) ; `Plaintext (ip, 53) ])
ns))
with
| Error _ | Ok [] ->
let peer_name = Dns_client.default_resolver_hostname in
let tls_config = Tls.Config.client ~authenticator ~peer_name () in
List.flatten
(List.map (fun ip -> [
`Tls (tls_config, ip, 853); `Plaintext (ip, 53)
]) Dns_client.default_resolvers)
| Ok ips -> ips
match resolv_conf () with
| Error _ -> Resolv_conf { nameservers = default_resolver (); digest = None }
| Ok (ips, digest) -> Resolv_conf { nameservers = ips; digest = Some digest }
in
let t = {
nameservers ;
Expand All @@ -146,7 +217,8 @@ module Transport : Dns_client.S
Lwt.async (fun () -> he_timer t);
t

let nameservers { nameservers ; _ } = `Tcp, nameservers
let nameservers { nameservers; _ } = `Tcp, nameserver_ips nameservers

let rng = Mirage_crypto_rng.generate ?g:None

let with_timeout timeout f =
Expand Down Expand Up @@ -254,74 +326,78 @@ module Transport : Dns_client.S
Ipaddr.compare ip addr = 0 && p = port)
ns

let rec connect_via_tcp_to_ns (t : t) nameservers =
let rec connect_to_ns_list (t : t) connected_condition nameservers =
let waiter, notify = Lwt.task () in
let waiters, id = Happy_eyeballs.Waiter_map.register notify t.waiters in
t.waiters <- waiters;
let ns = to_pairs nameservers in
let he, actions = Happy_eyeballs.connect_ip t.he (clock ()) ~id ns in
t.he <- he;
Lwt_condition.signal t.timer_condition ();
Lwt.async (fun () -> Lwt_list.iter_p (handle_action t) actions);
waiter >>= function
| Error `Msg msg ->
Lwt_condition.broadcast connected_condition ();
t.connected_condition <- None;
Lwt.return
(Error (`Msg (Fmt.str "error %s connecting to resolver %a"
msg Fmt.(list ~sep:(any ",") (pair ~sep:(any ":") Ipaddr.pp int))
(to_pairs (nameserver_ips t.nameservers)))))
| Ok (addr, socket) ->
let continue socket =
t.fd <- Some socket;
Lwt.async (fun () ->
read_loop t socket >>= fun () ->
if IM.is_empty t.requests then
Lwt.return_unit
else
connect_via_tcp_to_ns t >|= function
| Error (`Msg msg) ->
Log.err (fun m -> m "error while connecting to resolver: %s" msg)
| Ok () -> ());
Lwt_condition.broadcast connected_condition ();
t.connected_condition <- None;
req_all socket t
in
let config = find_ns (nameserver_ips t.nameservers) addr in
match config with
| `Plaintext _ -> continue (`Plain socket)
| `Tls (tls_cfg, _, _) ->
Lwt.catch (fun () ->
Tls_lwt.Unix.client_of_fd tls_cfg socket >>= fun f ->
continue (`Tls f))
(fun e ->
Log.warn (fun m -> m "TLS handshake with %a:%d failed: %s"
Ipaddr.pp (fst addr) (snd addr) (Printexc.to_string e));
let ns' =
List.filter
(function
| `Tls (_, ip, port) ->
not (Ipaddr.compare ip (fst addr) = 0 && port = snd addr)
| _ -> true)
nameservers
in
if ns' = [] then begin
Lwt_condition.broadcast connected_condition ();
t.connected_condition <- None;
Lwt.return (Error (`Msg "no further nameservers configured"))
end else
connect_to_ns_list t connected_condition ns')

and connect_via_tcp_to_ns (t : t) =
match t.fd, t.connected_condition with
| Some _, _ -> Lwt.return (Ok ())
| None, Some w ->
Lwt_condition.wait w >>= fun () ->
connect_via_tcp_to_ns t nameservers
connect_via_tcp_to_ns t
| None, None ->
let connected_condition = Lwt_condition.create () in
t.connected_condition <- Some connected_condition ;
let waiter, notify = Lwt.task () in
let waiters, id = Happy_eyeballs.Waiter_map.register notify t.waiters in
t.waiters <- waiters;
let ns = to_pairs nameservers in
let he, actions = Happy_eyeballs.connect_ip t.he (clock ()) ~id ns in
t.he <- he;
Lwt_condition.signal t.timer_condition ();
Lwt.async (fun () -> Lwt_list.iter_p (handle_action t) actions);
waiter >>= function
| Error `Msg msg ->
Lwt_condition.broadcast connected_condition ();
t.connected_condition <- None;
Lwt.return
(Error (`Msg (Fmt.str "error %s connecting to resolver %a"
msg Fmt.(list ~sep:(any ",") (pair ~sep:(any ":") Ipaddr.pp int))
(to_pairs t.nameservers))))
| Ok (addr, socket) ->
let continue socket =
t.fd <- Some socket;
Lwt.async (fun () ->
read_loop t socket >>= fun () ->
if IM.is_empty t.requests then
Lwt.return_unit
else
connect_via_tcp_to_ns t t.nameservers >|= function
| Error (`Msg msg) ->
Log.err (fun m -> m "error while connecting to resolver: %s" msg)
| Ok () -> ());
Lwt_condition.broadcast connected_condition ();
t.connected_condition <- None;
req_all socket t
in
let config = find_ns t.nameservers addr in
match config with
| `Plaintext _ -> continue (`Plain socket)
| `Tls (tls_cfg, _, _) ->
Lwt.catch (fun () ->
Tls_lwt.Unix.client_of_fd tls_cfg socket >>= fun f ->
continue (`Tls f))
(fun e ->
Log.warn (fun m -> m "TLS handshake with %a:%d failed: %s"
Ipaddr.pp (fst addr) (snd addr) (Printexc.to_string e));
Lwt_condition.broadcast connected_condition ();
t.connected_condition <- None;
let ns' =
List.filter
(function
| `Tls (_, ip, port) ->
not (Ipaddr.compare ip (fst addr) = 0 && port = snd addr)
| _ -> true)
nameservers
in
if ns' = [] then
Lwt.return (Error (`Msg "no further nameservers configured"))
else
connect_via_tcp_to_ns t ns')
maybe_resolv_conf t;
connect_to_ns_list t connected_condition (nameserver_ips t.nameservers)

let connect t =
connect_via_tcp_to_ns t t.nameservers >|= function
connect_via_tcp_to_ns t >|= function
| Ok () -> Ok t
| Error `Msg msg -> Error (`Msg msg)
end
Expand Down
62 changes: 52 additions & 10 deletions unix/client/dns_client_unix.ml
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,15 @@ module Transport : Dns_client.S
= struct
type io_addr = Ipaddr.t * int
type stack = unit
type nameservers =
| Static of io_addr list
| Resolv_conf of {
mutable nameservers : io_addr list;
mutable digest : Digest.t option
}
type t = {
protocol : Dns.proto ;
nameservers : io_addr list ;
nameservers : nameservers ;
timeout_ns : int64 ;
}
type context = {
Expand All @@ -34,25 +40,60 @@ module Transport : Dns_client.S
Error (`Msg ("Error reading file: " ^ file))
with _ -> Error (`Msg ("Error opening file " ^ file))

let decode_resolv_conf data =
match Dns_resolvconf.parse data with
| Ok [] -> Error (`Msg "empty nameservers from resolv.conf")
| Ok ips -> Ok ips
| Error _ as e -> e

let default_resolvers () =
List.map (fun ip -> ip, 53) Dns_client.default_resolvers

let maybe_resolv_conf t =
match t.nameservers with
| Static _ -> ()
| Resolv_conf resolv_conf ->
let decode_update data dgst =
match decode_resolv_conf data with
| Ok ips ->
resolv_conf.digest <- Some dgst;
resolv_conf.nameservers <- List.map (function `Nameserver ip -> (ip, 53)) ips
| Error _ ->
resolv_conf.digest <- None;
resolv_conf.nameservers <- default_resolvers ()
in
match read_file "/etc/resolv.conf", resolv_conf.digest with
| Ok data, Some d ->
let digest = Digest.string data in
if Digest.equal digest d then () else decode_update data digest
| Ok data, None -> decode_update data (Digest.string data)
| Error _, None -> ()
| Error _, Some _ ->
resolv_conf.digest <- None;
resolv_conf.nameservers <- default_resolvers ()

let create ?nameservers ~timeout () =
let protocol, nameservers =
match nameservers with
| Some ns -> ns
| Some (proto, ns) -> (proto, Static ns)
| None ->
let ips =
let ips, digest =
match
Result.bind
(read_file "/etc/resolv.conf")
(fun data -> Dns_resolvconf.parse data)
let ( let* ) = Result.bind in
let* data = read_file "/etc/resolv.conf" in
let* ips = decode_resolv_conf data in
Ok (ips, Digest.string data)
with
| Error _ | Ok [] -> List.map (fun ip -> ip, 53) Dns_client.default_resolvers
| Ok ips -> List.map (function `Nameserver ip -> (ip, 53)) ips
| Error _ -> default_resolvers (), None
| Ok (ips, digest) ->
List.map (function `Nameserver ip -> (ip, 53)) ips, Some digest
in
`Tcp, ips
(`Tcp, Resolv_conf { nameservers = ips; digest })
in
{ protocol ; nameservers ; timeout_ns = timeout }

let nameservers { protocol ; nameservers ; _ } = protocol, nameservers
let nameservers { protocol ; nameservers = Static nameservers | Resolv_conf { nameservers; _ } ; _ } =
protocol, nameservers
let clock = Mtime_clock.elapsed_ns
let rng = Mirage_crypto_rng.generate ?g:None

Expand All @@ -74,6 +115,7 @@ module Transport : Dns_client.S

(* there is no connect timeouts, just a request timeout (unix: receive timeout) *)
let connect t =
maybe_resolv_conf t;
match nameservers t with
| _, [] -> Error (`Msg "empty nameserver list")
| proto, (server, port) :: _ ->
Expand Down