From 56c1b6eae798901018fbdf0674c5d1fe1f2ad630 Mon Sep 17 00:00:00 2001 From: Anton Bachin Date: Sun, 29 Apr 2018 01:02:50 -0500 Subject: [PATCH] Add Lwt_io.establish_server_with_client_socket --- src/unix/lwt_io.ml | 20 +++++++++++++------- src/unix/lwt_io.mli | 39 ++++++++++++++++++++++++++++----------- 2 files changed, 41 insertions(+), 18 deletions(-) diff --git a/src/unix/lwt_io.ml b/src/unix/lwt_io.ml index 2ecba5774a..b8a5cf508c 100644 --- a/src/unix/lwt_io.ml +++ b/src/unix/lwt_io.ml @@ -1621,7 +1621,7 @@ let establish_server_generic in connection_handler_callback - client_address (input_channel, output_channel); + client_address client_socket (input_channel, output_channel); accept_loop () @@ -1659,8 +1659,8 @@ let establish_server_generic server, server_has_started -let establish_server_with_client_address - ?fd ?buffer_size ?backlog ?(no_close = false) sockaddr f = +let establish_server_with_client_socket + ?server_fd ?buffer_size ?backlog ?(no_close = false) sockaddr f = let best_effort_close channel = (* First, check whether the channel is closed. f may have already tried to close the channel, received an exception, and handled it somehow. If so, @@ -1678,13 +1678,13 @@ let establish_server_with_client_address Lwt.return_unit) in - let handler addr ((input_channel, output_channel) as channels) = + let handler addr socket ((input_channel, output_channel) as channels) = Lwt.async (fun () -> (* Not using Lwt.finalize here, to make sure that exceptions from [f] reach !Lwt.async_exception_hook before exceptions from closing the channels. *) Lwt.catch - (fun () -> f addr channels) + (fun () -> f addr socket channels) (fun exn -> !Lwt.async_exception_hook exn; Lwt.return_unit) @@ -1698,11 +1698,17 @@ let establish_server_with_client_address let server, started = establish_server_generic - Lwt_unix.bind ?fd ?buffer_size ?backlog sockaddr handler + Lwt_unix.bind ?fd:server_fd ?buffer_size ?backlog sockaddr handler in started >>= fun () -> Lwt.return server +let establish_server_with_client_address + ?fd ?buffer_size ?backlog ?no_close sockaddr handler = + let handler addr _socket c = handler addr c in + establish_server_with_client_socket + ?server_fd:fd ?buffer_size ?backlog ?no_close sockaddr handler + let establish_server ?fd ?buffer_size ?backlog ?no_close sockaddr f = let f _addr c = f c in establish_server_with_client_address @@ -1715,7 +1721,7 @@ let establish_server_deprecated ?fd ?buffer_size ?backlog sockaddr f = let blocking_bind fd addr = Lwt.return (Lwt_unix.Versioned.bind_1 fd addr) [@ocaml.warning "-3"] in - let f _addr c = f c in + let f _addr _socket c = f c in let server, server_started = establish_server_generic blocking_bind ?fd ?buffer_size ?backlog sockaddr f diff --git a/src/unix/lwt_io.mli b/src/unix/lwt_io.mli index 233d642991..657b045c2f 100644 --- a/src/unix/lwt_io.mli +++ b/src/unix/lwt_io.mli @@ -512,32 +512,36 @@ val with_close_connection : type server (** Type of a server *) -val establish_server_with_client_address : - ?fd : Lwt_unix.file_descr -> +val establish_server_with_client_socket : + ?server_fd : Lwt_unix.file_descr -> ?buffer_size : int -> ?backlog : int -> ?no_close : bool -> Unix.sockaddr -> - (Lwt_unix.sockaddr -> input_channel * output_channel -> unit Lwt.t) -> + (Lwt_unix.sockaddr -> + Lwt_unix.file_descr -> + input_channel * output_channel -> + unit Lwt.t) -> server Lwt.t -(** [establish_server_with_client_address listen_address f] creates a server +(** [establish_server_with_client_socket listen_address f] creates a server which listens for incoming connections on [listen_address]. When a client makes a new connection, it is passed to [f]: more precisely, the server calls {[ -f client_address (in_channel, out_channel) +f client_address client_socket (in_channel, out_channel) ]} - where [client_address] is the address (peer name) of the new client, and - [in_channel] and [out_channel] are two channels wrapping the socket for + where [client_address] is the address (peer name) of the new client, + [client_socket] is the socket connected to the client, and [in_channel] and + [out_channel] are two buffered channels wrapping the socket for communicating with that client. The server does not block waiting for [f] to complete: it concurrently tries to accept more client connections while [f] is handling the client. When the promise returned by [f] completes (i.e., [f] is done handling the - client), [establish_server_with_client_address] automatically closes + client), [establish_server_with_client_socket] automatically closes [in_channel] and [out_channel]. This is a default behavior that is useful for simple cases, but for a robust application you should explicitly close these channels yourself, and handle any exceptions. If the channels are @@ -553,8 +557,10 @@ f client_address (in_channel, out_channel) an exception), [establish_server_with_client_address] can do nothing with that exception, except pass it to {!Lwt.async_exception_hook}. - [~fd] can be specified to use an existing file descriptor for listening. - Otherwise, a fresh socket is created internally. + [~server_fd] can be specified to use an existing file descriptor for + listening. Otherwise, a fresh socket is created internally. In either case, + [establish_server_with_client_socket] will internally assign + [listen_address] to the server socket. [~backlog] is the argument passed to {!Lwt_unix.listen}. @@ -686,6 +692,17 @@ val set_default_buffer_size : int -> unit (** {2 Deprecated} *) +val establish_server_with_client_address : + ?fd : Lwt_unix.file_descr -> + ?buffer_size : int -> + ?backlog : int -> + ?no_close : bool -> + Unix.sockaddr -> + (Lwt_unix.sockaddr -> input_channel * output_channel -> unit Lwt.t) -> + server Lwt.t +(** Like [establish_server_with_client_fd], but does not pass the client fd to + the callback [f]. *) + val establish_server : ?fd : Lwt_unix.file_descr -> ?buffer_size : int -> @@ -696,7 +713,7 @@ val establish_server : [@@ocaml.deprecated " Since Lwt 3.1.0, use Lwt_io.establish_server_with_client_address"] (** Like [establish_server_with_client_address], but does not pass the client - address to the callback [f]. + address or fd to the callback [f]. @deprecated Use {!establish_server_with_client_address}. @since 3.0.0 *)