Skip to content

Commit

Permalink
OAI: Switch to background task for disconnect checks
Browse files Browse the repository at this point in the history
Waiting for request disconnect takes some extra time and allows
generation chunks to pile up, resulting in large payloads being sent
at once not making up a smooth stream.

Use the polling method in non-streaming requests by creating a background
task and then check if the task is done, signifying that the request
has been disconnected.

Signed-off-by: kingbri <[email protected]>
  • Loading branch information
bdashore3 committed May 26, 2024
1 parent 660f9b8 commit d710a1b
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 2 deletions.
6 changes: 5 additions & 1 deletion endpoints/OAI/utils/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
get_generator_error,
handle_request_disconnect,
handle_request_error,
request_disconnect_loop,
)
from common.utils import unwrap
from endpoints.OAI.types.chat_completion import (
Expand Down Expand Up @@ -204,10 +205,13 @@ async def stream_generate_chat_completion(
new_generation = model.container.generate_gen(
prompt, abort_event, **data.to_gen_params()
)
# Create a background task to avoid blocking the loop
disconnect_task = asyncio.create_task(request_disconnect_loop(request))

async for generation in new_generation:
# Sometimes this fires, and sometimes a CancelledError will fire
# Keep both implementations in to avoid the headache
if await request.is_disconnected():
if disconnect_task.done():
abort_event.set()
handle_request_disconnect("Completion generation cancelled by user.")

Expand Down
7 changes: 6 additions & 1 deletion endpoints/OAI/utils/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
get_generator_error,
handle_request_disconnect,
handle_request_error,
request_disconnect_loop,
)
from common.utils import unwrap
from endpoints.OAI.types.completion import (
Expand Down Expand Up @@ -72,10 +73,14 @@ async def stream_generate_completion(
new_generation = model.container.generate_gen(
data.prompt, abort_event, **data.to_gen_params()
)

# Create a background task to avoid blocking the loop
disconnect_task = asyncio.create_task(request_disconnect_loop(request))

async for generation in new_generation:
# Sometimes this fires, and sometimes a CancelledError will fire
# Keep both implementations in to avoid the headache
if await request.is_disconnected():
if disconnect_task.done():
abort_event.set()
handle_request_disconnect("Completion generation cancelled by user.")

Expand Down

0 comments on commit d710a1b

Please sign in to comment.