From 916959102ee4029be0adc1de0fa844abd63d46fe Mon Sep 17 00:00:00 2001 From: sitaowang1998 Date: Wed, 18 Dec 2024 02:03:10 -0500 Subject: [PATCH] feat: Add worker that request tasks from scheduler and execute tasks (#38) --- src/spider/CMakeLists.txt | 1 + src/spider/scheduler/scheduler.cpp | 2 +- src/spider/storage/MetadataStorage.hpp | 2 + src/spider/storage/MysqlStorage.cpp | 67 +++- src/spider/storage/MysqlStorage.hpp | 1 + src/spider/worker/FunctionManager.hpp | 99 +++++- src/spider/worker/TaskExecutor.cpp | 42 ++- src/spider/worker/TaskExecutor.hpp | 48 ++- src/spider/worker/WorkerClient.cpp | 6 +- src/spider/worker/WorkerClient.hpp | 2 - src/spider/worker/task_executor.cpp | 8 + src/spider/worker/worker.cpp | 384 ++++++++++++++++++++++- tests/scheduler/test-SchedulerServer.cpp | 12 + tests/worker/test-FunctionManager.cpp | 7 +- tests/worker/test-MessagePipe.cpp | 2 +- tests/worker/worker-test.cpp | 2 + 16 files changed, 657 insertions(+), 28 deletions(-) diff --git a/src/spider/CMakeLists.txt b/src/spider/CMakeLists.txt index 5271c6a..4504bd9 100644 --- a/src/spider/CMakeLists.txt +++ b/src/spider/CMakeLists.txt @@ -51,6 +51,7 @@ set(SPIDER_WORKER_SOURCES worker/message_pipe.hpp worker/WorkerClient.hpp worker/WorkerClient.cpp + utils/StopToken.hpp CACHE INTERNAL "spider worker source files" ) diff --git a/src/spider/scheduler/scheduler.cpp b/src/spider/scheduler/scheduler.cpp index e72b791..216d64e 100644 --- a/src/spider/scheduler/scheduler.cpp +++ b/src/spider/scheduler/scheduler.cpp @@ -128,7 +128,7 @@ auto cleanup_loop( auto main(int argc, char** argv) -> int { // Set up spdlog to write to stderr // NOLINTNEXTLINE(misc-include-cleaner) - spdlog::set_pattern("[%Y-%m-%d %H:%M:%S.%e] [%^%l%$] [spider][scheduler] %v"); + spdlog::set_pattern("[%Y-%m-%d %H:%M:%S.%e] [%^%l%$] [spider.scheduler] %v"); #ifndef NDEBUG spdlog::set_level(spdlog::level::trace); #endif diff --git a/src/spider/storage/MetadataStorage.hpp b/src/spider/storage/MetadataStorage.hpp index d0abd86..adb43a5 100644 --- a/src/spider/storage/MetadataStorage.hpp +++ b/src/spider/storage/MetadataStorage.hpp @@ -50,6 +50,8 @@ class MetadataStorage { virtual auto add_task_instance(TaskInstance const& instance) -> StorageErr = 0; virtual auto task_finish(TaskInstance const& instance, std::vector const& outputs) -> StorageErr = 0; + virtual auto task_fail(TaskInstance const& instance, std::string const& error) -> StorageErr + = 0; virtual auto get_task_timeout(std::vector* tasks) -> StorageErr = 0; virtual auto get_child_tasks(boost::uuids::uuid id, std::vector* children) -> StorageErr = 0; diff --git a/src/spider/storage/MysqlStorage.cpp b/src/spider/storage/MysqlStorage.cpp index 62129a9..a9efbe4 100644 --- a/src/spider/storage/MysqlStorage.cpp +++ b/src/spider/storage/MysqlStorage.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include #include @@ -27,6 +28,7 @@ #include #include #include +#include #include "../core/Data.hpp" #include "../core/Driver.hpp" @@ -241,11 +243,21 @@ auto string_to_task_state(std::string const& state) -> spider::core::TaskState { } // namespace auto MySqlMetadataStorage::connect(std::string const& url) -> StorageErr { + // Parse jdbc url + std::regex const url_regex(R"(jdbc:mariadb://[^?]+(\?user=([^&]*)(&password=([^&]*))?)?)"); + std::smatch match; + if (false == std::regex_match(url, match, url_regex)) { + return StorageErr{StorageErrType::OtherErr, "Invalid url"}; + } + bool const credential = match[2].matched && match[4].matched; if (nullptr == m_conn) { try { sql::Driver* driver = sql::mariadb::get_driver_instance(); - sql::Properties const properties; - m_conn = driver->connect(sql::SQLString(url), properties); + if (credential) { + m_conn = driver->connect(sql::SQLString(url), match[2].str(), match[4].str()); + } else { + m_conn = driver->connect(sql::SQLString(url), sql::Properties{}); + } m_conn->setAutoCommit(false); } catch (sql::SQLException& e) { return StorageErr{StorageErrType::ConnectionErr, e.what()}; @@ -1073,6 +1085,43 @@ auto MySqlMetadataStorage::task_finish( return StorageErr{}; } +auto MySqlMetadataStorage::task_fail(TaskInstance const& instance, std::string const& /*error*/) + -> StorageErr { + try { + // Remove task instance + std::unique_ptr const statement( + m_conn->prepareStatement("DELETE FROM `task_instances` WHERE `id` = ?") + ); + sql::bytes instance_id_bytes = uuid_get_bytes(instance.id); + statement->setBytes(1, &instance_id_bytes); + statement->executeUpdate(); + + // Get number of remaining instances + std::unique_ptr const count_statement(m_conn->prepareStatement( + "SELECT COUNT(*) FROM `task_instances` WHERE `task_id` = ?" + )); + sql::bytes task_id_bytes = uuid_get_bytes(instance.task_id); + count_statement->setBytes(1, &task_id_bytes); + std::unique_ptr const count_res{count_statement->executeQuery()}; + count_res->next(); + int32_t const count = count_res->getInt(1); + if (count == 0) { + // Set the task fail if the last task instance fails + std::unique_ptr const task_statement( + m_conn->prepareStatement("UPDATE `tasks` SET `state` = 'fail' WHERE `id` = ?") + ); + task_statement->setBytes(1, &task_id_bytes); + task_statement->executeUpdate(); + } + } catch (sql::SQLException& e) { + spdlog::error("Task fail error: {}", e.what()); + m_conn->rollback(); + return StorageErr{StorageErrType::OtherErr, e.what()}; + } + m_conn->commit(); + return StorageErr{}; +} + auto MySqlMetadataStorage::get_task_timeout(std::vector* tasks) -> StorageErr { try { std::unique_ptr statement(m_conn->createStatement()); @@ -1262,11 +1311,21 @@ auto MySqlMetadataStorage::set_scheduler_state(boost::uuids::uuid id, std::strin } auto MySqlDataStorage::connect(std::string const& url) -> StorageErr { + // Parse jdbc url + std::regex const url_regex(R"(jdbc:mariadb://[^?]+(\?user=([^&]*)(&password=([^&]*))?)?)"); + std::smatch match; + if (false == std::regex_match(url, match, url_regex)) { + return StorageErr{StorageErrType::OtherErr, "Invalid url"}; + } + bool const credential = match[2].matched && match[4].matched; if (nullptr == m_conn) { try { sql::Driver* driver = sql::mariadb::get_driver_instance(); - sql::Properties const properties; - m_conn = driver->connect(sql::SQLString(url), properties); + if (credential) { + m_conn = driver->connect(sql::SQLString(url), match[2].str(), match[4].str()); + } else { + m_conn = driver->connect(sql::SQLString(url), sql::Properties{}); + } m_conn->setAutoCommit(false); } catch (sql::SQLException& e) { return StorageErr{StorageErrType::ConnectionErr, e.what()}; diff --git a/src/spider/storage/MysqlStorage.hpp b/src/spider/storage/MysqlStorage.hpp index 8be9d8c..79244a7 100644 --- a/src/spider/storage/MysqlStorage.hpp +++ b/src/spider/storage/MysqlStorage.hpp @@ -54,6 +54,7 @@ class MySqlMetadataStorage : public MetadataStorage { auto add_task_instance(TaskInstance const& instance) -> StorageErr override; auto task_finish(TaskInstance const& instance, std::vector const& outputs) -> StorageErr override; + auto task_fail(TaskInstance const& instance, std::string const& error) -> StorageErr override; auto get_task_timeout(std::vector* tasks) -> StorageErr override; auto get_child_tasks(boost::uuids::uuid id, std::vector* children) -> StorageErr override; auto get_parent_tasks(boost::uuids::uuid id, std::vector* tasks) -> StorageErr override; diff --git a/src/spider/worker/FunctionManager.hpp b/src/spider/worker/FunctionManager.hpp index 7818bd6..6d012e4 100644 --- a/src/spider/worker/FunctionManager.hpp +++ b/src/spider/worker/FunctionManager.hpp @@ -11,11 +11,14 @@ #include #include #include +#include #include #include +#include #include "../io/MsgPack.hpp" // IWYU pragma: keep +#include "../io/Serializer.hpp" #include "TaskExecutorMessage.hpp" // NOLINTBEGIN(cppcoreguidelines-macro-usage) @@ -102,7 +105,7 @@ void create_error_buffer( msgpack::sbuffer& buffer ); -template +template auto response_get_result(msgpack::sbuffer const& buffer) -> std::optional { // NOLINTBEGIN(cppcoreguidelines-pro-type-union-access,cppcoreguidelines-pro-bounds-pointer-arithmetic) try { @@ -119,14 +122,81 @@ auto response_get_result(msgpack::sbuffer const& buffer) -> std::optional { return std::nullopt; } - return object.via.array.ptr[1].as(); + return std::make_optional(object.via.array.ptr[1].as()); } catch (msgpack::type_error& e) { return std::nullopt; } // NOLINTEND(cppcoreguidelines-pro-type-union-access,cppcoreguidelines-pro-bounds-pointer-arithmetic) } -template +template +requires(sizeof...(Ts) > 1) +auto response_get_result(msgpack::sbuffer const& buffer) -> std::optional> { + // NOLINTBEGIN(cppcoreguidelines-pro-type-union-access,cppcoreguidelines-pro-bounds-pointer-arithmetic) + try { + msgpack::object_handle const handle = msgpack::unpack(buffer.data(), buffer.size()); + msgpack::object const object = handle.get(); + + if (msgpack::type::ARRAY != object.type || sizeof...(Ts) + 1 != object.via.array.size) { + return std::nullopt; + } + + if (worker::TaskExecutorResponseType::Result + != object.via.array.ptr[0].as()) + { + return std::nullopt; + } + + std::tuple result; + for_n([&](auto i) { + object.via.array.ptr[i.cValue + 1].convert(std::get(result)); + }); + return std::make_optional(result); + } catch (msgpack::type_error& e) { + return std::nullopt; + } + // NOLINTEND(cppcoreguidelines-pro-type-union-access,cppcoreguidelines-pro-bounds-pointer-arithmetic) +} + +inline auto response_get_result_buffers(msgpack::sbuffer const& buffer +) -> std::optional> { + // NOLINTBEGIN(cppcoreguidelines-pro-type-union-access,cppcoreguidelines-pro-bounds-pointer-arithmetic) + try { + std::vector result_buffers; + msgpack::object_handle const handle = msgpack::unpack(buffer.data(), buffer.size()); + msgpack::object const object = handle.get(); + + if (msgpack::type::ARRAY != object.type || object.via.array.size < 2) { + spdlog::error("Cannot split result into buffers: Wrong type"); + return std::nullopt; + } + + if (worker::TaskExecutorResponseType::Result + != object.via.array.ptr[0].as()) + { + spdlog::error( + "Cannot split result into buffers: Wrong response type {}", + static_cast>( + object.via.array.ptr[0].as() + ) + ); + return std::nullopt; + } + + for (size_t i = 1; i < object.via.array.size; ++i) { + msgpack::object const& obj = object.via.array.ptr[i]; + result_buffers.emplace_back(); + msgpack::pack(result_buffers.back(), obj); + } + return result_buffers; + } catch (msgpack::type_error& e) { + spdlog::error("Cannot split result into buffers: {}", e.what()); + return std::nullopt; + } + // NOLINTEND(cppcoreguidelines-pro-type-union-access,cppcoreguidelines-pro-bounds-pointer-arithmetic) +} + +template auto create_result_response(T const& t) -> msgpack::sbuffer { msgpack::sbuffer buffer; msgpack::packer packer{buffer}; @@ -136,6 +206,16 @@ auto create_result_response(T const& t) -> msgpack::sbuffer { return buffer; } +template +auto create_result_response(std::tuple const& t) -> msgpack::sbuffer { + msgpack::sbuffer buffer; + msgpack::packer packer{buffer}; + packer.pack_array(sizeof...(Values) + 1); + packer.pack(worker::TaskExecutorResponseType::Result); + (..., packer.pack(std::get(t))); + return buffer; +} + // NOLINTBEGIN(cppcoreguidelines-missing-std-forward) template auto create_args_buffers(Args&&... args) -> ArgsBuffer { @@ -157,6 +237,19 @@ auto create_args_request(Args&&... args) -> msgpack::sbuffer { return buffer; } +inline auto create_args_request(std::vector const& args_buffers +) -> msgpack::sbuffer { + msgpack::sbuffer buffer; + msgpack::packer packer{buffer}; + packer.pack_array(2); + packer.pack(worker::TaskExecutorRequestType::Arguments); + packer.pack_array(args_buffers.size()); + for (msgpack::sbuffer const& args_buffer : args_buffers) { + buffer.write(args_buffer.data(), args_buffer.size()); + } + return buffer; +} + // NOLINTEND(cppcoreguidelines-missing-std-forward) template diff --git a/src/spider/worker/TaskExecutor.cpp b/src/spider/worker/TaskExecutor.cpp index d32819d..aabc1f8 100644 --- a/src/spider/worker/TaskExecutor.cpp +++ b/src/spider/worker/TaskExecutor.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include @@ -18,7 +19,8 @@ namespace spider::worker { auto TaskExecutor::completed() -> bool { std::lock_guard const lock(m_state_mutex); - return TaskExecutorState::Succeed == m_state || TaskExecutorState::Error == m_state; + return TaskExecutorState::Succeed == m_state || TaskExecutorState::Error == m_state + || TaskExecutorState::Cancelled == m_state; } auto TaskExecutor::waiting() -> bool { @@ -48,7 +50,14 @@ void TaskExecutor::wait() { m_result_buffer ); } + return; } + std::unique_lock lock(m_state_mutex); + m_complete_cv.wait(lock, [this] { + return TaskExecutorState::Succeed == m_state || TaskExecutorState::Error == m_state + || TaskExecutorState::Cancelled == m_state; + }); + lock.unlock(); } void TaskExecutor::cancel() { @@ -79,23 +88,32 @@ auto TaskExecutor::process_output_handler() -> boost::asio::awaitable { case TaskExecutorResponseType::Block: break; case TaskExecutorResponseType::Error: { - std::lock_guard const lock(m_state_mutex); - m_state = TaskExecutorState::Error; - m_result_buffer.write(response.data(), response.size()); + { + std::lock_guard const lock(m_state_mutex); + m_state = TaskExecutorState::Error; + m_result_buffer.write(response.data(), response.size()); + } + m_complete_cv.notify_all(); co_return; } case TaskExecutorResponseType::Ready: break; case TaskExecutorResponseType::Result: { - std::lock_guard const lock(m_state_mutex); - m_state = TaskExecutorState::Succeed; - m_result_buffer.write(response.data(), response.size()); + { + std::lock_guard const lock(m_state_mutex); + m_state = TaskExecutorState::Succeed; + m_result_buffer.write(response.data(), response.size()); + } + m_complete_cv.notify_all(); co_return; } case TaskExecutorResponseType::Cancel: { - std::lock_guard const lock(m_state_mutex); - m_state = TaskExecutorState::Cancelled; - m_result_buffer.write(response.data(), response.size()); + { + std::lock_guard const lock(m_state_mutex); + m_state = TaskExecutorState::Cancelled; + m_result_buffer.write(response.data(), response.size()); + } + m_complete_cv.notify_all(); co_return; } case TaskExecutorResponseType::Unknown: @@ -106,6 +124,10 @@ auto TaskExecutor::process_output_handler() -> boost::asio::awaitable { // NOLINTEND(clang-analyzer-core.CallAndMessage) +auto TaskExecutor::get_result_buffers() const -> std::optional> { + return core::response_get_result_buffers(m_result_buffer); +} + auto TaskExecutor::get_error() const -> std::tuple { return core::response_get_error(m_result_buffer) .value_or(std::make_tuple( diff --git a/src/spider/worker/TaskExecutor.hpp b/src/spider/worker/TaskExecutor.hpp index 1c9a4dd..25fc93b 100644 --- a/src/spider/worker/TaskExecutor.hpp +++ b/src/spider/worker/TaskExecutor.hpp @@ -1,6 +1,7 @@ #ifndef SPIDER_WORKER_TASKEXECUTOR_HPP #define SPIDER_WORKER_TASKEXECUTOR_HPP +#include #include #include #include @@ -34,7 +35,7 @@ enum class TaskExecutorState : std::uint8_t { class TaskExecutor { public: template - explicit TaskExecutor( + TaskExecutor( boost::asio::io_context& context, std::string const& func_name, std::vector const& libs, @@ -67,7 +68,45 @@ class TaskExecutor { boost::asio::co_spawn(context, process_output_handler(), boost::asio::detached); // Send args - msgpack::sbuffer args_request = core::create_args_request(std::forward(args)...); + msgpack::sbuffer const args_request + = core::create_args_request(std::forward(args)...); + send_message(m_write_pipe, args_request); + } + + TaskExecutor( + boost::asio::io_context& context, + std::string const& func_name, + std::vector const& libs, + absl::flat_hash_map< + boost::process::v2::environment::key, + boost::process::v2::environment::value> const& environment, + std::vector const& args_buffers + ) + : m_read_pipe(context), + m_write_pipe(context) { + std::vector process_args{"--func", func_name, "--libs"}; + process_args.insert(process_args.end(), libs.begin(), libs.end()); + boost::filesystem::path const exe = boost::process::v2::environment::find_executable( + "spider_task_executor", + environment + ); + m_process = std::make_unique( + context, + exe, + process_args, + boost::process::v2::process_stdio{ + .in = m_write_pipe, + .out = m_read_pipe, + .err = {/*stderr to default*/} + }, + boost::process::v2::process_environment{environment} + ); + + // Set up handler for output file + boost::asio::co_spawn(context, process_output_handler(), boost::asio::detached); + + // Send args + msgpack::sbuffer const args_request = core::create_args_request(args_buffers); send_message(m_write_pipe, args_request); } @@ -87,16 +126,19 @@ class TaskExecutor { void cancel(); template - auto get_result() -> std::optional { + auto get_result() const -> std::optional { return core::response_get_result(m_result_buffer); } + [[nodiscard]] auto get_result_buffers() const -> std::optional>; + [[nodiscard]] auto get_error() const -> std::tuple; private: auto process_output_handler() -> boost::asio::awaitable; std::mutex m_state_mutex; + std::condition_variable m_complete_cv; TaskExecutorState m_state = TaskExecutorState::Running; // Use `std::unique_ptr` to work around requirement of default constructor diff --git a/src/spider/worker/WorkerClient.cpp b/src/spider/worker/WorkerClient.cpp index c894604..0a3b6be 100644 --- a/src/spider/worker/WorkerClient.cpp +++ b/src/spider/worker/WorkerClient.cpp @@ -70,7 +70,8 @@ auto WorkerClient::get_next_task() -> std::optional { ); try { // Create socket to scheduler - boost::asio::ip::tcp::socket socket(m_context); + boost::asio::io_context context; + boost::asio::ip::tcp::socket socket(context); boost::asio::connect(socket, endpoints); scheduler::ScheduleTaskRequest const request{m_worker_id, m_worker_addr}; @@ -92,6 +93,9 @@ auto WorkerClient::get_next_task() -> std::optional { = msgpack::unpack(response_buffer.data(), response_buffer.size()); response_handle.get().convert(response); + if (!response.has_task_id()) { + return std::nullopt; + } return response.get_task_id(); } catch (boost::system::system_error const& e) { return std::nullopt; diff --git a/src/spider/worker/WorkerClient.hpp b/src/spider/worker/WorkerClient.hpp index 3b8a9b7..0ccb007 100644 --- a/src/spider/worker/WorkerClient.hpp +++ b/src/spider/worker/WorkerClient.hpp @@ -41,8 +41,6 @@ class WorkerClient { boost::uuids::uuid m_worker_id; std::string m_worker_addr; - boost::asio::io_context m_context; - std::shared_ptr m_data_store; std::shared_ptr m_metadata_store; }; diff --git a/src/spider/worker/task_executor.cpp b/src/spider/worker/task_executor.cpp index 796576d..7901808 100644 --- a/src/spider/worker/task_executor.cpp +++ b/src/spider/worker/task_executor.cpp @@ -57,6 +57,10 @@ auto main(int const argc, char** argv) -> int { // Set up spdlog to write to stderr // NOLINTNEXTLINE(misc-include-cleaner) spdlog::set_default_logger(spdlog::stderr_color_mt("stderr")); + spdlog::set_pattern("[%Y-%m-%d %H:%M:%S.%e] [%^%l%$] [spider.executor] %v"); +#ifndef NDEBUG + spdlog::set_level(spdlog::level::trace); +#endif boost::program_options::variables_map const args = parse_arg(argc, argv); @@ -82,6 +86,8 @@ auto main(int const argc, char** argv) -> int { return cCmdArgParseErr; } + spdlog::debug("Function to run: {}", func_name); + try { // Set up asio boost::asio::io_context context; @@ -105,6 +111,7 @@ auto main(int const argc, char** argv) -> int { msgpack::sbuffer args_buffer; msgpack::packer packer{args_buffer}; packer.pack(args_object); + spdlog::debug("Args buffer parsed"); // Run function spider::core::Function const* function @@ -120,6 +127,7 @@ auto main(int const argc, char** argv) -> int { return cResultSendErr; } msgpack::sbuffer const result_buffer = (*function)(args_buffer); + spdlog::debug("Function executed"); // Write result buffer to stdout spider::worker::send_message(out, result_buffer); diff --git a/src/spider/worker/worker.cpp b/src/spider/worker/worker.cpp index dfadebc..c80a0fe 100644 --- a/src/spider/worker/worker.cpp +++ b/src/spider/worker/worker.cpp @@ -1,4 +1,386 @@ -auto main(int /*argc*/, char** /*argv*/) -> int { +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include // IWYU pragma: keep +#include + +#include "../core/Data.hpp" +#include "../core/Driver.hpp" +#include "../core/Error.hpp" +#include "../core/Task.hpp" +#include "../io/BoostAsio.hpp" // IWYU pragma: keep +#include "../io/MsgPack.hpp" // IWYU pragma: keep +#include "../io/Serializer.hpp" // IWYU pragma: keep +#include "../storage/DataStorage.hpp" +#include "../storage/MetadataStorage.hpp" +#include "../storage/MysqlStorage.hpp" +#include "../utils/StopToken.hpp" +#include "TaskExecutor.hpp" +#include "WorkerClient.hpp" + +constexpr int cCmdArgParseErr = 1; +constexpr int cWorkerAddrErr = 2; +constexpr int cStorageConnectionErr = 3; +constexpr int cStorageErr = 4; +constexpr int cTaskErr = 5; + +constexpr int cRetryCount = 5; + +namespace { +auto parse_args(int const argc, char** argv) -> boost::program_options::variables_map { + boost::program_options::options_description desc; + desc.add_options()("help", "spider scheduler"); + desc.add_options()( + "storage_url", + boost::program_options::value(), + "storage server url" + ); + desc.add_options()( + "libs", + boost::program_options::value>(), + "dynamic libraries that include the spider tasks" + ); + + boost::program_options::variables_map variables; + boost::program_options::store( + // NOLINTNEXTLINE(misc-include-cleaner) + boost::program_options::parse_command_line(argc, argv, desc), + variables + ); + boost::program_options::notify(variables); + return variables; +} + +auto get_environment_variable() -> absl::flat_hash_map< + boost::process::v2::environment::key, + boost::process::v2::environment::value> { + boost::filesystem::path const executable_dir = boost::dll::program_location().parent_path(); + + // NOLINTNEXTLINE(concurrency-mt-unsafe) + char const* path_env_str = std::getenv("PATH"); + std::string path_env = nullptr == path_env_str ? "" : path_env_str; + path_env.append(":"); + path_env.append(executable_dir.string()); + + absl::flat_hash_map< + boost::process::v2::environment::key, + boost::process::v2::environment::value> + environment_variables; + + environment_variables.emplace("PATH", path_env); + + return environment_variables; +} + +auto heartbeat_loop( + std::shared_ptr const& metadata_store, + spider::core::Driver const& driver, + spider::core::StopToken& stop_token +) -> void { + int fail_count = 0; + while (!stop_token.stop_requested()) { + std::this_thread::sleep_for(std::chrono::seconds(1)); + spdlog::debug("Updating heartbeat"); + spider::core::StorageErr const err = metadata_store->update_heartbeat(driver.get_id()); + if (!err.success()) { + spdlog::error("Failed to update scheduler heartbeat: {}", err.description); + fail_count++; + } else { + fail_count = 0; + } + if (fail_count >= cRetryCount - 1) { + stop_token.request_stop(); + break; + } + } +} + +constexpr int cFetchTaskTimeout = 100; + +auto fetch_task(spider::worker::WorkerClient& client) -> boost::uuids::uuid { + spdlog::debug("Fetching task"); + while (true) { + std::optional const optional_task_id = client.get_next_task(); + if (optional_task_id.has_value()) { + return optional_task_id.value(); + } + std::this_thread::sleep_for(std::chrono::milliseconds(cFetchTaskTimeout)); + } +} + +auto get_args_buffers(spider::core::Task const& task +) -> std::optional> { + std::vector args_buffers; + for (spider::core::TaskInput const& input : task.get_inputs()) { + std::optional const optional_value = input.get_value(); + if (optional_value.has_value()) { + std::string const& value = optional_value.value(); + args_buffers.emplace_back(); + args_buffers.back().write(value.data(), value.size()); + continue; + } + std::optional const optional_data_id = input.get_data_id(); + if (optional_data_id.has_value()) { + boost::uuids::uuid const data_id = optional_data_id.value(); + args_buffers.emplace_back(); + msgpack::pack(args_buffers.back(), data_id); + continue; + } + spdlog::error( + "Task {} {} input has no value or data id", + task.get_function_name(), + boost::uuids::to_string(task.get_id()) + ); + return std::nullopt; + } + return args_buffers; +} + +auto parse_outputs( + spider::core::Task const& task, + std::vector const& result_buffers +) -> std::optional> { + std::vector outputs; + outputs.reserve(task.get_num_outputs()); + for (size_t i = 0; i < task.get_num_outputs(); ++i) { + std::string const type = task.get_output(i).get_type(); + if (type == typeid(spider::core::Data).name()) { + try { + msgpack::object_handle const handle + = msgpack::unpack(result_buffers[i].data(), result_buffers[i].size()); + msgpack::object const obj = handle.get(); + boost::uuids::uuid data_id; + obj.convert(data_id); + outputs.emplace_back(data_id, type); + } catch (std::runtime_error const& e) { + spdlog::error( + "Task {} failed to parse result as data id", + task.get_function_name() + ); + return std::nullopt; + } + } else { + msgpack::sbuffer const& buffer = result_buffers[i]; + std::string const value{buffer.data(), buffer.size()}; + outputs.emplace_back(value, type); + } + } + return outputs; +} + +// NOLINTBEGIN(clang-analyzer-unix.BlockInCriticalSection) +auto task_loop( + std::shared_ptr const& metadata_store, + spider::worker::WorkerClient& client, + std::vector const& libs, + absl::flat_hash_map< + boost::process::v2::environment::key, + boost::process::v2::environment::value> const& environment, + spider::core::StopToken const& stop_token +) -> void { + while (!stop_token.stop_requested()) { + boost::asio::io_context context; + boost::uuids::uuid const task_id = fetch_task(client); + spdlog::debug("Fetched task {}", boost::uuids::to_string(task_id)); + // Fetch task detail from metadata storage + spider::core::Task task{""}; + spider::core::StorageErr err = metadata_store->get_task(task_id, &task); + if (!err.success()) { + spdlog::error("Failed to fetch task detail: {}", err.description); + continue; + } + + // Update task status to running + err = metadata_store->set_task_state(task_id, spider::core::TaskState::Running); + if (!err.success()) { + spdlog::error("Failed to update task status to running: {}", err.description); + continue; + } + spider::core::TaskInstance const instance{task_id}; + spdlog::debug("Adding task instance"); + err = metadata_store->add_task_instance(instance); + if (!err.success()) { + spdlog::error("Failed to add task instance: {}", err.description); + continue; + } + + // Set up arguments + std::optional> const optional_args_buffers + = get_args_buffers(task); + if (!optional_args_buffers.has_value()) { + continue; + } + std::vector const& args_buffers = optional_args_buffers.value(); + + // Execute task + spider::worker::TaskExecutor + executor{context, task.get_function_name(), libs, environment, args_buffers}; + + context.run(); + executor.wait(); + + if (!executor.succeed()) { + spdlog::warn("Task {} failed", task.get_function_name()); + metadata_store->task_fail( + instance, + fmt::format("Task {} failed", task.get_function_name()) + ); + continue; + } + + // Parse result + std::optional> const optional_result_buffers + = executor.get_result_buffers(); + if (!optional_result_buffers.has_value()) { + spdlog::error("Task {} failed to parse result into buffers", task.get_function_name()); + metadata_store->task_fail( + instance, + fmt::format( + "Task {} failed to parse result into buffers", + task.get_function_name() + ) + ); + continue; + } + std::vector const& result_buffers = optional_result_buffers.value(); + std::optional> const optional_outputs + = parse_outputs(task, result_buffers); + if (!optional_outputs.has_value()) { + metadata_store->task_fail( + instance, + fmt::format( + "Task {} failed to parse result into TaskOutput", + task.get_function_name() + ) + ); + continue; + } + std::vector const& outputs = optional_outputs.value(); + // Submit result + spdlog::debug("Submitting result for task {}", boost::uuids::to_string(task_id)); + err = metadata_store->task_finish(instance, outputs); + if (!err.success()) { + spdlog::error("Submit task {} fails: {}", task.get_function_name(), err.description); + } + } +} + +// NOLINTEND(clang-analyzer-unix.BlockInCriticalSection) + +} // namespace + +// NOLINTNEXTLINE(bugprone-exception-escape) +auto main(int argc, char** argv) -> int { + // Set up spdlog to write to stderr + // NOLINTNEXTLINE(misc-include-cleaner) + spdlog::set_pattern("[%Y-%m-%d %H:%M:%S.%e] [%^%l%$] [spider.worker] %v"); +#ifndef NDEBUG + spdlog::set_level(spdlog::level::trace); +#endif + + boost::program_options::variables_map const args = parse_args(argc, argv); + + std::string storage_url; + std::vector libs; + try { + if (!args.contains("storage_url") || !args.contains("libs")) { + spdlog::error("Error: missing required arguments"); + return cCmdArgParseErr; + } + storage_url = args["storage_url"].as(); + libs = args["libs"].as>(); + } catch (boost::bad_any_cast const& e) { + spdlog::error("Error: {}", e.what()); + return cCmdArgParseErr; + } catch (boost::program_options::error const& e) { + spdlog::error("Error: {}", e.what()); + return cCmdArgParseErr; + } + + // Create storage + std::shared_ptr const metadata_store + = std::make_shared(); + spider::core::StorageErr err = metadata_store->connect(storage_url); + if (!err.success()) { + spdlog::error("Cannot connect to metadata storage: {}", err.description); + return cStorageConnectionErr; + } + std::shared_ptr const data_store + = std::make_shared(); + err = data_store->connect(storage_url); + if (!err.success()) { + spdlog::error("Cannot connect to data storage: {}", err.description); + return cStorageConnectionErr; + } + std::optional const optional_worker_addr = spider::core::get_address(); + if (!optional_worker_addr.has_value()) { + spdlog::error("Failed to get worker address"); + return cWorkerAddrErr; + } + std::string const& worker_addr = optional_worker_addr.value(); + + boost::uuids::random_generator gen; + boost::uuids::uuid const worker_id = gen(); + spider::core::Driver driver{worker_id, worker_addr}; + err = metadata_store->add_driver(driver); + if (!err.success()) { + spdlog::error("Cannot add driver to metadata storage: {}", err.description); + return cStorageErr; + } + + spider::core::StopToken stop_token; + + // Start client + spider::worker::WorkerClient client{worker_id, worker_addr, data_store, metadata_store}; + + absl::flat_hash_map< + boost::process::v2::environment::key, + boost::process::v2::environment::value> const environment_variables + = get_environment_variable(); + + // Start a thread that periodically updates the scheduler's heartbeat + std::thread heartbeat_thread{ + heartbeat_loop, + std::cref(metadata_store), + std::ref(driver), + std::ref(stop_token) + }; + + // Start a thread that processes tasks + std::thread task_thread{ + task_loop, + std::cref(metadata_store), + std::ref(client), + std::cref(libs), + std::cref(environment_variables), + std::cref(stop_token), + }; + + heartbeat_thread.join(); + task_thread.join(); + return 0; } diff --git a/tests/scheduler/test-SchedulerServer.cpp b/tests/scheduler/test-SchedulerServer.cpp index 775fbae..c2bafcb 100644 --- a/tests/scheduler/test-SchedulerServer.cpp +++ b/tests/scheduler/test-SchedulerServer.cpp @@ -1,6 +1,8 @@ // NOLINTBEGIN(cert-err58-cpp,cppcoreguidelines-avoid-do-while,readability-function-cognitive-complexity,cppcoreguidelines-avoid-non-const-global-variables,cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,clang-analyzer-optin.core.EnumCastOutOfRange) +#include #include #include +#include #include #include #include @@ -25,6 +27,9 @@ #include "../storage/StorageTestHelper.hpp" namespace { + +constexpr int cServerWarmupTime = 5; + TEMPLATE_LIST_TEST_CASE( "Scheduler server test", "[scheduler][server][storage]", @@ -51,6 +56,8 @@ TEMPLATE_LIST_TEST_CASE( // Pause and resume server server.pause(); server.resume(); + // Sleep for a while to let the server start + std::this_thread::sleep_for(std::chrono::milliseconds(cServerWarmupTime)); // Create client socket boost::asio::io_context context; @@ -75,6 +82,11 @@ TEMPLATE_LIST_TEST_CASE( msgpack::pack(req_buffer, req); REQUIRE(spider::core::send_message(socket, req_buffer)); + // Pause and resume server + server.pause(); + server.resume(); + std::this_thread::sleep_for(std::chrono::milliseconds(cServerWarmupTime)); + // Get response should succeed and get child task std::optional const& res_buffer = spider::core::receive_message(socket); REQUIRE(metadata_store->remove_job(job_id).success()); diff --git a/tests/worker/test-FunctionManager.cpp b/tests/worker/test-FunctionManager.cpp index 5df5db3..34e0416 100644 --- a/tests/worker/test-FunctionManager.cpp +++ b/tests/worker/test-FunctionManager.cpp @@ -31,8 +31,11 @@ TEST_CASE("Register and run function with POD inputs", "[core]") { // Run function with two ints should succeed spider::core::ArgsBuffer const args_buffers = spider::core::create_args_buffers(2, 3); + constexpr int cExpected = 2 + 3; msgpack::sbuffer const result = (*function)(args_buffers); - REQUIRE(5 == spider::core::response_get_result(result).value_or(0)); + msgpack::sbuffer buffer{}; + msgpack::pack(buffer, cExpected); + REQUIRE(cExpected == spider::core::response_get_result(result).value_or(0)); // Run function with wrong number of inputs should fail spider::core::ArgsBuffer wrong_args_buffers = spider::core::create_args_buffers(1); @@ -64,7 +67,7 @@ TEST_CASE("Register and run function with tuple return", "[core]") { spider::core::ArgsBuffer const args_buffers = spider::core::create_args_buffers("test", 3); msgpack::sbuffer const result = (*function)(args_buffers); REQUIRE(std::make_tuple("test", 3) - == spider::core::response_get_result>(result).value_or( + == spider::core::response_get_result(result).value_or( std::make_tuple("", 0) )); } diff --git a/tests/worker/test-MessagePipe.cpp b/tests/worker/test-MessagePipe.cpp index 5c370ac..e810b2c 100644 --- a/tests/worker/test-MessagePipe.cpp +++ b/tests/worker/test-MessagePipe.cpp @@ -44,7 +44,7 @@ TEST_CASE("pipe message response", "[worker]") { REQUIRE(spider::worker::TaskExecutorResponseType::Result == spider::worker::get_response_type(response_buffer)); std::optional> const parse_response - = spider::core::response_get_result>(response_buffer); + = spider::core::response_get_result(response_buffer); REQUIRE(parse_response.has_value()); if (parse_response.has_value()) { std::tuple result = parse_response.value(); diff --git a/tests/worker/worker-test.cpp b/tests/worker/worker-test.cpp index c02453d..04d97cf 100644 --- a/tests/worker/worker-test.cpp +++ b/tests/worker/worker-test.cpp @@ -1,9 +1,11 @@ +#include #include #include "../../src/spider/worker/FunctionManager.hpp" namespace { auto sum_test(int const x, int const y) -> int { + std::cerr << x << " + " << y << " = " << x + y << "\n"; return x + y; }