Skip to content

Commit

Permalink
Enable context serialization in workflows (#16250)
Browse files Browse the repository at this point in the history
  • Loading branch information
logan-markewich authored Oct 2, 2024
1 parent f19b36e commit ba2cc90
Show file tree
Hide file tree
Showing 6 changed files with 339 additions and 12 deletions.
91 changes: 90 additions & 1 deletion llama-index-core/llama_index/core/workflow/context.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import asyncio
import json
import warnings
from collections import defaultdict
from typing import Dict, Any, Optional, List, Type, TYPE_CHECKING, Set, Tuple

from .context_serializers import BaseSerializer, JsonSerializer
from .decorators import StepConfig
from .events import Event
from .errors import WorkflowRuntimeError
Expand All @@ -22,8 +24,14 @@ class Context:
Both `set` and `get` operations on global data are governed by a lock, and considered coroutine-safe.
"""

def __init__(self, workflow: "Workflow", stepwise: bool = False) -> None:
def __init__(
self,
workflow: "Workflow",
stepwise: bool = False,
) -> None:
self.stepwise = stepwise
self.is_running = False

self._workflow = workflow
# Broker machinery
self._queues: Dict[str, asyncio.Queue] = {}
Expand All @@ -49,6 +57,87 @@ def __init__(self, workflow: "Workflow", stepwise: bool = False) -> None:
# Step-specific instance
self._events_buffer: Dict[Type[Event], List[Event]] = defaultdict(list)

def _serialize_queue(self, queue: asyncio.Queue, serializer: BaseSerializer) -> str:
queue_items = list(queue._queue) # type: ignore
queue_objs = [serializer.serialize(obj) for obj in queue_items]
return json.dumps(queue_objs) # type: ignore

def _deserialize_queue(
self, queue_str: str, serializer: BaseSerializer
) -> asyncio.Queue:
queue_objs = json.loads(queue_str)
queue = asyncio.Queue() # type: ignore
for obj in queue_objs:
event_obj = serializer.deserialize(obj)
queue.put_nowait(event_obj)
return queue

def _serialize_globals(self, serializer: BaseSerializer) -> Dict[str, Any]:
serialized_globals = {}
for key, value in self._globals.items():
try:
serialized_globals[key] = serializer.serialize(value)
except Exception as e:
raise ValueError(f"Failed to serialize value for key {key}: {e}")
return serialized_globals

def _deserialize_globals(
self, serialized_globals: Dict[str, Any], serializer: BaseSerializer
) -> Dict[str, Any]:
deserialized_globals = {}
for key, value in serialized_globals.items():
try:
deserialized_globals[key] = serializer.deserialize(value)
except Exception as e:
raise ValueError(f"Failed to deserialize value for key {key}: {e}")
return deserialized_globals

def to_dict(self, serializer: Optional[BaseSerializer] = None) -> Dict[str, Any]:
serializer = serializer or JsonSerializer()

return {
"globals": self._serialize_globals(serializer),
"streaming_queue": self._serialize_queue(self._streaming_queue, serializer),
"queues": {
k: self._serialize_queue(v, serializer) for k, v in self._queues.items()
},
"stepwise": self.stepwise,
"events_buffer": {
k: [serializer.serialize(ev) for ev in v]
for k, v in self._events_buffer.items()
},
"accepted_events": self._accepted_events,
"broker_log": [serializer.serialize(ev) for ev in self._broker_log],
"is_running": self.is_running,
}

@classmethod
def from_dict(
cls,
workflow: "Workflow",
data: Dict[str, Any],
serializer: Optional[BaseSerializer] = None,
) -> "Context":
serializer = serializer or JsonSerializer()

context = cls(workflow, stepwise=data["stepwise"])
context._globals = context._deserialize_globals(data["globals"], serializer)
context._queues = {
k: context._deserialize_queue(v, serializer)
for k, v in data["queues"].items()
}
context._streaming_queue = context._deserialize_queue(
data["streaming_queue"], serializer
)
context._events_buffer = {
k: [serializer.deserialize(ev) for ev in v]
for k, v in data["events_buffer"].items()
}
context._accepted_events = data["accepted_events"]
context._broker_log = [serializer.deserialize(ev) for ev in data["broker_log"]]
context.is_running = data["is_running"]
return context

async def set(self, key: str, value: Any, make_private: bool = False) -> None:
"""Store `value` into the Context under `key`.
Expand Down
77 changes: 77 additions & 0 deletions llama-index-core/llama_index/core/workflow/context_serializers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import base64
import json
import pickle
from abc import ABC, abstractmethod
from typing import Any
from pydantic import BaseModel

from llama_index.core.schema import BaseComponent
from .utils import import_module_from_qualified_name, get_qualified_name


class BaseSerializer(ABC):
@abstractmethod
def serialize(self, value: Any) -> str:
...

@abstractmethod
def deserialize(self, value: str) -> Any:
...


class JsonSerializer(BaseSerializer):
def serialize(self, value: Any) -> str:
if isinstance(value, BaseComponent):
return json.dumps(
{
"__is_component": True,
"value": value.to_dict(),
"qualified_name": get_qualified_name(value),
}
)
elif isinstance(value, BaseModel):
return json.dumps(
{
"__is_pydantic": True,
"value": value.model_dump(),
"qualified_name": get_qualified_name(value),
}
)

return json.dumps(value)

def deserialize(self, value: str) -> Any:
data = json.loads(value)

if (
isinstance(data, dict)
and data.get("__is_pydantic")
and data.get("qualified_name")
):
module_class = import_module_from_qualified_name(data["qualified_name"])
return module_class.model_validate(data["value"])
elif (
isinstance(data, dict)
and data.get("__is_component")
and data.get("qualified_name")
):
module_class = import_module_from_qualified_name(data["qualified_name"])
return module_class.from_dict(data["value"])

return data


class JsonPickleSerializer(JsonSerializer):
def serialize(self, value: Any) -> str:
"""Serialize while prioritizing JSON, falling back to Pickle."""
try:
return super().serialize(value)
except Exception:
return base64.b64encode(pickle.dumps(value)).decode("utf-8")

def deserialize(self, value: str) -> Any:
"""Deserialize while prioritizing Pickle, falling back to JSON."""
try:
return pickle.loads(base64.b64decode(value))
except Exception:
return super().deserialize(value)
3 changes: 3 additions & 0 deletions llama-index-core/llama_index/core/workflow/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ async def run_step(self) -> Optional[Event]:
t.cancel()
await asyncio.sleep(0)

# the context is no longer running
self.ctx.is_running = False

if exception_raised:
raise exception_raised

Expand Down
13 changes: 13 additions & 0 deletions llama-index-core/llama_index/core/workflow/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import inspect
from importlib import import_module
from typing import (
get_args,
get_origin,
Expand Down Expand Up @@ -173,3 +174,15 @@ def is_free_function(qualname: str) -> bool:
return False
else:
return toks[-2] == "<locals>"


def get_qualified_name(value: Any) -> str:
"""Get the qualified name of a value."""
return value.__module__ + "." + value.__class__.__name__


def import_module_from_qualified_name(qualified_name: str) -> Any:
"""Import a module from a qualified name."""
module_path = qualified_name.rsplit(".", 1)
module = import_module(module_path[0])
return getattr(module, module_path[1])
35 changes: 24 additions & 11 deletions llama-index-core/llama_index/core/workflow/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,15 +59,20 @@ def __init__(
"""Create an instance of the workflow.
Args:
timeout: number of seconds after the workflow execution will be halted, raising a `WorkflowTimeoutError`
timeout:
Number of seconds after the workflow execution will be halted, raising a `WorkflowTimeoutError`
exception. If set to `None`, the timeout will be disabled.
disable_validaton: whether or not the workflow should be validated before running. In case the workflow is
disable_validaton:
Whether or not the workflow should be validated before running. In case the workflow is
misconfigured, a call to `run` will raise a `WorkflowValidationError` exception explaining the details
of the problem.
verbose: whether or not the workflow should print additional informative messages during execution.
service_manager: The instance of the `ServiceManager` used to make nested workflows available to this
verbose:
Whether or not the workflow should print additional informative messages during execution.
service_manager:
The instance of the `ServiceManager` used to make nested workflows available to this
workflow instance. The default value is the best choice unless you're customizing the workflow runtime.
num_concurrent_runs: maximum number of .run() executions occurring simultaneously. If set to `None`, there
num_concurrent_runs:
maximum number of .run() executions occurring simultaneously. If set to `None`, there
is no limit to this number.
"""
# Configuration
Expand Down Expand Up @@ -164,15 +169,16 @@ def _start(self, stepwise: bool = False, ctx: Optional[Context] = None) -> Conte
else:
# clean up the context from the previous run
ctx._tasks = set()
ctx._queues = {}
ctx._step_flags = {}
ctx._retval = None
ctx._step_event_holding = None
ctx._cancel_flag.clear()

for name, step_func in self._get_steps().items():
ctx._queues[name] = asyncio.Queue()
ctx._step_flags[name] = asyncio.Event()
if name not in ctx._queues:
ctx._queues[name] = asyncio.Queue()

if name not in ctx._step_flags:
ctx._step_flags[name] = asyncio.Event()

# At this point, step_func is guaranteed to have the `__step_config` attribute
step_config: StepConfig = getattr(step_func, "__step_config")
Expand Down Expand Up @@ -332,8 +338,12 @@ async def _run_workflow() -> None:
if self._sem:
await self._sem.acquire()
try:
# Send the first event
ctx.send_event(StartEvent(**kwargs))
if not ctx.is_running:
# Send the first event
ctx.send_event(StartEvent(**kwargs))

# the context is now running
ctx.is_running = True

done, unfinished = await asyncio.wait(
ctx._tasks,
Expand All @@ -358,6 +368,9 @@ async def _run_workflow() -> None:
# wait for cancelled tasks to cleanup
await asyncio.gather(*unfinished, return_exceptions=True)

# the context is no longer running
ctx.is_running = False

if exception_raised:
ctx.write_event_to_stream(StopEvent())
raise exception_raised
Expand Down
Loading

0 comments on commit ba2cc90

Please sign in to comment.