diff --git a/scaler/client/agent/client_agent.py b/scaler/client/agent/client_agent.py index f391353..5360f70 100644 --- a/scaler/client/agent/client_agent.py +++ b/scaler/client/agent/client_agent.py @@ -13,7 +13,6 @@ from scaler.client.serializer.mixins import Serializer from scaler.io.async_connector import AsyncConnector from scaler.protocol.python.message import ( - ClientClearRequest, ClientDisconnect, ClientHeartbeatEcho, ClientShutdownResponse, @@ -141,10 +140,6 @@ async def __on_receive_from_client(self, message: Message): await self._task_manager.on_cancel_graph_task(message) return - if isinstance(message, ClientClearRequest): - await self._object_manager.on_client_clear_request(message) - return - raise TypeError(f"Unknown {message=}") async def __on_receive_from_scheduler(self, message: Message): diff --git a/scaler/client/agent/object_manager.py b/scaler/client/agent/object_manager.py index 88f5e0e..bf47ea5 100644 --- a/scaler/client/agent/object_manager.py +++ b/scaler/client/agent/object_manager.py @@ -4,7 +4,6 @@ from scaler.io.async_connector import AsyncConnector from scaler.protocol.python.common import ObjectContent from scaler.protocol.python.message import ( - ClientClearRequest, ObjectInstruction, ObjectRequest, TaskResult, @@ -30,6 +29,8 @@ async def on_object_instruction(self, instruction: ObjectInstruction): await self.__send_object_creation(instruction) elif instruction.instruction_type == ObjectInstruction.ObjectInstructionType.Delete: await self.__delete_objects(instruction) + elif instruction.instruction_type == ObjectInstruction.ObjectInstructionType.Clear: + await self.clear_all_objects(clear_serializer=False) async def on_object_request(self, object_request: ObjectRequest): assert object_request.request_type == ObjectRequest.ObjectRequestType.Get @@ -43,9 +44,6 @@ def on_task_result(self, task_result: TaskResult): self._sent_object_ids.update(task_result.results) - async def on_client_clear_request(self, client_clear_request: ClientClearRequest): - await self.clear_all_objects(clear_serializer=False) - async def clear_all_objects(self, clear_serializer): cleared_object_ids = self._sent_object_ids.copy() diff --git a/scaler/client/object_buffer.py b/scaler/client/object_buffer.py index 5144d1d..fe04207 100644 --- a/scaler/client/object_buffer.py +++ b/scaler/client/object_buffer.py @@ -8,7 +8,7 @@ from scaler.io.sync_connector import SyncConnector from scaler.io.utility import chunk_to_list_of_bytes from scaler.protocol.python.common import ObjectContent -from scaler.protocol.python.message import ClientClearRequest, ObjectInstruction +from scaler.protocol.python.message import ObjectInstruction from scaler.utility.object_utility import generate_object_id, generate_serializer_object_id @@ -83,13 +83,19 @@ def commit_delete_objects(self): def clear(self): """ - remove all commited and pending objects. + remove all committed and pending objects. """ self._pending_delete_objects.clear() self._pending_objects.clear() - self._connector.send(ClientClearRequest.new_msg()) + self._connector.send( + ObjectInstruction.new_msg( + ObjectInstruction.ObjectInstructionType.Clear, + self._identity, + ObjectContent.new_msg(tuple()), + ) + ) def __construct_serializer(self) -> ObjectCache: serializer_bytes = cloudpickle.dumps(self._serializer, protocol=pickle.HIGHEST_PROTOCOL) diff --git a/scaler/protocol/capnp/message.capnp b/scaler/protocol/capnp/message.capnp index 8e4abb3..c619dd2 100644 --- a/scaler/protocol/capnp/message.capnp +++ b/scaler/protocol/capnp/message.capnp @@ -77,6 +77,7 @@ struct ObjectInstruction { enum ObjectInstructionType { create @0; delete @1; + clear @2; } } @@ -107,9 +108,6 @@ struct DisconnectResponse { worker @0 :Data; } -struct ClientClearRequest { -} - struct ClientDisconnect { disconnectType @0 :DisconnectType; @@ -203,11 +201,9 @@ struct Message { stateTask @19 :StateTask; stateGraphTask @20 :StateGraphTask; - clientClearRequest @21 :ClientClearRequest; - - clientDisconnect @22 :ClientDisconnect; - clientShutdownResponse @23 :ClientShutdownResponse; + clientDisconnect @21 :ClientDisconnect; + clientShutdownResponse @22 :ClientShutdownResponse; - processorInitialized @24 :ProcessorInitialized; + processorInitialized @23 :ProcessorInitialized; } } diff --git a/scaler/protocol/python/message.py b/scaler/protocol/python/message.py index 75f9bb3..6f8414a 100644 --- a/scaler/protocol/python/message.py +++ b/scaler/protocol/python/message.py @@ -290,6 +290,7 @@ class ObjectInstruction(Message): class ObjectInstructionType(enum.Enum): Create = _message.ObjectInstruction.ObjectInstructionType.create Delete = _message.ObjectInstruction.ObjectInstructionType.delete + Clear = _message.ObjectInstruction.ObjectInstructionType.clear def __init__(self, msg): super().__init__(msg) @@ -395,15 +396,6 @@ def new_msg(worker: bytes) -> "DisconnectResponse": return DisconnectResponse(_message.DisconnectResponse(worker=worker)) -class ClientClearRequest(Message): - def __init__(self, msg): - super().__init__(msg) - - @staticmethod - def new_msg() -> "ClientClearRequest": - return ClientClearRequest(_message.ClientClearRequest()) - - class ClientDisconnect(Message): class DisconnectType(enum.Enum): Disconnect = _message.ClientDisconnect.DisconnectType.disconnect @@ -646,7 +638,6 @@ def new_msg() -> "ProcessorInitialized": "stateWorker": StateWorker, "stateTask": StateTask, "stateGraphTask": StateGraphTask, - "clientClearRequest": ClientClearRequest, "clientDisconnect": ClientDisconnect, "clientShutdownResponse": ClientShutdownResponse, "processorInitialized": ProcessorInitialized, diff --git a/scaler/scheduler/object_manager.py b/scaler/scheduler/object_manager.py index 0907e28..5a7a555 100644 --- a/scaler/scheduler/object_manager.py +++ b/scaler/scheduler/object_manager.py @@ -61,7 +61,7 @@ async def on_object_instruction(self, source: bytes, instruction: ObjectInstruct logging.error( f"received unknown object response type instruction_type={instruction.instruction_type} from " - f"source={instruction.object_user}" + f"source={instruction.object_user!r}" ) async def on_object_request(self, source: bytes, request: ObjectRequest):