Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Refactor datasource module #1309

Merged
merged 1 commit into from
Mar 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions .mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,6 @@ exclude = /tests/
[mypy-dbgpt.app.*]
follow_imports = skip

[mypy-dbgpt.datasource.*]
follow_imports = skip

# [mypy-dbgpt.storage.*]
# follow_imports = skip

[mypy-dbgpt.serve.*]
follow_imports = skip

Expand Down Expand Up @@ -74,3 +68,16 @@ ignore_missing_imports = True

[mypy-cryptography.*]
ignore_missing_imports = True

# Datasource
[mypy-pyspark.*]
ignore_missing_imports = True

[mypy-regex.*]
ignore_missing_imports = True

[mypy-sqlparse.*]
ignore_missing_imports = True

[mypy-clickhouse_connect.*]
ignore_missing_imports = True
10 changes: 3 additions & 7 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,7 @@ fmt: setup ## Format Python code
$(VENV_BIN)/blackdoc examples
# TODO: Use flake8 to enforce Python style guide.
# https://flake8.pycqa.org/en/latest/
$(VENV_BIN)/flake8 dbgpt/core/
$(VENV_BIN)/flake8 dbgpt/rag/
$(VENV_BIN)/flake8 dbgpt/storage/
$(VENV_BIN)/flake8 dbgpt/core/ dbgpt/rag/ dbgpt/storage/ dbgpt/datasource/
# TODO: More package checks with flake8.

.PHONY: fmt-check
Expand All @@ -59,9 +57,7 @@ fmt-check: setup ## Check Python code formatting and style without making change
$(VENV_BIN)/isort --check-only --extend-skip="examples/notebook" examples
$(VENV_BIN)/black --check --extend-exclude="examples/notebook" .
$(VENV_BIN)/blackdoc --check dbgpt examples
$(VENV_BIN)/flake8 dbgpt/core/
$(VENV_BIN)/flake8 dbgpt/rag/
$(VENV_BIN)/flake8 dbgpt/storage/
$(VENV_BIN)/flake8 dbgpt/core/ dbgpt/rag/ dbgpt/storage/ dbgpt/datasource/

.PHONY: pre-commit
pre-commit: fmt-check test test-doc mypy ## Run formatting and unit tests before committing
Expand All @@ -77,7 +73,7 @@ test-doc: $(VENV)/.testenv ## Run doctests
.PHONY: mypy
mypy: $(VENV)/.testenv ## Run mypy checks
# https://github.com/python/mypy
$(VENV_BIN)/mypy --config-file .mypy.ini dbgpt/rag/
$(VENV_BIN)/mypy --config-file .mypy.ini dbgpt/rag/ dbgpt/datasource/
# rag depends on core and storage, so we not need to check it again.
# $(VENV_BIN)/mypy --config-file .mypy.ini dbgpt/storage/
# $(VENV_BIN)/mypy --config-file .mypy.ini dbgpt/core/
Expand Down
11 changes: 9 additions & 2 deletions dbgpt/_private/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from auto_gpt_plugin_template import AutoGPTPluginTemplate

from dbgpt.component import SystemApp
from dbgpt.datasource.manages import ConnectorManager


class Config(metaclass=Singleton):
Expand Down Expand Up @@ -185,8 +186,6 @@ def __init__(self) -> None:
os.getenv("NATIVE_SQL_CAN_RUN_WRITE", "True").lower() == "true"
)

self.LOCAL_DB_MANAGE = None

###dbgpt meta info database connection configuration
self.LOCAL_DB_HOST = os.getenv("LOCAL_DB_HOST")
self.LOCAL_DB_PATH = os.getenv("LOCAL_DB_PATH", "data/default_sqlite.db")
Expand Down Expand Up @@ -287,3 +286,11 @@ def __init__(self) -> None:
self.MODEL_CACHE_STORAGE_DISK_DIR: Optional[str] = os.getenv(
"MODEL_CACHE_STORAGE_DISK_DIR"
)

