Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lwt_io.establish_server (TCP servers): expose client socket to connection-handling callback #586

Merged
merged 4 commits into from
Jun 20, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
231 changes: 167 additions & 64 deletions src/unix/lwt_io.ml
Original file line number Diff line number Diff line change
Expand Up @@ -1559,71 +1559,131 @@ let shutdown_server server = Lazy.force server.shutdown
let shutdown_server_deprecated server =
Lwt.async (fun () -> shutdown_server server)

let establish_server_base
bind ?fd ?(buffer_size = !default_buffer_size) ?(backlog=5) sockaddr f =
let sock =
match fd with
(* There are several variants of establish_server that have accumulated over the
years in Lwt_io. This is their underlying implementation. The functions
exposed in the API are various wrappers around this one. *)
let establish_server_generic
bind_function
?fd:preexisting_socket_for_listening
?(backlog = 5)
listening_address
connection_handler_callback =

let listening_socket =
match preexisting_socket_for_listening with
| None ->
Lwt_unix.socket (Unix.domain_of_sockaddr sockaddr) Unix.SOCK_STREAM 0
| Some fd ->
fd
Lwt_unix.socket
(Unix.domain_of_sockaddr listening_address) Unix.SOCK_STREAM 0
| Some socket ->
socket
in
Lwt_unix.setsockopt sock Unix.SO_REUSEADDR true;
Lwt_unix.setsockopt listening_socket Unix.SO_REUSEADDR true;

(* This promise gets resolved with `Should_stop when the user calls
Lwt_io.shutdown_server. This begins the shutdown procedure. *)
let should_stop, notify_should_stop =
Lwt.wait () in

(* Some time after Lwt_io.shutdown_server is called, this function
establish_server_generic will actually close the listening socket. At that
point, this promise is resolved. This ends the shutdown procedure. *)
let wait_until_listening_socket_closed, notify_listening_socket_closed =
Lwt.wait () in

let rec accept_loop () =
let try_to_accept =
Lwt_unix.accept listening_socket >|= fun x ->
`Accepted x
in

let abort_waiter, abort_wakener = Lwt.wait () in
let abort_waiter = abort_waiter >>= fun () -> Lwt.return `Shutdown in
(* Signals that the listening socket has been closed. *)
let closed_waiter, closed_wakener = Lwt.wait () in
let rec loop () =
Lwt.pick
[Lwt_unix.accept sock >|= (fun x -> `Accept x);
abort_waiter] >>= function
| `Accept(fd, addr) ->
(try Lwt_unix.set_close_on_exec fd with Invalid_argument _ -> ());
let close = lazy (close_socket fd) in
f addr
(of_fd ~buffer:(Lwt_bytes.create buffer_size) ~mode:input
~close:(fun () -> Lazy.force close) fd,
of_fd ~buffer:(Lwt_bytes.create buffer_size) ~mode:output
~close:(fun () -> Lazy.force close) fd);
loop ()
| `Shutdown ->
Lwt_unix.close sock >>= fun () ->
(match sockaddr with
| Unix.ADDR_UNIX path when path <> "" && path.[0] <> '\x00' ->
Unix.unlink path;
Lwt.return_unit
| _ ->
Lwt.return_unit) [@ocaml.warning "-4"] >>= fun () ->
Lwt.wakeup closed_wakener ();
Lwt.pick [try_to_accept; should_stop] >>= function
| `Accepted (client_socket, client_address) ->
begin
try Lwt_unix.set_close_on_exec client_socket
with Invalid_argument _ -> ()
end;

connection_handler_callback client_address client_socket;

accept_loop ()

| `Should_stop ->
Lwt_unix.close listening_socket >>= fun () ->

begin match listening_address with
| Unix.ADDR_UNIX path when path <> "" && path.[0] <> '\x00' ->
Unix.unlink path
| _ ->
()
end [@ocaml.warning "-4"];

Lwt.wakeup_later notify_listening_socket_closed ();
Lwt.return_unit
in

