diff --git a/endpoints/OAI/router.py b/endpoints/OAI/router.py index 47416c86..0e4f27be 100644 --- a/endpoints/OAI/router.py +++ b/endpoints/OAI/router.py @@ -72,6 +72,104 @@ async def check_model_container(): raise HTTPException(400, error_message) +# Completions endpoint +@router.post( + "/v1/completions", + dependencies=[Depends(check_api_key), Depends(check_model_container)], +) +async def completion_request( + request: Request, data: CompletionRequest +) -> CompletionResponse: + """ + Generates a completion from a prompt. + + If stream = true, this returns an SSE stream. + """ + + model_path = model.container.get_model_path() + + if isinstance(data.prompt, list): + data.prompt = "\n".join(data.prompt) + + disable_request_streaming = unwrap( + config.developer_config().get("disable_request_streaming"), False + ) + + # Set an empty JSON schema if the request wants a JSON response + if data.response_format.type == "json": + data.json_schema = {"type": "object"} + + if data.stream and not disable_request_streaming: + return EventSourceResponse( + stream_generate_completion(data, request, model_path), + ping=maxsize, + ) + else: + generate_task = asyncio.create_task(generate_completion(data, model_path)) + + response = await run_with_request_disconnect( + request, + generate_task, + disconnect_message="Completion generation cancelled by user.", + ) + return response + + +# Chat completions endpoint +@router.post( + "/v1/chat/completions", + dependencies=[Depends(check_api_key), Depends(check_model_container)], +) +async def chat_completion_request( + request: Request, data: ChatCompletionRequest +) -> ChatCompletionResponse: + """ + Generates a chat completion from a prompt. + + If stream = true, this returns an SSE stream. + """ + + if model.container.prompt_template is None: + error_message = handle_request_error( + "Chat completions are disabled because a prompt template is not set.", + exc_info=False, + ).error.message + + raise HTTPException(422, error_message) + + model_path = model.container.get_model_path() + + if isinstance(data.messages, str): + prompt = data.messages + else: + prompt = format_prompt_with_template(data) + + # Set an empty JSON schema if the request wants a JSON response + if data.response_format.type == "json": + data.json_schema = {"type": "object"} + + disable_request_streaming = unwrap( + config.developer_config().get("disable_request_streaming"), False + ) + + if data.stream and not disable_request_streaming: + return EventSourceResponse( + stream_generate_chat_completion(prompt, data, request, model_path), + ping=maxsize, + ) + else: + generate_task = asyncio.create_task( + generate_chat_completion(prompt, data, model_path) + ) + + response = await run_with_request_disconnect( + request, + generate_task, + disconnect_message="Chat completion generation cancelled by user.", + ) + return response + + # Model list endpoint @router.get("/v1/models", dependencies=[Depends(check_api_key)]) @router.get("/v1/model/list", dependencies=[Depends(check_api_key)]) @@ -192,99 +290,6 @@ async def unload_model(): await model.unload_model(skip_wait=True) -@router.get("/v1/templates", dependencies=[Depends(check_api_key)]) -@router.get("/v1/template/list", dependencies=[Depends(check_api_key)]) -async def get_templates() -> TemplateList: - templates = get_all_templates() - template_strings = [template.stem for template in templates] - return TemplateList(data=template_strings) - - -@router.post( - "/v1/template/switch", - dependencies=[Depends(check_admin_key), Depends(check_model_container)], -) -async def switch_template(data: TemplateSwitchRequest): - """Switch the currently loaded template""" - if not data.name: - error_message = handle_request_error( - "New template name not found.", - exc_info=False, - ).error.message - - raise HTTPException(400, error_message) - - try: - model.container.prompt_template = PromptTemplate.from_file(data.name) - except FileNotFoundError as e: - error_message = handle_request_error( - f"The template name {data.name} doesn't exist. Check the spelling?", - exc_info=False, - ).error.message - - raise HTTPException(400, error_message) from e - - -@router.post( - "/v1/template/unload", - dependencies=[Depends(check_admin_key), Depends(check_model_container)], -) -async def unload_template(): - """Unloads the currently selected template""" - - model.container.prompt_template = None - - -# Sampler override endpoints -@router.get("/v1/sampling/overrides", dependencies=[Depends(check_api_key)]) -@router.get("/v1/sampling/override/list", dependencies=[Depends(check_api_key)]) -async def list_sampler_overrides() -> SamplerOverrideListResponse: - """API wrapper to list all currently applied sampler overrides""" - - return SamplerOverrideListResponse( - presets=sampling.get_all_presets(), **sampling.overrides_container.model_dump() - ) - - -@router.post( - "/v1/sampling/override/switch", - dependencies=[Depends(check_admin_key)], -) -async def switch_sampler_override(data: SamplerOverrideSwitchRequest): - """Switch the currently loaded override preset""" - - if data.preset: - try: - sampling.overrides_from_file(data.preset) - except FileNotFoundError as e: - error_message = handle_request_error( - f"Sampler override preset with name {data.preset} does not exist. " - + "Check the spelling?", - exc_info=False, - ).error.message - - raise HTTPException(400, error_message) from e - elif data.overrides: - sampling.overrides_from_dict(data.overrides) - else: - error_message = handle_request_error( - "A sampler override preset or dictionary wasn't provided.", - exc_info=False, - ).error.message - - raise HTTPException(400, error_message) - - -@router.post( - "/v1/sampling/override/unload", - dependencies=[Depends(check_admin_key)], -) -async def unload_sampler_override(): - """Unloads the currently selected override preset""" - - sampling.overrides_from_dict({}) - - @router.post("/v1/download", dependencies=[Depends(check_admin_key)]) async def download_model(request: Request, data: DownloadRequest) -> DownloadResponse: """Downloads a model from HuggingFace.""" @@ -452,99 +457,94 @@ async def get_key_permission( raise HTTPException(400, error_message) from exc -# Completions endpoint +@router.get("/v1/templates", dependencies=[Depends(check_api_key)]) +@router.get("/v1/template/list", dependencies=[Depends(check_api_key)]) +async def get_templates() -> TemplateList: + templates = get_all_templates() + template_strings = [template.stem for template in templates] + return TemplateList(data=template_strings) + + @router.post( - "/v1/completions", - dependencies=[Depends(check_api_key), Depends(check_model_container)], + "/v1/template/switch", + dependencies=[Depends(check_admin_key), Depends(check_model_container)], ) -async def completion_request( - request: Request, data: CompletionRequest -) -> CompletionResponse: - """ - Generates a completion from a prompt. +async def switch_template(data: TemplateSwitchRequest): + """Switch the currently loaded template""" + if not data.name: + error_message = handle_request_error( + "New template name not found.", + exc_info=False, + ).error.message - If stream = true, this returns an SSE stream. - """ + raise HTTPException(400, error_message) - model_path = model.container.get_model_path() + try: + model.container.prompt_template = PromptTemplate.from_file(data.name) + except FileNotFoundError as e: + error_message = handle_request_error( + f"The template name {data.name} doesn't exist. Check the spelling?", + exc_info=False, + ).error.message - if isinstance(data.prompt, list): - data.prompt = "\n".join(data.prompt) + raise HTTPException(400, error_message) from e - disable_request_streaming = unwrap( - config.developer_config().get("disable_request_streaming"), False - ) - # Set an empty JSON schema if the request wants a JSON response - if data.response_format.type == "json": - data.json_schema = {"type": "object"} +@router.post( + "/v1/template/unload", + dependencies=[Depends(check_admin_key), Depends(check_model_container)], +) +async def unload_template(): + """Unloads the currently selected template""" - if data.stream and not disable_request_streaming: - return EventSourceResponse( - stream_generate_completion(data, request, model_path), - ping=maxsize, - ) - else: - generate_task = asyncio.create_task(generate_completion(data, model_path)) + model.container.prompt_template = None - response = await run_with_request_disconnect( - request, - generate_task, - disconnect_message="Completion generation cancelled by user.", - ) - return response + +# Sampler override endpoints +@router.get("/v1/sampling/overrides", dependencies=[Depends(check_api_key)]) +@router.get("/v1/sampling/override/list", dependencies=[Depends(check_api_key)]) +async def list_sampler_overrides() -> SamplerOverrideListResponse: + """API wrapper to list all currently applied sampler overrides""" + + return SamplerOverrideListResponse( + presets=sampling.get_all_presets(), **sampling.overrides_container.model_dump() + ) -# Chat completions endpoint @router.post( - "/v1/chat/completions", - dependencies=[Depends(check_api_key), Depends(check_model_container)], + "/v1/sampling/override/switch", + dependencies=[Depends(check_admin_key)], ) -async def chat_completion_request( - request: Request, data: ChatCompletionRequest -) -> ChatCompletionResponse: - """ - Generates a chat completion from a prompt. +async def switch_sampler_override(data: SamplerOverrideSwitchRequest): + """Switch the currently loaded override preset""" - If stream = true, this returns an SSE stream. - """ + if data.preset: + try: + sampling.overrides_from_file(data.preset) + except FileNotFoundError as e: + error_message = handle_request_error( + f"Sampler override preset with name {data.preset} does not exist. " + + "Check the spelling?", + exc_info=False, + ).error.message - if model.container.prompt_template is None: + raise HTTPException(400, error_message) from e + elif data.overrides: + sampling.overrides_from_dict(data.overrides) + else: error_message = handle_request_error( - "Chat completions are disabled because a prompt template is not set.", + "A sampler override preset or dictionary wasn't provided.", exc_info=False, ).error.message - raise HTTPException(422, error_message) - - model_path = model.container.get_model_path() - - if isinstance(data.messages, str): - prompt = data.messages - else: - prompt = format_prompt_with_template(data) - - # Set an empty JSON schema if the request wants a JSON response - if data.response_format.type == "json": - data.json_schema = {"type": "object"} + raise HTTPException(400, error_message) - disable_request_streaming = unwrap( - config.developer_config().get("disable_request_streaming"), False - ) - if data.stream and not disable_request_streaming: - return EventSourceResponse( - stream_generate_chat_completion(prompt, data, request, model_path), - ping=maxsize, - ) - else: - generate_task = asyncio.create_task( - generate_chat_completion(prompt, data, model_path) - ) +@router.post( + "/v1/sampling/override/unload", + dependencies=[Depends(check_admin_key)], +) +async def unload_sampler_override(): + """Unloads the currently selected override preset""" - response = await run_with_request_disconnect( - request, - generate_task, - disconnect_message="Chat completion generation cancelled by user.", - ) - return response + sampling.overrides_from_dict({})