From 9c7b3b8194931ab6dffd012b73c719a9dc698e8e Mon Sep 17 00:00:00 2001 From: yihong0618 Date: Tue, 19 Dec 2023 18:04:14 +0800 Subject: [PATCH] feat: add gemini support Signed-off-by: yihong0618 --- README.zh.md | 1 + dbgpt/_private/config.py | 8 + dbgpt/config.py | 277 +++++++++++++++++++++++++++++++ dbgpt/configs/model_config.py | 1 + dbgpt/model/llm_out/proxy_llm.py | 2 + dbgpt/model/proxy/llms/gemini.py | 67 ++++++++ dbgpt/model/proxy/llms/zhipu.py | 35 +--- 7 files changed, 358 insertions(+), 33 deletions(-) create mode 100644 dbgpt/config.py create mode 100644 dbgpt/model/proxy/llms/gemini.py diff --git a/README.zh.md b/README.zh.md index 0ba4c40af..afe2189e0 100644 --- a/README.zh.md +++ b/README.zh.md @@ -122,6 +122,7 @@ DB-GPT是一个开源的数据库领域大模型框架。目的是构建大模 - [x] [智谱·ChatGLM](http://open.bigmodel.cn/) - [x] [讯飞·星火](https://xinghuo.xfyun.cn/) - [x] [Google·Bard](https://bard.google.com/) + - [x] [Google·Gemini](https://makersuite.google.com/app/apikey) - **隐私安全** diff --git a/dbgpt/_private/config.py b/dbgpt/_private/config.py index 05e08b3de..83ce4eecc 100644 --- a/dbgpt/_private/config.py +++ b/dbgpt/_private/config.py @@ -95,6 +95,14 @@ def __init__(self) -> None: os.environ["bc_proxyllm_proxy_api_secret"] = self.bc_proxy_api_secret os.environ["bc_proxyllm_proxyllm_backend"] = self.bc_model_version + # gemini proxy + self.gemini_proxy_api_key = os.getenv("GEMINI_PROXY_API_KEY") + if self.gemini_proxy_api_key: + os.environ["gemini_proxyllm_proxy_api_key"] = self.gemini_proxy_api_key + os.environ["gemini_proxyllm_proxyllm_backend"] = os.getenv( + "GEMINI_MODEL_VERSION", "gemini-pro" + ) + self.proxy_server_url = os.getenv("PROXY_SERVER_URL") self.elevenlabs_api_key = os.getenv("ELEVENLABS_API_KEY") diff --git a/dbgpt/config.py b/dbgpt/config.py new file mode 100644 index 000000000..9fe531727 --- /dev/null +++ b/dbgpt/config.py @@ -0,0 +1,277 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +from __future__ import annotations + +import os +from typing import List, Optional, TYPE_CHECKING + +from dbgpt.util.singleton import Singleton + +if TYPE_CHECKING: + from auto_gpt_plugin_template import AutoGPTPluginTemplate + from dbgpt.component import SystemApp + + +class Config(metaclass=Singleton): + """Configuration class to store the state of bools for different scripts access""" + + def __init__(self) -> None: + """Initialize the Config class""" + + self.NEW_SERVER_MODE = False + self.SERVER_LIGHT_MODE = False + + # Gradio language version: en, zh + self.LANGUAGE = os.getenv("LANGUAGE", "en") + self.WEB_SERVER_PORT = int(os.getenv("WEB_SERVER_PORT", 7860)) + + self.debug_mode = False + self.skip_reprompt = False + self.temperature = float(os.getenv("TEMPERATURE", 0.7)) + + # self.NUM_GPUS = int(os.getenv("NUM_GPUS", 1)) + + self.execute_local_commands = ( + os.getenv("EXECUTE_LOCAL_COMMANDS", "False").lower() == "true" + ) + # User agent header to use when making HTTP requests + # Some websites might just completely deny request with an error code if + # no user agent was found. + self.user_agent = os.getenv( + "USER_AGENT", + "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_4) AppleWebKit/537.36" + " (KHTML, like Gecko) Chrome/83.0.4103.97 Safari/537.36", + ) + + # This is a proxy server, just for test_py. we will remove this later. + self.proxy_api_key = os.getenv("PROXY_API_KEY") + self.bard_proxy_api_key = os.getenv("BARD_PROXY_API_KEY") + + # In order to be compatible with the new and old model parameter design + if self.bard_proxy_api_key: + os.environ["bard_proxyllm_proxy_api_key"] = self.bard_proxy_api_key + + # tongyi + self.tongyi_proxy_api_key = os.getenv("TONGYI_PROXY_API_KEY") + if self.tongyi_proxy_api_key: + os.environ["tongyi_proxyllm_proxy_api_key"] = self.tongyi_proxy_api_key + + # zhipu + self.zhipu_proxy_api_key = os.getenv("ZHIPU_PROXY_API_KEY") + if self.zhipu_proxy_api_key: + os.environ["zhipu_proxyllm_proxy_api_key"] = self.zhipu_proxy_api_key + os.environ["zhipu_proxyllm_proxyllm_backend"] = os.getenv( + "ZHIPU_MODEL_VERSION" + ) + + # wenxin + self.wenxin_proxy_api_key = os.getenv("WEN_XIN_API_KEY") + self.wenxin_proxy_api_secret = os.getenv("WEN_XIN_API_SECRET") + self.wenxin_model_version = os.getenv("WEN_XIN_MODEL_VERSION") + if self.wenxin_proxy_api_key and self.wenxin_proxy_api_secret: + os.environ["wenxin_proxyllm_proxy_api_key"] = self.wenxin_proxy_api_key + os.environ[ + "wenxin_proxyllm_proxy_api_secret" + ] = self.wenxin_proxy_api_secret + os.environ["wenxin_proxyllm_proxyllm_backend"] = self.wenxin_model_version + + # xunfei spark + self.spark_api_version = os.getenv("XUNFEI_SPARK_API_VERSION") + self.spark_proxy_api_key = os.getenv("XUNFEI_SPARK_API_KEY") + self.spark_proxy_api_secret = os.getenv("XUNFEI_SPARK_API_SECRET") + self.spark_proxy_api_appid = os.getenv("XUNFEI_SPARK_APPID") + if self.spark_proxy_api_key and self.spark_proxy_api_secret: + os.environ["spark_proxyllm_proxy_api_key"] = self.spark_proxy_api_key + os.environ["spark_proxyllm_proxy_api_secret"] = self.spark_proxy_api_secret + os.environ["spark_proxyllm_proxyllm_backend"] = self.spark_api_version + os.environ["spark_proxyllm_proxy_api_app_id"] = self.spark_proxy_api_appid + + # baichuan proxy + self.bc_proxy_api_key = os.getenv("BAICHUAN_PROXY_API_KEY") + self.bc_proxy_api_secret = os.getenv("BAICHUAN_PROXY_API_SECRET") + self.bc_model_version = os.getenv("BAICHUN_MODEL_NAME") + if self.bc_proxy_api_key and self.bc_proxy_api_secret: + os.environ["bc_proxyllm_proxy_api_key"] = self.bc_proxy_api_key + os.environ["bc_proxyllm_proxy_api_secret"] = self.bc_proxy_api_secret + os.environ["bc_proxyllm_proxyllm_backend"] = self.bc_model_version + + # gemini proxy + self.gemini_proxy_api_key = os.getenv("GEMINI_PROXY_API_KEY") + if self.gemini_proxy_api_key: + os.environ["gemini_proxyllm_proxy_api_key"] = self.gemini_proxy_api_key + os.environ["gemini_proxyllm_proxyllm_backend"] = os.getenv( + "GEMINI_MODEL_VERSION", "gemini-pro" + ) + + self.proxy_server_url = os.getenv("PROXY_SERVER_URL") + + self.elevenlabs_api_key = os.getenv("ELEVENLABS_API_KEY") + self.elevenlabs_voice_1_id = os.getenv("ELEVENLABS_VOICE_1_ID") + self.elevenlabs_voice_2_id = os.getenv("ELEVENLABS_VOICE_2_ID") + + self.use_mac_os_tts = False + self.use_mac_os_tts = os.getenv("USE_MAC_OS_TTS") + + self.authorise_key = os.getenv("AUTHORISE_COMMAND_KEY", "y") + self.exit_key = os.getenv("EXIT_KEY", "n") + self.image_provider = os.getenv("IMAGE_PROVIDER", True) + self.image_size = int(os.getenv("IMAGE_SIZE", 256)) + + self.huggingface_api_token = os.getenv("HUGGINGFACE_API_TOKEN") + self.image_provider = os.getenv("IMAGE_PROVIDER") + self.image_size = int(os.getenv("IMAGE_SIZE", 256)) + self.huggingface_image_model = os.getenv( + "HUGGINGFACE_IMAGE_MODEL", "CompVis/stable-diffusion-v1-4" + ) + self.huggingface_audio_to_text_model = os.getenv( + "HUGGINGFACE_AUDIO_TO_TEXT_MODEL" + ) + self.speak_mode = False + + from dbgpt.core._private.prompt_registry import PromptTemplateRegistry + + self.prompt_template_registry = PromptTemplateRegistry() + ### Related configuration of built-in commands + self.command_registry = [] + + ### Relate configuration of disply commands + self.command_disply = [] + + disabled_command_categories = os.getenv("DISABLED_COMMAND_CATEGORIES") + if disabled_command_categories: + self.disabled_command_categories = disabled_command_categories.split(",") + else: + self.disabled_command_categories = [] + + self.execute_local_commands = ( + os.getenv("EXECUTE_LOCAL_COMMANDS", "False").lower() == "true" + ) + ### message stor file + self.message_dir = os.getenv("MESSAGE_HISTORY_DIR", "../../message") + + ### The associated configuration parameters of the plug-in control the loading and use of the plug-in + + self.plugins: List["AutoGPTPluginTemplate"] = [] + self.plugins_openai = [] + self.plugins_auto_load = os.getenv("AUTO_LOAD_PLUGIN", "True").lower() == "true" + + self.plugins_git_branch = os.getenv("PLUGINS_GIT_BRANCH", "plugin_dashboard") + + plugins_allowlist = os.getenv("ALLOWLISTED_PLUGINS") + if plugins_allowlist: + self.plugins_allowlist = plugins_allowlist.split(",") + else: + self.plugins_allowlist = [] + + plugins_denylist = os.getenv("DENYLISTED_PLUGINS") + if plugins_denylist: + self.plugins_denylist = plugins_denylist.split(",") + else: + self.plugins_denylist = [] + ### Native SQL Execution Capability Control Configuration + self.NATIVE_SQL_CAN_RUN_DDL = ( + os.getenv("NATIVE_SQL_CAN_RUN_DDL", "True").lower() == "true" + ) + self.NATIVE_SQL_CAN_RUN_WRITE = ( + os.getenv("NATIVE_SQL_CAN_RUN_WRITE", "True").lower() == "true" + ) + + self.LOCAL_DB_MANAGE = None + + ###dbgpt meta info database connection configuration + self.LOCAL_DB_HOST = os.getenv("LOCAL_DB_HOST") + self.LOCAL_DB_PATH = os.getenv("LOCAL_DB_PATH", "data/default_sqlite.db") + self.LOCAL_DB_TYPE = os.getenv("LOCAL_DB_TYPE", "sqlite") + if self.LOCAL_DB_HOST is None and self.LOCAL_DB_PATH == "": + self.LOCAL_DB_HOST = "127.0.0.1" + + self.LOCAL_DB_NAME = os.getenv("LOCAL_DB_NAME", "dbgpt") + self.LOCAL_DB_PORT = int(os.getenv("LOCAL_DB_PORT", 3306)) + self.LOCAL_DB_USER = os.getenv("LOCAL_DB_USER", "root") + self.LOCAL_DB_PASSWORD = os.getenv("LOCAL_DB_PASSWORD", "aa123456") + self.LOCAL_DB_POOL_SIZE = int(os.getenv("LOCAL_DB_POOL_SIZE", 10)) + self.LOCAL_DB_POOL_OVERFLOW = int(os.getenv("LOCAL_DB_POOL_OVERFLOW", 20)) + + self.CHAT_HISTORY_STORE_TYPE = os.getenv("CHAT_HISTORY_STORE_TYPE", "db") + + ### LLM Model Service Configuration + self.LLM_MODEL = os.getenv("LLM_MODEL", "vicuna-13b-v1.5") + self.LLM_MODEL_PATH = os.getenv("LLM_MODEL_PATH") + + ### Proxy llm backend, this configuration is only valid when "LLM_MODEL=proxyllm" + ### When we use the rest API provided by deployment frameworks like fastchat as a proxyllm, "PROXYLLM_BACKEND" is the model they actually deploy. + ### We need to use "PROXYLLM_BACKEND" to load the prompt of the corresponding scene. + self.PROXYLLM_BACKEND = None + if self.LLM_MODEL == "proxyllm": + self.PROXYLLM_BACKEND = os.getenv("PROXYLLM_BACKEND") + + self.LIMIT_MODEL_CONCURRENCY = int(os.getenv("LIMIT_MODEL_CONCURRENCY", 5)) + self.MAX_POSITION_EMBEDDINGS = int(os.getenv("MAX_POSITION_EMBEDDINGS", 4096)) + self.MODEL_PORT = os.getenv("MODEL_PORT", 8000) + self.MODEL_SERVER = os.getenv( + "MODEL_SERVER", "http://127.0.0.1" + ":" + str(self.MODEL_PORT) + ) + + ### Vector Store Configuration + self.VECTOR_STORE_TYPE = os.getenv("VECTOR_STORE_TYPE", "Chroma") + self.MILVUS_URL = os.getenv("MILVUS_URL", "127.0.0.1") + self.MILVUS_PORT = os.getenv("MILVUS_PORT", "19530") + self.MILVUS_USERNAME = os.getenv("MILVUS_USERNAME", None) + self.MILVUS_PASSWORD = os.getenv("MILVUS_PASSWORD", None) + + # QLoRA + self.QLoRA = os.getenv("QUANTIZE_QLORA", "True") + self.IS_LOAD_8BIT = os.getenv("QUANTIZE_8bit", "True").lower() == "true" + self.IS_LOAD_4BIT = os.getenv("QUANTIZE_4bit", "False").lower() == "true" + if self.IS_LOAD_8BIT and self.IS_LOAD_4BIT: + self.IS_LOAD_8BIT = False + # In order to be compatible with the new and old model parameter design + os.environ["load_8bit"] = str(self.IS_LOAD_8BIT) + os.environ["load_4bit"] = str(self.IS_LOAD_4BIT) + + ### EMBEDDING Configuration + self.EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "text2vec") + self.KNOWLEDGE_CHUNK_SIZE = int(os.getenv("KNOWLEDGE_CHUNK_SIZE", 100)) + self.KNOWLEDGE_CHUNK_OVERLAP = int(os.getenv("KNOWLEDGE_CHUNK_OVERLAP", 50)) + self.KNOWLEDGE_SEARCH_TOP_SIZE = int(os.getenv("KNOWLEDGE_SEARCH_TOP_SIZE", 5)) + # default recall similarity score, between 0 and 1 + self.KNOWLEDGE_SEARCH_RECALL_SCORE = float( + os.getenv("KNOWLEDGE_SEARCH_RECALL_SCORE", 0.3) + ) + self.KNOWLEDGE_SEARCH_MAX_TOKEN = int( + os.getenv("KNOWLEDGE_SEARCH_MAX_TOKEN", 2000) + ) + # Whether to enable Chat Knowledge Search Rewrite Mode + self.KNOWLEDGE_SEARCH_REWRITE = ( + os.getenv("KNOWLEDGE_SEARCH_REWRITE", "False").lower() == "true" + ) + # Control whether to display the source document of knowledge on the front end. + self.KNOWLEDGE_CHAT_SHOW_RELATIONS = ( + os.getenv("KNOWLEDGE_CHAT_SHOW_RELATIONS", "False").lower() == "true" + ) + + ### SUMMARY_CONFIG Configuration + self.SUMMARY_CONFIG = os.getenv("SUMMARY_CONFIG", "FAST") + + self.MAX_GPU_MEMORY = os.getenv("MAX_GPU_MEMORY", None) + + ### Log level + self.DBGPT_LOG_LEVEL = os.getenv("DBGPT_LOG_LEVEL", "INFO") + + self.SYSTEM_APP: Optional["SystemApp"] = None + + ### Temporary configuration + self.USE_FASTCHAT: bool = os.getenv("USE_FASTCHAT", "True").lower() == "true" + + self.MODEL_CACHE_ENABLE: bool = ( + os.getenv("MODEL_CACHE_ENABLE", "True").lower() == "true" + ) + self.MODEL_CACHE_STORAGE_TYPE: str = os.getenv( + "MODEL_CACHE_STORAGE_TYPE", "disk" + ) + self.MODEL_CACHE_MAX_MEMORY_MB: int = int( + os.getenv("MODEL_CACHE_MAX_MEMORY_MB", 256) + ) + self.MODEL_CACHE_STORAGE_DISK_DIR: str = os.getenv( + "MODEL_CACHE_STORAGE_DISK_DIR" + ) diff --git a/dbgpt/configs/model_config.py b/dbgpt/configs/model_config.py index ccbe4e3d4..05a0636b8 100644 --- a/dbgpt/configs/model_config.py +++ b/dbgpt/configs/model_config.py @@ -60,6 +60,7 @@ def get_device() -> str: "wenxin_proxyllm": "wenxin_proxyllm", "tongyi_proxyllm": "tongyi_proxyllm", "zhipu_proxyllm": "zhipu_proxyllm", + "gemini_proxyllm": "gemini_proxyllm", "bc_proxyllm": "bc_proxyllm", "spark_proxyllm": "spark_proxyllm", "llama-2-7b": os.path.join(MODEL_PATH, "Llama-2-7b-chat-hf"), diff --git a/dbgpt/model/llm_out/proxy_llm.py b/dbgpt/model/llm_out/proxy_llm.py index 89470e838..14b2c3177 100644 --- a/dbgpt/model/llm_out/proxy_llm.py +++ b/dbgpt/model/llm_out/proxy_llm.py @@ -8,6 +8,7 @@ from dbgpt.model.proxy.llms.wenxin import wenxin_generate_stream from dbgpt.model.proxy.llms.tongyi import tongyi_generate_stream from dbgpt.model.proxy.llms.zhipu import zhipu_generate_stream +from dbgpt.model.proxy.llms.gemini import gemini_generate_stream from dbgpt.model.proxy.llms.baichuan import baichuan_generate_stream from dbgpt.model.proxy.llms.spark import spark_generate_stream from dbgpt.model.proxy.llms.proxy_model import ProxyModel @@ -25,6 +26,7 @@ def proxyllm_generate_stream( "wenxin_proxyllm": wenxin_generate_stream, "tongyi_proxyllm": tongyi_generate_stream, "zhipu_proxyllm": zhipu_generate_stream, + "gemini_proxyllm": gemini_generate_stream, "bc_proxyllm": baichuan_generate_stream, "spark_proxyllm": spark_generate_stream, } diff --git a/dbgpt/model/proxy/llms/gemini.py b/dbgpt/model/proxy/llms/gemini.py new file mode 100644 index 000000000..0a97d8236 --- /dev/null +++ b/dbgpt/model/proxy/llms/gemini.py @@ -0,0 +1,67 @@ +from typing import List + +from dbgpt.model.proxy.llms.proxy_model import ProxyModel +from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType + +GEMINI_DEFAULT_MODEL = "gemini-pro" + +# global history for the easy to support history +# TODO refactor the history thing in the future +history = [] + + +def gemini_generate_stream( + model: ProxyModel, tokenizer, params, device, context_len=2048 +): + """Zhipu ai, see: https://open.bigmodel.cn/dev/api#overview""" + model_params = model.get_params() + print(f"Model: {model}, model_params: {model_params}") + global history + + # TODO proxy model use unified config? + proxy_api_key = model_params.proxy_api_key + proxyllm_backend = GEMINI_DEFAULT_MODEL or model_params.proxyllm_backend + + generation_config = { + "temperature": 0.7, + "top_p": 1, + "top_k": 1, + "max_output_tokens": 2048, + } + + safety_settings = [ + {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"}, + { + "category": "HARM_CATEGORY_HATE_SPEECH", + "threshold": "BLOCK_MEDIUM_AND_ABOVE", + }, + { + "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "threshold": "BLOCK_MEDIUM_AND_ABOVE", + }, + { + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "threshold": "BLOCK_MEDIUM_AND_ABOVE", + }, + ] + + import google.generativeai as genai + + genai.configure(api_key=proxy_api_key) + model = genai.GenerativeModel( + model_name=proxyllm_backend, + generation_config=generation_config, + safety_settings=safety_settings, + ) + messages = params["messages"][0].content + chat = model.start_chat(history=history) + response = chat.send_message(messages, stream=True) + text = "" + for chunk in response: + text += chunk.text + yield text + # only keep the last five message + if len(history) > 10: + history = chat.history[2:] + else: + history = chat.history diff --git a/dbgpt/model/proxy/llms/zhipu.py b/dbgpt/model/proxy/llms/zhipu.py index 66b0b3dc6..90f1d3d2b 100644 --- a/dbgpt/model/proxy/llms/zhipu.py +++ b/dbgpt/model/proxy/llms/zhipu.py @@ -6,7 +6,7 @@ CHATGLM_DEFAULT_MODEL = "chatglm_pro" -def __convert_2_wenxin_messages(messages: List[ModelMessage]): +def __convert_2_zhipu_messages(messages: List[ModelMessage]): chat_round = 0 wenxin_messages = [] @@ -57,38 +57,7 @@ def zhipu_generate_stream( zhipuai.api_key = proxy_api_key messages: List[ModelMessage] = params["messages"] - # Add history conversation - # system = "" - # if len(messages) > 1 and messages[0].role == ModelMessageRoleType.SYSTEM: - # role_define = messages.pop(0) - # system = role_define.content - # else: - # message = messages.pop(0) - # if message.role == ModelMessageRoleType.HUMAN: - # history.append({"role": "user", "content": message.content}) - # for message in messages: - # if message.role == ModelMessageRoleType.SYSTEM: - # history.append({"role": "user", "content": message.content}) - # # elif message.role == ModelMessageRoleType.HUMAN: - # # history.append({"role": "user", "content": message.content}) - # elif message.role == ModelMessageRoleType.AI: - # history.append({"role": "assistant", "content": message.content}) - # else: - # pass - # - # # temp_his = history[::-1] - # temp_his = history - # last_user_input = None - # for m in temp_his: - # if m["role"] == "user": - # last_user_input = m - # break - # - # if last_user_input: - # history.remove(last_user_input) - # history.append(last_user_input) - - history, systems = __convert_2_wenxin_messages(messages) + history, systems = __convert_2_zhipu_messages(messages) res = zhipuai.model_api.sse_invoke( model=proxyllm_backend, prompt=history,