let started, signal_started = Lwt.wait () in
Lwt.ignore_result begin
bind sock sockaddr >>= fun () ->
Lwt_unix.listen sock backlog;
Lwt.wakeup signal_started ();
loop ()
end;
let server =
{shutdown =
lazy begin
Lwt.wakeup_later notify_should_stop `Should_stop;
wait_until_listening_socket_closed
end}
in

let server = {shutdown = lazy (Lwt.wakeup abort_wakener (); closed_waiter)} in
(* Actually start the server. *)
let server_has_started =
bind_function listening_socket listening_address >>= fun () ->
Lwt_unix.listen listening_socket backlog;

server, started
Lwt.async accept_loop;

(* Old, deprecated version of [establish_server]. This function has to persist
for a while, in some form, until it is no longer exposed as
[Lwt_io.Versioned.establish_server_1]. *)
let establish_server_deprecated ?fd ?buffer_size ?backlog sockaddr f =
let blocking_bind fd addr =
Lwt.return (Lwt_unix.Versioned.bind_1 fd addr) [@ocaml.warning "-3"]
Lwt.return_unit
in
let f _addr c = f c in
establish_server_base blocking_bind ?fd ?buffer_size ?backlog sockaddr f
|> fst

let establish_server_with_client_address
?fd ?buffer_size ?backlog ?(no_close = false) sockaddr f =
server, server_has_started

let establish_server_with_client_socket
?server_fd ?backlog ?(no_close = false) sockaddr f =
let handler client_address client_socket =
Lwt.async begin fun () ->
(* Not using Lwt.finalize here, to make sure that exceptions from [f]
reach !Lwt.async_exception_hook before exceptions from closing the
channels. *)
Lwt.catch
(fun () -> f client_address client_socket)
(fun exn ->
!Lwt.async_exception_hook exn;
Lwt.return_unit)

>>= fun () ->
if no_close then Lwt.return_unit
else
if Lwt_unix.state client_socket = Lwt_unix.Closed then
Lwt.return_unit
else
Lwt.catch
(fun () -> close_socket client_socket)
(fun exn ->
!Lwt.async_exception_hook exn;
Lwt.return_unit)
end
in

let server, server_started =
establish_server_generic
Lwt_unix.bind ?fd:server_fd ?backlog sockaddr handler
in
server_started >>= fun () ->
Lwt.return server

let establish_server_with_client_address_generic
bind_function
?fd
?(buffer_size = !default_buffer_size)
?backlog
?(no_close = false)
sockaddr
handler =

let best_effort_close channel =
(* First, check whether the channel is closed. f may have already tried to
close the channel, received an exception, and handled it somehow. If so,
Expand All @@ -1637,20 +1697,37 @@ let establish_server_with_client_address
Lwt.catch
(fun () -> close channel)
(fun exn ->
!Lwt.async_exception_hook exn;
Lwt.return_unit)
!Lwt.async_exception_hook exn;
Lwt.return_unit)
in

let handler addr ((input_channel, output_channel) as channels) =
let handler client_address client_socket =
Lwt.async (fun () ->
let close = lazy (close_socket client_socket) in
let input_channel =
of_fd
~buffer:(Lwt_bytes.create buffer_size)
~mode:input
~close:(fun () -> Lazy.force close)
client_socket
in
let output_channel =
of_fd
~buffer:(Lwt_bytes.create buffer_size)
~mode:output
~close:(fun () -> Lazy.force close)
client_socket
in

(* Not using Lwt.finalize here, to make sure that exceptions from [f]
reach !Lwt.async_exception_hook before exceptions from closing the
channels. *)
Lwt.catch
(fun () -> f addr channels)
(fun () ->
handler client_address (input_channel, output_channel))
(fun exn ->
!Lwt.async_exception_hook exn;
Lwt.return_unit)
!Lwt.async_exception_hook exn;
Lwt.return_unit)

>>= fun () ->
if no_close then Lwt.return_unit
Expand All @@ -1659,18 +1736,44 @@ let establish_server_with_client_address
best_effort_close output_channel)
in

let server, started =
establish_server_base
Lwt_unix.bind ?fd ?buffer_size ?backlog sockaddr handler
establish_server_generic bind_function ?fd ?backlog sockaddr handler

let establish_server_with_client_address
?fd ?buffer_size ?backlog ?no_close sockaddr handler =
let server, server_started =
establish_server_with_client_address_generic
Lwt_unix.bind ?fd ?buffer_size ?backlog ?no_close sockaddr handler
in
started >>= fun () ->
server_started >>= fun () ->
Lwt.return server

let establish_server ?fd ?buffer_size ?backlog ?no_close sockaddr f =
let f _addr c = f c in
establish_server_with_client_address
?fd ?buffer_size ?backlog ?no_close sockaddr f

(* Old, deprecated version of [establish_server]. This function has to persist
for a while, in some form, until it is no longer exposed as
[Lwt_io.Versioned.establish_server_1]. *)
let establish_server_deprecated ?fd ?buffer_size ?backlog sockaddr f =
let blocking_bind fd addr =
Lwt.return (Lwt_unix.Versioned.bind_1 fd addr) [@ocaml.warning "-3"]
in
let f _addr c =
f c;
Lwt.return_unit
in

let server, server_started =
establish_server_with_client_address_generic
blocking_bind ?fd ?buffer_size ?backlog ~no_close:true sockaddr f
in

(* Poll for exceptions in server startup that may have occurred synchronously.
This emulates an old, deprecated behavior. *)
Lwt.ignore_result server_started;
server

let ignore_close ch =
ignore (close ch)

Expand Down
51 changes: 34 additions & 17 deletions src/unix/lwt_io.mli
Original file line number Diff line number Diff line change
Expand Up @@ -512,35 +512,33 @@ val with_close_connection :
type server
(** Type of a server *)

val establish_server_with_client_address :
?fd : Lwt_unix.file_descr ->
?buffer_size : int ->
?backlog : int ->
?no_close : bool ->
val establish_server_with_client_socket :
?server_fd:Lwt_unix.file_descr ->
?backlog:int ->
?no_close:bool ->
Unix.sockaddr ->
(Lwt_unix.sockaddr -> input_channel * output_channel -> unit Lwt.t) ->
(Lwt_unix.sockaddr -> Lwt_unix.file_descr -> unit Lwt.t) ->
server Lwt.t
(** [establish_server_with_client_address listen_address f] creates a server
(** [establish_server_with_client_socket listen_address f] creates a server
which listens for incoming connections on [listen_address]. When a client
makes a new connection, it is passed to [f]: more precisely, the server
calls
{[
f client_address (in_channel, out_channel)
f client_address client_socket
]}
where [client_address] is the address (peer name) of the new client, and
[in_channel] and [out_channel] are two channels wrapping the socket for
communicating with that client.
[client_socket] is the socket connected to the client.
The server does not block waiting for [f] to complete: it concurrently tries
to accept more client connections while [f] is handling the client.
When the promise returned by [f] completes (i.e., [f] is done handling the
client), [establish_server_with_client_address] automatically closes
[in_channel] and [out_channel]. This is a default behavior that is useful
for simple cases, but for a robust application you should explicitly close
these channels yourself, and handle any exceptions. If the channels are
client), [establish_server_with_client_socket] automatically closes
[client_socket]. This is a default behavior that is useful for simple cases,
but for a robust application you should explicitly close these channels
yourself, and handle any exceptions as appropriate. If the channels are
still open when [f] completes, and their automatic closing raises an
exception, [establish_server_with_client_address] treats it as an unhandled
exception reaching the top level of the application: it passes that
Expand All @@ -553,15 +551,34 @@ f client_address (in_channel, out_channel)
an exception), [establish_server_with_client_address] can do nothing with
that exception, except pass it to {!Lwt.async_exception_hook}.
[~fd] can be specified to use an existing file descriptor for listening.
Otherwise, a fresh socket is created internally.
[~server_fd] can be specified to use an existing file descriptor for
listening. Otherwise, a fresh socket is created internally. In either case,
[establish_server_with_client_socket] will internally assign
[listen_address] to the server socket.
[~backlog] is the argument passed to {!Lwt_unix.listen}.
The returned promise (a [server Lwt.t]) resolves when the server has just
started listening on [listen_address]: right after the internal call to
[listen], and right before the first internal call to [accept].
@since 4.1.0 *)

val establish_server_with_client_address :
?fd:Lwt_unix.file_descr ->
?buffer_size:int ->
?backlog:int ->
?no_close:bool ->
Unix.sockaddr ->
(Lwt_unix.sockaddr -> input_channel * output_channel -> unit Lwt.t) ->
server Lwt.t
(** Like {!Lwt_io.establish_server_with_client_socket}, but passes two buffered
channels to the connection handler [f]. These channels wrap the client
socket.
The channels are closed automatically when the promise returned by [f]
resolves. To avoid this behavior, pass [~no_close:true].
@since 3.1.0 *)

val shutdown_server : server -> unit Lwt.t
Expand Down Expand Up @@ -696,7 +713,7 @@ val establish_server :
[@@ocaml.deprecated
" Since Lwt 3.1.0, use Lwt_io.establish_server_with_client_address"]
(** Like [establish_server_with_client_address], but does not pass the client
address to the callback [f].
address or fd to the callback [f].
@deprecated Use {!establish_server_with_client_address}.
@since 3.0.0 *)
Expand Down