Skip to content

Commit

Permalink
refactor: Stop duplicating code between the low level API and the SDK (
Browse files Browse the repository at this point in the history
  • Loading branch information
ddiakiteaneo authored Mar 4, 2024
2 parents 9bd9847 + 2bd8cc0 commit be4140b
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 99 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ class SessionService {
* @warning The data of these results will not be recoverable. Tasks which depend on these data will fail.
* @warning The tasks will not be processed by the client.
*/
void CleanupTasks(const std::set<std::string> &task_ids);
void CleanupTasks(std::vector<std::string> task_ids);
};
} // namespace Client
} // namespace Sdk
Expand Down
8 changes: 1 addition & 7 deletions ArmoniK.SDK.Client/private/SessionServiceImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,6 @@ namespace Internal {
*/
class SessionServiceImpl {
private:
/**
* @brief Requests the control plane to create results
* @param num Number of results to create
* @return List of result ids
*/
std::vector<std::string> generate_result_ids(size_t num);
/**
* @brief Session
*/
Expand Down Expand Up @@ -146,7 +140,7 @@ class SessionServiceImpl {
* @warning If the given task has not been processed, the behavior is undefined. The tasks will not be processed by
* the client.
*/
void CleanupTasks(const std::set<std::string> &task_ids);
void CleanupTasks(std::vector<std::string> task_ids);
};
} // namespace Internal
} // namespace Client
Expand Down
5 changes: 3 additions & 2 deletions ArmoniK.SDK.Client/src/SessionService.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "armonik/sdk/client/SessionService.h"
#include "SessionServiceImpl.h"
#include <string>
#include <utility>

