Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Refactor for core SDK #1092

Merged
merged 4 commits into from
Jan 21, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,19 @@ clean: ## Clean up the environment
find . -type d -name '.pytest_cache' -delete
find . -type d -name '.coverage' -delete

.PHONY: clean-dist
clean-dist: ## Clean up the distribution
rm -rf dist/ *.egg-info build/

.PHONY: package
package: clean-dist ## Package the project for distribution
IS_DEV_MODE=false python setup.py sdist bdist_wheel

.PHONY: upload
upload: package ## Upload the package to PyPI
# upload to testpypi: twine upload --repository testpypi dist/*
twine upload dist/*

.PHONY: help
help: ## Display this help screen
@echo "Available commands:"
Expand Down
12 changes: 9 additions & 3 deletions dbgpt/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
from dbgpt.component import BaseComponent, SystemApp

__ALL__ = ["SystemApp", "BaseComponent"]
"""DB-GPT: Next Generation Data Interaction Solution with LLMs.
"""
from dbgpt import _version # noqa: E402
from dbgpt.component import BaseComponent, SystemApp # noqa: F401

_CORE_LIBS = ["core", "rag", "model", "agent", "datasource", "vis", "storage", "train"]
_SERVE_LIBS = ["serve"]
_LIBS = _CORE_LIBS + _SERVE_LIBS


__version__ = _version.version

__ALL__ = ["__version__", "SystemApp", "BaseComponent"]


def __getattr__(name: str):
# Lazy load
import importlib
Expand Down
1 change: 1 addition & 0 deletions dbgpt/_version.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
version = "0.4.7"
Original file line number Diff line number Diff line change
Expand Up @@ -579,7 +579,7 @@ def _num_token_from_text(self, text: str, model: str = "gpt-3.5-turbo-0613"):
from dbgpt.agent.agents.agent import AgentContext
from dbgpt.agent.agents.user_proxy_agent import UserProxyAgent
from dbgpt.core.interface.llm import ModelMetadata
from dbgpt.model import OpenAILLMClient
from dbgpt.model.proxy import OpenAILLMClient

llm_client = OpenAILLMClient()
context: AgentContext = AgentContext(
Expand Down
4 changes: 2 additions & 2 deletions dbgpt/app/chat_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from typing import Dict, List, Tuple

from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
from dbgpt.model.conversation import Conversation, get_conv_template
from dbgpt.model.llm.conversation import Conversation, get_conv_template


class BaseChatAdpter:
Expand All @@ -21,7 +21,7 @@ def match(self, model_path: str):

def get_generate_stream_func(self, model_path: str):
"""Return the generate stream handler func"""
from dbgpt.model.inference import generate_stream
from dbgpt.model.llm.inference import generate_stream

return generate_stream

Expand Down
4 changes: 4 additions & 0 deletions dbgpt/component.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
"""Component module for dbgpt.

Manages the lifecycle and registration of components.
"""
from __future__ import annotations

import asyncio
Expand Down
18 changes: 9 additions & 9 deletions dbgpt/core/awel/trigger/http_trigger.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
"""Http trigger for AWEL."""
from __future__ import annotations

import logging
from enum import Enum
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, Union, cast

from starlette.requests import Request

from dbgpt._private.pydantic import BaseModel

from ..dag.base import DAG
Expand All @@ -15,9 +11,10 @@

if TYPE_CHECKING:
from fastapi import APIRouter
from starlette.requests import Request

RequestBody = Union[Type[Request], Type[BaseModel], Type[str]]
StreamingPredictFunc = Callable[[Union[Request, BaseModel]], bool]
RequestBody = Union[Type[Request], Type[BaseModel], Type[str]]
StreamingPredictFunc = Callable[[Union[Request, BaseModel]], bool]

logger = logging.getLogger(__name__)

Expand All @@ -32,9 +29,9 @@ def __init__(
self,
endpoint: str,
methods: Optional[Union[str, List[str]]] = "GET",
request_body: Optional[RequestBody] = None,
request_body: Optional["RequestBody"] = None,
streaming_response: bool = False,
streaming_predict_func: Optional[StreamingPredictFunc] = None,
streaming_predict_func: Optional["StreamingPredictFunc"] = None,
response_model: Optional[Type] = None,
response_headers: Optional[Dict[str, str]] = None,
response_media_type: Optional[str] = None,
Expand Down Expand Up @@ -69,6 +66,7 @@ def mount_to_router(self, router: "APIRouter") -> None:
router (APIRouter): The router to mount the trigger.
"""
from fastapi import Depends
from starlette.requests import Request

methods = [self._methods] if isinstance(self._methods, str) else self._methods

Expand Down Expand Up @@ -114,8 +112,10 @@ async def route_function(body=Depends(_request_body_dependency)):


async def _parse_request_body(
request: Request, request_body_cls: Optional[RequestBody]
request: "Request", request_body_cls: Optional["RequestBody"]
):
from starlette.requests import Request

if not request_body_cls:
return None
if request_body_cls == Request:
Expand Down
17 changes: 10 additions & 7 deletions dbgpt/model/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from dbgpt.model.cluster.client import DefaultLLMClient
try:
from dbgpt.model.cluster.client import DefaultLLMClient
except ImportError as exc:
# logging.warning("Can't import dbgpt.model.DefaultLLMClient")
DefaultLLMClient = None

# from dbgpt.model.utils.chatgpt_utils import OpenAILLMClient
from dbgpt.model.proxy.llms.chatgpt import OpenAILLMClient

__ALL__ = [
"DefaultLLMClient",
"OpenAILLMClient",
]
_exports = []
if DefaultLLMClient:
_exports.append("DefaultLLMClient")

__ALL__ = _exports
2 changes: 1 addition & 1 deletion dbgpt/model/loader.py → dbgpt/model/adapter/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def loader_with_params(
def huggingface_loader(llm_adapter: LLMModelAdapter, model_params: ModelParameters):
import torch

from dbgpt.model.compression import compress_module
from dbgpt.model.llm.compression import compress_module

device = model_params.device
max_memory = None
Expand Down
2 changes: 1 addition & 1 deletion dbgpt/model/adapter/old_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from dbgpt.model.adapter.base import LLMModelAdapter
from dbgpt.model.adapter.template import ConversationAdapter, PromptType
from dbgpt.model.base import ModelType
from dbgpt.model.conversation import Conversation
from dbgpt.model.llm.conversation import Conversation
from dbgpt.model.parameter import (
LlamaCppModelParameters,
ModelParameters,
Expand Down
4 changes: 1 addition & 3 deletions dbgpt/model/base.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import time
from dataclasses import asdict, dataclass
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Optional, TypedDict
from typing import Dict, List, Optional

from dbgpt.util.model_utils import GPUInfo
from dbgpt.util.parameter_utils import ParameterDescription


Expand Down
2 changes: 1 addition & 1 deletion dbgpt/model/cluster/worker/default_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
ModelOutput,
)
from dbgpt.model.adapter.base import LLMModelAdapter
from dbgpt.model.adapter.loader import ModelLoader, _get_model_real_path
from dbgpt.model.adapter.model_adapter import get_llm_model_adapter
from dbgpt.model.cluster.worker_base import ModelWorker
from dbgpt.model.loader import ModelLoader, _get_model_real_path
from dbgpt.model.parameter import ModelParameters
from dbgpt.util.model_utils import _clear_model_cache, _get_current_cuda_memory
from dbgpt.util.parameter_utils import EnvArgumentParser, _get_dict_from_obj
Expand Down
2 changes: 1 addition & 1 deletion dbgpt/model/cluster/worker/embedding_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@

from dbgpt.configs.model_config import get_device
from dbgpt.core import ModelMetadata
from dbgpt.model.adapter.loader import _get_model_real_path
from dbgpt.model.cluster.embedding.loader import EmbeddingLoader
from dbgpt.model.cluster.worker_base import ModelWorker
from dbgpt.model.loader import _get_model_real_path
from dbgpt.model.parameter import (
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG,
BaseEmbeddingModelParameters,
Expand Down
13 changes: 4 additions & 9 deletions dbgpt/model/cluster/worker/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,15 @@
import time
from concurrent.futures import ThreadPoolExecutor
from dataclasses import asdict
from typing import Awaitable, Callable, Dict, Iterator, List
from typing import Awaitable, Callable, Iterator

from fastapi import APIRouter, FastAPI
from fastapi.responses import StreamingResponse

from dbgpt.component import SystemApp
from dbgpt.configs.model_config import LOGDIR
from dbgpt.core import ModelMetadata, ModelOutput
from dbgpt.model.base import (
ModelInstance,
WorkerApplyOutput,
WorkerApplyType,
WorkerSupportedModel,
)
from dbgpt.model.base import ModelInstance, WorkerApplyOutput, WorkerSupportedModel
from dbgpt.model.cluster.base import *
from dbgpt.model.cluster.manager_base import (
WorkerManager,
Expand All @@ -30,8 +25,8 @@
)
from dbgpt.model.cluster.registry import ModelRegistry
from dbgpt.model.cluster.worker_base import ModelWorker
from dbgpt.model.llm_utils import list_supported_models
from dbgpt.model.parameter import ModelParameters, ModelWorkerParameters, WorkerType
from dbgpt.model.parameter import ModelWorkerParameters, WorkerType
from dbgpt.model.utils.llm_utils import list_supported_models
from dbgpt.util.parameter_utils import (
EnvArgumentParser,
ParameterDescription,
Expand Down
File renamed without changes.
File renamed without changes.
2 changes: 1 addition & 1 deletion dbgpt/model/inference.py → dbgpt/model/llm/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
TopPLogitsWarper,
)

from dbgpt.model.llm_utils import is_partial_stop, is_sentence_complete
from dbgpt.model.utils.llm_utils import is_partial_stop, is_sentence_complete


def prepare_logits_processor(
Expand Down
4 changes: 2 additions & 2 deletions dbgpt/model/operator/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from dbgpt.model.operator.llm_operator import (
from dbgpt.model.operator.llm_operator import ( # noqa: F401
LLMOperator,
MixinLLMOperator,
StreamingLLMOperator,
)
from dbgpt.model.utils.chatgpt_utils import OpenAIStreamingOutputOperator
from dbgpt.model.utils.chatgpt_utils import OpenAIStreamingOutputOperator # noqa: F401

__ALL__ = [
"MixinLLMOperator",
Expand Down
36 changes: 17 additions & 19 deletions dbgpt/model/operator/llm_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from dbgpt.core import LLMClient
from dbgpt.core.awel import BaseOperator
from dbgpt.core.operator import BaseLLM, BaseLLMOperator, BaseStreamingLLMOperator
from dbgpt.model.cluster import WorkerManagerFactory

logger = logging.getLogger(__name__)

Expand All @@ -19,31 +18,30 @@ class MixinLLMOperator(BaseLLM, BaseOperator, ABC):

def __init__(self, default_client: Optional[LLMClient] = None, **kwargs):
super().__init__(default_client)
self._default_llm_client = default_client

@property
def llm_client(self) -> LLMClient:
if not self._llm_client:
worker_manager_factory: WorkerManagerFactory = (
self.system_app.get_component(
ComponentType.WORKER_MANAGER_FACTORY,
WorkerManagerFactory,
default_component=None,
)
)
if worker_manager_factory:
try:
from dbgpt.model.cluster import WorkerManagerFactory
from dbgpt.model.cluster.client import DefaultLLMClient

self._llm_client = DefaultLLMClient(worker_manager_factory.create())
else:
if self._default_llm_client is None:
from dbgpt.model.proxy.llms.chatgpt import OpenAILLMClient

self._default_llm_client = OpenAILLMClient()
logger.info(
f"Can't find worker manager factory, use default llm client {self._default_llm_client}."
worker_manager_factory: WorkerManagerFactory = (
self.system_app.get_component(
ComponentType.WORKER_MANAGER_FACTORY,
WorkerManagerFactory,
default_component=None,
)
)
self._llm_client = self._default_llm_client
if worker_manager_factory:
self._llm_client = DefaultLLMClient(worker_manager_factory.create())
except Exception as e:
logger.warning(f"Load worker manager failed: {e}.")
if not self._llm_client:
from dbgpt.model.proxy.llms.chatgpt import OpenAILLMClient

logger.info("Can't find worker manager factory, use OpenAILLMClient.")
self._llm_client = OpenAILLMClient()
return self._llm_client


Expand Down
18 changes: 10 additions & 8 deletions dbgpt/model/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,8 @@
from enum import Enum
from typing import Dict, Optional, Tuple, Union

from dbgpt.model.conversation import conv_templates
from dbgpt.util.parameter_utils import BaseParameters

suported_prompt_templates = ",".join(conv_templates.keys())


class WorkerType(str, Enum):
LLM = "llm"
Expand Down Expand Up @@ -299,7 +296,8 @@ class ModelParameters(BaseModelParameters):
prompt_template: Optional[str] = field(
default=None,
metadata={
"help": f"Prompt template. If None, the prompt template is automatically determined from model path, supported template: {suported_prompt_templates}"
"help": f"Prompt template. If None, the prompt template is automatically "
f"determined from model path"
},
)
max_context_size: Optional[int] = field(
Expand Down Expand Up @@ -450,7 +448,8 @@ class ProxyModelParameters(BaseModelParameters):
proxyllm_backend: Optional[str] = field(
default=None,
metadata={
"help": "The model name actually pass to current proxy server url, such as gpt-3.5-turbo, gpt-4, chatglm_pro, chatglm_std and so on"
"help": "The model name actually pass to current proxy server url, such "
"as gpt-3.5-turbo, gpt-4, chatglm_pro, chatglm_std and so on"
},
)
model_type: Optional[str] = field(
Expand All @@ -463,13 +462,15 @@ class ProxyModelParameters(BaseModelParameters):
device: Optional[str] = field(
default=None,
metadata={
"help": "Device to run model. If None, the device is automatically determined"
"help": "Device to run model. If None, the device is automatically "
"determined"
},
)
prompt_template: Optional[str] = field(
default=None,
metadata={
"help": f"Prompt template. If None, the prompt template is automatically determined from model path, supported template: {suported_prompt_templates}"
"help": f"Prompt template. If None, the prompt template is automatically "
f"determined from model path"
},
)
max_context_size: Optional[int] = field(
Expand All @@ -478,7 +479,8 @@ class ProxyModelParameters(BaseModelParameters):
llm_client_class: Optional[str] = field(
default=None,
metadata={
"help": "The class name of llm client, such as dbgpt.model.proxy.llms.proxy_model.ProxyModel"
"help": "The class name of llm client, such as "
"dbgpt.model.proxy.llms.proxy_model.ProxyModel"
},
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ def list_supported_models():
def _list_supported_models(
worker_type: str, model_config: Dict[str, str]
) -> List[SupportedModel]:
from dbgpt.model.adapter.loader import _get_model_real_path
from dbgpt.model.adapter.model_adapter import get_llm_model_adapter
from dbgpt.model.loader import _get_model_real_path

ret = []
for model_name, model_path in model_config.items():
Expand Down
Loading
Loading