Skip to content

Commit

Permalink
refactor: Refactor datasource module
Browse files Browse the repository at this point in the history
  • Loading branch information
fangyinc committed Mar 18, 2024
1 parent 84bedee commit 15aa9de
Show file tree
Hide file tree
Showing 108 changed files with 1,187 additions and 1,059 deletions.
19 changes: 13 additions & 6 deletions .mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,6 @@ exclude = /tests/
[mypy-dbgpt.app.*]
follow_imports = skip

[mypy-dbgpt.datasource.*]
follow_imports = skip

# [mypy-dbgpt.storage.*]
# follow_imports = skip

[mypy-dbgpt.serve.*]
follow_imports = skip

Expand Down Expand Up @@ -74,3 +68,16 @@ ignore_missing_imports = True

[mypy-cryptography.*]
ignore_missing_imports = True

# Datasource
[mypy-pyspark.*]
ignore_missing_imports = True

[mypy-regex.*]
ignore_missing_imports = True

[mypy-sqlparse.*]
ignore_missing_imports = True

[mypy-clickhouse_connect.*]
ignore_missing_imports = True
10 changes: 3 additions & 7 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,7 @@ fmt: setup ## Format Python code
$(VENV_BIN)/blackdoc examples
# TODO: Use flake8 to enforce Python style guide.
# https://flake8.pycqa.org/en/latest/
$(VENV_BIN)/flake8 dbgpt/core/
$(VENV_BIN)/flake8 dbgpt/rag/
$(VENV_BIN)/flake8 dbgpt/storage/
$(VENV_BIN)/flake8 dbgpt/core/ dbgpt/rag/ dbgpt/storage/ dbgpt/datasource/
# TODO: More package checks with flake8.

.PHONY: fmt-check
Expand All @@ -59,9 +57,7 @@ fmt-check: setup ## Check Python code formatting and style without making change
$(VENV_BIN)/isort --check-only --extend-skip="examples/notebook" examples
$(VENV_BIN)/black --check --extend-exclude="examples/notebook" .
$(VENV_BIN)/blackdoc --check dbgpt examples
$(VENV_BIN)/flake8 dbgpt/core/
$(VENV_BIN)/flake8 dbgpt/rag/
$(VENV_BIN)/flake8 dbgpt/storage/
$(VENV_BIN)/flake8 dbgpt/core/ dbgpt/rag/ dbgpt/storage/ dbgpt/datasource/

.PHONY: pre-commit
pre-commit: fmt-check test test-doc mypy ## Run formatting and unit tests before committing
Expand All @@ -77,7 +73,7 @@ test-doc: $(VENV)/.testenv ## Run doctests
.PHONY: mypy
mypy: $(VENV)/.testenv ## Run mypy checks
# https://github.com/python/mypy
$(VENV_BIN)/mypy --config-file .mypy.ini dbgpt/rag/
$(VENV_BIN)/mypy --config-file .mypy.ini dbgpt/rag/ dbgpt/datasource/
# rag depends on core and storage, so we not need to check it again.
# $(VENV_BIN)/mypy --config-file .mypy.ini dbgpt/storage/
# $(VENV_BIN)/mypy --config-file .mypy.ini dbgpt/core/
Expand Down
11 changes: 9 additions & 2 deletions dbgpt/_private/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from auto_gpt_plugin_template import AutoGPTPluginTemplate

from dbgpt.component import SystemApp
from dbgpt.datasource.manages import ConnectorManager


class Config(metaclass=Singleton):
Expand Down Expand Up @@ -185,8 +186,6 @@ def __init__(self) -> None:
os.getenv("NATIVE_SQL_CAN_RUN_WRITE", "True").lower() == "true"
)

self.LOCAL_DB_MANAGE = None

###dbgpt meta info database connection configuration
self.LOCAL_DB_HOST = os.getenv("LOCAL_DB_HOST")
self.LOCAL_DB_PATH = os.getenv("LOCAL_DB_PATH", "data/default_sqlite.db")
Expand Down Expand Up @@ -287,3 +286,11 @@ def __init__(self) -> None:
self.MODEL_CACHE_STORAGE_DISK_DIR: Optional[str] = os.getenv(
"MODEL_CACHE_STORAGE_DISK_DIR"
)

