Skip to content

Commit

Permalink
Add Lwt_io.establish_server_with_client_socket
Browse files Browse the repository at this point in the history
  • Loading branch information
aantron committed May 4, 2018
1 parent 37d3da1 commit e658860
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 50 deletions.
114 changes: 81 additions & 33 deletions src/unix/lwt_io.ml
Original file line number Diff line number Diff line change
Expand Up @@ -1565,7 +1565,6 @@ let shutdown_server_deprecated server =
let establish_server_generic
bind_function
?fd:preexisting_socket_for_listening
?(buffer_size = !default_buffer_size)
?(backlog = 5)
listening_address
connection_handler_callback =
Expand Down Expand Up @@ -1604,24 +1603,7 @@ let establish_server_generic
with Invalid_argument _ -> ()
end;

let close = lazy (close_socket client_socket) in
let input_channel =
of_fd
~buffer:(Lwt_bytes.create buffer_size)
~mode:input
~close:(fun () -> Lazy.force close)
client_socket
in
let output_channel =
of_fd
~buffer:(Lwt_bytes.create buffer_size)
~mode:output
~close:(fun () -> Lazy.force close)
client_socket
in

connection_handler_callback
client_address (input_channel, output_channel);
connection_handler_callback client_address client_socket;

accept_loop ()

Expand Down Expand Up @@ -1659,8 +1641,49 @@ let establish_server_generic

server, server_has_started

let establish_server_with_client_address
?fd ?buffer_size ?backlog ?(no_close = false) sockaddr f =
let establish_server_with_client_socket
?server_fd ?backlog ?(no_close = false) sockaddr f =
let handler client_address client_socket =
Lwt.async begin fun () ->
(* Not using Lwt.finalize here, to make sure that exceptions from [f]
reach !Lwt.async_exception_hook before exceptions from closing the
channels. *)
Lwt.catch
(fun () -> f client_address client_socket)
(fun exn ->
!Lwt.async_exception_hook exn;
Lwt.return_unit)

>>= fun () ->
if no_close then Lwt.return_unit
else
if Lwt_unix.state client_socket = Lwt_unix.Closed then
Lwt.return_unit
else
Lwt.catch
(fun () -> close_socket client_socket)
(fun exn ->
!Lwt.async_exception_hook exn;
Lwt.return_unit)
end
in

let server, server_started =
establish_server_generic
Lwt_unix.bind ?fd:server_fd ?backlog sockaddr handler
in
server_started >>= fun () ->
Lwt.return server

let establish_server_with_client_address_generic
bind_function
?fd
?(buffer_size = !default_buffer_size)
?backlog
?(no_close = false)
sockaddr
handler =

let best_effort_close channel =
(* First, check whether the channel is closed. f may have already tried to
close the channel, received an exception, and handled it somehow. If so,
Expand All @@ -1674,20 +1697,37 @@ let establish_server_with_client_address
Lwt.catch
(fun () -> close channel)
(fun exn ->
!Lwt.async_exception_hook exn;
Lwt.return_unit)
!Lwt.async_exception_hook exn;
Lwt.return_unit)
in

let handler addr ((input_channel, output_channel) as channels) =
let handler client_address client_socket =
Lwt.async (fun () ->
let close = lazy (close_socket client_socket) in
let input_channel =
of_fd
~buffer:(Lwt_bytes.create buffer_size)
~mode:input
~close:(fun () -> Lazy.force close)
client_socket
in
let output_channel =
of_fd
~buffer:(Lwt_bytes.create buffer_size)
~mode:output
~close:(fun () -> Lazy.force close)
client_socket
in

(* Not using Lwt.finalize here, to make sure that exceptions from [f]
reach !Lwt.async_exception_hook before exceptions from closing the
channels. *)
Lwt.catch
(fun () -> f addr channels)
(fun () ->
handler client_address (input_channel, output_channel))
(fun exn ->
!Lwt.async_exception_hook exn;
Lwt.return_unit)
!Lwt.async_exception_hook exn;
Lwt.return_unit)

>>= fun () ->
if no_close then Lwt.return_unit
Expand All @@ -1696,11 +1736,15 @@ let establish_server_with_client_address
best_effort_close output_channel)
in

let server, started =
establish_server_generic
Lwt_unix.bind ?fd ?buffer_size ?backlog sockaddr handler
establish_server_generic bind_function ?fd ?backlog sockaddr handler

let establish_server_with_client_address
?fd ?buffer_size ?backlog ?no_close sockaddr handler =
let server, server_started =
establish_server_with_client_address_generic
Lwt_unix.bind ?fd ?buffer_size ?backlog ?no_close sockaddr handler
in
started >>= fun () ->
server_started >>= fun () ->
Lwt.return server