@property
def local_db_manager(self) -> "ConnectorManager":
from dbgpt.datasource.manages import ConnectorManager

if not self.SYSTEM_APP:
raise ValueError("SYSTEM_APP is not set")
return ConnectorManager.get_instance(self.SYSTEM_APP)
7 changes: 0 additions & 7 deletions dbgpt/app/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,15 +70,8 @@ def server_init(param: "WebServerParameters", system_app: SystemApp):


def _create_model_start_listener(system_app: SystemApp):
from dbgpt.datasource.manages.connection_manager import ConnectManager

cfg = Config()

def startup_event(wh):
# init connect manage
print("begin run _add_app_startup_event")
conn_manage = ConnectManager(system_app)
cfg.LOCAL_DB_MANAGE = conn_manage
async_db_summary(system_app)

return startup_event
Expand Down
2 changes: 2 additions & 0 deletions dbgpt/app/component_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def initialize_components(
from dbgpt.app.initialization.embedding_component import _initialize_embedding_model
from dbgpt.app.initialization.scheduler import DefaultScheduler
from dbgpt.app.initialization.serve_initialization import register_serve_apps
from dbgpt.datasource.manages.connector_manager import ConnectorManager
from dbgpt.model.cluster.controller.controller import controller

# Register global default executor factory first
Expand All @@ -31,6 +32,7 @@ def initialize_components(
)
system_app.register(DefaultScheduler)
system_app.register_instance(controller)
system_app.register(ConnectorManager)

from dbgpt.serve.agent.hub.controller import module_plugin

Expand Down
2 changes: 1 addition & 1 deletion dbgpt/app/knowledge/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
from dbgpt.app.knowledge.space_db import KnowledgeSpaceDao, KnowledgeSpaceEntity
from dbgpt.component import ComponentType
from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG
from dbgpt.core import Chunk
from dbgpt.model import DefaultLLMClient
from dbgpt.rag.chunk import Chunk
from dbgpt.rag.chunk_manager import ChunkParameters
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
from dbgpt.rag.knowledge.base import ChunkStrategy, KnowledgeType
Expand Down
16 changes: 8 additions & 8 deletions dbgpt/app/openapi/api_v1/api_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def __new_conversation(chat_mode, user_name: str, sys_code: str) -> Conversation


def get_db_list():
dbs = CFG.LOCAL_DB_MANAGE.get_db_list()
dbs = CFG.local_db_manager.get_db_list()
db_params = []
for item in dbs:
params: dict = {}
Expand All @@ -85,7 +85,7 @@ def plugins_select_info():


def get_db_list_info():
dbs = CFG.LOCAL_DB_MANAGE.get_db_list()
dbs = CFG.local_db_manager.get_db_list()
params: dict = {}
for item in dbs:
comment = item["comment"]
Expand Down Expand Up @@ -147,22 +147,22 @@ def get_executor() -> Executor:

@router.get("/v1/chat/db/list", response_model=Result[DBConfig])
async def db_connect_list():
return Result.succ(CFG.LOCAL_DB_MANAGE.get_db_list())
return Result.succ(CFG.local_db_manager.get_db_list())


@router.post("/v1/chat/db/add", response_model=Result[bool])
async def db_connect_add(db_config: DBConfig = Body()):
return Result.succ(CFG.LOCAL_DB_MANAGE.add_db(db_config))
return Result.succ(CFG.local_db_manager.add_db(db_config))


@router.post("/v1/chat/db/edit", response_model=Result[bool])
async def db_connect_edit(db_config: DBConfig = Body()):
return Result.succ(CFG.LOCAL_DB_MANAGE.edit_db(db_config))
return Result.succ(CFG.local_db_manager.edit_db(db_config))


@router.post("/v1/chat/db/delete", response_model=Result[bool])
async def db_connect_delete(db_name: str = None):
return Result.succ(CFG.LOCAL_DB_MANAGE.delete_db(db_name))
return Result.succ(CFG.local_db_manager.delete_db(db_name))


async def async_db_summary_embedding(db_name, db_type):
Expand All @@ -174,7 +174,7 @@ async def async_db_summary_embedding(db_name, db_type):
async def test_connect(db_config: DBConfig = Body()):
try:
# TODO Change the synchronous call to the asynchronous call
CFG.LOCAL_DB_MANAGE.test_connect(db_config)
CFG.local_db_manager.test_connect(db_config)
return Result.succ(True)
except Exception as e:
return Result.failed(code="E1001", msg=str(e))
Expand All @@ -189,7 +189,7 @@ async def db_summary(db_name: str, db_type: str):

@router.get("/v1/chat/db/support/type", response_model=Result[DbTypeInfo])
async def db_support_types():
support_types = CFG.LOCAL_DB_MANAGE.get_all_completed_types()
support_types = CFG.local_db_manager.get_all_completed_types()
db_type_infos = []
for type in support_types:
db_type_infos.append(
Expand Down
10 changes: 5 additions & 5 deletions dbgpt/app/openapi/api_v1/editor/api_editor_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ async def get_editor_tables(
db_name: str, page_index: int, page_size: int, search_str: str = ""
):
logger.info(f"get_editor_tables:{db_name},{page_index},{page_size},{search_str}")
db_conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name)
db_conn = CFG.local_db_manager.get_connector(db_name)
tables = db_conn.get_table_names()
db_node: DataNode = DataNode(title=db_name, key=db_name, type="db")
for table in tables:
Expand Down Expand Up @@ -95,7 +95,7 @@ async def editor_sql_run(run_param: dict = Body()):
sql = run_param["sql"]
if not db_name and not sql:
return Result.failed(msg="SQL run param error!")
conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name)
conn = CFG.local_db_manager.get_connector(db_name)

