Skip to content

Commit

Permalink
Cancel event processing task
Browse files Browse the repository at this point in the history
  • Loading branch information
ajhai committed Nov 4, 2024
1 parent 51a3601 commit 2b8f4f7
Showing 1 changed file with 35 additions and 7 deletions.
42 changes: 35 additions & 7 deletions llmstack/server/consumers.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,13 +128,19 @@ async def connect(self):
self.scope.get("user", None),
self._preview,
)
self._event_response_task = None
self._connected = True
await self.accept()

async def disconnect(self, close_code):
self._connected = False
if self._app_runner:
await self._app_runner.stop()

async def stop(self):
self._connected = False
if self._event_response_task:
self._event_response_task.cancel()
await self.close()

async def _respond_to_event(self, text_data):
Expand All @@ -151,6 +157,10 @@ async def _respond_to_event(self, text_data):
try:
response_iterator = self._app_runner.run(app_runner_request)
async for response in response_iterator:
# Check both cancellation and connection state
if asyncio.current_task().cancelled() or not self._connected:
break

if response.type == AppRunnerStreamingResponseType.OUTPUT_STREAM_CHUNK:
await self.send(text_data=response.model_dump_json())
elif response.type == AppRunnerStreamingResponseType.ERRORS:
Expand Down Expand Up @@ -356,7 +366,7 @@ async def _respond_to_event_old(self, text_data):
self._coordinator_ref.stop()

async def receive(self, text_data):
run_coro_in_new_loop(self._respond_to_event(text_data))
self._event_response_task = run_coro_in_new_loop(self._respond_to_event(text_data), name="respond_to_event")


class AssetStreamConsumer(AsyncWebsocketConsumer):
Expand Down Expand Up @@ -514,10 +524,17 @@ async def connect(self):
processor_slug="",
provider_slug="",
)
self._event_response_task = None
self._app_runner = None
self._connected = True
await self.accept()

async def disconnect(self, close_code):
pass
self._connected = False
if self._event_response_task:
self._event_response_task.cancel()
if self._app_runner:
await self._app_runner.stop()

async def _respond_to_event(self, text_data):
from llmstack.apps.apis import PlaygroundViewSet
Expand All @@ -544,12 +561,15 @@ async def _respond_to_event(self, text_data):
client_request_id=client_request_id, session_id=session_id, input=input_data
)

app_runner = await PlaygroundViewSet().get_app_runner_async(
self._app_runner = await PlaygroundViewSet().get_app_runner_async(
session_id, source, self.scope.get("user", None), input_data, config_data
)
try:
response_iterator = app_runner.run(app_runner_request)
response_iterator = self._app_runner.run(app_runner_request)
async for response in response_iterator:
if not self._connected:
break

if response.type == AppRunnerStreamingResponseType.OUTPUT_STREAM_CHUNK:
await self.send(text_data=response.model_dump_json())
elif response.type == AppRunnerStreamingResponseType.OUTPUT:
Expand All @@ -560,10 +580,10 @@ async def _respond_to_event(self, text_data):
)
except Exception as e:
logger.exception(f"Failed to run app: {e}")
await app_runner.stop()
await self._app_runner.stop()

async def receive(self, text_data):
run_coro_in_new_loop(self._respond_to_event(text_data))
self._event_response_task = run_coro_in_new_loop(self._respond_to_event(text_data))


class StoreAppConsumer(AppConsumer):
Expand Down Expand Up @@ -592,6 +612,8 @@ async def connect(self):
self._app_runner = await AppStoreAppViewSet().get_app_runner_async(
self._session_id, self._app_slug, self._source, self.scope.get("user", None)
)
self._connected = True
self._event_response_task = None
await self.accept()


Expand Down Expand Up @@ -625,15 +647,18 @@ async def connect(self):
},
)

self._connected = True
self._event_response_task = None
await self.accept()

async def disconnect(self, close_code):
self._connected = False
if self._app_runner:
await self._app_runner.stop()
await self.close(code=close_code)

async def receive(self, text_data):
run_coro_in_new_loop(self._respond_to_event(text_data))
self._event_response_task = run_coro_in_new_loop(self._respond_to_event(text_data))

async def _respond_to_event(self, text_data):
from llmstack.assets.stream import AssetStream
Expand All @@ -654,6 +679,9 @@ async def _respond_to_event(self, text_data):

# Iterate till we get the objrefs for input and output audio
async for response in response_iterator:
if not self._connected:
break

if response.type == AppRunnerStreamingResponseType.OUTPUT_STREAM_CHUNK:
deltas = response.data.deltas
if "agent_input_audio_stream" in deltas:
Expand Down

0 comments on commit 2b8f4f7

Please sign in to comment.