diff --git a/assets/schema/dbgpt.sql b/assets/schema/dbgpt.sql index 4621faf79..1b8315402 100644 --- a/assets/schema/dbgpt.sql +++ b/assets/schema/dbgpt.sql @@ -278,6 +278,7 @@ CREATE TABLE `dbgpt_serve_flow` ( `flow_category` varchar(64) DEFAULT NULL COMMENT 'Flow category', `description` varchar(512) DEFAULT NULL COMMENT 'Flow description', `state` varchar(32) DEFAULT NULL COMMENT 'Flow state', + `error_message` varchar(512) NULL comment 'Error message', `source` varchar(64) DEFAULT NULL COMMENT 'Flow source', `source_url` varchar(512) DEFAULT NULL COMMENT 'Flow source url', `version` varchar(32) DEFAULT NULL COMMENT 'Flow version', diff --git a/assets/schema/upgrade/v0_5_2/upgrade_to_v0.5.2.sql b/assets/schema/upgrade/v0_5_2/upgrade_to_v0.5.2.sql new file mode 100644 index 000000000..e69de29bb diff --git a/assets/schema/upgrade/v0_5_2/v0.5.1.sql b/assets/schema/upgrade/v0_5_2/v0.5.1.sql new file mode 100644 index 000000000..5096ee185 --- /dev/null +++ b/assets/schema/upgrade/v0_5_2/v0.5.1.sql @@ -0,0 +1,395 @@ +-- Full SQL of v0.5.1, please not modify this file(It must be same as the file in the release package) + +CREATE +DATABASE IF NOT EXISTS dbgpt; +use dbgpt; + +-- For alembic migration tool +CREATE TABLE IF NOT EXISTS `alembic_version` +( + version_num VARCHAR(32) NOT NULL, + CONSTRAINT alembic_version_pkc PRIMARY KEY (version_num) +) DEFAULT CHARSET=utf8mb4 ; + +CREATE TABLE IF NOT EXISTS `knowledge_space` +( + `id` int NOT NULL AUTO_INCREMENT COMMENT 'auto increment id', + `name` varchar(100) NOT NULL COMMENT 'knowledge space name', + `vector_type` varchar(50) NOT NULL COMMENT 'vector type', + `desc` varchar(500) NOT NULL COMMENT 'description', + `owner` varchar(100) DEFAULT NULL COMMENT 'owner', + `context` TEXT DEFAULT NULL COMMENT 'context argument', + `gmt_created` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'created time', + `gmt_modified` TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time', + PRIMARY KEY (`id`), + KEY `idx_name` (`name`) COMMENT 'index:idx_name' +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT='knowledge space table'; + +CREATE TABLE IF NOT EXISTS `knowledge_document` +( + `id` int NOT NULL AUTO_INCREMENT COMMENT 'auto increment id', + `doc_name` varchar(100) NOT NULL COMMENT 'document path name', + `doc_type` varchar(50) NOT NULL COMMENT 'doc type', + `space` varchar(50) NOT NULL COMMENT 'knowledge space', + `chunk_size` int NOT NULL COMMENT 'chunk size', + `last_sync` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'last sync time', + `status` varchar(50) NOT NULL COMMENT 'status TODO,RUNNING,FAILED,FINISHED', + `content` LONGTEXT NOT NULL COMMENT 'knowledge embedding sync result', + `result` TEXT NULL COMMENT 'knowledge content', + `vector_ids` LONGTEXT NULL COMMENT 'vector_ids', + `summary` LONGTEXT NULL COMMENT 'knowledge summary', + `gmt_created` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'created time', + `gmt_modified` TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time', + PRIMARY KEY (`id`), + KEY `idx_doc_name` (`doc_name`) COMMENT 'index:idx_doc_name' +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT='knowledge document table'; + +CREATE TABLE IF NOT EXISTS `document_chunk` +( + `id` int NOT NULL AUTO_INCREMENT COMMENT 'auto increment id', + `doc_name` varchar(100) NOT NULL COMMENT 'document path name', + `doc_type` varchar(50) NOT NULL COMMENT 'doc type', + `document_id` int NOT NULL COMMENT 'document parent id', + `content` longtext NOT NULL COMMENT 'chunk content', + `meta_info` varchar(200) NOT NULL COMMENT 'metadata info', + `gmt_created` timestamp NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'created time', + `gmt_modified` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time', + PRIMARY KEY (`id`), + KEY `idx_document_id` (`document_id`) COMMENT 'index:document_id' +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT='knowledge document chunk detail'; + + + +CREATE TABLE IF NOT EXISTS `connect_config` +( + `id` int NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id', + `db_type` varchar(255) NOT NULL COMMENT 'db type', + `db_name` varchar(255) NOT NULL COMMENT 'db name', + `db_path` varchar(255) DEFAULT NULL COMMENT 'file db path', + `db_host` varchar(255) DEFAULT NULL COMMENT 'db connect host(not file db)', + `db_port` varchar(255) DEFAULT NULL COMMENT 'db cnnect port(not file db)', + `db_user` varchar(255) DEFAULT NULL COMMENT 'db user', + `db_pwd` varchar(255) DEFAULT NULL COMMENT 'db password', + `comment` text COMMENT 'db comment', + `sys_code` varchar(128) DEFAULT NULL COMMENT 'System code', + PRIMARY KEY (`id`), + UNIQUE KEY `uk_db` (`db_name`), + KEY `idx_q_db_type` (`db_type`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT 'Connection confi'; + +CREATE TABLE IF NOT EXISTS `chat_history` +( + `id` int NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id', + `conv_uid` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'Conversation record unique id', + `chat_mode` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'Conversation scene mode', + `summary` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'Conversation record summary', + `user_name` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'interlocutor', + `messages` text COLLATE utf8mb4_unicode_ci COMMENT 'Conversation details', + `message_ids` text COLLATE utf8mb4_unicode_ci COMMENT 'Message id list, split by comma', + `sys_code` varchar(128) DEFAULT NULL COMMENT 'System code', + `gmt_created` timestamp NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'created time', + `gmt_modified` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time', + UNIQUE KEY `conv_uid` (`conv_uid`), + PRIMARY KEY (`id`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT 'Chat history'; + +CREATE TABLE IF NOT EXISTS `chat_history_message` +( + `id` int NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id', + `conv_uid` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'Conversation record unique id', + `index` int NOT NULL COMMENT 'Message index', + `round_index` int NOT NULL COMMENT 'Round of conversation', + `message_detail` text COLLATE utf8mb4_unicode_ci COMMENT 'Message details, json format', + `gmt_created` timestamp NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'created time', + `gmt_modified` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time', + UNIQUE KEY `message_uid_index` (`conv_uid`, `index`), + PRIMARY KEY (`id`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT 'Chat history message'; + +CREATE TABLE IF NOT EXISTS `chat_feed_back` +( + `id` bigint(20) NOT NULL AUTO_INCREMENT, + `conv_uid` varchar(128) DEFAULT NULL COMMENT 'Conversation ID', + `conv_index` int(4) DEFAULT NULL COMMENT 'Round of conversation', + `score` int(1) DEFAULT NULL COMMENT 'Score of user', + `ques_type` varchar(32) DEFAULT NULL COMMENT 'User question category', + `question` longtext DEFAULT NULL COMMENT 'User question', + `knowledge_space` varchar(128) DEFAULT NULL COMMENT 'Knowledge space name', + `messages` longtext DEFAULT NULL COMMENT 'The details of user feedback', + `user_name` varchar(128) DEFAULT NULL COMMENT 'User name', + `gmt_created` timestamp NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'created time', + `gmt_modified` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time', + PRIMARY KEY (`id`), + UNIQUE KEY `uk_conv` (`conv_uid`,`conv_index`), + KEY `idx_conv` (`conv_uid`,`conv_index`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT='User feedback table'; + + +CREATE TABLE IF NOT EXISTS `my_plugin` +( + `id` int NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id', + `tenant` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'user tenant', + `user_code` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'user code', + `user_name` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'user name', + `name` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'plugin name', + `file_name` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'plugin package file name', + `type` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin type', + `version` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin version', + `use_count` int DEFAULT NULL COMMENT 'plugin total use count', + `succ_count` int DEFAULT NULL COMMENT 'plugin total success count', + `sys_code` varchar(128) DEFAULT NULL COMMENT 'System code', + `gmt_created` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'plugin install time', + PRIMARY KEY (`id`), + UNIQUE KEY `name` (`name`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT='User plugin table'; + +CREATE TABLE IF NOT EXISTS `plugin_hub` +( + `id` int NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id', + `name` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'plugin name', + `description` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'plugin description', + `author` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin author', + `email` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin author email', + `type` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin type', + `version` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin version', + `storage_channel` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin storage channel', + `storage_url` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin download url', + `download_param` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin download param', + `gmt_created` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'plugin upload time', + `installed` int DEFAULT NULL COMMENT 'plugin already installed count', + PRIMARY KEY (`id`), + UNIQUE KEY `name` (`name`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT='Plugin Hub table'; + + +CREATE TABLE IF NOT EXISTS `prompt_manage` +( + `id` int(11) NOT NULL AUTO_INCREMENT, + `chat_scene` varchar(100) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Chat scene', + `sub_chat_scene` varchar(100) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Sub chat scene', + `prompt_type` varchar(100) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Prompt type: common or private', + `prompt_name` varchar(256) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'prompt name', + `content` longtext COLLATE utf8mb4_unicode_ci COMMENT 'Prompt content', + `input_variables` varchar(1024) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Prompt input variables(split by comma))', + `model` varchar(128) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Prompt model name(we can use different models for different prompt)', + `prompt_language` varchar(32) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Prompt language(eg:en, zh-cn)', + `prompt_format` varchar(32) COLLATE utf8mb4_unicode_ci DEFAULT 'f-string' COMMENT 'Prompt format(eg: f-string, jinja2)', + `prompt_desc` varchar(512) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Prompt description', + `user_name` varchar(128) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'User name', + `sys_code` varchar(128) DEFAULT NULL COMMENT 'System code', + `gmt_created` timestamp NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'created time', + `gmt_modified` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time', + PRIMARY KEY (`id`), + UNIQUE KEY `prompt_name_uiq` (`prompt_name`, `sys_code`, `prompt_language`, `model`), + KEY `gmt_created_idx` (`gmt_created`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT='Prompt management table'; + + CREATE TABLE IF NOT EXISTS `gpts_conversations` ( + `id` int(11) NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id', + `conv_id` varchar(255) NOT NULL COMMENT 'The unique id of the conversation record', + `user_goal` text NOT NULL COMMENT 'User''s goals content', + `gpts_name` varchar(255) NOT NULL COMMENT 'The gpts name', + `state` varchar(255) DEFAULT NULL COMMENT 'The gpts state', + `max_auto_reply_round` int(11) NOT NULL COMMENT 'max auto reply round', + `auto_reply_count` int(11) NOT NULL COMMENT 'auto reply count', + `user_code` varchar(255) DEFAULT NULL COMMENT 'user code', + `sys_code` varchar(255) DEFAULT NULL COMMENT 'system app ', + `created_at` datetime DEFAULT NULL COMMENT 'create time', + `updated_at` datetime DEFAULT NULL COMMENT 'last update time', + `team_mode` varchar(255) NULL COMMENT 'agent team work mode', + + PRIMARY KEY (`id`), + UNIQUE KEY `uk_gpts_conversations` (`conv_id`), + KEY `idx_gpts_name` (`gpts_name`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT="gpt conversations"; + +CREATE TABLE IF NOT EXISTS `gpts_instance` ( + `id` int(11) NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id', + `gpts_name` varchar(255) NOT NULL COMMENT 'Current AI assistant name', + `gpts_describe` varchar(2255) NOT NULL COMMENT 'Current AI assistant describe', + `resource_db` text COMMENT 'List of structured database names contained in the current gpts', + `resource_internet` text COMMENT 'Is it possible to retrieve information from the internet', + `resource_knowledge` text COMMENT 'List of unstructured database names contained in the current gpts', + `gpts_agents` varchar(1000) DEFAULT NULL COMMENT 'List of agents names contained in the current gpts', + `gpts_models` varchar(1000) DEFAULT NULL COMMENT 'List of llm model names contained in the current gpts', + `language` varchar(100) DEFAULT NULL COMMENT 'gpts language', + `user_code` varchar(255) NOT NULL COMMENT 'user code', + `sys_code` varchar(255) DEFAULT NULL COMMENT 'system app code', + `created_at` datetime DEFAULT NULL COMMENT 'create time', + `updated_at` datetime DEFAULT NULL COMMENT 'last update time', + `team_mode` varchar(255) NOT NULL COMMENT 'Team work mode', + `is_sustainable` tinyint(1) NOT NULL COMMENT 'Applications for sustainable dialogue', + PRIMARY KEY (`id`), + UNIQUE KEY `uk_gpts` (`gpts_name`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT="gpts instance"; + +CREATE TABLE `gpts_messages` ( + `id` int(11) NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id', + `conv_id` varchar(255) NOT NULL COMMENT 'The unique id of the conversation record', + `sender` varchar(255) NOT NULL COMMENT 'Who speaking in the current conversation turn', + `receiver` varchar(255) NOT NULL COMMENT 'Who receive message in the current conversation turn', + `model_name` varchar(255) DEFAULT NULL COMMENT 'message generate model', + `rounds` int(11) NOT NULL COMMENT 'dialogue turns', + `content` text COMMENT 'Content of the speech', + `current_goal` text COMMENT 'The target corresponding to the current message', + `context` text COMMENT 'Current conversation context', + `review_info` text COMMENT 'Current conversation review info', + `action_report` text COMMENT 'Current conversation action report', + `role` varchar(255) DEFAULT NULL COMMENT 'The role of the current message content', + `created_at` datetime DEFAULT NULL COMMENT 'create time', + `updated_at` datetime DEFAULT NULL COMMENT 'last update time', + PRIMARY KEY (`id`), + KEY `idx_q_messages` (`conv_id`,`rounds`,`sender`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT="gpts message"; + + +CREATE TABLE `gpts_plans` ( + `id` int(11) NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id', + `conv_id` varchar(255) NOT NULL COMMENT 'The unique id of the conversation record', + `sub_task_num` int(11) NOT NULL COMMENT 'Subtask number', + `sub_task_title` varchar(255) NOT NULL COMMENT 'subtask title', + `sub_task_content` text NOT NULL COMMENT 'subtask content', + `sub_task_agent` varchar(255) DEFAULT NULL COMMENT 'Available agents corresponding to subtasks', + `resource_name` varchar(255) DEFAULT NULL COMMENT 'resource name', + `rely` varchar(255) DEFAULT NULL COMMENT 'Subtask dependencies,like: 1,2,3', + `agent_model` varchar(255) DEFAULT NULL COMMENT 'LLM model used by subtask processing agents', + `retry_times` int(11) DEFAULT NULL COMMENT 'number of retries', + `max_retry_times` int(11) DEFAULT NULL COMMENT 'Maximum number of retries', + `state` varchar(255) DEFAULT NULL COMMENT 'subtask status', + `result` longtext COMMENT 'subtask result', + `created_at` datetime DEFAULT NULL COMMENT 'create time', + `updated_at` datetime DEFAULT NULL COMMENT 'last update time', + PRIMARY KEY (`id`), + UNIQUE KEY `uk_sub_task` (`conv_id`,`sub_task_num`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT="gpt plan"; + +-- dbgpt.dbgpt_serve_flow definition +CREATE TABLE `dbgpt_serve_flow` ( + `id` int NOT NULL AUTO_INCREMENT COMMENT 'Auto increment id', + `uid` varchar(128) NOT NULL COMMENT 'Unique id', + `dag_id` varchar(128) DEFAULT NULL COMMENT 'DAG id', + `name` varchar(128) DEFAULT NULL COMMENT 'Flow name', + `flow_data` text COMMENT 'Flow data, JSON format', + `user_name` varchar(128) DEFAULT NULL COMMENT 'User name', + `sys_code` varchar(128) DEFAULT NULL COMMENT 'System code', + `gmt_created` datetime DEFAULT NULL COMMENT 'Record creation time', + `gmt_modified` datetime DEFAULT NULL COMMENT 'Record update time', + `flow_category` varchar(64) DEFAULT NULL COMMENT 'Flow category', + `description` varchar(512) DEFAULT NULL COMMENT 'Flow description', + `state` varchar(32) DEFAULT NULL COMMENT 'Flow state', + `error_message` varchar(512) NULL comment 'Error message', + `source` varchar(64) DEFAULT NULL COMMENT 'Flow source', + `source_url` varchar(512) DEFAULT NULL COMMENT 'Flow source url', + `version` varchar(32) DEFAULT NULL COMMENT 'Flow version', + `label` varchar(128) DEFAULT NULL COMMENT 'Flow label', + `editable` int DEFAULT NULL COMMENT 'Editable, 0: editable, 1: not editable', + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`), + KEY `ix_dbgpt_serve_flow_sys_code` (`sys_code`), + KEY `ix_dbgpt_serve_flow_uid` (`uid`), + KEY `ix_dbgpt_serve_flow_dag_id` (`dag_id`), + KEY `ix_dbgpt_serve_flow_user_name` (`user_name`), + KEY `ix_dbgpt_serve_flow_name` (`name`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; + +-- dbgpt.gpts_app definition +CREATE TABLE `gpts_app` ( + `id` int NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id', + `app_code` varchar(255) NOT NULL COMMENT 'Current AI assistant code', + `app_name` varchar(255) NOT NULL COMMENT 'Current AI assistant name', + `app_describe` varchar(2255) NOT NULL COMMENT 'Current AI assistant describe', + `language` varchar(100) NOT NULL COMMENT 'gpts language', + `team_mode` varchar(255) NOT NULL COMMENT 'Team work mode', + `team_context` text COMMENT 'The execution logic and team member content that teams with different working modes rely on', + `user_code` varchar(255) DEFAULT NULL COMMENT 'user code', + `sys_code` varchar(255) DEFAULT NULL COMMENT 'system app code', + `created_at` datetime DEFAULT NULL COMMENT 'create time', + `updated_at` datetime DEFAULT NULL COMMENT 'last update time', + `icon` varchar(1024) DEFAULT NULL COMMENT 'app icon, url', + PRIMARY KEY (`id`), + UNIQUE KEY `uk_gpts_app` (`app_name`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; + +CREATE TABLE `gpts_app_collection` ( + `id` int(11) NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id', + `app_code` varchar(255) NOT NULL COMMENT 'Current AI assistant code', + `user_code` int(11) NOT NULL COMMENT 'user code', + `sys_code` varchar(255) NOT NULL COMMENT 'system app code', + `created_at` datetime DEFAULT NULL COMMENT 'create time', + `updated_at` datetime DEFAULT NULL COMMENT 'last update time', + PRIMARY KEY (`id`), + KEY `idx_app_code` (`app_code`), + KEY `idx_user_code` (`user_code`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT="gpt collections"; + +-- dbgpt.gpts_app_detail definition +CREATE TABLE `gpts_app_detail` ( + `id` int NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id', + `app_code` varchar(255) NOT NULL COMMENT 'Current AI assistant code', + `app_name` varchar(255) NOT NULL COMMENT 'Current AI assistant name', + `agent_name` varchar(255) NOT NULL COMMENT ' Agent name', + `node_id` varchar(255) NOT NULL COMMENT 'Current AI assistant Agent Node id', + `resources` text COMMENT 'Agent bind resource', + `prompt_template` text COMMENT 'Agent bind template', + `llm_strategy` varchar(25) DEFAULT NULL COMMENT 'Agent use llm strategy', + `llm_strategy_value` text COMMENT 'Agent use llm strategy value', + `created_at` datetime DEFAULT NULL COMMENT 'create time', + `updated_at` datetime DEFAULT NULL COMMENT 'last update time', + PRIMARY KEY (`id`), + UNIQUE KEY `uk_gpts_app_agent_node` (`app_name`,`agent_name`,`node_id`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; + +CREATE +DATABASE IF NOT EXISTS EXAMPLE_1; +use EXAMPLE_1; +CREATE TABLE IF NOT EXISTS `users` +( + `id` int NOT NULL AUTO_INCREMENT, + `username` varchar(50) NOT NULL COMMENT '用户名', + `password` varchar(50) NOT NULL COMMENT '密码', + `email` varchar(50) NOT NULL COMMENT '邮箱', + `phone` varchar(20) DEFAULT NULL COMMENT '电话', + PRIMARY KEY (`id`), + KEY `idx_username` (`username`) COMMENT '索引:按用户名查询' +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT='聊天用户表'; + +INSERT INTO users (username, password, email, phone) +VALUES ('user_1', 'password_1', 'user_1@example.com', '12345678901'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_2', 'password_2', 'user_2@example.com', '12345678902'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_3', 'password_3', 'user_3@example.com', '12345678903'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_4', 'password_4', 'user_4@example.com', '12345678904'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_5', 'password_5', 'user_5@example.com', '12345678905'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_6', 'password_6', 'user_6@example.com', '12345678906'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_7', 'password_7', 'user_7@example.com', '12345678907'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_8', 'password_8', 'user_8@example.com', '12345678908'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_9', 'password_9', 'user_9@example.com', '12345678909'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_10', 'password_10', 'user_10@example.com', '12345678900'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_11', 'password_11', 'user_11@example.com', '12345678901'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_12', 'password_12', 'user_12@example.com', '12345678902'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_13', 'password_13', 'user_13@example.com', '12345678903'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_14', 'password_14', 'user_14@example.com', '12345678904'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_15', 'password_15', 'user_15@example.com', '12345678905'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_16', 'password_16', 'user_16@example.com', '12345678906'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_17', 'password_17', 'user_17@example.com', '12345678907'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_18', 'password_18', 'user_18@example.com', '12345678908'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_19', 'password_19', 'user_19@example.com', '12345678909'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_20', 'password_20', 'user_20@example.com', '12345678900'); \ No newline at end of file diff --git a/dbgpt/configs/model_config.py b/dbgpt/configs/model_config.py index d6d60b8b2..d55097bd8 100644 --- a/dbgpt/configs/model_config.py +++ b/dbgpt/configs/model_config.py @@ -167,6 +167,8 @@ def get_device() -> str: # https://huggingface.co/BAAI/bge-large-zh "bge-large-zh": os.path.join(MODEL_PATH, "bge-large-zh"), "bge-base-zh": os.path.join(MODEL_PATH, "bge-base-zh"), + # https://huggingface.co/BAAI/bge-m3, beg need normalize_embeddings=True + "bge-m3": os.path.join(MODEL_PATH, "bge-m3"), "gte-large-zh": os.path.join(MODEL_PATH, "gte-large-zh"), "gte-base-zh": os.path.join(MODEL_PATH, "gte-base-zh"), "sentence-transforms": os.path.join(MODEL_PATH, "all-MiniLM-L6-v2"), diff --git a/dbgpt/core/__init__.py b/dbgpt/core/__init__.py index fe955bffb..04c069d94 100644 --- a/dbgpt/core/__init__.py +++ b/dbgpt/core/__init__.py @@ -7,6 +7,7 @@ CachePolicy, CacheValue, ) +from dbgpt.core.interface.embeddings import Embeddings # noqa: F401 from dbgpt.core.interface.llm import ( # noqa: F401 DefaultMessageConverter, LLMClient, @@ -103,4 +104,5 @@ "DefaultStorageItemAdapter", "QuerySpec", "StorageError", + "Embeddings", ] diff --git a/dbgpt/core/awel/__init__.py b/dbgpt/core/awel/__init__.py index c13729d7d..06bce9312 100644 --- a/dbgpt/core/awel/__init__.py +++ b/dbgpt/core/awel/__init__.py @@ -55,6 +55,7 @@ CommonLLMHttpResponseBody, HttpTrigger, ) +from .trigger.iterator_trigger import IteratorTrigger _request_http_trigger_available = False try: @@ -100,6 +101,7 @@ "TransformStreamAbsOperator", "Trigger", "HttpTrigger", + "IteratorTrigger", "CommonLLMHTTPRequestContext", "CommonLLMHttpResponseBody", "CommonLLMHttpRequestBody", diff --git a/dbgpt/core/awel/operators/common_operator.py b/dbgpt/core/awel/operators/common_operator.py index f95db095d..25cc058fa 100644 --- a/dbgpt/core/awel/operators/common_operator.py +++ b/dbgpt/core/awel/operators/common_operator.py @@ -277,7 +277,7 @@ async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]: return task_output -class TriggerOperator(InputOperator, Generic[OUT]): +class TriggerOperator(InputOperator[OUT], Generic[OUT]): """Operator node that triggers the DAG to run.""" def __init__(self, **kwargs) -> None: diff --git a/dbgpt/core/awel/runner/local_runner.py b/dbgpt/core/awel/runner/local_runner.py index 8ded92ffc..b22949375 100644 --- a/dbgpt/core/awel/runner/local_runner.py +++ b/dbgpt/core/awel/runner/local_runner.py @@ -60,8 +60,8 @@ async def execute_workflow( streaming_call=streaming_call, node_name_to_ids=job_manager._node_name_to_ids, ) - if node.dag: - self._running_dag_ctx[node.dag.dag_id] = dag_ctx + # if node.dag: + # self._running_dag_ctx[node.dag.dag_id] = dag_ctx logger.info( f"Begin run workflow from end operator, id: {node.node_id}, runner: {self}" ) @@ -76,8 +76,8 @@ async def execute_workflow( if not streaming_call and node.dag: # streaming call not work for dag end await node.dag._after_dag_end() - if node.dag: - del self._running_dag_ctx[node.dag.dag_id] + # if node.dag: + # del self._running_dag_ctx[node.dag.dag_id] return dag_ctx async def _execute_node( diff --git a/dbgpt/core/awel/task/base.py b/dbgpt/core/awel/task/base.py index a0befd5f2..188fa99dd 100644 --- a/dbgpt/core/awel/task/base.py +++ b/dbgpt/core/awel/task/base.py @@ -3,11 +3,13 @@ from enum import Enum from typing import ( Any, + AsyncIterable, AsyncIterator, Awaitable, Callable, Dict, Generic, + Iterable, List, Optional, TypeVar, @@ -421,3 +423,40 @@ async def read(self, task_ctx: TaskContext) -> TaskOutput[T]: Returns: TaskOutput[T]: The output object read from current source """ + + @classmethod + def from_data(cls, data: T) -> "InputSource[T]": + """Create an InputSource from data. + + Args: + data (T): The data to create the InputSource from. + + Returns: + InputSource[T]: The InputSource created from the data. + """ + from .task_impl import SimpleInputSource + + return SimpleInputSource(data, streaming=False) + + @classmethod + def from_iterable( + cls, iterable: Union[AsyncIterable[T], Iterable[T]] + ) -> "InputSource[T]": + """Create an InputSource from an iterable. + + Args: + iterable (List[T]): The iterable to create the InputSource from. + + Returns: + InputSource[T]: The InputSource created from the iterable. + """ + from .task_impl import SimpleInputSource + + return SimpleInputSource(iterable, streaming=True) + + @classmethod + def from_callable(cls) -> "InputSource[T]": + """Create an InputSource from a callable.""" + from .task_impl import SimpleCallDataInputSource + + return SimpleCallDataInputSource() diff --git a/dbgpt/core/awel/task/task_impl.py b/dbgpt/core/awel/task/task_impl.py index 2a79629aa..cfafbeb09 100644 --- a/dbgpt/core/awel/task/task_impl.py +++ b/dbgpt/core/awel/task/task_impl.py @@ -261,13 +261,42 @@ def _is_async_iterator(obj): ) +def _is_async_iterable(obj): + return hasattr(obj, "__aiter__") and callable(getattr(obj, "__aiter__", None)) + + +def _is_iterator(obj): + return ( + hasattr(obj, "__iter__") + and callable(getattr(obj, "__iter__", None)) + and hasattr(obj, "__next__") + and callable(getattr(obj, "__next__", None)) + ) + + +def _is_iterable(obj): + return hasattr(obj, "__iter__") and callable(getattr(obj, "__iter__", None)) + + +async def _to_async_iterator(obj) -> AsyncIterator: + if _is_async_iterable(obj): + async for item in obj: + yield item + elif _is_iterable(obj): + for item in obj: + yield item + else: + raise ValueError(f"Can not convert {obj} to AsyncIterator") + + class BaseInputSource(InputSource, ABC): """The base class of InputSource.""" - def __init__(self) -> None: + def __init__(self, streaming: Optional[bool] = None) -> None: """Create a BaseInputSource.""" super().__init__() self._is_read = False + self._streaming_data = streaming @abstractmethod def _read_data(self, task_ctx: TaskContext) -> Any: @@ -286,10 +315,15 @@ async def read(self, task_ctx: TaskContext) -> TaskOutput: ValueError: If the input source is a stream and has been read. """ data = self._read_data(task_ctx) - if _is_async_iterator(data): + if self._streaming_data is None: + streaming_data = _is_async_iterator(data) or _is_iterator(data) + else: + streaming_data = self._streaming_data + if streaming_data: if self._is_read: raise ValueError(f"Input iterator {data} has been read!") - output: TaskOutput = SimpleStreamTaskOutput(data) + it_data = _to_async_iterator(data) + output: TaskOutput = SimpleStreamTaskOutput(it_data) else: output = SimpleTaskOutput(data) self._is_read = True @@ -299,13 +333,13 @@ async def read(self, task_ctx: TaskContext) -> TaskOutput: class SimpleInputSource(BaseInputSource): """The default implementation of InputSource.""" - def __init__(self, data: Any) -> None: + def __init__(self, data: Any, streaming: Optional[bool] = None) -> None: """Create a SimpleInputSource. Args: data (Any): The input data. """ - super().__init__() + super().__init__(streaming=streaming) self._data = data def _read_data(self, task_ctx: TaskContext) -> Any: diff --git a/dbgpt/core/awel/tests/trigger/__init__.py b/dbgpt/core/awel/tests/trigger/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/dbgpt/core/awel/tests/trigger/test_iterator_trigger.py b/dbgpt/core/awel/tests/trigger/test_iterator_trigger.py new file mode 100644 index 000000000..b6ffc1435 --- /dev/null +++ b/dbgpt/core/awel/tests/trigger/test_iterator_trigger.py @@ -0,0 +1,118 @@ +from typing import AsyncIterator + +import pytest + +from dbgpt.core.awel import ( + DAG, + InputSource, + MapOperator, + StreamifyAbsOperator, + TransformStreamAbsOperator, +) +from dbgpt.core.awel.trigger.iterator_trigger import IteratorTrigger + + +class NumberProducerOperator(StreamifyAbsOperator[int, int]): + """Create a stream of numbers from 0 to `n-1`""" + + async def streamify(self, n: int) -> AsyncIterator[int]: + for i in range(n): + yield i + + +class MyStreamingOperator(TransformStreamAbsOperator[int, int]): + async def transform_stream(self, data: AsyncIterator[int]) -> AsyncIterator[int]: + async for i in data: + yield i * i + + +async def _check_stream_results(stream_results, expected_len): + assert len(stream_results) == expected_len + for _, result in stream_results: + i = 0 + async for num in result: + assert num == i * i + i += 1 + + +@pytest.mark.asyncio +async def test_single_data(): + with DAG("test_single_data"): + trigger_task = IteratorTrigger(data=2) + task = MapOperator(lambda x: x * x) + trigger_task >> task + results = await trigger_task.trigger() + assert len(results) == 1 + assert results[0][1] == 4 + + with DAG("test_single_data_stream"): + trigger_task = IteratorTrigger(data=2, streaming_call=True) + number_task = NumberProducerOperator() + task = MyStreamingOperator() + trigger_task >> number_task >> task + stream_results = await trigger_task.trigger() + await _check_stream_results(stream_results, 1) + + +@pytest.mark.asyncio +async def test_list_data(): + with DAG("test_list_data"): + trigger_task = IteratorTrigger(data=[0, 1, 2, 3]) + task = MapOperator(lambda x: x * x) + trigger_task >> task + results = await trigger_task.trigger() + assert len(results) == 4 + assert results == [(0, 0), (1, 1), (2, 4), (3, 9)] + + with DAG("test_list_data_stream"): + trigger_task = IteratorTrigger(data=[0, 1, 2, 3], streaming_call=True) + number_task = NumberProducerOperator() + task = MyStreamingOperator() + trigger_task >> number_task >> task + stream_results = await trigger_task.trigger() + await _check_stream_results(stream_results, 4) + + +@pytest.mark.asyncio +async def test_async_iterator_data(): + async def async_iter(): + for i in range(4): + yield i + + with DAG("test_async_iterator_data"): + trigger_task = IteratorTrigger(data=async_iter()) + task = MapOperator(lambda x: x * x) + trigger_task >> task + results = await trigger_task.trigger() + assert len(results) == 4 + assert results == [(0, 0), (1, 1), (2, 4), (3, 9)] + + with DAG("test_async_iterator_data_stream"): + trigger_task = IteratorTrigger(data=async_iter(), streaming_call=True) + number_task = NumberProducerOperator() + task = MyStreamingOperator() + trigger_task >> number_task >> task + stream_results = await trigger_task.trigger() + await _check_stream_results(stream_results, 4) + + +@pytest.mark.asyncio +async def test_input_source_data(): + with DAG("test_input_source_data"): + trigger_task = IteratorTrigger(data=InputSource.from_iterable([0, 1, 2, 3])) + task = MapOperator(lambda x: x * x) + trigger_task >> task + results = await trigger_task.trigger() + assert len(results) == 4 + assert results == [(0, 0), (1, 1), (2, 4), (3, 9)] + + with DAG("test_input_source_data_stream"): + trigger_task = IteratorTrigger( + data=InputSource.from_iterable([0, 1, 2, 3]), + streaming_call=True, + ) + number_task = NumberProducerOperator() + task = MyStreamingOperator() + trigger_task >> number_task >> task + stream_results = await trigger_task.trigger() + await _check_stream_results(stream_results, 4) diff --git a/dbgpt/core/awel/trigger/base.py b/dbgpt/core/awel/trigger/base.py index 309153eed..9b59cd94e 100644 --- a/dbgpt/core/awel/trigger/base.py +++ b/dbgpt/core/awel/trigger/base.py @@ -2,16 +2,18 @@ from __future__ import annotations from abc import ABC, abstractmethod +from typing import Any, Generic from ..operators.common_operator import TriggerOperator +from ..task.base import OUT -class Trigger(TriggerOperator, ABC): +class Trigger(TriggerOperator[OUT], ABC, Generic[OUT]): """Base class for all trigger classes. Now only support http trigger. """ @abstractmethod - async def trigger(self) -> None: + async def trigger(self, **kwargs) -> Any: """Trigger the workflow or a specific operation in the workflow.""" diff --git a/dbgpt/core/awel/trigger/http_trigger.py b/dbgpt/core/awel/trigger/http_trigger.py index f8c39dcb4..adbbec1ed 100644 --- a/dbgpt/core/awel/trigger/http_trigger.py +++ b/dbgpt/core/awel/trigger/http_trigger.py @@ -397,9 +397,9 @@ def __init__( self._end_node: Optional[BaseOperator] = None self._register_to_app = register_to_app - async def trigger(self) -> None: + async def trigger(self, **kwargs) -> Any: """Trigger the DAG. Not used in HttpTrigger.""" - pass + raise NotImplementedError("HttpTrigger does not support trigger directly") def register_to_app(self) -> bool: """Register the trigger to a FastAPI app. diff --git a/dbgpt/core/awel/trigger/iterator_trigger.py b/dbgpt/core/awel/trigger/iterator_trigger.py new file mode 100644 index 000000000..15cbab637 --- /dev/null +++ b/dbgpt/core/awel/trigger/iterator_trigger.py @@ -0,0 +1,148 @@ +"""Trigger for iterator data.""" + +import asyncio +from typing import Any, AsyncIterator, Iterator, List, Optional, Tuple, Union, cast + +from ..operators.base import BaseOperator +from ..task.base import InputSource, TaskState +from ..task.task_impl import DefaultTaskContext, _is_async_iterator, _is_iterable +from .base import Trigger + +IterDataType = Union[InputSource, Iterator, AsyncIterator, Any] + + +async def _to_async_iterator(iter_data: IterDataType, task_id: str) -> AsyncIterator: + """Convert iter_data to an async iterator.""" + if _is_async_iterator(iter_data): + async for item in iter_data: # type: ignore + yield item + elif _is_iterable(iter_data): + for item in iter_data: # type: ignore + yield item + elif isinstance(iter_data, InputSource): + task_ctx: DefaultTaskContext[Any] = DefaultTaskContext( + task_id, TaskState.RUNNING, None + ) + data = await iter_data.read(task_ctx) + if data.is_stream: + async for item in data.output_stream: + yield item + else: + yield data.output + else: + yield iter_data + + +class IteratorTrigger(Trigger): + """Trigger for iterator data. + + Trigger the dag with iterator data. + Return the list of results of the leaf nodes in the dag. + The times of dag running is the length of the iterator data. + """ + + def __init__( + self, + data: IterDataType, + parallel_num: int = 1, + streaming_call: bool = False, + **kwargs + ): + """Create a IteratorTrigger. + + Args: + data (IterDataType): The iterator data. + parallel_num (int, optional): The parallel number of the dag running. + Defaults to 1. + streaming_call (bool, optional): Whether the dag is a streaming call. + Defaults to False. + """ + self._iter_data = data + self._parallel_num = parallel_num + self._streaming_call = streaming_call + super().__init__(**kwargs) + + async def trigger( + self, parallel_num: Optional[int] = None, **kwargs + ) -> List[Tuple[Any, Any]]: + """Trigger the dag with iterator data. + + If the dag is a streaming call, return the list of async iterator. + + Examples: + .. code-block:: python + + import asyncio + from dbgpt.core.awel import DAG, IteratorTrigger, MapOperator + + with DAG("test_dag") as dag: + trigger_task = IteratorTrigger([0, 1, 2, 3]) + task = MapOperator(lambda x: x * x) + trigger_task >> task + results = asyncio.run(trigger_task.trigger()) + # Fist element of the tuple is the input data, the second element is the + # output data of the leaf node. + assert results == [(0, 0), (1, 1), (2, 4), (3, 9)] + + .. code-block:: python + + import asyncio + from datasets import Dataset + from dbgpt.core.awel import ( + DAG, + IteratorTrigger, + MapOperator, + InputSource, + ) + + data_samples = { + "question": ["What is 1+1?", "What is 7*7?"], + "answer": [2, 49], + } + dataset = Dataset.from_dict(data_samples) + with DAG("test_dag_stream") as dag: + trigger_task = IteratorTrigger(InputSource.from_iterable(dataset)) + task = MapOperator(lambda x: x["answer"]) + trigger_task >> task + results = asyncio.run(trigger_task.trigger()) + assert results == [ + ({"question": "What is 1+1?", "answer": 2}, 2), + ({"question": "What is 7*7?", "answer": 49}, 49), + ] + Args: + parallel_num (Optional[int], optional): The parallel number of the dag + running. Defaults to None. + + Returns: + List[Tuple[Any, Any]]: The list of results of the leaf nodes in the dag. + The first element of the tuple is the input data, the second element is + the output data of the leaf node. + """ + dag = self.dag + if not dag: + raise ValueError("DAG is not set for IteratorTrigger") + leaf_nodes = dag.leaf_nodes + if len(leaf_nodes) != 1: + raise ValueError("IteratorTrigger just support one leaf node in dag") + end_node = cast(BaseOperator, leaf_nodes[0]) + streaming_call = self._streaming_call + semaphore = asyncio.Semaphore(parallel_num or self._parallel_num) + task_id = self.node_id + + async def call_stream(call_data: Any): + async for out in await end_node.call_stream(call_data): + yield out + + async def run_node(call_data: Any): + async with semaphore: + if streaming_call: + task_output = call_stream(call_data) + else: + task_output = await end_node.call(call_data) + return call_data, task_output + + tasks = [] + async for data in _to_async_iterator(self._iter_data, task_id): + tasks.append(run_node(data)) + results = await asyncio.gather(*tasks) + return results diff --git a/dbgpt/core/interface/embeddings.py b/dbgpt/core/interface/embeddings.py new file mode 100644 index 000000000..e3756fd1c --- /dev/null +++ b/dbgpt/core/interface/embeddings.py @@ -0,0 +1,32 @@ +"""Interface for embedding models.""" +import asyncio +from abc import ABC, abstractmethod +from typing import List + + +class Embeddings(ABC): + """Interface for embedding models. + + Refer to `Langchain Embeddings `_. + """ + + @abstractmethod + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """Embed search docs.""" + + @abstractmethod + def embed_query(self, text: str) -> List[float]: + """Embed query text.""" + + async def aembed_documents(self, texts: List[str]) -> List[List[float]]: + """Asynchronous Embed search docs.""" + return await asyncio.get_running_loop().run_in_executor( + None, self.embed_documents, texts + ) + + async def aembed_query(self, text: str) -> List[float]: + """Asynchronous Embed query text.""" + return await asyncio.get_running_loop().run_in_executor( + None, self.embed_query, text + ) diff --git a/dbgpt/core/interface/evaluation.py b/dbgpt/core/interface/evaluation.py new file mode 100644 index 000000000..b4e53867c --- /dev/null +++ b/dbgpt/core/interface/evaluation.py @@ -0,0 +1,253 @@ +"""Evaluation module.""" +import asyncio +import string +from abc import ABC, abstractmethod +from typing import ( + TYPE_CHECKING, + Any, + AsyncIterator, + Callable, + Generic, + Iterator, + List, + Optional, + Sequence, + TypeVar, + Union, +) + +from dbgpt._private.pydantic import BaseModel, Field +from dbgpt.util.similarity_util import calculate_cosine_similarity + +from .embeddings import Embeddings +from .llm import LLMClient + +if TYPE_CHECKING: + from dbgpt.core.awel.task.base import InputSource + +QueryType = Union[str, Any] +PredictionType = Union[str, Any] +ContextType = Union[str, Sequence[str], Any] +DatasetType = Union["InputSource", Iterator, AsyncIterator] + + +class BaseEvaluationResult(BaseModel): + """Base evaluation result.""" + + prediction: Optional[PredictionType] = Field( + None, + description="Prediction data(including the output of LLM, the data from " + "retrieval, etc.)", + ) + contexts: Optional[ContextType] = Field(None, description="Context data") + score: Optional[float] = Field(None, description="Score for the prediction") + passing: Optional[bool] = Field( + None, description="Binary evaluation result (passing or not)" + ) + metric_name: Optional[str] = Field(None, description="Name of the metric") + + +class EvaluationResult(BaseEvaluationResult): + """Evaluation result. + + Output of an BaseEvaluator. + """ + + query: Optional[QueryType] = Field(None, description="Query data") + raw_dataset: Optional[Any] = Field(None, description="Raw dataset") + + +Q = TypeVar("Q") +P = TypeVar("P") +C = TypeVar("C") + + +class EvaluationMetric(ABC, Generic[P, C]): + """Base class for evaluation metric.""" + + @property + def name(self) -> str: + """Name of the metric.""" + return self.__class__.__name__ + + async def compute( + self, + prediction: P, + contexts: Optional[Sequence[C]] = None, + ) -> BaseEvaluationResult: + """Compute the evaluation metric. + + Args: + prediction(P): The prediction data. + contexts(Optional[Sequence[C]]): The context data. + + Returns: + BaseEvaluationResult: The evaluation result. + """ + return await asyncio.get_running_loop().run_in_executor( + None, self.sync_compute, prediction, contexts + ) + + def sync_compute( + self, + prediction: P, + contexts: Optional[Sequence[C]] = None, + ) -> BaseEvaluationResult: + """Compute the evaluation metric. + + Args: + prediction(P): The prediction data. + contexts(Optional[Sequence[C]]): The context data. + + Returns: + BaseEvaluationResult: The evaluation result. + """ + raise NotImplementedError("sync_compute is not implemented") + + +class FunctionMetric(EvaluationMetric[P, C], Generic[P, C]): + """Evaluation metric based on a function.""" + + def __init__( + self, + name: str, + func: Callable[ + [P, Optional[Sequence[C]]], + BaseEvaluationResult, + ], + ): + """Create a FunctionMetric. + + Args: + name(str): The name of the metric. + func(Callable[[P, Optional[Sequence[C]]], BaseEvaluationResult]): + The function to use for evaluation. + """ + self._name = name + self.func = func + + @property + def name(self) -> str: + """Name of the metric.""" + return self._name + + async def compute( + self, + prediction: P, + context: Optional[Sequence[C]] = None, + ) -> BaseEvaluationResult: + """Compute the evaluation metric.""" + return self.func(prediction, context) + + +class ExactMatchMetric(EvaluationMetric[str, str]): + """Exact match metric. + + Just support string prediction and context. + """ + + def __init__(self, ignore_case: bool = False, ignore_punctuation: bool = False): + """Create an ExactMatchMetric.""" + self._ignore_case = ignore_case + self._ignore_punctuation = ignore_punctuation + + async def compute( + self, + prediction: str, + contexts: Optional[Sequence[str]] = None, + ) -> BaseEvaluationResult: + """Compute the evaluation metric.""" + if self._ignore_case: + prediction = prediction.lower() + if contexts: + contexts = [c.lower() for c in contexts] + if self._ignore_punctuation: + prediction = prediction.translate(str.maketrans("", "", string.punctuation)) + if contexts: + contexts = [ + c.translate(str.maketrans("", "", string.punctuation)) + for c in contexts + ] + score = 0 if not contexts else float(prediction in contexts) + return BaseEvaluationResult( + prediction=prediction, + contexts=contexts, + score=score, + ) + + +class SimilarityMetric(EvaluationMetric[str, str]): + """Similarity metric. + + Calculate the cosine similarity between a prediction and a list of contexts. + """ + + def __init__(self, embeddings: Embeddings): + """Create a SimilarityMetric with embeddings.""" + self._embeddings = embeddings + + def sync_compute( + self, + prediction: str, + contexts: Optional[Sequence[str]] = None, + ) -> BaseEvaluationResult: + """Compute the evaluation metric.""" + if not contexts: + return BaseEvaluationResult( + prediction=prediction, + contexts=contexts, + score=0.0, + ) + try: + import numpy as np + except ImportError: + raise ImportError("numpy is required for SimilarityMetric") + + similarity: np.ndarray = calculate_cosine_similarity( + self._embeddings, prediction, contexts + ) + return BaseEvaluationResult( + prediction=prediction, + contexts=contexts, + score=float(similarity.mean()), + ) + + +class Evaluator(ABC): + """Base Evaluator class.""" + + def __init__( + self, + llm_client: Optional[LLMClient] = None, + ): + """Create an Evaluator.""" + self.llm_client = llm_client + + @abstractmethod + async def evaluate( + self, + dataset: DatasetType, + metrics: Optional[List[EvaluationMetric]] = None, + query_key: str = "query", + contexts_key: str = "contexts", + prediction_key: str = "prediction", + parallel_num: int = 1, + **kwargs + ) -> List[List[EvaluationResult]]: + """Run evaluation with a dataset and metrics. + + Args: + dataset(DatasetType): The dataset to evaluate. + metrics(Optional[List[EvaluationMetric]]): The metrics to use for + evaluation. + query_key(str): The key for query in the dataset. + contexts_key(str): The key for contexts in the dataset. + prediction_key(str): The key for prediction in the dataset. + parallel_num(int): The number of parallel tasks. + kwargs: Additional arguments. + + Returns: + List[List[EvaluationResult]]: The evaluation results, the length of the + result equals to the length of the dataset. The first element in the + list is the list of evaluation results for metrics. + """ diff --git a/dbgpt/rag/chunk_manager.py b/dbgpt/rag/chunk_manager.py index 550fa6d0e..094876a66 100644 --- a/dbgpt/rag/chunk_manager.py +++ b/dbgpt/rag/chunk_manager.py @@ -5,9 +5,10 @@ from pydantic import BaseModel, Field -from dbgpt.rag.chunk import Chunk +from dbgpt.rag.chunk import Chunk, Document from dbgpt.rag.extractor.base import Extractor from dbgpt.rag.knowledge.base import ChunkStrategy, Knowledge +from dbgpt.rag.text_splitter import TextSplitter class SplitterType(Enum): @@ -81,14 +82,14 @@ def __init__( self._text_splitter = self._chunk_parameters.text_splitter self._splitter_type = self._chunk_parameters.splitter_type - def split(self, documents) -> List[Chunk]: + def split(self, documents: List[Document]) -> List[Chunk]: """Split a document into chunks.""" text_splitter = self._select_text_splitter() if SplitterType.LANGCHAIN == self._splitter_type: documents = text_splitter.split_documents(documents) return [Chunk.langchain2chunk(document) for document in documents] elif SplitterType.LLAMA_INDEX == self._splitter_type: - nodes = text_splitter.split_text(documents) + nodes = text_splitter.split_documents(documents) return [Chunk.llamaindex2chunk(node) for node in nodes] else: return text_splitter.split_documents(documents) @@ -106,7 +107,7 @@ def chunk_parameters(self) -> ChunkParameters: def set_text_splitter( self, - text_splitter, + text_splitter: TextSplitter, splitter_type: SplitterType = SplitterType.LANGCHAIN, ) -> None: """Add text splitter.""" @@ -115,13 +116,13 @@ def set_text_splitter( def get_text_splitter( self, - ) -> Any: + ) -> TextSplitter: """Return text splitter.""" return self._select_text_splitter() def _select_text_splitter( self, - ): + ) -> TextSplitter: """Select text splitter by chunk strategy.""" if self._text_splitter: return self._text_splitter diff --git a/dbgpt/rag/embedding/embeddings.py b/dbgpt/rag/embedding/embeddings.py index f02623b5f..71ac90041 100644 --- a/dbgpt/rag/embedding/embeddings.py +++ b/dbgpt/rag/embedding/embeddings.py @@ -1,13 +1,13 @@ """Embedding implementations.""" -import asyncio -from abc import ABC, abstractmethod + from typing import Any, Dict, List, Optional import aiohttp import requests from dbgpt._private.pydantic import BaseModel, Extra, Field +from dbgpt.core import Embeddings DEFAULT_MODEL_NAME = "sentence-transformers/all-mpnet-base-v2" DEFAULT_INSTRUCT_MODEL = "hkunlp/instructor-large" @@ -22,34 +22,6 @@ DEFAULT_QUERY_BGE_INSTRUCTION_ZH = "为这个句子生成表示以用于检索相关文章:" -class Embeddings(ABC): - """Interface for embedding models. - - Refer to `Langchain Embeddings `_. - """ - - @abstractmethod - def embed_documents(self, texts: List[str]) -> List[List[float]]: - """Embed search docs.""" - - @abstractmethod - def embed_query(self, text: str) -> List[float]: - """Embed query text.""" - - async def aembed_documents(self, texts: List[str]) -> List[List[float]]: - """Asynchronous Embed search docs.""" - return await asyncio.get_running_loop().run_in_executor( - None, self.embed_documents, texts - ) - - async def aembed_query(self, text: str) -> List[float]: - """Asynchronous Embed query text.""" - return await asyncio.get_running_loop().run_in_executor( - None, self.embed_query, text - ) - - class HuggingFaceEmbeddings(BaseModel, Embeddings): """HuggingFace sentence_transformers embedding models. diff --git a/dbgpt/rag/evaluation/__init__.py b/dbgpt/rag/evaluation/__init__.py new file mode 100644 index 000000000..9206207a2 --- /dev/null +++ b/dbgpt/rag/evaluation/__init__.py @@ -0,0 +1,13 @@ +"""Module for evaluation of RAG.""" + +from .retriever import ( # noqa: F401 + RetrieverEvaluationMetric, + RetrieverEvaluator, + RetrieverSimilarityMetric, +) + +__ALL__ = [ + "RetrieverEvaluator", + "RetrieverSimilarityMetric", + "RetrieverEvaluationMetric", +] diff --git a/dbgpt/rag/evaluation/retriever.py b/dbgpt/rag/evaluation/retriever.py new file mode 100644 index 000000000..fabf0aa70 --- /dev/null +++ b/dbgpt/rag/evaluation/retriever.py @@ -0,0 +1,171 @@ +"""Evaluation for retriever.""" +from abc import ABC +from typing import Any, Dict, List, Optional, Sequence, Type + +from dbgpt.core import Embeddings, LLMClient +from dbgpt.core.interface.evaluation import ( + BaseEvaluationResult, + DatasetType, + EvaluationMetric, + EvaluationResult, + Evaluator, +) +from dbgpt.core.interface.operators.retriever import RetrieverOperator +from dbgpt.util.similarity_util import calculate_cosine_similarity + +from ..operators.evaluation import RetrieverEvaluatorOperator + + +class RetrieverEvaluationMetric(EvaluationMetric[List[str], str], ABC): + """Evaluation metric for retriever. + + The prediction is a list of str(content from chunks) and the context is a string. + """ + + +class RetrieverSimilarityMetric(RetrieverEvaluationMetric): + """Similarity metric for retriever.""" + + def __init__(self, embeddings: Embeddings): + """Create a SimilarityMetric with embeddings.""" + self._embeddings = embeddings + + def sync_compute( + self, + prediction: List[str], + contexts: Optional[Sequence[str]] = None, + ) -> BaseEvaluationResult: + """Compute the evaluation metric. + + Args: + prediction(List[str]): The retrieved chunks from the retriever. + contexts(Sequence[str]): The contexts from dataset. + + Returns: + BaseEvaluationResult: The evaluation result. + The score is the mean of the cosine similarity between the prediction + and the contexts. + """ + if not prediction or not contexts: + return BaseEvaluationResult( + prediction=prediction, + contexts=contexts, + score=0.0, + ) + try: + import numpy as np + except ImportError: + raise ImportError("numpy is required for RelevancySimilarityMetric") + + similarity: np.ndarray = calculate_cosine_similarity( + self._embeddings, contexts[0], prediction + ) + return BaseEvaluationResult( + prediction=prediction, + contexts=contexts, + score=float(similarity.mean()), + ) + + +class RetrieverEvaluator(Evaluator): + """Evaluator for relevancy. + + Examples: + .. code-block:: python + + import os + import asyncio + from dbgpt.rag.operators import ( + EmbeddingRetrieverOperator, + RetrieverEvaluatorOperator, + ) + from dbgpt.rag.evaluation import ( + RetrieverEvaluator, + RetrieverSimilarityMetric, + ) + from dbgpt.rag.embedding import DefaultEmbeddingFactory + from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig + from dbgpt.storage.vector_store.connector import VectorStoreConnector + from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH + + embeddings = DefaultEmbeddingFactory( + default_model_name=os.path.join(MODEL_PATH, "text2vec-large-chinese"), + ).create() + vector_connector = VectorStoreConnector.from_default( + "Chroma", + vector_store_config=ChromaVectorConfig( + name="my_test_schema", + persist_path=os.path.join(PILOT_PATH, "data"), + ), + embedding_fn=embeddings, + ) + + dataset = [ + { + "query": "what is awel talk about", + "contexts": [ + "Through the AWEL API, you can focus on the development" + " of business logic for LLMs applications without paying " + "attention to cumbersome model and environment details." + ], + }, + ] + evaluator = RetrieverEvaluator( + operator_cls=EmbeddingRetrieverOperator, + embeddings=embeddings, + operator_kwargs={ + "top_k": 5, + "vector_store_connector": vector_connector, + }, + ) + results = asyncio.run(evaluator.evaluate(dataset)) + """ + + def __init__( + self, + operator_cls: Type[RetrieverOperator], + llm_client: Optional[LLMClient] = None, + embeddings: Optional[Embeddings] = None, + operator_kwargs: Optional[Dict] = None, + ): + """Create a new RetrieverEvaluator.""" + if not operator_kwargs: + operator_kwargs = {} + self._operator_cls = operator_cls + self._operator_kwargs: Dict[str, Any] = operator_kwargs + self.embeddings = embeddings + super().__init__(llm_client=llm_client) + + async def evaluate( + self, + dataset: DatasetType, + metrics: Optional[List[EvaluationMetric]] = None, + query_key: str = "query", + contexts_key: str = "contexts", + prediction_key: str = "prediction", + parallel_num: int = 1, + **kwargs + ) -> List[List[EvaluationResult]]: + """Evaluate the dataset.""" + from dbgpt.core.awel import DAG, IteratorTrigger, MapOperator + + if not metrics: + if not self.embeddings: + raise ValueError("embeddings are required for SimilarityMetric") + metrics = [RetrieverSimilarityMetric(self.embeddings)] + + with DAG("relevancy_evaluation_dag"): + input_task = IteratorTrigger(dataset) + query_task: MapOperator = MapOperator(lambda x: x[query_key]) + retriever_task = self._operator_cls(**self._operator_kwargs) + retriever_eva_task = RetrieverEvaluatorOperator( + evaluation_metrics=metrics, llm_client=self.llm_client + ) + input_task >> query_task + query_task >> retriever_eva_task + query_task >> retriever_task >> retriever_eva_task + input_task >> MapOperator(lambda x: x[contexts_key]) >> retriever_eva_task + input_task >> retriever_eva_task + + results = await input_task.trigger(parallel_num=parallel_num) + return [item for _, item in results] diff --git a/dbgpt/rag/knowledge/base.py b/dbgpt/rag/knowledge/base.py index 21187c61b..e59138fff 100644 --- a/dbgpt/rag/knowledge/base.py +++ b/dbgpt/rag/knowledge/base.py @@ -154,7 +154,7 @@ def __init__( self._type = knowledge_type self._data_loader = data_loader - def load(self): + def load(self) -> List[Document]: """Load knowledge from data_loader.""" documents = self._load() return self._postprocess(documents) @@ -174,7 +174,7 @@ def _postprocess(self, docs: List[Document]) -> List[Document]: return docs @abstractmethod - def _load(self): + def _load(self) -> List[Document]: """Preprocess knowledge from data_loader.""" @classmethod diff --git a/dbgpt/rag/operators/__init__.py b/dbgpt/rag/operators/__init__.py index 41bafcbbb..72a3f1a9a 100644 --- a/dbgpt/rag/operators/__init__.py +++ b/dbgpt/rag/operators/__init__.py @@ -3,6 +3,7 @@ from .datasource import DatasourceRetrieverOperator # noqa: F401 from .db_schema import DBSchemaRetrieverOperator # noqa: F401 from .embedding import EmbeddingRetrieverOperator # noqa: F401 +from .evaluation import RetrieverEvaluatorOperator # noqa: F401 from .knowledge import KnowledgeOperator # noqa: F401 from .rerank import RerankOperator # noqa: F401 from .rewrite import QueryRewriteOperator # noqa: F401 @@ -16,4 +17,5 @@ "RerankOperator", "QueryRewriteOperator", "SummaryAssemblerOperator", + "RetrieverEvaluatorOperator", ] diff --git a/dbgpt/rag/operators/embedding.py b/dbgpt/rag/operators/embedding.py index a5e2096ab..d21401be2 100644 --- a/dbgpt/rag/operators/embedding.py +++ b/dbgpt/rag/operators/embedding.py @@ -1,16 +1,17 @@ """Embedding retriever operator.""" from functools import reduce -from typing import Any, Optional +from typing import List, Optional, Union from dbgpt.core.interface.operators.retriever import RetrieverOperator +from dbgpt.rag.chunk import Chunk from dbgpt.rag.retriever.embedding import EmbeddingRetriever from dbgpt.rag.retriever.rerank import Ranker from dbgpt.rag.retriever.rewrite import QueryRewrite from dbgpt.storage.vector_store.connector import VectorStoreConnector -class EmbeddingRetrieverOperator(RetrieverOperator[Any, Any]): +class EmbeddingRetrieverOperator(RetrieverOperator[Union[str, List[str]], List[Chunk]]): """The Embedding Retriever Operator.""" def __init__( @@ -32,7 +33,7 @@ def __init__( rerank=rerank, ) - def retrieve(self, query: Any) -> Any: + def retrieve(self, query: Union[str, List[str]]) -> List[Chunk]: """Retrieve the candidates.""" if isinstance(query, str): return self._retriever.retrieve_with_scores(query, self._score_threshold) diff --git a/dbgpt/rag/operators/evaluation.py b/dbgpt/rag/operators/evaluation.py new file mode 100644 index 000000000..81c5ac52c --- /dev/null +++ b/dbgpt/rag/operators/evaluation.py @@ -0,0 +1,61 @@ +"""Evaluation operators.""" +import asyncio +from typing import Any, List, Optional + +from dbgpt.core.awel import JoinOperator +from dbgpt.core.interface.evaluation import EvaluationMetric, EvaluationResult +from dbgpt.core.interface.llm import LLMClient + +from ..chunk import Chunk + + +class RetrieverEvaluatorOperator(JoinOperator[List[EvaluationResult]]): + """Evaluator for retriever.""" + + def __init__( + self, + evaluation_metrics: List[EvaluationMetric], + llm_client: Optional[LLMClient] = None, + **kwargs, + ): + """Create a new RetrieverEvaluatorOperator.""" + self.llm_client = llm_client + self.evaluation_metrics = evaluation_metrics + super().__init__(combine_function=self._do_evaluation, **kwargs) + + async def _do_evaluation( + self, + query: str, + prediction: List[Chunk], + contexts: List[str], + raw_dataset: Any = None, + ) -> List[EvaluationResult]: + """Run evaluation. + + Args: + query(str): The query string. + prediction(List[Chunk]): The retrieved chunks from the retriever. + contexts(List[str]): The contexts from dataset. + raw_dataset(Any): The raw data(single row) from dataset. + """ + if isinstance(contexts, str): + contexts = [contexts] + prediction_strs = [chunk.content for chunk in prediction] + tasks = [] + for metric in self.evaluation_metrics: + tasks.append(metric.compute(prediction_strs, contexts)) + task_results = await asyncio.gather(*tasks) + results = [] + for result, metric in zip(task_results, self.evaluation_metrics): + results.append( + EvaluationResult( + query=query, + prediction=prediction, + score=result.score, + contexts=contexts, + passing=result.passing, + raw_dataset=raw_dataset, + metric_name=metric.name, + ) + ) + return results diff --git a/dbgpt/serve/agent/team/layout/agent_operator.py b/dbgpt/serve/agent/team/layout/agent_operator.py index 5a4c7c207..86f80c81a 100644 --- a/dbgpt/serve/agent/team/layout/agent_operator.py +++ b/dbgpt/serve/agent/team/layout/agent_operator.py @@ -256,6 +256,6 @@ def __init__( """Initialize a HttpTrigger.""" super().__init__(**kwargs) - async def trigger(self) -> None: + async def trigger(self, **kwargs) -> None: """Trigger the DAG. Not used in HttpTrigger.""" - pass + raise NotImplementedError("Dummy trigger does not support trigger.") diff --git a/dbgpt/serve/rag/assembler/base.py b/dbgpt/serve/rag/assembler/base.py index 31aa3ce25..72501f5b7 100644 --- a/dbgpt/serve/rag/assembler/base.py +++ b/dbgpt/serve/rag/assembler/base.py @@ -44,8 +44,10 @@ def __init__( with root_tracer.start_span("BaseAssembler.load_knowledge", metadata=metadata): self.load_knowledge(self._knowledge) - def load_knowledge(self, knowledge) -> None: + def load_knowledge(self, knowledge: Optional[Knowledge] = None) -> None: """Load knowledge Pipeline.""" + if not knowledge: + raise ValueError("knowledge must be provided.") with root_tracer.start_span("BaseAssembler.knowledge.load"): documents = knowledge.load() with root_tracer.start_span("BaseAssembler.chunk_manager.split"): @@ -56,8 +58,12 @@ def as_retriever(self, **kwargs: Any) -> BaseRetriever: """Return a retriever.""" @abstractmethod - def persist(self, chunks: List[Chunk]) -> None: - """Persist chunks.""" + def persist(self) -> List[str]: + """Persist chunks. + + Returns: + List[str]: List of persisted chunk ids. + """ def get_chunks(self) -> List[Chunk]: """Return chunks.""" diff --git a/dbgpt/serve/rag/assembler/db_schema.py b/dbgpt/serve/rag/assembler/db_schema.py index f7935dcb0..eed45ba34 100644 --- a/dbgpt/serve/rag/assembler/db_schema.py +++ b/dbgpt/serve/rag/assembler/db_schema.py @@ -129,7 +129,11 @@ def get_chunks(self) -> List[Chunk]: return self._chunks def persist(self) -> List[str]: - """Persist chunks into vector store.""" + """Persist chunks into vector store. + + Returns: + List[str]: List of chunk ids. + """ return self._vector_store_connector.load_document(self._chunks) def _extract_info(self, chunks) -> List[Chunk]: diff --git a/dbgpt/serve/rag/assembler/embedding.py b/dbgpt/serve/rag/assembler/embedding.py index bc536147b..b43803a42 100644 --- a/dbgpt/serve/rag/assembler/embedding.py +++ b/dbgpt/serve/rag/assembler/embedding.py @@ -29,7 +29,7 @@ class EmbeddingAssembler(BaseAssembler): def __init__( self, - knowledge: Knowledge = None, + knowledge: Knowledge, chunk_parameters: Optional[ChunkParameters] = None, embedding_model: Optional[str] = None, embedding_factory: Optional[EmbeddingFactory] = None, @@ -69,7 +69,7 @@ def __init__( @classmethod def load_from_knowledge( cls, - knowledge: Knowledge = None, + knowledge: Knowledge, chunk_parameters: Optional[ChunkParameters] = None, embedding_model: Optional[str] = None, embedding_factory: Optional[EmbeddingFactory] = None, @@ -99,7 +99,11 @@ def load_from_knowledge( ) def persist(self) -> List[str]: - """Persist chunks into vector store.""" + """Persist chunks into vector store. + + Returns: + List[str]: List of chunk ids. + """ return self._vector_store_connector.load_document(self._chunks) def _extract_info(self, chunks) -> List[Chunk]: diff --git a/dbgpt/serve/rag/assembler/summary.py b/dbgpt/serve/rag/assembler/summary.py index 927fa024e..5ff3fa8a3 100644 --- a/dbgpt/serve/rag/assembler/summary.py +++ b/dbgpt/serve/rag/assembler/summary.py @@ -32,7 +32,7 @@ class SummaryAssembler(BaseAssembler): def __init__( self, - knowledge: Knowledge = None, + knowledge: Knowledge, chunk_parameters: Optional[ChunkParameters] = None, model_name: Optional[str] = None, llm_client: Optional[LLMClient] = None, @@ -69,7 +69,7 @@ def __init__( @classmethod def load_from_knowledge( cls, - knowledge: Knowledge = None, + knowledge: Knowledge, chunk_parameters: Optional[ChunkParameters] = None, model_name: Optional[str] = None, llm_client: Optional[LLMClient] = None, @@ -104,6 +104,7 @@ async def generate_summary(self) -> str: def persist(self) -> List[str]: """Persist chunks into store.""" + raise NotImplementedError def _extract_info(self, chunks) -> List[Chunk]: """Extract info from chunks.""" diff --git a/dbgpt/util/similarity_util.py b/dbgpt/util/similarity_util.py new file mode 100644 index 000000000..1f382dd27 --- /dev/null +++ b/dbgpt/util/similarity_util.py @@ -0,0 +1,37 @@ +"""Utility functions for calculating similarity.""" +from typing import TYPE_CHECKING, Any, Sequence + +if TYPE_CHECKING: + from dbgpt.core.interface.embeddings import Embeddings + + +def calculate_cosine_similarity( + embeddings: "Embeddings", prediction: str, contexts: Sequence[str] +) -> Any: + """Calculate the cosine similarity between a prediction and a list of contexts. + + Args: + embeddings(Embeddings): The embeddings to use. + prediction(str): The prediction. + contexts(Sequence[str]): The contexts. + + Returns: + numpy.ndarray: The cosine similarity. + """ + try: + import numpy as np + except ImportError: + raise ImportError("numpy is required for SimilarityMetric") + prediction_vec = np.asarray(embeddings.embed_query(prediction)).reshape(1, -1) + context_list = list(contexts) + context_list_vec = np.asarray(embeddings.embed_documents(context_list)).reshape( + len(contexts), -1 + ) + # cos(a,b) = dot(a,b) / (norm(a) * norm(b)) + dot = np.dot(context_list_vec, prediction_vec.T).reshape( + -1, + ) + norm = np.linalg.norm(context_list_vec, axis=1) * np.linalg.norm( + prediction_vec, axis=1 + ) + return dot / norm diff --git a/examples/rag/retriever_evaluation_example.py b/examples/rag/retriever_evaluation_example.py new file mode 100644 index 000000000..92c2386ad --- /dev/null +++ b/examples/rag/retriever_evaluation_example.py @@ -0,0 +1,82 @@ +import asyncio +import os +from typing import Optional + +from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH, ROOT_PATH +from dbgpt.core import Embeddings +from dbgpt.rag.chunk_manager import ChunkParameters +from dbgpt.rag.embedding import DefaultEmbeddingFactory +from dbgpt.rag.evaluation import RetrieverEvaluator +from dbgpt.rag.knowledge import KnowledgeFactory +from dbgpt.rag.operators import EmbeddingRetrieverOperator +from dbgpt.serve.rag.assembler.embedding import EmbeddingAssembler +from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig +from dbgpt.storage.vector_store.connector import VectorStoreConnector + + +def _create_embeddings( + model_name: Optional[str] = "text2vec-large-chinese", +) -> Embeddings: + """Create embeddings.""" + return DefaultEmbeddingFactory( + default_model_name=os.path.join(MODEL_PATH, model_name), + ).create() + + +def _create_vector_connector( + embeddings: Embeddings, space_name: str = "retriever_evaluation_example" +) -> VectorStoreConnector: + """Create vector connector.""" + return VectorStoreConnector.from_default( + "Chroma", + vector_store_config=ChromaVectorConfig( + name=space_name, + persist_path=os.path.join(PILOT_PATH, "data"), + ), + embedding_fn=embeddings, + ) + + +async def main(): + file_path = os.path.join(ROOT_PATH, "docs/docs/awel/awel.md") + knowledge = KnowledgeFactory.from_file_path(file_path) + embeddings = _create_embeddings() + vector_connector = _create_vector_connector(embeddings) + chunk_parameters = ChunkParameters(chunk_strategy="CHUNK_BY_SIZE") + # get embedding assembler + assembler = EmbeddingAssembler.load_from_knowledge( + knowledge=knowledge, + chunk_parameters=chunk_parameters, + vector_store_connector=vector_connector, + ) + assembler.persist() + + dataset = [ + { + "query": "what is awel talk about", + "contexts": [ + "Through the AWEL API, you can focus on the development" + " of business logic for LLMs applications without paying " + "attention to cumbersome model and environment details." + ], + }, + ] + evaluator = RetrieverEvaluator( + operator_cls=EmbeddingRetrieverOperator, + embeddings=embeddings, + operator_kwargs={ + "top_k": 5, + "vector_store_connector": vector_connector, + }, + ) + results = await evaluator.evaluate(dataset) + for result in results: + for metric in result: + print("Metric:", metric.metric_name) + print("Question:", metric.query) + print("Score:", metric.score) + print(f"Results:\n{results}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/rag/simple_rag_embedding_example.py b/examples/rag/simple_rag_embedding_example.py index 56e5f959f..358263c60 100644 --- a/examples/rag/simple_rag_embedding_example.py +++ b/examples/rag/simple_rag_embedding_example.py @@ -7,7 +7,7 @@ curl --location --request POST 'http://127.0.0.1:5555/api/v1/awel/trigger/examples/rag/embedding' \ --header 'Content-Type: application/json' \ --data-raw '{ - "url": "https://docs.dbgpt.site/docs/awel" + "url": "https://docs.dbgpt.site/docs/latest/awel/" }' """ diff --git a/scripts/setup_autodl_env.sh b/scripts/setup_autodl_env.sh index 236c1d35c..e4af3068e 100644 --- a/scripts/setup_autodl_env.sh +++ b/scripts/setup_autodl_env.sh @@ -57,6 +57,8 @@ clean_local_data() { rm -rf /root/DB-GPT/pilot/message rm -f /root/DB-GPT/logs/* rm -f /root/DB-GPT/logsDbChatOutputParser.log + rm -rf /root/DB-GPT/pilot/meta_data/alembic/versions/* + rm -rf /root/DB-GPT/pilot/meta_data/*.db } usage() {