From cbba50ab1b452d129bda153bdf52957b16e14d05 Mon Sep 17 00:00:00 2001 From: FangYin Cheng Date: Mon, 11 Dec 2023 18:33:54 +0800 Subject: [PATCH] feat(core): Support simple DB query for sdk (#917) Co-authored-by: chengfangyin2 --- dbgpt/_private/config.py | 16 -- dbgpt/app/dbgpt_server.py | 12 +- dbgpt/app/scene/chat_db/auto_execute/chat.py | 1 + dbgpt/core/__init__.py | 3 +- dbgpt/core/awel/task/task_impl.py | 2 +- dbgpt/core/interface/output_parser.py | 10 ++ dbgpt/core/interface/retriever.py | 23 +++ dbgpt/datasource/base.py | 70 ++++++-- dbgpt/datasource/db_conn_info.py | 2 +- .../datasource/manages/connection_manager.py | 7 +- dbgpt/datasource/operator/__init__.py | 0 .../operator/datasource_operator.py | 16 ++ dbgpt/datasource/rdbms/conn_sqlite.py | 118 +++++++++++++- dbgpt/rag/operator/__init__.py | 0 dbgpt/rag/operator/datasource.py | 14 ++ dbgpt/rag/summary/rdbms_db_summary.py | 83 +++++++--- examples/sdk/simple_sdk_llm_example.py | 14 +- examples/sdk/simple_sdk_llm_sql_example.py | 150 ++++++++++++++++++ 18 files changed, 467 insertions(+), 74 deletions(-) create mode 100644 dbgpt/core/interface/retriever.py create mode 100644 dbgpt/datasource/operator/__init__.py create mode 100644 dbgpt/datasource/operator/datasource_operator.py create mode 100644 dbgpt/rag/operator/__init__.py create mode 100644 dbgpt/rag/operator/datasource.py create mode 100644 examples/sdk/simple_sdk_llm_sql_example.py diff --git a/dbgpt/_private/config.py b/dbgpt/_private/config.py index 5d6c90fa0..adab63b1b 100644 --- a/dbgpt/_private/config.py +++ b/dbgpt/_private/config.py @@ -266,19 +266,3 @@ def __init__(self) -> None: self.MODEL_CACHE_STORAGE_DISK_DIR: str = os.getenv( "MODEL_CACHE_STORAGE_DISK_DIR" ) - - def set_debug_mode(self, value: bool) -> None: - """Set the debug mode value""" - self.debug_mode = value - - def set_templature(self, value: int) -> None: - """Set the temperature value.""" - self.temperature = value - - def set_speak_mode(self, value: bool) -> None: - """Set the speak mode value.""" - self.speak_mode = value - - def set_last_plugin_return(self, value: bool) -> None: - """Set the speak mode value.""" - self.last_plugin_return = value diff --git a/dbgpt/app/dbgpt_server.py b/dbgpt/app/dbgpt_server.py index 57c4a6206..40ce82872 100644 --- a/dbgpt/app/dbgpt_server.py +++ b/dbgpt/app/dbgpt_server.py @@ -5,13 +5,13 @@ ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.append(ROOT_PATH) -from dbgpt._private.config import Config from dbgpt.configs.model_config import ( LLM_MODEL_CONFIG, EMBEDDING_MODEL_CONFIG, LOGDIR, ROOT_PATH, ) +from dbgpt._private.config import Config from dbgpt.component import SystemApp from dbgpt.app.base import ( @@ -30,7 +30,6 @@ from dbgpt.app.prompt.api import router as prompt_router from dbgpt.app.llm_manage.api import router as llm_manage_api - from dbgpt.app.openapi.api_v1.api_v1 import router as api_v1 from dbgpt.app.openapi.base import validation_exception_handler from dbgpt.app.openapi.api_v1.editor.api_editor_v1 import router as api_editor_route_v1 @@ -59,7 +58,7 @@ def swagger_monkey_patch(*args, **kwargs): *args, **kwargs, swagger_js_url="https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui-bundle.js", - swagger_css_url="https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui.css" + swagger_css_url="https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui.css", ) @@ -79,13 +78,11 @@ def swagger_monkey_patch(*args, **kwargs): allow_headers=["*"], ) - app.include_router(api_v1, prefix="/api", tags=["Chat"]) app.include_router(api_editor_route_v1, prefix="/api", tags=["Editor"]) app.include_router(llm_manage_api, prefix="/api", tags=["LLM Manage"]) app.include_router(api_fb_v1, prefix="/api", tags=["FeedBack"]) - app.include_router(knowledge_router, tags=["Knowledge"]) app.include_router(prompt_router, tags=["Prompt"]) @@ -133,7 +130,8 @@ def initialize_app(param: WebServerParameters = None, args: List[str] = None): # Before start system_app.before_start() - + model_name = param.model_name or CFG.LLM_MODEL + param.model_name = model_name print(param) embedding_model_name = CFG.EMBEDDING_MODEL @@ -143,8 +141,6 @@ def initialize_app(param: WebServerParameters = None, args: List[str] = None): model_start_listener = _create_model_start_listener(system_app) initialize_components(param, system_app, embedding_model_name, embedding_model_path) - model_name = param.model_name or CFG.LLM_MODEL - model_path = CFG.LLM_MODEL_PATH or LLM_MODEL_CONFIG.get(model_name) if not param.light: print("Model Unified Deployment Mode!") diff --git a/dbgpt/app/scene/chat_db/auto_execute/chat.py b/dbgpt/app/scene/chat_db/auto_execute/chat.py index 6c0cdd187..13790795b 100644 --- a/dbgpt/app/scene/chat_db/auto_execute/chat.py +++ b/dbgpt/app/scene/chat_db/auto_execute/chat.py @@ -52,6 +52,7 @@ async def generate_input_values(self) -> Dict: except ImportError: raise ValueError("Could not import DBSummaryClient. ") client = DBSummaryClient(system_app=CFG.SYSTEM_APP) + table_infos = None try: with root_tracer.start_span("ChatWithDbAutoExecute.get_db_summary"): table_infos = await blocking_func_to_async( diff --git a/dbgpt/core/__init__.py b/dbgpt/core/__init__.py index 81631cec9..d2f64aafd 100644 --- a/dbgpt/core/__init__.py +++ b/dbgpt/core/__init__.py @@ -11,7 +11,7 @@ OnceConversation, ) from dbgpt.core.interface.prompt import PromptTemplate, PromptTemplateOperator -from dbgpt.core.interface.output_parser import BaseOutputParser +from dbgpt.core.interface.output_parser import BaseOutputParser, SQLOutputParser from dbgpt.core.interface.serialization import Serializable, Serializer from dbgpt.core.interface.cache import ( CacheKey, @@ -33,6 +33,7 @@ "PromptTemplate", "PromptTemplateOperator", "BaseOutputParser", + "SQLOutputParser", "Serializable", "Serializer", "CacheKey", diff --git a/dbgpt/core/awel/task/task_impl.py b/dbgpt/core/awel/task/task_impl.py index 9a81f738e..a7bf542d5 100644 --- a/dbgpt/core/awel/task/task_impl.py +++ b/dbgpt/core/awel/task/task_impl.py @@ -53,7 +53,7 @@ def new_output(self) -> TaskOutput[T]: @property def is_empty(self) -> bool: - return not self._data + return self._data is None async def _apply_func(self, func) -> Any: if asyncio.iscoroutinefunction(func): diff --git a/dbgpt/core/interface/output_parser.py b/dbgpt/core/interface/output_parser.py index af1a64a96..ffea7f094 100644 --- a/dbgpt/core/interface/output_parser.py +++ b/dbgpt/core/interface/output_parser.py @@ -251,3 +251,13 @@ def _parse_model_response(response: ResponseTye): else: raise ValueError(f"Unsupported response type {type(response)}") return resp_obj_ex + + +class SQLOutputParser(BaseOutputParser): + def __init__(self, is_stream_out: bool = False, **kwargs): + super().__init__(is_stream_out=is_stream_out, **kwargs) + + def parse_model_nostream_resp(self, response: ResponseTye, sep: str): + model_out_text = super().parse_model_nostream_resp(response, sep) + clean_str = super().parse_prompt_response(model_out_text) + return json.loads(clean_str, strict=True) diff --git a/dbgpt/core/interface/retriever.py b/dbgpt/core/interface/retriever.py new file mode 100644 index 000000000..ec6e6650d --- /dev/null +++ b/dbgpt/core/interface/retriever.py @@ -0,0 +1,23 @@ +from abc import abstractmethod +from dbgpt.core.awel import MapOperator +from dbgpt.core.awel.task.base import IN, OUT + + +class RetrieverOperator(MapOperator[IN, OUT]): + """The Abstract Retriever Operator.""" + + async def map(self, input_value: IN) -> OUT: + """Map input value to output value. + + Args: + input_value (IN): The input value. + + Returns: + OUT: The output value. + """ + # The retrieve function is blocking, so we need to wrap it in a blocking_func_to_async. + return await self.blocking_func_to_async(self.retrieve, input_value) + + @abstractmethod + def retrieve(self, input_value: IN) -> OUT: + """Retrieve data for input value.""" diff --git a/dbgpt/datasource/base.py b/dbgpt/datasource/base.py index c380a1455..58f5f131c 100644 --- a/dbgpt/datasource/base.py +++ b/dbgpt/datasource/base.py @@ -2,58 +2,104 @@ # -*- coding:utf-8 -*- """We need to design a base class. That other connector can Write with this""" -from abc import ABC, abstractmethod -from typing import Any, Iterable, List, Optional +from abc import ABC +from typing import Iterable, List, Optional class BaseConnect(ABC): - def get_connect(self, db_name: str): - pass - def get_table_names(self) -> Iterable[str]: + """Get all table names""" pass def get_table_info(self, table_names: Optional[List[str]] = None) -> str: + """Get table info about specified table. + + Returns: + str: Table information joined by '\n\n' + """ pass def get_index_info(self, table_names: Optional[List[str]] = None) -> str: + """Get index info about specified table. + + Args: + table_names (Optional[List[str]]): table names + """ pass def get_example_data(self, table: str, count: int = 3): + """Get example data about specified table. + + Not used now. + + Args: + table (str): table name + count (int): example data count + """ pass - def get_database_list(self): + def get_database_list(self) -> List[str]: + """Get database list. + + Returns: + List[str]: database list + """ pass def get_database_names(self): + """Get database names.""" pass def get_table_comments(self, db_name): + """Get table comments. + + Args: + db_name (str): database name + """ pass - def run(self, session, command: str, fetch: str = "all") -> List: + def run(self, command: str, fetch: str = "all") -> List: + """Execute sql command. + + Args: + command (str): sql command + fetch (str): fetch type + """ pass def run_to_df(self, command: str, fetch: str = "all"): + """Execute sql command and return dataframe.""" pass def get_users(self): - pass + """Get user info.""" + return [] def get_grants(self): - pass + """Get grant info.""" + return [] def get_collation(self): - pass + """Get collation.""" + return None - def get_charset(self): - pass + def get_charset(self) -> str: + """Get character_set of current database.""" + return "utf-8" def get_fields(self, table_name): + """Get column fields about specified table.""" pass def get_show_create_table(self, table_name): + """Get the creation table sql about specified table.""" pass def get_indexes(self, table_name): + """Get table indexes about specified table.""" pass + + @classmethod + def is_normal_type(cls) -> bool: + """Return whether the connector is a normal type.""" + return True diff --git a/dbgpt/datasource/db_conn_info.py b/dbgpt/datasource/db_conn_info.py index 7d112cef8..56daaae7c 100644 --- a/dbgpt/datasource/db_conn_info.py +++ b/dbgpt/datasource/db_conn_info.py @@ -1,4 +1,4 @@ -from dbgpt._private.pydantic import BaseModel, Field +from dbgpt._private.pydantic import BaseModel class DBConfig(BaseModel): diff --git a/dbgpt/datasource/manages/connection_manager.py b/dbgpt/datasource/manages/connection_manager.py index 597613016..5594b5207 100644 --- a/dbgpt/datasource/manages/connection_manager.py +++ b/dbgpt/datasource/manages/connection_manager.py @@ -1,3 +1,4 @@ +from typing import List, Type from dbgpt.datasource import ConnectConfigDao from dbgpt.storage.schema import DBType from dbgpt.component import SystemApp, ComponentType @@ -21,7 +22,7 @@ class ConnectManager: """db connect manager""" - def get_all_subclasses(self, cls): + def get_all_subclasses(self, cls: Type[BaseConnect]) -> List[Type[BaseConnect]]: subclasses = cls.__subclasses__() for subclass in subclasses: subclasses += self.get_all_subclasses(subclass) @@ -31,7 +32,7 @@ def get_all_completed_types(self): chat_classes = self.get_all_subclasses(BaseConnect) support_types = [] for cls in chat_classes: - if cls.db_type: + if cls.db_type and cls.is_normal_type(): support_types.append(DBType.of_db_type(cls.db_type)) return support_types @@ -39,7 +40,7 @@ def get_cls_by_dbtype(self, db_type): chat_classes = self.get_all_subclasses(BaseConnect) result = None for cls in chat_classes: - if cls.db_type == db_type: + if cls.db_type == db_type and cls.is_normal_type(): result = cls if not result: raise ValueError("Unsupported Db Type!" + db_type) diff --git a/dbgpt/datasource/operator/__init__.py b/dbgpt/datasource/operator/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/dbgpt/datasource/operator/datasource_operator.py b/dbgpt/datasource/operator/datasource_operator.py new file mode 100644 index 000000000..752465d18 --- /dev/null +++ b/dbgpt/datasource/operator/datasource_operator.py @@ -0,0 +1,16 @@ +from typing import Any +from dbgpt.core.awel import MapOperator +from dbgpt.core.awel.task.base import IN, OUT +from dbgpt.datasource.base import BaseConnect + + +class DatasourceOperator(MapOperator[str, Any]): + def __init__(self, connection: BaseConnect, **kwargs): + super().__init__(**kwargs) + self._connection = connection + + async def map(self, input_value: IN) -> OUT: + return await self.blocking_func_to_async(self.query, input_value) + + def query(self, input_value: str) -> Any: + return self._connection.run_to_df(input_value) diff --git a/dbgpt/datasource/rdbms/conn_sqlite.py b/dbgpt/datasource/rdbms/conn_sqlite.py index 350f034c3..cff76df94 100644 --- a/dbgpt/datasource/rdbms/conn_sqlite.py +++ b/dbgpt/datasource/rdbms/conn_sqlite.py @@ -4,9 +4,12 @@ import os from typing import Optional, Any, Iterable from sqlalchemy import create_engine, text - +import tempfile +import logging from dbgpt.datasource.rdbms.base import RDBMSDatabase +logger = logging.getLogger(__name__) + class SQLiteConnect(RDBMSDatabase): """Connect SQLite Database fetch MetaData @@ -127,3 +130,116 @@ def table_simple_info(self) -> Iterable[str]: results.append(f"{table_name}({','.join(table_colums)});") return results + + +class SQLiteTempConnect(SQLiteConnect): + """A temporary SQLite database connection. The database file will be deleted when the connection is closed.""" + + def __init__(self, engine, temp_file_path, *args, **kwargs): + super().__init__(engine, *args, **kwargs) + self.temp_file_path = temp_file_path + self._is_closed = False + + @classmethod + def create_temporary_db( + cls, engine_args: Optional[dict] = None, **kwargs: Any + ) -> "SQLiteTempConnect": + """Create a temporary SQLite database with a temporary file. + + Examples: + .. code-block:: python + with SQLiteTempConnect.create_temporary_db() as db: + db.run(db.session, "CREATE TABLE test (id INTEGER PRIMARY KEY);") + db.run(db.session, "insert into test(id) values (1)") + db.run(db.session, "insert into test(id) values (2)") + field_names, result = db.query_ex(db.session, "select * from test") + assert field_names == ["id"] + assert result == [(1,), (2,)] + + Args: + engine_args (Optional[dict]): SQLAlchemy engine arguments. + + Returns: + SQLiteTempConnect: A SQLiteTempConnect instance. + """ + _engine_args = engine_args or {} + _engine_args["connect_args"] = {"check_same_thread": False} + + temp_file = tempfile.NamedTemporaryFile(delete=False) + temp_file_path = temp_file.name + temp_file.close() + + engine = create_engine(f"sqlite:///{temp_file_path}", **_engine_args) + return cls(engine, temp_file_path, **kwargs) + + def close(self): + """Close the connection.""" + if not self._is_closed: + if self._engine: + self._engine.dispose() + try: + if os.path.exists(self.temp_file_path): + os.remove(self.temp_file_path) + except Exception as e: + logger.error(f"Error removing temporary database file: {e}") + self._is_closed = True + + def create_temp_tables(self, tables_info): + """Create temporary tables with data. + + Examples: + .. code-block:: python + tables_info = { + "test": { + "columns": { + "id": "INTEGER PRIMARY KEY", + "name": "TEXT", + "age": "INTEGER", + }, + "data": [ + (1, "Tom", 20), + (2, "Jack", 21), + (3, "Alice", 22), + ], + }, + } + with SQLiteTempConnect.create_temporary_db() as db: + db.create_temp_tables(tables_info) + field_names, result = db.query_ex(db.session, "select * from test") + assert field_names == ["id", "name", "age"] + assert result == [(1, "Tom", 20), (2, "Jack", 21), (3, "Alice", 22)] + + Args: + tables_info (dict): A dictionary of table information. + """ + for table_name, table_data in tables_info.items(): + columns = ", ".join( + [f"{col} {dtype}" for col, dtype in table_data["columns"].items()] + ) + create_sql = f"CREATE TABLE {table_name} ({columns});" + self.session.execute(text(create_sql)) + for row in table_data.get("data", []): + placeholders = ", ".join( + [":param" + str(index) for index, _ in enumerate(row)] + ) + insert_sql = f"INSERT INTO {table_name} VALUES ({placeholders});" + + param_dict = { + "param" + str(index): value for index, value in enumerate(row) + } + self.session.execute(text(insert_sql), param_dict) + self.session.commit() + self._sync_tables_from_db() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + + def __del__(self): + self.close() + + @classmethod + def is_normal_type(cls) -> bool: + return False diff --git a/dbgpt/rag/operator/__init__.py b/dbgpt/rag/operator/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/dbgpt/rag/operator/datasource.py b/dbgpt/rag/operator/datasource.py new file mode 100644 index 000000000..c015eb33b --- /dev/null +++ b/dbgpt/rag/operator/datasource.py @@ -0,0 +1,14 @@ +from typing import Any +from dbgpt.core.interface.retriever import RetrieverOperator +from dbgpt.datasource.rdbms.base import RDBMSDatabase +from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary + + +class DatasourceRetrieverOperator(RetrieverOperator[Any, Any]): + def __init__(self, connection: RDBMSDatabase, **kwargs): + super().__init__(**kwargs) + self._connection = connection + + def retrieve(self, input_value: Any) -> Any: + summary = _parse_db_summary(self._connection) + return summary diff --git a/dbgpt/rag/summary/rdbms_db_summary.py b/dbgpt/rag/summary/rdbms_db_summary.py index fa5a9326d..373157e9b 100644 --- a/dbgpt/rag/summary/rdbms_db_summary.py +++ b/dbgpt/rag/summary/rdbms_db_summary.py @@ -1,5 +1,7 @@ +from typing import List from dbgpt._private.config import Config from dbgpt.rag.summary.db_summary import DBSummary +from dbgpt.datasource.rdbms.base import RDBMSDatabase CFG = Config() @@ -36,32 +38,63 @@ def get_table_summary(self, table_name): example: table_name(column1(column1 comment),column2(column2 comment),column3(column3 comment) and index keys, and table comment: {table_comment}) """ - columns = [] - for column in self.db._inspector.get_columns(table_name): - if column.get("comment"): - columns.append((f"{column['name']} ({column.get('comment')})")) - else: - columns.append(f"{column['name']}") - - column_str = ", ".join(columns) - index_keys = [] - for index_key in self.db._inspector.get_indexes(table_name): - key_str = ", ".join(index_key["column_names"]) - index_keys.append(f"{index_key['name']}(`{key_str}`) ") - table_str = self.summary_template.format( - table_name=table_name, columns=column_str - ) - if len(index_keys) > 0: - index_key_str = ", ".join(index_keys) - table_str += f", and index keys: {index_key_str}" - try: - comment = self.db._inspector.get_table_comment(table_name) - except Exception: - comment = dict(text=None) - if comment.get("text"): - table_str += f", and table comment: {comment.get('text')}" - return table_str + return _parse_table_summary(self.db, self.summary_template, table_name) def table_summaries(self): """Get table summaries.""" return self.table_info_summaries + + +def _parse_db_summary( + conn: RDBMSDatabase, summary_template: str = "{table_name}({columns})" +) -> List[str]: + """Get db summary for database. + + Args: + conn (RDBMSDatabase): database connection + summary_template (str): summary template + """ + tables = conn.get_table_names() + table_info_summaries = [ + _parse_table_summary(conn, summary_template, table_name) + for table_name in tables + ] + return table_info_summaries + + +def _parse_table_summary( + conn: RDBMSDatabase, summary_template: str, table_name: str +) -> str: + """Get table summary for table. + + Args: + conn (RDBMSDatabase): database connection + summary_template (str): summary template + table_name (str): table name + + Examples: + table_name(column1(column1 comment),column2(column2 comment),column3(column3 comment) and index keys, and table comment: {table_comment}) + """ + columns = [] + for column in conn._inspector.get_columns(table_name): + if column.get("comment"): + columns.append(f"{column['name']} ({column.get('comment')})") + else: + columns.append(f"{column['name']}") + + column_str = ", ".join(columns) + index_keys = [] + for index_key in conn._inspector.get_indexes(table_name): + key_str = ", ".join(index_key["column_names"]) + index_keys.append(f"{index_key['name']}(`{key_str}`) ") + table_str = summary_template.format(table_name=table_name, columns=column_str) + if len(index_keys) > 0: + index_key_str = ", ".join(index_keys) + table_str += f", and index keys: {index_key_str}" + try: + comment = conn._inspector.get_table_comment(table_name) + except Exception: + comment = dict(text=None) + if comment.get("text"): + table_str += f", and table comment: {comment.get('text')}" + return table_str diff --git a/examples/sdk/simple_sdk_llm_example.py b/examples/sdk/simple_sdk_llm_example.py index eee7f5e99..1cce17667 100644 --- a/examples/sdk/simple_sdk_llm_example.py +++ b/examples/sdk/simple_sdk_llm_example.py @@ -3,16 +3,18 @@ from dbgpt.core import BaseOutputParser, OpenAILLM, RequestBuildOperator, PromptTemplate with DAG("simple_sdk_llm_example_dag") as dag: - prompt = PromptTemplate.from_template( + prompt_task = PromptTemplate.from_template( "Write a SQL of {dialect} to query all data of {table_name}." ) - req_builder = RequestBuildOperator(model="gpt-3.5-turbo") - llm = OpenAILLM() - out_parser = BaseOutputParser() - prompt >> req_builder >> llm >> out_parser + model_pre_handle_task = RequestBuildOperator(model="gpt-3.5-turbo") + llm_task = OpenAILLM() + out_parse_task = BaseOutputParser() + prompt_task >> model_pre_handle_task >> llm_task >> out_parse_task if __name__ == "__main__": output = asyncio.run( - out_parser.call(call_data={"data": {"dialect": "mysql", "table_name": "user"}}) + out_parse_task.call( + call_data={"data": {"dialect": "mysql", "table_name": "user"}} + ) ) print(f"output: \n\n{output}") diff --git a/examples/sdk/simple_sdk_llm_sql_example.py b/examples/sdk/simple_sdk_llm_sql_example.py new file mode 100644 index 000000000..4aedf12c6 --- /dev/null +++ b/examples/sdk/simple_sdk_llm_sql_example.py @@ -0,0 +1,150 @@ +import asyncio +from typing import Dict, List +import json +from dbgpt.core.awel import ( + DAG, + InputOperator, + SimpleCallDataInputSource, + JoinOperator, + MapOperator, +) +from dbgpt.core import SQLOutputParser, OpenAILLM, RequestBuildOperator, PromptTemplate +from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnect +from dbgpt.datasource.operator.datasource_operator import DatasourceOperator +from dbgpt.rag.operator.datasource import DatasourceRetrieverOperator + + +def _create_temporary_connection(): + """Create a temporary database connection for testing.""" + connect = SQLiteTempConnect.create_temporary_db() + connect.create_temp_tables( + { + "user": { + "columns": { + "id": "INTEGER PRIMARY KEY", + "name": "TEXT", + "age": "INTEGER", + }, + "data": [ + (1, "Tom", 10), + (2, "Jerry", 16), + (3, "Jack", 18), + (4, "Alice", 20), + (5, "Bob", 22), + ], + } + } + ) + return connect + + +def _sql_prompt() -> str: + """This is a prompt template for SQL generation. + + Format of arguments: + {db_name}: database name + {table_info}: table structure information + {dialect}: database dialect + {top_k}: maximum number of results + {user_input}: user question + {response}: response format + + Returns: + str: prompt template + """ + return """Please answer the user's question based on the database selected by the user and some of the available table structure definitions of the database. +Database name: + {db_name} + +Table structure definition: + {table_info} + +Constraint: + 1.Please understand the user's intention based on the user's question, and use the given table structure definition to create a grammatically correct {dialect} sql. If sql is not required, answer the user's question directly.. + 2.Always limit the query to a maximum of {top_k} results unless the user specifies in the question the specific number of rows of data he wishes to obtain. + 3.You can only use the tables provided in the table structure information to generate sql. If you cannot generate sql based on the provided table structure, please say: "The table structure information provided is not enough to generate sql queries." It is prohibited to fabricate information at will. + 4.Please be careful not to mistake the relationship between tables and columns when generating SQL. + 5.Please check the correctness of the SQL and ensure that the query performance is optimized under correct conditions. + +User Question: + {user_input} +Please think step by step and respond according to the following JSON format: + {response} +Ensure the response is correct json and can be parsed by Python json.loads. +""" + + +def _join_func(query_dict: Dict, db_summary: List[str]): + """Join function for JoinOperator. + + Build the format arguments for the prompt template. + + Args: + query_dict (Dict): The query dict from DAG input. + db_summary (List[str]): The table structure information from DatasourceRetrieverOperator. + + Returns: + Dict: The query dict with the format arguments. + """ + default_response = { + "thoughts": "thoughts summary to say to user", + "sql": "SQL Query to run", + } + response = json.dumps(default_response, ensure_ascii=False, indent=4) + query_dict["table_info"] = db_summary + query_dict["response"] = response + return query_dict + + +class SQLResultOperator(JoinOperator[Dict]): + """Merge the SQL result and the model result.""" + + def __init__(self, **kwargs): + super().__init__(combine_function=self._combine_result, **kwargs) + + def _combine_result(self, sql_result_df, model_result: Dict) -> Dict: + model_result["data_df"] = sql_result_df + return model_result + + +with DAG("simple_sdk_llm_sql_example") as dag: + db_connection = _create_temporary_connection() + input_task = InputOperator(input_source=SimpleCallDataInputSource()) + retriever_task = DatasourceRetrieverOperator(connection=db_connection) + # Merge the input data and the table structure information. + prompt_input_task = JoinOperator(combine_function=_join_func) + prompt_task = PromptTemplate.from_template(_sql_prompt()) + model_pre_handle_task = RequestBuildOperator(model="gpt-3.5-turbo") + llm_task = OpenAILLM() + out_parse_task = SQLOutputParser() + sql_parse_task = MapOperator(map_function=lambda x: x["sql"]) + db_query_task = DatasourceOperator(connection=db_connection) + sql_result_task = SQLResultOperator() + input_task >> prompt_input_task + input_task >> retriever_task >> prompt_input_task + ( + prompt_input_task + >> prompt_task + >> model_pre_handle_task + >> llm_task + >> out_parse_task + >> sql_parse_task + >> db_query_task + >> sql_result_task + ) + out_parse_task >> sql_result_task + + +if __name__ == "__main__": + input_data = { + "data": { + "db_name": "test_db", + "dialect": "sqlite", + "top_k": 5, + "user_input": "What is the name and age of the user with age less than 18", + } + } + output = asyncio.run(sql_result_task.call(call_data=input_data)) + print(f"\nthoughts: {output.get('thoughts')}\n") + print(f"sql: {output.get('sql')}\n") + print(f"result data:\n{output.get('data_df')}")