diff --git a/BUILD.bazel b/BUILD.bazel index 03c002fc1256..814254f69eb0 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -116,12 +116,14 @@ ray_cc_library( "src/ray/rpc/grpc_server.cc", "src/ray/rpc/server_call.cc", "src/ray/rpc/rpc_chaos.cc", + "src/ray/rpc/retryable_grpc_client.cc", ], hdrs = glob([ "src/ray/rpc/rpc_chaos.h", "src/ray/rpc/client_call.h", "src/ray/rpc/common.h", "src/ray/rpc/grpc_client.h", + "src/ray/rpc/retryable_grpc_client.h", "src/ray/rpc/grpc_server.h", "src/ray/rpc/metrics_agent_client.h", "src/ray/rpc/server_call.h", diff --git a/python/ray/tests/test_streaming_generator_4.py b/python/ray/tests/test_streaming_generator_4.py index 0a961cfd6d6c..833ef6c54f74 100644 --- a/python/ray/tests/test_streaming_generator_4.py +++ b/python/ray/tests/test_streaming_generator_4.py @@ -54,6 +54,10 @@ def test_ray_datasetlike_mini_stress_test( "RAY_testing_asio_delay_us", "CoreWorkerService.grpc_server.ReportGeneratorItemReturns=10000:1000000", ) + m.setenv( + "RAY_testing_rpc_failure", + "CoreWorkerService.grpc_client.ReportGeneratorItemReturns=5", + ) cluster = ray_start_cluster cluster.add_node( num_cpus=1, diff --git a/src/mock/ray/raylet_client/raylet_client.h b/src/mock/ray/raylet_client/raylet_client.h index de30333bdf78..9fb1c5ea86ee 100644 --- a/src/mock/ray/raylet_client/raylet_client.h +++ b/src/mock/ray/raylet_client/raylet_client.h @@ -128,6 +128,11 @@ class MockRayletClientInterface : public RayletClientInterface { int64_t draining_deadline_timestamp_ms, const rpc::ClientCallback &callback), (override)); + MOCK_METHOD(void, + IsLocalWorkerDead, + (const WorkerID &worker_id, + const rpc::ClientCallback &callback), + (override)); }; } // namespace ray diff --git a/src/ray/common/ray_config_def.h b/src/ray/common/ray_config_def.h index b1a73f79cb1a..718e21b285cd 100644 --- a/src/ray/common/ray_config_def.h +++ b/src/ray/common/ray_config_def.h @@ -447,7 +447,7 @@ RAY_CONFIG(int32_t, gcs_grpc_initial_reconnect_backoff_ms, 100) RAY_CONFIG(uint64_t, gcs_grpc_max_request_queued_max_bytes, 1024UL * 1024 * 1024 * 5) /// The duration between two checks for grpc status. -RAY_CONFIG(int32_t, gcs_client_check_connection_status_interval_milliseconds, 1000) +RAY_CONFIG(int32_t, grpc_client_check_connection_status_interval_milliseconds, 1000) /// Due to the protocol drawback, raylet needs to refresh the message if /// no message is received for a while. @@ -693,6 +693,9 @@ RAY_CONFIG(int64_t, timeout_ms_task_wait_for_death_info, 1000) /// report the loads to raylet. RAY_CONFIG(int64_t, core_worker_internal_heartbeat_ms, 1000) +/// Timeout for core worker grpc server reconnection in seconds. +RAY_CONFIG(int32_t, core_worker_rpc_server_reconnect_timeout_s, 60) + /// Maximum amount of memory that will be used by running tasks' args. RAY_CONFIG(float, max_task_args_memory_fraction, 0.7) diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index 9a15e8702892..8ddfceba138d 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -294,7 +294,45 @@ CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_ } core_worker_client_pool_ = - std::make_shared(*client_call_manager_); + std::make_shared([&](const rpc::Address &addr) { + return std::make_shared( + addr, + *client_call_manager_, + /*core_worker_unavailable_timeout_callback=*/[this, addr]() { + const NodeID node_id = NodeID::FromBinary(addr.raylet_id()); + const WorkerID worker_id = WorkerID::FromBinary(addr.worker_id()); + const rpc::GcsNodeInfo *node_info = + gcs_client_->Nodes().Get(node_id, /*filter_dead_nodes=*/true); + if (node_info == nullptr) { + RAY_LOG(INFO).WithField(worker_id).WithField(node_id) + << "Disconnect core worker client since its node is dead"; + io_service_.post( + [this, worker_id]() { + core_worker_client_pool_->Disconnect(worker_id); + }, + "CoreWorkerClientPool.Disconnect"); + + return; + } + + std::shared_ptr raylet_client = + std::make_shared( + rpc::NodeManagerWorkerClient::make( + node_info->node_manager_address(), + node_info->node_manager_port(), + *client_call_manager_)); + raylet_client->IsLocalWorkerDead( + worker_id, + [this, worker_id](const Status &status, + rpc::IsLocalWorkerDeadReply &&reply) { + if (status.ok() && reply.is_dead()) { + RAY_LOG(INFO).WithField(worker_id) + << "Disconnect core worker client since it is dead"; + core_worker_client_pool_->Disconnect(worker_id); + } + }); + }); + }); object_info_publisher_ = std::make_unique( /*channels=*/std::vector< @@ -743,14 +781,6 @@ void CoreWorker::Shutdown() { task_event_buffer_->Stop(); - if (gcs_client_) { - // We should disconnect gcs client first otherwise because it contains - // a blocking logic that can block the io service upon - // gcs shutdown. - // TODO(sang): Refactor GCS client to be more robust. - RAY_LOG(INFO) << "Disconnecting a GCS client."; - gcs_client_->Disconnect(); - } io_service_.stop(); RAY_LOG(INFO) << "Waiting for joining a core worker io thread. If it hangs here, there " "might be deadlock or a high load in the core worker io service."; @@ -763,7 +793,11 @@ void CoreWorker::Shutdown() { // Now that gcs_client is not used within io service, we can reset the pointer and clean // it up. - gcs_client_.reset(); + if (gcs_client_) { + RAY_LOG(INFO) << "Disconnecting a GCS client."; + gcs_client_->Disconnect(); + gcs_client_.reset(); + } RAY_LOG(INFO) << "Core worker ready to be deallocated."; } @@ -3311,13 +3345,13 @@ Status CoreWorker::ReportGeneratorItemReturns( if (status.ok()) { num_objects_consumed = reply.total_num_object_consumed(); } else { - // TODO(sang): Handle network error more gracefully. // If the request fails, we should just resume until task finishes without // backpressure. num_objects_consumed = waiter->TotalObjectGenerated(); RAY_LOG(WARNING).WithField(return_id) << "Failed to report streaming generator return " - "to the caller. The yield'ed ObjectRef may not be usable."; + "to the caller. The yield'ed ObjectRef may not be usable. " + << status; } waiter->HandleObjectReported(num_objects_consumed); }); diff --git a/src/ray/gcs/gcs_client/gcs_client.cc b/src/ray/gcs/gcs_client/gcs_client.cc index e46007f819ba..1305d5794553 100644 --- a/src/ray/gcs/gcs_client/gcs_client.cc +++ b/src/ray/gcs/gcs_client/gcs_client.cc @@ -176,7 +176,6 @@ Status GcsClient::FetchClusterId(int64_t timeout_ms) { Status s = gcs_rpc_client_->SyncGetClusterId(request, &reply, timeout_ms); if (!s.ok()) { RAY_LOG(WARNING) << "Failed to get cluster ID from GCS server: " << s; - gcs_rpc_client_->Shutdown(); gcs_rpc_client_.reset(); client_call_manager_.reset(); return s; @@ -189,7 +188,7 @@ Status GcsClient::FetchClusterId(int64_t timeout_ms) { void GcsClient::Disconnect() { if (gcs_rpc_client_) { - gcs_rpc_client_->Shutdown(); + gcs_rpc_client_.reset(); } } diff --git a/src/ray/gcs/gcs_server/gcs_server.cc b/src/ray/gcs/gcs_server/gcs_server.cc index c51c14bbb0dd..2b8b8c98bb23 100644 --- a/src/ray/gcs/gcs_server/gcs_server.cc +++ b/src/ray/gcs/gcs_server/gcs_server.cc @@ -411,7 +411,9 @@ void GcsServer::InitClusterTaskManager() { void GcsServer::InitGcsJobManager(const GcsInitData &gcs_init_data) { auto client_factory = [this](const rpc::Address &address) { - return std::make_shared(address, client_call_manager_); + return std::make_shared(address, client_call_manager_, []() { + RAY_LOG(FATAL) << "GCS doesn't call any retryable core worker grpc methods."; + }); }; RAY_CHECK(gcs_table_storage_ && gcs_publisher_); gcs_job_manager_ = std::make_unique(gcs_table_storage_, @@ -447,7 +449,9 @@ void GcsServer::InitGcsActorManager(const GcsInitData &gcs_init_data) { gcs_actor_manager_->OnActorCreationSuccess(std::move(actor), reply); }; auto client_factory = [this](const rpc::Address &address) { - return std::make_shared(address, client_call_manager_); + return std::make_shared(address, client_call_manager_, []() { + RAY_LOG(FATAL) << "GCS doesn't call any retryable core worker grpc methods."; + }); }; RAY_CHECK(gcs_resource_manager_ && cluster_task_manager_); @@ -464,18 +468,24 @@ void GcsServer::InitGcsActorManager(const GcsInitData &gcs_init_data) { [this](const NodeID &node_id, const rpc::ResourcesData &resources) { gcs_resource_manager_->UpdateNodeNormalTaskResources(node_id, resources); }); - gcs_actor_manager_ = std::make_shared( - std::move(scheduler), - gcs_table_storage_, - gcs_publisher_, - *runtime_env_manager_, - *function_manager_, - [this](const ActorID &actor_id) { - gcs_placement_group_manager_->CleanPlacementGroupIfNeededWhenActorDead(actor_id); - }, - [this](const rpc::Address &address) { - return std::make_shared(address, client_call_manager_); - }); + gcs_actor_manager_ = + std::make_shared( + std::move(scheduler), + gcs_table_storage_, + gcs_publisher_, + *runtime_env_manager_, + *function_manager_, + [this](const ActorID &actor_id) { + gcs_placement_group_manager_->CleanPlacementGroupIfNeededWhenActorDead( + actor_id); + }, + [this](const rpc::Address &address) { + return std::make_shared( + address, client_call_manager_, []() { + RAY_LOG(FATAL) + << "GCS doesn't call any retryable core worker grpc methods."; + }); + }); // Initialize by gcs tables data. gcs_actor_manager_->Initialize(gcs_init_data); diff --git a/src/ray/gcs/gcs_server/test/gcs_server_test_util.h b/src/ray/gcs/gcs_server/test/gcs_server_test_util.h index ce8c685f706d..727cfd28ee5b 100644 --- a/src/ray/gcs/gcs_server/test/gcs_server_test_util.h +++ b/src/ray/gcs/gcs_server/test/gcs_server_test_util.h @@ -321,6 +321,10 @@ struct GcsServerMocker { drain_raylet_callbacks.push_back(callback); }; + void IsLocalWorkerDead( + const WorkerID &worker_id, + const rpc::ClientCallback &callback) override{}; + void NotifyGCSRestart( const rpc::ClientCallback &callback) override{}; diff --git a/src/ray/protobuf/node_manager.proto b/src/ray/protobuf/node_manager.proto index 3b30f7e71b1a..5955df852a79 100644 --- a/src/ray/protobuf/node_manager.proto +++ b/src/ray/protobuf/node_manager.proto @@ -377,6 +377,16 @@ message PushMutableObjectReply { bool done = 1; } +message IsLocalWorkerDeadRequest { + // Binary worker id of the target worker. + bytes worker_id = 1; +} + +message IsLocalWorkerDeadReply { + // Whether the target worker is dead or not. + bool is_dead = 1; +} + // Service for inter-node-manager communication. service NodeManagerService { // Handle the case when GCS restarted. @@ -440,4 +450,5 @@ service NodeManagerService { rpc RegisterMutableObject(RegisterMutableObjectRequest) returns (RegisterMutableObjectReply); rpc PushMutableObject(PushMutableObjectRequest) returns (PushMutableObjectReply); + rpc IsLocalWorkerDead(IsLocalWorkerDeadRequest) returns (IsLocalWorkerDeadReply); } diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index 879edff0bb6c..3f345ed51daa 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -145,7 +145,11 @@ NodeManager::NodeManager( config.ray_debugger_external, /*get_time=*/[]() { return absl::GetCurrentTimeNanos() / 1e6; }), client_call_manager_(io_service), - worker_rpc_pool_(client_call_manager_), + worker_rpc_pool_([&](const rpc::Address &addr) { + return std::make_shared(addr, client_call_manager_, []() { + RAY_LOG(FATAL) << "Raylet doesn't call any retryable core worker grpc methods."; + }); + }), core_worker_subscriber_(std::make_unique( self_node_id_, /*channels=*/ @@ -1988,6 +1992,14 @@ void NodeManager::HandleReturnWorker(rpc::ReturnWorkerRequest request, send_reply_callback(status, nullptr, nullptr); } +void NodeManager::HandleIsLocalWorkerDead(rpc::IsLocalWorkerDeadRequest request, + rpc::IsLocalWorkerDeadReply *reply, + rpc::SendReplyCallback send_reply_callback) { + reply->set_is_dead(worker_pool_.GetRegisteredWorker( + WorkerID::FromBinary(request.worker_id())) == nullptr); + send_reply_callback(Status::OK(), nullptr, nullptr); +} + void NodeManager::HandleDrainRaylet(rpc::DrainRayletRequest request, rpc::DrainRayletReply *reply, rpc::SendReplyCallback send_reply_callback) { diff --git a/src/ray/raylet/node_manager.h b/src/ray/raylet/node_manager.h index cef5e66aa26f..6d128d322dae 100644 --- a/src/ray/raylet/node_manager.h +++ b/src/ray/raylet/node_manager.h @@ -552,6 +552,10 @@ class NodeManager : public rpc::NodeManagerServiceHandler, rpc::DrainRayletReply *reply, rpc::SendReplyCallback send_reply_callback) override; + void HandleIsLocalWorkerDead(rpc::IsLocalWorkerDeadRequest request, + rpc::IsLocalWorkerDeadReply *reply, + rpc::SendReplyCallback send_reply_callback) override; + /// Handle a `CancelWorkerLease` request. void HandleCancelWorkerLease(rpc::CancelWorkerLeaseRequest request, rpc::CancelWorkerLeaseReply *reply, diff --git a/src/ray/raylet/worker.cc b/src/ray/raylet/worker.cc index 82c7476b17fc..5798e7004c5c 100644 --- a/src/ray/raylet/worker.cc +++ b/src/ray/raylet/worker.cc @@ -117,7 +117,9 @@ void Worker::Connect(int port) { rpc::Address addr; addr.set_ip_address(ip_address_); addr.set_port(port_); - rpc_client_ = std::make_unique(addr, client_call_manager_); + rpc_client_ = std::make_unique(addr, client_call_manager_, []() { + RAY_LOG(FATAL) << "Raylet doesn't call any retryable core worker grpc methods."; + }); Connect(rpc_client_); } diff --git a/src/ray/raylet/worker_pool.cc b/src/ray/raylet/worker_pool.cc index 943ea89b24b5..2a7fb7031514 100644 --- a/src/ray/raylet/worker_pool.cc +++ b/src/ray/raylet/worker_pool.cc @@ -866,6 +866,20 @@ Status WorkerPool::RegisterDriver(const std::shared_ptr &driver return Status::OK(); } +std::shared_ptr WorkerPool::GetRegisteredWorker( + const WorkerID &worker_id) const { + for (const auto &entry : states_by_lang_) { + for (auto it = entry.second.registered_workers.begin(); + it != entry.second.registered_workers.end(); + it++) { + if ((*it)->WorkerId() == worker_id) { + return (*it); + } + } + } + return nullptr; +} + std::shared_ptr WorkerPool::GetRegisteredWorker( const std::shared_ptr &connection) const { for (const auto &entry : states_by_lang_) { diff --git a/src/ray/raylet/worker_pool.h b/src/ray/raylet/worker_pool.h index ef2e1e048635..2e12f3bf9981 100644 --- a/src/ray/raylet/worker_pool.h +++ b/src/ray/raylet/worker_pool.h @@ -294,6 +294,9 @@ class WorkerPool : public WorkerPoolInterface, public IOWorkerPoolInterface { std::shared_ptr GetRegisteredWorker( const std::shared_ptr &connection) const; + /// Get the registered worker by worker id or nullptr if not found. + std::shared_ptr GetRegisteredWorker(const WorkerID &worker_id) const; + /// Get the client connection's registered driver. /// /// \param The client connection owned by a registered driver. diff --git a/src/ray/raylet_client/raylet_client.cc b/src/ray/raylet_client/raylet_client.cc index 7911a3ce0a86..c957e68c0a0c 100644 --- a/src/ray/raylet_client/raylet_client.cc +++ b/src/ray/raylet_client/raylet_client.cc @@ -595,6 +595,14 @@ void raylet::RayletClient::DrainRaylet( grpc_client_->DrainRaylet(request, callback); } +void raylet::RayletClient::IsLocalWorkerDead( + const WorkerID &worker_id, + const rpc::ClientCallback &callback) { + rpc::IsLocalWorkerDeadRequest request; + request.set_worker_id(worker_id.Binary()); + grpc_client_->IsLocalWorkerDead(request, callback); +} + void raylet::RayletClient::GlobalGC( const rpc::ClientCallback &callback) { rpc::GlobalGCRequest request; diff --git a/src/ray/raylet_client/raylet_client.h b/src/ray/raylet_client/raylet_client.h index f40c97edf620..21663661b205 100644 --- a/src/ray/raylet_client/raylet_client.h +++ b/src/ray/raylet_client/raylet_client.h @@ -239,6 +239,10 @@ class RayletClientInterface : public PinObjectsInterface, int64_t deadline_timestamp_ms, const rpc::ClientCallback &callback) = 0; + virtual void IsLocalWorkerDead( + const WorkerID &worker_id, + const rpc::ClientCallback &callback) = 0; + virtual std::shared_ptr GetChannel() const = 0; }; @@ -523,6 +527,10 @@ class RayletClient : public RayletClientInterface { int64_t deadline_timestamp_ms, const rpc::ClientCallback &callback) override; + void IsLocalWorkerDead( + const WorkerID &worker_id, + const rpc::ClientCallback &callback) override; + void GetSystemConfig( const rpc::ClientCallback &callback) override; diff --git a/src/ray/rpc/gcs_server/gcs_rpc_client.h b/src/ray/rpc/gcs_server/gcs_rpc_client.h index a8fed2d20f27..920ff7ed63b4 100644 --- a/src/ray/rpc/gcs_server/gcs_rpc_client.h +++ b/src/ray/rpc/gcs_server/gcs_rpc_client.h @@ -21,41 +21,13 @@ #include "absl/container/btree_map.h" #include "ray/common/grpc_util.h" -#include "ray/common/network_util.h" -#include "ray/rpc/grpc_client.h" +#include "ray/rpc/retryable_grpc_client.h" #include "src/ray/protobuf/autoscaler.grpc.pb.h" #include "src/ray/protobuf/gcs_service.grpc.pb.h" namespace ray { namespace rpc { -class GcsRpcClient; - -/// \class Executor -/// Executor saves operation and support retries. -class Executor { - public: - Executor(std::function abort_callback) - : abort_callback_(std::move(abort_callback)) {} - - /// This function is used to execute the given operation. - /// - /// \param operation The operation to be executed. - void Execute(std::function operation) { - operation_ = std::move(operation); - operation_(); - } - - /// This function is used to retry the given operation. - void Retry() { operation_(); } - - void Abort(const ray::Status &status) { abort_callback_(status); } - - private: - std::function abort_callback_; - std::function operation_; -}; - /// Convenience macro to invoke VOID_GCS_RPC_CLIENT_METHOD_FULL with defaults. /// /// Creates a Sync and an Async method just like in VOID_GCS_RPC_CLIENT_METHOD_FULL, @@ -117,7 +89,7 @@ class Executor { NAMESPACE::METHOD##Reply, \ handle_payload_status>( \ &NAMESPACE::SERVICE::Stub::PrepareAsync##METHOD, \ - *grpc_client, \ + grpc_client, \ #NAMESPACE "::" #SERVICE ".grpc_client." #METHOD, \ request, \ callback, \ @@ -166,53 +138,71 @@ class GcsRpcClient { GcsRpcClient(const std::string &address, const int port, ClientCallManager &client_call_manager) - : gcs_address_(address), - gcs_port_(port), - io_context_(&client_call_manager.GetMainService()), - timer_(std::make_unique(*io_context_)) { + : gcs_address_(address), gcs_port_(port) { channel_ = CreateGcsChannel(address, port); // If not the reconnection will continue to work. auto deadline = std::chrono::system_clock::now() + std::chrono::seconds(::RayConfig::instance().gcs_rpc_server_connect_timeout_s()); if (!channel_->WaitForConnected(deadline)) { - RAY_LOG(ERROR) << "Failed to connect to GCS at address " << address << ":" << port - << " within " - << ::RayConfig::instance().gcs_rpc_server_connect_timeout_s() - << " seconds."; - gcs_is_down_ = true; - } else { - gcs_is_down_ = false; + RAY_LOG(WARNING) << "Failed to connect to GCS at address " << address << ":" << port + << " within " + << ::RayConfig::instance().gcs_rpc_server_connect_timeout_s() + << " seconds."; } job_info_grpc_client_ = - std::make_unique>(channel_, client_call_manager); + std::make_shared>(channel_, client_call_manager); actor_info_grpc_client_ = - std::make_unique>(channel_, client_call_manager); + std::make_shared>(channel_, client_call_manager); node_info_grpc_client_ = - std::make_unique>(channel_, client_call_manager); + std::make_shared>(channel_, client_call_manager); node_resource_info_grpc_client_ = - std::make_unique>(channel_, + std::make_shared>(channel_, client_call_manager); worker_info_grpc_client_ = - std::make_unique>(channel_, client_call_manager); + std::make_shared>(channel_, client_call_manager); placement_group_info_grpc_client_ = - std::make_unique>(channel_, + std::make_shared>(channel_, client_call_manager); internal_kv_grpc_client_ = - std::make_unique>(channel_, client_call_manager); - internal_pubsub_grpc_client_ = std::make_unique>( + std::make_shared>(channel_, client_call_manager); + internal_pubsub_grpc_client_ = std::make_shared>( channel_, client_call_manager); task_info_grpc_client_ = - std::make_unique>(channel_, client_call_manager); + std::make_shared>(channel_, client_call_manager); autoscaler_state_grpc_client_ = - std::make_unique>( + std::make_shared>( channel_, client_call_manager); runtime_env_grpc_client_ = - std::make_unique>(channel_, client_call_manager); - - SetupCheckTimer(); + std::make_shared>(channel_, client_call_manager); + + retryable_grpc_client_ = RetryableGrpcClient::Create( + channel_, + client_call_manager.GetMainService(), + /*max_pending_requests_bytes=*/ + ::RayConfig::instance().gcs_grpc_max_request_queued_max_bytes(), + /*check_channel_status_interval_milliseconds=*/ + ::RayConfig::instance() + .grpc_client_check_connection_status_interval_milliseconds(), + /*server_unavailable_timeout_seconds=*/ + ::RayConfig::instance().gcs_rpc_server_reconnect_timeout_s(), + /*server_unavailable_timeout_callback=*/ + []() { + RAY_LOG(ERROR) << "Failed to connect to GCS within " + << ::RayConfig::instance().gcs_rpc_server_reconnect_timeout_s() + << " seconds. " + << "GCS may have been killed. It's either GCS is terminated by " + "`ray stop` or " + << "is killed unexpectedly. If it is killed unexpectedly, " + << "see the log file gcs_server.out. " + << "https://docs.ray.io/en/master/ray-observability/user-guides/" + "configure-logging.html#logging-directory-structure. " + << "The program will terminate."; + std::_Exit(EXIT_FAILURE); + }, + /*server_name=*/"GCS"); } template void invoke_async_method( PrepareAsyncFunction prepare_async_function, - GrpcClient &grpc_client, + std::shared_ptr> grpc_client, const std::string &call_name, const Request &request, const ClientCallback &callback, const int64_t timeout_ms) { - auto executor = new Executor( - [callback](const ray::Status &status) { callback(status, Reply()); }); - auto operation_callback = [this, request, callback, executor, timeout_ms]( - const ray::Status &status, Reply &&reply) { - if (status.ok()) { - if constexpr (handle_payload_status) { - Status st = - (reply.status().code() == (int)StatusCode::OK) - ? Status() - : Status(StatusCode(reply.status().code()), reply.status().message()); - callback(st, std::move(reply)); - } else { - callback(status, std::move(reply)); - } - delete executor; - } else if (!IsGrpcRetryableStatus(status)) { - callback(status, std::move(reply)); - delete executor; - } else { - /* In case of GCS failure, we queue the request and these requests will be */ - /* executed once GCS is back. */ - gcs_is_down_ = true; - auto request_bytes = request.ByteSizeLong(); - if (pending_requests_bytes_ + request_bytes > - ::RayConfig::instance().gcs_grpc_max_request_queued_max_bytes()) { - RAY_LOG(WARNING) << "Pending queue for failed GCS request has reached the " - << "limit. Blocking the current thread until GCS is back"; - while (gcs_is_down_ && !shutdown_) { - CheckChannelStatus(false); - std::this_thread::sleep_for(std::chrono::milliseconds( - ::RayConfig::instance() - .gcs_client_check_connection_status_interval_milliseconds())); - } - if (shutdown_) { - callback(Status::Disconnected("GCS client has been disconnected."), - std::move(reply)); - delete executor; + retryable_grpc_client_->template CallMethod( + prepare_async_function, + std::move(grpc_client), + call_name, + request, + [callback](const Status &status, Reply &&reply) { + if (status.ok()) { + if constexpr (handle_payload_status) { + Status st = (reply.status().code() == (int)StatusCode::OK) + ? Status() + : Status(StatusCode(reply.status().code()), + reply.status().message()); + callback(st, std::move(reply)); + } else { + callback(status, std::move(reply)); + } } else { - executor->Retry(); + callback(status, std::move(reply)); } - } else { - pending_requests_bytes_ += request_bytes; - auto timeout = timeout_ms == -1 ? absl::InfiniteFuture() - : absl::Now() + absl::Milliseconds(timeout_ms); - pending_requests_.emplace(timeout, std::make_pair(executor, request_bytes)); - } - } - }; - auto operation = [prepare_async_function, - &grpc_client, - call_name, - request, - operation_callback, - timeout_ms]() { - grpc_client.template CallMethod( - prepare_async_function, request, operation_callback, call_name, timeout_ms); - }; - executor->Execute(std::move(operation)); + }, + timeout_ms); } /// Add job info to GCS Service. @@ -596,16 +549,6 @@ class GcsRpcClient { runtime_env_grpc_client_, /*method_timeout_ms*/ -1, ) - void Shutdown() { - if (!shutdown_.exchange(true)) { - // First call to shut down this GCS RPC client. - absl::MutexLock lock(&timer_mu_); - timer_->cancel(); - } else { - RAY_LOG(DEBUG) << "GCS RPC client has already shutdown."; - } - } - std::pair GetAddress() const { return std::make_pair(gcs_address_, gcs_port_); } @@ -613,121 +556,26 @@ class GcsRpcClient { std::shared_ptr GetChannel() const { return channel_; } private: - void SetupCheckTimer() { - auto duration = boost::posix_time::milliseconds( - ::RayConfig::instance() - .gcs_client_check_connection_status_interval_milliseconds()); - absl::MutexLock lock(&timer_mu_); - timer_->expires_from_now(duration); - timer_->async_wait([this](boost::system::error_code error) { - if (error == boost::system::errc::success) { - CheckChannelStatus(); - } - }); - } - - void CheckChannelStatus(bool reset_timer = true) { - if (shutdown_) { - return; - } - - auto status = channel_->GetState(false); - // https://grpc.github.io/grpc/core/md_doc_connectivity-semantics-and-api.html - // https://grpc.github.io/grpc/core/connectivity__state_8h_source.html - if (status != GRPC_CHANNEL_READY) { - RAY_LOG(DEBUG) << "GCS channel status: " << status; - } - - // We need to cleanup all the pending requests which are timeout. - auto now = absl::Now(); - while (!pending_requests_.empty()) { - auto iter = pending_requests_.begin(); - if (iter->first > now) { - break; - } - auto [executor, request_bytes] = iter->second; - executor->Abort( - ray::Status::TimedOut("Timed out while waiting for GCS to become available.")); - pending_requests_bytes_ -= request_bytes; - delete executor; - pending_requests_.erase(iter); - } - - switch (status) { - case GRPC_CHANNEL_TRANSIENT_FAILURE: - case GRPC_CHANNEL_CONNECTING: - if (!gcs_is_down_) { - gcs_is_down_ = true; - } else { - if (absl::ToInt64Seconds(absl::Now() - gcs_last_alive_time_) >= - ::RayConfig::instance().gcs_rpc_server_reconnect_timeout_s()) { - RAY_LOG(ERROR) << "Failed to connect to GCS within " - << ::RayConfig::instance().gcs_rpc_server_reconnect_timeout_s() - << " seconds. " - << "GCS may have been killed. It's either GCS is terminated by " - "`ray stop` or " - << "is killed unexpectedly. If it is killed unexpectedly, " - << "see the log file gcs_server.out. " - << "https://docs.ray.io/en/master/ray-observability/user-guides/" - "configure-logging.html#logging-directory-structure. " - << "The program will terminate."; - std::_Exit(EXIT_FAILURE); - } - } - break; - case GRPC_CHANNEL_SHUTDOWN: - RAY_CHECK(shutdown_) << "Channel shoud never go to this status."; - break; - case GRPC_CHANNEL_READY: - case GRPC_CHANNEL_IDLE: - gcs_last_alive_time_ = absl::Now(); - gcs_is_down_ = false; - // Retry the one queued. - while (!pending_requests_.empty()) { - pending_requests_.begin()->second.first->Retry(); - pending_requests_.erase(pending_requests_.begin()); - } - pending_requests_bytes_ = 0; - break; - default: - RAY_LOG(FATAL) << "Not covered status: " << status; - } - SetupCheckTimer(); - } - const std::string gcs_address_; const int64_t gcs_port_; - - instrumented_io_context *const io_context_; - - // Timer can be called from either the GCS RPC event loop, or the application's - // main thread. It needs to be protected by a mutex. - absl::Mutex timer_mu_; - const std::unique_ptr timer_; + std::shared_ptr channel_; + std::shared_ptr retryable_grpc_client_; /// The gRPC-generated stub. - std::unique_ptr> job_info_grpc_client_; - std::unique_ptr> actor_info_grpc_client_; - std::unique_ptr> node_info_grpc_client_; - std::unique_ptr> node_resource_info_grpc_client_; - std::unique_ptr> worker_info_grpc_client_; - std::unique_ptr> + std::shared_ptr> job_info_grpc_client_; + std::shared_ptr> actor_info_grpc_client_; + std::shared_ptr> node_info_grpc_client_; + std::shared_ptr> node_resource_info_grpc_client_; + std::shared_ptr> worker_info_grpc_client_; + std::shared_ptr> placement_group_info_grpc_client_; - std::unique_ptr> internal_kv_grpc_client_; - std::unique_ptr> internal_pubsub_grpc_client_; - std::unique_ptr> task_info_grpc_client_; - std::unique_ptr> runtime_env_grpc_client_; - std::unique_ptr> + std::shared_ptr> internal_kv_grpc_client_; + std::shared_ptr> internal_pubsub_grpc_client_; + std::shared_ptr> task_info_grpc_client_; + std::shared_ptr> runtime_env_grpc_client_; + std::shared_ptr> autoscaler_state_grpc_client_; - std::shared_ptr channel_; - bool gcs_is_down_ = false; - absl::Time gcs_last_alive_time_ = absl::Now(); - - std::atomic shutdown_ = false; - absl::btree_multimap> pending_requests_; - size_t pending_requests_bytes_ = 0; - friend class GcsClientReconnectionTest; FRIEND_TEST(GcsClientReconnectionTest, ReconnectionBackoff); }; diff --git a/src/ray/rpc/node_manager/node_manager_client.h b/src/ray/rpc/node_manager/node_manager_client.h index 95ca4846c3f5..77746485a6a9 100644 --- a/src/ray/rpc/node_manager/node_manager_client.h +++ b/src/ray/rpc/node_manager/node_manager_client.h @@ -126,6 +126,11 @@ class NodeManagerWorkerClient grpc_client_, /*method_timeout_ms*/ -1, ) + VOID_RPC_CLIENT_METHOD(NodeManagerService, + IsLocalWorkerDead, + grpc_client_, + /*method_timeout_ms*/ -1, ) + /// Cancel a pending worker lease request. VOID_RPC_CLIENT_METHOD(NodeManagerService, CancelWorkerLease, diff --git a/src/ray/rpc/node_manager/node_manager_server.h b/src/ray/rpc/node_manager/node_manager_server.h index bb11333ae35d..2a18b97a3493 100644 --- a/src/ray/rpc/node_manager/node_manager_server.h +++ b/src/ray/rpc/node_manager/node_manager_server.h @@ -45,6 +45,7 @@ namespace rpc { RAY_NODE_MANAGER_RPC_SERVICE_HANDLER(CancelResourceReserve) \ RAY_NODE_MANAGER_RPC_SERVICE_HANDLER(ReleaseUnusedBundles) \ RAY_NODE_MANAGER_RPC_SERVICE_HANDLER(GetSystemConfig) \ + RAY_NODE_MANAGER_RPC_SERVICE_HANDLER(IsLocalWorkerDead) \ RAY_NODE_MANAGER_RPC_SERVICE_HANDLER(ShutdownRaylet) \ RAY_NODE_MANAGER_RPC_SERVICE_HANDLER(DrainRaylet) \ RAY_NODE_MANAGER_RPC_SERVICE_HANDLER(GetTasksInfo) \ @@ -103,6 +104,10 @@ class NodeManagerServiceHandler { rpc::CancelWorkerLeaseReply *reply, rpc::SendReplyCallback send_reply_callback) = 0; + virtual void HandleIsLocalWorkerDead(rpc::IsLocalWorkerDeadRequest request, + rpc::IsLocalWorkerDeadReply *reply, + SendReplyCallback send_reply_callback) = 0; + virtual void HandlePrepareBundleResources( rpc::PrepareBundleResourcesRequest request, rpc::PrepareBundleResourcesReply *reply, diff --git a/src/ray/rpc/retryable_grpc_client.cc b/src/ray/rpc/retryable_grpc_client.cc new file mode 100644 index 000000000000..9cb3c3f2c221 --- /dev/null +++ b/src/ray/rpc/retryable_grpc_client.cc @@ -0,0 +1,157 @@ +// Copyright 2024 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ray/rpc/retryable_grpc_client.h" + +namespace ray { +namespace rpc { +RetryableGrpcClient::~RetryableGrpcClient() { + timer_->cancel(); + + // Fail the pending requests. + while (!pending_requests_.empty()) { + auto iter = pending_requests_.begin(); + // Make sure the callback is executed in the io context thread. + io_context_.post( + [request = std::move(iter->second)]() { + request->Fail(Status::Disconnected("GRPC client is shut down.")); + }, + "~RetryableGrpcClient"); + pending_requests_.erase(iter); + } + pending_requests_bytes_ = 0; +} + +void RetryableGrpcClient::SetupCheckTimer() { + auto duration = + boost::posix_time::milliseconds(check_channel_status_interval_milliseconds_); + timer_->expires_from_now(duration); + std::weak_ptr weak_self = weak_from_this(); + timer_->async_wait([weak_self](boost::system::error_code error) { + if (auto self = weak_self.lock(); self && (error == boost::system::errc::success)) { + self->CheckChannelStatus(); + } + }); +} + +void RetryableGrpcClient::CheckChannelStatus(bool reset_timer) { + // We need to cleanup all the pending requests which are timeout. + const auto now = absl::Now(); + while (!pending_requests_.empty()) { + auto iter = pending_requests_.begin(); + if (iter->first > now) { + break; + } + iter->second->Fail(ray::Status::TimedOut(absl::StrFormat( + "Timed out while waiting for %s to become available.", server_name_))); + pending_requests_bytes_ -= iter->second->GetRequestBytes(); + pending_requests_.erase(iter); + } + + if (pending_requests_.empty()) { + server_unavailable_timeout_time_ = std::nullopt; + return; + } + + RAY_CHECK(server_unavailable_timeout_time_.has_value()); + + auto status = channel_->GetState(false); + // https://grpc.github.io/grpc/core/md_doc_connectivity-semantics-and-api.html + // https://grpc.github.io/grpc/core/connectivity__state_8h_source.html + if (status != GRPC_CHANNEL_READY) { + RAY_LOG(DEBUG) << "GRPC channel status: " << status; + } + + switch (status) { + case GRPC_CHANNEL_TRANSIENT_FAILURE: + case GRPC_CHANNEL_CONNECTING: { + if (server_unavailable_timeout_time_ < absl::Now()) { + RAY_LOG(WARNING) << server_name_ << " has been unavailable for more than " + << server_unavailable_timeout_seconds_ << " seconds"; + server_unavailable_timeout_callback_(); + // Reset the unavailable timeout. + server_unavailable_timeout_time_ = + absl::Now() + absl::Seconds(server_unavailable_timeout_seconds_); + } + + if (reset_timer) { + SetupCheckTimer(); + } + + break; + } + case GRPC_CHANNEL_SHUTDOWN: { + RAY_LOG(FATAL) << "Channel shoud never go to this status."; + break; + } + case GRPC_CHANNEL_READY: + case GRPC_CHANNEL_IDLE: { + server_unavailable_timeout_time_ = std::nullopt; + // Retry the ones queued. + while (!pending_requests_.empty()) { + pending_requests_.begin()->second->CallMethod(); + pending_requests_.erase(pending_requests_.begin()); + } + pending_requests_bytes_ = 0; + break; + } + default: { + RAY_LOG(FATAL) << "Not covered status: " << status; + } + } +} + +void RetryableGrpcClient::Retry(std::shared_ptr request) { + // In case of transient network error, we queue the request and these requests + // will be executed once network is recovered. + auto request_bytes = request->GetRequestBytes(); + auto self = shared_from_this(); + if (pending_requests_bytes_ + request_bytes > max_pending_requests_bytes_) { + RAY_LOG(WARNING) << "Pending queue for failed request has reached the " + << "limit. Blocking the current thread until network is recovered"; + if (!server_unavailable_timeout_time_.has_value()) { + server_unavailable_timeout_time_ = + absl::Now() + absl::Seconds(server_unavailable_timeout_seconds_); + } + while (server_unavailable_timeout_time_.has_value()) { + std::this_thread::sleep_for( + std::chrono::milliseconds(check_channel_status_interval_milliseconds_)); + + if (self.use_count() == 2) { + // This means there are no external owners of this client + // and the client is considered shut down. + // The only two internal owners are caller of Retry and self. + break; + } + + CheckChannelStatus(false); + } + request->CallMethod(); + return; + } + + pending_requests_bytes_ += request_bytes; + auto timeout = request->GetTimeoutMs() == -1 + ? absl::InfiniteFuture() + : absl::Now() + absl::Milliseconds(request->GetTimeoutMs()); + pending_requests_.emplace(timeout, request); + if (!server_unavailable_timeout_time_.has_value()) { + // First request to retry. + server_unavailable_timeout_time_ = + absl::Now() + absl::Seconds(server_unavailable_timeout_seconds_); + SetupCheckTimer(); + } +} +} // namespace rpc +} // namespace ray diff --git a/src/ray/rpc/retryable_grpc_client.h b/src/ray/rpc/retryable_grpc_client.h new file mode 100644 index 000000000000..55eff746c174 --- /dev/null +++ b/src/ray/rpc/retryable_grpc_client.h @@ -0,0 +1,260 @@ +// Copyright 2024 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "absl/container/btree_map.h" +#include "absl/strings/str_format.h" +#include "ray/common/grpc_util.h" +#include "ray/rpc/client_call.h" +#include "ray/rpc/grpc_client.h" + +namespace ray { +namespace rpc { + +// Define a void retryable RPC client method. +#define VOID_RETRYABLE_RPC_CLIENT_METHOD( \ + retryable_rpc_client, SERVICE, METHOD, rpc_client, method_timeout_ms, SPECS) \ + void METHOD(const METHOD##Request &request, \ + const ClientCallback &callback) SPECS { \ + retryable_rpc_client->CallMethod( \ + &SERVICE::Stub::PrepareAsync##METHOD, \ + rpc_client, \ + #SERVICE ".grpc_client." #METHOD, \ + request, \ + callback, \ + method_timeout_ms); \ + } + +/** + * The client makes RPC calls through the provided underlying grpc client. + * If the call goes through, the user provided callback is invoked. + * If the call fails due to transient network error, it is added to a retry queue. + * The client waits for the grpc channel reconnection to resend the requests. + * If the total number of request bytes in the queue exceeds max_pending_requests_bytes, + * the io context thread is blocked until some requests are resent. + * If a call's timeout_ms reaches during retry, its callback is called with + * Status::TimedOut. If the whole client does not reconnect within + * server_unavailable_timeout_seconds, server_unavailable_timeout_callback is invoked. + * + * When all callers of the client release the shared_ptr of the client, the client + * destructor is called and the client is shut down. + */ +class RetryableGrpcClient : public std::enable_shared_from_this { + private: + /** + * Represents a single retryable grpc request. + * The lifecycle is managed by shared_ptr and it's either in the callback of an ongoing + * call or the RetryableGrpcClient retry queue. + * + * Implementation wise, it uses std::function for type erasure so that it can represent + * any underlying grpc request without making this class a template. + */ + class RetryableGrpcRequest : public std::enable_shared_from_this { + public: + template + static std::shared_ptr Create( + std::weak_ptr weak_retryable_grpc_client, + PrepareAsyncFunction prepare_async_function, + std::shared_ptr> grpc_client, + const std::string &call_name, + const Request &request, + const ClientCallback &callback, + const int64_t timeout_ms); + + RetryableGrpcRequest(const RetryableGrpcRequest &) = delete; + RetryableGrpcRequest &operator=(const RetryableGrpcRequest &) = delete; + + /// This function is used to call the RPC method to send out the request. + void CallMethod() { executor_(shared_from_this()); } + + void Fail(const ray::Status &status) { failure_callback_(status); } + + size_t GetRequestBytes() const { return request_bytes_; } + + int64_t GetTimeoutMs() const { return timeout_ms_; } + + private: + RetryableGrpcRequest( + std::function request)> executor, + std::function failure_callback, + size_t request_bytes, + int64_t timeout_ms) + : executor_(std::move(executor)), + failure_callback_(std::move(failure_callback)), + request_bytes_(request_bytes), + timeout_ms_(timeout_ms) {} + + std::function request)> executor_; + std::function failure_callback_; + const size_t request_bytes_; + const int64_t timeout_ms_; + }; + + public: + static std::shared_ptr Create( + std::shared_ptr channel, + instrumented_io_context &io_context, + uint64_t max_pending_requests_bytes, + uint64_t check_channel_status_interval_milliseconds, + uint64_t server_unavailable_timeout_seconds, + std::function server_unavailable_timeout_callback, + const std::string &server_name) { + return std::shared_ptr( + new RetryableGrpcClient(channel, + io_context, + max_pending_requests_bytes, + check_channel_status_interval_milliseconds, + server_unavailable_timeout_seconds, + server_unavailable_timeout_callback, + server_name)); + } + + RetryableGrpcClient(const RetryableGrpcClient &) = delete; + RetryableGrpcClient &operator=(const RetryableGrpcClient &) = delete; + + template + void CallMethod(PrepareAsyncFunction prepare_async_function, + std::shared_ptr> grpc_client, + const std::string &call_name, + const Request &request, + const ClientCallback &callback, + const int64_t timeout_ms); + + void Retry(std::shared_ptr request); + + // Return the number of pending requests waiting for retry. + size_t NumPendingRequests() const { return pending_requests_.size(); } + + ~RetryableGrpcClient(); + + private: + RetryableGrpcClient(std::shared_ptr channel, + instrumented_io_context &io_context, + uint64_t max_pending_requests_bytes, + uint64_t check_channel_status_interval_milliseconds, + uint64_t server_unavailable_timeout_seconds, + std::function server_unavailable_timeout_callback, + const std::string &server_name) + : io_context_(io_context), + timer_(std::make_unique(io_context)), + channel_(std::move(channel)), + max_pending_requests_bytes_(max_pending_requests_bytes), + check_channel_status_interval_milliseconds_( + check_channel_status_interval_milliseconds), + server_unavailable_timeout_seconds_(server_unavailable_timeout_seconds), + server_unavailable_timeout_callback_( + std::move(server_unavailable_timeout_callback)), + server_name_(server_name) {} + + void SetupCheckTimer(); + + void CheckChannelStatus(bool reset_timer = true); + + instrumented_io_context &io_context_; + const std::unique_ptr timer_; + + std::shared_ptr channel_; + + const uint64_t max_pending_requests_bytes_; + const uint64_t check_channel_status_interval_milliseconds_; + const uint64_t server_unavailable_timeout_seconds_; + std::function server_unavailable_timeout_callback_; + const std::string server_name_; + + // This is only set when there are pending requests and + // we need to check channel status. + // This is the time when the server will timeout for + // unavailability and server_unavailable_timeout_callback_ + // will be called. + std::optional server_unavailable_timeout_time_; + + // Key is when the request will timeout and value is the request. + // This is only accessed in the io context thread and the destructor so + // no mutex is needed. + absl::btree_multimap> + pending_requests_; + size_t pending_requests_bytes_ = 0; +}; + +template +void RetryableGrpcClient::CallMethod( + PrepareAsyncFunction prepare_async_function, + std::shared_ptr> grpc_client, + const std::string &call_name, + const Request &request, + const ClientCallback &callback, + const int64_t timeout_ms) { + RetryableGrpcRequest::Create(weak_from_this(), + prepare_async_function, + grpc_client, + call_name, + request, + callback, + timeout_ms) + ->CallMethod(); +} + +template +std::shared_ptr +RetryableGrpcClient::RetryableGrpcRequest::Create( + std::weak_ptr weak_retryable_grpc_client, + PrepareAsyncFunction prepare_async_function, + std::shared_ptr> grpc_client, + const std::string &call_name, + const Request &request, + const ClientCallback &callback, + const int64_t timeout_ms) { + RAY_CHECK(callback != nullptr); + RAY_CHECK(grpc_client.get() != nullptr); + + auto executor = [weak_retryable_grpc_client, + prepare_async_function, + grpc_client, + call_name, + request, + callback](std::shared_ptr + retryable_grpc_request) { + grpc_client->template CallMethod( + prepare_async_function, + request, + [weak_retryable_grpc_client, retryable_grpc_request, callback]( + const ray::Status &status, Reply &&reply) { + auto retryable_grpc_client = weak_retryable_grpc_client.lock(); + if (status.ok() || !IsGrpcRetryableStatus(status) || !retryable_grpc_client) { + callback(status, std::move(reply)); + return; + } + + retryable_grpc_client->Retry(retryable_grpc_request); + }, + call_name, + retryable_grpc_request->GetTimeoutMs()); + }; + + auto failure_callback = [callback](const ray::Status &status) { + callback(status, Reply()); + }; + + return std::shared_ptr( + new RetryableGrpcClient::RetryableGrpcRequest(std::move(executor), + std::move(failure_callback), + request.ByteSizeLong(), + timeout_ms)); +} + +} // namespace rpc +} // namespace ray diff --git a/src/ray/rpc/worker/core_worker_client.h b/src/ray/rpc/worker/core_worker_client.h index add45a82b24d..489d98bed66b 100644 --- a/src/ray/rpc/worker/core_worker_client.h +++ b/src/ray/rpc/worker/core_worker_client.h @@ -25,7 +25,7 @@ #include "absl/hash/hash.h" #include "ray/common/status.h" #include "ray/pubsub/subscriber.h" -#include "ray/rpc/grpc_client.h" +#include "ray/rpc/retryable_grpc_client.h" #include "ray/util/logging.h" #include "src/ray/protobuf/core_worker.grpc.pb.h" #include "src/ray/protobuf/core_worker.pb.h" @@ -76,7 +76,9 @@ class CoreWorkerClientInterface : public pubsub::SubscriberClientInterface { return empty_addr_; } - virtual bool IsChannelIdleAfterRPCs() const { return false; } + /// Returns true if the grpc channel is idle and there are no pending requests + /// after at least one RPC call is made. + virtual bool IsIdleAfterRPCs() const { return false; } /// Push an actor task directly from worker to worker. /// @@ -207,16 +209,33 @@ class CoreWorkerClient : public std::enable_shared_from_this, /// /// \param[in] address Address of the worker server. /// \param[in] client_call_manager The `ClientCallManager` used for managing requests. - CoreWorkerClient(const rpc::Address &address, ClientCallManager &client_call_manager) + CoreWorkerClient(const rpc::Address &address, + ClientCallManager &client_call_manager, + std::function core_worker_unavailable_timeout_callback) : addr_(address) { - grpc_client_ = std::make_unique>( + grpc_client_ = std::make_shared>( addr_.ip_address(), addr_.port(), client_call_manager); + + retryable_grpc_client_ = RetryableGrpcClient::Create( + grpc_client_->Channel(), + client_call_manager.GetMainService(), + /*max_pending_requests_bytes=*/ + std::numeric_limits::max(), + /*check_channel_status_interval_milliseconds=*/ + ::RayConfig::instance() + .grpc_client_check_connection_status_interval_milliseconds(), + /*server_unavailable_timeout_seconds=*/ + ::RayConfig::instance().core_worker_rpc_server_reconnect_timeout_s(), + /*server_unavailable_timeout_callback=*/ + core_worker_unavailable_timeout_callback, + /*server_name=*/"Core worker " + addr_.ip_address()); }; const rpc::Address &Addr() const override { return addr_; } - bool IsChannelIdleAfterRPCs() const override { - return grpc_client_->IsChannelIdleAfterRPCs(); + bool IsIdleAfterRPCs() const override { + return grpc_client_->IsChannelIdleAfterRPCs() && + (retryable_grpc_client_->NumPendingRequests() == 0); } VOID_RPC_CLIENT_METHOD(CoreWorkerService, @@ -279,11 +298,12 @@ class CoreWorkerClient : public std::enable_shared_from_this, /*method_timeout_ms*/ -1, override) - VOID_RPC_CLIENT_METHOD(CoreWorkerService, - ReportGeneratorItemReturns, - grpc_client_, - /*method_timeout_ms*/ -1, - override) + VOID_RETRYABLE_RPC_CLIENT_METHOD(retryable_grpc_client_, + CoreWorkerService, + ReportGeneratorItemReturns, + grpc_client_, + /*method_timeout_ms*/ -1, + override) VOID_RPC_CLIENT_METHOD(CoreWorkerService, RegisterMutableObjectReader, @@ -452,7 +472,9 @@ class CoreWorkerClient : public std::enable_shared_from_this, rpc::Address addr_; /// The RPC client. - std::unique_ptr> grpc_client_; + std::shared_ptr> grpc_client_; + + std::shared_ptr retryable_grpc_client_; /// Queue of requests to send. std::deque, ClientCallback>> diff --git a/src/ray/rpc/worker/core_worker_client_pool.cc b/src/ray/rpc/worker/core_worker_client_pool.cc index 19ee34497cc7..95a36ae15ed1 100644 --- a/src/ray/rpc/worker/core_worker_client_pool.cc +++ b/src/ray/rpc/worker/core_worker_client_pool.cc @@ -31,7 +31,7 @@ std::shared_ptr CoreWorkerClientPool::GetOrConnect( entry = *it->second; client_list_.erase(it->second); } else { - entry = CoreWorkerClientEntry(id, client_factory_(addr_proto)); + entry = CoreWorkerClientEntry(id, core_worker_client_factory_(addr_proto)); } client_list_.emplace_front(entry); client_map_[id] = client_list_.begin(); @@ -45,7 +45,7 @@ void CoreWorkerClientPool::RemoveIdleClients() { while (!client_list_.empty()) { auto id = client_list_.back().worker_id; // The last client in the list is the least recent accessed client. - if (client_list_.back().core_worker_client->IsChannelIdleAfterRPCs()) { + if (client_list_.back().core_worker_client->IsIdleAfterRPCs()) { client_map_.erase(id); client_list_.pop_back(); RAY_LOG(DEBUG) << "Remove idle client to worker " << id diff --git a/src/ray/rpc/worker/core_worker_client_pool.h b/src/ray/rpc/worker/core_worker_client_pool.h index 704ed82361b8..584171d2fbdf 100644 --- a/src/ray/rpc/worker/core_worker_client_pool.h +++ b/src/ray/rpc/worker/core_worker_client_pool.h @@ -27,13 +27,9 @@ class CoreWorkerClientPool { public: CoreWorkerClientPool() = delete; - /// Creates a CoreWorkerClientPool based on the low-level ClientCallManager. - CoreWorkerClientPool(rpc::ClientCallManager &ccm) - : client_factory_(defaultClientFactory(ccm)){}; - /// Creates a CoreWorkerClientPool by a given connection function. - CoreWorkerClientPool(CoreWorkerClientFactoryFn client_factory) - : client_factory_(client_factory){}; + explicit CoreWorkerClientPool(CoreWorkerClientFactoryFn client_factory) + : core_worker_client_factory_(std::move(client_factory)){}; /// Returns an open CoreWorkerClientInterface if one exists, and connect to one /// if it does not. The returned pointer is borrowed, and expected to be used @@ -53,15 +49,6 @@ class CoreWorkerClientPool { } private: - /// Provides the default client factory function. Providing this function to the - /// construtor aids migration but is ultimately a thing that should be - /// deprecated and brought internal to the pool, so this is our bridge. - CoreWorkerClientFactoryFn defaultClientFactory(rpc::ClientCallManager &ccm) const { - return [&](const rpc::Address &addr) { - return std::shared_ptr(new rpc::CoreWorkerClient(addr, ccm)); - }; - }; - /// Try to remove some idle clients to free memory. /// It doesn't go through the entire list and remove all idle clients. /// Instead, it tries to remove idle clients from the end of the list @@ -73,7 +60,7 @@ class CoreWorkerClientPool { /// This factory function does the connection to CoreWorkerClient, and is /// provided by the constructor (either the default implementation, above, or a /// provided one) - CoreWorkerClientFactoryFn client_factory_; + CoreWorkerClientFactoryFn core_worker_client_factory_; absl::Mutex mu_; @@ -82,7 +69,7 @@ class CoreWorkerClientPool { CoreWorkerClientEntry() {} CoreWorkerClientEntry(ray::WorkerID worker_id, std::shared_ptr core_worker_client) - : worker_id(worker_id), core_worker_client(core_worker_client) {} + : worker_id(worker_id), core_worker_client(std::move(core_worker_client)) {} ray::WorkerID worker_id; std::shared_ptr core_worker_client; diff --git a/src/ray/rpc/worker/test/core_worker_client_pool_test.cc b/src/ray/rpc/worker/test/core_worker_client_pool_test.cc index f85fc97f49dc..ffffde9ea852 100644 --- a/src/ray/rpc/worker/test/core_worker_client_pool_test.cc +++ b/src/ray/rpc/worker/test/core_worker_client_pool_test.cc @@ -21,9 +21,9 @@ namespace ray { namespace rpc { class MockCoreWorkerClient : public CoreWorkerClientInterface { public: - bool IsChannelIdleAfterRPCs() const override { return is_channel_idle_after_rpcs; } + bool IsIdleAfterRPCs() const override { return is_idle_after_rpcs; } - bool is_channel_idle_after_rpcs = false; + bool is_idle_after_rpcs = false; }; class CoreWorkerClientPoolTest : public ::testing::Test { @@ -55,7 +55,7 @@ TEST_F(CoreWorkerClientPoolTest, TestGC) { ASSERT_EQ(client_pool.Size(), 1); client2 = client_pool.GetOrConnect(address2); ASSERT_EQ(client_pool.Size(), 2); - static_cast(client1.get())->is_channel_idle_after_rpcs = true; + static_cast(client1.get())->is_idle_after_rpcs = true; // Client 1 will be removed since it's idle. ASSERT_EQ(client2.get(), client_pool.GetOrConnect(address2).get()); ASSERT_EQ(client_pool.Size(), 1);