From f8aea8cc51116ff27fe29fd170d4f23148e9f709 Mon Sep 17 00:00:00 2001 From: Barry Xu Date: Fri, 4 Oct 2024 18:12:00 +0800 Subject: [PATCH] Implement callback support of async_send_request for service generic client (#2614) Signed-off-by: Barry Xu --- rclcpp/include/rclcpp/generic_client.hpp | 106 ++++++++++++++++++++- rclcpp/src/rclcpp/generic_client.cpp | 19 ++++ rclcpp/test/rclcpp/test_generic_client.cpp | 69 ++++++++++++++ 3 files changed, 189 insertions(+), 5 deletions(-) diff --git a/rclcpp/include/rclcpp/generic_client.hpp b/rclcpp/include/rclcpp/generic_client.hpp index d6073decfc..cff0eb25a1 100644 --- a/rclcpp/include/rclcpp/generic_client.hpp +++ b/rclcpp/include/rclcpp/generic_client.hpp @@ -19,6 +19,7 @@ #include #include #include +#include #include #include @@ -46,6 +47,8 @@ class GenericClient : public ClientBase using Future = std::future; using SharedFuture = std::shared_future; + using CallbackType = std::function; + RCLCPP_SMART_PTR_DEFINITIONS(GenericClient) /// A convenient GenericClient::Future and request id pair. @@ -76,6 +79,20 @@ class GenericClient : public ClientBase ~FutureAndRequestId() = default; }; + /// A convenient GenericClient::SharedFuture and request id pair. + /** + * Public members: + * - future: a std::shared_future. + * - request_id: the request id associated with the future. + * + * All the other methods are equivalent to the ones std::shared_future provides. + */ + struct SharedFutureAndRequestId + : detail::FutureAndRequestId> + { + using detail::FutureAndRequestId>::FutureAndRequestId; + }; + GenericClient( rclcpp::node_interfaces::NodeBaseInterface * node_base, rclcpp::node_interfaces::NodeGraphInterface::SharedPtr node_graph, @@ -106,16 +123,16 @@ class GenericClient : public ClientBase * If the future never completes, * e.g. the call to Executor::spin_until_future_complete() times out, * GenericClient::remove_pending_request() must be called to clean the client internal state. - * Not doing so will make the `Client` instance to use more memory each time a response is not - * received from the service server. + * Not doing so will make the `GenericClient` instance to use more memory each time a response is + * not received from the service server. * * ```cpp - * auto future = client->async_send_request(my_request); + * auto future = generic_client->async_send_request(my_request); * if ( * rclcpp::FutureReturnCode::TIMEOUT == * executor->spin_until_future_complete(future, timeout)) * { - * client->remove_pending_request(future); + * generic_client->remove_pending_request(future); * // handle timeout * } else { * handle_response(future.get()); @@ -129,6 +146,45 @@ class GenericClient : public ClientBase FutureAndRequestId async_send_request(const Request request); + /// Send a request to the service server and schedule a callback in the executor. + /** + * Similar to the previous overload, but a callback will automatically be called when a response + * is received. + * + * If the callback is never called, because we never got a reply for the service server, + * remove_pending_request() has to be called with the returned request id or + * prune_pending_requests(). + * Not doing so will make the `GenericClient` instance use more memory each time a response is not + * received from the service server. + * In this case, it's convenient to setup a timer to cleanup the pending requests. + * + * \param[in] request request to be send. + * \param[in] cb callback that will be called when we get a response for this request. + * \return the request id representing the request just sent. + */ + template< + typename CallbackT, + typename std::enable_if< + rclcpp::function_traits::same_arguments< + CallbackT, + CallbackType + >::value + >::type * = nullptr + > + SharedFutureAndRequestId + async_send_request(const Request request, CallbackT && cb) + { + Promise promise; + auto shared_future = promise.get_future().share(); + auto req_id = async_send_request_impl( + request, + std::make_tuple( + CallbackType{std::forward(cb)}, + shared_future, + std::move(promise))); + return SharedFutureAndRequestId{std::move(shared_future), req_id}; + } + /// Clean all pending requests older than a time_point. /** * \param[in] time_point Requests that were sent before this point are going to be removed. @@ -149,15 +205,52 @@ class GenericClient : public ClientBase pruned_requests); } + /// Clean all pending requests. + /** + * \return number of pending requests that were removed. + */ RCLCPP_PUBLIC size_t prune_pending_requests(); + /// Cleanup a pending request. + /** + * This notifies the client that we have waited long enough for a response from the server + * to come, we have given up and we are not waiting for a response anymore. + * + * Not calling this will make the client start using more memory for each request + * that never got a reply from the server. + * + * \param[in] request_id request id returned by async_send_request(). + * \return true when a pending request was removed, false if not (e.g. a response was received). + */ RCLCPP_PUBLIC bool remove_pending_request( int64_t request_id); + /// Cleanup a pending request. + /** + * Convenient overload, same as: + * + * `GenericClient::remove_pending_request(this, future.request_id)`. + */ + RCLCPP_PUBLIC + bool + remove_pending_request( + const FutureAndRequestId & future); + + /// Cleanup a pending request. + /** + * Convenient overload, same as: + * + * `GenericClient::remove_pending_request(this, future.request_id)`. + */ + RCLCPP_PUBLIC + bool + remove_pending_request( + const SharedFutureAndRequestId & future); + /// Take the next response for this client. /** * \sa ClientBase::take_type_erased_response(). @@ -179,9 +272,12 @@ class GenericClient : public ClientBase } protected: + using CallbackTypeValueVariant = std::tuple; using CallbackInfoVariant = std::variant< - std::promise>; // Use variant for extension + std::promise, + CallbackTypeValueVariant>; // Use variant for extension + RCLCPP_PUBLIC int64_t async_send_request_impl( const Request request, diff --git a/rclcpp/src/rclcpp/generic_client.cpp b/rclcpp/src/rclcpp/generic_client.cpp index 987975d803..0ac9a86e15 100644 --- a/rclcpp/src/rclcpp/generic_client.cpp +++ b/rclcpp/src/rclcpp/generic_client.cpp @@ -109,6 +109,13 @@ GenericClient::handle_response( if (std::holds_alternative(value)) { auto & promise = std::get(value); promise.set_value(std::move(response)); + } else if (std::holds_alternative(value)) { + auto & inner = std::get(value); + const auto & callback = std::get(inner); + auto & promise = std::get(inner); + auto & future = std::get(inner); + promise.set_value(std::move(response)); + callback(std::move(future)); } } @@ -128,6 +135,18 @@ GenericClient::remove_pending_request(int64_t request_id) return pending_requests_.erase(request_id) != 0u; } +bool +GenericClient::remove_pending_request(const FutureAndRequestId & future) +{ + return this->remove_pending_request(future.request_id); +} + +bool +GenericClient::remove_pending_request(const SharedFutureAndRequestId & future) +{ + return this->remove_pending_request(future.request_id); +} + std::optional GenericClient::get_and_erase_pending_request(int64_t request_number) { diff --git a/rclcpp/test/rclcpp/test_generic_client.cpp b/rclcpp/test/rclcpp/test_generic_client.cpp index 496b21ab63..433348220b 100644 --- a/rclcpp/test/rclcpp/test_generic_client.cpp +++ b/rclcpp/test/rclcpp/test_generic_client.cpp @@ -14,6 +14,8 @@ #include +#include +#include #include #include #include @@ -28,6 +30,7 @@ #include "../mocking_utils/patch.hpp" #include "test_msgs/srv/empty.hpp" +#include "test_msgs/srv/basic_types.hpp" using namespace std::chrono_literals; @@ -228,3 +231,69 @@ TEST_F(TestGenericClientSub, construction_and_destruction) { }, rclcpp::exceptions::InvalidServiceNameError); } } + +TEST_F(TestGenericClientSub, async_send_request_with_request) { + const std::string service_name = "test_service"; + int64_t expected_change = 1111; + + auto client = node->create_generic_client(service_name, "test_msgs/srv/BasicTypes"); + + auto callback = [&expected_change]( + const test_msgs::srv::BasicTypes::Request::SharedPtr request, + test_msgs::srv::BasicTypes::Response::SharedPtr response) { + response->int64_value = request->int64_value + expected_change; + }; + + auto service = + node->create_service(service_name, std::move(callback)); + + ASSERT_TRUE(client->wait_for_service(std::chrono::seconds(5))); + ASSERT_TRUE(client->service_is_ready()); + + test_msgs::srv::BasicTypes::Request request; + request.int64_value = 12345678; + + auto future = client->async_send_request(static_cast(&request)); + rclcpp::spin_until_future_complete( + node->get_node_base_interface(), future, std::chrono::seconds(5)); + ASSERT_TRUE(future.valid()); + auto get_untyped_response = future.get(); + auto typed_response = + static_cast(get_untyped_response.get()); + EXPECT_EQ(typed_response->int64_value, (request.int64_value + expected_change)); +} + +TEST_F(TestGenericClientSub, async_send_request_with_request_and_callback) { + const std::string service_name = "test_service"; + int64_t expected_change = 2222; + + auto client = node->create_generic_client(service_name, "test_msgs/srv/BasicTypes"); + + auto server_callback = [&expected_change]( + const test_msgs::srv::BasicTypes::Request::SharedPtr request, + test_msgs::srv::BasicTypes::Response::SharedPtr response) { + response->int64_value = request->int64_value + expected_change; + }; + + auto service = + node->create_service(service_name, std::move(server_callback)); + + ASSERT_TRUE(client->wait_for_service(std::chrono::seconds(5))); + ASSERT_TRUE(client->service_is_ready()); + + test_msgs::srv::BasicTypes::Request request; + request.int64_value = 12345678; + + auto client_callback = [&request, &expected_change]( + rclcpp::GenericClient::SharedFuture future) { + auto untyped_response = future.get(); + auto typed_response = + static_cast(untyped_response.get()); + EXPECT_EQ(typed_response->int64_value, (request.int64_value + expected_change)); + }; + + auto future = + client->async_send_request(static_cast(&request), client_callback); + rclcpp::spin_until_future_complete( + node->get_node_base_interface(), future, std::chrono::seconds(5)); +}