Skip to content

Commit

Permalink
[fix] handler.stream_events() doesn't yield StopEvent (#16115)
Browse files Browse the repository at this point in the history
  • Loading branch information
nerdai authored Sep 20, 2024
1 parent d3f578c commit 85425cf
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 8 deletions.
5 changes: 3 additions & 2 deletions llama-index-core/llama_index/core/workflow/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,12 @@ async def stream_events(self) -> AsyncGenerator[Event, None]:

while True:
ev = await self.ctx.streaming_queue.get()
if type(ev) is StopEvent:
break

yield ev

if type(ev) is StopEvent:
break

async def run_step(self) -> Optional[Any]:
if self.ctx and not self.ctx.stepwise:
raise ValueError("Stepwise context is required to run stepwise.")
Expand Down
2 changes: 0 additions & 2 deletions llama-index-core/llama_index/core/workflow/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,8 +327,6 @@ async def _run_workflow() -> None:
result.set_result(ctx._retval)
except Exception as e:
result.set_exception(e)
finally:
ctx.write_event_to_stream(StopEvent())

asyncio.create_task(_run_workflow())
return result
Expand Down
12 changes: 8 additions & 4 deletions llama-index-core/tests/workflow/test_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ async def test_e2e():
r = wf.run()

async for ev in r.stream_events():
assert "msg" in ev
if not isinstance(ev, StopEvent):
assert "msg" in ev

await r

Expand Down Expand Up @@ -62,7 +63,8 @@ async def step(self, ctx: Context, ev: StartEvent) -> StopEvent:

# Make sure we don't block indefinitely here because the step raised
async for ev in r.stream_events():
assert ev.test_param == "foo"
if not isinstance(ev, StopEvent):
assert ev.test_param == "foo"

# Make sure the await actually caught the exception
with pytest.raises(ValueError, match="The step raised an error!"):
Expand Down Expand Up @@ -93,10 +95,12 @@ async def test_multiple_ongoing_streams():
stream_2 = wf.run()

async for ev in stream_1.stream_events():
assert "msg" in ev
if not isinstance(ev, StopEvent):
assert "msg" in ev

async for ev in stream_2.stream_events():
assert "msg" in ev
if not isinstance(ev, StopEvent):
assert "msg" in ev


@pytest.mark.asyncio()
Expand Down

0 comments on commit 85425cf

Please sign in to comment.