Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Retryable grpc client #47981

Open
wants to merge 30 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -116,12 +116,14 @@ ray_cc_library(
"src/ray/rpc/grpc_server.cc",
"src/ray/rpc/server_call.cc",
"src/ray/rpc/rpc_chaos.cc",
"src/ray/rpc/retryable_grpc_client.cc",
],
hdrs = glob([
"src/ray/rpc/rpc_chaos.h",
"src/ray/rpc/client_call.h",
"src/ray/rpc/common.h",
"src/ray/rpc/grpc_client.h",
"src/ray/rpc/retryable_grpc_client.h",
"src/ray/rpc/grpc_server.h",
"src/ray/rpc/metrics_agent_client.h",
"src/ray/rpc/server_call.h",
Expand Down
4 changes: 4 additions & 0 deletions python/ray/tests/test_streaming_generator_4.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ def test_ray_datasetlike_mini_stress_test(
"RAY_testing_asio_delay_us",
"CoreWorkerService.grpc_server.ReportGeneratorItemReturns=10000:1000000",
)
m.setenv(
"RAY_testing_rpc_failure",
"CoreWorkerService.grpc_client.ReportGeneratorItemReturns=5",
)
cluster = ray_start_cluster
cluster.add_node(
num_cpus=1,
Expand Down
5 changes: 5 additions & 0 deletions src/mock/ray/raylet_client/raylet_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,11 @@ class MockRayletClientInterface : public RayletClientInterface {
int64_t draining_deadline_timestamp_ms,
const rpc::ClientCallback<rpc::DrainRayletReply> &callback),
(override));
MOCK_METHOD(void,
IsLocalWorkerDead,
(const WorkerID &worker_id,
const rpc::ClientCallback<rpc::IsLocalWorkerDeadReply> &callback),
(override));
};

} // namespace ray
5 changes: 4 additions & 1 deletion src/ray/common/ray_config_def.h
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,7 @@ RAY_CONFIG(int32_t, gcs_grpc_initial_reconnect_backoff_ms, 100)
RAY_CONFIG(uint64_t, gcs_grpc_max_request_queued_max_bytes, 1024UL * 1024 * 1024 * 5)

/// The duration between two checks for grpc status.
RAY_CONFIG(int32_t, gcs_client_check_connection_status_interval_milliseconds, 1000)
RAY_CONFIG(int32_t, grpc_client_check_connection_status_interval_milliseconds, 1000)

/// Due to the protocol drawback, raylet needs to refresh the message if
/// no message is received for a while.
Expand Down Expand Up @@ -693,6 +693,9 @@ RAY_CONFIG(int64_t, timeout_ms_task_wait_for_death_info, 1000)
/// report the loads to raylet.
RAY_CONFIG(int64_t, core_worker_internal_heartbeat_ms, 1000)

/// Timeout for core worker grpc server reconnection in seconds.
RAY_CONFIG(int32_t, core_worker_rpc_server_reconnect_timeout_s, 60)

/// Maximum amount of memory that will be used by running tasks' args.
RAY_CONFIG(float, max_task_args_memory_fraction, 0.7)

Expand Down
58 changes: 46 additions & 12 deletions src/ray/core_worker/core_worker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,45 @@ CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_
}

core_worker_client_pool_ =
std::make_shared<rpc::CoreWorkerClientPool>(*client_call_manager_);
std::make_shared<rpc::CoreWorkerClientPool>([&](const rpc::Address &addr) {
return std::make_shared<rpc::CoreWorkerClient>(
addr,
*client_call_manager_,
/*core_worker_unavailable_timeout_callback=*/[this, addr]() {
const NodeID node_id = NodeID::FromBinary(addr.raylet_id());
const WorkerID worker_id = WorkerID::FromBinary(addr.worker_id());
const rpc::GcsNodeInfo *node_info =
gcs_client_->Nodes().Get(node_id, /*filter_dead_nodes=*/true);
if (node_info == nullptr) {
RAY_LOG(INFO).WithField(worker_id).WithField(node_id)
<< "Disconnect core worker client since its node is dead";
io_service_.post(
[this, worker_id]() {
core_worker_client_pool_->Disconnect(worker_id);
},
"CoreWorkerClientPool.Disconnect");

return;
}

std::shared_ptr<raylet::RayletClient> raylet_client =
std::make_shared<raylet::RayletClient>(
rpc::NodeManagerWorkerClient::make(
node_info->node_manager_address(),
node_info->node_manager_port(),
*client_call_manager_));
raylet_client->IsLocalWorkerDead(
worker_id,
[this, worker_id](const Status &status,
rpc::IsLocalWorkerDeadReply &&reply) {
if (status.ok() && reply.is_dead()) {
RAY_LOG(INFO).WithField(worker_id)
<< "Disconnect core worker client since it is dead";
core_worker_client_pool_->Disconnect(worker_id);
}
});
});
});

