Skip to content

Commit

Permalink
serialization fix
Browse files Browse the repository at this point in the history
Signed-off-by: Andrew Feldman <[email protected]>
  • Loading branch information
afeldman-nm committed Dec 12, 2024
1 parent f1a689c commit e962aa7
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 4 deletions.
4 changes: 2 additions & 2 deletions vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from vllm.v1.executor.gpu_executor import GPUExecutor
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus
from vllm.v1.serial_utils import PickleEncoder
from vllm.v1.serial_utils import PickleEncoder, custom_enc_hook
from vllm.version import __version__ as VLLM_VERSION

logger = init_logger(__name__)
Expand Down Expand Up @@ -517,7 +517,7 @@ def process_output_socket(self, output_path: str):
"""Output socket IO thread."""

# Msgpack serialization encoding.
encoder = msgpack.Encoder()
encoder = msgpack.Encoder(enc_hook=custom_enc_hook)
# Reuse send buffer.
buffer = bytearray()

Expand Down
5 changes: 3 additions & 2 deletions vllm/v1/engine/core_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
EngineCoreProfile, EngineCoreRequest,
EngineCoreRequestType)
from vllm.v1.engine.core import EngineCore, EngineCoreProc
from vllm.v1.serial_utils import PickleEncoder
from vllm.v1.serial_utils import PickleEncoder, custom_ext_hook

logger = init_logger(__name__)

Expand Down Expand Up @@ -124,7 +124,8 @@ def __init__(
):
# Serialization setup.
self.encoder = PickleEncoder()
self.decoder = msgspec.msgpack.Decoder(EngineCoreOutputs)
self.decoder = msgspec.msgpack.Decoder(EngineCoreOutputs,
ext_hook=custom_ext_hook)

# ZMQ setup.
self.ctx = (zmq.asyncio.Context() if asyncio_mode else zmq.Context())
Expand Down
28 changes: 28 additions & 0 deletions vllm/v1/serial_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
import pickle
from typing import Any

import numpy as np
from msgspec import msgpack

CUSTOM_TYPE_CODE_PICKLE = 1
pickle_types = (np.ndarray, )


class PickleEncoder:
Expand All @@ -8,3 +15,24 @@ def encode(self, obj):

def decode(self, data):
return pickle.loads(data)


def custom_enc_hook(obj: Any) -> Any:
if isinstance(obj, pickle_types):
# Return an `Ext` object so msgspec serializes it as an extension type.
return msgpack.Ext(CUSTOM_TYPE_CODE_PICKLE, pickle.dumps(obj))
else:
# Raise a NotImplementedError for other types
raise NotImplementedError(
f"Objects of type {type(obj)} are not supported")


def custom_ext_hook(code: int, data: memoryview) -> Any:
if code == CUSTOM_TYPE_CODE_PICKLE:
# This extension type represents a complex number, decode the data
# buffer accordingly.
return pickle.loads(data)
else:
# Raise a NotImplementedError for other extension type codes
raise NotImplementedError(
f"Extension type code {code} is not supported")

0 comments on commit e962aa7

Please sign in to comment.