try:
start_time = time.time() * 1000
Expand Down Expand Up @@ -125,7 +125,7 @@ async def sql_editor_submit(
):
logger.info(f"sql_editor_submit:{sql_edit_context.__dict__}")

conn = CFG.LOCAL_DB_MANAGE.get_connect(sql_edit_context.db_name)
conn = CFG.local_db_manager.get_connector(sql_edit_context.db_name)
try:
editor_service.sql_editor_submit_and_save(sql_edit_context, conn)
return Result.succ(None)
Expand Down Expand Up @@ -168,7 +168,7 @@ async def editor_chart_run(run_param: dict = Body()):
return Result.failed("SQL run param error!")
try:
dashboard_data_loader: DashboardDataLoader = DashboardDataLoader()
db_conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name)
db_conn = CFG.local_db_manager.get_connector(db_name)
colunms, sql_result = db_conn.query_ex(sql)
field_names, chart_values = dashboard_data_loader.get_chart_values_by_data(
colunms, sql_result, sql
Expand Down Expand Up @@ -204,7 +204,7 @@ async def chart_editor_submit(chart_edit_context: ChatChartEditContext = Body())
history_messages: List[Dict] = history_mem.get_messages()
if history_messages:
dashboard_data_loader: DashboardDataLoader = DashboardDataLoader()
db_conn = CFG.LOCAL_DB_MANAGE.get_connect(chart_edit_context.db_name)
db_conn = CFG.local_db_manager.get_connector(chart_edit_context.db_name)

edit_round = max(history_messages, key=lambda x: x["chat_order"])
if edit_round:
Expand Down
6 changes: 3 additions & 3 deletions dbgpt/app/openapi/api_v1/editor/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from dbgpt.serve.conversation.serve import Serve as ConversationServe

if TYPE_CHECKING:
from dbgpt.datasource.base import BaseConnect
from dbgpt.datasource.base import BaseConnector

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -86,7 +86,7 @@ def get_editor_sql_by_round(
return None

def sql_editor_submit_and_save(
self, sql_edit_context: ChatSqlEditContext, connection: BaseConnect
self, sql_edit_context: ChatSqlEditContext, connection: BaseConnector
):
storage_conv: StorageConversation = self.get_storage_conv(
sql_edit_context.conv_uid
Expand Down Expand Up @@ -169,7 +169,7 @@ def get_editor_chart_info(
filter(lambda x: x["chart_name"] == chart_title, charts)
)[0]

conn = cfg.LOCAL_DB_MANAGE.get_connect(db_name)
conn = cfg.local_db_manager.get_connector(db_name)
detail: ChartDetail = ChartDetail(
chart_uid=find_chart["chart_uid"],
chart_type=find_chart["chart_type"],
Expand Down
2 changes: 1 addition & 1 deletion dbgpt/app/scene/chat_dashboard/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(self, chat_param: Dict):
self.db_name = self.db_name
self.report_name = chat_param.get("report_name", "report")

self.database = CFG.LOCAL_DB_MANAGE.get_connect(self.db_name)
self.database = CFG.local_db_manager.get_connector(self.db_name)

self.top_k: int = 5
self.dashboard_template = self.__load_dashboard_template(self.report_name)
Expand Down
2 changes: 1 addition & 1 deletion dbgpt/app/scene/chat_dashboard/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,5 +57,5 @@ def get_chart_values_by_data(self, field_names, datas, chart_sql: str):

def get_chart_values_by_db(self, db_name: str, chart_sql: str):
logger.info(f"get_chart_values_by_db:{db_name},{chart_sql}")
db_conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name)
db_conn = CFG.local_db_manager.get_connector(db_name)
return self.get_chart_values_by_conn(db_conn, chart_sql)
2 changes: 1 addition & 1 deletion dbgpt/app/scene/chat_db/auto_execute/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(self, chat_param: Dict):
with root_tracer.start_span(
"ChatWithDbAutoExecute.get_connect", metadata={"db_name": self.db_name}
):
self.database = CFG.LOCAL_DB_MANAGE.get_connect(self.db_name)
self.database = CFG.local_db_manager.get_connector(self.db_name)

self.top_k: int = 50
self.api_call = ApiCall(display_registry=CFG.command_display)
Expand Down
2 changes: 1 addition & 1 deletion dbgpt/app/scene/chat_db/professional_qa/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(self, chat_param: Dict):
super().__init__(chat_param=chat_param)

if self.db_name:
self.database = CFG.LOCAL_DB_MANAGE.get_connect(self.db_name)
self.database = CFG.local_db_manager.get_connector(self.db_name)
self.tables = self.database.get_table_names()

self.top_k = (
Expand Down
1 change: 1 addition & 0 deletions dbgpt/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ class ComponentType(str, Enum):
AWEL_TRIGGER_MANAGER = "dbgpt_awel_trigger_manager"
AWEL_DAG_MANAGER = "dbgpt_awel_dag_manager"
UNIFIED_METADATA_DB_MANAGER_FACTORY = "dbgpt_unified_metadata_db_manager_factory"
CONNECTOR_MANAGER = "dbgpt_connector_manager"


_EMPTY_DEFAULT_COMPONENT = "_EMPTY_DEFAULT_COMPONENT"
Expand Down
3 changes: 3 additions & 0 deletions dbgpt/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
CacheValue,
)
from dbgpt.core.interface.embeddings import Embeddings # noqa: F401
from dbgpt.core.interface.knowledge import Chunk, Document # noqa: F401
from dbgpt.core.interface.llm import ( # noqa: F401
DefaultMessageConverter,
LLMClient,
Expand Down Expand Up @@ -105,4 +106,6 @@
"QuerySpec",
"StorageError",
"Embeddings",
"Chunk",
"Document",
]
File renamed without changes.
7 changes: 6 additions & 1 deletion dbgpt/datasource/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,6 @@
from .manages.connect_config_db import ConnectConfigDao, ConnectConfigEntity
"""Module to define the data source connectors."""

from .base import BaseConnector # noqa: F401
from .rdbms.base import RDBMSConnector # noqa: F401

__ALL__ = ["BaseConnector", "RDBMSConnector"]
Loading
Loading