Skip to content

Commit

Permalink
Re-factor WorkflowHandler.run_step() so user manually emits Event t…
Browse files Browse the repository at this point in the history
…o start next step in worfklow (#16277)

* initial version of stepwise working

* remove stepwise deprecated test

* remove print

* better naming

* add docstring to WorkflowHandler

* revert Workflow.run_step

* fix check for when Workflow is done or if step raised error

* fix error handling

* fix error handling

* update docs

* delete Workflow.run_step
  • Loading branch information
nerdai authored Oct 1, 2024
1 parent f684acc commit 1024574
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 126 deletions.
12 changes: 9 additions & 3 deletions docs/docs/understanding/workflows/observability.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,13 +107,19 @@ In a notebook environment it can be helpful to run a workflow step by step. You
w = ConcurrentFlow(timeout=10, verbose=True)
handler = w.run()

async for _ in handler.run_step():
# inspect context
while not handler.is_done():
# run_step returns the step's output event
ev = await handler.run_step()
# can make modifications to the results before dispatching the event
# val = ev.get("some_key")
# ev.set("some_key", new_val)
# can also inspect context
# val = await handler.ctx.get("key")
handler.ctx.send_event(ev)
continue

# get the result
result = await handler
result = handler.result()
```

You can call `run_step` multiple times to step through the workflow one step at a time.
Expand Down
8 changes: 8 additions & 0 deletions llama-index-core/llama_index/core/workflow/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,14 @@ def __init__(self, workflow: "Workflow", stepwise: bool = False) -> None:
self._tasks: Set[asyncio.Task] = set()
self._broker_log: List[Event] = []
self._step_flags: Dict[str, asyncio.Event] = {}
self._step_event_holding: Optional[Event] = None
self._step_lock: asyncio.Lock = asyncio.Lock()
self._step_condition: asyncio.Condition = asyncio.Condition(
lock=self._step_lock
)
self._step_event_written: asyncio.Condition = asyncio.Condition(
lock=self._step_lock
)
self._accepted_events: List[Tuple[str, str]] = []
self._retval: Any = None
# Streaming machinery
Expand Down
97 changes: 64 additions & 33 deletions llama-index-core/llama_index/core/workflow/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,44 +31,75 @@ async def stream_events(self) -> AsyncGenerator[Event, None]:
if type(ev) is StopEvent:
break

async def run_step(self) -> Optional[Any]:
async def run_step(self) -> Optional[Event]:
"""Runs the next workflow step and returns the output Event.
If return is None, then the workflow is considered done.
Examples:
```python
handler = workflow.run(stepwise=True)
while not handler.is_done():
ev = await handler.run_step()
handler.ctx.send_event(ev)
result = handler.result()
print(result)
```
"""
# since event is sent before calling this method, we need to unblock the event loop
await asyncio.sleep(0)

if self.ctx and not self.ctx.stepwise:
raise ValueError("Stepwise context is required to run stepwise.")

if self.ctx:
# Unblock all pending steps
for flag in self.ctx._step_flags.values():
flag.set()

# Yield back control to the event loop to give an unblocked step
# the chance to run (we won't actually sleep here).
await asyncio.sleep(0)

# See if we're done, or if a step raised any error
we_done = False
exception_raised = None
for t in self.ctx._tasks:
# Check if we're done
if not t.done():
continue

we_done = True
e = t.exception()
if type(e) != WorkflowDone:
exception_raised = e

retval = None
if we_done:
# Remove any reference to the tasks
try:
# Unblock all pending steps
for flag in self.ctx._step_flags.values():
flag.set()

# Yield back control to the event loop to give an unblocked step
# the chance to run (we won't actually sleep here).
await asyncio.sleep(0)

# check if we're done, or if a step raised error
we_done = False
exception_raised = None
retval = None
for t in self.ctx._tasks:
t.cancel()
await asyncio.sleep(0)
retval = self.ctx.get_result()

self.set_result(retval)

if exception_raised:
raise exception_raised
# Check if we're done
if not t.done():
continue

we_done = True
e = t.exception()
if type(e) != WorkflowDone:
exception_raised = e

if we_done:
# Remove any reference to the tasks
for t in self.ctx._tasks:
t.cancel()
await asyncio.sleep(0)

if exception_raised:
raise exception_raised

self.set_result(self.ctx.get_result())
else: # continue with running next step
# notify unblocked task that we're ready to accept next event
async with self.ctx._step_condition:
self.ctx._step_condition.notify()

# Wait to be notified that the new_ev has been written
async with self.ctx._step_event_written:
await self.ctx._step_event_written.wait()
retval = self.ctx._step_event_holding
except Exception as e:
if not self.is_done(): # Avoid InvalidStateError edge case
self.set_exception(e)
raise
else:
raise ValueError("Context is not set!")

Expand Down
62 changes: 8 additions & 54 deletions llama-index-core/llama_index/core/workflow/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ def _start(self, stepwise: bool = False, ctx: Optional[Context] = None) -> Conte
ctx._queues = {}
ctx._step_flags = {}
ctx._retval = None
ctx._step_event_holding = None

for name, step_func in self._get_steps().items():
ctx._queues[name] = asyncio.Queue()
Expand Down Expand Up @@ -258,7 +259,13 @@ async def _task(
elif isinstance(new_ev, InputRequiredEvent):
ctx.write_event_to_stream(new_ev)
else:
ctx.send_event(new_ev)
if stepwise:
async with ctx._step_condition:
await ctx._step_condition.wait()
ctx._step_event_holding = new_ev
ctx._step_event_written.notify() # shares same lock
else:
ctx.send_event(new_ev)

for _ in range(step_config.num_workers):
ctx._tasks.add(
Expand Down Expand Up @@ -351,59 +358,6 @@ async def _run_workflow() -> None:
asyncio.create_task(_run_workflow())
return result

@dispatcher.span
async def run_step(self, **kwargs: Any) -> Optional[Any]:
"""Runs the workflow stepwise until completion."""
warnings.warn(
"run_step() is deprecated, use `workflow.run(stepwise=True)` instead.\n"
"handler = workflow.run(stepwise=True)\n"
"while not handler.is_done():\n"
" result = await handler.run_step()\n"
" print(result)\n"
)

# Check if we need to start a new session
if self._stepwise_context is None:
self._validate()
self._stepwise_context = self._start(stepwise=True)
# Run the first step
self._stepwise_context.send_event(StartEvent(**kwargs))

# Unblock all pending steps
for flag in self._stepwise_context._step_flags.values():
flag.set()

# Yield back control to the event loop to give an unblocked step
# the chance to run (we won't actually sleep here).
await asyncio.sleep(0)

# See if we're done, or if a step raised any error
we_done = False
exception_raised = None
for t in self._stepwise_context._tasks:
# Check if we're done
if not t.done():
continue

we_done = True
e = t.exception()
if type(e) != WorkflowDone:
exception_raised = e

retval = None
if we_done:
# Remove any reference to the tasks
for t in self._stepwise_context._tasks:
t.cancel()
await asyncio.sleep(0)
retval = self._stepwise_context._retval
self._stepwise_context = None

if exception_raised:
raise exception_raised

return retval

def is_done(self) -> bool:
"""Checks if the workflow is done."""
return self._stepwise_context is None
Expand Down
49 changes: 13 additions & 36 deletions llama-index-core/tests/workflow/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,54 +37,31 @@ async def test_workflow_run(workflow):
assert result == "Workflow completed"


@pytest.mark.asyncio()
async def test_deprecated_workflow_run_step(workflow):
workflow._verbose = True

# First step
result = await workflow.run_step()
assert result is None
assert not workflow.is_done()

# Second step
result = await workflow.run_step()
assert result is None
assert not workflow.is_done()

# Final step
result = await workflow.run_step()
assert not workflow.is_done()
assert result is None

# Cleanup step
result = await workflow.run_step()
assert result == "Workflow completed"
assert workflow.is_done()


@pytest.mark.asyncio()
async def test_workflow_run_step(workflow):
handler = workflow.run(stepwise=True)

result = await handler.run_step()
assert result is None
event = await handler.run_step()
assert isinstance(event, OneTestEvent)
assert not handler.is_done()
handler.ctx.send_event(event)

result = await handler.run_step()
assert result is None
event = await handler.run_step()
assert isinstance(event, LastEvent)
assert not handler.is_done()
handler.ctx.send_event(event)

result = await handler.run_step()
assert result is None
event = await handler.run_step()
assert isinstance(event, StopEvent)
assert not handler.is_done()
handler.ctx.send_event(event)

result = await handler.run_step()
assert result is None
assert not handler.is_done()
event = await handler.run_step()
assert event is None

result = await handler.run_step()
assert result == "Workflow completed"
result = await handler
assert handler.is_done()
assert result == "Workflow completed"


@pytest.mark.asyncio()
Expand Down

0 comments on commit 1024574

Please sign in to comment.