Skip to content

Commit

Permalink
Fix graph
Browse files Browse the repository at this point in the history
  • Loading branch information
sharpener6 authored and yzard committed Dec 16, 2024
1 parent 9d648a8 commit 3cec31c
Show file tree
Hide file tree
Showing 7 changed files with 97 additions and 57 deletions.
1 change: 1 addition & 0 deletions scaler/client/agent/future_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def on_task_result(self, result: TaskResult):
if result.status == TaskStatus.Failed:
assert len(result.results) == 1
result_object_id = result.results[0]
print(f"received failed result object {result_object_id.hex()} for future={future}")
future.set_result_ready(result_object_id, profile_result)
self._object_id_to_future[result_object_id] = result.status, future
return
Expand Down
16 changes: 15 additions & 1 deletion scaler/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ class Client:
def __init__(
self,
address: str,
event_loop: str = "buildin",
io_threads: int = 1,
profiling: bool = False,
timeout_seconds: int = DEFAULT_CLIENT_TIMEOUT_SECONDS,
heartbeat_interval_seconds: int = DEFAULT_HEARTBEAT_INTERVAL_SECONDS,
Expand All @@ -68,11 +70,16 @@ def __init__(
:param heartbeat_interval_seconds: Frequency of heartbeat to scheduler in seconds
:type heartbeat_interval_seconds: int
"""
self.__initialize__(address, profiling, timeout_seconds, heartbeat_interval_seconds, serializer)

self.__initialize__(
address, event_loop, io_threads, profiling, timeout_seconds, heartbeat_interval_seconds, serializer
)

def __initialize__(
self,
address: str,
event_loop: str,
io_threads: int,
profiling: bool,
timeout_seconds: int,
heartbeat_interval_seconds: int,
Expand All @@ -85,6 +92,8 @@ def __initialize__(

self._client_agent_address = ZMQConfig(ZMQType.inproc, host=f"scaler_client_{uuid.uuid4().hex}")
self._scheduler_address = ZMQConfig.from_string(address)
self._event_loop = event_loop
self._io_threads = io_threads
self._timeout_seconds = timeout_seconds
self._heartbeat_interval_seconds = heartbeat_interval_seconds

Expand Down Expand Up @@ -159,6 +168,8 @@ def fibonacci(client: Client, n: int):

return {
"address": self._scheduler_address.to_address(),
"event_loop": self._event_loop,
"io_threads": self._io_threads,
"profiling": self._profiling,
"timeout_seconds": self._timeout_seconds,
"heartbeat_interval_seconds": self._heartbeat_interval_seconds,
Expand All @@ -168,6 +179,8 @@ def __setstate__(self, state: dict) -> None:
# TODO: fix copy the serializer
self.__initialize__(
address=state["address"],
event_loop=state["event_loop"],
io_threads=state["io_threads"],
profiling=state["profiling"],
timeout_seconds=state["timeout_seconds"],
heartbeat_interval_seconds=state["heartbeat_interval_seconds"],
Expand Down Expand Up @@ -501,6 +514,7 @@ def __construct_graph(
compute_futures[key] = self._future_factory(
task=task_id_to_tasks[node_name_to_task_id[key]], is_delayed=not block, group_task_id=graph_task_id
)
print(f"future[{key}]: task_id={node_name_to_task_id[key].hex()}, future={compute_futures[key]}")

elif key in node_name_to_arguments:
argument, data = node_name_to_arguments[key]
Expand Down
4 changes: 3 additions & 1 deletion scaler/client/future.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ def profiling_info(self) -> ProfileResult:
return self._profiling_info

def set_result_ready(self, object_id: Optional[bytes], profile_result: Optional[ProfileResult] = None) -> None:
print("set_result_ready out")
with self._condition: # type: ignore[attr-defined]
print("set_result_ready in")
if self.done():
raise InvalidStateError(f"invalid future state: {self._state}")

Expand All @@ -61,7 +63,7 @@ def _set_result_or_exception(
self,
result: Optional[Any] = None,
exception: Optional[BaseException] = None,
profiling_info: Optional[ProfileResult] = None
profiling_info: Optional[ProfileResult] = None,
) -> None:
with self._condition: # type: ignore[attr-defined]
if self.cancelled():
Expand Down
3 changes: 2 additions & 1 deletion scaler/protocol/python/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@ class TaskStatus(enum.Enum):
Success = _common.TaskStatus.success # if submit and task is done and get result
Failed = _common.TaskStatus.failed # if submit and task is failed on worker
Canceled = _common.TaskStatus.canceled # if submit and task is canceled
NotFound = _common.TaskStatus.notFound # if submit and task is not found in scheduler
WorkerDied = (
_common.TaskStatus.workerDied
) # if submit and worker died (only happened when scheduler keep_task=False)
NoWorker = _common.TaskStatus.noWorker # if submit and scheduler is full (not implemented yet)

NotFound = _common.TaskStatus.notFound # if task is not found in scheduler, it will send such status to client

# below are only used for monitoring channel, not sent to client
Inactive = _common.TaskStatus.inactive # task is scheduled but not allocate to worker
Running = _common.TaskStatus.running # task is running in worker
Expand Down
104 changes: 55 additions & 49 deletions scaler/scheduler/graph_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class _NodeTaskState(enum.Enum):

class _GraphState(enum.Enum):
Running = enum.auto()
Canceling = enum.auto()
Aborting = enum.auto()


@dataclasses.dataclass
Expand All @@ -43,6 +43,7 @@ class _Graph:
client: bytes
status: _GraphState = dataclasses.field(default=_GraphState.Running)
running_task_ids: Set[bytes] = dataclasses.field(default_factory=set)
abort_result: Optional[TaskResult] = dataclasses.field(default=None)


class VanillaGraphTaskManager(GraphTaskManager, Looper, Reporter):
Expand Down Expand Up @@ -101,26 +102,24 @@ async def on_graph_task_cancel(self, client: bytes, graph_task_cancel: GraphTask
return

graph_task_id = self._task_id_to_graph_task_id[graph_task_cancel.task_id]
graph_info = self._graph_task_id_to_graph[graph_task_id]
if graph_info.status == _GraphState.Canceling:
return

await self.__cancel_one_graph(graph_task_id, TaskResult.new_msg(graph_task_cancel.task_id, TaskStatus.Canceled))
await self.__cancel_all_running_nodes(
graph_task_id, TaskResult.new_msg(graph_task_cancel.task_id, TaskStatus.Canceled)
)

async def on_graph_sub_task_done(self, result: TaskResult):
print(f"task done {result.task_id.hex()=} {result.status=}")
graph_task_id = self._task_id_to_graph_task_id[result.task_id]
graph_info = self._graph_task_id_to_graph[graph_task_id]
if graph_info.status == _GraphState.Canceling:
if result.status in {TaskStatus.Failed, TaskStatus.Canceled, TaskStatus.WorkerDied, TaskStatus.NoWorker}:
await self.__cancel_all_running_nodes(graph_task_id, result)
return

await self.__mark_node_done(result)

if result.status == TaskStatus.Success:
await self.__check_one_graph(graph_task_id)
graph_info = self._graph_task_id_to_graph[graph_task_id]
if graph_info.status == _GraphState.Aborting:
await self.__cancel_all_running_nodes(graph_task_id, result)
return

assert result.status != TaskStatus.Success
await self.__cancel_one_graph(graph_task_id, result)
await self.__mark_node_done(result)
await self.__check_one_graph(graph_task_id)

def is_graph_sub_task(self, task_id: bytes):
return task_id in self._task_id_to_graph_task_id
Expand All @@ -140,6 +139,7 @@ async def __add_new_graph(self, client: bytes, graph_task: GraphTask):
tasks = dict()
depended_task_id_to_task_id: ManyToManyDict[bytes, bytes] = ManyToManyDict()
for task in graph_task.graph:
print(task.task_id.hex())
self._task_id_to_graph_task_id[task.task_id] = graph_task.task_id
tasks[task.task_id] = _TaskInfo(_NodeTaskState.Inactive, task)

Expand All @@ -149,6 +149,8 @@ async def __add_new_graph(self, client: bytes, graph_task: GraphTask):

graph[task.task_id] = required_task_ids

print(f"graph[{graph_task.task_id.hex()}]: {task.task_id.hex()}: {[r.hex() for r in required_task_ids]}")

await self._binder_monitor.send(
StateGraphTask.new_msg(
graph_task.task_id,
Expand All @@ -173,13 +175,26 @@ async def __add_new_graph(self, client: bytes, graph_task: GraphTask):
async def __check_one_graph(self, graph_task_id: bytes):
graph_info = self._graph_task_id_to_graph[graph_task_id]
if not graph_info.sorter.is_active():
print("finish one graph")
await self.__finish_one_graph(graph_task_id, TaskResult.new_msg(graph_task_id, TaskStatus.Success))
return

ready_task_ids = graph_info.sorter.get_ready()
if not ready_task_ids:
return

if graph_info.abort_result is not None:
for task_id in ready_task_ids:
await self.__mark_node_done(
TaskResult.new_msg(
task_id=task_id,
status=graph_info.abort_result.status,
metadata=graph_info.abort_result.metadata,
results=graph_info.abort_result.results,
)
)
return

for task_id in ready_task_ids:
task_info = graph_info.tasks[task_id]
task_info.state = _NodeTaskState.Running
Expand All @@ -195,7 +210,13 @@ async def __check_one_graph(self, graph_task_id: bytes):

await self._task_manager.on_task_new(graph_info.client, task)

if graph_info.abort_result is not None:
await self.__check_one_graph(graph_task_id)

async def __mark_node_done(self, result: TaskResult):
if result.task_id not in self._task_id_to_graph_task_id:
return

graph_task_id = self._task_id_to_graph_task_id.pop(result.task_id)

graph_info = self._graph_task_id_to_graph[graph_task_id]
Expand All @@ -217,6 +238,8 @@ async def __mark_node_done(self, result: TaskResult):
raise ValueError(f"received unexpected task result {result}")

self.__clean_intermediate_result(graph_task_id, result.task_id)

print(f"mark node done: {result.status=} {result.task_id.hex()=}")
graph_info.sorter.done(result.task_id)

if result.task_id in graph_info.running_task_ids:
Expand All @@ -225,25 +248,20 @@ async def __mark_node_done(self, result: TaskResult):
if result.task_id in graph_info.target_task_ids:
await self._binder.send(graph_info.client, result)

async def __cancel_one_graph(self, graph_task_id: bytes, result: TaskResult):
graph_info = self._graph_task_id_to_graph[graph_task_id]
graph_info.status = _GraphState.Canceling

if not self.__is_graph_finished(graph_task_id):
await self.__clean_all_running_nodes(graph_task_id, result)
await self.__clean_all_inactive_nodes(graph_task_id, result)

await self.__finish_one_graph(
graph_task_id, TaskResult.new_msg(result.task_id, result.status, result.metadata, result.results)
)
async def __cancel_all_running_nodes(self, graph_task_id: bytes, result: TaskResult):
if self.__is_graph_finished(graph_task_id):
return

async def __clean_all_running_nodes(self, graph_task_id: bytes, result: TaskResult):
graph_info = self._graph_task_id_to_graph[graph_task_id]
if graph_info.status == _GraphState.Aborting:
return

running_task_ids = graph_info.running_task_ids.copy()
graph_info.abort_result = result
graph_info.status = _GraphState.Aborting

# cancel all running tasks
for task_id in running_task_ids:
task_cancels = list()
for task_id in graph_info.running_task_ids:
new_result_object_ids = []
for result_object_id in result.results:
new_result_object_id = uuid.uuid4().bytes
Expand All @@ -254,33 +272,21 @@ async def __clean_all_running_nodes(self, graph_task_id: bytes, result: TaskResu
self._object_manager.get_object_content(result_object_id),
)
new_result_object_ids.append(new_result_object_id)
task_cancels.append(TaskCancel.new_msg(task_id))

await self._task_manager.on_task_cancel(graph_info.client, TaskCancel.new_msg(task_id))
for task_cancel in task_cancels:
print(f"canceling running task: {task_cancel.task_id.hex()}")
await self._task_manager.on_task_cancel(graph_info.client, task_cancel)
await self.__mark_node_done(
TaskResult.new_msg(task_id, result.status, result.metadata, new_result_object_ids)
TaskResult.new_msg(
task_id=task_cancel.task_id, status=result.status, metadata=result.metadata, results=result.results
)
)

async def __clean_all_inactive_nodes(self, graph_task_id: bytes, result: TaskResult):
graph_info = self._graph_task_id_to_graph[graph_task_id]
while graph_info.sorter.is_active():
ready_task_ids = graph_info.sorter.get_ready()
for task_id in ready_task_ids:
new_result_object_ids = []
for result_object_id in result.results:
new_result_object_id = uuid.uuid4().bytes
self._object_manager.on_add_object(
graph_info.client,
new_result_object_id,
self._object_manager.get_object_name(result_object_id),
self._object_manager.get_object_content(result_object_id),
)
new_result_object_ids.append(new_result_object_id)

await self.__mark_node_done(
TaskResult.new_msg(task_id, result.status, result.metadata, new_result_object_ids)
)
await self.__check_one_graph(graph_task_id)

async def __finish_one_graph(self, graph_task_id: bytes, result: TaskResult):
print("finish graph")
self._client_manager.on_task_finish(graph_task_id)
info = self._graph_task_id_to_graph.pop(graph_task_id)
await self._binder.send(info.client, TaskResult.new_msg(graph_task_id, result.status, results=result.results))
Expand Down
1 change: 1 addition & 0 deletions scaler/scheduler/worker_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ async def on_task_cancel(self, task_cancel: TaskCancel):
async def on_task_result(self, task_result: TaskResult):
worker = self._allocator.remove_task(task_result.task_id)

print(f"on task result {task_result.task_id.hex()=}, {task_result.status=}")
if task_result.status in {TaskStatus.Canceled, TaskStatus.NotFound}:
if worker is not None:
# The worker canceled the task, but the scheduler still had it queued. Re-route the task to another
Expand Down
25 changes: 20 additions & 5 deletions tests/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
import time
import unittest

from scaler import Client, SchedulerClusterCombo
from scaler import Client
from scaler.utility.graph.optimization import cull_graph
from scaler.utility.logging.scoped_logger import ScopedLogger
from scaler.utility.logging.utility import setup_logger
from tests.utility import get_available_tcp_port, logging_test_name
from tests.utility import logging_test_name


def inc(i):
Expand All @@ -25,11 +25,11 @@ class TestGraph(unittest.TestCase):
def setUp(self) -> None:
setup_logger()
logging_test_name(self)
self.address = f"tcp://127.0.0.1:{get_available_tcp_port()}"
self.cluster = SchedulerClusterCombo(address=self.address, n_workers=3, event_loop="builtin")
self.address = f"tcp://127.0.0.1:2345"
# self.cluster = SchedulerClusterCombo(address=self.address, n_workers=3, event_loop="builtin")

def tearDown(self) -> None:
self.cluster.shutdown()
# self.cluster.shutdown()
pass

def test_graph(self):
Expand Down Expand Up @@ -147,6 +147,21 @@ def func(a):

self.assertTrue(all(f.cancelled() for f in futures.values()))

def test_block_false(self):
def throw_error(*args):
time.sleep(1)
raise ValueError("throw error")

def func(*args):
time.sleep(2)
return 0

graph = {"a": 1, "b": 2, "c": 3, "d": (throw_error, "a"), "e": (func, "b"), "f": (func, "e")}

with Client(address=self.address) as client:
futures = client.get(graph, keys=["f", "e", "d"], block=False)
print(futures["e"].result())

def test_cull_graph(self):
graph = {
"a": (lambda *_: None,),
Expand Down

0 comments on commit 3cec31c

Please sign in to comment.