let establish_server ?fd ?buffer_size ?backlog ?no_close sockaddr f =
Expand All @@ -1715,10 +1759,14 @@ let establish_server_deprecated ?fd ?buffer_size ?backlog sockaddr f =
let blocking_bind fd addr =
Lwt.return (Lwt_unix.Versioned.bind_1 fd addr) [@ocaml.warning "-3"]
in
let f _addr c = f c in
let f _addr c =
f c;
Lwt.return_unit
in

let server, server_started =
establish_server_generic blocking_bind ?fd ?buffer_size ?backlog sockaddr f
establish_server_with_client_address_generic
blocking_bind ?fd ?buffer_size ?backlog ~no_close:true sockaddr f
in

(* Poll for exceptions in server startup that may have occurred synchronously.
Expand Down
51 changes: 34 additions & 17 deletions src/unix/lwt_io.mli
Original file line number Diff line number Diff line change
Expand Up @@ -512,35 +512,33 @@ val with_close_connection :
type server
(** Type of a server *)

val establish_server_with_client_address :
?fd : Lwt_unix.file_descr ->
?buffer_size : int ->
?backlog : int ->
?no_close : bool ->
val establish_server_with_client_socket :
?server_fd:Lwt_unix.file_descr ->
?backlog:int ->
?no_close:bool ->
Unix.sockaddr ->
(Lwt_unix.sockaddr -> input_channel * output_channel -> unit Lwt.t) ->
(Lwt_unix.sockaddr -> Lwt_unix.file_descr -> unit Lwt.t) ->
server Lwt.t
(** [establish_server_with_client_address listen_address f] creates a server
(** [establish_server_with_client_socket listen_address f] creates a server
which listens for incoming connections on [listen_address]. When a client
makes a new connection, it is passed to [f]: more precisely, the server
calls
{[
f client_address (in_channel, out_channel)
f client_address client_socket
]}
where [client_address] is the address (peer name) of the new client, and
[in_channel] and [out_channel] are two channels wrapping the socket for
communicating with that client.
[client_socket] is the socket connected to the client.
The server does not block waiting for [f] to complete: it concurrently tries
to accept more client connections while [f] is handling the client.
When the promise returned by [f] completes (i.e., [f] is done handling the
client), [establish_server_with_client_address] automatically closes
[in_channel] and [out_channel]. This is a default behavior that is useful
for simple cases, but for a robust application you should explicitly close
these channels yourself, and handle any exceptions. If the channels are
client), [establish_server_with_client_socket] automatically closes
[client_socket]. This is a default behavior that is useful for simple cases,
but for a robust application you should explicitly close these channels
yourself, and handle any exceptions as appropriate. If the channels are
still open when [f] completes, and their automatic closing raises an
exception, [establish_server_with_client_address] treats it as an unhandled
exception reaching the top level of the application: it passes that
Expand All @@ -553,15 +551,34 @@ f client_address (in_channel, out_channel)
an exception), [establish_server_with_client_address] can do nothing with
that exception, except pass it to {!Lwt.async_exception_hook}.
[~fd] can be specified to use an existing file descriptor for listening.
Otherwise, a fresh socket is created internally.
[~server_fd] can be specified to use an existing file descriptor for
listening. Otherwise, a fresh socket is created internally. In either case,
[establish_server_with_client_socket] will internally assign
[listen_address] to the server socket.
[~backlog] is the argument passed to {!Lwt_unix.listen}.
The returned promise (a [server Lwt.t]) resolves when the server has just
started listening on [listen_address]: right after the internal call to
[listen], and right before the first internal call to [accept].
@since 4.1.0 *)

val establish_server_with_client_address :
?fd:Lwt_unix.file_descr ->
?buffer_size:int ->
?backlog:int ->
?no_close:bool ->
Unix.sockaddr ->
(Lwt_unix.sockaddr -> input_channel * output_channel -> unit Lwt.t) ->
server Lwt.t
(** Like {!Lwt_io.establish_server_with_client_socket}, but passes two buffered
channels to the connection handler [f]. These channels wrap the client
socket.
The channels are closed automatically when the promise returned by [f]
resolves. To avoid this behavior, pass [~no_close:true].
@since 3.1.0 *)

val shutdown_server : server -> unit Lwt.t
Expand Down Expand Up @@ -696,7 +713,7 @@ val establish_server :
[@@ocaml.deprecated
" Since Lwt 3.1.0, use Lwt_io.establish_server_with_client_address"]
(** Like [establish_server_with_client_address], but does not pass the client
address to the callback [f].
address or fd to the callback [f].
@deprecated Use {!establish_server_with_client_address}.
@since 3.0.0 *)
Expand Down

0 comments on commit e658860

Please sign in to comment.