From f1f4f4adda78c24ac64217347630a8de72a04a3b Mon Sep 17 00:00:00 2001 From: FangYin Cheng Date: Tue, 12 Dec 2023 09:14:40 +0800 Subject: [PATCH] fix(core): Fix fschat and alembic log conflict (#919) Co-authored-by: chengfangyin2 --- dbgpt/model/cluster/apiserver/api.py | 69 ++++++++++++++++++++++++++-- dbgpt/util/__init__.py | 3 -- dbgpt/util/utils.py | 52 --------------------- pilot/meta_data/alembic.ini | 35 -------------- pilot/meta_data/alembic/env.py | 6 --- 5 files changed, 64 insertions(+), 101 deletions(-) diff --git a/dbgpt/model/cluster/apiserver/api.py b/dbgpt/model/cluster/apiserver/api.py index 89204d052..22263051a 100644 --- a/dbgpt/model/cluster/apiserver/api.py +++ b/dbgpt/model/cluster/apiserver/api.py @@ -13,7 +13,7 @@ from fastapi import Depends, HTTPException from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import StreamingResponse +from fastapi.responses import StreamingResponse, JSONResponse from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer from pydantic import BaseSettings @@ -30,10 +30,7 @@ ModelPermission, UsageInfo, ) -from fastchat.protocol.api_protocol import ( - APIChatCompletionRequest, -) -from fastchat.serve.openai_api_server import create_error_response, check_requests +from fastchat.protocol.api_protocol import APIChatCompletionRequest, ErrorResponse from fastchat.constants import ErrorCode from dbgpt.component import BaseComponent, ComponentType, SystemApp @@ -85,6 +82,68 @@ async def check_api_key( return None +def create_error_response(code: int, message: str) -> JSONResponse: + """Copy from fastchat.serve.openai_api_server.check_requests + + We can't use fastchat.serve.openai_api_server because it has too many dependencies. + """ + return JSONResponse( + ErrorResponse(message=message, code=code).dict(), status_code=400 + ) + + +def check_requests(request) -> Optional[JSONResponse]: + """Copy from fastchat.serve.openai_api_server.create_error_response + + We can't use fastchat.serve.openai_api_server because it has too many dependencies. + """ + # Check all params + if request.max_tokens is not None and request.max_tokens <= 0: + return create_error_response( + ErrorCode.PARAM_OUT_OF_RANGE, + f"{request.max_tokens} is less than the minimum of 1 - 'max_tokens'", + ) + if request.n is not None and request.n <= 0: + return create_error_response( + ErrorCode.PARAM_OUT_OF_RANGE, + f"{request.n} is less than the minimum of 1 - 'n'", + ) + if request.temperature is not None and request.temperature < 0: + return create_error_response( + ErrorCode.PARAM_OUT_OF_RANGE, + f"{request.temperature} is less than the minimum of 0 - 'temperature'", + ) + if request.temperature is not None and request.temperature > 2: + return create_error_response( + ErrorCode.PARAM_OUT_OF_RANGE, + f"{request.temperature} is greater than the maximum of 2 - 'temperature'", + ) + if request.top_p is not None and request.top_p < 0: + return create_error_response( + ErrorCode.PARAM_OUT_OF_RANGE, + f"{request.top_p} is less than the minimum of 0 - 'top_p'", + ) + if request.top_p is not None and request.top_p > 1: + return create_error_response( + ErrorCode.PARAM_OUT_OF_RANGE, + f"{request.top_p} is greater than the maximum of 1 - 'temperature'", + ) + if request.top_k is not None and (request.top_k > -1 and request.top_k < 1): + return create_error_response( + ErrorCode.PARAM_OUT_OF_RANGE, + f"{request.top_k} is out of Range. Either set top_k to -1 or >=1.", + ) + if request.stop is not None and ( + not isinstance(request.stop, str) and not isinstance(request.stop, list) + ): + return create_error_response( + ErrorCode.PARAM_OUT_OF_RANGE, + f"{request.stop} is not valid under any of the given schemas - 'stop'", + ) + + return None + + class APIServer(BaseComponent): name = ComponentType.MODEL_API_SERVER diff --git a/dbgpt/util/__init__.py b/dbgpt/util/__init__.py index e64aa1ff9..83798b7ad 100644 --- a/dbgpt/util/__init__.py +++ b/dbgpt/util/__init__.py @@ -1,8 +1,5 @@ from .utils import ( get_gpu_memory, - StreamToLogger, - disable_torch_init, - pretty_print_semaphore, server_error_msg, get_or_create_event_loop, ) diff --git a/dbgpt/util/utils.py b/dbgpt/util/utils.py index 4cfed28fe..23a78120f 100644 --- a/dbgpt/util/utils.py +++ b/dbgpt/util/utils.py @@ -111,58 +111,6 @@ def _build_logger(logger_name, logging_level=None, logger_filename: str = None): return logger -class StreamToLogger(object): - """ - Fake file-like stream object that redirects writes to a logger instance. - """ - - def __init__(self, logger, log_level=logging.INFO): - self.terminal = sys.stdout - self.logger = logger - self.log_level = log_level - self.linebuf = "" - - def __getattr__(self, attr): - return getattr(self.terminal, attr) - - def write(self, buf): - temp_linebuf = self.linebuf + buf - self.linebuf = "" - for line in temp_linebuf.splitlines(True): - # From the io.TextIOWrapper docs: - # On output, if newline is None, any '\n' characters written - # are translated to the system default line separator. - # By default sys.stdout.write() expects '\n' newlines and then - # translates them so this is still cross platform. - if line[-1] == "\n": - encoded_message = line.encode("utf-8", "ignore").decode("utf-8") - self.logger.log(self.log_level, encoded_message.rstrip()) - else: - self.linebuf += line - - def flush(self): - if self.linebuf != "": - encoded_message = self.linebuf.encode("utf-8", "ignore").decode("utf-8") - self.logger.log(self.log_level, encoded_message.rstrip()) - self.linebuf = "" - - -def disable_torch_init(): - """ - Disable the redundant torch default initialization to accelerate model creation. - """ - import torch - - setattr(torch.nn.Linear, "reset_parameters", lambda self: None) - setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) - - -def pretty_print_semaphore(semaphore): - if semaphore is None: - return "None" - return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})" - - def get_or_create_event_loop() -> asyncio.BaseEventLoop: loop = None try: diff --git a/pilot/meta_data/alembic.ini b/pilot/meta_data/alembic.ini index 485504227..7315483fe 100644 --- a/pilot/meta_data/alembic.ini +++ b/pilot/meta_data/alembic.ini @@ -79,38 +79,3 @@ version_path_separator = os # Use os.pathsep. Default configuration used for ne # ruff.type = exec # ruff.executable = %(here)s/.venv/bin/ruff # ruff.options = --fix REVISION_SCRIPT_FILENAME - -# Logging configuration -[loggers] -keys = root,sqlalchemy,alembic - -[handlers] -keys = console - -[formatters] -keys = generic - -[logger_root] -level = WARN -handlers = console -qualname = - -[logger_sqlalchemy] -level = WARN -handlers = -qualname = sqlalchemy.engine - -[logger_alembic] -level = INFO -handlers = -qualname = alembic - -[handler_console] -class = StreamHandler -args = (sys.stderr,) -level = NOTSET -formatter = generic - -[formatter_generic] -format = %(levelname)-5.5s [%(name)s] %(message)s -datefmt = %H:%M:%S diff --git a/pilot/meta_data/alembic/env.py b/pilot/meta_data/alembic/env.py index 01de1cdcb..ef2e26a75 100644 --- a/pilot/meta_data/alembic/env.py +++ b/pilot/meta_data/alembic/env.py @@ -1,5 +1,3 @@ -from logging.config import fileConfig - from sqlalchemy import engine_from_config from sqlalchemy import pool @@ -11,10 +9,6 @@ # access to the values within the .ini file in use. config = context.config -# Interpret the config file for Python logging. -# This line sets up loggers basically. -if config.config_file_name is not None: - fileConfig(config.config_file_name) # add your model's MetaData object here # for 'autogenerate' support