Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Stop duplicating code between the low level API and the SDK #44

Merged
merged 2 commits into from
Mar 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ class SessionService {
* @warning The data of these results will not be recoverable. Tasks which depend on these data will fail.
* @warning The tasks will not be processed by the client.
*/
void CleanupTasks(const std::set<std::string> &task_ids);
void CleanupTasks(std::vector<std::string> task_ids);
};
} // namespace Client
} // namespace Sdk
Expand Down
8 changes: 1 addition & 7 deletions ArmoniK.SDK.Client/private/SessionServiceImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,6 @@ namespace Internal {
*/
class SessionServiceImpl {
private:
/**
* @brief Requests the control plane to create results
* @param num Number of results to create
* @return List of result ids
*/
std::vector<std::string> generate_result_ids(size_t num);
/**
* @brief Session
*/
Expand Down Expand Up @@ -146,7 +140,7 @@ class SessionServiceImpl {
* @warning If the given task has not been processed, the behavior is undefined. The tasks will not be processed by
* the client.
*/
void CleanupTasks(const std::set<std::string> &task_ids);
void CleanupTasks(std::vector<std::string> task_ids);
};
} // namespace Internal
} // namespace Client
Expand Down
5 changes: 3 additions & 2 deletions ArmoniK.SDK.Client/src/SessionService.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "armonik/sdk/client/SessionService.h"
#include "SessionServiceImpl.h"
#include <string>
#include <utility>

