Skip to content

Commit

Permalink
feat: Batching for events
Browse files Browse the repository at this point in the history
  • Loading branch information
dbrasseur-aneo committed Aug 26, 2024
1 parent 1605ae1 commit 4f3b30a
Show file tree
Hide file tree
Showing 4 changed files with 181 additions and 36 deletions.
104 changes: 70 additions & 34 deletions packages/python/src/armonik/client/events.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from __future__ import annotations
from typing import Callable, cast, Iterable, List, Optional, Union

import concurrent.futures
from concurrent.futures.thread import ThreadPoolExecutor
from typing import Callable, cast, Iterable, List, Optional, Union, Collection

from grpc import Channel, RpcError

from .results import ArmoniKResults
from ..common import (
EventTypes,
NewTaskEvent,
Expand All @@ -14,6 +16,7 @@
ResultStatus,
Event,
Result,
batched,
)
from ..common.filter import Filter
from ..protogen.client.events_service_pb2_grpc import EventsStub
Expand All @@ -38,7 +41,6 @@ def __init__(self, grpc_channel: Channel):
grpc_channel: gRPC channel to use
"""
self._client = EventsStub(grpc_channel)
self._results_client = ArmoniKResults(grpc_channel)

def get_events(
self,
Expand Down Expand Up @@ -91,48 +93,82 @@ def get_events(
break

def wait_for_result_availability(
self, result_ids: Union[str, List[str]], session_id: str
self,
result_ids: Union[str, List[str]],
session_id: str,
bucket_size: int = 100,
parallelism: int = 1,
) -> None:
"""Wait until a result is ready i.e its status updates to COMPLETED.
Args:
result_ids: The IDs of the results.
session_id: The ID of the session.
bucket_size: Batch size
parallelism: Parallelism
Raises:
RuntimeError: If the result status is ABORTED.
"""
if isinstance(result_ids, str):
result_ids = [result_ids]
result_ids = set(result_ids)
if len(result_ids) == 0:
return
results_not_found = set(result_ids)

results_filter = Result.result_id == result_ids[0]
for result_id in result_ids[1:]:
results_filter = results_filter | (Result.result_id == result_id)

def handler(_, _2, event: Event) -> bool:
event = cast(Union[NewResultEvent, ResultStatusUpdateEvent], event)
if event.result_id in results_not_found:
if event.status == ResultStatus.COMPLETED:
results_not_found.remove(event.result_id)
if not results_not_found:
return True
elif event.status == ResultStatus.ABORTED:
raise RuntimeError(f"Result {event.result_id} has been aborted.")
return False

while results_not_found:

if parallelism > 1:
pool = ThreadPoolExecutor(max_workers=parallelism)
try:
self.get_events(
session_id,
[EventTypes.RESULT_STATUS_UPDATE, EventTypes.NEW_RESULT],
[handler],
None,
results_filter,
)
except RpcError:
pass
else:
break
futures = [
pool.submit(_wait_all, self, session_id, batch)
for batch in batched(result_ids, bucket_size)
]
for i, future in enumerate(concurrent.futures.as_completed(futures)):
exp = future.exception()
if exp is not None:
for f in futures:
f.cancel()
raise exp
finally:
pool.shutdown(wait=False)
else:
for batch in batched(result_ids, bucket_size):
_wait_all(self, session_id, batch)


def _wait_all(event_client: ArmoniKEvents, session_id: str, results: Collection[str]):
if len(results) == 0:
return
results_filter = None
for result_id in results:
results_filter = (
Result.result_id == result_id
if results_filter is None
else (results_filter | (Result.result_id == result_id))
)

not_found = set(results)

def handler(_, _2, event: Event) -> bool:
event = cast(Union[NewResultEvent, ResultStatusUpdateEvent], event)
if event.result_id in not_found:
if event.status == ResultStatus.COMPLETED:
not_found.remove(event.result_id)
if not not_found:
return True
elif event.status == ResultStatus.ABORTED:
raise RuntimeError(f"Result {event.result_id} has been aborted.")
return False

while not_found:
try:
event_client.get_events(
session_id,
[EventTypes.RESULT_STATUS_UPDATE, EventTypes.NEW_RESULT],
[handler],
None,
results_filter,
)
except RpcError:
pass
else:
break
2 changes: 1 addition & 1 deletion packages/python/src/armonik/client/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def submit_tasks(
Task(
id=t.task_id,
session_id=session_id,
expected_output_ids=list(t.expected_output_keys),
expected_output_ids=list(t.expected_output_ids),
data_dependencies=list(t.data_dependencies),
payload_id=t.payload_id,
)
Expand Down
2 changes: 1 addition & 1 deletion packages/python/src/armonik/common/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def batched(iterable: Iterable[T], n: int) -> Iterable[List[T]]:
batch.append(c)
if len(batch) == n:
yield batch
batch.clear()
batch = []
c = next(it, sentinel)
if len(batch) > 0:
yield batch
109 changes: 109 additions & 0 deletions packages/python/tests/test_wait_availability.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
from datetime import timedelta
import time
from threading import Thread

import grpc
import pytest
from armonik.client import ArmoniKTasks, ArmoniKResults, ArmoniKSessions, ArmoniKEvents
from armonik.common import TaskOptions, TaskDefinition

endpoint = ""


def wait_and_unpause(session_id: str):
time.sleep(1)
with grpc.insecure_channel(endpoint) as channel:
ArmoniKSessions(channel).resume_session(session_id)
print("Session resumed")


class TestWaitAvailability:
def test_wait_availability(self):
pytest.skip()
n_tasks = 10000
with grpc.insecure_channel(endpoint) as channel:
task_client = ArmoniKTasks(channel)
result_client = ArmoniKResults(channel)
session_client = ArmoniKSessions(channel)
events_client = ArmoniKEvents(channel)
session_id = session_client.create_session(TaskOptions(timedelta(seconds=60), 1, 1, ""))
print(f"Created session {session_id}")
session_client.pause_session(session_id)
payload_ids = list(
r.result_id
for r in result_client.create_results(
{str(r): str(r).encode() for r in range(n_tasks)}, session_id
).values()
)
print(f"Submitted payloads {len(payload_ids)}")
result_ids = list(
r.result_id
for r in result_client.create_results_metadata(
[str(r) for r in range(n_tasks)], session_id
).values()
)
print(f"Submitted results {len(result_ids)}")
tasks = task_client.submit_tasks(
session_id,
[
TaskDefinition(payload_id=p, expected_output_ids=[r])
for p, r in zip(payload_ids, result_ids)
],
)
print(f"Submitted tasks {len(tasks)}")
t = Thread(target=wait_and_unpause, args=(session_id,))
start = time.time()
t.start()
print("Waiting on results")
events_client.wait_for_result_availability(result_ids, session_id, bucket_size=100)
end = time.time()
print(end - start)
session_client.close_session(session_id)
session_client.purge_session(session_id)
session_client.delete_session(session_id)

def test_wait_availability2(self):
pytest.skip()
n_tasks = 10000
with grpc.insecure_channel(endpoint) as channel:
task_client = ArmoniKTasks(channel)
result_client = ArmoniKResults(channel)
session_client = ArmoniKSessions(channel)
events_client = ArmoniKEvents(channel)
session_id = session_client.create_session(TaskOptions(timedelta(seconds=60), 1, 1, ""))
print(f"Created session {session_id}")
session_client.pause_session(session_id)
payload_ids = list(
r.result_id
for r in result_client.create_results(
{str(r): str(r).encode() for r in range(n_tasks)}, session_id
).values()
)
print(f"Submitted payloads {len(payload_ids)}")
result_ids = list(
r.result_id
for r in result_client.create_results_metadata(
[str(r) for r in range(n_tasks)], session_id
).values()
)
print(f"Submitted results {len(result_ids)}")
tasks = task_client.submit_tasks(
session_id,
[
TaskDefinition(payload_id=p, expected_output_ids=[r])
for p, r in zip(payload_ids, result_ids)
],
)
print(f"Submitted tasks {len(tasks)}")
t = Thread(target=wait_and_unpause, args=(session_id,))
start = time.time()
t.start()
print("Waiting on results")
events_client.wait_for_result_availability(
result_ids, session_id, bucket_size=100, parallelism=10
)
end = time.time()
print(end - start)
session_client.close_session(session_id)
session_client.purge_session(session_id)
session_client.delete_session(session_id)

0 comments on commit 4f3b30a

Please sign in to comment.