@property
def local_db_manager(self) -> "ConnectorManager":
from dbgpt.datasource.manages import ConnectorManager

if not self.SYSTEM_APP:
raise ValueError("SYSTEM_APP is not set")
return ConnectorManager.get_instance(self.SYSTEM_APP)
7 changes: 0 additions & 7 deletions dbgpt/app/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,15 +70,8 @@ def server_init(param: "WebServerParameters", system_app: SystemApp):


def _create_model_start_listener(system_app: SystemApp):
from dbgpt.datasource.manages.connection_manager import ConnectManager

cfg = Config()

def startup_event(wh):
# init connect manage
print("begin run _add_app_startup_event")
conn_manage = ConnectManager(system_app)
cfg.LOCAL_DB_MANAGE = conn_manage
async_db_summary(system_app)

return startup_event
Expand Down
2 changes: 2 additions & 0 deletions dbgpt/app/component_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def initialize_components(
from dbgpt.app.initialization.embedding_component import _initialize_embedding_model
from dbgpt.app.initialization.scheduler import DefaultScheduler
from dbgpt.app.initialization.serve_initialization import register_serve_apps
from dbgpt.datasource.manages.connector_manager import ConnectorManager
from dbgpt.model.cluster.controller.controller import controller

# Register global default executor factory first
Expand All @@ -31,6 +32,7 @@ def initialize_components(
)
system_app.register(DefaultScheduler)
system_app.register_instance(controller)
system_app.register(ConnectorManager)

from dbgpt.serve.agent.hub.controller import module_plugin

Expand Down
2 changes: 1 addition & 1 deletion dbgpt/app/knowledge/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
from dbgpt.app.knowledge.space_db import KnowledgeSpaceDao, KnowledgeSpaceEntity
from dbgpt.component import ComponentType
from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG
from dbgpt.core import Chunk
from dbgpt.model import DefaultLLMClient
from dbgpt.rag.chunk import Chunk
from dbgpt.rag.chunk_manager import ChunkParameters
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
from dbgpt.rag.knowledge.base import ChunkStrategy, KnowledgeType
Expand Down
16 changes: 8 additions & 8 deletions dbgpt/app/openapi/api_v1/api_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def __new_conversation(chat_mode, user_name: str, sys_code: str) -> Conversation


def get_db_list():
dbs = CFG.LOCAL_DB_MANAGE.get_db_list()
dbs = CFG.local_db_manager.get_db_list()
db_params = []
for item in dbs:
params: dict = {}
Expand All @@ -85,7 +85,7 @@ def plugins_select_info():


def get_db_list_info():
dbs = CFG.LOCAL_DB_MANAGE.get_db_list()
dbs = CFG.local_db_manager.get_db_list()
params: dict = {}
for item in dbs:
comment = item["comment"]
Expand Down Expand Up @@ -147,22 +147,22 @@ def get_executor() -> Executor:

@router.get("/v1/chat/db/list", response_model=Result[DBConfig])
async def db_connect_list():
return Result.succ(CFG.LOCAL_DB_MANAGE.get_db_list())
return Result.succ(CFG.local_db_manager.get_db_list())


@router.post("/v1/chat/db/add", response_model=Result[bool])
async def db_connect_add(db_config: DBConfig = Body()):
return Result.succ(CFG.LOCAL_DB_MANAGE.add_db(db_config))
return Result.succ(CFG.local_db_manager.add_db(db_config))


@router.post("/v1/chat/db/edit", response_model=Result[bool])
async def db_connect_edit(db_config: DBConfig = Body()):
return Result.succ(CFG.LOCAL_DB_MANAGE.edit_db(db_config))
return Result.succ(CFG.local_db_manager.edit_db(db_config))


@router.post("/v1/chat/db/delete", response_model=Result[bool])
async def db_connect_delete(db_name: str = None):
return Result.succ(CFG.LOCAL_DB_MANAGE.delete_db(db_name))
return Result.succ(CFG.local_db_manager.delete_db(db_name))


async def async_db_summary_embedding(db_name, db_type):
Expand All @@ -174,7 +174,7 @@ async def async_db_summary_embedding(db_name, db_type):
async def test_connect(db_config: DBConfig = Body()):
try:
# TODO Change the synchronous call to the asynchronous call
CFG.LOCAL_DB_MANAGE.test_connect(db_config)
CFG.local_db_manager.test_connect(db_config)
return Result.succ(True)
except Exception as e:
return Result.failed(code="E1001", msg=str(e))
Expand All @@ -189,7 +189,7 @@ async def db_summary(db_name: str, db_type: str):

@router.get("/v1/chat/db/support/type", response_model=Result[DbTypeInfo])
async def db_support_types():
support_types = CFG.LOCAL_DB_MANAGE.get_all_completed_types()
support_types = CFG.local_db_manager.get_all_completed_types()
db_type_infos = []
for type in support_types:
db_type_infos.append(
Expand Down
10 changes: 5 additions & 5 deletions dbgpt/app/openapi/api_v1/editor/api_editor_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ async def get_editor_tables(
db_name: str, page_index: int, page_size: int, search_str: str = ""
):
logger.info(f"get_editor_tables:{db_name},{page_index},{page_size},{search_str}")
db_conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name)
db_conn = CFG.local_db_manager.get_connector(db_name)
tables = db_conn.get_table_names()
db_node: DataNode = DataNode(title=db_name, key=db_name, type="db")
for table in tables:
Expand Down Expand Up @@ -95,7 +95,7 @@ async def editor_sql_run(run_param: dict = Body()):
sql = run_param["sql"]
if not db_name and not sql:
return Result.failed(msg="SQL run param error!")
conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name)
conn = CFG.local_db_manager.get_connector(db_name)

