From 58590a6c571b272fe829dbd521b05f1df19b9ab9 Mon Sep 17 00:00:00 2001 From: kingbri Date: Wed, 7 Feb 2024 21:08:21 -0500 Subject: [PATCH] Config: Add option to force streaming off Many APIs automatically ask for request streaming without giving the user the option to turn it off. Therefore, give the user more freedom by giving a server-side kill switch. Signed-off-by: kingbri --- common/args.py | 5 +++++ config_sample.yml | 4 ++++ main.py | 12 ++++++++++-- 3 files changed, 19 insertions(+), 2 deletions(-) diff --git a/common/args.py b/common/args.py index 736b3607..a0745aee 100644 --- a/common/args.py +++ b/common/args.py @@ -135,3 +135,8 @@ def add_developer_args(parser: argparse.ArgumentParser): developer_group.add_argument( "--unsafe-launch", type=str_to_bool, help="Skip Exllamav2 version check" ) + developer_group.add_argument( + "--disable-request-streaming", + type=str_to_bool, + help="Disables API request streaming", + ) diff --git a/config_sample.yml b/config_sample.yml index c50111c1..23b71917 100644 --- a/config_sample.yml +++ b/config_sample.yml @@ -42,6 +42,10 @@ developer: # WARNING: Don't set this unless you know what you're doing! #unsafe_launch: False + # Disable all request streaming (default: False) + # A kill switch for turning off SSE in the API server + #disable_request_streaming: False + # Options for model overrides and loading model: # Overrides the directory to look for models (default: models) diff --git a/main.py b/main.py index c47a9fc0..eaf045a5 100644 --- a/main.py +++ b/main.py @@ -449,7 +449,11 @@ async def generate_completion(request: Request, data: CompletionRequest): if isinstance(data.prompt, list): data.prompt = "\n".join(data.prompt) - if data.stream: + disable_request_streaming = unwrap( + get_developer_config().get("disable_request_streaming"), False + ) + + if data.stream and not disable_request_streaming: async def generator(): """Generator for the generation process.""" @@ -531,7 +535,11 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest f"TemplateError: {str(exc)}", ) from exc - if data.stream: + disable_request_streaming = unwrap( + get_developer_config().get("disable_request_streaming"), False + ) + + if data.stream and not disable_request_streaming: const_id = f"chatcmpl-{uuid4().hex}" async def generator():