namespace ArmoniK {
namespace Sdk {
Expand Down Expand Up @@ -44,9 +45,9 @@ void SessionService::DropSession() {
ensure_valid();
impl->DropSession();
}
void SessionService::CleanupTasks(const std::set<std::string> &task_ids) {
void SessionService::CleanupTasks(std::vector<std::string> task_ids) {
ensure_valid();
impl->CleanupTasks(task_ids);
impl->CleanupTasks(std::move(task_ids));
}

SessionService::SessionService(SessionService &&) noexcept = default;
Expand Down
117 changes: 28 additions & 89 deletions ArmoniK.SDK.Client/src/SessionServiceImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,9 @@
#include <armonik/common/exceptions/ArmoniKTaskError.h>
#include <armonik/common/objects.pb.h>
#include <armonik/common/utils/GuuId.h>
#include <armonik/sdk/common/ArmoniKSdkException.h>
#include <armonik/sdk/common/Properties.h>
#include <armonik/sdk/common/TaskPayload.h>
#include <chrono>
#include <grpcpp/client_context.h>
#include <submitter_service.grpc.pb.h>
#include <thread>
#include <utility>
#include <vector>
Expand All @@ -26,46 +23,6 @@ namespace Sdk {
namespace Client {
namespace Internal {

std::vector<std::string> SessionServiceImpl::generate_result_ids(size_t num) {
armonik::api::grpc::v1::results::CreateResultsMetaDataRequest results_request;
armonik::api::grpc::v1::results::CreateResultsMetaDataResponse results_response;

// Creates the result creation requests
std::vector<armonik::api::grpc::v1::results::CreateResultsMetaDataRequest_ResultCreate> results_create;
results_create.reserve(num);
for (size_t i = 0; i < num; i++) {
armonik::api::grpc::v1::results::CreateResultsMetaDataRequest_ResultCreate result_create;
// Random name
*result_create.mutable_name() = armonik::api::common::utils::GuuId::generate_uuid();
results_create.push_back(std::move(result_create));
}

results_request.mutable_results()->Add(std::make_move_iterator(results_create.begin()),
std::make_move_iterator(results_create.end()));
*results_request.mutable_session_id() = session;

// Creates the results
auto status = channel_pool.WithChannel([&](auto &&channel) {
grpc::ClientContext context;
return armonik::api::grpc::v1::results::Results::NewStub(channel)->CreateResultsMetaData(&context, results_request,
&results_response);
});

if (!status.ok()) {
std::stringstream message;
message << "Error: " << status.error_code() << ": " << status.error_message()
<< ". details : " << status.error_details() << std::endl;
logger_.log(armonik::api::common::logger::Level::Error, "Could not create results for submit: ");
throw armonik::api::common::exceptions::ArmoniKApiException(message.str());
}
std::vector<std::string> result_ids;
result_ids.reserve(num);
// Get the result ids from the response
std::transform(results_response.mutable_results()->begin(), results_response.mutable_results()->end(),
std::back_inserter(result_ids), [](auto &res) { return std::move(*res.mutable_result_id()); });
return result_ids;
}

const std::string &SessionServiceImpl::getSession() const { return session; }

[[maybe_unused]] std::vector<std::string>
Expand All @@ -76,22 +33,22 @@ SessionServiceImpl::Submit(const std::vector<Common::TaskPayload> &task_requests
std::vector<armonik::api::common::TaskCreation> task_creations;
task_creations.reserve(task_requests.size());

for (size_t i = 0; i < task_requests.size(); i++) {
for (const auto &task_request : task_requests) {
armonik::api::grpc::v1::TaskRequest request;
// Serialize the request in an ArmoniK format
*request.mutable_payload() = task_requests[i].Serialize();
*request.mutable_payload() = task_request.Serialize();
// Set the data dependencies
request.mutable_data_dependencies()->Add(task_requests[i].data_dependencies.begin(),
task_requests[i].data_dependencies.end());
request.mutable_data_dependencies()->Add(task_request.data_dependencies.begin(),
task_request.data_dependencies.end());

armonik::api::common::TaskCreation creation{};

auto result_payload = channel_pool.WithChannel([&](auto channel) {
auto client = armonik::api::client::ResultsClient(armonik::api::grpc::v1::results::Results::NewStub(channel));
auto result = client.create_results_metadata(session, std::vector<std::string>{"result"})["result"];
auto payload = client.create_results(
session, std::vector<std::pair<std::string, std::string>>{
{task_requests[i].method_name, request.payload()}})[task_requests[i].method_name];
auto payload =
client.create_results(session, std::vector<std::pair<std::string, std::string>>{
{task_request.method_name, request.payload()}})[task_request.method_name];

return std::pair<std::string, std::string>{result, payload};
});
Expand All @@ -100,8 +57,8 @@ SessionServiceImpl::Submit(const std::vector<Common::TaskPayload> &task_requests
creation.payload_id = std::move(result_payload.second);
// One result per task
creation.expected_output_keys.push_back(std::move(result_payload.first));
creation.data_dependencies.insert(creation.data_dependencies.end(), task_requests[i].data_dependencies.begin(),
task_requests[i].data_dependencies.end());
creation.data_dependencies.insert(creation.data_dependencies.end(), task_request.data_dependencies.begin(),
task_request.data_dependencies.end());

task_creations.emplace_back(std::move(creation));
}
Expand Down Expand Up @@ -289,16 +246,9 @@ void SessionServiceImpl::DropSession() {
result_handlers.clear();
}
// Cancel the session
channel_pool.WithChannel([&](const std::shared_ptr<::grpc::Channel> &channel) {
::grpc::ClientContext context;
armonik::api::grpc::v1::sessions::CancelSessionRequest request;
armonik::api::grpc::v1::sessions::CancelSessionResponse response;
*request.mutable_session_id() = session;
auto status =
armonik::api::grpc::v1::sessions::Sessions::NewStub(channel)->CancelSession(&context, request, &response);
if (!status.ok()) {
throw ArmoniK::Sdk::Common::ArmoniKSdkException("Unable to cancel session " + status.error_message());
}
auto reply = channel_pool.WithChannel([&](const std::shared_ptr<::grpc::Channel> &channel) {
fdenefaneo marked this conversation as resolved.
Show resolved Hide resolved
return armonik::api::client::SessionsClient(armonik::api::grpc::v1::sessions::Sessions::NewStub(channel))
.cancel_session(session);
});

// Create the result filter for result.session_id == session
Expand Down Expand Up @@ -333,7 +283,7 @@ void SessionServiceImpl::DropSession() {
} while (page * page_size < total);
}

void SessionServiceImpl::CleanupTasks(const std::set<std::string> &task_ids) {
void SessionServiceImpl::CleanupTasks(std::vector<std::string> task_ids) {
// Remove the given tasks from the maps
{
std::lock_guard<std::mutex> _(maps_mutex);
Expand All @@ -351,51 +301,40 @@ void SessionServiceImpl::CleanupTasks(const std::set<std::string> &task_ids) {
// Cancel the given tasks
while (tasks_iterator != task_ids.end()) {
channel_pool.WithChannel([&](const std::shared_ptr<::grpc::Channel> &channel) {
auto stub = armonik::api::grpc::v1::tasks::Tasks::NewStub(channel);
::grpc::ClientContext context;
armonik::api::grpc::v1::tasks::CancelTasksRequest request;
std::vector<std::string> batched_ids;
for (size_t i = 0; i < batch_size && tasks_iterator != task_ids.end(); ++i) {
*request.mutable_task_ids()->Add() = *tasks_iterator;
batched_ids.push_back(*tasks_iterator);
fdenefaneo marked this conversation as resolved.
Show resolved Hide resolved
tasks_iterator++;
}
armonik::api::grpc::v1::tasks::CancelTasksResponse response;
auto status = stub->CancelTasks(&context, request, &response);
if (!status.ok()) {
throw ArmoniK::Sdk::Common::ArmoniKSdkException("Unable to cancel tasks " + status.error_message());
}
armonik::api::client::TasksClient(armonik::api::grpc::v1::tasks::Tasks::NewStub(channel))
.cancel_tasks(batched_ids);
});
}

tasks_iterator = task_ids.begin();

while (tasks_iterator != task_ids.end()) {
armonik::api::grpc::v1::tasks::GetResultIdsResponse response;
// List batch of results from the given tasks
channel_pool.WithChannel([&](const std::shared_ptr<::grpc::Channel> &channel) {
auto stub = armonik::api::grpc::v1::tasks::Tasks::NewStub(channel);

::grpc::ClientContext context;
armonik::api::grpc::v1::tasks::GetResultIdsRequest request;
auto map_results = channel_pool.WithChannel([&](const std::shared_ptr<::grpc::Channel> &channel) {
fdenefaneo marked this conversation as resolved.
Show resolved Hide resolved
std::vector<std::string> batched_ids;
for (size_t i = 0; i < batch_size && tasks_iterator != task_ids.end(); ++i) {
*request.mutable_task_id()->Add() = *tasks_iterator;
batched_ids.push_back(std::move(*tasks_iterator));
fdenefaneo marked this conversation as resolved.
Show resolved Hide resolved
tasks_iterator++;
}

auto status = stub->GetResultIds(&context, request, &response);
if (!status.ok()) {
throw ArmoniK::Sdk::Common::ArmoniKSdkException("Unable to list tasks resultIds " + status.error_message());
}
return armonik::api::client::TasksClient(armonik::api::grpc::v1::tasks::Tasks::NewStub(channel))
.get_result_ids(batched_ids);
});

// Delete results
channel_pool.WithChannel([&](const std::shared_ptr<::grpc::Channel> &channel) {
auto results = armonik::api::client::ResultsClient(armonik::api::grpc::v1::results::Results::NewStub(channel));
std::vector<std::string> resultids;
resultids.reserve(response.task_results_size());
for (auto &&tid_rids : response.task_results()) {
resultids.insert(resultids.end(), tid_rids.result_ids().begin(), tid_rids.result_ids().end());
std::vector<std::string> resultIds;
resultIds.reserve(map_results.size());
for (auto &&taskId_resultIds : map_results) {
resultIds.insert(resultIds.end(), std::make_move_iterator(taskId_resultIds.second.begin()),
std::make_move_iterator(taskId_resultIds.second.end()));
}
results.delete_results_data(session, resultids);
results.delete_results_data(session, resultIds);
});
}
}
Expand Down
Loading