try:
start_time = time.time() * 1000
Expand Down Expand Up @@ -125,7 +125,7 @@ async def sql_editor_submit(
):
logger.info(f"sql_editor_submit:{sql_edit_context.__dict__}")

conn = CFG.LOCAL_DB_MANAGE.get_connect(sql_edit_context.db_name)
conn = CFG.local_db_manager.get_connector(sql_edit_context.db_name)
try:
editor_service.sql_editor_submit_and_save(sql_edit_context, conn)
return Result.succ(None)
Expand Down Expand Up @@ -168,7 +168,7 @@ async def editor_chart_run(run_param: dict = Body()):
return Result.failed("SQL run param error!")
try:
dashboard_data_loader: DashboardDataLoader = DashboardDataLoader()
db_conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name)
db_conn = CFG.local_db_manager.get_connector(db_name)
colunms, sql_result = db_conn.query_ex(sql)
field_names, chart_values = dashboard_data_loader.get_chart_values_by_data(
colunms, sql_result, sql
Expand Down Expand Up @@ -204,7 +204,7 @@ async def chart_editor_submit(chart_edit_context: ChatChartEditContext = Body())
history_messages: List[Dict] = history_mem.get_messages()
if history_messages:
dashboard_data_loader: DashboardDataLoader = DashboardDataLoader()
db_conn = CFG.LOCAL_DB_MANAGE.get_connect(chart_edit_context.db_name)
db_conn = CFG.local_db_manager.get_connector(chart_edit_context.db_name)

edit_round = max(history_messages, key=lambda x: x["chat_order"])
if edit_round:
Expand Down
6 changes: 3 additions & 3 deletions dbgpt/app/openapi/api_v1/editor/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from dbgpt.serve.conversation.serve import Serve as ConversationServe

if TYPE_CHECKING:
from dbgpt.datasource.base import BaseConnect
from dbgpt.datasource.base import BaseConnector

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -86,7 +86,7 @@ def get_editor_sql_by_round(
return None

def sql_editor_submit_and_save(
self, sql_edit_context: ChatSqlEditContext, connection: BaseConnect
self, sql_edit_context: ChatSqlEditContext, connection: BaseConnector
):
storage_conv: StorageConversation = self.get_storage_conv(
sql_edit_context.conv_uid
Expand Down Expand Up @@ -169,7 +169,7 @@ def get_editor_chart_info(
filter(lambda x: x["chart_name"] == chart_title, charts)
)[0]

conn = cfg.LOCAL_DB_MANAGE.get_connect(db_name)
conn = cfg.local_db_manager.get_connector(db_name)
detail: ChartDetail = ChartDetail(
chart_uid=find_chart["chart_uid"],
chart_type=find_chart["chart_type"],
Expand Down
2 changes: 1 addition & 1 deletion dbgpt/app/scene/chat_dashboard/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(self, chat_param: Dict):
self.db_name = self.db_name
self.report_name = chat_param.get("report_name", "report")

self.database = CFG.LOCAL_DB_MANAGE.get_connect(self.db_name)
self.database = CFG.local_db_manager.get_connector(self.db_name)

self.top_k: int = 5
self.dashboard_template = self.__load_dashboard_template(self.report_name)
Expand Down
2 changes: 1 addition & 1 deletion dbgpt/app/scene/chat_dashboard/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,5 +57,5 @@ def get_chart_values_by_data(self, field_names, datas, chart_sql: str):

def get_chart_values_by_db(self, db_name: str, chart_sql: str):
logger.info(f"get_chart_values_by_db:{db_name},{chart_sql}")
db_conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name)
db_conn = CFG.local_db_manager.get_connector(db_name)
return self.get_chart_values_by_conn(db_conn, chart_sql)
2 changes: 1 addition & 1 deletion dbgpt/app/scene/chat_db/auto_execute/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(self, chat_param: Dict):
with root_tracer.start_span(
"ChatWithDbAutoExecute.get_connect", metadata={"db_name": self.db_name}
):
self.database = CFG.LOCAL_DB_MANAGE.get_connect(self.db_name)
self.database = CFG.local_db_manager.get_connector(self.db_name)

self.top_k: int = 50
self.api_call = ApiCall(display_registry=CFG.command_display)
Expand Down
2 changes: 1 addition & 1 deletion dbgpt/app/scene/chat_db/professional_qa/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(self, chat_param: Dict):
super().__init__(chat_param=chat_param)

if self.db_name:
self.database = CFG.LOCAL_DB_MANAGE.get_connect(self.db_name)
self.database = CFG.local_db_manager.get_connector(self.db_name)
self.tables = self.database.get_table_names()

self.top_k = (
Expand Down
1 change: 1 addition & 0 deletions dbgpt/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ class ComponentType(str, Enum):
AWEL_TRIGGER_MANAGER = "dbgpt_awel_trigger_manager"
AWEL_DAG_MANAGER = "dbgpt_awel_dag_manager"
UNIFIED_METADATA_DB_MANAGER_FACTORY = "dbgpt_unified_metadata_db_manager_factory"
CONNECTOR_MANAGER = "dbgpt_connector_manager"


_EMPTY_DEFAULT_COMPONENT = "_EMPTY_DEFAULT_COMPONENT"
Expand Down
3 changes: 3 additions & 0 deletions dbgpt/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
CacheValue,
)
from dbgpt.core.interface.embeddings import Embeddings # noqa: F401
from dbgpt.core.interface.knowledge import Chunk, Document # noqa: F401
from dbgpt.core.interface.llm import ( # noqa: F401
DefaultMessageConverter,
LLMClient,
Expand Down Expand Up @@ -105,4 +106,6 @@
"QuerySpec",
"StorageError",
"Embeddings",
"Chunk",
"Document",
]
File renamed without changes.
7 changes: 6 additions & 1 deletion dbgpt/datasource/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,6 @@
from .manages.connect_config_db import ConnectConfigDao, ConnectConfigEntity
"""Module to define the data source connectors."""

from .base import BaseConnector # noqa: F401
from .rdbms.base import RDBMSConnector # noqa: F401

__ALL__ = ["BaseConnector", "RDBMSConnector"]
Loading

0 comments on commit 15aa9de

Please sign in to comment.