diff --git a/.env.template b/.env.template index 44aa2d710..2a281e698 100644 --- a/.env.template +++ b/.env.template @@ -277,6 +277,11 @@ DBGPT_LOG_LEVEL=INFO # ENCRYPT KEY - The key used to encrypt and decrypt the data # ENCRYPT_KEY=your_secret_key +#*******************************************************************# +#** File Server **# +#*******************************************************************# +## The local storage path of the file server, the default is pilot/data/file_server +# FILE_SERVER_LOCAL_STORAGE_PATH = #*******************************************************************# #** Application Config **# diff --git a/.mypy.ini b/.mypy.ini index e2c2bc3ab..52ae00c35 100644 --- a/.mypy.ini +++ b/.mypy.ini @@ -115,3 +115,6 @@ ignore_missing_imports = True [mypy-networkx.*] ignore_missing_imports = True + +[mypy-pypdf.*] +ignore_missing_imports = True diff --git a/assets/schema/dbgpt.sql b/assets/schema/dbgpt.sql index 0cdd7d17e..f0683d5a6 100644 --- a/assets/schema/dbgpt.sql +++ b/assets/schema/dbgpt.sql @@ -295,6 +295,50 @@ CREATE TABLE `dbgpt_serve_flow` ( KEY `ix_dbgpt_serve_flow_name` (`name`) ) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; +-- dbgpt.dbgpt_serve_file definition +CREATE TABLE `dbgpt_serve_file` ( + `id` int NOT NULL AUTO_INCREMENT COMMENT 'Auto increment id', + `bucket` varchar(255) NOT NULL COMMENT 'Bucket name', + `file_id` varchar(255) NOT NULL COMMENT 'File id', + `file_name` varchar(256) NOT NULL COMMENT 'File name', + `file_size` int DEFAULT NULL COMMENT 'File size', + `storage_type` varchar(32) NOT NULL COMMENT 'Storage type', + `storage_path` varchar(512) NOT NULL COMMENT 'Storage path', + `uri` varchar(512) NOT NULL COMMENT 'File URI', + `custom_metadata` text DEFAULT NULL COMMENT 'Custom metadata, JSON format', + `file_hash` varchar(128) DEFAULT NULL COMMENT 'File hash', + `user_name` varchar(128) DEFAULT NULL COMMENT 'User name', + `sys_code` varchar(128) DEFAULT NULL COMMENT 'System code', + `gmt_created` datetime DEFAULT CURRENT_TIMESTAMP COMMENT 'Record creation time', + `gmt_modified` datetime DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'Record update time', + PRIMARY KEY (`id`), + UNIQUE KEY `uk_bucket_file_id` (`bucket`, `file_id`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; + +-- dbgpt.dbgpt_serve_variables definition +CREATE TABLE `dbgpt_serve_variables` ( + `id` int NOT NULL AUTO_INCREMENT COMMENT 'Auto increment id', + `key` varchar(128) NOT NULL COMMENT 'Variable key', + `name` varchar(128) DEFAULT NULL COMMENT 'Variable name', + `label` varchar(128) DEFAULT NULL COMMENT 'Variable label', + `value` text DEFAULT NULL COMMENT 'Variable value, JSON format', + `value_type` varchar(32) DEFAULT NULL COMMENT 'Variable value type(string, int, float, bool)', + `category` varchar(32) DEFAULT 'common' COMMENT 'Variable category(common or secret)', + `encryption_method` varchar(32) DEFAULT NULL COMMENT 'Variable encryption method(fernet, simple, rsa, aes)', + `salt` varchar(128) DEFAULT NULL COMMENT 'Variable salt', + `scope` varchar(32) DEFAULT 'global' COMMENT 'Variable scope(global,flow,app,agent,datasource,flow:uid, flow:dag_name,agent:agent_name) etc', + `scope_key` varchar(256) DEFAULT NULL COMMENT 'Variable scope key, default is empty, for scope is "flow:uid", the scope_key is uid of flow', + `enabled` int DEFAULT 1 COMMENT 'Variable enabled, 0: disabled, 1: enabled', + `user_name` varchar(128) DEFAULT NULL COMMENT 'User name', + `sys_code` varchar(128) DEFAULT NULL COMMENT 'System code', + `gmt_created` datetime DEFAULT CURRENT_TIMESTAMP COMMENT 'Record creation time', + `gmt_modified` datetime DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'Record update time', + PRIMARY KEY (`id`), + KEY `ix_your_table_name_key` (`key`), + KEY `ix_your_table_name_name` (`name`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; + + -- dbgpt.gpts_app definition CREATE TABLE `gpts_app` ( `id` int NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id', diff --git a/assets/schema/upgrade/v0_6_0/upgrade_to_v0.6.0.sql b/assets/schema/upgrade/v0_6_0/upgrade_to_v0.6.0.sql new file mode 100644 index 000000000..fa345fabe --- /dev/null +++ b/assets/schema/upgrade/v0_6_0/upgrade_to_v0.6.0.sql @@ -0,0 +1,43 @@ +-- dbgpt.dbgpt_serve_file definition +CREATE TABLE `dbgpt_serve_file` ( + `id` int NOT NULL AUTO_INCREMENT COMMENT 'Auto increment id', + `bucket` varchar(255) NOT NULL COMMENT 'Bucket name', + `file_id` varchar(255) NOT NULL COMMENT 'File id', + `file_name` varchar(256) NOT NULL COMMENT 'File name', + `file_size` int DEFAULT NULL COMMENT 'File size', + `storage_type` varchar(32) NOT NULL COMMENT 'Storage type', + `storage_path` varchar(512) NOT NULL COMMENT 'Storage path', + `uri` varchar(512) NOT NULL COMMENT 'File URI', + `custom_metadata` text DEFAULT NULL COMMENT 'Custom metadata, JSON format', + `file_hash` varchar(128) DEFAULT NULL COMMENT 'File hash', + `user_name` varchar(128) DEFAULT NULL COMMENT 'User name', + `sys_code` varchar(128) DEFAULT NULL COMMENT 'System code', + `gmt_created` datetime DEFAULT CURRENT_TIMESTAMP COMMENT 'Record creation time', + `gmt_modified` datetime DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'Record update time', + PRIMARY KEY (`id`), + UNIQUE KEY `uk_bucket_file_id` (`bucket`, `file_id`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; + +-- dbgpt.dbgpt_serve_variables definition +CREATE TABLE `dbgpt_serve_variables` ( + `id` int NOT NULL AUTO_INCREMENT COMMENT 'Auto increment id', + `key` varchar(128) NOT NULL COMMENT 'Variable key', + `name` varchar(128) DEFAULT NULL COMMENT 'Variable name', + `label` varchar(128) DEFAULT NULL COMMENT 'Variable label', + `value` text DEFAULT NULL COMMENT 'Variable value, JSON format', + `value_type` varchar(32) DEFAULT NULL COMMENT 'Variable value type(string, int, float, bool)', + `category` varchar(32) DEFAULT 'common' COMMENT 'Variable category(common or secret)', + `encryption_method` varchar(32) DEFAULT NULL COMMENT 'Variable encryption method(fernet, simple, rsa, aes)', + `salt` varchar(128) DEFAULT NULL COMMENT 'Variable salt', + `scope` varchar(32) DEFAULT 'global' COMMENT 'Variable scope(global,flow,app,agent,datasource,flow:uid, flow:dag_name,agent:agent_name) etc', + `scope_key` varchar(256) DEFAULT NULL COMMENT 'Variable scope key, default is empty, for scope is "flow:uid", the scope_key is uid of flow', + `enabled` int DEFAULT 1 COMMENT 'Variable enabled, 0: disabled, 1: enabled', + `user_name` varchar(128) DEFAULT NULL COMMENT 'User name', + `sys_code` varchar(128) DEFAULT NULL COMMENT 'System code', + `gmt_created` datetime DEFAULT CURRENT_TIMESTAMP COMMENT 'Record creation time', + `gmt_modified` datetime DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'Record update time', + PRIMARY KEY (`id`), + KEY `ix_your_table_name_key` (`key`), + KEY `ix_your_table_name_name` (`name`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; + diff --git a/assets/schema/upgrade/v0_6_0/v0.5.10.sql b/assets/schema/upgrade/v0_6_0/v0.5.10.sql new file mode 100644 index 000000000..a70d8e643 --- /dev/null +++ b/assets/schema/upgrade/v0_6_0/v0.5.10.sql @@ -0,0 +1,419 @@ +-- Full SQL of v0.5.10, please not modify this file(It must be same as the file in the release package) + +CREATE +DATABASE IF NOT EXISTS dbgpt; +use dbgpt; + +-- For alembic migration tool +CREATE TABLE IF NOT EXISTS `alembic_version` +( + version_num VARCHAR(32) NOT NULL, + CONSTRAINT alembic_version_pkc PRIMARY KEY (version_num) +) DEFAULT CHARSET=utf8mb4 ; + +CREATE TABLE IF NOT EXISTS `knowledge_space` +( + `id` int NOT NULL AUTO_INCREMENT COMMENT 'auto increment id', + `name` varchar(100) NOT NULL COMMENT 'knowledge space name', + `vector_type` varchar(50) NOT NULL COMMENT 'vector type', + `domain_type` varchar(50) NOT NULL COMMENT 'domain type', + `desc` varchar(500) NOT NULL COMMENT 'description', + `owner` varchar(100) DEFAULT NULL COMMENT 'owner', + `context` TEXT DEFAULT NULL COMMENT 'context argument', + `gmt_created` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'created time', + `gmt_modified` TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time', + PRIMARY KEY (`id`), + KEY `idx_name` (`name`) COMMENT 'index:idx_name' +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT='knowledge space table'; + +CREATE TABLE IF NOT EXISTS `knowledge_document` +( + `id` int NOT NULL AUTO_INCREMENT COMMENT 'auto increment id', + `doc_name` varchar(100) NOT NULL COMMENT 'document path name', + `doc_type` varchar(50) NOT NULL COMMENT 'doc type', + `space` varchar(50) NOT NULL COMMENT 'knowledge space', + `chunk_size` int NOT NULL COMMENT 'chunk size', + `last_sync` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'last sync time', + `status` varchar(50) NOT NULL COMMENT 'status TODO,RUNNING,FAILED,FINISHED', + `content` LONGTEXT NOT NULL COMMENT 'knowledge embedding sync result', + `result` TEXT NULL COMMENT 'knowledge content', + `vector_ids` LONGTEXT NULL COMMENT 'vector_ids', + `summary` LONGTEXT NULL COMMENT 'knowledge summary', + `gmt_created` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'created time', + `gmt_modified` TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time', + PRIMARY KEY (`id`), + KEY `idx_doc_name` (`doc_name`) COMMENT 'index:idx_doc_name' +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT='knowledge document table'; + +CREATE TABLE IF NOT EXISTS `document_chunk` +( + `id` int NOT NULL AUTO_INCREMENT COMMENT 'auto increment id', + `doc_name` varchar(100) NOT NULL COMMENT 'document path name', + `doc_type` varchar(50) NOT NULL COMMENT 'doc type', + `document_id` int NOT NULL COMMENT 'document parent id', + `content` longtext NOT NULL COMMENT 'chunk content', + `meta_info` varchar(200) NOT NULL COMMENT 'metadata info', + `gmt_created` timestamp NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'created time', + `gmt_modified` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time', + PRIMARY KEY (`id`), + KEY `idx_document_id` (`document_id`) COMMENT 'index:document_id' +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT='knowledge document chunk detail'; + + + +CREATE TABLE IF NOT EXISTS `connect_config` +( + `id` int NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id', + `db_type` varchar(255) NOT NULL COMMENT 'db type', + `db_name` varchar(255) NOT NULL COMMENT 'db name', + `db_path` varchar(255) DEFAULT NULL COMMENT 'file db path', + `db_host` varchar(255) DEFAULT NULL COMMENT 'db connect host(not file db)', + `db_port` varchar(255) DEFAULT NULL COMMENT 'db cnnect port(not file db)', + `db_user` varchar(255) DEFAULT NULL COMMENT 'db user', + `db_pwd` varchar(255) DEFAULT NULL COMMENT 'db password', + `comment` text COMMENT 'db comment', + `sys_code` varchar(128) DEFAULT NULL COMMENT 'System code', + PRIMARY KEY (`id`), + UNIQUE KEY `uk_db` (`db_name`), + KEY `idx_q_db_type` (`db_type`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT 'Connection confi'; + +CREATE TABLE IF NOT EXISTS `chat_history` +( + `id` int NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id', + `conv_uid` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'Conversation record unique id', + `chat_mode` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'Conversation scene mode', + `summary` longtext COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'Conversation record summary', + `user_name` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'interlocutor', + `messages` text COLLATE utf8mb4_unicode_ci COMMENT 'Conversation details', + `message_ids` text COLLATE utf8mb4_unicode_ci COMMENT 'Message id list, split by comma', + `sys_code` varchar(128) DEFAULT NULL COMMENT 'System code', + `gmt_created` timestamp NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'created time', + `gmt_modified` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time', + UNIQUE KEY `conv_uid` (`conv_uid`), + PRIMARY KEY (`id`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT 'Chat history'; + +CREATE TABLE IF NOT EXISTS `chat_history_message` +( + `id` int NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id', + `conv_uid` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'Conversation record unique id', + `index` int NOT NULL COMMENT 'Message index', + `round_index` int NOT NULL COMMENT 'Round of conversation', + `message_detail` text COLLATE utf8mb4_unicode_ci COMMENT 'Message details, json format', + `gmt_created` timestamp NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'created time', + `gmt_modified` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time', + UNIQUE KEY `message_uid_index` (`conv_uid`, `index`), + PRIMARY KEY (`id`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT 'Chat history message'; + +CREATE TABLE IF NOT EXISTS `chat_feed_back` +( + `id` bigint(20) NOT NULL AUTO_INCREMENT, + `conv_uid` varchar(128) DEFAULT NULL COMMENT 'Conversation ID', + `conv_index` int(4) DEFAULT NULL COMMENT 'Round of conversation', + `score` int(1) DEFAULT NULL COMMENT 'Score of user', + `ques_type` varchar(32) DEFAULT NULL COMMENT 'User question category', + `question` longtext DEFAULT NULL COMMENT 'User question', + `knowledge_space` varchar(128) DEFAULT NULL COMMENT 'Knowledge space name', + `messages` longtext DEFAULT NULL COMMENT 'The details of user feedback', + `user_name` varchar(128) DEFAULT NULL COMMENT 'User name', + `gmt_created` timestamp NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'created time', + `gmt_modified` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time', + PRIMARY KEY (`id`), + UNIQUE KEY `uk_conv` (`conv_uid`,`conv_index`), + KEY `idx_conv` (`conv_uid`,`conv_index`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT='User feedback table'; + + +CREATE TABLE IF NOT EXISTS `my_plugin` +( + `id` int NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id', + `tenant` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'user tenant', + `user_code` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'user code', + `user_name` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'user name', + `name` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'plugin name', + `file_name` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'plugin package file name', + `type` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin type', + `version` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin version', + `use_count` int DEFAULT NULL COMMENT 'plugin total use count', + `succ_count` int DEFAULT NULL COMMENT 'plugin total success count', + `sys_code` varchar(128) DEFAULT NULL COMMENT 'System code', + `gmt_created` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'plugin install time', + PRIMARY KEY (`id`), + UNIQUE KEY `name` (`name`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT='User plugin table'; + +CREATE TABLE IF NOT EXISTS `plugin_hub` +( + `id` int NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id', + `name` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'plugin name', + `description` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'plugin description', + `author` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin author', + `email` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin author email', + `type` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin type', + `version` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin version', + `storage_channel` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin storage channel', + `storage_url` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin download url', + `download_param` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin download param', + `gmt_created` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'plugin upload time', + `installed` int DEFAULT NULL COMMENT 'plugin already installed count', + PRIMARY KEY (`id`), + UNIQUE KEY `name` (`name`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT='Plugin Hub table'; + + +CREATE TABLE IF NOT EXISTS `prompt_manage` +( + `id` int(11) NOT NULL AUTO_INCREMENT, + `chat_scene` varchar(100) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Chat scene', + `sub_chat_scene` varchar(100) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Sub chat scene', + `prompt_type` varchar(100) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Prompt type: common or private', + `prompt_name` varchar(256) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'prompt name', + `content` longtext COLLATE utf8mb4_unicode_ci COMMENT 'Prompt content', + `input_variables` varchar(1024) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Prompt input variables(split by comma))', + `model` varchar(128) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Prompt model name(we can use different models for different prompt)', + `prompt_language` varchar(32) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Prompt language(eg:en, zh-cn)', + `prompt_format` varchar(32) COLLATE utf8mb4_unicode_ci DEFAULT 'f-string' COMMENT 'Prompt format(eg: f-string, jinja2)', + `prompt_desc` varchar(512) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Prompt description', + `user_name` varchar(128) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'User name', + `sys_code` varchar(128) DEFAULT NULL COMMENT 'System code', + `gmt_created` timestamp NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'created time', + `gmt_modified` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time', + PRIMARY KEY (`id`), + UNIQUE KEY `prompt_name_uiq` (`prompt_name`, `sys_code`, `prompt_language`, `model`), + KEY `gmt_created_idx` (`gmt_created`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT='Prompt management table'; + + CREATE TABLE IF NOT EXISTS `gpts_conversations` ( + `id` int(11) NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id', + `conv_id` varchar(255) NOT NULL COMMENT 'The unique id of the conversation record', + `user_goal` text NOT NULL COMMENT 'User''s goals content', + `gpts_name` varchar(255) NOT NULL COMMENT 'The gpts name', + `state` varchar(255) DEFAULT NULL COMMENT 'The gpts state', + `max_auto_reply_round` int(11) NOT NULL COMMENT 'max auto reply round', + `auto_reply_count` int(11) NOT NULL COMMENT 'auto reply count', + `user_code` varchar(255) DEFAULT NULL COMMENT 'user code', + `sys_code` varchar(255) DEFAULT NULL COMMENT 'system app ', + `created_at` datetime DEFAULT NULL COMMENT 'create time', + `updated_at` datetime DEFAULT NULL COMMENT 'last update time', + `team_mode` varchar(255) NULL COMMENT 'agent team work mode', + + PRIMARY KEY (`id`), + UNIQUE KEY `uk_gpts_conversations` (`conv_id`), + KEY `idx_gpts_name` (`gpts_name`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT="gpt conversations"; + +CREATE TABLE IF NOT EXISTS `gpts_instance` ( + `id` int(11) NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id', + `gpts_name` varchar(255) NOT NULL COMMENT 'Current AI assistant name', + `gpts_describe` varchar(2255) NOT NULL COMMENT 'Current AI assistant describe', + `resource_db` text COMMENT 'List of structured database names contained in the current gpts', + `resource_internet` text COMMENT 'Is it possible to retrieve information from the internet', + `resource_knowledge` text COMMENT 'List of unstructured database names contained in the current gpts', + `gpts_agents` varchar(1000) DEFAULT NULL COMMENT 'List of agents names contained in the current gpts', + `gpts_models` varchar(1000) DEFAULT NULL COMMENT 'List of llm model names contained in the current gpts', + `language` varchar(100) DEFAULT NULL COMMENT 'gpts language', + `user_code` varchar(255) NOT NULL COMMENT 'user code', + `sys_code` varchar(255) DEFAULT NULL COMMENT 'system app code', + `created_at` datetime DEFAULT NULL COMMENT 'create time', + `updated_at` datetime DEFAULT NULL COMMENT 'last update time', + `team_mode` varchar(255) NOT NULL COMMENT 'Team work mode', + `is_sustainable` tinyint(1) NOT NULL COMMENT 'Applications for sustainable dialogue', + PRIMARY KEY (`id`), + UNIQUE KEY `uk_gpts` (`gpts_name`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT="gpts instance"; + +CREATE TABLE `gpts_messages` ( + `id` int(11) NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id', + `conv_id` varchar(255) NOT NULL COMMENT 'The unique id of the conversation record', + `sender` varchar(255) NOT NULL COMMENT 'Who speaking in the current conversation turn', + `receiver` varchar(255) NOT NULL COMMENT 'Who receive message in the current conversation turn', + `model_name` varchar(255) DEFAULT NULL COMMENT 'message generate model', + `rounds` int(11) NOT NULL COMMENT 'dialogue turns', + `content` text COMMENT 'Content of the speech', + `current_goal` text COMMENT 'The target corresponding to the current message', + `context` text COMMENT 'Current conversation context', + `review_info` text COMMENT 'Current conversation review info', + `action_report` text COMMENT 'Current conversation action report', + `role` varchar(255) DEFAULT NULL COMMENT 'The role of the current message content', + `created_at` datetime DEFAULT NULL COMMENT 'create time', + `updated_at` datetime DEFAULT NULL COMMENT 'last update time', + PRIMARY KEY (`id`), + KEY `idx_q_messages` (`conv_id`,`rounds`,`sender`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT="gpts message"; + + +CREATE TABLE `gpts_plans` ( + `id` int(11) NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id', + `conv_id` varchar(255) NOT NULL COMMENT 'The unique id of the conversation record', + `sub_task_num` int(11) NOT NULL COMMENT 'Subtask number', + `sub_task_title` varchar(255) NOT NULL COMMENT 'subtask title', + `sub_task_content` text NOT NULL COMMENT 'subtask content', + `sub_task_agent` varchar(255) DEFAULT NULL COMMENT 'Available agents corresponding to subtasks', + `resource_name` varchar(255) DEFAULT NULL COMMENT 'resource name', + `rely` varchar(255) DEFAULT NULL COMMENT 'Subtask dependencies,like: 1,2,3', + `agent_model` varchar(255) DEFAULT NULL COMMENT 'LLM model used by subtask processing agents', + `retry_times` int(11) DEFAULT NULL COMMENT 'number of retries', + `max_retry_times` int(11) DEFAULT NULL COMMENT 'Maximum number of retries', + `state` varchar(255) DEFAULT NULL COMMENT 'subtask status', + `result` longtext COMMENT 'subtask result', + `created_at` datetime DEFAULT NULL COMMENT 'create time', + `updated_at` datetime DEFAULT NULL COMMENT 'last update time', + PRIMARY KEY (`id`), + UNIQUE KEY `uk_sub_task` (`conv_id`,`sub_task_num`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT="gpt plan"; + +-- dbgpt.dbgpt_serve_flow definition +CREATE TABLE `dbgpt_serve_flow` ( + `id` int NOT NULL AUTO_INCREMENT COMMENT 'Auto increment id', + `uid` varchar(128) NOT NULL COMMENT 'Unique id', + `dag_id` varchar(128) DEFAULT NULL COMMENT 'DAG id', + `name` varchar(128) DEFAULT NULL COMMENT 'Flow name', + `flow_data` text COMMENT 'Flow data, JSON format', + `user_name` varchar(128) DEFAULT NULL COMMENT 'User name', + `sys_code` varchar(128) DEFAULT NULL COMMENT 'System code', + `gmt_created` datetime DEFAULT NULL COMMENT 'Record creation time', + `gmt_modified` datetime DEFAULT NULL COMMENT 'Record update time', + `flow_category` varchar(64) DEFAULT NULL COMMENT 'Flow category', + `description` varchar(512) DEFAULT NULL COMMENT 'Flow description', + `state` varchar(32) DEFAULT NULL COMMENT 'Flow state', + `error_message` varchar(512) NULL comment 'Error message', + `source` varchar(64) DEFAULT NULL COMMENT 'Flow source', + `source_url` varchar(512) DEFAULT NULL COMMENT 'Flow source url', + `version` varchar(32) DEFAULT NULL COMMENT 'Flow version', + `define_type` varchar(32) null comment 'Flow define type(json or python)', + `label` varchar(128) DEFAULT NULL COMMENT 'Flow label', + `editable` int DEFAULT NULL COMMENT 'Editable, 0: editable, 1: not editable', + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`), + KEY `ix_dbgpt_serve_flow_sys_code` (`sys_code`), + KEY `ix_dbgpt_serve_flow_uid` (`uid`), + KEY `ix_dbgpt_serve_flow_dag_id` (`dag_id`), + KEY `ix_dbgpt_serve_flow_user_name` (`user_name`), + KEY `ix_dbgpt_serve_flow_name` (`name`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; + +-- dbgpt.gpts_app definition +CREATE TABLE `gpts_app` ( + `id` int NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id', + `app_code` varchar(255) NOT NULL COMMENT 'Current AI assistant code', + `app_name` varchar(255) NOT NULL COMMENT 'Current AI assistant name', + `app_describe` varchar(2255) NOT NULL COMMENT 'Current AI assistant describe', + `language` varchar(100) NOT NULL COMMENT 'gpts language', + `team_mode` varchar(255) NOT NULL COMMENT 'Team work mode', + `team_context` text COMMENT 'The execution logic and team member content that teams with different working modes rely on', + `user_code` varchar(255) DEFAULT NULL COMMENT 'user code', + `sys_code` varchar(255) DEFAULT NULL COMMENT 'system app code', + `created_at` datetime DEFAULT NULL COMMENT 'create time', + `updated_at` datetime DEFAULT NULL COMMENT 'last update time', + `icon` varchar(1024) DEFAULT NULL COMMENT 'app icon, url', + PRIMARY KEY (`id`), + UNIQUE KEY `uk_gpts_app` (`app_name`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; + +CREATE TABLE `gpts_app_collection` ( + `id` int(11) NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id', + `app_code` varchar(255) NOT NULL COMMENT 'Current AI assistant code', + `user_code` int(11) NOT NULL COMMENT 'user code', + `sys_code` varchar(255) NOT NULL COMMENT 'system app code', + `created_at` datetime DEFAULT NULL COMMENT 'create time', + `updated_at` datetime DEFAULT NULL COMMENT 'last update time', + PRIMARY KEY (`id`), + KEY `idx_app_code` (`app_code`), + KEY `idx_user_code` (`user_code`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT="gpt collections"; + +-- dbgpt.gpts_app_detail definition +CREATE TABLE `gpts_app_detail` ( + `id` int NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id', + `app_code` varchar(255) NOT NULL COMMENT 'Current AI assistant code', + `app_name` varchar(255) NOT NULL COMMENT 'Current AI assistant name', + `agent_name` varchar(255) NOT NULL COMMENT ' Agent name', + `node_id` varchar(255) NOT NULL COMMENT 'Current AI assistant Agent Node id', + `resources` text COMMENT 'Agent bind resource', + `prompt_template` text COMMENT 'Agent bind template', + `llm_strategy` varchar(25) DEFAULT NULL COMMENT 'Agent use llm strategy', + `llm_strategy_value` text COMMENT 'Agent use llm strategy value', + `created_at` datetime DEFAULT NULL COMMENT 'create time', + `updated_at` datetime DEFAULT NULL COMMENT 'last update time', + PRIMARY KEY (`id`), + UNIQUE KEY `uk_gpts_app_agent_node` (`app_name`,`agent_name`,`node_id`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; + + +-- For deploy model cluster of DB-GPT(StorageModelRegistry) +CREATE TABLE IF NOT EXISTS `dbgpt_cluster_registry_instance` ( + `id` int(11) NOT NULL AUTO_INCREMENT COMMENT 'Auto increment id', + `model_name` varchar(128) NOT NULL COMMENT 'Model name', + `host` varchar(128) NOT NULL COMMENT 'Host of the model', + `port` int(11) NOT NULL COMMENT 'Port of the model', + `weight` float DEFAULT 1.0 COMMENT 'Weight of the model', + `check_healthy` tinyint(1) DEFAULT 1 COMMENT 'Whether to check the health of the model', + `healthy` tinyint(1) DEFAULT 0 COMMENT 'Whether the model is healthy', + `enabled` tinyint(1) DEFAULT 1 COMMENT 'Whether the model is enabled', + `prompt_template` varchar(128) DEFAULT NULL COMMENT 'Prompt template for the model instance', + `last_heartbeat` datetime DEFAULT NULL COMMENT 'Last heartbeat time of the model instance', + `user_name` varchar(128) DEFAULT NULL COMMENT 'User name', + `sys_code` varchar(128) DEFAULT NULL COMMENT 'System code', + `gmt_created` datetime DEFAULT CURRENT_TIMESTAMP COMMENT 'Record creation time', + `gmt_modified` datetime DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'Record update time', + PRIMARY KEY (`id`), + UNIQUE KEY `uk_model_instance` (`model_name`, `host`, `port`, `sys_code`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT='Cluster model instance table, for registering and managing model instances'; + + +CREATE +DATABASE IF NOT EXISTS EXAMPLE_1; +use EXAMPLE_1; +CREATE TABLE IF NOT EXISTS `users` +( + `id` int NOT NULL AUTO_INCREMENT, + `username` varchar(50) NOT NULL COMMENT '用户名', + `password` varchar(50) NOT NULL COMMENT '密码', + `email` varchar(50) NOT NULL COMMENT '邮箱', + `phone` varchar(20) DEFAULT NULL COMMENT '电话', + PRIMARY KEY (`id`), + KEY `idx_username` (`username`) COMMENT '索引:按用户名查询' +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT='聊天用户表'; + +INSERT INTO users (username, password, email, phone) +VALUES ('user_1', 'password_1', 'user_1@example.com', '12345678901'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_2', 'password_2', 'user_2@example.com', '12345678902'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_3', 'password_3', 'user_3@example.com', '12345678903'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_4', 'password_4', 'user_4@example.com', '12345678904'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_5', 'password_5', 'user_5@example.com', '12345678905'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_6', 'password_6', 'user_6@example.com', '12345678906'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_7', 'password_7', 'user_7@example.com', '12345678907'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_8', 'password_8', 'user_8@example.com', '12345678908'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_9', 'password_9', 'user_9@example.com', '12345678909'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_10', 'password_10', 'user_10@example.com', '12345678900'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_11', 'password_11', 'user_11@example.com', '12345678901'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_12', 'password_12', 'user_12@example.com', '12345678902'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_13', 'password_13', 'user_13@example.com', '12345678903'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_14', 'password_14', 'user_14@example.com', '12345678904'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_15', 'password_15', 'user_15@example.com', '12345678905'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_16', 'password_16', 'user_16@example.com', '12345678906'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_17', 'password_17', 'user_17@example.com', '12345678907'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_18', 'password_18', 'user_18@example.com', '12345678908'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_19', 'password_19', 'user_19@example.com', '12345678909'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_20', 'password_20', 'user_20@example.com', '12345678900'); \ No newline at end of file diff --git a/dbgpt/_private/config.py b/dbgpt/_private/config.py index 2dbfac0f0..403570c5a 100644 --- a/dbgpt/_private/config.py +++ b/dbgpt/_private/config.py @@ -316,6 +316,17 @@ def __init__(self) -> None: # experimental financial report model configuration self.FIN_REPORT_MODEL = os.getenv("FIN_REPORT_MODEL", None) + # file server configuration + # The host of the current file server, if None, get the host automatically + self.FILE_SERVER_HOST = os.getenv("FILE_SERVER_HOST") + self.FILE_SERVER_LOCAL_STORAGE_PATH = os.getenv( + "FILE_SERVER_LOCAL_STORAGE_PATH" + ) + # multi-instance flag + self.WEBSERVER_MULTI_INSTANCE = ( + os.getenv("MULTI_INSTANCE", "False").lower() == "true" + ) + @property def local_db_manager(self) -> "ConnectorManager": from dbgpt.datasource.manages import ConnectorManager diff --git a/dbgpt/app/component_configs.py b/dbgpt/app/component_configs.py index 3ef08d4bc..a8a0f24d1 100644 --- a/dbgpt/app/component_configs.py +++ b/dbgpt/app/component_configs.py @@ -52,17 +52,17 @@ def initialize_components( param, system_app, embedding_model_name, embedding_model_path ) _initialize_rerank_model(param, system_app, rerank_model_name, rerank_model_path) - _initialize_model_cache(system_app) + _initialize_model_cache(system_app, param.port) _initialize_awel(system_app, param) # Initialize resource manager of agent _initialize_resource_manager(system_app) _initialize_agent(system_app) _initialize_openapi(system_app) # Register serve apps - register_serve_apps(system_app, CFG) + register_serve_apps(system_app, CFG, param.port) -def _initialize_model_cache(system_app: SystemApp): +def _initialize_model_cache(system_app: SystemApp, port: int): from dbgpt.storage.cache import initialize_cache if not CFG.MODEL_CACHE_ENABLE: @@ -72,6 +72,8 @@ def _initialize_model_cache(system_app: SystemApp): storage_type = CFG.MODEL_CACHE_STORAGE_TYPE or "disk" max_memory_mb = CFG.MODEL_CACHE_MAX_MEMORY_MB or 256 persist_dir = CFG.MODEL_CACHE_STORAGE_DISK_DIR or MODEL_DISK_CACHE_DIR + if CFG.WEBSERVER_MULTI_INSTANCE: + persist_dir = f"{persist_dir}_{port}" initialize_cache(system_app, storage_type, max_memory_mb, persist_dir) diff --git a/dbgpt/app/initialization/db_model_initialization.py b/dbgpt/app/initialization/db_model_initialization.py index b8808c400..969340c44 100644 --- a/dbgpt/app/initialization/db_model_initialization.py +++ b/dbgpt/app/initialization/db_model_initialization.py @@ -8,6 +8,7 @@ from dbgpt.model.cluster.registry_impl.db_storage import ModelInstanceEntity from dbgpt.serve.agent.db.my_plugin_db import MyPluginEntity from dbgpt.serve.agent.db.plugin_hub_db import PluginHubEntity +from dbgpt.serve.file.models.models import ServeEntity as FileServeEntity from dbgpt.serve.flow.models.models import ServeEntity as FlowServeEntity from dbgpt.serve.flow.models.models import VariablesEntity as FlowVariableEntity from dbgpt.serve.prompt.models.models import ServeEntity as PromptManageEntity @@ -19,6 +20,7 @@ _MODELS = [ PluginHubEntity, + FileServeEntity, MyPluginEntity, PromptManageEntity, KnowledgeSpaceEntity, diff --git a/dbgpt/app/initialization/serve_initialization.py b/dbgpt/app/initialization/serve_initialization.py index f0b9c9e42..5b29ce455 100644 --- a/dbgpt/app/initialization/serve_initialization.py +++ b/dbgpt/app/initialization/serve_initialization.py @@ -2,7 +2,7 @@ from dbgpt.component import SystemApp -def register_serve_apps(system_app: SystemApp, cfg: Config): +def register_serve_apps(system_app: SystemApp, cfg: Config, webserver_port: int): """Register serve apps""" system_app.config.set("dbgpt.app.global.language", cfg.LANGUAGE) if cfg.API_KEYS: @@ -47,6 +47,8 @@ def register_serve_apps(system_app: SystemApp, cfg: Config): # Register serve app system_app.register(FlowServe) + # ################################ AWEL Flow Serve Register End ######################################## + # ################################ Rag Serve Register Begin ###################################### from dbgpt.serve.rag.serve import ( @@ -57,6 +59,8 @@ def register_serve_apps(system_app: SystemApp, cfg: Config): # Register serve app system_app.register(RagServe) + # ################################ Rag Serve Register End ######################################## + # ################################ Datasource Serve Register Begin ###################################### from dbgpt.serve.datasource.serve import ( @@ -66,4 +70,34 @@ def register_serve_apps(system_app: SystemApp, cfg: Config): # Register serve app system_app.register(DatasourceServe) - # ################################ AWEL Flow Serve Register End ######################################## + + # ################################ Datasource Serve Register End ######################################## + + # ################################ File Serve Register Begin ###################################### + + from dbgpt.configs.model_config import FILE_SERVER_LOCAL_STORAGE_PATH + from dbgpt.serve.file.serve import ( + SERVE_CONFIG_KEY_PREFIX as FILE_SERVE_CONFIG_KEY_PREFIX, + ) + from dbgpt.serve.file.serve import Serve as FileServe + + local_storage_path = ( + cfg.FILE_SERVER_LOCAL_STORAGE_PATH or FILE_SERVER_LOCAL_STORAGE_PATH + ) + if cfg.WEBSERVER_MULTI_INSTANCE: + local_storage_path = f"{local_storage_path}_{webserver_port}" + # Set config + system_app.config.set( + f"{FILE_SERVE_CONFIG_KEY_PREFIX}local_storage_path", local_storage_path + ) + system_app.config.set( + f"{FILE_SERVE_CONFIG_KEY_PREFIX}file_server_port", webserver_port + ) + if cfg.FILE_SERVER_HOST: + system_app.config.set( + f"{FILE_SERVE_CONFIG_KEY_PREFIX}file_server_host", cfg.FILE_SERVER_HOST + ) + # Register serve app + system_app.register(FileServe) + + # ################################ File Serve Register End ######################################## diff --git a/dbgpt/component.py b/dbgpt/component.py index cb88a61ec..da3c5e753 100644 --- a/dbgpt/component.py +++ b/dbgpt/component.py @@ -90,6 +90,7 @@ class ComponentType(str, Enum): AGENT_MANAGER = "dbgpt_agent_manager" RESOURCE_MANAGER = "dbgpt_resource_manager" VARIABLES_PROVIDER = "dbgpt_variables_provider" + FILE_STORAGE_CLIENT = "dbgpt_file_storage_client" _EMPTY_DEFAULT_COMPONENT = "_EMPTY_DEFAULT_COMPONENT" diff --git a/dbgpt/configs/model_config.py b/dbgpt/configs/model_config.py index 4d02a2730..e4abac3e7 100644 --- a/dbgpt/configs/model_config.py +++ b/dbgpt/configs/model_config.py @@ -14,6 +14,7 @@ DATA_DIR = os.path.join(PILOT_PATH, "data") PLUGINS_DIR = os.path.join(ROOT_PATH, "plugins") MODEL_DISK_CACHE_DIR = os.path.join(DATA_DIR, "model_cache") +FILE_SERVER_LOCAL_STORAGE_PATH = os.path.join(DATA_DIR, "file_server") _DAG_DEFINITION_DIR = os.path.join(ROOT_PATH, "examples/awel") # Global language setting LOCALES_DIR = os.path.join(ROOT_PATH, "i18n/locales") diff --git a/dbgpt/core/awel/dag/dag_manager.py b/dbgpt/core/awel/dag/dag_manager.py index 91a49a166..15a07254a 100644 --- a/dbgpt/core/awel/dag/dag_manager.py +++ b/dbgpt/core/awel/dag/dag_manager.py @@ -197,7 +197,7 @@ def get_dag_metadata( return self._dag_metadata_map.get(dag.dag_id) -def _parse_metadata(dag: DAG): +def _parse_metadata(dag: DAG) -> DAGMetadata: from ..util.chat_util import _is_sse_output metadata = DAGMetadata() diff --git a/dbgpt/core/awel/flow/flow_factory.py b/dbgpt/core/awel/flow/flow_factory.py index 3f847c07c..e0d505aa5 100644 --- a/dbgpt/core/awel/flow/flow_factory.py +++ b/dbgpt/core/awel/flow/flow_factory.py @@ -4,7 +4,7 @@ import uuid from contextlib import suppress from enum import Enum -from typing import Any, Dict, List, Optional, Tuple, Type, Union, cast +from typing import Any, Dict, List, Literal, Optional, Tuple, Type, Union, cast from typing_extensions import Annotated @@ -166,6 +166,59 @@ class FlowData(BaseModel): viewport: FlowPositionData = Field(..., description="Viewport of the flow") +class VariablesRequest(BaseModel): + """Variable request model. + + For creating a new variable in the DB-GPT. + """ + + key: str = Field( + ..., + description="The key of the variable to create", + examples=["dbgpt.model.openai.api_key"], + ) + name: str = Field( + ..., + description="The name of the variable to create", + examples=["my_first_openai_key"], + ) + label: str = Field( + ..., + description="The label of the variable to create", + examples=["My First OpenAI Key"], + ) + value: Any = Field( + ..., description="The value of the variable to create", examples=["1234567890"] + ) + value_type: Literal["str", "int", "float", "bool"] = Field( + "str", + description="The type of the value of the variable to create", + examples=["str", "int", "float", "bool"], + ) + category: Literal["common", "secret"] = Field( + ..., + description="The category of the variable to create", + examples=["common"], + ) + scope: str = Field( + ..., + description="The scope of the variable to create", + examples=["global"], + ) + scope_key: Optional[str] = Field( + ..., + description="The scope key of the variable to create", + examples=["dbgpt"], + ) + enabled: Optional[bool] = Field( + True, + description="Whether the variable is enabled", + examples=[True], + ) + user_name: Optional[str] = Field(None, description="User name") + sys_code: Optional[str] = Field(None, description="System code") + + class State(str, Enum): """State of a flow panel.""" @@ -356,6 +409,12 @@ class FlowPanel(BaseModel): metadata: Optional[Union[DAGMetadata, Dict[str, Any]]] = Field( default=None, description="The metadata of the flow" ) + variables: Optional[List[VariablesRequest]] = Field( + default=None, description="The variables of the flow" + ) + authors: Optional[List[str]] = Field( + default=None, description="The authors of the flow" + ) @model_validator(mode="before") @classmethod diff --git a/dbgpt/core/awel/flow/ui.py b/dbgpt/core/awel/flow/ui.py index c763859b0..928755a20 100644 --- a/dbgpt/core/awel/flow/ui.py +++ b/dbgpt/core/awel/flow/ui.py @@ -367,7 +367,10 @@ class UIAttribute(UIComponent.UIAttribute): ) ui_type: Literal["upload"] = Field("upload", frozen=True) - + attr: Optional[UIAttribute] = Field( + None, + description="The attributes of the component", + ) max_file_size: Optional[int] = Field( None, description="The maximum size of the file, in bytes", @@ -387,8 +390,8 @@ class UIAttribute(UIComponent.UIAttribute): description="Whether to support drag and drop upload", ) action: Optional[str] = Field( - None, - description="The URL for the file upload", + "/api/v2/serve/file/files/dbgpt", + description="The URL for the file upload(default bucket is 'dbgpt')", ) diff --git a/dbgpt/core/awel/operators/common_operator.py b/dbgpt/core/awel/operators/common_operator.py index fc2dc098b..f8bc25370 100644 --- a/dbgpt/core/awel/operators/common_operator.py +++ b/dbgpt/core/awel/operators/common_operator.py @@ -334,7 +334,8 @@ def __init__(self, input_source: InputSource[OUT], **kwargs) -> None: async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]: curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context task_output = await self._input_source.read(curr_task_ctx) - curr_task_ctx.set_task_output(task_output) + new_task_output: TaskOutput[OUT] = await task_output.map(self.map) + curr_task_ctx.set_task_output(new_task_output) return task_output @classmethod @@ -342,6 +343,10 @@ def dummy_input(cls, dummy_data: Any = SKIP_DATA, **kwargs) -> "InputOperator[OU """Create a dummy InputOperator with a given input value.""" return cls(input_source=InputSource.from_data(dummy_data), **kwargs) + async def map(self, input_data: OUT) -> OUT: + """Map the input data to a new value.""" + return input_data + class TriggerOperator(InputOperator[OUT], Generic[OUT]): """Operator node that triggers the DAG to run.""" diff --git a/dbgpt/core/awel/trigger/http_trigger.py b/dbgpt/core/awel/trigger/http_trigger.py index 22e025c13..8f0298297 100644 --- a/dbgpt/core/awel/trigger/http_trigger.py +++ b/dbgpt/core/awel/trigger/http_trigger.py @@ -87,7 +87,9 @@ class HttpTriggerMetadata(TriggerMetadata): path: str = Field(..., description="The path of the trigger") methods: List[str] = Field(..., description="The methods of the trigger") - + trigger_mode: str = Field( + default="command", description="The mode of the trigger, command or chat" + ) trigger_type: Optional[str] = Field( default="http", description="The type of the trigger" ) @@ -477,7 +479,9 @@ def mount_to_router( )(dynamic_route_function) logger.info(f"Mount http trigger success, path: {path}") - return HttpTriggerMetadata(path=path, methods=self._methods) + return HttpTriggerMetadata( + path=path, methods=self._methods, trigger_mode=self._trigger_mode() + ) def mount_to_app( self, app: "FastAPI", global_prefix: Optional[str] = None @@ -512,7 +516,9 @@ def mount_to_app( app.openapi_schema = None app.middleware_stack = None logger.info(f"Mount http trigger success, path: {path}") - return HttpTriggerMetadata(path=path, methods=self._methods) + return HttpTriggerMetadata( + path=path, methods=self._methods, trigger_mode=self._trigger_mode() + ) def remove_from_app( self, app: "FastAPI", global_prefix: Optional[str] = None @@ -537,6 +543,36 @@ def remove_from_app( # TODO, remove with path and methods del app_router.routes[i] + def _trigger_mode(self) -> str: + if ( + self._req_body + and isinstance(self._req_body, type) + and issubclass(self._req_body, CommonLLMHttpRequestBody) + ): + return "chat" + return "command" + + async def map(self, input_data: Any) -> Any: + """Map the input data. + + Do some transformation for the input data. + + Args: + input_data (Any): The input data from caller. + + Returns: + Any: The mapped data. + """ + if not self._req_body or not input_data: + return await super().map(input_data) + if ( + isinstance(self._req_body, type) + and issubclass(self._req_body, BaseModel) + and isinstance(input_data, dict) + ): + return self._req_body(**input_data) + return await super().map(input_data) + def _create_route_func(self): from inspect import Parameter, Signature from typing import get_type_hints diff --git a/dbgpt/core/interface/file.py b/dbgpt/core/interface/file.py new file mode 100644 index 000000000..83a524510 --- /dev/null +++ b/dbgpt/core/interface/file.py @@ -0,0 +1,798 @@ +"""File storage interface.""" + +import dataclasses +import hashlib +import io +import os +import uuid +from abc import ABC, abstractmethod +from io import BytesIO +from typing import Any, BinaryIO, Dict, List, Optional, Tuple +from urllib.parse import parse_qs, urlencode, urlparse + +import requests + +from dbgpt.component import BaseComponent, ComponentType, SystemApp +from dbgpt.util.tracer import root_tracer, trace + +from .storage import ( + InMemoryStorage, + QuerySpec, + ResourceIdentifier, + StorageError, + StorageInterface, + StorageItem, +) + +_SCHEMA = "dbgpt-fs" + + +@dataclasses.dataclass +class FileMetadataIdentifier(ResourceIdentifier): + """File metadata identifier.""" + + file_id: str + bucket: str + + def to_dict(self) -> Dict: + """Convert the identifier to a dictionary.""" + return {"file_id": self.file_id, "bucket": self.bucket} + + @property + def str_identifier(self) -> str: + """Get the string identifier. + + Returns: + str: The string identifier + """ + return f"{self.bucket}/{self.file_id}" + + +@dataclasses.dataclass +class FileMetadata(StorageItem): + """File metadata for storage.""" + + file_id: str + bucket: str + file_name: str + file_size: int + storage_type: str + storage_path: str + uri: str + custom_metadata: Dict[str, Any] + file_hash: str + user_name: Optional[str] = None + sys_code: Optional[str] = None + _identifier: FileMetadataIdentifier = dataclasses.field(init=False) + + def __post_init__(self): + """Post init method.""" + self._identifier = FileMetadataIdentifier( + file_id=self.file_id, bucket=self.bucket + ) + custom_metadata = self.custom_metadata or {} + if not self.user_name: + self.user_name = custom_metadata.get("user_name") + if not self.sys_code: + self.sys_code = custom_metadata.get("sys_code") + + @property + def identifier(self) -> ResourceIdentifier: + """Get the resource identifier.""" + return self._identifier + + def merge(self, other: "StorageItem") -> None: + """Merge the metadata with another item.""" + if not isinstance(other, FileMetadata): + raise StorageError("Cannot merge different types of items") + self._from_object(other) + + def to_dict(self) -> Dict: + """Convert the metadata to a dictionary.""" + return { + "file_id": self.file_id, + "bucket": self.bucket, + "file_name": self.file_name, + "file_size": self.file_size, + "storage_type": self.storage_type, + "storage_path": self.storage_path, + "uri": self.uri, + "custom_metadata": self.custom_metadata, + "file_hash": self.file_hash, + } + + def _from_object(self, obj: "FileMetadata") -> None: + self.file_id = obj.file_id + self.bucket = obj.bucket + self.file_name = obj.file_name + self.file_size = obj.file_size + self.storage_type = obj.storage_type + self.storage_path = obj.storage_path + self.uri = obj.uri + self.custom_metadata = obj.custom_metadata + self.file_hash = obj.file_hash + self._identifier = obj._identifier + + +class FileStorageURI: + """File storage URI.""" + + def __init__( + self, + storage_type: str, + bucket: str, + file_id: str, + version: Optional[str] = None, + custom_params: Optional[Dict[str, Any]] = None, + ): + """Initialize the file storage URI.""" + self.scheme = _SCHEMA + self.storage_type = storage_type + self.bucket = bucket + self.file_id = file_id + self.version = version + self.custom_params = custom_params or {} + + @classmethod + def parse(cls, uri: str) -> "FileStorageURI": + """Parse the URI string.""" + parsed = urlparse(uri) + if parsed.scheme != _SCHEMA: + raise ValueError(f"Invalid URI scheme. Must be '{_SCHEMA}'") + path_parts = parsed.path.strip("/").split("/") + if len(path_parts) < 2: + raise ValueError("Invalid URI path. Must contain bucket and file ID") + storage_type = parsed.netloc + bucket = path_parts[0] + file_id = path_parts[1] + version = path_parts[2] if len(path_parts) > 2 else None + custom_params = parse_qs(parsed.query) + return cls(storage_type, bucket, file_id, version, custom_params) + + def __str__(self) -> str: + """Get the string representation of the URI.""" + base_uri = f"{self.scheme}://{self.storage_type}/{self.bucket}/{self.file_id}" + if self.version: + base_uri += f"/{self.version}" + if self.custom_params: + query_string = urlencode(self.custom_params, doseq=True) + base_uri += f"?{query_string}" + return base_uri + + +class StorageBackend(ABC): + """Storage backend interface.""" + + storage_type: str = "__base__" + + @abstractmethod + def save(self, bucket: str, file_id: str, file_data: BinaryIO) -> str: + """Save the file data to the storage backend. + + Args: + bucket (str): The bucket name + file_id (str): The file ID + file_data (BinaryIO): The file data + + Returns: + str: The storage path + """ + + @abstractmethod + def load(self, fm: FileMetadata) -> BinaryIO: + """Load the file data from the storage backend. + + Args: + fm (FileMetadata): The file metadata + + Returns: + BinaryIO: The file data + """ + + @abstractmethod + def delete(self, fm: FileMetadata) -> bool: + """Delete the file data from the storage backend. + + Args: + fm (FileMetadata): The file metadata + + Returns: + bool: True if the file was deleted, False otherwise + """ + + @property + @abstractmethod + def save_chunk_size(self) -> int: + """Get the save chunk size. + + Returns: + int: The save chunk size + """ + + +class LocalFileStorage(StorageBackend): + """Local file storage backend.""" + + storage_type: str = "local" + + def __init__(self, base_path: str, save_chunk_size: int = 1024 * 1024): + """Initialize the local file storage backend.""" + self.base_path = base_path + self._save_chunk_size = save_chunk_size + os.makedirs(self.base_path, exist_ok=True) + + @property + def save_chunk_size(self) -> int: + """Get the save chunk size.""" + return self._save_chunk_size + + def save(self, bucket: str, file_id: str, file_data: BinaryIO) -> str: + """Save the file data to the local storage backend.""" + bucket_path = os.path.join(self.base_path, bucket) + os.makedirs(bucket_path, exist_ok=True) + file_path = os.path.join(bucket_path, file_id) + with open(file_path, "wb") as f: + while True: + chunk = file_data.read(self.save_chunk_size) + if not chunk: + break + f.write(chunk) + return file_path + + def load(self, fm: FileMetadata) -> BinaryIO: + """Load the file data from the local storage backend.""" + bucket_path = os.path.join(self.base_path, fm.bucket) + file_path = os.path.join(bucket_path, fm.file_id) + return open(file_path, "rb") # noqa: SIM115 + + def delete(self, fm: FileMetadata) -> bool: + """Delete the file data from the local storage backend.""" + bucket_path = os.path.join(self.base_path, fm.bucket) + file_path = os.path.join(bucket_path, fm.file_id) + if os.path.exists(file_path): + os.remove(file_path) + return True + return False + + +class FileStorageSystem: + """File storage system.""" + + def __init__( + self, + storage_backends: Dict[str, StorageBackend], + metadata_storage: Optional[StorageInterface[FileMetadata, Any]] = None, + check_hash: bool = True, + ): + """Initialize the file storage system.""" + metadata_storage = metadata_storage or InMemoryStorage() + self.storage_backends = storage_backends + self.metadata_storage = metadata_storage + self.check_hash = check_hash + self._save_chunk_size = min( + backend.save_chunk_size for backend in storage_backends.values() + ) + + def _calculate_file_hash(self, file_data: BinaryIO) -> str: + """Calculate the MD5 hash of the file data.""" + if not self.check_hash: + return "-1" + hasher = hashlib.md5() + file_data.seek(0) + while chunk := file_data.read(self._save_chunk_size): + hasher.update(chunk) + file_data.seek(0) + return hasher.hexdigest() + + @trace("file_storage_system.save_file") + def save_file( + self, + bucket: str, + file_name: str, + file_data: BinaryIO, + storage_type: str, + custom_metadata: Optional[Dict[str, Any]] = None, + ) -> str: + """Save the file data to the storage backend.""" + file_id = str(uuid.uuid4()) + backend = self.storage_backends.get(storage_type) + if not backend: + raise ValueError(f"Unsupported storage type: {storage_type}") + + with root_tracer.start_span( + "file_storage_system.save_file.backend_save", + metadata={ + "bucket": bucket, + "file_id": file_id, + "file_name": file_name, + "storage_type": storage_type, + }, + ): + storage_path = backend.save(bucket, file_id, file_data) + file_data.seek(0, 2) # Move to the end of the file + file_size = file_data.tell() # Get the file size + file_data.seek(0) # Reset file pointer + + with root_tracer.start_span( + "file_storage_system.save_file.calculate_hash", + ): + file_hash = self._calculate_file_hash(file_data) + uri = FileStorageURI( + storage_type, bucket, file_id, custom_params=custom_metadata + ) + + metadata = FileMetadata( + file_id=file_id, + bucket=bucket, + file_name=file_name, + file_size=file_size, + storage_type=storage_type, + storage_path=storage_path, + uri=str(uri), + custom_metadata=custom_metadata or {}, + file_hash=file_hash, + ) + + self.metadata_storage.save(metadata) + return str(uri) + + @trace("file_storage_system.get_file") + def get_file(self, uri: str) -> Tuple[BinaryIO, FileMetadata]: + """Get the file data from the storage backend.""" + parsed_uri = FileStorageURI.parse(uri) + metadata = self.metadata_storage.load( + FileMetadataIdentifier( + file_id=parsed_uri.file_id, bucket=parsed_uri.bucket + ), + FileMetadata, + ) + if not metadata: + raise FileNotFoundError(f"No metadata found for URI: {uri}") + + backend = self.storage_backends.get(metadata.storage_type) + if not backend: + raise ValueError(f"Unsupported storage type: {metadata.storage_type}") + + with root_tracer.start_span( + "file_storage_system.get_file.backend_load", + metadata={ + "bucket": metadata.bucket, + "file_id": metadata.file_id, + "file_name": metadata.file_name, + "storage_type": metadata.storage_type, + }, + ): + file_data = backend.load(metadata) + + with root_tracer.start_span( + "file_storage_system.get_file.verify_hash", + ): + calculated_hash = self._calculate_file_hash(file_data) + if calculated_hash != "-1" and calculated_hash != metadata.file_hash: + raise ValueError("File integrity check failed. Hash mismatch.") + + return file_data, metadata + + def get_file_metadata(self, bucket: str, file_id: str) -> Optional[FileMetadata]: + """Get the file metadata. + + Args: + bucket (str): The bucket name + file_id (str): The file ID + + Returns: + Optional[FileMetadata]: The file metadata + """ + fid = FileMetadataIdentifier(file_id=file_id, bucket=bucket) + return self.metadata_storage.load(fid, FileMetadata) + + def delete_file(self, uri: str) -> bool: + """Delete the file data from the storage backend. + + Args: + uri (str): The file URI + + Returns: + bool: True if the file was deleted, False otherwise + """ + parsed_uri = FileStorageURI.parse(uri) + fid = FileMetadataIdentifier( + file_id=parsed_uri.file_id, bucket=parsed_uri.bucket + ) + metadata = self.metadata_storage.load(fid, FileMetadata) + if not metadata: + return False + + backend = self.storage_backends.get(metadata.storage_type) + if not backend: + raise ValueError(f"Unsupported storage type: {metadata.storage_type}") + + if backend.delete(metadata): + try: + self.metadata_storage.delete(fid) + return True + except Exception: + # If the metadata deletion fails, log the error and return False + return False + return False + + def list_files( + self, bucket: str, filters: Optional[Dict[str, Any]] = None + ) -> List[FileMetadata]: + """List the files in the bucket.""" + filters = filters or {} + filters["bucket"] = bucket + return self.metadata_storage.query(QuerySpec(conditions=filters), FileMetadata) + + +class FileStorageClient(BaseComponent): + """File storage client component.""" + + name = ComponentType.FILE_STORAGE_CLIENT.value + + def __init__( + self, + system_app: Optional[SystemApp] = None, + storage_system: Optional[FileStorageSystem] = None, + ): + """Initialize the file storage client.""" + super().__init__(system_app=system_app) + if not storage_system: + from pathlib import Path + + base_path = Path.home() / ".cache" / "dbgpt" / "files" + storage_system = FileStorageSystem( + { + LocalFileStorage.storage_type: LocalFileStorage( + base_path=str(base_path) + ) + } + ) + + self.system_app = system_app + self._storage_system = storage_system + + def init_app(self, system_app: SystemApp): + """Initialize the application.""" + self.system_app = system_app + + @property + def storage_system(self) -> FileStorageSystem: + """Get the file storage system.""" + if not self._storage_system: + raise ValueError("File storage system not initialized") + return self._storage_system + + def upload_file( + self, + bucket: str, + file_path: str, + storage_type: str, + custom_metadata: Optional[Dict[str, Any]] = None, + ) -> str: + """Upload a file to the storage system. + + Args: + bucket (str): The bucket name + file_path (str): The file path + storage_type (str): The storage type + custom_metadata (Dict[str, Any], optional): Custom metadata. Defaults to + None. + + Returns: + str: The file URI + """ + with open(file_path, "rb") as file: + return self.save_file( + bucket, os.path.basename(file_path), file, storage_type, custom_metadata + ) + + def save_file( + self, + bucket: str, + file_name: str, + file_data: BinaryIO, + storage_type: str, + custom_metadata: Optional[Dict[str, Any]] = None, + ) -> str: + """Save the file data to the storage system. + + Args: + bucket (str): The bucket name + file_name (str): The file name + file_data (BinaryIO): The file data + storage_type (str): The storage type + custom_metadata (Dict[str, Any], optional): Custom metadata. Defaults to + None. + + Returns: + str: The file URI + """ + return self.storage_system.save_file( + bucket, file_name, file_data, storage_type, custom_metadata + ) + + def download_file(self, uri: str, destination_path: str) -> None: + """Download a file from the storage system. + + Args: + uri (str): The file URI + destination_path (str): The destination + + Raises: + FileNotFoundError: If the file is not found + """ + file_data, _ = self.storage_system.get_file(uri) + with open(destination_path, "wb") as f: + f.write(file_data.read()) + + def get_file(self, uri: str) -> Tuple[BinaryIO, FileMetadata]: + """Get the file data from the storage system. + + Args: + uri (str): The file URI + + Returns: + Tuple[BinaryIO, FileMetadata]: The file data and metadata + """ + return self.storage_system.get_file(uri) + + def get_file_by_id( + self, bucket: str, file_id: str + ) -> Tuple[BinaryIO, FileMetadata]: + """Get the file data from the storage system by ID. + + Args: + bucket (str): The bucket name + file_id (str): The file ID + + Returns: + Tuple[BinaryIO, FileMetadata]: The file data and metadata + """ + metadata = self.storage_system.get_file_metadata(bucket, file_id) + if not metadata: + raise FileNotFoundError(f"File {file_id} not found in bucket {bucket}") + return self.get_file(metadata.uri) + + def delete_file(self, uri: str) -> bool: + """Delete the file data from the storage system. + + Args: + uri (str): The file URI + + Returns: + bool: True if the file was deleted, False otherwise + """ + return self.storage_system.delete_file(uri) + + def delete_file_by_id(self, bucket: str, file_id: str) -> bool: + """Delete the file data from the storage system by ID. + + Args: + bucket (str): The bucket name + file_id (str): The file ID + + Returns: + bool: True if the file was deleted, False otherwise + """ + metadata = self.storage_system.get_file_metadata(bucket, file_id) + if not metadata: + raise FileNotFoundError(f"File {file_id} not found in bucket {bucket}") + return self.delete_file(metadata.uri) + + def list_files( + self, bucket: str, filters: Optional[Dict[str, Any]] = None + ) -> List[FileMetadata]: + """List the files in the bucket. + + Args: + bucket (str): The bucket name + filters (Dict[str, Any], optional): Filters. Defaults to None. + + Returns: + List[FileMetadata]: The list of file metadata + """ + return self.storage_system.list_files(bucket, filters) + + +class SimpleDistributedStorage(StorageBackend): + """Simple distributed storage backend.""" + + storage_type: str = "distributed" + + def __init__( + self, + node_address: str, + local_storage_path: str, + save_chunk_size: int = 1024 * 1024, + transfer_chunk_size: int = 1024 * 1024, + transfer_timeout: int = 360, + api_prefix: str = "/api/v2/serve/file/files", + ): + """Initialize the simple distributed storage backend.""" + self.node_address = node_address + self.local_storage_path = local_storage_path + os.makedirs(self.local_storage_path, exist_ok=True) + self._save_chunk_size = save_chunk_size + self._transfer_chunk_size = transfer_chunk_size + self._transfer_timeout = transfer_timeout + self._api_prefix = api_prefix + + @property + def save_chunk_size(self) -> int: + """Get the save chunk size.""" + return self._save_chunk_size + + def _get_file_path(self, bucket: str, file_id: str, node_address: str) -> str: + node_id = hashlib.md5(node_address.encode()).hexdigest() + return os.path.join(self.local_storage_path, bucket, f"{file_id}_{node_id}") + + def _parse_node_address(self, fm: FileMetadata) -> str: + storage_path = fm.storage_path + if not storage_path.startswith("distributed://"): + raise ValueError("Invalid storage path") + return storage_path.split("//")[1].split("/")[0] + + def save(self, bucket: str, file_id: str, file_data: BinaryIO) -> str: + """Save the file data to the distributed storage backend. + + Just save the file locally. + + Args: + bucket (str): The bucket name + file_id (str): The file ID + file_data (BinaryIO): The file data + + Returns: + str: The storage path + """ + file_path = self._get_file_path(bucket, file_id, self.node_address) + os.makedirs(os.path.dirname(file_path), exist_ok=True) + with open(file_path, "wb") as f: + while True: + chunk = file_data.read(self.save_chunk_size) + if not chunk: + break + f.write(chunk) + + return f"distributed://{self.node_address}/{bucket}/{file_id}" + + def load(self, fm: FileMetadata) -> BinaryIO: + """Load the file data from the distributed storage backend. + + If the file is stored on the local node, load it from the local storage. + + Args: + fm (FileMetadata): The file metadata + + Returns: + BinaryIO: The file data + """ + file_id = fm.file_id + bucket = fm.bucket + node_address = self._parse_node_address(fm) + file_path = self._get_file_path(bucket, file_id, node_address) + + # TODO: check if the file is cached in local storage + if node_address == self.node_address: + if os.path.exists(file_path): + return open(file_path, "rb") # noqa: SIM115 + else: + raise FileNotFoundError(f"File {file_id} not found on the local node") + else: + response = requests.get( + f"http://{node_address}{self._api_prefix}/{bucket}/{file_id}", + timeout=self._transfer_timeout, + stream=True, + ) + response.raise_for_status() + # TODO: cache the file in local storage + return StreamedBytesIO( + response.iter_content(chunk_size=self._transfer_chunk_size) + ) + + def delete(self, fm: FileMetadata) -> bool: + """Delete the file data from the distributed storage backend. + + If the file is stored on the local node, delete it from the local storage. + If the file is stored on a remote node, send a delete request to the remote + node. + + Args: + fm (FileMetadata): The file metadata + + Returns: + bool: True if the file was deleted, False otherwise + """ + file_id = fm.file_id + bucket = fm.bucket + node_address = self._parse_node_address(fm) + file_path = self._get_file_path(bucket, file_id, node_address) + if node_address == self.node_address: + if os.path.exists(file_path): + os.remove(file_path) + return True + return False + else: + try: + response = requests.delete( + f"http://{node_address}{self._api_prefix}/{bucket}/{file_id}", + timeout=self._transfer_timeout, + ) + response.raise_for_status() + return True + except Exception: + return False + + +class StreamedBytesIO(io.BytesIO): + """A BytesIO subclass that can be used with streaming responses. + + Adapted from: https://gist.github.com/obskyr/b9d4b4223e7eaf4eedcd9defabb34f13 + """ + + def __init__(self, request_iterator): + """Initialize the StreamedBytesIO instance.""" + super().__init__() + self._bytes = BytesIO() + self._iterator = request_iterator + + def _load_all(self): + self._bytes.seek(0, io.SEEK_END) + for chunk in self._iterator: + self._bytes.write(chunk) + + def _load_until(self, goal_position): + current_position = self._bytes.seek(0, io.SEEK_END) + while current_position < goal_position: + try: + current_position += self._bytes.write(next(self._iterator)) + except StopIteration: + break + + def tell(self) -> int: + """Get the current position.""" + return self._bytes.tell() + + def read(self, size: Optional[int] = None) -> bytes: + """Read the data from the stream. + + Args: + size (Optional[int], optional): The number of bytes to read. Defaults to + None. + + Returns: + bytes: The read data + """ + left_off_at = self._bytes.tell() + if size is None: + self._load_all() + else: + goal_position = left_off_at + size + self._load_until(goal_position) + + self._bytes.seek(left_off_at) + return self._bytes.read(size) + + def seek(self, position: int, whence: int = io.SEEK_SET): + """Seek to a position in the stream. + + Args: + position (int): The position + whence (int, optional): The reference point. Defaults to io.SEEK + + Raises: + ValueError: If the reference point is invalid + """ + if whence == io.SEEK_END: + self._load_all() + else: + self._bytes.seek(position, whence) + + def __enter__(self): + """Enter the context manager.""" + return self + + def __exit__(self, ext_type, value, tb): + """Exit the context manager.""" + self._bytes.close() diff --git a/dbgpt/core/interface/tests/test_file.py b/dbgpt/core/interface/tests/test_file.py new file mode 100644 index 000000000..f6e462944 --- /dev/null +++ b/dbgpt/core/interface/tests/test_file.py @@ -0,0 +1,506 @@ +import hashlib +import io +import os +from unittest import mock + +import pytest + +from ..file import ( + FileMetadata, + FileMetadataIdentifier, + FileStorageClient, + FileStorageSystem, + InMemoryStorage, + LocalFileStorage, + SimpleDistributedStorage, +) + + +@pytest.fixture +def temp_test_file_dir(tmpdir): + return str(tmpdir) + + +@pytest.fixture +def temp_storage_path(tmpdir): + return str(tmpdir) + + +@pytest.fixture +def local_storage_backend(temp_storage_path): + return LocalFileStorage(temp_storage_path) + + +@pytest.fixture +def distributed_storage_backend(temp_storage_path): + node_address = "127.0.0.1:8000" + return SimpleDistributedStorage(node_address, temp_storage_path) + + +@pytest.fixture +def file_storage_system(local_storage_backend): + backends = {"local": local_storage_backend} + metadata_storage = InMemoryStorage() + return FileStorageSystem(backends, metadata_storage) + + +@pytest.fixture +def file_storage_client(file_storage_system): + return FileStorageClient(storage_system=file_storage_system) + + +@pytest.fixture +def sample_file_path(temp_test_file_dir): + file_path = os.path.join(temp_test_file_dir, "sample.txt") + with open(file_path, "wb") as f: + f.write(b"Sample file content") + return file_path + + +@pytest.fixture +def sample_file_data(): + return io.BytesIO(b"Sample file content for distributed storage") + + +def test_save_file(file_storage_client, sample_file_path): + bucket = "test-bucket" + uri = file_storage_client.upload_file( + bucket=bucket, file_path=sample_file_path, storage_type="local" + ) + assert uri.startswith("dbgpt-fs://local/test-bucket/") + assert os.path.exists(sample_file_path) + + +def test_get_file(file_storage_client, sample_file_path): + bucket = "test-bucket" + uri = file_storage_client.upload_file( + bucket=bucket, file_path=sample_file_path, storage_type="local" + ) + file_data, metadata = file_storage_client.storage_system.get_file(uri) + assert file_data.read() == b"Sample file content" + assert metadata.file_name == "sample.txt" + assert metadata.bucket == bucket + + +def test_delete_file(file_storage_client, sample_file_path): + bucket = "test-bucket" + uri = file_storage_client.upload_file( + bucket=bucket, file_path=sample_file_path, storage_type="local" + ) + assert len(file_storage_client.list_files(bucket=bucket)) == 1 + result = file_storage_client.delete_file(uri) + assert result is True + assert len(file_storage_client.list_files(bucket=bucket)) == 0 + + +def test_list_files(file_storage_client, sample_file_path): + bucket = "test-bucket" + uri1 = file_storage_client.upload_file( + bucket=bucket, file_path=sample_file_path, storage_type="local" + ) + files = file_storage_client.list_files(bucket=bucket) + assert len(files) == 1 + + +def test_save_file_unsupported_storage(file_storage_system, sample_file_path): + bucket = "test-bucket" + with pytest.raises(ValueError): + file_storage_system.save_file( + bucket=bucket, + file_name="unsupported.txt", + file_data=io.BytesIO(b"Unsupported storage"), + storage_type="unsupported", + ) + + +def test_get_file_not_found(file_storage_system): + with pytest.raises(FileNotFoundError): + file_storage_system.get_file("dbgpt-fs://local/test-bucket/nonexistent") + + +def test_delete_file_not_found(file_storage_system): + result = file_storage_system.delete_file("dbgpt-fs://local/test-bucket/nonexistent") + assert result is False + + +def test_metadata_management(file_storage_system): + bucket = "test-bucket" + file_id = "test_file" + metadata = file_storage_system.metadata_storage.save( + FileMetadata( + file_id=file_id, + bucket=bucket, + file_name="test.txt", + file_size=100, + storage_type="local", + storage_path="/path/to/test.txt", + uri="dbgpt-fs://local/test-bucket/test_file", + custom_metadata={"key": "value"}, + file_hash="hash", + ) + ) + + loaded_metadata = file_storage_system.metadata_storage.load( + FileMetadataIdentifier(file_id=file_id, bucket=bucket), FileMetadata + ) + assert loaded_metadata.file_name == "test.txt" + assert loaded_metadata.custom_metadata["key"] == "value" + assert loaded_metadata.bucket == bucket + + +def test_concurrent_save_and_delete(file_storage_client, sample_file_path): + bucket = "test-bucket" + + # Simulate concurrent file save and delete operations + def save_file(): + return file_storage_client.upload_file( + bucket=bucket, file_path=sample_file_path, storage_type="local" + ) + + def delete_file(uri): + return file_storage_client.delete_file(uri) + + uri = save_file() + + # Simulate concurrent operations + save_file() + delete_file(uri) + assert len(file_storage_client.list_files(bucket=bucket)) == 1 + + +def test_large_file_handling(file_storage_client, temp_storage_path): + bucket = "test-bucket" + large_file_path = os.path.join(temp_storage_path, "large_sample.bin") + with open(large_file_path, "wb") as f: + f.write(os.urandom(10 * 1024 * 1024)) # 10 MB file + + uri = file_storage_client.upload_file( + bucket=bucket, + file_path=large_file_path, + storage_type="local", + custom_metadata={"description": "Large file test"}, + ) + file_data, metadata = file_storage_client.storage_system.get_file(uri) + assert file_data.read() == open(large_file_path, "rb").read() + assert metadata.file_name == "large_sample.bin" + assert metadata.bucket == bucket + + +def test_file_hash_verification_success(file_storage_client, sample_file_path): + bucket = "test-bucket" + # Upload file and + uri = file_storage_client.upload_file( + bucket=bucket, file_path=sample_file_path, storage_type="local" + ) + + file_data, metadata = file_storage_client.storage_system.get_file(uri) + file_hash = metadata.file_hash + calculated_hash = file_storage_client.storage_system._calculate_file_hash(file_data) + + assert ( + file_hash == calculated_hash + ), "File hash should match after saving and loading" + + +def test_file_hash_verification_failure(file_storage_client, sample_file_path): + bucket = "test-bucket" + # Upload file and + uri = file_storage_client.upload_file( + bucket=bucket, file_path=sample_file_path, storage_type="local" + ) + + # Modify the file content manually to simulate file tampering + storage_system = file_storage_client.storage_system + metadata = storage_system.metadata_storage.load( + FileMetadataIdentifier(file_id=uri.split("/")[-1], bucket=bucket), FileMetadata + ) + with open(metadata.storage_path, "wb") as f: + f.write(b"Tampered content") + + # Get file should raise an exception due to hash mismatch + with pytest.raises(ValueError, match="File integrity check failed. Hash mismatch."): + storage_system.get_file(uri) + + +def test_file_isolation_across_buckets(file_storage_client, sample_file_path): + bucket1 = "bucket1" + bucket2 = "bucket2" + + # Upload the same file to two different buckets + uri1 = file_storage_client.upload_file( + bucket=bucket1, file_path=sample_file_path, storage_type="local" + ) + uri2 = file_storage_client.upload_file( + bucket=bucket2, file_path=sample_file_path, storage_type="local" + ) + + # Verify both URIs are different and point to different files + assert uri1 != uri2 + + file_data1, metadata1 = file_storage_client.storage_system.get_file(uri1) + file_data2, metadata2 = file_storage_client.storage_system.get_file(uri2) + + assert file_data1.read() == b"Sample file content" + assert file_data2.read() == b"Sample file content" + assert metadata1.bucket == bucket1 + assert metadata2.bucket == bucket2 + + +def test_list_files_in_specific_bucket(file_storage_client, sample_file_path): + bucket1 = "bucket1" + bucket2 = "bucket2" + + # Upload a file to both buckets + file_storage_client.upload_file( + bucket=bucket1, file_path=sample_file_path, storage_type="local" + ) + file_storage_client.upload_file( + bucket=bucket2, file_path=sample_file_path, storage_type="local" + ) + + # List files in bucket1 and bucket2 + files_in_bucket1 = file_storage_client.list_files(bucket=bucket1) + files_in_bucket2 = file_storage_client.list_files(bucket=bucket2) + + assert len(files_in_bucket1) == 1 + assert len(files_in_bucket2) == 1 + assert files_in_bucket1[0].bucket == bucket1 + assert files_in_bucket2[0].bucket == bucket2 + + +def test_delete_file_in_one_bucket_does_not_affect_other_bucket( + file_storage_client, sample_file_path +): + bucket1 = "bucket1" + bucket2 = "bucket2" + + # Upload the same file to two different buckets + uri1 = file_storage_client.upload_file( + bucket=bucket1, file_path=sample_file_path, storage_type="local" + ) + uri2 = file_storage_client.upload_file( + bucket=bucket2, file_path=sample_file_path, storage_type="local" + ) + + # Delete the file in bucket1 + file_storage_client.delete_file(uri1) + + # Check that the file in bucket1 is deleted + assert len(file_storage_client.list_files(bucket=bucket1)) == 0 + + # Check that the file in bucket2 is still there + assert len(file_storage_client.list_files(bucket=bucket2)) == 1 + file_data2, metadata2 = file_storage_client.storage_system.get_file(uri2) + assert file_data2.read() == b"Sample file content" + + +def test_file_hash_verification_in_different_buckets( + file_storage_client, sample_file_path +): + bucket1 = "bucket1" + bucket2 = "bucket2" + + # Upload the file to both buckets + uri1 = file_storage_client.upload_file( + bucket=bucket1, file_path=sample_file_path, storage_type="local" + ) + uri2 = file_storage_client.upload_file( + bucket=bucket2, file_path=sample_file_path, storage_type="local" + ) + + file_data1, metadata1 = file_storage_client.storage_system.get_file(uri1) + file_data2, metadata2 = file_storage_client.storage_system.get_file(uri2) + + # Verify that file hashes are the same for the same content + file_hash1 = file_storage_client.storage_system._calculate_file_hash(file_data1) + file_hash2 = file_storage_client.storage_system._calculate_file_hash(file_data2) + + assert file_hash1 == metadata1.file_hash + assert file_hash2 == metadata2.file_hash + assert file_hash1 == file_hash2 + + +def test_file_download_from_different_buckets( + file_storage_client, sample_file_path, temp_storage_path +): + bucket1 = "bucket1" + bucket2 = "bucket2" + + # Upload the file to both buckets + uri1 = file_storage_client.upload_file( + bucket=bucket1, file_path=sample_file_path, storage_type="local" + ) + uri2 = file_storage_client.upload_file( + bucket=bucket2, file_path=sample_file_path, storage_type="local" + ) + + # Download files to different locations + download_path1 = os.path.join(temp_storage_path, "downloaded_bucket1.txt") + download_path2 = os.path.join(temp_storage_path, "downloaded_bucket2.txt") + + file_storage_client.download_file(uri1, download_path1) + file_storage_client.download_file(uri2, download_path2) + + # Verify contents of downloaded files + assert open(download_path1, "rb").read() == b"Sample file content" + assert open(download_path2, "rb").read() == b"Sample file content" + + +def test_delete_all_files_in_bucket(file_storage_client, sample_file_path): + bucket1 = "bucket1" + bucket2 = "bucket2" + + # Upload files to both buckets + file_storage_client.upload_file( + bucket=bucket1, file_path=sample_file_path, storage_type="local" + ) + file_storage_client.upload_file( + bucket=bucket2, file_path=sample_file_path, storage_type="local" + ) + + # Delete all files in bucket1 + for file in file_storage_client.list_files(bucket=bucket1): + file_storage_client.delete_file(file.uri) + + # Verify bucket1 is empty + assert len(file_storage_client.list_files(bucket=bucket1)) == 0 + + # Verify bucket2 still has files + assert len(file_storage_client.list_files(bucket=bucket2)) == 1 + + +def test_simple_distributed_storage_save_file( + distributed_storage_backend, sample_file_data, temp_storage_path +): + bucket = "test-bucket" + file_id = "test_file" + file_path = distributed_storage_backend.save(bucket, file_id, sample_file_data) + + expected_path = os.path.join( + temp_storage_path, + bucket, + f"{file_id}_{hashlib.md5('127.0.0.1:8000'.encode()).hexdigest()}", + ) + assert file_path == f"distributed://127.0.0.1:8000/{bucket}/{file_id}" + assert os.path.exists(expected_path) + + +def test_simple_distributed_storage_load_file_local( + distributed_storage_backend, sample_file_data +): + bucket = "test-bucket" + file_id = "test_file" + distributed_storage_backend.save(bucket, file_id, sample_file_data) + + metadata = FileMetadata( + file_id=file_id, + bucket=bucket, + file_name="test.txt", + file_size=len(sample_file_data.getvalue()), + storage_type="distributed", + storage_path=f"distributed://127.0.0.1:8000/{bucket}/{file_id}", + uri=f"distributed://127.0.0.1:8000/{bucket}/{file_id}", + custom_metadata={}, + file_hash="hash", + ) + + file_data = distributed_storage_backend.load(metadata) + assert file_data.read() == b"Sample file content for distributed storage" + + +@mock.patch("requests.get") +def test_simple_distributed_storage_load_file_remote( + mock_get, distributed_storage_backend, sample_file_data +): + bucket = "test-bucket" + file_id = "test_file" + remote_node_address = "127.0.0.2:8000" + + # Mock the response from remote node + mock_response = mock.Mock() + mock_response.iter_content = mock.Mock( + return_value=iter([b"Sample file content for distributed storage"]) + ) + mock_response.raise_for_status = mock.Mock(return_value=None) + mock_get.return_value = mock_response + + metadata = FileMetadata( + file_id=file_id, + bucket=bucket, + file_name="test.txt", + file_size=len(sample_file_data.getvalue()), + storage_type="distributed", + storage_path=f"distributed://{remote_node_address}/{bucket}/{file_id}", + uri=f"distributed://{remote_node_address}/{bucket}/{file_id}", + custom_metadata={}, + file_hash="hash", + ) + + file_data = distributed_storage_backend.load(metadata) + assert file_data.read() == b"Sample file content for distributed storage" + mock_get.assert_called_once_with( + f"http://{remote_node_address}/api/v2/serve/file/files/{bucket}/{file_id}", + stream=True, + timeout=360, + ) + + +def test_simple_distributed_storage_delete_file_local( + distributed_storage_backend, sample_file_data, temp_storage_path +): + bucket = "test-bucket" + file_id = "test_file" + distributed_storage_backend.save(bucket, file_id, sample_file_data) + + metadata = FileMetadata( + file_id=file_id, + bucket=bucket, + file_name="test.txt", + file_size=len(sample_file_data.getvalue()), + storage_type="distributed", + storage_path=f"distributed://127.0.0.1:8000/{bucket}/{file_id}", + uri=f"distributed://127.0.0.1:8000/{bucket}/{file_id}", + custom_metadata={}, + file_hash="hash", + ) + + result = distributed_storage_backend.delete(metadata) + file_path = os.path.join( + temp_storage_path, + bucket, + f"{file_id}_{hashlib.md5('127.0.0.1:8000'.encode()).hexdigest()}", + ) + assert result is True + assert not os.path.exists(file_path) + + +@mock.patch("requests.delete") +def test_simple_distributed_storage_delete_file_remote( + mock_delete, distributed_storage_backend, sample_file_data +): + bucket = "test-bucket" + file_id = "test_file" + remote_node_address = "127.0.0.2:8000" + + mock_response = mock.Mock() + mock_response.raise_for_status = mock.Mock(return_value=None) + mock_delete.return_value = mock_response + + metadata = FileMetadata( + file_id=file_id, + bucket=bucket, + file_name="test.txt", + file_size=len(sample_file_data.getvalue()), + storage_type="distributed", + storage_path=f"distributed://{remote_node_address}/{bucket}/{file_id}", + uri=f"distributed://{remote_node_address}/{bucket}/{file_id}", + custom_metadata={}, + file_hash="hash", + ) + + result = distributed_storage_backend.delete(metadata) + assert result is True + mock_delete.assert_called_once_with( + f"http://{remote_node_address}/api/v2/serve/file/files/{bucket}/{file_id}", + timeout=360, + ) diff --git a/dbgpt/serve/file/__init__.py b/dbgpt/serve/file/__init__.py new file mode 100644 index 000000000..54a428180 --- /dev/null +++ b/dbgpt/serve/file/__init__.py @@ -0,0 +1,2 @@ +# This is an auto-generated __init__.py file +# generated by `dbgpt new serve file` diff --git a/dbgpt/serve/file/api/__init__.py b/dbgpt/serve/file/api/__init__.py new file mode 100644 index 000000000..54a428180 --- /dev/null +++ b/dbgpt/serve/file/api/__init__.py @@ -0,0 +1,2 @@ +# This is an auto-generated __init__.py file +# generated by `dbgpt new serve file` diff --git a/dbgpt/serve/file/api/endpoints.py b/dbgpt/serve/file/api/endpoints.py new file mode 100644 index 000000000..26bbb9673 --- /dev/null +++ b/dbgpt/serve/file/api/endpoints.py @@ -0,0 +1,169 @@ +import logging +from functools import cache +from typing import List, Optional +from urllib.parse import quote + +from fastapi import APIRouter, Depends, HTTPException, Query, UploadFile +from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer +from starlette.responses import StreamingResponse + +from dbgpt.component import SystemApp +from dbgpt.serve.core import Result, blocking_func_to_async +from dbgpt.util import PaginationResult + +from ..config import APP_NAME, SERVE_APP_NAME, SERVE_SERVICE_COMPONENT_NAME, ServeConfig +from ..service.service import Service +from .schemas import ServeRequest, ServerResponse, UploadFileResponse + +router = APIRouter() +logger = logging.getLogger(__name__) + +# Add your API endpoints here + +global_system_app: Optional[SystemApp] = None + + +def get_service() -> Service: + """Get the service instance""" + return global_system_app.get_component(SERVE_SERVICE_COMPONENT_NAME, Service) + + +get_bearer_token = HTTPBearer(auto_error=False) + + +@cache +def _parse_api_keys(api_keys: str) -> List[str]: + """Parse the string api keys to a list + + Args: + api_keys (str): The string api keys + + Returns: + List[str]: The list of api keys + """ + if not api_keys: + return [] + return [key.strip() for key in api_keys.split(",")] + + +async def check_api_key( + auth: Optional[HTTPAuthorizationCredentials] = Depends(get_bearer_token), + service: Service = Depends(get_service), +) -> Optional[str]: + """Check the api key + + If the api key is not set, allow all. + + Your can pass the token in you request header like this: + + .. code-block:: python + + import requests + + client_api_key = "your_api_key" + headers = {"Authorization": "Bearer " + client_api_key} + res = requests.get("http://test/hello", headers=headers) + assert res.status_code == 200 + + """ + if service.config.api_keys: + api_keys = _parse_api_keys(service.config.api_keys) + if auth is None or (token := auth.credentials) not in api_keys: + raise HTTPException( + status_code=401, + detail={ + "error": { + "message": "", + "type": "invalid_request_error", + "param": None, + "code": "invalid_api_key", + } + }, + ) + return token + else: + # api_keys not set; allow all + return None + + +@router.get("/health") +async def health(): + """Health check endpoint""" + return {"status": "ok"} + + +@router.get("/test_auth", dependencies=[Depends(check_api_key)]) +async def test_auth(): + """Test auth endpoint""" + return {"status": "ok"} + + +@router.post( + "/files/{bucket}", + response_model=Result[List[UploadFileResponse]], + dependencies=[Depends(check_api_key)], +) +async def upload_files( + bucket: str, + files: List[UploadFile], + user_name: Optional[str] = Query(default=None, description="user name"), + sys_code: Optional[str] = Query(default=None, description="system code"), + service: Service = Depends(get_service), +) -> Result[List[UploadFileResponse]]: + """Upload files by a list of UploadFile.""" + logger.info(f"upload_files: bucket={bucket}, files={files}") + results = await blocking_func_to_async( + global_system_app, + service.upload_files, + bucket, + "distributed", + files, + user_name, + sys_code, + ) + return Result.succ(results) + + +@router.get("/files/{bucket}/{file_id}", dependencies=[Depends(check_api_key)]) +async def download_file( + bucket: str, file_id: str, service: Service = Depends(get_service) +): + """Download a file by file_id.""" + logger.info(f"download_file: bucket={bucket}, file_id={file_id}") + file_data, file_metadata = await blocking_func_to_async( + global_system_app, service.download_file, bucket, file_id + ) + file_name_encoded = quote(file_metadata.file_name) + + def file_iterator(raw_iter): + with raw_iter: + while chunk := raw_iter.read( + service.config.file_server_download_chunk_size + ): + yield chunk + + response = StreamingResponse( + file_iterator(file_data), media_type="application/octet-stream" + ) + response.headers[ + "Content-Disposition" + ] = f"attachment; filename={file_name_encoded}" + return response + + +@router.delete("/files/{bucket}/{file_id}", dependencies=[Depends(check_api_key)]) +async def delete_file( + bucket: str, file_id: str, service: Service = Depends(get_service) +): + """Delete a file by file_id.""" + await blocking_func_to_async( + global_system_app, service.delete_file, bucket, file_id + ) + return Result.succ(None) + + +def init_endpoints(system_app: SystemApp) -> None: + """Initialize the endpoints""" + global global_system_app + system_app.register(Service) + global_system_app = system_app diff --git a/dbgpt/serve/file/api/schemas.py b/dbgpt/serve/file/api/schemas.py new file mode 100644 index 000000000..911f71db3 --- /dev/null +++ b/dbgpt/serve/file/api/schemas.py @@ -0,0 +1,43 @@ +# Define your Pydantic schemas here +from typing import Any, Dict + +from dbgpt._private.pydantic import BaseModel, ConfigDict, Field, model_to_dict + +from ..config import SERVE_APP_NAME_HUMP + + +class ServeRequest(BaseModel): + """File request model""" + + # TODO define your own fields here + + model_config = ConfigDict(title=f"ServeRequest for {SERVE_APP_NAME_HUMP}") + + def to_dict(self, **kwargs) -> Dict[str, Any]: + """Convert the model to a dictionary""" + return model_to_dict(self, **kwargs) + + +class ServerResponse(BaseModel): + """File response model""" + + # TODO define your own fields here + + model_config = ConfigDict(title=f"ServerResponse for {SERVE_APP_NAME_HUMP}") + + def to_dict(self, **kwargs) -> Dict[str, Any]: + """Convert the model to a dictionary""" + return model_to_dict(self, **kwargs) + + +class UploadFileResponse(BaseModel): + """Upload file response model""" + + file_name: str = Field(..., title="The name of the uploaded file") + file_id: str = Field(..., title="The ID of the uploaded file") + bucket: str = Field(..., title="The bucket of the uploaded file") + uri: str = Field(..., title="The URI of the uploaded file") + + def to_dict(self, **kwargs) -> Dict[str, Any]: + """Convert the model to a dictionary""" + return model_to_dict(self, **kwargs) diff --git a/dbgpt/serve/file/config.py b/dbgpt/serve/file/config.py new file mode 100644 index 000000000..1ab1afede --- /dev/null +++ b/dbgpt/serve/file/config.py @@ -0,0 +1,68 @@ +from dataclasses import dataclass, field +from typing import Optional + +from dbgpt.serve.core import BaseServeConfig + +APP_NAME = "file" +SERVE_APP_NAME = "dbgpt_serve_file" +SERVE_APP_NAME_HUMP = "dbgpt_serve_File" +SERVE_CONFIG_KEY_PREFIX = "dbgpt.serve.file." +SERVE_SERVICE_COMPONENT_NAME = f"{SERVE_APP_NAME}_service" +# Database table name +SERVER_APP_TABLE_NAME = "dbgpt_serve_file" + + +@dataclass +class ServeConfig(BaseServeConfig): + """Parameters for the serve command""" + + # TODO: add your own parameters here + api_keys: Optional[str] = field( + default=None, metadata={"help": "API keys for the endpoint, if None, allow all"} + ) + check_hash: Optional[bool] = field( + default=True, metadata={"help": "Check the hash of the file when downloading"} + ) + file_server_host: Optional[str] = field( + default=None, metadata={"help": "The host of the file server"} + ) + file_server_port: Optional[int] = field( + default=5670, metadata={"help": "The port of the file server"} + ) + file_server_download_chunk_size: Optional[int] = field( + default=1024 * 1024, + metadata={"help": "The chunk size when downloading the file"}, + ) + file_server_save_chunk_size: Optional[int] = field( + default=1024 * 1024, metadata={"help": "The chunk size when saving the file"} + ) + file_server_transfer_chunk_size: Optional[int] = field( + default=1024 * 1024, + metadata={"help": "The chunk size when transferring the file"}, + ) + file_server_transfer_timeout: Optional[int] = field( + default=360, metadata={"help": "The timeout when transferring the file"} + ) + local_storage_path: Optional[str] = field( + default=None, metadata={"help": "The local storage path"} + ) + + def get_node_address(self) -> str: + """Get the node address""" + file_server_host = self.file_server_host + if not file_server_host: + from dbgpt.util.net_utils import _get_ip_address + + file_server_host = _get_ip_address() + file_server_port = self.file_server_port or 5670 + return f"{file_server_host}:{file_server_port}" + + def get_local_storage_path(self) -> str: + """Get the local storage path""" + local_storage_path = self.local_storage_path + if not local_storage_path: + from pathlib import Path + + base_path = Path.home() / ".cache" / "dbgpt" / "files" + local_storage_path = str(base_path) + return local_storage_path diff --git a/dbgpt/serve/file/dependencies.py b/dbgpt/serve/file/dependencies.py new file mode 100644 index 000000000..8598ecd97 --- /dev/null +++ b/dbgpt/serve/file/dependencies.py @@ -0,0 +1 @@ +# Define your dependencies here diff --git a/dbgpt/serve/file/models/__init__.py b/dbgpt/serve/file/models/__init__.py new file mode 100644 index 000000000..54a428180 --- /dev/null +++ b/dbgpt/serve/file/models/__init__.py @@ -0,0 +1,2 @@ +# This is an auto-generated __init__.py file +# generated by `dbgpt new serve file` diff --git a/dbgpt/serve/file/models/file_adapter.py b/dbgpt/serve/file/models/file_adapter.py new file mode 100644 index 000000000..29ee831f4 --- /dev/null +++ b/dbgpt/serve/file/models/file_adapter.py @@ -0,0 +1,88 @@ +import json +from typing import Type + +from sqlalchemy.orm import Session + +from dbgpt.core.interface.file import FileMetadata, FileMetadataIdentifier +from dbgpt.core.interface.storage import StorageItemAdapter + +from .models import ServeEntity + + +class FileMetadataAdapter(StorageItemAdapter[FileMetadata, ServeEntity]): + """File metadata adapter. + + Convert between storage format and database model. + """ + + def to_storage_format(self, item: FileMetadata) -> ServeEntity: + """Convert to storage format.""" + custom_metadata = ( + {k: v for k, v in item.custom_metadata.items()} + if item.custom_metadata + else {} + ) + user_name = item.user_name or custom_metadata.get("user_name") + sys_code = item.sys_code or custom_metadata.get("sys_code") + if "user_name" in custom_metadata: + del custom_metadata["user_name"] + if "sys_code" in custom_metadata: + del custom_metadata["sys_code"] + custom_metadata_json = ( + json.dumps(custom_metadata, ensure_ascii=False) if custom_metadata else None + ) + return ServeEntity( + bucket=item.bucket, + file_id=item.file_id, + file_name=item.file_name, + file_size=item.file_size, + storage_type=item.storage_type, + storage_path=item.storage_path, + uri=item.uri, + custom_metadata=custom_metadata_json, + file_hash=item.file_hash, + user_name=user_name, + sys_code=sys_code, + ) + + def from_storage_format(self, model: ServeEntity) -> FileMetadata: + """Convert from storage format.""" + custom_metadata = ( + json.loads(model.custom_metadata) if model.custom_metadata else None + ) + if custom_metadata is None: + custom_metadata = {} + if model.user_name: + custom_metadata["user_name"] = model.user_name + if model.sys_code: + custom_metadata["sys_code"] = model.sys_code + + return FileMetadata( + bucket=model.bucket, + file_id=model.file_id, + file_name=model.file_name, + file_size=model.file_size, + storage_type=model.storage_type, + storage_path=model.storage_path, + uri=model.uri, + custom_metadata=custom_metadata, + file_hash=model.file_hash, + user_name=model.user_name, + sys_code=model.sys_code, + ) + + def get_query_for_identifier( + self, + storage_format: Type[ServeEntity], + resource_id: FileMetadataIdentifier, + **kwargs, + ): + """Get query for identifier.""" + session: Session = kwargs.get("session") + if session is None: + raise Exception("session is None") + return ( + session.query(storage_format) + .filter(storage_format.bucket == resource_id.bucket) + .filter(storage_format.file_id == resource_id.file_id) + ) diff --git a/dbgpt/serve/file/models/models.py b/dbgpt/serve/file/models/models.py new file mode 100644 index 000000000..fd816740d --- /dev/null +++ b/dbgpt/serve/file/models/models.py @@ -0,0 +1,90 @@ +"""This is an auto-generated model file +You can define your own models and DAOs here +""" + +from datetime import datetime +from typing import Any, Dict, Union + +from sqlalchemy import Column, DateTime, Index, Integer, String, Text, UniqueConstraint + +from dbgpt.storage.metadata import BaseDao, Model, db + +from ..api.schemas import ServeRequest, ServerResponse +from ..config import SERVER_APP_TABLE_NAME, ServeConfig + + +class ServeEntity(Model): + __tablename__ = SERVER_APP_TABLE_NAME + __table_args__ = (UniqueConstraint("bucket", "file_id", name="uk_bucket_file_id"),) + + id = Column(Integer, primary_key=True, comment="Auto increment id") + + bucket = Column(String(255), nullable=False, comment="Bucket name") + file_id = Column(String(255), nullable=False, comment="File id") + file_name = Column(String(256), nullable=False, comment="File name") + file_size = Column(Integer, nullable=True, comment="File size") + storage_type = Column(String(32), nullable=False, comment="Storage type") + storage_path = Column(String(512), nullable=False, comment="Storage path") + uri = Column(String(512), nullable=False, comment="File URI") + custom_metadata = Column( + Text, nullable=True, comment="Custom metadata, JSON format" + ) + file_hash = Column(String(128), nullable=True, comment="File hash") + user_name = Column(String(128), index=True, nullable=True, comment="User name") + sys_code = Column(String(128), index=True, nullable=True, comment="System code") + gmt_created = Column(DateTime, default=datetime.now, comment="Record creation time") + gmt_modified = Column(DateTime, default=datetime.now, comment="Record update time") + + def __repr__(self): + return ( + f"ServeEntity(id={self.id}, gmt_created='{self.gmt_created}', " + f"gmt_modified='{self.gmt_modified}')" + ) + + +class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]): + """The DAO class for File""" + + def __init__(self, serve_config: ServeConfig): + super().__init__() + self._serve_config = serve_config + + def from_request(self, request: Union[ServeRequest, Dict[str, Any]]) -> ServeEntity: + """Convert the request to an entity + + Args: + request (Union[ServeRequest, Dict[str, Any]]): The request + + Returns: + T: The entity + """ + request_dict = ( + request.to_dict() if isinstance(request, ServeRequest) else request + ) + entity = ServeEntity(**request_dict) + # TODO implement your own logic here, transfer the request_dict to an entity + return entity + + def to_request(self, entity: ServeEntity) -> ServeRequest: + """Convert the entity to a request + + Args: + entity (T): The entity + + Returns: + REQ: The request + """ + # TODO implement your own logic here, transfer the entity to a request + return ServeRequest() + + def to_response(self, entity: ServeEntity) -> ServerResponse: + """Convert the entity to a response + + Args: + entity (T): The entity + + Returns: + RES: The response + """ + # TODO implement your own logic here, transfer the entity to a response + return ServerResponse() diff --git a/dbgpt/serve/file/serve.py b/dbgpt/serve/file/serve.py new file mode 100644 index 000000000..559509573 --- /dev/null +++ b/dbgpt/serve/file/serve.py @@ -0,0 +1,113 @@ +import logging +from typing import List, Optional, Union + +from sqlalchemy import URL + +from dbgpt.component import SystemApp +from dbgpt.core.interface.file import FileStorageClient +from dbgpt.serve.core import BaseServe +from dbgpt.storage.metadata import DatabaseManager + +from .api.endpoints import init_endpoints, router +from .config import ( + APP_NAME, + SERVE_APP_NAME, + SERVE_APP_NAME_HUMP, + SERVE_CONFIG_KEY_PREFIX, + ServeConfig, +) + +logger = logging.getLogger(__name__) + + +class Serve(BaseServe): + """Serve component for DB-GPT""" + + name = SERVE_APP_NAME + + def __init__( + self, + system_app: SystemApp, + api_prefix: Optional[str] = f"/api/v2/serve/{APP_NAME}", + api_tags: Optional[List[str]] = None, + db_url_or_db: Union[str, URL, DatabaseManager] = None, + try_create_tables: Optional[bool] = False, + ): + if api_tags is None: + api_tags = [SERVE_APP_NAME_HUMP] + super().__init__( + system_app, api_prefix, api_tags, db_url_or_db, try_create_tables + ) + self._db_manager: Optional[DatabaseManager] = None + + self._db_manager: Optional[DatabaseManager] = None + self._file_storage_client: Optional[FileStorageClient] = None + self._serve_config: Optional[ServeConfig] = None + + def init_app(self, system_app: SystemApp): + if self._app_has_initiated: + return + self._system_app = system_app + self._system_app.app.include_router( + router, prefix=self._api_prefix, tags=self._api_tags + ) + init_endpoints(self._system_app) + self._app_has_initiated = True + + def on_init(self): + """Called when init the application. + + You can do some initialization here. You can't get other components here because they may be not initialized yet + """ + # import your own module here to ensure the module is loaded before the application starts + from .models.models import ServeEntity + + def before_start(self): + """Called before the start of the application.""" + from dbgpt.core.interface.file import ( + FileStorageSystem, + SimpleDistributedStorage, + ) + from dbgpt.storage.metadata.db_storage import SQLAlchemyStorage + from dbgpt.util.serialization.json_serialization import JsonSerializer + + from .models.file_adapter import FileMetadataAdapter + from .models.models import ServeEntity + + self._serve_config = ServeConfig.from_app_config( + self._system_app.config, SERVE_CONFIG_KEY_PREFIX + ) + + self._db_manager = self.create_or_get_db_manager() + serializer = JsonSerializer() + storage = SQLAlchemyStorage( + self._db_manager, + ServeEntity, + FileMetadataAdapter(), + serializer, + ) + simple_distributed_storage = SimpleDistributedStorage( + node_address=self._serve_config.get_node_address(), + local_storage_path=self._serve_config.get_local_storage_path(), + save_chunk_size=self._serve_config.file_server_save_chunk_size, + transfer_chunk_size=self._serve_config.file_server_transfer_chunk_size, + transfer_timeout=self._serve_config.file_server_transfer_timeout, + ) + storage_backends = { + simple_distributed_storage.storage_type: simple_distributed_storage, + } + fs = FileStorageSystem( + storage_backends, + metadata_storage=storage, + check_hash=self._serve_config.check_hash, + ) + self._file_storage_client = FileStorageClient( + system_app=self._system_app, storage_system=fs + ) + + @property + def file_storage_client(self) -> FileStorageClient: + """Returns the file storage client.""" + if not self._file_storage_client: + raise ValueError("File storage client is not initialized") + return self._file_storage_client diff --git a/dbgpt/serve/file/service/__init__.py b/dbgpt/serve/file/service/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/dbgpt/serve/file/service/service.py b/dbgpt/serve/file/service/service.py new file mode 100644 index 000000000..13e8b6225 --- /dev/null +++ b/dbgpt/serve/file/service/service.py @@ -0,0 +1,119 @@ +import logging +from typing import BinaryIO, List, Optional, Tuple + +from fastapi import UploadFile + +from dbgpt.component import BaseComponent, SystemApp +from dbgpt.core.interface.file import FileMetadata, FileStorageClient, FileStorageURI +from dbgpt.serve.core import BaseService +from dbgpt.storage.metadata import BaseDao +from dbgpt.util.pagination_utils import PaginationResult +from dbgpt.util.tracer import root_tracer, trace + +from ..api.schemas import ServeRequest, ServerResponse, UploadFileResponse +from ..config import SERVE_CONFIG_KEY_PREFIX, SERVE_SERVICE_COMPONENT_NAME, ServeConfig +from ..models.models import ServeDao, ServeEntity + +logger = logging.getLogger(__name__) + + +class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]): + """The service class for File""" + + name = SERVE_SERVICE_COMPONENT_NAME + + def __init__(self, system_app: SystemApp, dao: Optional[ServeDao] = None): + self._system_app = None + self._serve_config: ServeConfig = None + self._dao: ServeDao = dao + super().__init__(system_app) + + def init_app(self, system_app: SystemApp) -> None: + """Initialize the service + + Args: + system_app (SystemApp): The system app + """ + super().init_app(system_app) + self._serve_config = ServeConfig.from_app_config( + system_app.config, SERVE_CONFIG_KEY_PREFIX + ) + self._dao = self._dao or ServeDao(self._serve_config) + self._system_app = system_app + + @property + def dao(self) -> BaseDao[ServeEntity, ServeRequest, ServerResponse]: + """Returns the internal DAO.""" + return self._dao + + @property + def config(self) -> ServeConfig: + """Returns the internal ServeConfig.""" + return self._serve_config + + @property + def file_storage_client(self) -> FileStorageClient: + """Returns the internal FileStorageClient. + + Returns: + FileStorageClient: The internal FileStorageClient + """ + file_storage_client = FileStorageClient.get_instance( + self._system_app, default_component=None + ) + if file_storage_client: + return file_storage_client + else: + from ..serve import Serve + + file_storage_client = Serve.get_instance( + self._system_app + ).file_storage_client + self._system_app.register_instance(file_storage_client) + return file_storage_client + + @trace("upload_files") + def upload_files( + self, + bucket: str, + storage_type: str, + files: List[UploadFile], + user_name: Optional[str] = None, + sys_code: Optional[str] = None, + ) -> List[UploadFileResponse]: + """Upload files by a list of UploadFile.""" + results = [] + for file in files: + file_name = file.filename + logger.info(f"Uploading file {file_name} to bucket {bucket}") + custom_metadata = { + "user_name": user_name, + "sys_code": sys_code, + } + uri = self.file_storage_client.save_file( + bucket, + file_name, + file_data=file.file, + storage_type=storage_type, + custom_metadata=custom_metadata, + ) + parsed_uri = FileStorageURI.parse(uri) + logger.info(f"Uploaded file {file_name} to bucket {bucket}, uri={uri}") + results.append( + UploadFileResponse( + file_name=file_name, + file_id=parsed_uri.file_id, + bucket=bucket, + uri=uri, + ) + ) + return results + + @trace("download_file") + def download_file(self, bucket: str, file_id: str) -> Tuple[BinaryIO, FileMetadata]: + """Download a file by file_id.""" + return self.file_storage_client.get_file_by_id(bucket, file_id) + + def delete_file(self, bucket: str, file_id: str) -> None: + """Delete a file by file_id.""" + self.file_storage_client.delete_file_by_id(bucket, file_id) diff --git a/dbgpt/serve/file/tests/__init__.py b/dbgpt/serve/file/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/dbgpt/serve/file/tests/test_endpoints.py b/dbgpt/serve/file/tests/test_endpoints.py new file mode 100644 index 000000000..ba7b4f0cd --- /dev/null +++ b/dbgpt/serve/file/tests/test_endpoints.py @@ -0,0 +1,124 @@ +import pytest +from fastapi import FastAPI +from httpx import AsyncClient + +from dbgpt.component import SystemApp +from dbgpt.serve.core.tests.conftest import asystem_app, client +from dbgpt.storage.metadata import db +from dbgpt.util import PaginationResult + +from ..api.endpoints import init_endpoints, router +from ..api.schemas import ServeRequest, ServerResponse +from ..config import SERVE_CONFIG_KEY_PREFIX + + +@pytest.fixture(autouse=True) +def setup_and_teardown(): + db.init_db("sqlite:///:memory:") + db.create_all() + + yield + + +def client_init_caller(app: FastAPI, system_app: SystemApp): + app.include_router(router) + init_endpoints(system_app) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "client, asystem_app, has_auth", + [ + ( + { + "app_caller": client_init_caller, + "client_api_key": "test_token1", + }, + { + "app_config": { + f"{SERVE_CONFIG_KEY_PREFIX}api_keys": "test_token1,test_token2" + } + }, + True, + ), + ( + { + "app_caller": client_init_caller, + "client_api_key": "error_token", + }, + { + "app_config": { + f"{SERVE_CONFIG_KEY_PREFIX}api_keys": "test_token1,test_token2" + } + }, + False, + ), + ], + indirect=["client", "asystem_app"], +) +async def test_api_health(client: AsyncClient, asystem_app, has_auth: bool): + response = await client.get("/test_auth") + if has_auth: + assert response.status_code == 200 + assert response.json() == {"status": "ok"} + else: + assert response.status_code == 401 + assert response.json() == { + "detail": { + "error": { + "message": "", + "type": "invalid_request_error", + "param": None, + "code": "invalid_api_key", + } + } + } + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "client", [{"app_caller": client_init_caller}], indirect=["client"] +) +async def test_api_health(client: AsyncClient): + response = await client.get("/health") + assert response.status_code == 200 + assert response.json() == {"status": "ok"} + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "client", [{"app_caller": client_init_caller}], indirect=["client"] +) +async def test_api_create(client: AsyncClient): + # TODO: add your test case + pass + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "client", [{"app_caller": client_init_caller}], indirect=["client"] +) +async def test_api_update(client: AsyncClient): + # TODO: implement your test case + pass + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "client", [{"app_caller": client_init_caller}], indirect=["client"] +) +async def test_api_query(client: AsyncClient): + # TODO: implement your test case + pass + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "client", [{"app_caller": client_init_caller}], indirect=["client"] +) +async def test_api_query_by_page(client: AsyncClient): + # TODO: implement your test case + pass + + +# Add more test cases according to your own logic diff --git a/dbgpt/serve/file/tests/test_models.py b/dbgpt/serve/file/tests/test_models.py new file mode 100644 index 000000000..8b66e9f97 --- /dev/null +++ b/dbgpt/serve/file/tests/test_models.py @@ -0,0 +1,99 @@ +import pytest + +from dbgpt.storage.metadata import db + +from ..api.schemas import ServeRequest, ServerResponse +from ..config import ServeConfig +from ..models.models import ServeDao, ServeEntity + + +@pytest.fixture(autouse=True) +def setup_and_teardown(): + db.init_db("sqlite:///:memory:") + db.create_all() + + yield + + +@pytest.fixture +def server_config(): + # TODO : build your server config + return ServeConfig() + + +@pytest.fixture +def dao(server_config): + return ServeDao(server_config) + + +@pytest.fixture +def default_entity_dict(): + # TODO: build your default entity dict + return {} + + +def test_table_exist(): + assert ServeEntity.__tablename__ in db.metadata.tables + + +def test_entity_create(default_entity_dict): + # TODO: implement your test case + pass + + +def test_entity_unique_key(default_entity_dict): + # TODO: implement your test case + pass + + +def test_entity_get(default_entity_dict): + # TODO: implement your test case + pass + + +def test_entity_update(default_entity_dict): + # TODO: implement your test case + pass + + +def test_entity_delete(default_entity_dict): + # TODO: implement your test case + pass + + +def test_entity_all(): + # TODO: implement your test case + pass + + +def test_dao_create(dao, default_entity_dict): + # TODO: implement your test case + pass + + +def test_dao_get_one(dao, default_entity_dict): + # TODO: implement your test case + pass + + +def test_get_dao_get_list(dao): + # TODO: implement your test case + pass + + +def test_dao_update(dao, default_entity_dict): + # TODO: implement your test case + pass + + +def test_dao_delete(dao, default_entity_dict): + # TODO: implement your test case + pass + + +def test_dao_get_list_page(dao): + # TODO: implement your test case + pass + + +# Add more test cases according to your own logic diff --git a/dbgpt/serve/file/tests/test_service.py b/dbgpt/serve/file/tests/test_service.py new file mode 100644 index 000000000..00177924d --- /dev/null +++ b/dbgpt/serve/file/tests/test_service.py @@ -0,0 +1,78 @@ +from typing import List + +import pytest + +from dbgpt.component import SystemApp +from dbgpt.serve.core.tests.conftest import system_app +from dbgpt.storage.metadata import db + +from ..api.schemas import ServeRequest, ServerResponse +from ..models.models import ServeEntity +from ..service.service import Service + + +@pytest.fixture(autouse=True) +def setup_and_teardown(): + db.init_db("sqlite:///:memory:") + db.create_all() + yield + + +@pytest.fixture +def service(system_app: SystemApp): + instance = Service(system_app) + instance.init_app(system_app) + return instance + + +@pytest.fixture +def default_entity_dict(): + # TODO: build your default entity dict + return {} + + +@pytest.mark.parametrize( + "system_app", + [{"app_config": {"DEBUG": True, "dbgpt.serve.test_key": "hello"}}], + indirect=True, +) +def test_config_exists(service: Service): + system_app: SystemApp = service._system_app + assert system_app.config.get("DEBUG") is True + assert system_app.config.get("dbgpt.serve.test_key") == "hello" + assert service.config is not None + + +def test_service_create(service: Service, default_entity_dict): + # TODO: implement your test case + # eg. entity: ServerResponse = service.create(ServeRequest(**default_entity_dict)) + # ... + pass + + +def test_service_update(service: Service, default_entity_dict): + # TODO: implement your test case + pass + + +def test_service_get(service: Service, default_entity_dict): + # TODO: implement your test case + pass + + +def test_service_delete(service: Service, default_entity_dict): + # TODO: implement your test case + pass + + +def test_service_get_list(service: Service): + # TODO: implement your test case + pass + + +def test_service_get_list_by_page(service: Service): + # TODO: implement your test case + pass + + +# Add more test cases according to your own logic diff --git a/dbgpt/serve/flow/api/endpoints.py b/dbgpt/serve/flow/api/endpoints.py index 4b28641e8..4174502a5 100644 --- a/dbgpt/serve/flow/api/endpoints.py +++ b/dbgpt/serve/flow/api/endpoints.py @@ -1,9 +1,11 @@ +import io import json from functools import cache -from typing import Dict, List, Optional, Union +from typing import Dict, List, Literal, Optional, Union -from fastapi import APIRouter, Depends, HTTPException, Query, Request +from fastapi import APIRouter, Depends, File, HTTPException, Query, Request, UploadFile from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer +from starlette.responses import JSONResponse, StreamingResponse from dbgpt.component import SystemApp from dbgpt.core.awel.flow import ResourceMetadata, ViewMetadata @@ -14,6 +16,7 @@ from ..service.service import Service from ..service.variables_service import VariablesService from .schemas import ( + FlowDebugRequest, RefreshNodeRequest, ServeRequest, ServerResponse, @@ -322,10 +325,116 @@ async def update_variables( return Result.succ(res) -@router.post("/flow/debug") -async def debug(): - """Debug the flow.""" - # TODO: Implement the debug endpoint +@router.post("/flow/debug", dependencies=[Depends(check_api_key)]) +async def debug_flow( + flow_debug_request: FlowDebugRequest, service: Service = Depends(get_service) +): + """Run the flow in debug mode.""" + # Return the no-incremental stream by default + stream_iter = service.debug_flow(flow_debug_request, default_incremental=False) + + headers = { + "Content-Type": "text/event-stream", + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "Transfer-Encoding": "chunked", + } + return StreamingResponse( + service._wrapper_chat_stream_flow_str(stream_iter), + headers=headers, + media_type="text/event-stream", + ) + + +@router.get("/flow/export/{uid}", dependencies=[Depends(check_api_key)]) +async def export_flow( + uid: str, + export_type: Literal["json", "dbgpts"] = Query( + "json", description="export type(json or dbgpts)" + ), + format: Literal["file", "json"] = Query( + "file", description="response format(file or json)" + ), + file_name: Optional[str] = Query(default=None, description="file name to export"), + user_name: Optional[str] = Query(default=None, description="user name"), + sys_code: Optional[str] = Query(default=None, description="system code"), + service: Service = Depends(get_service), +): + """Export the flow to a file.""" + flow = service.get({"uid": uid, "user_name": user_name, "sys_code": sys_code}) + if not flow: + raise HTTPException(status_code=404, detail=f"Flow {uid} not found") + package_name = flow.name.replace("_", "-") + file_name = file_name or package_name + if export_type == "json": + flow_dict = {"flow": flow.to_dict()} + if format == "json": + return JSONResponse(content=flow_dict) + else: + # Return the json file + return StreamingResponse( + io.BytesIO(json.dumps(flow_dict, ensure_ascii=False).encode("utf-8")), + media_type="application/file", + headers={ + "Content-Disposition": f"attachment;filename={file_name}.json" + }, + ) + + elif export_type == "dbgpts": + from ..service.share_utils import _generate_dbgpts_zip + + if format == "json": + raise HTTPException( + status_code=400, detail="json response is not supported for dbgpts" + ) + + zip_buffer = await blocking_func_to_async( + global_system_app, _generate_dbgpts_zip, package_name, flow + ) + return StreamingResponse( + zip_buffer, + media_type="application/x-zip-compressed", + headers={"Content-Disposition": f"attachment;filename={file_name}.zip"}, + ) + + +@router.post( + "/flow/import", + response_model=Result[ServerResponse], + dependencies=[Depends(check_api_key)], +) +async def import_flow( + file: UploadFile = File(...), + save_flow: bool = Query( + False, description="Whether to save the flow after importing" + ), + service: Service = Depends(get_service), +): + """Import the flow from a file.""" + filename = file.filename + file_extension = filename.split(".")[-1].lower() + if file_extension == "json": + # Handle json file + json_content = await file.read() + json_dict = json.loads(json_content) + if "flow" not in json_dict: + raise HTTPException( + status_code=400, detail="invalid json file, missing 'flow' key" + ) + flow = ServeRequest.parse_obj(json_dict["flow"]) + elif file_extension == "zip": + from ..service.share_utils import _parse_flow_from_zip_file + + # Handle zip file + flow = await _parse_flow_from_zip_file(file, global_system_app) + else: + raise HTTPException( + status_code=400, detail=f"invalid file extension {file_extension}" + ) + if save_flow: + return Result.succ(service.create_and_save_dag(flow)) + else: + return Result.succ(flow) def init_endpoints(system_app: SystemApp) -> None: diff --git a/dbgpt/serve/flow/api/schemas.py b/dbgpt/serve/flow/api/schemas.py index 537996fe7..cf82de982 100644 --- a/dbgpt/serve/flow/api/schemas.py +++ b/dbgpt/serve/flow/api/schemas.py @@ -2,7 +2,7 @@ from dbgpt._private.pydantic import BaseModel, ConfigDict, Field from dbgpt.core.awel import CommonLLMHttpRequestBody -from dbgpt.core.awel.flow.flow_factory import FlowPanel +from dbgpt.core.awel.flow.flow_factory import FlowPanel, VariablesRequest from dbgpt.core.awel.util.parameter_util import RefreshOptionRequest from ..config import SERVE_APP_NAME_HUMP @@ -18,59 +18,6 @@ class ServerResponse(FlowPanel): model_config = ConfigDict(title=f"ServerResponse for {SERVE_APP_NAME_HUMP}") -class VariablesRequest(BaseModel): - """Variable request model. - - For creating a new variable in the DB-GPT. - """ - - key: str = Field( - ..., - description="The key of the variable to create", - examples=["dbgpt.model.openai.api_key"], - ) - name: str = Field( - ..., - description="The name of the variable to create", - examples=["my_first_openai_key"], - ) - label: str = Field( - ..., - description="The label of the variable to create", - examples=["My First OpenAI Key"], - ) - value: Any = Field( - ..., description="The value of the variable to create", examples=["1234567890"] - ) - value_type: Literal["str", "int", "float", "bool"] = Field( - "str", - description="The type of the value of the variable to create", - examples=["str", "int", "float", "bool"], - ) - category: Literal["common", "secret"] = Field( - ..., - description="The category of the variable to create", - examples=["common"], - ) - scope: str = Field( - ..., - description="The scope of the variable to create", - examples=["global"], - ) - scope_key: Optional[str] = Field( - ..., - description="The scope key of the variable to create", - examples=["dbgpt"], - ) - enabled: Optional[bool] = Field( - True, - description="Whether the variable is enabled", - examples=[True], - ) - user_name: Optional[str] = Field(None, description="User name") - sys_code: Optional[str] = Field(None, description="System code") - - class VariablesResponse(VariablesRequest): """Variable response model.""" diff --git a/dbgpt/serve/flow/service/service.py b/dbgpt/serve/flow/service/service.py index 83b79847f..3cdb136eb 100644 --- a/dbgpt/serve/flow/service/service.py +++ b/dbgpt/serve/flow/service/service.py @@ -8,7 +8,6 @@ from dbgpt._private.pydantic import model_to_json from dbgpt.component import SystemApp from dbgpt.core.awel import DAG, BaseOperator, CommonLLMHttpRequestBody -from dbgpt.core.awel.dag.dag_manager import DAGManager from dbgpt.core.awel.flow.flow_factory import ( FlowCategory, FlowFactory, @@ -33,7 +32,7 @@ from dbgpt.util.dbgpts.loader import DBGPTsLoader from dbgpt.util.pagination_utils import PaginationResult -from ..api.schemas import ServeRequest, ServerResponse +from ..api.schemas import FlowDebugRequest, ServeRequest, ServerResponse from ..config import SERVE_CONFIG_KEY_PREFIX, SERVE_SERVICE_COMPONENT_NAME, ServeConfig from ..models.models import ServeDao, ServeEntity @@ -146,7 +145,9 @@ def create_and_save_dag( raise ValueError( f"Create DAG {request.name} error, define_type: {request.define_type}, error: {str(e)}" ) from e - res = self.dao.create(request) + self.dao.create(request) + # Query from database + res = self.get({"uid": request.uid}) state = request.state try: @@ -563,3 +564,61 @@ def _parse_flow_category(self, dag: DAG) -> FlowCategory: return FlowCategory.CHAT_FLOW except Exception: return FlowCategory.COMMON + + async def debug_flow( + self, request: FlowDebugRequest, default_incremental: Optional[bool] = None + ) -> AsyncIterator[ModelOutput]: + """Debug the flow. + + Args: + request (FlowDebugRequest): The request + default_incremental (Optional[bool]): The default incremental configuration + + Returns: + AsyncIterator[ModelOutput]: The output + """ + from dbgpt.core.awel.dag.dag_manager import DAGMetadata, _parse_metadata + + dag = self._flow_factory.build(request.flow) + leaf_nodes = dag.leaf_nodes + if len(leaf_nodes) != 1: + raise ValueError("Chat Flow just support one leaf node in dag") + task = cast(BaseOperator, leaf_nodes[0]) + dag_metadata = _parse_metadata(dag) + # TODO: Run task with variables + variables = request.variables + dag_request = request.request + + if isinstance(request.request, CommonLLMHttpRequestBody): + incremental = request.request.incremental + elif isinstance(request.request, dict): + incremental = request.request.get("incremental", False) + else: + raise ValueError("Invalid request type") + + if default_incremental is not None: + incremental = default_incremental + + try: + async for output in safe_chat_stream_with_dag_task( + task, dag_request, incremental + ): + yield output + except HTTPException as e: + yield ModelOutput(error_code=1, text=e.detail, incremental=incremental) + except Exception as e: + yield ModelOutput(error_code=1, text=str(e), incremental=incremental) + + async def _wrapper_chat_stream_flow_str( + self, stream_iter: AsyncIterator[ModelOutput] + ) -> AsyncIterator[str]: + + async for output in stream_iter: + text = output.text + if text: + text = text.replace("\n", "\\n") + if output.error_code != 0: + yield f"data:[SERVER_ERROR]{text}\n\n" + break + else: + yield f"data:{text}\n\n" diff --git a/dbgpt/serve/flow/service/share_utils.py b/dbgpt/serve/flow/service/share_utils.py new file mode 100644 index 000000000..99ba222a9 --- /dev/null +++ b/dbgpt/serve/flow/service/share_utils.py @@ -0,0 +1,121 @@ +import io +import json +import os +import tempfile +import zipfile + +import aiofiles +import tomlkit +from fastapi import UploadFile + +from dbgpt.component import SystemApp +from dbgpt.serve.core import blocking_func_to_async + +from ..api.schemas import ServeRequest + + +def _generate_dbgpts_zip(package_name: str, flow: ServeRequest) -> io.BytesIO: + + zip_buffer = io.BytesIO() + flow_name = flow.name + flow_label = flow.label + flow_description = flow.description + dag_json = json.dumps(flow.flow_data.dict(), indent=4, ensure_ascii=False) + with zipfile.ZipFile(zip_buffer, "a", zipfile.ZIP_DEFLATED, False) as zip_file: + manifest = f"include dbgpts.toml\ninclude {flow_name}/definition/*.json" + readme = f"# {flow_label}\n\n{flow_description}" + zip_file.writestr(f"{package_name}/MANIFEST.in", manifest) + zip_file.writestr(f"{package_name}/README.md", readme) + zip_file.writestr( + f"{package_name}/{flow_name}/__init__.py", + "", + ) + zip_file.writestr( + f"{package_name}/{flow_name}/definition/flow_definition.json", + dag_json, + ) + dbgpts_toml = tomlkit.document() + # Add flow information + dbgpts_flow_toml = tomlkit.document() + dbgpts_flow_toml.add("label", "Simple Streaming Chat") + name_with_comment = tomlkit.string("awel_flow_simple_streaming_chat") + name_with_comment.comment("A unique name for all dbgpts") + dbgpts_flow_toml.add("name", name_with_comment) + + dbgpts_flow_toml.add("version", "0.1.0") + dbgpts_flow_toml.add( + "description", + flow_description, + ) + dbgpts_flow_toml.add("authors", []) + + definition_type_with_comment = tomlkit.string("json") + definition_type_with_comment.comment("How to define the flow, python or json") + dbgpts_flow_toml.add("definition_type", definition_type_with_comment) + + dbgpts_toml.add("flow", dbgpts_flow_toml) + + # Add python and json config + python_config = tomlkit.table() + dbgpts_toml.add("python_config", python_config) + + json_config = tomlkit.table() + json_config.add("file_path", "definition/flow_definition.json") + json_config.comment("Json config") + + dbgpts_toml.add("json_config", json_config) + + # Transform to string + toml_string = tomlkit.dumps(dbgpts_toml) + zip_file.writestr(f"{package_name}/dbgpts.toml", toml_string) + + pyproject_toml = tomlkit.document() + + # Add [tool.poetry] section + tool_poetry_toml = tomlkit.table() + tool_poetry_toml.add("name", package_name) + tool_poetry_toml.add("version", "0.1.0") + tool_poetry_toml.add("description", "A dbgpts package") + tool_poetry_toml.add("authors", []) + tool_poetry_toml.add("readme", "README.md") + pyproject_toml["tool"] = tomlkit.table() + pyproject_toml["tool"]["poetry"] = tool_poetry_toml + + # Add [tool.poetry.dependencies] section + dependencies = tomlkit.table() + dependencies.add("python", "^3.10") + pyproject_toml["tool"]["poetry"]["dependencies"] = dependencies + + # Add [build-system] section + build_system = tomlkit.table() + build_system.add("requires", ["poetry-core"]) + build_system.add("build-backend", "poetry.core.masonry.api") + pyproject_toml["build-system"] = build_system + + # Transform to string + pyproject_toml_string = tomlkit.dumps(pyproject_toml) + zip_file.writestr(f"{package_name}/pyproject.toml", pyproject_toml_string) + zip_buffer.seek(0) + return zip_buffer + + +async def _parse_flow_from_zip_file( + file: UploadFile, sys_app: SystemApp +) -> ServeRequest: + from dbgpt.util.dbgpts.loader import _load_flow_package_from_zip_path + + filename = file.filename + if not filename.endswith(".zip"): + raise ValueError("Uploaded file must be a ZIP file") + + with tempfile.TemporaryDirectory() as temp_dir: + zip_path = os.path.join(temp_dir, filename) + + # Save uploaded file to temporary directory + async with aiofiles.open(zip_path, "wb") as out_file: + while content := await file.read(1024 * 64): # Read in chunks of 64KB + await out_file.write(content) + flow = await blocking_func_to_async( + sys_app, _load_flow_package_from_zip_path, zip_path + ) + return flow diff --git a/dbgpt/storage/vector_store/milvus_store.py b/dbgpt/storage/vector_store/milvus_store.py index b8b036770..f0984d20e 100644 --- a/dbgpt/storage/vector_store/milvus_store.py +++ b/dbgpt/storage/vector_store/milvus_store.py @@ -35,7 +35,7 @@ "The uri of milvus store, if not set, will use the default " "uri." ), optional=True, - default="localhost", + default=None, ), Parameter.build_from( _("Port"), @@ -98,8 +98,8 @@ class MilvusVectorConfig(VectorStoreConfig): model_config = ConfigDict(arbitrary_types_allowed=True) - uri: str = Field( - default="localhost", + uri: Optional[str] = Field( + default=None, description="The uri of milvus store, if not set, will use the default uri.", ) port: str = Field( diff --git a/dbgpt/util/dbgpts/loader.py b/dbgpt/util/dbgpts/loader.py index 8545ad067..4151546e9 100644 --- a/dbgpt/util/dbgpts/loader.py +++ b/dbgpt/util/dbgpts/loader.py @@ -320,14 +320,19 @@ def _load_package_from_path(path: str): return parsed_packages -def _load_flow_package_from_path(name: str, path: str = INSTALL_DIR) -> FlowPackage: +def _load_flow_package_from_path( + name: str, path: str = INSTALL_DIR, filter_by_name: bool = True +) -> FlowPackage: raw_packages = _load_installed_package(path) new_name = name.replace("_", "-") - packages = [p for p in raw_packages if p.package == name or p.name == name] - if not packages: - packages = [ - p for p in raw_packages if p.package == new_name or p.name == new_name - ] + if filter_by_name: + packages = [p for p in raw_packages if p.package == name or p.name == name] + if not packages: + packages = [ + p for p in raw_packages if p.package == new_name or p.name == new_name + ] + else: + packages = raw_packages if not packages: raise ValueError(f"Can't find the package {name} or {new_name}") flow_package = _parse_package_metadata(packages[0]) @@ -336,6 +341,35 @@ def _load_flow_package_from_path(name: str, path: str = INSTALL_DIR) -> FlowPack return cast(FlowPackage, flow_package) +def _load_flow_package_from_zip_path(zip_path: str) -> FlowPanel: + import tempfile + import zipfile + + with tempfile.TemporaryDirectory() as temp_dir: + with zipfile.ZipFile(zip_path, "r") as zip_ref: + zip_ref.extractall(temp_dir) + package_names = os.listdir(temp_dir) + if not package_names: + raise ValueError("No package found in the zip file") + if len(package_names) > 1: + raise ValueError("Only support one package in the zip file") + package_name = package_names[0] + with open( + Path(temp_dir) / package_name / INSTALL_METADATA_FILE, mode="w+" + ) as f: + # Write the metadata + import tomlkit + + install_metadata = { + "name": package_name, + "repo": "local/dbgpts", + } + tomlkit.dump(install_metadata, f) + + package = _load_flow_package_from_path("", path=temp_dir, filter_by_name=False) + return _flow_package_to_flow_panel(package) + + def _flow_package_to_flow_panel(package: FlowPackage) -> FlowPanel: dict_value = { "name": package.name, @@ -345,6 +379,7 @@ def _flow_package_to_flow_panel(package: FlowPackage) -> FlowPanel: "description": package.description, "source": package.repo, "define_type": "json", + "authors": package.authors, } if isinstance(package, FlowJsonPackage): dict_value["flow_data"] = package.read_definition_json() diff --git a/examples/awel/awel_flow_ui_components.py b/examples/awel/awel_flow_ui_components.py index cba0c14df..7fa2dc236 100644 --- a/examples/awel/awel_flow_ui_components.py +++ b/examples/awel/awel_flow_ui_components.py @@ -2,7 +2,7 @@ import json import logging -from typing import List, Optional +from typing import Any, Dict, List, Optional from dbgpt.core.awel import MapOperator from dbgpt.core.awel.flow import ( @@ -15,6 +15,7 @@ ViewMetadata, ui, ) +from dbgpt.core.interface.file import FileStorageClient from dbgpt.core.interface.variables import ( BUILTIN_VARIABLES_CORE_EMBEDDINGS, BUILTIN_VARIABLES_CORE_FLOW_NODES, @@ -787,6 +788,109 @@ async def map(self, user_name: str) -> str: ) +class ExampleFlowUploadOperator(MapOperator[str, str]): + """An example flow operator that includes an upload as parameter.""" + + metadata = ViewMetadata( + label="Example Flow Upload", + name="example_flow_upload", + category=OperatorCategory.EXAMPLE, + description="An example flow operator that includes a upload as parameter.", + parameters=[ + Parameter.build_from( + "Single File Selector", + "file", + type=str, + optional=True, + default=None, + placeholder="Select the file", + description="The file you want to upload.", + ui=ui.UIUpload( + max_file_size=1024 * 1024 * 100, + up_event="after_select", + attr=ui.UIUpload.UIAttribute(max_count=1), + ), + ), + Parameter.build_from( + "Multiple Files Selector", + "multiple_files", + type=str, + is_list=True, + optional=True, + default=None, + placeholder="Select the multiple files", + description="The multiple files you want to upload.", + ui=ui.UIUpload( + max_file_size=1024 * 1024 * 100, + up_event="button_click", + attr=ui.UIUpload.UIAttribute(max_count=5), + ), + ), + ], + inputs=[ + IOField.build_from( + "User Name", + "user_name", + str, + description="The name of the user.", + ) + ], + outputs=[ + IOField.build_from( + "File", + "file", + str, + description="User's uploaded file.", + ) + ], + ) + + def __init__( + self, + file: Optional[str] = None, + multiple_files: Optional[List[str]] = None, + **kwargs, + ): + super().__init__(**kwargs) + self.file = file + self.multiple_files = multiple_files or [] + + async def map(self, user_name: str) -> str: + """Map the user name to the file.""" + + fsc = FileStorageClient.get_instance(self.system_app) + files_metadata = await self.blocking_func_to_async( + self._parse_files_metadata, fsc + ) + files_metadata_str = json.dumps(files_metadata, ensure_ascii=False) + return "Your name is %s, and you files are %s." % ( + user_name, + files_metadata_str, + ) + + def _parse_files_metadata(self, fsc: FileStorageClient) -> List[Dict[str, Any]]: + """Parse the files metadata.""" + if not self.file: + raise ValueError("The file is not uploaded.") + if not self.multiple_files: + raise ValueError("The multiple files are not uploaded.") + files = [self.file] + self.multiple_files + results = [] + for file in files: + _, metadata = fsc.get_file(file) + results.append( + { + "bucket": metadata.bucket, + "file_id": metadata.file_id, + "file_size": metadata.file_size, + "storage_type": metadata.storage_type, + "uri": metadata.uri, + "file_hash": metadata.file_hash, + } + ) + return results + + class ExampleFlowVariablesOperator(MapOperator[str, str]): """An example flow operator that includes a variables option."""