Skip to content

Commit

Permalink
refactor: remove submitter service from parts of the client (#235)
Browse files Browse the repository at this point in the history
  • Loading branch information
aneojgurhem authored Dec 1, 2023
2 parents 9295b0a + 791c04c commit bd1ac96
Show file tree
Hide file tree
Showing 3 changed files with 250 additions and 140 deletions.
246 changes: 121 additions & 125 deletions Client/src/Common/Submitter/BaseClientSubmitter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,12 @@
using System.Threading.Tasks;

using ArmoniK.Api.Client.Submitter;
using ArmoniK.Api.Client;
using ArmoniK.Api.Common.Utils;
using ArmoniK.Api.gRPC.V1;
using ArmoniK.Api.gRPC.V1.Results;
using ArmoniK.Api.gRPC.V1.Sessions;
using ArmoniK.Api.gRPC.V1.SortDirection;
using ArmoniK.Api.gRPC.V1.Submitter;
using ArmoniK.Api.gRPC.V1.Tasks;
using ArmoniK.DevelopmentKit.Client.Common.Status;
Expand All @@ -41,6 +44,8 @@

using Microsoft.Extensions.Logging;

using CreateSessionRequest = ArmoniK.Api.gRPC.V1.Sessions.CreateSessionRequest;
using Filters = ArmoniK.Api.gRPC.V1.Tasks.Filters;
using TaskStatus = ArmoniK.Api.gRPC.V1.TaskStatus;

namespace ArmoniK.DevelopmentKit.Client.Common.Submitter;
Expand Down Expand Up @@ -120,21 +125,21 @@ private Session CreateSession(IEnumerable<string> partitionIds)
{
using var _ = Logger.LogFunction();
Logger.LogDebug("Creating Session... ");
var createSessionRequest = new CreateSessionRequest
{
DefaultTaskOption = TaskOptions,
PartitionIds =
{
partitionIds,
},
};
var session = ChannelPool.WithChannel(channel => new Api.gRPC.V1.Submitter.Submitter.SubmitterClient(channel).CreateSession(createSessionRequest));

using var channel = ChannelPool.GetChannel();
var sessionsClient = new Sessions.SessionsClient(channel);
var createSessionReply = sessionsClient.CreateSession(new CreateSessionRequest
{
DefaultTaskOption = TaskOptions,
PartitionIds =
{
partitionIds,
},
});
Logger.LogDebug("Session Created {SessionId}",
SessionId);
return new Session
{
Id = session.SessionId,
Id = createSessionReply.SessionId,
};
}

Expand All @@ -157,27 +162,53 @@ public TaskStatus GetTaskStatus(string taskId)
/// <param name="taskIds">The list of taskIds</param>
/// <returns></returns>
public IEnumerable<Tuple<string, TaskStatus>> GetTaskStatues(params string[] taskIds)
=> ChannelPool.WithChannel(channel => new Api.gRPC.V1.Submitter.Submitter.SubmitterClient(channel).GetTaskStatus(new GetTaskStatusRequest
{
TaskIds =
{
taskIds,
},
})
.IdStatuses.Select(x => Tuple.Create(x.TaskId,
x.Status)));
{
using var channel = ChannelPool.GetChannel();
var tasksClient = new Tasks.TasksClient(channel);
return tasksClient.ListTasks(new Filters
{
Or =
{
taskIds.Select(TasksClientExt.TaskIdFilter),
},
},
new ListTasksRequest.Types.Sort
{
Direction = SortDirection.Asc,
Field = new TaskField
{
TaskSummaryField = new TaskSummaryField
{
Field = TaskSummaryEnumField.TaskId,
},
},
})
.Select(task => new Tuple<string, TaskStatus>(task.Id,
task.Status));
}

/// <summary>
/// Return the taskOutput when error occurred
/// </summary>
/// <param name="taskId"></param>
/// <returns></returns>

// TODO: This function should not have Output as a return type because it is a gRPC type
public Output GetTaskOutputInfo(string taskId)
=> ChannelPool.WithChannel(channel => new Api.gRPC.V1.Submitter.Submitter.SubmitterClient(channel).TryGetTaskOutput(new TaskOutputRequest
{
TaskId = taskId,
Session = SessionId.Id,
}));
{
var getTaskResponse = ChannelPool.WithChannel(channel => new Tasks.TasksClient(channel).GetTask(new GetTaskRequest
{
TaskId = taskId,
}));
return new Output
{
Error = new Output.Types.Error
{
Details = getTaskResponse.Task.Output.Error,
},
};
}


/// <summary>
/// The method to submit several tasks with dependencies tasks. This task will wait for
Expand Down Expand Up @@ -374,8 +405,8 @@ public void WaitForTasksCompletion(IEnumerable<string> taskIds,
delayMs,
retry =>
{
using var channel = ChannelPool.GetChannel();
var submitterService = new Api.gRPC.V1.Submitter.Submitter.SubmitterClient(channel);
using var channel = ChannelPool.GetChannel();
var submitterService = new Api.gRPC.V1.Submitter.Submitter.SubmitterClient(channel);

if (retry > 1)
{
Expand All @@ -385,20 +416,20 @@ public void WaitForTasksCompletion(IEnumerable<string> taskIds,
}

var __ = submitterService.WaitForCompletion(new WaitRequest
{
Filter = new TaskFilter
{
Task = new TaskFilter.Types.IdsRequest
{
Ids =
{
Filter = new TaskFilter
{
Task = new TaskFilter.Types.IdsRequest
{
Ids =
{
taskIds,
},
},
},
StopOnFirstTaskCancellation = true,
StopOnFirstTaskError = true,
});
},
},
StopOnFirstTaskCancellation = true,
StopOnFirstTaskError = true,
});
},
true,
Logger,
Expand Down Expand Up @@ -432,22 +463,21 @@ public ResultStatusCollection GetResultStatus(IEnumerable<string> taskIds,
2000,
retry =>
{
using var channel = ChannelPool.GetChannel();
var submitterService = new Api.gRPC.V1.Submitter.Submitter.SubmitterClient(channel);

using var channel = ChannelPool.GetChannel();
var resultsClient = new Results.ResultsClient(channel);
Logger.LogDebug("Try {try} for {funcName}",
retry,
nameof(submitterService.GetResultStatus));
// TODO: replace with submitterService.TryGetResultStream() => Issue #
var resultStatusReply = submitterService.GetResultStatus(new GetResultStatusRequest
{
ResultIds =
{
result2TaskDic.Keys,
},
SessionId = SessionId.Id,
});
return resultStatusReply.IdStatuses;
nameof(resultsClient.GetResult));
var idStatusPair = result2TaskDic.Keys.Select(resultId =>
{
var status = resultsClient.GetResult(new GetResultRequest
{
ResultId = resultId,
})
.Result.Status;
return (resultId, status);
});
return idStatusPair;
},
true,
Logger,
Expand All @@ -460,11 +490,11 @@ public ResultStatusCollection GetResultStatus(IEnumerable<string> taskIds,

foreach (var idStatusPair in idStatus)
{
var resData = new ResultStatusData(idStatusPair.ResultId,
result2TaskDic[idStatusPair.ResultId],
idStatusPair.Status);
var resData = new ResultStatusData(idStatusPair.resultId,
result2TaskDic[idStatusPair.resultId],
idStatusPair.status);

switch (idStatusPair.Status)
switch (idStatusPair.status)
{
case ResultStatus.Notfound:
continue;
Expand All @@ -481,7 +511,7 @@ public ResultStatusCollection GetResultStatus(IEnumerable<string> taskIds,
break;
}

result2TaskDic.Remove(idStatusPair.ResultId);
result2TaskDic.Remove(idStatusPair.resultId);
}

var resultStatusList = new ResultStatusCollection(idsReady,
Expand Down Expand Up @@ -524,7 +554,6 @@ public ResultStatusCollection GetResultStatus(IEnumerable<string> taskIds,
typeof(IOException),
typeof(RpcException));


/// <summary>
/// Try to find the result of One task. If there no result, the function return byte[0]
/// </summary>
Expand All @@ -547,17 +576,17 @@ public byte[] GetResult(string taskId,


var resultRequest = new ResultRequest
{
ResultId = resultId,
Session = SessionId.Id,
};
{
ResultId = resultId,
Session = SessionId.Id,
};

Retry.WhileException(5,
2000,
retry =>
{
using var channel = ChannelPool.GetChannel();
var submitterService = new Api.gRPC.V1.Submitter.Submitter.SubmitterClient(channel);
using var channel = ChannelPool.GetChannel();
var submitterService = new Api.gRPC.V1.Submitter.Submitter.SubmitterClient(channel);

Logger.LogDebug("Try {try} for {funcName}",
retry,
Expand Down Expand Up @@ -603,7 +632,6 @@ public byte[] GetResult(string taskId,
}
}


/// <summary>
/// Retrieve results from control plane
/// </summary>
Expand All @@ -629,76 +657,44 @@ public IEnumerable<Tuple<string, byte[]>> GetResults(IEnumerable<string> taskIds
/// </summary>
/// <param name="resultRequest">Request specifying the result to fetch</param>
/// <param name="cancellationToken">The token used to cancel the operation.</param>
/// <returns>Returns the result or byte[0] if there no result or null if task is not yet ready</returns>
/// <returns>Returns the result if it is ready, null if task is not yet ready</returns>
/// <exception cref="Exception"></exception>
/// <exception cref="ArgumentOutOfRangeException"></exception>
// TODO: return a compound type to avoid having a nullable that holds the information and return an empty array.
// TODO: This function should not have an argument of type ResultRequest because it is a gRPC type
public async Task<byte[]?> TryGetResultAsync(ResultRequest resultRequest,
CancellationToken cancellationToken = default)
{
List<ReadOnlyMemory<byte>> chunks;
int len;

using var channel = ChannelPool.GetChannel();
var submitterService = new Api.gRPC.V1.Submitter.Submitter.SubmitterClient(channel);

using var channel = ChannelPool.GetChannel();
var resultsClient = new Results.ResultsClient(channel);
var getResultResponse = await resultsClient.GetResultAsync(new GetResultRequest
{
ResultId = resultRequest.ResultId,
},
null,
null,
cancellationToken)
.ConfigureAwait(false);
var result = getResultResponse.Result;
switch (result.Status)
{
using var streamingCall = submitterService.TryGetResultStream(resultRequest,
cancellationToken: cancellationToken);
chunks = new List<ReadOnlyMemory<byte>>();
len = 0;
var isPayloadComplete = false;

while (await streamingCall.ResponseStream.MoveNext(cancellationToken))
{
var reply = streamingCall.ResponseStream.Current;

switch (reply.TypeCase)
{
case ResultReply.TypeOneofCase.Result:
if (!reply.Result.DataComplete)
{
chunks.Add(reply.Result.Data.Memory);
len += reply.Result.Data.Memory.Length;
// In case we receive a chunk after the data complete message (corrupt stream)
isPayloadComplete = false;
}
else
{
isPayloadComplete = true;
}

break;
case ResultReply.TypeOneofCase.None:
return null;

case ResultReply.TypeOneofCase.Error:
throw new Exception($"Error in task {reply.Error.TaskId} {string.Join("Message is : ", reply.Error.Errors.Select(x => x.Detail))}");

case ResultReply.TypeOneofCase.NotCompletedTask:
return null;

default:
throw new InvalidOperationException("Got a reply with an unexpected message type.");
}
}

if (!isPayloadComplete)
case ResultStatus.Completed:
{
throw new ClientResultsException($"Result data is incomplete for id {resultRequest.ResultId}");
return await resultsClient.DownloadResultData(result.SessionId,
result.ResultId,
cancellationToken)
.ConfigureAwait(false);
}
case ResultStatus.Aborted:
throw new Exception($"Error while trying to get result {result.ResultId}. Result was aborted");
case ResultStatus.Notfound:
throw new Exception($"Error while trying to get result {result.ResultId}. Result was not found");
case ResultStatus.Created:
return null;
case ResultStatus.Unspecified:
default:
throw new ArgumentOutOfRangeException(nameof(result.Status));
}

var res = new byte[len];
var idx = 0;
foreach (var rm in chunks)
{
rm.CopyTo(res.AsMemory(idx,
rm.Length));
idx += rm.Length;
}

return res;
}

/// <summary>
Expand Down
Loading

0 comments on commit bd1ac96

Please sign in to comment.