Skip to content

Commit

Permalink
feat: Add worker that request tasks from scheduler and execute tasks (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
sitaowang1998 authored Dec 18, 2024
1 parent 88f6232 commit 9169591
Show file tree
Hide file tree
Showing 16 changed files with 657 additions and 28 deletions.
1 change: 1 addition & 0 deletions src/spider/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ set(SPIDER_WORKER_SOURCES
worker/message_pipe.hpp
worker/WorkerClient.hpp
worker/WorkerClient.cpp
utils/StopToken.hpp
CACHE INTERNAL
"spider worker source files"
)
Expand Down
2 changes: 1 addition & 1 deletion src/spider/scheduler/scheduler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ auto cleanup_loop(
auto main(int argc, char** argv) -> int {
// Set up spdlog to write to stderr
// NOLINTNEXTLINE(misc-include-cleaner)
spdlog::set_pattern("[%Y-%m-%d %H:%M:%S.%e] [%^%l%$] [spider][scheduler] %v");
spdlog::set_pattern("[%Y-%m-%d %H:%M:%S.%e] [%^%l%$] [spider.scheduler] %v");
#ifndef NDEBUG
spdlog::set_level(spdlog::level::trace);
#endif
Expand Down
2 changes: 2 additions & 0 deletions src/spider/storage/MetadataStorage.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ class MetadataStorage {
virtual auto add_task_instance(TaskInstance const& instance) -> StorageErr = 0;
virtual auto task_finish(TaskInstance const& instance, std::vector<TaskOutput> const& outputs)
-> StorageErr = 0;
virtual auto task_fail(TaskInstance const& instance, std::string const& error) -> StorageErr
= 0;
virtual auto get_task_timeout(std::vector<TaskInstance>* tasks) -> StorageErr = 0;
virtual auto get_child_tasks(boost::uuids::uuid id, std::vector<Task>* children) -> StorageErr
= 0;
Expand Down
67 changes: 63 additions & 4 deletions src/spider/storage/MysqlStorage.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <iomanip>
#include <memory>
#include <optional>
#include <regex>
#include <sstream>
#include <string>
#include <tuple>
Expand All @@ -27,6 +28,7 @@
#include <mariadb/conncpp/ResultSet.hpp>
#include <mariadb/conncpp/Statement.hpp>
#include <mariadb/conncpp/Types.hpp>
#include <spdlog/spdlog.h>

#include "../core/Data.hpp"
#include "../core/Driver.hpp"
Expand Down Expand Up @@ -241,11 +243,21 @@ auto string_to_task_state(std::string const& state) -> spider::core::TaskState {
} // namespace

auto MySqlMetadataStorage::connect(std::string const& url) -> StorageErr {
// Parse jdbc url
std::regex const url_regex(R"(jdbc:mariadb://[^?]+(\?user=([^&]*)(&password=([^&]*))?)?)");
std::smatch match;
if (false == std::regex_match(url, match, url_regex)) {
return StorageErr{StorageErrType::OtherErr, "Invalid url"};
}
bool const credential = match[2].matched && match[4].matched;
if (nullptr == m_conn) {
try {
sql::Driver* driver = sql::mariadb::get_driver_instance();
sql::Properties const properties;
m_conn = driver->connect(sql::SQLString(url), properties);
if (credential) {
m_conn = driver->connect(sql::SQLString(url), match[2].str(), match[4].str());
} else {
m_conn = driver->connect(sql::SQLString(url), sql::Properties{});
}
m_conn->setAutoCommit(false);
} catch (sql::SQLException& e) {
return StorageErr{StorageErrType::ConnectionErr, e.what()};
Expand Down Expand Up @@ -1073,6 +1085,43 @@ auto MySqlMetadataStorage::task_finish(
return StorageErr{};
}

auto MySqlMetadataStorage::task_fail(TaskInstance const& instance, std::string const& /*error*/)
-> StorageErr {
try {
// Remove task instance
std::unique_ptr<sql::PreparedStatement> const statement(
m_conn->prepareStatement("DELETE FROM `task_instances` WHERE `id` = ?")
);
sql::bytes instance_id_bytes = uuid_get_bytes(instance.id);
statement->setBytes(1, &instance_id_bytes);
statement->executeUpdate();

// Get number of remaining instances
std::unique_ptr<sql::PreparedStatement> const count_statement(m_conn->prepareStatement(
"SELECT COUNT(*) FROM `task_instances` WHERE `task_id` = ?"
));
sql::bytes task_id_bytes = uuid_get_bytes(instance.task_id);
count_statement->setBytes(1, &task_id_bytes);
std::unique_ptr<sql::ResultSet> const count_res{count_statement->executeQuery()};
count_res->next();
int32_t const count = count_res->getInt(1);
if (count == 0) {
// Set the task fail if the last task instance fails
std::unique_ptr<sql::PreparedStatement> const task_statement(
m_conn->prepareStatement("UPDATE `tasks` SET `state` = 'fail' WHERE `id` = ?")
);
task_statement->setBytes(1, &task_id_bytes);
task_statement->executeUpdate();
}
} catch (sql::SQLException& e) {
spdlog::error("Task fail error: {}", e.what());
m_conn->rollback();
return StorageErr{StorageErrType::OtherErr, e.what()};
}
m_conn->commit();
return StorageErr{};
}

auto MySqlMetadataStorage::get_task_timeout(std::vector<TaskInstance>* tasks) -> StorageErr {
try {
std::unique_ptr<sql::Statement> statement(m_conn->createStatement());
Expand Down Expand Up @@ -1262,11 +1311,21 @@ auto MySqlMetadataStorage::set_scheduler_state(boost::uuids::uuid id, std::strin
}

auto MySqlDataStorage::connect(std::string const& url) -> StorageErr {
// Parse jdbc url
std::regex const url_regex(R"(jdbc:mariadb://[^?]+(\?user=([^&]*)(&password=([^&]*))?)?)");
std::smatch match;
if (false == std::regex_match(url, match, url_regex)) {
return StorageErr{StorageErrType::OtherErr, "Invalid url"};
}
bool const credential = match[2].matched && match[4].matched;
if (nullptr == m_conn) {
try {
sql::Driver* driver = sql::mariadb::get_driver_instance();
sql::Properties const properties;
m_conn = driver->connect(sql::SQLString(url), properties);
if (credential) {
m_conn = driver->connect(sql::SQLString(url), match[2].str(), match[4].str());
} else {
m_conn = driver->connect(sql::SQLString(url), sql::Properties{});
}
m_conn->setAutoCommit(false);
} catch (sql::SQLException& e) {
return StorageErr{StorageErrType::ConnectionErr, e.what()};
Expand Down
1 change: 1 addition & 0 deletions src/spider/storage/MysqlStorage.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class MySqlMetadataStorage : public MetadataStorage {
auto add_task_instance(TaskInstance const& instance) -> StorageErr override;
auto task_finish(TaskInstance const& instance, std::vector<TaskOutput> const& outputs)
-> StorageErr override;
auto task_fail(TaskInstance const& instance, std::string const& error) -> StorageErr override;
auto get_task_timeout(std::vector<TaskInstance>* tasks) -> StorageErr override;
auto get_child_tasks(boost::uuids::uuid id, std::vector<Task>* children) -> StorageErr override;
auto get_parent_tasks(boost::uuids::uuid id, std::vector<Task>* tasks) -> StorageErr override;
Expand Down
99 changes: 96 additions & 3 deletions src/spider/worker/FunctionManager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,14 @@
#include <tuple>
#include <type_traits>
#include <utility>
#include <vector>

#include <absl/container/flat_hash_map.h>
#include <fmt/format.h>
#include <spdlog/spdlog.h>

#include "../io/MsgPack.hpp" // IWYU pragma: keep
#include "../io/Serializer.hpp"
#include "TaskExecutorMessage.hpp"

// NOLINTBEGIN(cppcoreguidelines-macro-usage)
Expand Down Expand Up @@ -102,7 +105,7 @@ void create_error_buffer(
msgpack::sbuffer& buffer
);

template <class T>
template <Serializable T>
auto response_get_result(msgpack::sbuffer const& buffer) -> std::optional<T> {
// NOLINTBEGIN(cppcoreguidelines-pro-type-union-access,cppcoreguidelines-pro-bounds-pointer-arithmetic)
try {
Expand All @@ -119,14 +122,81 @@ auto response_get_result(msgpack::sbuffer const& buffer) -> std::optional<T> {
return std::nullopt;
}

return object.via.array.ptr[1].as<T>();
return std::make_optional(object.via.array.ptr[1].as<T>());
} catch (msgpack::type_error& e) {
return std::nullopt;
}
// NOLINTEND(cppcoreguidelines-pro-type-union-access,cppcoreguidelines-pro-bounds-pointer-arithmetic)
}

template <class T>
template <Serializable... Ts>
requires(sizeof...(Ts) > 1)
auto response_get_result(msgpack::sbuffer const& buffer) -> std::optional<std::tuple<Ts...>> {
// NOLINTBEGIN(cppcoreguidelines-pro-type-union-access,cppcoreguidelines-pro-bounds-pointer-arithmetic)
try {
msgpack::object_handle const handle = msgpack::unpack(buffer.data(), buffer.size());
msgpack::object const object = handle.get();

if (msgpack::type::ARRAY != object.type || sizeof...(Ts) + 1 != object.via.array.size) {
return std::nullopt;
}

if (worker::TaskExecutorResponseType::Result
!= object.via.array.ptr[0].as<worker::TaskExecutorResponseType>())
{
return std::nullopt;
}

std::tuple<Ts...> result;
for_n<sizeof...(Ts)>([&](auto i) {
object.via.array.ptr[i.cValue + 1].convert(std::get<i.cValue>(result));
});
return std::make_optional(result);
} catch (msgpack::type_error& e) {
return std::nullopt;
}
// NOLINTEND(cppcoreguidelines-pro-type-union-access,cppcoreguidelines-pro-bounds-pointer-arithmetic)
}

inline auto response_get_result_buffers(msgpack::sbuffer const& buffer
) -> std::optional<std::vector<msgpack::sbuffer>> {
// NOLINTBEGIN(cppcoreguidelines-pro-type-union-access,cppcoreguidelines-pro-bounds-pointer-arithmetic)
try {
std::vector<msgpack::sbuffer> result_buffers;
msgpack::object_handle const handle = msgpack::unpack(buffer.data(), buffer.size());
msgpack::object const object = handle.get();

if (msgpack::type::ARRAY != object.type || object.via.array.size < 2) {
spdlog::error("Cannot split result into buffers: Wrong type");
return std::nullopt;
}

if (worker::TaskExecutorResponseType::Result
!= object.via.array.ptr[0].as<worker::TaskExecutorResponseType>())
{
spdlog::error(
"Cannot split result into buffers: Wrong response type {}",
static_cast<std::underlying_type_t<worker::TaskExecutorResponseType>>(
object.via.array.ptr[0].as<worker::TaskExecutorResponseType>()
)
);
return std::nullopt;
}

for (size_t i = 1; i < object.via.array.size; ++i) {
msgpack::object const& obj = object.via.array.ptr[i];
result_buffers.emplace_back();
msgpack::pack(result_buffers.back(), obj);
}
return result_buffers;
} catch (msgpack::type_error& e) {
spdlog::error("Cannot split result into buffers: {}", e.what());
return std::nullopt;
}
// NOLINTEND(cppcoreguidelines-pro-type-union-access,cppcoreguidelines-pro-bounds-pointer-arithmetic)
}

template <Serializable T>
auto create_result_response(T const& t) -> msgpack::sbuffer {
msgpack::sbuffer buffer;
msgpack::packer packer{buffer};
Expand All @@ -136,6 +206,16 @@ auto create_result_response(T const& t) -> msgpack::sbuffer {
return buffer;
}

template <Serializable... Values>
auto create_result_response(std::tuple<Values...> const& t) -> msgpack::sbuffer {
msgpack::sbuffer buffer;
msgpack::packer packer{buffer};
packer.pack_array(sizeof...(Values) + 1);
packer.pack(worker::TaskExecutorResponseType::Result);
(..., packer.pack(std::get<Values>(t)));
return buffer;
}

// NOLINTBEGIN(cppcoreguidelines-missing-std-forward)
template <class... Args>
auto create_args_buffers(Args&&... args) -> ArgsBuffer {
Expand All @@ -157,6 +237,19 @@ auto create_args_request(Args&&... args) -> msgpack::sbuffer {
return buffer;
}

inline auto create_args_request(std::vector<msgpack::sbuffer> const& args_buffers
) -> msgpack::sbuffer {
msgpack::sbuffer buffer;
msgpack::packer packer{buffer};
packer.pack_array(2);
packer.pack(worker::TaskExecutorRequestType::Arguments);
packer.pack_array(args_buffers.size());
for (msgpack::sbuffer const& args_buffer : args_buffers) {
buffer.write(args_buffer.data(), args_buffer.size());
}
return buffer;
}

// NOLINTEND(cppcoreguidelines-missing-std-forward)

template <class F>
Expand Down
42 changes: 32 additions & 10 deletions src/spider/worker/TaskExecutor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <optional>
#include <string>
#include <tuple>
#include <vector>

#include <boost/process/v2/process.hpp>
#include <fmt/format.h>
Expand All @@ -18,7 +19,8 @@ namespace spider::worker {

auto TaskExecutor::completed() -> bool {
std::lock_guard const lock(m_state_mutex);
return TaskExecutorState::Succeed == m_state || TaskExecutorState::Error == m_state;
return TaskExecutorState::Succeed == m_state || TaskExecutorState::Error == m_state
|| TaskExecutorState::Cancelled == m_state;
}

auto TaskExecutor::waiting() -> bool {
Expand Down Expand Up @@ -48,7 +50,14 @@ void TaskExecutor::wait() {
m_result_buffer
);
}
return;
}
std::unique_lock lock(m_state_mutex);
m_complete_cv.wait(lock, [this] {
return TaskExecutorState::Succeed == m_state || TaskExecutorState::Error == m_state
|| TaskExecutorState::Cancelled == m_state;
});
lock.unlock();
}

void TaskExecutor::cancel() {
Expand Down Expand Up @@ -79,23 +88,32 @@ auto TaskExecutor::process_output_handler() -> boost::asio::awaitable<void> {
case TaskExecutorResponseType::Block:
break;
case TaskExecutorResponseType::Error: {
std::lock_guard const lock(m_state_mutex);
m_state = TaskExecutorState::Error;
m_result_buffer.write(response.data(), response.size());
{
std::lock_guard const lock(m_state_mutex);
m_state = TaskExecutorState::Error;
m_result_buffer.write(response.data(), response.size());
}
m_complete_cv.notify_all();
co_return;
}
case TaskExecutorResponseType::Ready:
break;
case TaskExecutorResponseType::Result: {
std::lock_guard const lock(m_state_mutex);
m_state = TaskExecutorState::Succeed;
m_result_buffer.write(response.data(), response.size());
{
std::lock_guard const lock(m_state_mutex);
m_state = TaskExecutorState::Succeed;
m_result_buffer.write(response.data(), response.size());
}
m_complete_cv.notify_all();
co_return;
}
case TaskExecutorResponseType::Cancel: {
std::lock_guard const lock(m_state_mutex);
m_state = TaskExecutorState::Cancelled;
m_result_buffer.write(response.data(), response.size());
{
std::lock_guard const lock(m_state_mutex);
m_state = TaskExecutorState::Cancelled;
m_result_buffer.write(response.data(), response.size());
}
m_complete_cv.notify_all();
co_return;
}
case TaskExecutorResponseType::Unknown:
Expand All @@ -106,6 +124,10 @@ auto TaskExecutor::process_output_handler() -> boost::asio::awaitable<void> {

// NOLINTEND(clang-analyzer-core.CallAndMessage)

auto TaskExecutor::get_result_buffers() const -> std::optional<std::vector<msgpack::sbuffer>> {
return core::response_get_result_buffers(m_result_buffer);
}

auto TaskExecutor::get_error() const -> std::tuple<core::FunctionInvokeError, std::string> {
return core::response_get_error(m_result_buffer)
.value_or(std::make_tuple(
Expand Down
Loading

0 comments on commit 9169591

Please sign in to comment.