namespace ArmoniK {
namespace Sdk {
Expand Down Expand Up @@ -44,9 +45,9 @@ void SessionService::DropSession() {
ensure_valid();
impl->DropSession();
}
void SessionService::CleanupTasks(const std::set<std::string> &task_ids) {
void SessionService::CleanupTasks(std::vector<std::string> task_ids) {
ensure_valid();
impl->CleanupTasks(task_ids);
impl->CleanupTasks(std::move(task_ids));
}

SessionService::SessionService(SessionService &&) noexcept = default;
Expand Down
117 changes: 28 additions & 89 deletions ArmoniK.SDK.Client/src/SessionServiceImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,9 @@
#include <armonik/common/exceptions/ArmoniKTaskError.h>
#include <armonik/common/objects.pb.h>
#include <armonik/common/utils/GuuId.h>
#include <armonik/sdk/common/ArmoniKSdkException.h>
#include <armonik/sdk/common/Properties.h>
#include <armonik/sdk/common/TaskPayload.h>
#include <chrono>
#include <grpcpp/client_context.h>
#include <submitter_service.grpc.pb.h>
#include <thread>
#include <utility>
#include <vector>
Expand All @@ -26,46 +23,6 @@ namespace Sdk {
namespace Client {
namespace Internal {

std::vector<std::string> SessionServiceImpl::generate_result_ids(size_t num) {
armonik::api::grpc::v1::results::CreateResultsMetaDataRequest results_request;
armonik::api::grpc::v1::results::CreateResultsMetaDataResponse results_response;

// Creates the result creation requests
std::vector<armonik::api::grpc::v1::results::CreateResultsMetaDataRequest_ResultCreate> results_create;
results_create.reserve(num);
for (size_t i = 0; i < num; i++) {
armonik::api::grpc::v1::results::CreateResultsMetaDataRequest_ResultCreate result_create;
// Random name
*result_create.mutable_name() = armonik::api::common::utils::GuuId::generate_uuid();
results_create.push_back(std::move(result_create));
}

results_request.mutable_results()->Add(std::make_move_iterator(results_create.begin()),
std::make_move_iterator(results_create.end()));
*results_request.mutable_session_id() = session;

// Creates the results
auto status = channel_pool.WithChannel([&](auto &&channel) {
grpc::ClientContext context;
return armonik::api::grpc::v1::results::Results::NewStub(channel)->CreateResultsMetaData(&context, results_request,
&results_response);
});

if (!status.ok()) {
std::stringstream message;
message << "Error: " << status.error_code() << ": " << status.error_message()
<< ". details : " << status.error_details() << std::endl;
logger_.log(armonik::api::common::logger::Level::Error, "Could not create results for submit: ");
throw armonik::api::common::exceptions::ArmoniKApiException(message.str());
}
std::vector<std::string> result_ids;
result_ids.reserve(num);
// Get the result ids from the response
std::transform(results_response.mutable_results()->begin(), results_response.mutable_results()->end(),
std::back_inserter(result_ids), [](auto &res) { return std::move(*res.mutable_result_id()); });
return result_ids;
}

const std::string &SessionServiceImpl::getSession() const { return session; }

[[maybe_unused]] std::vector<std::string>
Expand All @@ -76,22 +33,22 @@ SessionServiceImpl::Submit(const std::vector<Common::TaskPayload> &task_requests
std::vector<armonik::api::common::TaskCreation> task_creations;
task_creations.reserve(task_requests.size());

for (size_t i = 0; i < task_requests.size(); i++) {
for (const auto &task_request : task_requests) {
armonik::api::grpc::v1::TaskRequest request;
// Serialize the request in an ArmoniK format
*request.mutable_payload() = task_requests[i].Serialize();
*request.mutable_payload() = task_request.Serialize();
// Set the data dependencies
request.mutable_data_dependencies()->Add(task_requests[i].data_dependencies.begin(),
task_requests[i].data_dependencies.end());
request.mutable_data_dependencies()->Add(task_request.data_dependencies.begin(),
task_request.data_dependencies.end());

armonik::api::common::TaskCreation creation{};

auto result_payload = channel_pool.WithChannel([&](auto channel) {
auto client = armonik::api::client::ResultsClient(armonik::api::grpc::v1::results::Results::NewStub(channel));
auto result = client.create_results_metadata(session, std::vector<std::string>{"result"})["result"];
auto payload = client.create_results(
session, std::vector<std::pair<std::string, std::string>>{
{task_requests[i].method_name, request.payload()}})[task_requests[i].method_name];
auto payload =
client.create_results(session, std::vector<std::pair<std::string, std::string>>{
{task_request.method_name, request.payload()}})[task_request.method_name];

return std::pair<std::string, std::string>{result, payload};
});
Expand All @@ -100,8 +57,8 @@ SessionServiceImpl::Submit(const std::vector<Common::TaskPayload> &task_requests
creation.payload_id = std::move(result_payload.second);
// One result per task
creation.expected_output_keys.push_back(std::move(result_payload.first));
creation.data_dependencies.insert(creation.data_dependencies.end(), task_requests[i].data_dependencies.begin(),
task_requests[i].data_dependencies.end());
creation.data_dependencies.insert(creation.data_dependencies.end(), task_request.data_dependencies.begin(),
task_request.data_dependencies.end());

task_creations.emplace_back(std::move(creation));
}
Expand Down Expand Up @@ -289,16 +246,9 @@ void SessionServiceImpl::DropSession() {
result_handlers.clear();
}
// Cancel the session
channel_pool.WithChannel([&](const std::shared_ptr<::grpc::Channel> &channel) {
::grpc::ClientContext context;
armonik::api::grpc::v1::sessions::CancelSessionRequest request;
armonik::api::grpc::v1::sessions::CancelSessionResponse response;
*request.mutable_session_id() = session;
auto status =
armonik::api::grpc::v1::sessions::Sessions::NewStub(channel)->CancelSession(&context, request, &response);
if (!status.ok()) {
throw ArmoniK::Sdk::Common::ArmoniKSdkException("Unable to cancel session " + status.error_message());
}
auto reply = channel_pool.WithChannel([&](const std::shared_ptr<::grpc::Channel> &channel) {
return armonik::api::client::SessionsClient(armonik::api::grpc::v1::sessions::Sessions::NewStub(channel))
.cancel_session(session);
});

// Create the result filter for result.session_id == session
Expand Down Expand Up @@ -333,7 +283,7 @@ void SessionServiceImpl::DropSession() {
} while (page * page_size < total);
}

void SessionServiceImpl::CleanupTasks(const std::set<std::string> &task_ids) {
void SessionServiceImpl::CleanupTasks(std::vector<std::string> task_ids) {
// Remove the given tasks from the maps
{
std::lock_guard<std::mutex> _(maps_mutex);
Expand All @@ -351,51 +301,40 @@ void SessionServiceImpl::CleanupTasks(const std::set<std::string> &task_ids) {
// Cancel the given tasks
while (tasks_iterator != task_ids.end()) {
channel_pool.WithChannel([&](const std::shared_ptr<::grpc::Channel> &channel) {
auto stub = armonik::api::grpc::v1::tasks::Tasks::NewStub(channel);
::grpc::ClientContext context;
armonik::api::grpc::v1::tasks::CancelTasksRequest request;
std::vector<std::string> batched_ids;
for (size_t i = 0; i < batch_size && tasks_iterator != task_ids.end(); ++i) {
*request.mutable_task_ids()->Add() = *tasks_iterator;
batched_ids.push_back(*tasks_iterator);
tasks_iterator++;
}
armonik::api::grpc::v1::tasks::CancelTasksResponse response;
auto status = stub->CancelTasks(&context, request, &response);
if (!status.ok()) {
throw ArmoniK::Sdk::Common::ArmoniKSdkException("Unable to cancel tasks " + status.error_message());
}
armonik::api::client::TasksClient(armonik::api::grpc::v1::tasks::Tasks::NewStub(channel))
.cancel_tasks(batched_ids);
});
}

tasks_iterator = task_ids.begin();

while (tasks_iterator != task_ids.end()) {
armonik::api::grpc::v1::tasks::GetResultIdsResponse response;
// List batch of results from the given tasks
channel_pool.WithChannel([&](const std::shared_ptr<::grpc::Channel> &channel) {
auto stub = armonik::api::grpc::v1::tasks::Tasks::NewStub(channel);

::grpc::ClientContext context;
armonik::api::grpc::v1::tasks::GetResultIdsRequest request;
auto map_results = channel_pool.WithChannel([&](const std::shared_ptr<::grpc::Channel> &channel) {
std::vector<std::string> batched_ids;
for (size_t i = 0; i < batch_size && tasks_iterator != task_ids.end(); ++i) {
*request.mutable_task_id()->Add() = *tasks_iterator;
batched_ids.push_back(std::move(*tasks_iterator));
tasks_iterator++;
}

auto status = stub->GetResultIds(&context, request, &response);
if (!status.ok()) {
throw ArmoniK::Sdk::Common::ArmoniKSdkException("Unable to list tasks resultIds " + status.error_message());
}
return armonik::api::client::TasksClient(armonik::api::grpc::v1::tasks::Tasks::NewStub(channel))
.get_result_ids(batched_ids);
});

// Delete results
channel_pool.WithChannel([&](const std::shared_ptr<::grpc::Channel> &channel) {
auto results = armonik::api::client::ResultsClient(armonik::api::grpc::v1::results::Results::NewStub(channel));
std::vector<std::string> resultids;
resultids.reserve(response.task_results_size());
for (auto &&tid_rids : response.task_results()) {
resultids.insert(resultids.end(), tid_rids.result_ids().begin(), tid_rids.result_ids().end());
std::vector<std::string> resultIds;
resultIds.reserve(map_results.size());
for (auto &&taskId_resultIds : map_results) {
resultIds.insert(resultIds.end(), std::make_move_iterator(taskId_resultIds.second.begin()),
std::make_move_iterator(taskId_resultIds.second.end()));
}
results.delete_results_data(session, resultids);
results.delete_results_data(session, resultIds);
});
}
}
Expand Down

0 comments on commit be4140b

Please sign in to comment.