Skip to content

Commit

Permalink
fix(core): Fix fschat and alembic log conflict (#919)
Browse files Browse the repository at this point in the history
Co-authored-by: chengfangyin2 <[email protected]>
  • Loading branch information
fangyinc and chengfangyin2 authored Dec 12, 2023
1 parent cbba50a commit f1f4f4a
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 101 deletions.
69 changes: 64 additions & 5 deletions dbgpt/model/cluster/apiserver/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from fastapi import Depends, HTTPException
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from fastapi.responses import StreamingResponse, JSONResponse
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer

from pydantic import BaseSettings
Expand All @@ -30,10 +30,7 @@
ModelPermission,
UsageInfo,
)
from fastchat.protocol.api_protocol import (
APIChatCompletionRequest,
)
from fastchat.serve.openai_api_server import create_error_response, check_requests
from fastchat.protocol.api_protocol import APIChatCompletionRequest, ErrorResponse
from fastchat.constants import ErrorCode

from dbgpt.component import BaseComponent, ComponentType, SystemApp
Expand Down Expand Up @@ -85,6 +82,68 @@ async def check_api_key(
return None


def create_error_response(code: int, message: str) -> JSONResponse:
"""Copy from fastchat.serve.openai_api_server.check_requests
We can't use fastchat.serve.openai_api_server because it has too many dependencies.
"""
return JSONResponse(
ErrorResponse(message=message, code=code).dict(), status_code=400
)


def check_requests(request) -> Optional[JSONResponse]:
"""Copy from fastchat.serve.openai_api_server.create_error_response
We can't use fastchat.serve.openai_api_server because it has too many dependencies.
"""
# Check all params
if request.max_tokens is not None and request.max_tokens <= 0:
return create_error_response(
ErrorCode.PARAM_OUT_OF_RANGE,
f"{request.max_tokens} is less than the minimum of 1 - 'max_tokens'",
)
if request.n is not None and request.n <= 0:
return create_error_response(
ErrorCode.PARAM_OUT_OF_RANGE,
f"{request.n} is less than the minimum of 1 - 'n'",
)
if request.temperature is not None and request.temperature < 0:
return create_error_response(
ErrorCode.PARAM_OUT_OF_RANGE,
f"{request.temperature} is less than the minimum of 0 - 'temperature'",
)
if request.temperature is not None and request.temperature > 2:
return create_error_response(
ErrorCode.PARAM_OUT_OF_RANGE,
f"{request.temperature} is greater than the maximum of 2 - 'temperature'",
)
if request.top_p is not None and request.top_p < 0:
return create_error_response(
ErrorCode.PARAM_OUT_OF_RANGE,
f"{request.top_p} is less than the minimum of 0 - 'top_p'",
)
if request.top_p is not None and request.top_p > 1:
return create_error_response(
ErrorCode.PARAM_OUT_OF_RANGE,
f"{request.top_p} is greater than the maximum of 1 - 'temperature'",
)
if request.top_k is not None and (request.top_k > -1 and request.top_k < 1):
return create_error_response(
ErrorCode.PARAM_OUT_OF_RANGE,
f"{request.top_k} is out of Range. Either set top_k to -1 or >=1.",
)
if request.stop is not None and (
not isinstance(request.stop, str) and not isinstance(request.stop, list)
):
return create_error_response(
ErrorCode.PARAM_OUT_OF_RANGE,
f"{request.stop} is not valid under any of the given schemas - 'stop'",
)

return None


class APIServer(BaseComponent):
name = ComponentType.MODEL_API_SERVER

Expand Down
3 changes: 0 additions & 3 deletions dbgpt/util/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
from .utils import (
get_gpu_memory,
StreamToLogger,
disable_torch_init,
pretty_print_semaphore,
server_error_msg,
get_or_create_event_loop,
)
52 changes: 0 additions & 52 deletions dbgpt/util/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,58 +111,6 @@ def _build_logger(logger_name, logging_level=None, logger_filename: str = None):
return logger


class StreamToLogger(object):
"""
Fake file-like stream object that redirects writes to a logger instance.
"""

def __init__(self, logger, log_level=logging.INFO):
self.terminal = sys.stdout
self.logger = logger
self.log_level = log_level
self.linebuf = ""

def __getattr__(self, attr):
return getattr(self.terminal, attr)

def write(self, buf):
temp_linebuf = self.linebuf + buf
self.linebuf = ""
for line in temp_linebuf.splitlines(True):
# From the io.TextIOWrapper docs:
# On output, if newline is None, any '\n' characters written
# are translated to the system default line separator.
# By default sys.stdout.write() expects '\n' newlines and then
# translates them so this is still cross platform.
if line[-1] == "\n":
encoded_message = line.encode("utf-8", "ignore").decode("utf-8")
self.logger.log(self.log_level, encoded_message.rstrip())
else:
self.linebuf += line

def flush(self):
if self.linebuf != "":
encoded_message = self.linebuf.encode("utf-8", "ignore").decode("utf-8")
self.logger.log(self.log_level, encoded_message.rstrip())
self.linebuf = ""


def disable_torch_init():
"""
Disable the redundant torch default initialization to accelerate model creation.
"""
import torch

setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)


def pretty_print_semaphore(semaphore):
if semaphore is None:
return "None"
return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"


def get_or_create_event_loop() -> asyncio.BaseEventLoop:
loop = None
try:
Expand Down
35 changes: 0 additions & 35 deletions pilot/meta_data/alembic.ini
Original file line number Diff line number Diff line change
Expand Up @@ -79,38 +79,3 @@ version_path_separator = os # Use os.pathsep. Default configuration used for ne
# ruff.type = exec
# ruff.executable = %(here)s/.venv/bin/ruff
# ruff.options = --fix REVISION_SCRIPT_FILENAME

# Logging configuration
[loggers]
keys = root,sqlalchemy,alembic

[handlers]
keys = console

[formatters]
keys = generic

[logger_root]
level = WARN
handlers = console
qualname =

[logger_sqlalchemy]
level = WARN
handlers =
qualname = sqlalchemy.engine

[logger_alembic]
level = INFO
handlers =
qualname = alembic

[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic

[formatter_generic]
format = %(levelname)-5.5s [%(name)s] %(message)s
datefmt = %H:%M:%S
6 changes: 0 additions & 6 deletions pilot/meta_data/alembic/env.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from logging.config import fileConfig

from sqlalchemy import engine_from_config
from sqlalchemy import pool

Expand All @@ -11,10 +9,6 @@
# access to the values within the .ini file in use.
config = context.config

# Interpret the config file for Python logging.
# This line sets up loggers basically.
if config.config_file_name is not None:
fileConfig(config.config_file_name)

# add your model's MetaData object here
# for 'autogenerate' support
Expand Down

0 comments on commit f1f4f4a

Please sign in to comment.