From 0cdc77abb230ad4e3441084cbf051c01be087659 Mon Sep 17 00:00:00 2001 From: Fangyin Cheng Date: Fri, 29 Dec 2023 12:01:31 +0800 Subject: [PATCH] feat(model): Proxy model support count token (#996) --- dbgpt/model/cluster/worker/default_worker.py | 10 +- dbgpt/model/cluster/worker/manager.py | 4 +- dbgpt/model/proxy/llms/chatgpt.py | 2 +- dbgpt/model/proxy/llms/proxy_model.py | 27 +++ dbgpt/model/utils/chatgpt_utils.py | 9 +- dbgpt/model/utils/token_utils.py | 80 +++++++ dbgpt/serve/conversation/tests/test_models.py | 17 +- dbgpt/serve/prompt/tests/test_models.py | 142 ++++++------ dbgpt/serve/prompt/tests/test_service.py | 6 +- .../tests/test_models.py | 17 +- dbgpt/storage/chat_history/chat_history_db.py | 12 +- dbgpt/storage/metadata/_base_dao.py | 9 +- dbgpt/storage/metadata/db_manager.py | 207 +++++++++--------- dbgpt/storage/metadata/tests/test_base_dao.py | 4 +- .../storage/metadata/tests/test_db_manager.py | 64 +++--- examples/awel/simple_llm_client_example.py | 2 +- 16 files changed, 365 insertions(+), 247 deletions(-) create mode 100644 dbgpt/model/utils/token_utils.py diff --git a/dbgpt/model/cluster/worker/default_worker.py b/dbgpt/model/cluster/worker/default_worker.py index a42858bb4..26f2b40ae 100644 --- a/dbgpt/model/cluster/worker/default_worker.py +++ b/dbgpt/model/cluster/worker/default_worker.py @@ -189,7 +189,7 @@ def generate(self, params: Dict) -> ModelOutput: return output def count_token(self, prompt: str) -> int: - return _try_to_count_token(prompt, self.tokenizer) + return _try_to_count_token(prompt, self.tokenizer, self.model) async def async_count_token(self, prompt: str) -> int: # TODO if we deploy the model by vllm, it can't work, we should run transformer _try_to_count_token to async @@ -454,12 +454,13 @@ def _new_metrics_from_model_output( return metrics -def _try_to_count_token(prompt: str, tokenizer) -> int: +def _try_to_count_token(prompt: str, tokenizer, model) -> int: """Try to count token of prompt Args: prompt (str): prompt tokenizer ([type]): tokenizer + model ([type]): model Returns: int: token count, if error return -1 @@ -467,6 +468,11 @@ def _try_to_count_token(prompt: str, tokenizer) -> int: TODO: More implementation """ try: + from dbgpt.model.proxy.llms.proxy_model import ProxyModel + + if isinstance(model, ProxyModel): + return model.count_token(prompt) + # Only support huggingface model now return len(tokenizer(prompt).input_ids[0]) except Exception as e: logger.warning(f"Count token error, detail: {e}, return -1") diff --git a/dbgpt/model/cluster/worker/manager.py b/dbgpt/model/cluster/worker/manager.py index 83b39370e..8a8af3b35 100644 --- a/dbgpt/model/cluster/worker/manager.py +++ b/dbgpt/model/cluster/worker/manager.py @@ -197,7 +197,7 @@ def add_worker( return True else: # TODO Update worker - logger.warn(f"Instance {worker_key} exist") + logger.warning(f"Instance {worker_key} exist") return False def _remove_worker(self, worker_params: ModelWorkerParameters) -> None: @@ -229,7 +229,7 @@ async def model_startup(self, startup_req: WorkerStartupRequest): ) if not success: msg = f"Add worker {model_name}@{worker_type}, worker instances is exist" - logger.warn(f"{msg}, worker_params: {worker_params}") + logger.warning(f"{msg}, worker_params: {worker_params}") self._remove_worker(worker_params) raise Exception(msg) supported_types = WorkerType.values() diff --git a/dbgpt/model/proxy/llms/chatgpt.py b/dbgpt/model/proxy/llms/chatgpt.py index 1b98e8a8a..b57ae0486 100755 --- a/dbgpt/model/proxy/llms/chatgpt.py +++ b/dbgpt/model/proxy/llms/chatgpt.py @@ -92,11 +92,11 @@ def _initialize_openai_v1(params: ProxyModelParameters): def __convert_2_gpt_messages(messages: List[ModelMessage]): - chat_round = 0 gpt_messages = [] last_usr_message = "" system_messages = [] + # TODO: We can't change message order in low level for message in messages: if message.role == ModelMessageRoleType.HUMAN or message.role == "user": last_usr_message = message.content diff --git a/dbgpt/model/proxy/llms/proxy_model.py b/dbgpt/model/proxy/llms/proxy_model.py index 5d5a3feb4..4e55ec3ea 100644 --- a/dbgpt/model/proxy/llms/proxy_model.py +++ b/dbgpt/model/proxy/llms/proxy_model.py @@ -1,9 +1,36 @@ +from __future__ import annotations + +from typing import Union, List, Optional, TYPE_CHECKING +import logging from dbgpt.model.parameter import ProxyModelParameters +from dbgpt.model.utils.token_utils import ProxyTokenizerWrapper + +if TYPE_CHECKING: + from dbgpt.core.interface.message import ModelMessage, BaseMessage + +logger = logging.getLogger(__name__) class ProxyModel: def __init__(self, model_params: ProxyModelParameters) -> None: self._model_params = model_params + self._tokenizer = ProxyTokenizerWrapper() def get_params(self) -> ProxyModelParameters: return self._model_params + + def count_token( + self, + messages: Union[str, BaseMessage, ModelMessage, List[ModelMessage]], + model_name: Optional[int] = None, + ) -> int: + """Count token of given messages + + Args: + messages (Union[str, BaseMessage, ModelMessage, List[ModelMessage]]): messages to count token + model_name (Optional[int], optional): model name. Defaults to None. + + Returns: + int: token count, -1 if failed + """ + return self._tokenizer.count_token(messages, model_name) diff --git a/dbgpt/model/utils/chatgpt_utils.py b/dbgpt/model/utils/chatgpt_utils.py index 02333b2b1..4e9cfc353 100644 --- a/dbgpt/model/utils/chatgpt_utils.py +++ b/dbgpt/model/utils/chatgpt_utils.py @@ -25,6 +25,7 @@ from dbgpt.model.cluster.client import DefaultLLMClient from dbgpt.model.cluster import WorkerManagerFactory from dbgpt._private.pydantic import model_to_json +from dbgpt.model.utils.token_utils import ProxyTokenizerWrapper if TYPE_CHECKING: import httpx @@ -152,6 +153,7 @@ def __init__( self._context_length = context_length self._client = openai_client self._openai_kwargs = openai_kwargs or {} + self._tokenizer = ProxyTokenizerWrapper() @property def client(self) -> ClientType: @@ -238,10 +240,11 @@ async def get_context_length(self) -> int: async def count_token(self, model: str, prompt: str) -> int: """Count the number of tokens in a given prompt. - TODO: Get the real number of tokens from the openai api or tiktoken package + Args: + model (str): The model name. + prompt (str): The prompt. """ - - raise NotImplementedError() + return self._tokenizer.count_token(prompt, model) class OpenAIStreamingOperator(TransformStreamAbsOperator[ModelOutput, str]): diff --git a/dbgpt/model/utils/token_utils.py b/dbgpt/model/utils/token_utils.py new file mode 100644 index 000000000..281ed5eed --- /dev/null +++ b/dbgpt/model/utils/token_utils.py @@ -0,0 +1,80 @@ +from __future__ import annotations + +from typing import Union, List, Optional, TYPE_CHECKING +import logging + +if TYPE_CHECKING: + from dbgpt.core.interface.message import ModelMessage, BaseMessage + +logger = logging.getLogger(__name__) + + +class ProxyTokenizerWrapper: + def __init__(self) -> None: + self._support_encoding = True + self._encoding_model = None + + def count_token( + self, + messages: Union[str, BaseMessage, ModelMessage, List[ModelMessage]], + model_name: Optional[str] = None, + ) -> int: + """Count token of given messages + + Args: + messages (Union[str, BaseMessage, ModelMessage, List[ModelMessage]]): messages to count token + model_name (Optional[str], optional): model name. Defaults to None. + + Returns: + int: token count, -1 if failed + """ + if not self._support_encoding: + logger.warning( + "model does not support encoding model, can't count token, returning -1" + ) + return -1 + encoding = self._get_or_create_encoding_model(model_name) + cnt = 0 + if isinstance(messages, str): + cnt = len(encoding.encode(messages, disallowed_special=())) + elif isinstance(messages, BaseMessage): + cnt = len(encoding.encode(messages.content, disallowed_special=())) + elif isinstance(messages, ModelMessage): + cnt = len(encoding.encode(messages.content, disallowed_special=())) + elif isinstance(messages, list): + for message in messages: + cnt += len(encoding.encode(message.content, disallowed_special=())) + else: + logger.warning( + "unsupported type of messages, can't count token, returning -1" + ) + return -1 + return cnt + + def _get_or_create_encoding_model(self, model_name: Optional[str] = None): + """Get or create encoding model for given model name + More detail see: https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb + """ + if self._encoding_model: + return self._encoding_model + try: + import tiktoken + + logger.info( + "tiktoken installed, using it to count tokens, tiktoken will download tokenizer from network, " + "also you can download it and put it in the directory of environment variable TIKTOKEN_CACHE_DIR" + ) + except ImportError: + self._support_encoding = False + logger.warn("tiktoken not installed, cannot count tokens, returning -1") + return -1 + try: + if not model_name: + model_name = "gpt-3.5-turbo" + self._encoding_model = tiktoken.model.encoding_for_model(model_name) + except KeyError: + logger.warning( + f"{model_name}'s tokenizer not found, using cl100k_base encoding." + ) + self._encoding_model = tiktoken.get_encoding("cl100k_base") + return self._encoding_model diff --git a/dbgpt/serve/conversation/tests/test_models.py b/dbgpt/serve/conversation/tests/test_models.py index 218307f1a..1d111644d 100644 --- a/dbgpt/serve/conversation/tests/test_models.py +++ b/dbgpt/serve/conversation/tests/test_models.py @@ -1,5 +1,3 @@ -from typing import List - import pytest from dbgpt.storage.metadata import db @@ -39,11 +37,9 @@ def test_table_exist(): def test_entity_create(default_entity_dict): - entity: ServeEntity = ServeEntity.create(**default_entity_dict) - # TODO: implement your test case with db.session() as session: - db_entity: ServeEntity = session.query(ServeEntity).get(entity.id) - assert db_entity.id == entity.id + entity = ServeEntity(**default_entity_dict) + session.add(entity) def test_entity_unique_key(default_entity_dict): @@ -52,10 +48,8 @@ def test_entity_unique_key(default_entity_dict): def test_entity_get(default_entity_dict): - entity: ServeEntity = ServeEntity.create(**default_entity_dict) - db_entity: ServeEntity = ServeEntity.get(entity.id) - assert db_entity.id == entity.id # TODO: implement your test case + pass def test_entity_update(default_entity_dict): @@ -65,10 +59,7 @@ def test_entity_update(default_entity_dict): def test_entity_delete(default_entity_dict): # TODO: implement your test case - entity: ServeEntity = ServeEntity.create(**default_entity_dict) - entity.delete() - db_entity: ServeEntity = ServeEntity.get(entity.id) - assert db_entity is None + pass def test_entity_all(): diff --git a/dbgpt/serve/prompt/tests/test_models.py b/dbgpt/serve/prompt/tests/test_models.py index 744412ec5..cc9988216 100644 --- a/dbgpt/serve/prompt/tests/test_models.py +++ b/dbgpt/serve/prompt/tests/test_models.py @@ -47,9 +47,11 @@ def test_table_exist(): def test_entity_create(default_entity_dict): - entity: ServeEntity = ServeEntity.create(**default_entity_dict) with db.session() as session: - db_entity: ServeEntity = session.query(ServeEntity).get(entity.id) + entity: ServeEntity = ServeEntity(**default_entity_dict) + session.add(entity) + session.commit() + db_entity: ServeEntity = session.get(ServeEntity, entity.id) assert db_entity.id == entity.id assert db_entity.chat_scene == "chat_data" assert db_entity.sub_chat_scene == "excel" @@ -63,78 +65,96 @@ def test_entity_create(default_entity_dict): def test_entity_unique_key(default_entity_dict): - ServeEntity.create(**default_entity_dict) + with db.session() as session: + entity = ServeEntity(**default_entity_dict) + session.add(entity) with pytest.raises(Exception): - ServeEntity.create( - **{ - "prompt_name": "my_prompt_1", - "sys_code": "dbgpt", - "prompt_language": "zh", - "model": "vicuna-13b-v1.5", - } - ) + with db.session() as session: + entity = ServeEntity( + **{ + "prompt_name": "my_prompt_1", + "sys_code": "dbgpt", + "prompt_language": "zh", + "model": "vicuna-13b-v1.5", + } + ) + session.add(entity) def test_entity_get(default_entity_dict): - entity: ServeEntity = ServeEntity.create(**default_entity_dict) - db_entity: ServeEntity = ServeEntity.get(entity.id) - assert db_entity.id == entity.id - assert db_entity.chat_scene == "chat_data" - assert db_entity.sub_chat_scene == "excel" - assert db_entity.prompt_type == "common" - assert db_entity.prompt_name == "my_prompt_1" - assert db_entity.content == "Write a qsort function in python." - assert db_entity.user_name == "zhangsan" - assert db_entity.sys_code == "dbgpt" - assert db_entity.gmt_created is not None - assert db_entity.gmt_modified is not None + with db.session() as session: + entity = ServeEntity(**default_entity_dict) + session.add(entity) + session.commit() + db_entity: ServeEntity = session.get(ServeEntity, entity.id) + assert db_entity.id == entity.id + assert db_entity.chat_scene == "chat_data" + assert db_entity.sub_chat_scene == "excel" + assert db_entity.prompt_type == "common" + assert db_entity.prompt_name == "my_prompt_1" + assert db_entity.content == "Write a qsort function in python." + assert db_entity.user_name == "zhangsan" + assert db_entity.sys_code == "dbgpt" + assert db_entity.gmt_created is not None + assert db_entity.gmt_modified is not None def test_entity_update(default_entity_dict): - entity: ServeEntity = ServeEntity.create(**default_entity_dict) - entity.update(prompt_name="my_prompt_2") - db_entity: ServeEntity = ServeEntity.get(entity.id) - assert db_entity.id == entity.id - assert db_entity.chat_scene == "chat_data" - assert db_entity.sub_chat_scene == "excel" - assert db_entity.prompt_type == "common" - assert db_entity.prompt_name == "my_prompt_2" - assert db_entity.content == "Write a qsort function in python." - assert db_entity.user_name == "zhangsan" - assert db_entity.sys_code == "dbgpt" - assert db_entity.gmt_created is not None - assert db_entity.gmt_modified is not None + with db.session() as session: + entity = ServeEntity(**default_entity_dict) + session.add(entity) + session.commit() + entity.prompt_name = "my_prompt_2" + session.merge(entity) + db_entity: ServeEntity = session.get(ServeEntity, entity.id) + assert db_entity.id == entity.id + assert db_entity.chat_scene == "chat_data" + assert db_entity.sub_chat_scene == "excel" + assert db_entity.prompt_type == "common" + assert db_entity.prompt_name == "my_prompt_2" + assert db_entity.content == "Write a qsort function in python." + assert db_entity.user_name == "zhangsan" + assert db_entity.sys_code == "dbgpt" + assert db_entity.gmt_created is not None + assert db_entity.gmt_modified is not None def test_entity_delete(default_entity_dict): - entity: ServeEntity = ServeEntity.create(**default_entity_dict) - entity.delete() - db_entity: ServeEntity = ServeEntity.get(entity.id) - assert db_entity is None + with db.session() as session: + entity = ServeEntity(**default_entity_dict) + session.add(entity) + session.commit() + session.delete(entity) + session.commit() + db_entity: ServeEntity = session.get(ServeEntity, entity.id) + assert db_entity is None def test_entity_all(): - for i in range(10): - ServeEntity.create( - chat_scene="chat_data", - sub_chat_scene="excel", - prompt_type="common", - prompt_name=f"my_prompt_{i}", - content="Write a qsort function in python.", - user_name="zhangsan", - sys_code="dbgpt", - ) - entities = ServeEntity.all() - assert len(entities) == 10 - for entity in entities: - assert entity.chat_scene == "chat_data" - assert entity.sub_chat_scene == "excel" - assert entity.prompt_type == "common" - assert entity.content == "Write a qsort function in python." - assert entity.user_name == "zhangsan" - assert entity.sys_code == "dbgpt" - assert entity.gmt_created is not None - assert entity.gmt_modified is not None + with db.session() as session: + for i in range(10): + entity = ServeEntity( + chat_scene="chat_data", + sub_chat_scene="excel", + prompt_type="common", + prompt_name=f"my_prompt_{i}", + content="Write a qsort function in python.", + user_name="zhangsan", + sys_code="dbgpt", + ) + session.add(entity) + with db.session() as session: + entities = session.query(ServeEntity).all() + assert len(entities) == 10 + for entity in entities: + assert entity.chat_scene == "chat_data" + assert entity.sub_chat_scene == "excel" + assert entity.prompt_type == "common" + assert entity.content == "Write a qsort function in python." + assert entity.user_name == "zhangsan" + assert entity.sys_code == "dbgpt" + assert entity.gmt_created is not None + assert entity.gmt_modified is not None def test_dao_create(dao, default_entity_dict): diff --git a/dbgpt/serve/prompt/tests/test_service.py b/dbgpt/serve/prompt/tests/test_service.py index 3992e89cb..98e896b00 100644 --- a/dbgpt/serve/prompt/tests/test_service.py +++ b/dbgpt/serve/prompt/tests/test_service.py @@ -75,7 +75,7 @@ def test_config_default_user(service: Service): def test_service_create(service: Service, default_entity_dict): entity: ServerResponse = service.create(ServeRequest(**default_entity_dict)) with db.session() as session: - db_entity: ServeEntity = session.query(ServeEntity).get(entity.id) + db_entity: ServeEntity = session.get(ServeEntity, entity.id) assert db_entity.id == entity.id assert db_entity.chat_scene == "chat_data" assert db_entity.sub_chat_scene == "excel" @@ -92,7 +92,7 @@ def test_service_update(service: Service, default_entity_dict): service.create(ServeRequest(**default_entity_dict)) entity: ServerResponse = service.update(ServeRequest(**default_entity_dict)) with db.session() as session: - db_entity: ServeEntity = session.query(ServeEntity).get(entity.id) + db_entity: ServeEntity = session.get(ServeEntity, entity.id) assert db_entity.id == entity.id assert db_entity.chat_scene == "chat_data" assert db_entity.sub_chat_scene == "excel" @@ -109,7 +109,7 @@ def test_service_get(service: Service, default_entity_dict): service.create(ServeRequest(**default_entity_dict)) entity: ServerResponse = service.get(ServeRequest(**default_entity_dict)) with db.session() as session: - db_entity: ServeEntity = session.query(ServeEntity).get(entity.id) + db_entity: ServeEntity = session.get(ServeEntity, entity.id) assert db_entity.id == entity.id assert db_entity.chat_scene == "chat_data" assert db_entity.sub_chat_scene == "excel" diff --git a/dbgpt/serve/utils/_template_files/default_serve_template/tests/test_models.py b/dbgpt/serve/utils/_template_files/default_serve_template/tests/test_models.py index 218307f1a..1d111644d 100644 --- a/dbgpt/serve/utils/_template_files/default_serve_template/tests/test_models.py +++ b/dbgpt/serve/utils/_template_files/default_serve_template/tests/test_models.py @@ -1,5 +1,3 @@ -from typing import List - import pytest from dbgpt.storage.metadata import db @@ -39,11 +37,9 @@ def test_table_exist(): def test_entity_create(default_entity_dict): - entity: ServeEntity = ServeEntity.create(**default_entity_dict) - # TODO: implement your test case with db.session() as session: - db_entity: ServeEntity = session.query(ServeEntity).get(entity.id) - assert db_entity.id == entity.id + entity = ServeEntity(**default_entity_dict) + session.add(entity) def test_entity_unique_key(default_entity_dict): @@ -52,10 +48,8 @@ def test_entity_unique_key(default_entity_dict): def test_entity_get(default_entity_dict): - entity: ServeEntity = ServeEntity.create(**default_entity_dict) - db_entity: ServeEntity = ServeEntity.get(entity.id) - assert db_entity.id == entity.id # TODO: implement your test case + pass def test_entity_update(default_entity_dict): @@ -65,10 +59,7 @@ def test_entity_update(default_entity_dict): def test_entity_delete(default_entity_dict): # TODO: implement your test case - entity: ServeEntity = ServeEntity.create(**default_entity_dict) - entity.delete() - db_entity: ServeEntity = ServeEntity.get(entity.id) - assert db_entity is None + pass def test_entity_all(): diff --git a/dbgpt/storage/chat_history/chat_history_db.py b/dbgpt/storage/chat_history/chat_history_db.py index 080ee7492..029abadbf 100644 --- a/dbgpt/storage/chat_history/chat_history_db.py +++ b/dbgpt/storage/chat_history/chat_history_db.py @@ -105,12 +105,6 @@ def raw_delete(self, conv_uid: int): chat_history = chat_history.filter(ChatHistoryEntity.conv_uid == conv_uid) chat_history.delete() - def get_by_uid(self, conv_uid: str) -> ChatHistoryEntity: - # return ChatHistoryEntity.query.filter_by(conv_uid=conv_uid).first() - - session = self.get_raw_session() - chat_history = session.query(ChatHistoryEntity) - chat_history = chat_history.filter(ChatHistoryEntity.conv_uid == conv_uid) - result = chat_history.first() - session.close() - return result + def get_by_uid(self, conv_uid: str) -> Optional[ChatHistoryEntity]: + with self.session(commit=False) as session: + return session.query(ChatHistoryEntity).filter_by(conv_uid=conv_uid).first() diff --git a/dbgpt/storage/metadata/_base_dao.py b/dbgpt/storage/metadata/_base_dao.py index 70f25197d..659e06b2e 100644 --- a/dbgpt/storage/metadata/_base_dao.py +++ b/dbgpt/storage/metadata/_base_dao.py @@ -51,7 +51,9 @@ def get_raw_session(self) -> Session: Example: + .. code-block:: python + user = User(name="Edward Snowden") session = self.get_raw_session() session.add(user) @@ -61,7 +63,7 @@ def get_raw_session(self) -> Session: return self._db_manager._session() @contextmanager - def session(self) -> Session: + def session(self, commit: Optional[bool] = True) -> Session: """Provide a transactional scope around a series of operations. If raise an exception, the session will be roll back automatically, otherwise it will be committed. @@ -71,13 +73,16 @@ def session(self) -> Session: with self.session() as session: session.query(User).filter(User.name == 'Edward Snowden').first() + Args: + commit (Optional[bool], optional): Whether to commit the session. Defaults to True. + Returns: Session: A session object. Raises: Exception: Any exception will be raised. """ - with self._db_manager.session() as session: + with self._db_manager.session(commit=commit) as session: yield session def from_request(self, request: QUERY_SPEC) -> T: diff --git a/dbgpt/storage/metadata/db_manager.py b/dbgpt/storage/metadata/db_manager.py index be8f0b1ec..eddd8f4e7 100644 --- a/dbgpt/storage/metadata/db_manager.py +++ b/dbgpt/storage/metadata/db_manager.py @@ -1,8 +1,15 @@ from __future__ import annotations -import abc from contextlib import contextmanager -from typing import TypeVar, Generic, Union, Dict, Optional, Type, Iterator, List +from typing import ( + TypeVar, + Generic, + Union, + Dict, + Optional, + Type, + ClassVar, +) import logging from sqlalchemy import create_engine, URL, Engine from sqlalchemy import orm, inspect, MetaData @@ -13,8 +20,6 @@ declarative_base, DeclarativeMeta, ) -from sqlalchemy.orm.session import _PKIdentityArgument -from sqlalchemy.orm.exc import UnmappedClassError from sqlalchemy.pool import QueuePool from dbgpt.util.string_utils import _to_str @@ -27,16 +32,10 @@ class _QueryObject: """The query object.""" - def __init__(self, db_manager: "DatabaseManager"): - self._db_manager = db_manager - - def __get__(self, obj, type): - try: - mapper = orm.class_mapper(type) - if mapper: - return type.query_class(mapper, session=self._db_manager._session()) - except UnmappedClassError: - return None + def __get__(self, obj: Union[_Model, None], model_cls: type[_Model]): + return model_cls.query_class( + model_cls, session=model_cls.__db_manager__._session() + ) class BaseQuery(orm.Query): @@ -46,7 +45,9 @@ def paginate_query( """Paginate the query. Example: + .. code-block:: python + from dbgpt.storage.metadata import db, Model class User(Model): __tablename__ = "user" @@ -58,10 +59,6 @@ class User(Model): pagination = session.query(User).paginate_query(page=1, page_size=10) print(pagination) - # Or you can use the query object - with db.session() as session: - pagination = User.query.paginate_query(page=1, page_size=10) - print(pagination) Args: page (Optional[int], optional): The page number. Defaults to 1. @@ -86,26 +83,12 @@ class User(Model): class _Model: - """Base class for SQLAlchemy declarative base model. - - With this class, we can use the query object to query the database. - - Examples: - .. code-block:: python - from dbgpt.storage.metadata import db, Model - class User(Model): - __tablename__ = "user" - id = Column(Integer, primary_key=True) - name = Column(String(50)) - fullname = Column(String(50)) + """Base class for SQLAlchemy declarative base model.""" - with db.session() as session: - # User is an instance of _Model, and we can use the query object to query the database. - User.query.filter(User.name == "test").all() - """ + __db_manager__: ClassVar[DatabaseManager] + query_class = BaseQuery - query_class = None - query: Optional[BaseQuery] = None + # query: Optional[BaseQuery] = _QueryObject() def __repr__(self): identity = inspect(self).identity @@ -120,7 +103,9 @@ class DatabaseManager: """The database manager. Examples: + .. code-block:: python + from urllib.parse import quote_plus as urlquote, quote from dbgpt.storage.metadata import DatabaseManager, create_model db = DatabaseManager() @@ -141,21 +126,25 @@ class User(Model): session.add(User(name="test", fullname="test")) # db will commit the session automatically default. # session.commit() - print(User.query.filter(User.name == "test").all()) + assert session.query(User).filter(User.name == "test").first().name == "test" + + # More usage: - # Use CURDMixin APIs to create, update, delete, query the database. with db.session() as session: - User.create(**{"name": "test1", "fullname": "test1"}) - User.create(**{"name": "test2", "fullname": "test1"}) - users = User.all() + session.add(User(name="test1", fullname="test1")) + session.add(User(name="test2", fullname="test1")) + users = session.query(User).all() print(users) user = users[0] - user.update(**{"name": "test1_1111"}) + user.name = "test1_1111" + session.merge(user) + user2 = users[1] # Update user2 by save user2.name = "test2_1111" - user2.save() + session.merge(user2) + session.commit() # Delete user2 user2.delete() """ @@ -189,28 +178,65 @@ def is_initialized(self) -> bool: return self._engine is not None and self._session is not None @contextmanager - def session(self) -> Session: + def session(self, commit: Optional[bool] = True) -> Session: """Get the session with context manager. - If raise any exception, the session will roll back automatically, otherwise, the session will commit automatically. + This context manager handles the lifecycle of a SQLAlchemy session. + It automatically commits or rolls back transactions based on + the execution and handles session closure. - Example: - >>> with db.session() as session: - >>> session.query(...) + The `commit` parameter controls whether the session should commit + changes at the end of the block. This is useful for separating + read and write operations. - Returns: - Session: The session. + Examples: + + .. code-block:: python + + # For write operations (insert, update, delete): + with db.session() as session: + user = User(name="John Doe") + session.add(user) + # session.commit() is called automatically + + # For read-only operations: + with db.session(commit=False) as session: + user = session.query(User).filter_by(name="John Doe").first() + # session.commit() is NOT called, as it's unnecessary for read operations + + Args: + commit (Optional[bool], optional): Whether to commit the session. + If True (default), the session will commit changes at the end + of the block. Use False for read-only operations or when manual + control over commit is needed. Defaults to True. + + Yields: + Session: The SQLAlchemy session object. Raises: - RuntimeError: The database manager is not initialized. - Exception: Any exception. + RuntimeError: Raised if the database manager is not initialized. + Exception: Propagates any exception that occurred within the block. + + Important Notes: + - DetachedInstanceError: This error occurs when trying to access or + modify an instance that has been detached from its session. + DetachedInstanceError can occur in scenarios where the session is + closed, and further interaction with the ORM object is attempted, + especially when accessing lazy-loaded attributes. To avoid this: + a. Ensure required attributes are loaded before session closure. + b. Avoid closing the session before all necessary interactions + with the ORM object are complete. + c. Re-bind the instance to a new session if further interaction + is required after the session is closed. + """ if not self.is_initialized: raise RuntimeError("The database manager is not initialized.") session = self._session() try: yield session - session.commit() + if commit: + session.commit() except: session.rollback() raise @@ -223,7 +249,7 @@ def _make_declarative_base( """Make the declarative base. Args: - base (DeclarativeMeta): The base class. + model (DeclarativeMeta): The base class. Returns: DeclarativeMeta: The declarative base. @@ -232,7 +258,8 @@ def _make_declarative_base( model = declarative_base(cls=model, name="Model") if not getattr(model, "query_class", None): model.query_class = self.Query - model.query = _QueryObject(self) + # model.query = _QueryObject() + model.__db_manager__ = self return model def init_db( @@ -242,6 +269,7 @@ def init_db( base: Optional[DeclarativeMeta] = None, query_class=BaseQuery, override_query_class: Optional[bool] = False, + session_options: Optional[Dict] = None, ): """Initialize the database manager. @@ -251,18 +279,26 @@ def init_db( base (Optional[DeclarativeMeta]): The base class. Defaults to None. query_class (BaseQuery, optional): The query class. Defaults to BaseQuery. override_query_class (Optional[bool], optional): Whether to override the query class. Defaults to False. + session_options (Optional[Dict], optional): The session options. Defaults to None. """ + if session_options is None: + session_options = {} self._db_url = db_url if query_class is not None: self.Query = query_class if base is not None: self._base = base - if not hasattr(base, "query") or override_query_class: - base.query = _QueryObject(self) + # if not hasattr(base, "query") or override_query_class: + # base.query = _QueryObject() if not getattr(base, "query_class", None) or override_query_class: base.query_class = self.Query + if not hasattr(base, "__db_manager__") or override_query_class: + base.__db_manager__ = self self._engine = create_engine(db_url, **(engine_args or {})) - session_factory = sessionmaker(bind=self._engine) + + session_options.setdefault("class_", Session) + session_options.setdefault("query_cls", self.Query) + session_factory = sessionmaker(bind=self._engine, **session_options) self._session = scoped_session(session_factory) self._base.metadata.bind = self._engine @@ -397,35 +433,12 @@ class BaseCRUDMixin(Generic[T]): __abstract__ = True @classmethod - def create(cls: Type[T], **kwargs) -> T: - instance = cls(**kwargs) - return instance.save() - - @classmethod - def all(cls: Type[T]) -> List[T]: - return cls.query.all() - - @classmethod - def get(cls: Type[T], ident: _PKIdentityArgument) -> Optional[T]: - """Get a record by its primary key identifier.""" - - def update(self: T, commit: Optional[bool] = True, **kwargs) -> T: - """Update specific fields of a record.""" - for attr, value in kwargs.items(): - setattr(self, attr, value) - return commit and self.save() or self - - @abc.abstractmethod - def save(self: T, commit: Optional[bool] = True) -> T: - """Save the record.""" - - @abc.abstractmethod - def delete(self: T, commit: Optional[bool] = True) -> None: - """Remove the record from the database.""" + def db(cls) -> DatabaseManager: + """Get the database manager.""" + return cls.__db_manager__ class BaseModel(BaseCRUDMixin[T], _Model, Generic[T]): - """The base model class that includes CRUD convenience methods.""" __abstract__ = True @@ -438,28 +451,14 @@ class CRUDMixin(BaseCRUDMixin[T], Generic[T]): _db_manager: DatabaseManager = db_manager @classmethod - def set_db_manager(cls, db_manager: DatabaseManager): + def set_db(cls, db_manager: DatabaseManager): # TODO: It is hard to replace to user DB Connection cls._db_manager = db_manager @classmethod - def get(cls: Type[T], ident: _PKIdentityArgument) -> Optional[T]: - """Get a record by its primary key identifier.""" - return cls._db_manager._session().get(cls, ident) - - def save(self: T, commit: Optional[bool] = True) -> T: - """Save the record.""" - session = self._db_manager._session() - session.add(self) - if commit: - session.commit() - return self - - def delete(self: T, commit: Optional[bool] = True) -> None: - """Remove the record from the database.""" - session = self._db_manager._session() - session.delete(self) - return commit and session.commit() + def db(cls) -> DatabaseManager: + """Get the database manager.""" + return cls._db_manager class _NewModel(CRUDMixin[T], db_manager.Model, Generic[T]): """Base model class that includes CRUD convenience methods.""" @@ -478,6 +477,7 @@ def initialize_db( engine_args: Optional[Dict] = None, base: Optional[DeclarativeMeta] = None, try_to_create_db: Optional[bool] = False, + session_options: Optional[Dict] = None, ) -> DatabaseManager: """Initialize the database manager. @@ -487,10 +487,11 @@ def initialize_db( engine_args (Optional[Dict], optional): The engine arguments. Defaults to None. base (Optional[DeclarativeMeta]): The base class. Defaults to None. try_to_create_db (Optional[bool], optional): Whether to try to create the database. Defaults to False. + session_options (Optional[Dict], optional): The session options. Defaults to None. Returns: DatabaseManager: The database manager. """ - db.init_db(db_url, engine_args, base) + db.init_db(db_url, engine_args, base, session_options=session_options) if try_to_create_db: try: db.create_all() diff --git a/dbgpt/storage/metadata/tests/test_base_dao.py b/dbgpt/storage/metadata/tests/test_base_dao.py index a537188f2..9728fd482 100644 --- a/dbgpt/storage/metadata/tests/test_base_dao.py +++ b/dbgpt/storage/metadata/tests/test_base_dao.py @@ -100,7 +100,7 @@ def test_update_user(db: DatabaseManager, User: Type[BaseModel], user_dao, user_ # Verify that the user is updated in the database with db.session() as session: - user = session.query(User).get(created_user_response.id) + user = session.get(User, created_user_response.id) assert user.age == 35 @@ -121,7 +121,7 @@ def test_update_user_partial( # Verify that the user is updated in the database with db.session() as session: - user = session.query(User).get(created_user_response.id) + user = session.get(User, created_user_response.id) assert user.age == user_req.age assert user.password == "newpassword" diff --git a/dbgpt/storage/metadata/tests/test_db_manager.py b/dbgpt/storage/metadata/tests/test_db_manager.py index a6ad24caa..aa67787fa 100644 --- a/dbgpt/storage/metadata/tests/test_db_manager.py +++ b/dbgpt/storage/metadata/tests/test_db_manager.py @@ -53,11 +53,10 @@ class User(Model): # Create with db.session() as session: - user = User.create(name="John Doe") + user = User(name="John Doe") session.add(user) - session.commit() - # Read + # # Read with db.session() as session: user = session.query(User).filter_by(name="John Doe").first() assert user is not None @@ -65,12 +64,20 @@ class User(Model): # Update with db.session() as session: user = session.query(User).filter_by(name="John Doe").first() - user.update(name="Jane Doe") + user.name = "Mike Doe" + session.merge(user) + with db.session() as session: + user = session.query(User).filter_by(name="Mike Doe").first() + assert user is not None + session.query(User).filter(User.name == "John Doe").first() is None + # + # # Delete + with db.session() as session: + user = session.query(User).filter_by(name="Mike Doe").first() + session.delete(user) - # Delete with db.session() as session: - user = session.query(User).filter_by(name="Jane Doe").first() - user.delete() + assert len(session.query(User).all()) == 0 def test_crud_mixins(db: DatabaseManager, Model: Type[BaseModel]): @@ -80,20 +87,7 @@ class User(Model): name = Column(String(50)) db.create_all() - - # Create - user = User.create(name="John Doe") - assert User.get(user.id) is not None - users = User.all() - assert len(users) == 1 - - # Update - user.update(name="Bob Doe") - assert User.get(user.id).name == "Bob Doe" - - user = User.get(user.id) - user.delete() - assert User.get(user.id) is None + User.db() == db def test_pagination_query(db: DatabaseManager, Model: Type[BaseModel]): @@ -108,11 +102,10 @@ class User(Model): for i in range(30): user = User(name=f"User {i}") session.add(user) - session.commit() - - users_page_1 = User.query.paginate_query(page=1, per_page=10) - assert len(users_page_1.items) == 10 - assert users_page_1.total_pages == 3 + with db.session() as session: + users_page_1 = session.query(User).paginate_query(page=1, per_page=10) + assert len(users_page_1.items) == 10 + assert users_page_1.total_pages == 3 def test_invalid_pagination(db: DatabaseManager, Model: Type[BaseModel]): @@ -124,9 +117,11 @@ class User(Model): db.create_all() with pytest.raises(ValueError): - User.query.paginate_query(page=0, per_page=10) + with db.session() as session: + session.query(User).paginate_query(page=0, per_page=10) with pytest.raises(ValueError): - User.query.paginate_query(page=1, per_page=-1) + with db.session() as session: + session.query(User).paginate_query(page=1, per_page=-1) def test_set_model_db_manager(db: DatabaseManager, Model: Type[BaseModel]): @@ -142,14 +137,19 @@ class User(Model): new_db = DatabaseManager.build_from( f"sqlite:///{filename}", base=Model, override_query_class=True ) - Model.set_db_manager(new_db) + Model.set_db(new_db) new_db.create_all() db.create_all() assert list(new_db.metadata.tables.keys())[0] == "user" - User.create(**{"name": "John Doe"}) + with new_db.session() as session: + user = User(name="John Doe") + session.add(user) with new_db.session() as session: assert session.query(User).filter_by(name="John Doe").first() is not None with db.session() as session: assert session.query(User).filter_by(name="John Doe").first() is None - assert len(User.query.all()) == 1 - assert User.query.filter(User.name == "John Doe").first().name == "John Doe" + with new_db.session() as session: + session.query(User).all() == 1 + session.query(User).filter( + User.name == "John Doe" + ).first().name == "John Doe" diff --git a/examples/awel/simple_llm_client_example.py b/examples/awel/simple_llm_client_example.py index 5751221c9..ee7f4292a 100644 --- a/examples/awel/simple_llm_client_example.py +++ b/examples/awel/simple_llm_client_example.py @@ -7,7 +7,7 @@ Call with non-streaming response. .. code-block:: shell - DBGPT_SERVER="http://127.0.0.1:5000" + DBGPT_SERVER="http://127.0.0.1:5555" curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/simple_client/chat/completions \ -H "Content-Type: application/json" -d '{ "model": "proxyllm",