object_info_publisher_ = std::make_unique<pubsub::Publisher>(
/*channels=*/std::vector<
Expand Down Expand Up @@ -743,14 +781,6 @@ void CoreWorker::Shutdown() {

task_event_buffer_->Stop();

if (gcs_client_) {
// We should disconnect gcs client first otherwise because it contains
// a blocking logic that can block the io service upon
// gcs shutdown.
// TODO(sang): Refactor GCS client to be more robust.
RAY_LOG(INFO) << "Disconnecting a GCS client.";
gcs_client_->Disconnect();
}
io_service_.stop();
RAY_LOG(INFO) << "Waiting for joining a core worker io thread. If it hangs here, there "
"might be deadlock or a high load in the core worker io service.";
Expand All @@ -763,7 +793,11 @@ void CoreWorker::Shutdown() {

// Now that gcs_client is not used within io service, we can reset the pointer and clean
// it up.
gcs_client_.reset();
if (gcs_client_) {
RAY_LOG(INFO) << "Disconnecting a GCS client.";
gcs_client_->Disconnect();
gcs_client_.reset();
}

RAY_LOG(INFO) << "Core worker ready to be deallocated.";
}
Expand Down Expand Up @@ -3311,13 +3345,13 @@ Status CoreWorker::ReportGeneratorItemReturns(
if (status.ok()) {
num_objects_consumed = reply.total_num_object_consumed();
} else {
// TODO(sang): Handle network error more gracefully.
// If the request fails, we should just resume until task finishes without
// backpressure.
num_objects_consumed = waiter->TotalObjectGenerated();
RAY_LOG(WARNING).WithField(return_id)
<< "Failed to report streaming generator return "
"to the caller. The yield'ed ObjectRef may not be usable.";
"to the caller. The yield'ed ObjectRef may not be usable. "
<< status;
}
waiter->HandleObjectReported(num_objects_consumed);
});
Expand Down
3 changes: 1 addition & 2 deletions src/ray/gcs/gcs_client/gcs_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,6 @@ Status GcsClient::FetchClusterId(int64_t timeout_ms) {
Status s = gcs_rpc_client_->SyncGetClusterId(request, &reply, timeout_ms);
if (!s.ok()) {
RAY_LOG(WARNING) << "Failed to get cluster ID from GCS server: " << s;
gcs_rpc_client_->Shutdown();
gcs_rpc_client_.reset();
client_call_manager_.reset();
return s;
Expand All @@ -189,7 +188,7 @@ Status GcsClient::FetchClusterId(int64_t timeout_ms) {

void GcsClient::Disconnect() {
if (gcs_rpc_client_) {
gcs_rpc_client_->Shutdown();
gcs_rpc_client_.reset();
}
}

Expand Down
38 changes: 24 additions & 14 deletions src/ray/gcs/gcs_server/gcs_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,9 @@ void GcsServer::InitClusterTaskManager() {

void GcsServer::InitGcsJobManager(const GcsInitData &gcs_init_data) {
auto client_factory = [this](const rpc::Address &address) {
return std::make_shared<rpc::CoreWorkerClient>(address, client_call_manager_);
return std::make_shared<rpc::CoreWorkerClient>(address, client_call_manager_, []() {
RAY_LOG(FATAL) << "GCS doesn't call any retryable core worker grpc methods.";
});
};
RAY_CHECK(gcs_table_storage_ && gcs_publisher_);
gcs_job_manager_ = std::make_unique<GcsJobManager>(gcs_table_storage_,
Expand Down Expand Up @@ -447,7 +449,9 @@ void GcsServer::InitGcsActorManager(const GcsInitData &gcs_init_data) {
gcs_actor_manager_->OnActorCreationSuccess(std::move(actor), reply);
};
auto client_factory = [this](const rpc::Address &address) {
return std::make_shared<rpc::CoreWorkerClient>(address, client_call_manager_);
return std::make_shared<rpc::CoreWorkerClient>(address, client_call_manager_, []() {
RAY_LOG(FATAL) << "GCS doesn't call any retryable core worker grpc methods.";
});
};

RAY_CHECK(gcs_resource_manager_ && cluster_task_manager_);
Expand All @@ -464,18 +468,24 @@ void GcsServer::InitGcsActorManager(const GcsInitData &gcs_init_data) {
[this](const NodeID &node_id, const rpc::ResourcesData &resources) {
gcs_resource_manager_->UpdateNodeNormalTaskResources(node_id, resources);
});
gcs_actor_manager_ = std::make_shared<GcsActorManager>(
std::move(scheduler),
gcs_table_storage_,
gcs_publisher_,
*runtime_env_manager_,
*function_manager_,
[this](const ActorID &actor_id) {
gcs_placement_group_manager_->CleanPlacementGroupIfNeededWhenActorDead(actor_id);
},
[this](const rpc::Address &address) {
return std::make_shared<rpc::CoreWorkerClient>(address, client_call_manager_);
});
gcs_actor_manager_ =
std::make_shared<GcsActorManager>(
std::move(scheduler),
gcs_table_storage_,
gcs_publisher_,
*runtime_env_manager_,
*function_manager_,
[this](const ActorID &actor_id) {
gcs_placement_group_manager_->CleanPlacementGroupIfNeededWhenActorDead(
actor_id);
},
[this](const rpc::Address &address) {
return std::make_shared<rpc::CoreWorkerClient>(
address, client_call_manager_, []() {
RAY_LOG(FATAL)
<< "GCS doesn't call any retryable core worker grpc methods.";
});
});

// Initialize by gcs tables data.
gcs_actor_manager_->Initialize(gcs_init_data);
Expand Down
4 changes: 4 additions & 0 deletions src/ray/gcs/gcs_server/test/gcs_server_test_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,10 @@ struct GcsServerMocker {
drain_raylet_callbacks.push_back(callback);
};

void IsLocalWorkerDead(
const WorkerID &worker_id,
const rpc::ClientCallback<rpc::IsLocalWorkerDeadReply> &callback) override{};

void NotifyGCSRestart(
const rpc::ClientCallback<rpc::NotifyGCSRestartReply> &callback) override{};

Expand Down
11 changes: 11 additions & 0 deletions src/ray/protobuf/node_manager.proto
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,16 @@ message PushMutableObjectReply {
bool done = 1;
}

message IsLocalWorkerDeadRequest {
// Binary worker id of the target worker.
bytes worker_id = 1;
}

message IsLocalWorkerDeadReply {
// Whether the target worker is dead or not.
bool is_dead = 1;
}

// Service for inter-node-manager communication.
service NodeManagerService {
// Handle the case when GCS restarted.
Expand Down Expand Up @@ -440,4 +450,5 @@ service NodeManagerService {
rpc RegisterMutableObject(RegisterMutableObjectRequest)
returns (RegisterMutableObjectReply);
rpc PushMutableObject(PushMutableObjectRequest) returns (PushMutableObjectReply);
rpc IsLocalWorkerDead(IsLocalWorkerDeadRequest) returns (IsLocalWorkerDeadReply);
}
14 changes: 13 additions & 1 deletion src/ray/raylet/node_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,11 @@ NodeManager::NodeManager(
config.ray_debugger_external,
/*get_time=*/[]() { return absl::GetCurrentTimeNanos() / 1e6; }),
client_call_manager_(io_service),
worker_rpc_pool_(client_call_manager_),
worker_rpc_pool_([&](const rpc::Address &addr) {
return std::make_shared<rpc::CoreWorkerClient>(addr, client_call_manager_, []() {
RAY_LOG(FATAL) << "Raylet doesn't call any retryable core worker grpc methods.";
});
}),
core_worker_subscriber_(std::make_unique<pubsub::Subscriber>(
self_node_id_,
/*channels=*/
Expand Down Expand Up @@ -1988,6 +1992,14 @@ void NodeManager::HandleReturnWorker(rpc::ReturnWorkerRequest request,
send_reply_callback(status, nullptr, nullptr);
}

void NodeManager::HandleIsLocalWorkerDead(rpc::IsLocalWorkerDeadRequest request,
rpc::IsLocalWorkerDeadReply *reply,
rpc::SendReplyCallback send_reply_callback) {
reply->set_is_dead(worker_pool_.GetRegisteredWorker(
WorkerID::FromBinary(request.worker_id())) == nullptr);
send_reply_callback(Status::OK(), nullptr, nullptr);
}

void NodeManager::HandleDrainRaylet(rpc::DrainRayletRequest request,
rpc::DrainRayletReply *reply,
rpc::SendReplyCallback send_reply_callback) {
Expand Down
4 changes: 4 additions & 0 deletions src/ray/raylet/node_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -552,6 +552,10 @@ class NodeManager : public rpc::NodeManagerServiceHandler,
rpc::DrainRayletReply *reply,
rpc::SendReplyCallback send_reply_callback) override;

void HandleIsLocalWorkerDead(rpc::IsLocalWorkerDeadRequest request,
rpc::IsLocalWorkerDeadReply *reply,
rpc::SendReplyCallback send_reply_callback) override;

/// Handle a `CancelWorkerLease` request.
void HandleCancelWorkerLease(rpc::CancelWorkerLeaseRequest request,
rpc::CancelWorkerLeaseReply *reply,
Expand Down
4 changes: 3 additions & 1 deletion src/ray/raylet/worker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,9 @@ void Worker::Connect(int port) {
rpc::Address addr;
addr.set_ip_address(ip_address_);
addr.set_port(port_);
rpc_client_ = std::make_unique<rpc::CoreWorkerClient>(addr, client_call_manager_);
rpc_client_ = std::make_unique<rpc::CoreWorkerClient>(addr, client_call_manager_, []() {
RAY_LOG(FATAL) << "Raylet doesn't call any retryable core worker grpc methods.";
});
Connect(rpc_client_);
}

Expand Down
14 changes: 14 additions & 0 deletions src/ray/raylet/worker_pool.cc
Original file line number Diff line number Diff line change
Expand Up @@ -866,6 +866,20 @@ Status WorkerPool::RegisterDriver(const std::shared_ptr<WorkerInterface> &driver
return Status::OK();
}

std::shared_ptr<WorkerInterface> WorkerPool::GetRegisteredWorker(
const WorkerID &worker_id) const {
for (const auto &entry : states_by_lang_) {
for (auto it = entry.second.registered_workers.begin();
it != entry.second.registered_workers.end();
it++) {
if ((*it)->WorkerId() == worker_id) {
return (*it);
}
}
}
return nullptr;
}

std::shared_ptr<WorkerInterface> WorkerPool::GetRegisteredWorker(
const std::shared_ptr<ClientConnection> &connection) const {
for (const auto &entry : states_by_lang_) {
Expand Down
3 changes: 3 additions & 0 deletions src/ray/raylet/worker_pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,9 @@ class WorkerPool : public WorkerPoolInterface, public IOWorkerPoolInterface {
std::shared_ptr<WorkerInterface> GetRegisteredWorker(
const std::shared_ptr<ClientConnection> &connection) const;

/// Get the registered worker by worker id or nullptr if not found.
std::shared_ptr<WorkerInterface> GetRegisteredWorker(const WorkerID &worker_id) const;

/// Get the client connection's registered driver.
///
/// \param The client connection owned by a registered driver.
Expand Down
8 changes: 8 additions & 0 deletions src/ray/raylet_client/raylet_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -595,6 +595,14 @@ void raylet::RayletClient::DrainRaylet(
grpc_client_->DrainRaylet(request, callback);
}

void raylet::RayletClient::IsLocalWorkerDead(
const WorkerID &worker_id,
const rpc::ClientCallback<rpc::IsLocalWorkerDeadReply> &callback) {
rpc::IsLocalWorkerDeadRequest request;
request.set_worker_id(worker_id.Binary());
grpc_client_->IsLocalWorkerDead(request, callback);
}

void raylet::RayletClient::GlobalGC(
const rpc::ClientCallback<rpc::GlobalGCReply> &callback) {
rpc::GlobalGCRequest request;
Expand Down
8 changes: 8 additions & 0 deletions src/ray/raylet_client/raylet_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,10 @@ class RayletClientInterface : public PinObjectsInterface,
int64_t deadline_timestamp_ms,
const rpc::ClientCallback<rpc::DrainRayletReply> &callback) = 0;

virtual void IsLocalWorkerDead(
const WorkerID &worker_id,
const rpc::ClientCallback<rpc::IsLocalWorkerDeadReply> &callback) = 0;

virtual std::shared_ptr<grpc::Channel> GetChannel() const = 0;
};

Expand Down Expand Up @@ -523,6 +527,10 @@ class RayletClient : public RayletClientInterface {
int64_t deadline_timestamp_ms,
const rpc::ClientCallback<rpc::DrainRayletReply> &callback) override;

void IsLocalWorkerDead(
const WorkerID &worker_id,
const rpc::ClientCallback<rpc::IsLocalWorkerDeadReply> &callback) override;

void GetSystemConfig(
const rpc::ClientCallback<rpc::GetSystemConfigReply> &callback) override;

Expand Down
Loading