diff --git a/BUILD.bazel b/BUILD.bazel index b7a0014a77907..d13850f41a230 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -1850,6 +1850,18 @@ ray_cc_test( ], ) +ray_cc_test( + name = "virtual_cluster_manager_test", + size = "small", + srcs = ["src/ray/raylet/virtual_cluster_manager_test.cc"], + tags = ["team:core"], + deps = [ + ":ray_mock", + ":raylet_lib", + "@com_google_googletest//:gtest_main", + ], +) + ray_cc_library( name = "gcs_table_storage_lib", srcs = glob( diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index 3efe9a3d6712f..38f179037138c 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -405,6 +405,8 @@ NodeManager::NodeManager( mutable_object_provider_ = std::make_unique( *store_client_, absl::bind_front(&NodeManager::CreateRayletClient, this)); + + virtual_cluster_manager_ = std::make_shared(); } std::shared_ptr NodeManager::CreateRayletClient( @@ -505,8 +507,8 @@ ray::Status NodeManager::RegisterGcs() { // Subscribe to all virtual clusrter update notification. const auto virtual_cluster_update_notification_handler = [this](const VirtualClusterID &virtual_cluster_id, - const rpc::VirtualClusterTableData &virtual_cluster_data) { - // TODO(Shanly): To be implemented. + rpc::VirtualClusterTableData &&virtual_cluster_data) { + virtual_cluster_manager_->UpdateVirtualCluster(std::move(virtual_cluster_data)); }; RAY_RETURN_NOT_OK(gcs_client_->VirtualCluster().AsyncSubscribeAll( virtual_cluster_update_notification_handler, [](const ray::Status &status) { diff --git a/src/ray/raylet/node_manager.h b/src/ray/raylet/node_manager.h index ecfac284aa7dd..4e90f44d25b77 100644 --- a/src/ray/raylet/node_manager.h +++ b/src/ray/raylet/node_manager.h @@ -41,6 +41,7 @@ #include "ray/common/asio/instrumented_io_context.h" #include "ray/common/bundle_spec.h" #include "ray/raylet/placement_group_resource_manager.h" +#include "ray/raylet/virtual_cluster_manager.h" #include "ray/raylet/worker_killing_policy.h" #include "ray/core_worker/experimental_mutable_object_provider.h" // clang-format on @@ -894,6 +895,9 @@ class NodeManager : public rpc::NodeManagerServiceHandler, std::unique_ptr memory_monitor_; std::unique_ptr mutable_object_provider_; + + /// The virtual cluster manager. + std::shared_ptr virtual_cluster_manager_; }; } // namespace raylet diff --git a/src/ray/raylet/virtual_cluster_manager.cc b/src/ray/raylet/virtual_cluster_manager.cc new file mode 100644 index 0000000000000..2318ccadb42fe --- /dev/null +++ b/src/ray/raylet/virtual_cluster_manager.cc @@ -0,0 +1,70 @@ +// Copyright 2017 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License + +#include "ray/raylet/virtual_cluster_manager.h" + +namespace ray { + +namespace raylet { + +//////////////////////// VirtualClusterManager //////////////////////// +bool VirtualClusterManager::UpdateVirtualCluster( + rpc::VirtualClusterTableData virtual_cluster_data) { + RAY_LOG(INFO) << "Virtual cluster updated: " << virtual_cluster_data.id(); + if (virtual_cluster_data.mode() != rpc::AllocationMode::MIXED) { + RAY_LOG(WARNING) << "The virtual cluster mode is not MIXED, ignore it."; + return false; + } + + const auto &virtual_cluster_id = virtual_cluster_data.id(); + auto it = virtual_clusters_.find(virtual_cluster_id); + if (it == virtual_clusters_.end()) { + virtual_clusters_[virtual_cluster_id] = std::move(virtual_cluster_data); + } else { + if (it->second.revision() > virtual_cluster_data.revision()) { + RAY_LOG(WARNING) + << "The revision of the received virtual cluster is outdated, ignore it."; + return false; + } + + if (virtual_cluster_data.is_removed()) { + virtual_clusters_.erase(it); + return true; + } + + it->second = std::move(virtual_cluster_data); + } + return true; +} + +bool VirtualClusterManager::ContainsVirtualCluster( + const std::string &virtual_cluster_id) const { + return virtual_clusters_.find(virtual_cluster_id) != virtual_clusters_.end(); +} + +bool VirtualClusterManager::ContainsNodeInstance(const std::string &virtual_cluster_id, + const NodeID &node_id) const { + auto it = virtual_clusters_.find(virtual_cluster_id); + if (it == virtual_clusters_.end()) { + return false; + } + const auto &virtual_cluster_data = it->second; + RAY_CHECK(virtual_cluster_data.mode() == rpc::AllocationMode::MIXED); + + const auto &node_instances = virtual_cluster_data.node_instances(); + return node_instances.find(node_id.Hex()) != node_instances.end(); +} + +} // namespace raylet +} // namespace ray diff --git a/src/ray/raylet/virtual_cluster_manager.h b/src/ray/raylet/virtual_cluster_manager.h new file mode 100644 index 0000000000000..59f0de3ff558a --- /dev/null +++ b/src/ray/raylet/virtual_cluster_manager.h @@ -0,0 +1,53 @@ +// Copyright 2017 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "ray/common/id.h" +#include "src/ray/protobuf/gcs_service.pb.h" + +namespace ray { + +namespace raylet { + +class VirtualClusterManager { + public: + VirtualClusterManager() = default; + + /// Update the virtual cluster. + /// + /// \param virtual_cluster_data The virtual cluster data. + bool UpdateVirtualCluster(rpc::VirtualClusterTableData virtual_cluster_data); + + /// Check if the virtual cluster exists. + /// + /// \param virtual_cluster_id The virtual cluster id. + /// \return Whether the virtual cluster exists. + bool ContainsVirtualCluster(const std::string &virtual_cluster_id) const; + + /// Check if the virtual cluster contains the node instance. + /// + /// \param virtual_cluster_id The virtual cluster id. + /// \param node_id The node instance id. + /// \return Whether the virtual cluster contains the node instance. + bool ContainsNodeInstance(const std::string &virtual_cluster_id, + const NodeID &node_id) const; + + private: + /// The virtual clusters. + absl::flat_hash_map virtual_clusters_; +}; + +} // namespace raylet +} // end namespace ray diff --git a/src/ray/raylet/virtual_cluster_manager_test.cc b/src/ray/raylet/virtual_cluster_manager_test.cc new file mode 100644 index 0000000000000..526cc040ca09f --- /dev/null +++ b/src/ray/raylet/virtual_cluster_manager_test.cc @@ -0,0 +1,85 @@ +// Copyright 2017 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License + +#include "ray/raylet/virtual_cluster_manager.h" + +#include "absl/container/flat_hash_set.h" +#include "gtest/gtest.h" + +namespace ray { + +namespace raylet { + +class VirtualClusterManagerTest : public ::testing::Test {}; + +TEST_F(VirtualClusterManagerTest, UpdateVirtualCluster) { + VirtualClusterManager virtual_cluster_manager; + + std::string virtual_cluster_id_0 = "virtual_cluster_id_0"; + ASSERT_FALSE(virtual_cluster_manager.ContainsVirtualCluster("virtual_cluster_id")); + + rpc::VirtualClusterTableData virtual_cluster_data; + virtual_cluster_data.set_id(virtual_cluster_id_0); + virtual_cluster_data.set_mode(rpc::AllocationMode::EXCLUSIVE); + virtual_cluster_data.set_revision(100); + for (size_t i = 0; i < 100; ++i) { + auto node_id = NodeID::FromRandom(); + virtual_cluster_data.mutable_node_instances()->insert( + {node_id.Hex(), ray::rpc::NodeInstance()}); + } + ASSERT_FALSE(virtual_cluster_manager.UpdateVirtualCluster(virtual_cluster_data)); + ASSERT_FALSE(virtual_cluster_manager.ContainsVirtualCluster(virtual_cluster_id_0)); + + virtual_cluster_data.set_mode(rpc::AllocationMode::MIXED); + ASSERT_TRUE(virtual_cluster_manager.UpdateVirtualCluster(virtual_cluster_data)); + ASSERT_TRUE(virtual_cluster_manager.ContainsVirtualCluster(virtual_cluster_id_0)); + + virtual_cluster_data.set_revision(50); + ASSERT_FALSE(virtual_cluster_manager.UpdateVirtualCluster(virtual_cluster_data)); + + virtual_cluster_data.set_revision(150); + ASSERT_TRUE(virtual_cluster_manager.UpdateVirtualCluster(virtual_cluster_data)); + + virtual_cluster_data.set_is_removed(true); + ASSERT_TRUE(virtual_cluster_manager.UpdateVirtualCluster(virtual_cluster_data)); + ASSERT_FALSE(virtual_cluster_manager.ContainsVirtualCluster(virtual_cluster_id_0)); +} + +TEST_F(VirtualClusterManagerTest, TestContainsNodeInstance) { + VirtualClusterManager virtual_cluster_manager; + std::string virtual_cluster_id_0 = "virtual_cluster_id_0"; + + rpc::VirtualClusterTableData virtual_cluster_data; + virtual_cluster_data.set_id(virtual_cluster_id_0); + virtual_cluster_data.set_mode(rpc::AllocationMode::MIXED); + virtual_cluster_data.set_revision(100); + absl::flat_hash_set node_ids; + for (size_t i = 0; i < 100; ++i) { + auto node_id = NodeID::FromRandom(); + node_ids.emplace(node_id); + + virtual_cluster_data.mutable_node_instances()->insert( + {node_id.Hex(), ray::rpc::NodeInstance()}); + } + ASSERT_TRUE(virtual_cluster_manager.UpdateVirtualCluster(virtual_cluster_data)); + ASSERT_TRUE(virtual_cluster_manager.ContainsVirtualCluster(virtual_cluster_id_0)); + + for (const auto &node_id : node_ids) { + ASSERT_TRUE( + virtual_cluster_manager.ContainsNodeInstance(virtual_cluster_id_0, node_id)); + } +} + +} // namespace raylet +} // namespace ray