Skip to content

Commit

Permalink
Add an ObjectContentType enum to identify object types (i.e. serial…
Browse files Browse the repository at this point in the history
…izer, function or object).

Signed-off-by: rafa-be <[email protected]>
  • Loading branch information
rafa-be committed Jan 23, 2025
1 parent 760ad77 commit e18dca0
Show file tree
Hide file tree
Showing 11 changed files with 89 additions and 17 deletions.
2 changes: 1 addition & 1 deletion scaler/about.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "1.8.16"
__version__ = "1.8.17"
5 changes: 3 additions & 2 deletions scaler/client/agent/object_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,11 @@ async def __send_object_creation(self, instruction: ObjectInstruction):
if not new_object_ids:
return

if b"serializer" in instruction.object_content.object_names:
if ObjectContent.ObjectContentType.Serializer in instruction.object_content.object_types:
if self._sent_serializer_id is not None:
raise ValueError("trying to send multiple serializers.")

serializer_index = instruction.object_content.object_names.index(b"serializer")
serializer_index = instruction.object_content.object_types.index(ObjectContent.ObjectContentType.Serializer)
self._sent_serializer_id = instruction.object_content.object_ids[serializer_index]

new_object_content = ObjectContent.new_msg(
Expand All @@ -82,6 +82,7 @@ async def __send_object_creation(self, instruction: ObjectInstruction):
lambda object_pack: object_pack[0] in new_object_ids,
zip(
instruction.object_content.object_ids,
instruction.object_content.object_types,
instruction.object_content.object_names,
instruction.object_content.object_bytes,
),
Expand Down
19 changes: 16 additions & 3 deletions scaler/client/object_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
@dataclasses.dataclass
class ObjectCache:
object_id: bytes
object_type: ObjectContent.ObjectContentType
object_name: bytes
object_bytes: List[bytes]

Expand Down Expand Up @@ -54,7 +55,8 @@ def commit_send_objects(self):
return

objects_to_send = [
(obj_cache.object_id, obj_cache.object_name, obj_cache.object_bytes) for obj_cache in self._pending_objects
(obj_cache.object_id, obj_cache.object_type, obj_cache.object_name, obj_cache.object_bytes)
for obj_cache in self._pending_objects
]

self._connector.send(
Expand Down Expand Up @@ -100,13 +102,19 @@ def clear(self):
def __construct_serializer(self) -> ObjectCache:
serializer_bytes = cloudpickle.dumps(self._serializer, protocol=pickle.HIGHEST_PROTOCOL)
object_id = generate_serializer_object_id(self._identity)
return ObjectCache(object_id, b"serializer", chunk_to_list_of_bytes(serializer_bytes))
return ObjectCache(
object_id,
ObjectContent.ObjectContentType.Serializer,
b"serializer",
chunk_to_list_of_bytes(serializer_bytes)
)

def __construct_function(self, fn: Callable) -> ObjectCache:
function_bytes = self._serializer.serialize(fn)
object_id = generate_object_id(self._identity, function_bytes)
function_cache = ObjectCache(
object_id,
ObjectContent.ObjectContentType.Function,
getattr(fn, "__name__", f"<func {object_id.hex()[:6]}>").encode(),
chunk_to_list_of_bytes(function_bytes),
)
Expand All @@ -116,4 +124,9 @@ def __construct_object(self, obj: Any, name: Optional[str] = None) -> ObjectCach
object_payload = self._serializer.serialize(obj)
object_id = generate_object_id(self._identity, object_payload)
name_bytes = name.encode() if name else f"<obj {object_id.hex()[-6:]}>".encode()
return ObjectCache(object_id, name_bytes, chunk_to_list_of_bytes(object_payload))
return ObjectCache(
object_id,
ObjectContent.ObjectContentType.Object,
name_bytes,
chunk_to_list_of_bytes(object_payload)
)
11 changes: 9 additions & 2 deletions scaler/protocol/capnp/common.capnp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,13 @@ enum TaskStatus {

struct ObjectContent {
objectIds @0 :List(Data);
objectNames @1 :List(Data);
objectBytes @2 :List(List(Data));
objectTypes @1 :List(ObjectContentType);
objectNames @2 :List(Data);
objectBytes @3 :List(List(Data));

enum ObjectContentType {
serializer @0;
function @1;
object @2;
}
}
19 changes: 18 additions & 1 deletion scaler/protocol/python/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,26 @@ class TaskStatus(enum.Enum):

@dataclasses.dataclass
class ObjectContent(Message):
class ObjectContentType(enum.Enum):
# FIXME: Pycapnp does not support assignment of raw enum values when the enum is itself declared within a list.
# However, assigning the enum's string value works.
# See https://github.com/capnproto/pycapnp/issues/374

Serializer = "serializer"
Function = "function"
Object = "object"

def __init__(self, msg):
super().__init__(msg)

@property
def object_ids(self) -> Tuple[bytes, ...]:
return tuple(self._msg.objectIds)

@property
def object_types(self) -> Tuple[ObjectContentType, ...]:
return tuple(ObjectContent.ObjectContentType(object_type._as_str()) for object_type in self._msg.objectTypes)

@property
def object_names(self) -> Tuple[bytes, ...]:
return tuple(self._msg.objectNames)
Expand All @@ -43,12 +56,16 @@ def object_bytes(self) -> Tuple[List[bytes], ...]:
@staticmethod
def new_msg(
object_ids: Tuple[bytes, ...],
object_types: Tuple[ObjectContentType, ...] = tuple(),
object_names: Tuple[bytes, ...] = tuple(),
object_bytes: Tuple[List[bytes], ...] = tuple(),
) -> "ObjectContent":
return ObjectContent(
_common.ObjectContent(
objectIds=list(object_ids), objectNames=list(object_names), objectBytes=tuple(object_bytes)
objectIds=list(object_ids),
objectTypes=[object_type.value for object_type in object_types],
objectNames=list(object_names),
objectBytes=tuple(object_bytes),
)
)

Expand Down
4 changes: 3 additions & 1 deletion scaler/scheduler/graph_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from scaler.io.async_binder import AsyncBinder
from scaler.io.async_connector import AsyncConnector
from scaler.protocol.python.common import TaskStatus
from scaler.protocol.python.common import ObjectContent, TaskStatus
from scaler.protocol.python.message import GraphTask, GraphTaskCancel, StateGraphTask, Task, TaskCancel, TaskResult
from scaler.scheduler.mixins import ClientManager, GraphTaskManager, ObjectManager, TaskManager
from scaler.utility.graph.topological_sorter import TopologicalSorter
Expand Down Expand Up @@ -250,6 +250,7 @@ async def __clean_all_running_nodes(self, graph_task_id: bytes, result: TaskResu
self._object_manager.on_add_object(
graph_info.client,
new_result_object_id,
ObjectContent.ObjectContentType.Object,
self._object_manager.get_object_name(result_object_id),
self._object_manager.get_object_content(result_object_id),
)
Expand All @@ -271,6 +272,7 @@ async def __clean_all_inactive_nodes(self, graph_task_id: bytes, result: TaskRes
self._object_manager.on_add_object(
graph_info.client,
new_result_object_id,
ObjectContent.ObjectContentType.Object,
self._object_manager.get_object_name(result_object_id),
self._object_manager.get_object_content(result_object_id),
)
Expand Down
10 changes: 9 additions & 1 deletion scaler/scheduler/mixins.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import abc
from typing import List, Optional, Set

from scaler.protocol.python.common import ObjectContent
from scaler.protocol.python.message import (
ClientDisconnect,
ClientHeartbeat,
Expand All @@ -27,7 +28,14 @@ async def on_object_request(self, source: bytes, request: ObjectRequest):
raise NotImplementedError()

@abc.abstractmethod
def on_add_object(self, object_user: bytes, object_id: bytes, object_name: bytes, object_bytes: List[bytes]):
def on_add_object(
self,
object_user: bytes,
object_id: bytes,
object_type: ObjectContent.ObjectContentType,
object_name: bytes,
object_bytes: List[bytes]
):
raise NotImplementedError()

@abc.abstractmethod
Expand Down
27 changes: 22 additions & 5 deletions scaler/scheduler/object_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
class _ObjectCreation(ObjectUsage):
object_id: bytes
object_creator: bytes
object_type: ObjectContent.ObjectContentType
object_name: bytes
object_bytes: List[bytes]

Expand Down Expand Up @@ -71,11 +72,19 @@ async def on_object_request(self, source: bytes, request: ObjectRequest):

logging.error(f"received unknown object request type {request=} from {source=!r}")

def on_add_object(self, object_user: bytes, object_id: bytes, object_name: bytes, object_bytes: List[bytes]):
creation = _ObjectCreation(object_id, object_user, object_name, object_bytes)
def on_add_object(
self,
object_user: bytes,
object_id: bytes,
object_type: ObjectContent.ObjectContentType,
object_name: bytes,
object_bytes: List[bytes]
):
creation = _ObjectCreation(object_id, object_user, object_type, object_name, object_bytes)
logging.debug(
f"add object cache "
f"object_name={creation.object_name!r}, "
f"object_type={creation.object_type}, "
f"object_id={creation.object_id.hex()}, "
f"size={format_bytes(len(creation.object_bytes))}"
)
Expand Down Expand Up @@ -140,12 +149,13 @@ def __on_object_create(self, source: bytes, instruction: ObjectInstruction):
logging.error(f"received object creation from {source!r} for unknown client {instruction.object_user!r}")
return

for object_id, object_name, object_bytes in zip(
for object_id, object_type, object_name, object_bytes in zip(
instruction.object_content.object_ids,
instruction.object_content.object_types,
instruction.object_content.object_names,
instruction.object_content.object_bytes,
):
self.on_add_object(instruction.object_user, object_id, object_name, object_bytes)
self.on_add_object(instruction.object_user, object_id, object_type, object_name, object_bytes)

def __finished_object_storage(self, creation: _ObjectCreation):
logging.debug(
Expand All @@ -158,6 +168,7 @@ def __finished_object_storage(self, creation: _ObjectCreation):

def __construct_response(self, request: ObjectRequest) -> ObjectResponse:
object_ids = []
object_types = []
object_names = []
object_bytes = []
for object_id in request.object_ids:
Expand All @@ -166,10 +177,16 @@ def __construct_response(self, request: ObjectRequest) -> ObjectResponse:

object_info = self._object_storage.get_object(object_id)
object_ids.append(object_info.object_id)
object_types.append(object_info.object_type)
object_names.append(object_info.object_name)
object_bytes.append(object_info.object_bytes)

return ObjectResponse.new_msg(
ObjectResponse.ObjectResponseType.Content,
ObjectContent.new_msg(tuple(request.object_ids), tuple(object_names), tuple(object_bytes)),
ObjectContent.new_msg(
tuple(request.object_ids),
tuple(object_types),
tuple(object_names),
tuple(object_bytes)
),
)
1 change: 1 addition & 0 deletions scaler/worker/agent/processor/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,7 @@ def __send_result(self, source: bytes, task_id: bytes, status: TaskStatus, resul
source,
ObjectContent.new_msg(
(result_object_id,),
(ObjectContent.ObjectContentType.Object,),
(f"<res {result_object_id.hex()[:6]}>".encode(),),
(chunk_to_list_of_bytes(result_bytes),),
),
Expand Down
7 changes: 6 additions & 1 deletion scaler/worker/agent/processor_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,12 @@ async def on_failing_processor(self, processor_id: bytes, process_status: str):
ObjectInstruction.new_msg(
ObjectInstruction.ObjectInstructionType.Create,
source,
ObjectContent.new_msg((result_object_id,), (b"",), (result_object_bytes,)),
ObjectContent.new_msg(
(result_object_id,),
(ObjectContent.ObjectContentType.Object,),
(b"",),
(result_object_bytes,)
),
)
)

Expand Down
1 change: 1 addition & 0 deletions tests/test_worker_object_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def test_object_tracker(self) -> None:
b"client",
ObjectContent.new_msg(
(b"object_1", b"object_2", b"object_3"),
tuple([ObjectContent.ObjectContentType.Object] * 3),
(b"name_1", b"name_2", b"name_3"),
([b"content_1"], [b"content_2"], [b"content_3"]),
),
Expand Down

0 comments on commit e18dca0

Please sign in to comment.