Skip to content

Commit

Permalink
Config: Add option to force streaming off
Browse files Browse the repository at this point in the history
Many APIs automatically ask for request streaming without giving
the user the option to turn it off. Therefore, give the user more
freedom by giving a server-side kill switch.

Signed-off-by: kingbri <[email protected]>
  • Loading branch information
bdashore3 committed Feb 8, 2024
1 parent d0027bc commit 58590a6
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 2 deletions.
5 changes: 5 additions & 0 deletions common/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,3 +135,8 @@ def add_developer_args(parser: argparse.ArgumentParser):
developer_group.add_argument(
"--unsafe-launch", type=str_to_bool, help="Skip Exllamav2 version check"
)
developer_group.add_argument(
"--disable-request-streaming",
type=str_to_bool,
help="Disables API request streaming",
)
4 changes: 4 additions & 0 deletions config_sample.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ developer:
# WARNING: Don't set this unless you know what you're doing!
#unsafe_launch: False

# Disable all request streaming (default: False)
# A kill switch for turning off SSE in the API server
#disable_request_streaming: False

# Options for model overrides and loading
model:
# Overrides the directory to look for models (default: models)
Expand Down
12 changes: 10 additions & 2 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,11 @@ async def generate_completion(request: Request, data: CompletionRequest):
if isinstance(data.prompt, list):
data.prompt = "\n".join(data.prompt)

if data.stream:
disable_request_streaming = unwrap(
get_developer_config().get("disable_request_streaming"), False
)

if data.stream and not disable_request_streaming:

async def generator():
"""Generator for the generation process."""
Expand Down Expand Up @@ -531,7 +535,11 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest
f"TemplateError: {str(exc)}",
) from exc

if data.stream:
disable_request_streaming = unwrap(
get_developer_config().get("disable_request_streaming"), False
)

if data.stream and not disable_request_streaming:
const_id = f"chatcmpl-{uuid4().hex}"

async def generator():
Expand Down

0 comments on commit 58590a6

Please sign in to comment.