Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add worker that request tasks from scheduler and execute tasks #38

Merged
merged 4 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/spider/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ set(SPIDER_WORKER_SOURCES
worker/message_pipe.hpp
worker/WorkerClient.hpp
worker/WorkerClient.cpp
utils/StopToken.hpp
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Consider moving shared utilities to core component

StopToken.hpp is used by both worker and scheduler components, suggesting it's a core utility. Consider moving it to the SPIDER_CORE_HEADERS list to avoid duplication and better reflect its shared nature.

-    utils/StopToken.hpp

Add to SPIDER_CORE_HEADERS:

 set(SPIDER_CORE_HEADERS
     core/Error.hpp
     core/Data.hpp
+    utils/StopToken.hpp
     ...
 )

Committable suggestion skipped: line range outside the PR's diff.

CACHE INTERNAL
"spider worker source files"
)
Expand Down
2 changes: 2 additions & 0 deletions src/spider/storage/MetadataStorage.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ class MetadataStorage {
virtual auto add_task_instance(TaskInstance const& instance) -> StorageErr = 0;
virtual auto task_finish(TaskInstance const& instance, std::vector<TaskOutput> const& outputs)
-> StorageErr = 0;
virtual auto task_fail(TaskInstance const& instance, std::string const& error) -> StorageErr
= 0;
virtual auto get_task_timeout(std::vector<TaskInstance>* tasks) -> StorageErr = 0;
virtual auto get_child_tasks(boost::uuids::uuid id, std::vector<Task>* children) -> StorageErr
= 0;
Expand Down
67 changes: 63 additions & 4 deletions src/spider/storage/MysqlStorage.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <iomanip>
#include <memory>
#include <optional>
#include <regex>
#include <sstream>
#include <string>
#include <tuple>
Expand All @@ -27,6 +28,7 @@
#include <mariadb/conncpp/ResultSet.hpp>
#include <mariadb/conncpp/Statement.hpp>
#include <mariadb/conncpp/Types.hpp>
#include <spdlog/spdlog.h>

#include "../core/Data.hpp"
#include "../core/Driver.hpp"
Expand Down Expand Up @@ -241,11 +243,21 @@ auto string_to_task_state(std::string const& state) -> spider::core::TaskState {
} // namespace

auto MySqlMetadataStorage::connect(std::string const& url) -> StorageErr {
// Parse jdbc url
std::regex const url_regex(R"(jdbc:mariadb://[^?]+(\?user=([^&]*)(&password=([^&]*))?)?)");
std::smatch match;
if (false == std::regex_match(url, match, url_regex)) {
return StorageErr{StorageErrType::OtherErr, "Invalid url"};
}
bool const credential = match[2].matched && match[4].matched;
if (nullptr == m_conn) {
try {
sql::Driver* driver = sql::mariadb::get_driver_instance();
sql::Properties const properties;
m_conn = driver->connect(sql::SQLString(url), properties);
if (credential) {
m_conn = driver->connect(sql::SQLString(url), match[2].str(), match[4].str());
} else {
m_conn = driver->connect(sql::SQLString(url), sql::Properties{});
}
m_conn->setAutoCommit(false);
} catch (sql::SQLException& e) {
return StorageErr{StorageErrType::ConnectionErr, e.what()};
Expand Down Expand Up @@ -1073,6 +1085,43 @@ auto MySqlMetadataStorage::task_finish(
return StorageErr{};
}

auto MySqlMetadataStorage::task_fail(TaskInstance const& instance, std::string const& /*error*/)
-> StorageErr {
try {
// Remove task instance
std::unique_ptr<sql::PreparedStatement> const statement(
m_conn->prepareStatement("DELETE FROM `task_instances` WHERE `id` = ?")
);
sql::bytes instance_id_bytes = uuid_get_bytes(instance.id);
statement->setBytes(1, &instance_id_bytes);
statement->executeUpdate();

// Get number of remaining instances
std::unique_ptr<sql::PreparedStatement> const count_statement(m_conn->prepareStatement(
"SELECT COUNT(*) FROM `task_instances` WHERE `task_id` = ?"
));
sql::bytes task_id_bytes = uuid_get_bytes(instance.task_id);
count_statement->setBytes(1, &task_id_bytes);
std::unique_ptr<sql::ResultSet> const count_res{count_statement->executeQuery()};
count_res->next();
int32_t const count = count_res->getInt(1);
if (count == 0) {
// Set the task fail if the last task instance fails
std::unique_ptr<sql::PreparedStatement> const task_statement(
m_conn->prepareStatement("UPDATE `tasks` SET `state` = 'fail' WHERE `id` = ?")
);
task_statement->setBytes(1, &task_id_bytes);
task_statement->executeUpdate();
}
} catch (sql::SQLException& e) {
spdlog::error("Task fail error: {}", e.what());
m_conn->rollback();
return StorageErr{StorageErrType::OtherErr, e.what()};
}
m_conn->commit();
return StorageErr{};
}

auto MySqlMetadataStorage::get_task_timeout(std::vector<TaskInstance>* tasks) -> StorageErr {
try {
std::unique_ptr<sql::Statement> statement(m_conn->createStatement());
Expand Down Expand Up @@ -1262,11 +1311,21 @@ auto MySqlMetadataStorage::set_scheduler_state(boost::uuids::uuid id, std::strin
}

auto MySqlDataStorage::connect(std::string const& url) -> StorageErr {
// Parse jdbc url
std::regex const url_regex(R"(jdbc:mariadb://[^?]+(\?user=([^&]*)(&password=([^&]*))?)?)");
std::smatch match;
if (false == std::regex_match(url, match, url_regex)) {
return StorageErr{StorageErrType::OtherErr, "Invalid url"};
}
bool const credential = match[2].matched && match[4].matched;
if (nullptr == m_conn) {
try {
sql::Driver* driver = sql::mariadb::get_driver_instance();
sql::Properties const properties;
m_conn = driver->connect(sql::SQLString(url), properties);
if (credential) {
m_conn = driver->connect(sql::SQLString(url), match[2].str(), match[4].str());
} else {
m_conn = driver->connect(sql::SQLString(url), sql::Properties{});
}
m_conn->setAutoCommit(false);
} catch (sql::SQLException& e) {
return StorageErr{StorageErrType::ConnectionErr, e.what()};
Expand Down
1 change: 1 addition & 0 deletions src/spider/storage/MysqlStorage.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class MySqlMetadataStorage : public MetadataStorage {
auto add_task_instance(TaskInstance const& instance) -> StorageErr override;
auto task_finish(TaskInstance const& instance, std::vector<TaskOutput> const& outputs)
-> StorageErr override;
auto task_fail(TaskInstance const& instance, std::string const& error) -> StorageErr override;
auto get_task_timeout(std::vector<TaskInstance>* tasks) -> StorageErr override;
auto get_child_tasks(boost::uuids::uuid id, std::vector<Task>* children) -> StorageErr override;
auto get_parent_tasks(boost::uuids::uuid id, std::vector<Task>* tasks) -> StorageErr override;
Expand Down
99 changes: 96 additions & 3 deletions src/spider/worker/FunctionManager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,14 @@
#include <tuple>
#include <type_traits>
#include <utility>
#include <vector>

#include <absl/container/flat_hash_map.h>
#include <fmt/format.h>
#include <spdlog/spdlog.h>

#include "../io/MsgPack.hpp" // IWYU pragma: keep
#include "../io/Serializer.hpp"
#include "TaskExecutorMessage.hpp"

// NOLINTBEGIN(cppcoreguidelines-macro-usage)
Expand Down Expand Up @@ -102,7 +105,7 @@ void create_error_buffer(
msgpack::sbuffer& buffer
);

template <class T>
template <Serializable T>
auto response_get_result(msgpack::sbuffer const& buffer) -> std::optional<T> {
// NOLINTBEGIN(cppcoreguidelines-pro-type-union-access,cppcoreguidelines-pro-bounds-pointer-arithmetic)
try {
Expand All @@ -119,14 +122,81 @@ auto response_get_result(msgpack::sbuffer const& buffer) -> std::optional<T> {
return std::nullopt;
}

return object.via.array.ptr[1].as<T>();
return std::make_optional(object.via.array.ptr[1].as<T>());
} catch (msgpack::type_error& e) {
return std::nullopt;
}
// NOLINTEND(cppcoreguidelines-pro-type-union-access,cppcoreguidelines-pro-bounds-pointer-arithmetic)
}

template <class T>
template <Serializable... Ts>
requires(sizeof...(Ts) > 1)
auto response_get_result(msgpack::sbuffer const& buffer) -> std::optional<std::tuple<Ts...>> {
// NOLINTBEGIN(cppcoreguidelines-pro-type-union-access,cppcoreguidelines-pro-bounds-pointer-arithmetic)
try {
msgpack::object_handle const handle = msgpack::unpack(buffer.data(), buffer.size());
msgpack::object const object = handle.get();

if (msgpack::type::ARRAY != object.type || sizeof...(Ts) + 1 != object.via.array.size) {
return std::nullopt;
}

if (worker::TaskExecutorResponseType::Result
!= object.via.array.ptr[0].as<worker::TaskExecutorResponseType>())
{
return std::nullopt;
}

std::tuple<Ts...> result;
for_n<sizeof...(Ts)>([&](auto i) {
object.via.array.ptr[i.cValue + 1].convert(std::get<i.cValue>(result));
});
return std::make_optional(result);
} catch (msgpack::type_error& e) {
return std::nullopt;
}
// NOLINTEND(cppcoreguidelines-pro-type-union-access,cppcoreguidelines-pro-bounds-pointer-arithmetic)
}

inline auto response_get_result_buffers(msgpack::sbuffer const& buffer
) -> std::optional<std::vector<msgpack::sbuffer>> {
// NOLINTBEGIN(cppcoreguidelines-pro-type-union-access,cppcoreguidelines-pro-bounds-pointer-arithmetic)
try {
std::vector<msgpack::sbuffer> result_buffers;
msgpack::object_handle const handle = msgpack::unpack(buffer.data(), buffer.size());
msgpack::object const object = handle.get();

if (msgpack::type::ARRAY != object.type || object.via.array.size < 2) {
spdlog::error("Cannot split result into buffers: Wrong type");
return std::nullopt;
}

if (worker::TaskExecutorResponseType::Result
!= object.via.array.ptr[0].as<worker::TaskExecutorResponseType>())
{
spdlog::error(
"Cannot split result into buffers: Wrong response type {}",
static_cast<std::underlying_type_t<worker::TaskExecutorResponseType>>(
object.via.array.ptr[0].as<worker::TaskExecutorResponseType>()
)
);
return std::nullopt;
}

for (size_t i = 1; i < object.via.array.size; ++i) {
msgpack::object const& obj = object.via.array.ptr[i];
result_buffers.emplace_back();
msgpack::pack(result_buffers.back(), obj);
}
return result_buffers;
} catch (msgpack::type_error& e) {
spdlog::error("Cannot split result into buffers: {}", e.what());
return std::nullopt;
}
// NOLINTEND(cppcoreguidelines-pro-type-union-access,cppcoreguidelines-pro-bounds-pointer-arithmetic)
}

template <Serializable T>
auto create_result_response(T const& t) -> msgpack::sbuffer {
msgpack::sbuffer buffer;
msgpack::packer packer{buffer};
Expand All @@ -136,6 +206,16 @@ auto create_result_response(T const& t) -> msgpack::sbuffer {
return buffer;
}

template <Serializable... Values>
auto create_result_response(std::tuple<Values...> const& t) -> msgpack::sbuffer {
msgpack::sbuffer buffer;
msgpack::packer packer{buffer};
packer.pack_array(sizeof...(Values) + 1);
packer.pack(worker::TaskExecutorResponseType::Result);
(..., packer.pack(std::get<Values>(t)));
return buffer;
}

// NOLINTBEGIN(cppcoreguidelines-missing-std-forward)
template <class... Args>
auto create_args_buffers(Args&&... args) -> ArgsBuffer {
Expand All @@ -157,6 +237,19 @@ auto create_args_request(Args&&... args) -> msgpack::sbuffer {
return buffer;
}

inline auto create_args_request(std::vector<msgpack::sbuffer> const& args_buffers
) -> msgpack::sbuffer {
msgpack::sbuffer buffer;
msgpack::packer packer{buffer};
packer.pack_array(2);
packer.pack(worker::TaskExecutorRequestType::Arguments);
packer.pack_array(args_buffers.size());
for (msgpack::sbuffer const& args_buffer : args_buffers) {
buffer.write(args_buffer.data(), args_buffer.size());
}
return buffer;
}

// NOLINTEND(cppcoreguidelines-missing-std-forward)

template <class F>
Expand Down
42 changes: 32 additions & 10 deletions src/spider/worker/TaskExecutor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <optional>
#include <string>
#include <tuple>
#include <vector>

#include <boost/process/v2/process.hpp>
#include <fmt/format.h>
Expand All @@ -18,7 +19,8 @@ namespace spider::worker {

auto TaskExecutor::completed() -> bool {
std::lock_guard const lock(m_state_mutex);
return TaskExecutorState::Succeed == m_state || TaskExecutorState::Error == m_state;
return TaskExecutorState::Succeed == m_state || TaskExecutorState::Error == m_state
|| TaskExecutorState::Cancelled == m_state;
}

auto TaskExecutor::waiting() -> bool {
Expand Down Expand Up @@ -48,7 +50,14 @@ void TaskExecutor::wait() {
m_result_buffer
);
}
return;
}
std::unique_lock lock(m_state_mutex);
m_complete_cv.wait(lock, [this] {
return TaskExecutorState::Succeed == m_state || TaskExecutorState::Error == m_state
|| TaskExecutorState::Cancelled == m_state;
});
lock.unlock();
}

void TaskExecutor::cancel() {
Expand Down Expand Up @@ -79,23 +88,32 @@ auto TaskExecutor::process_output_handler() -> boost::asio::awaitable<void> {
case TaskExecutorResponseType::Block:
break;
case TaskExecutorResponseType::Error: {
std::lock_guard const lock(m_state_mutex);
m_state = TaskExecutorState::Error;
m_result_buffer.write(response.data(), response.size());
{
std::lock_guard const lock(m_state_mutex);
m_state = TaskExecutorState::Error;
m_result_buffer.write(response.data(), response.size());
}
m_complete_cv.notify_all();
co_return;
}
case TaskExecutorResponseType::Ready:
break;
case TaskExecutorResponseType::Result: {
std::lock_guard const lock(m_state_mutex);
m_state = TaskExecutorState::Succeed;
m_result_buffer.write(response.data(), response.size());
{
std::lock_guard const lock(m_state_mutex);
m_state = TaskExecutorState::Succeed;
m_result_buffer.write(response.data(), response.size());
}
m_complete_cv.notify_all();
co_return;
}
case TaskExecutorResponseType::Cancel: {
std::lock_guard const lock(m_state_mutex);
m_state = TaskExecutorState::Cancelled;
m_result_buffer.write(response.data(), response.size());
{
std::lock_guard const lock(m_state_mutex);
m_state = TaskExecutorState::Cancelled;
m_result_buffer.write(response.data(), response.size());
}
m_complete_cv.notify_all();
co_return;
}
case TaskExecutorResponseType::Unknown:
Expand All @@ -106,6 +124,10 @@ auto TaskExecutor::process_output_handler() -> boost::asio::awaitable<void> {

// NOLINTEND(clang-analyzer-core.CallAndMessage)

auto TaskExecutor::get_result_buffers() const -> std::optional<std::vector<msgpack::sbuffer>> {
return core::response_get_result_buffers(m_result_buffer);
}

auto TaskExecutor::get_error() const -> std::tuple<core::FunctionInvokeError, std::string> {
return core::response_get_error(m_result_buffer)
.value_or(std::make_tuple(
Expand Down
Loading
Loading