From 12234ae2583ef1f7a5f155908cc385d3c736c693 Mon Sep 17 00:00:00 2001 From: yihong Date: Sat, 23 Dec 2023 11:10:42 +0800 Subject: [PATCH] feat: add gemini support (#953) Signed-off-by: yihong0618 Signed-off-by: Fangyin Cheng Co-authored-by: Fangyin Cheng --- README.zh.md | 1 + dbgpt/_private/config.py | 10 +- dbgpt/configs/model_config.py | 1 + dbgpt/core/interface/message.py | 62 ++++++++++-- dbgpt/core/interface/tests/test_message.py | 65 ++++++++++++ dbgpt/model/llm_out/proxy_llm.py | 2 + dbgpt/model/proxy/llms/gemini.py | 109 +++++++++++++++++++++ dbgpt/model/proxy/llms/zhipu.py | 35 +------ 8 files changed, 243 insertions(+), 42 deletions(-) create mode 100644 dbgpt/model/proxy/llms/gemini.py diff --git a/README.zh.md b/README.zh.md index 61cbc7def..857e9a241 100644 --- a/README.zh.md +++ b/README.zh.md @@ -123,6 +123,7 @@ DB-GPT是一个开源的数据库领域大模型框架。目的是构建大模 - [x] [智谱·ChatGLM](http://open.bigmodel.cn/) - [x] [讯飞·星火](https://xinghuo.xfyun.cn/) - [x] [Google·Bard](https://bard.google.com/) + - [x] [Google·Gemini](https://makersuite.google.com/app/apikey) - **隐私安全** diff --git a/dbgpt/_private/config.py b/dbgpt/_private/config.py index 05e08b3de..9fe531727 100644 --- a/dbgpt/_private/config.py +++ b/dbgpt/_private/config.py @@ -61,7 +61,7 @@ def __init__(self) -> None: if self.zhipu_proxy_api_key: os.environ["zhipu_proxyllm_proxy_api_key"] = self.zhipu_proxy_api_key os.environ["zhipu_proxyllm_proxyllm_backend"] = os.getenv( - "ZHIPU_MODEL_VERSION", "chatglm_pro" + "ZHIPU_MODEL_VERSION" ) # wenxin @@ -95,6 +95,14 @@ def __init__(self) -> None: os.environ["bc_proxyllm_proxy_api_secret"] = self.bc_proxy_api_secret os.environ["bc_proxyllm_proxyllm_backend"] = self.bc_model_version + # gemini proxy + self.gemini_proxy_api_key = os.getenv("GEMINI_PROXY_API_KEY") + if self.gemini_proxy_api_key: + os.environ["gemini_proxyllm_proxy_api_key"] = self.gemini_proxy_api_key + os.environ["gemini_proxyllm_proxyllm_backend"] = os.getenv( + "GEMINI_MODEL_VERSION", "gemini-pro" + ) + self.proxy_server_url = os.getenv("PROXY_SERVER_URL") self.elevenlabs_api_key = os.getenv("ELEVENLABS_API_KEY") diff --git a/dbgpt/configs/model_config.py b/dbgpt/configs/model_config.py index c9f123677..136c57e43 100644 --- a/dbgpt/configs/model_config.py +++ b/dbgpt/configs/model_config.py @@ -60,6 +60,7 @@ def get_device() -> str: "wenxin_proxyllm": "wenxin_proxyllm", "tongyi_proxyllm": "tongyi_proxyllm", "zhipu_proxyllm": "zhipu_proxyllm", + "gemini_proxyllm": "gemini_proxyllm", "bc_proxyllm": "bc_proxyllm", "spark_proxyllm": "spark_proxyllm", "llama-2-7b": os.path.join(MODEL_PATH, "Llama-2-7b-chat-hf"), diff --git a/dbgpt/core/interface/message.py b/dbgpt/core/interface/message.py index 2b1439c6d..2493ebb53 100755 --- a/dbgpt/core/interface/message.py +++ b/dbgpt/core/interface/message.py @@ -202,19 +202,65 @@ def _messages_from_dict(messages: List[Dict]) -> List[BaseMessage]: return [_message_from_dict(m) for m in messages] -def _parse_model_messages( +def parse_model_messages( messages: List[ModelMessage], ) -> Tuple[str, List[str], List[List[str, str]]]: """ - Parameters: - messages: List of message from base chat. + Parse model messages to extract the user prompt, system messages, and a history of conversation. + + This function analyzes a list of ModelMessage objects, identifying the role of each message (e.g., human, system, ai) + and categorizes them accordingly. The last message is expected to be from the user (human), and it's treated as + the current user prompt. System messages are extracted separately, and the conversation history is compiled into + pairs of human and AI messages. + + Args: + messages (List[ModelMessage]): List of messages from a chat conversation. + Returns: - A tuple contains user prompt, system message list and history message list - str: user prompt - List[str]: system messages - List[List[str]]: history message of user and assistant + tuple: A tuple containing the user prompt, list of system messages, and the conversation history. + The conversation history is a list of message pairs, each containing a user message and the corresponding AI response. + + Examples: + .. code-block:: python + + # Example 1: Single round of conversation + messages = [ + ModelMessage(role="human", content="Hello"), + ModelMessage(role="ai", content="Hi there!"), + ModelMessage(role="human", content="How are you?"), + ] + user_prompt, system_messages, history = parse_model_messages(messages) + # user_prompt: "How are you?" + # system_messages: [] + # history: [["Hello", "Hi there!"]] + + # Example 2: Conversation with system messages + messages = [ + ModelMessage(role="system", content="System initializing..."), + ModelMessage(role="human", content="Is it sunny today?"), + ModelMessage(role="ai", content="Yes, it's sunny."), + ModelMessage(role="human", content="Great!"), + ] + user_prompt, system_messages, history = parse_model_messages(messages) + # user_prompt: "Great!" + # system_messages: ["System initializing..."] + # history: [["Is it sunny today?", "Yes, it's sunny."]] + + # Example 3: Multiple rounds with system message + messages = [ + ModelMessage(role="human", content="Hi"), + ModelMessage(role="ai", content="Hello!"), + ModelMessage(role="system", content="Error 404"), + ModelMessage(role="human", content="What's the error?"), + ModelMessage(role="ai", content="Just a joke."), + ModelMessage(role="human", content="Funny!"), + ] + user_prompt, system_messages, history = parse_model_messages(messages) + # user_prompt: "Funny!" + # system_messages: ["Error 404"] + # history: [["Hi", "Hello!"], ["What's the error?", "Just a joke."]] """ - user_prompt = "" + system_messages: List[str] = [] history_messages: List[List[str]] = [[]] diff --git a/dbgpt/core/interface/tests/test_message.py b/dbgpt/core/interface/tests/test_message.py index 41f5f36c5..0650b1f67 100755 --- a/dbgpt/core/interface/tests/test_message.py +++ b/dbgpt/core/interface/tests/test_message.py @@ -324,6 +324,71 @@ def test_load_from_storage(storage_conversation, in_memory_storage): assert isinstance(new_conversation.messages[1], AIMessage) +def test_parse_model_messages_no_history_messages(): + messages = [ + ModelMessage(role=ModelMessageRoleType.HUMAN, content="Hello"), + ] + user_prompt, system_messages, history_messages = parse_model_messages(messages) + assert user_prompt == "Hello" + assert system_messages == [] + assert history_messages == [] + + +def test_parse_model_messages_single_round_conversation(): + messages = [ + ModelMessage(role=ModelMessageRoleType.HUMAN, content="Hello"), + ModelMessage(role=ModelMessageRoleType.AI, content="Hi there!"), + ModelMessage(role=ModelMessageRoleType.HUMAN, content="Hello again"), + ] + user_prompt, system_messages, history_messages = parse_model_messages(messages) + assert user_prompt == "Hello again" + assert system_messages == [] + assert history_messages == [["Hello", "Hi there!"]] + + +def test_parse_model_messages_two_round_conversation_with_system_message(): + messages = [ + ModelMessage( + role=ModelMessageRoleType.SYSTEM, content="System initializing..." + ), + ModelMessage(role=ModelMessageRoleType.HUMAN, content="How's the weather?"), + ModelMessage(role=ModelMessageRoleType.AI, content="It's sunny!"), + ModelMessage(role=ModelMessageRoleType.HUMAN, content="Great to hear!"), + ] + user_prompt, system_messages, history_messages = parse_model_messages(messages) + assert user_prompt == "Great to hear!" + assert system_messages == ["System initializing..."] + assert history_messages == [["How's the weather?", "It's sunny!"]] + + +def test_parse_model_messages_three_round_conversation(): + messages = [ + ModelMessage(role=ModelMessageRoleType.HUMAN, content="Hi"), + ModelMessage(role=ModelMessageRoleType.AI, content="Hello!"), + ModelMessage(role=ModelMessageRoleType.HUMAN, content="What's up?"), + ModelMessage(role=ModelMessageRoleType.AI, content="Not much, you?"), + ModelMessage(role=ModelMessageRoleType.HUMAN, content="Same here."), + ] + user_prompt, system_messages, history_messages = parse_model_messages(messages) + assert user_prompt == "Same here." + assert system_messages == [] + assert history_messages == [["Hi", "Hello!"], ["What's up?", "Not much, you?"]] + + +def test_parse_model_messages_multiple_system_messages(): + messages = [ + ModelMessage(role=ModelMessageRoleType.SYSTEM, content="System start"), + ModelMessage(role=ModelMessageRoleType.HUMAN, content="Hey"), + ModelMessage(role=ModelMessageRoleType.AI, content="Hello!"), + ModelMessage(role=ModelMessageRoleType.SYSTEM, content="System check"), + ModelMessage(role=ModelMessageRoleType.HUMAN, content="How are you?"), + ] + user_prompt, system_messages, history_messages = parse_model_messages(messages) + assert user_prompt == "How are you?" + assert system_messages == ["System start", "System check"] + assert history_messages == [["Hey", "Hello!"]] + + def test_to_openai_messages( human_model_message, ai_model_message, system_model_message ): diff --git a/dbgpt/model/llm_out/proxy_llm.py b/dbgpt/model/llm_out/proxy_llm.py index 89470e838..14b2c3177 100644 --- a/dbgpt/model/llm_out/proxy_llm.py +++ b/dbgpt/model/llm_out/proxy_llm.py @@ -8,6 +8,7 @@ from dbgpt.model.proxy.llms.wenxin import wenxin_generate_stream from dbgpt.model.proxy.llms.tongyi import tongyi_generate_stream from dbgpt.model.proxy.llms.zhipu import zhipu_generate_stream +from dbgpt.model.proxy.llms.gemini import gemini_generate_stream from dbgpt.model.proxy.llms.baichuan import baichuan_generate_stream from dbgpt.model.proxy.llms.spark import spark_generate_stream from dbgpt.model.proxy.llms.proxy_model import ProxyModel @@ -25,6 +26,7 @@ def proxyllm_generate_stream( "wenxin_proxyllm": wenxin_generate_stream, "tongyi_proxyllm": tongyi_generate_stream, "zhipu_proxyllm": zhipu_generate_stream, + "gemini_proxyllm": gemini_generate_stream, "bc_proxyllm": baichuan_generate_stream, "spark_proxyllm": spark_generate_stream, } diff --git a/dbgpt/model/proxy/llms/gemini.py b/dbgpt/model/proxy/llms/gemini.py new file mode 100644 index 000000000..53122ce87 --- /dev/null +++ b/dbgpt/model/proxy/llms/gemini.py @@ -0,0 +1,109 @@ +from typing import List, Tuple, Dict, Any + +from dbgpt.model.proxy.llms.proxy_model import ProxyModel +from dbgpt.core.interface.message import ModelMessage, parse_model_messages + +GEMINI_DEFAULT_MODEL = "gemini-pro" + + +def gemini_generate_stream( + model: ProxyModel, tokenizer, params, device, context_len=2048 +): + """Zhipu ai, see: https://open.bigmodel.cn/dev/api#overview""" + model_params = model.get_params() + print(f"Model: {model}, model_params: {model_params}") + global history + + # TODO proxy model use unified config? + proxy_api_key = model_params.proxy_api_key + proxyllm_backend = GEMINI_DEFAULT_MODEL or model_params.proxyllm_backend + + generation_config = { + "temperature": 0.7, + "top_p": 1, + "top_k": 1, + "max_output_tokens": 2048, + } + + safety_settings = [ + {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"}, + { + "category": "HARM_CATEGORY_HATE_SPEECH", + "threshold": "BLOCK_MEDIUM_AND_ABOVE", + }, + { + "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "threshold": "BLOCK_MEDIUM_AND_ABOVE", + }, + { + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "threshold": "BLOCK_MEDIUM_AND_ABOVE", + }, + ] + + import google.generativeai as genai + + if model_params.proxy_api_base: + from google.api_core import client_options + + client_opts = client_options.ClientOptions( + api_endpoint=model_params.proxy_api_base + ) + genai.configure( + api_key=proxy_api_key, transport="rest", client_options=client_opts + ) + else: + genai.configure(api_key=proxy_api_key) + model = genai.GenerativeModel( + model_name=proxyllm_backend, + generation_config=generation_config, + safety_settings=safety_settings, + ) + messages: List[ModelMessage] = params["messages"] + user_prompt, gemini_hist = _transform_to_gemini_messages(messages) + chat = model.start_chat(history=gemini_hist) + response = chat.send_message(user_prompt, stream=True) + text = "" + for chunk in response: + text += chunk.text + print(text) + yield text + + +def _transform_to_gemini_messages( + messages: List[ModelMessage], +) -> Tuple[str, List[Dict[str, Any]]]: + """Transform messages to gemini format + + See https://github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/getting-started/intro_gemini_python.ipynb + + Args: + messages (List[ModelMessage]): messages + + Returns: + Tuple[str, List[Dict[str, Any]]]: user_prompt, gemini_hist + + Examples: + .. code-block:: python + + messages = [ + ModelMessage(role="human", content="Hello"), + ModelMessage(role="ai", content="Hi there!"), + ModelMessage(role="human", content="How are you?"), + ] + user_prompt, gemini_hist = _transform_to_gemini_messages(messages) + assert user_prompt == "How are you?" + assert gemini_hist == [ + {"role": "user", "parts": {"text": "Hello"}}, + {"role": "model", "parts": {"text": "Hi there!"}} + ] + """ + user_prompt, system_messages, history_messages = parse_model_messages(messages) + if system_messages: + user_prompt = "".join(system_messages) + "\n" + user_prompt + gemini_hist = [] + if history_messages: + for user_message, model_message in history_messages: + gemini_hist.append({"role": "user", "parts": {"text": user_message}}) + gemini_hist.append({"role": "model", "parts": {"text": model_message}}) + return user_prompt, gemini_hist diff --git a/dbgpt/model/proxy/llms/zhipu.py b/dbgpt/model/proxy/llms/zhipu.py index 66b0b3dc6..90f1d3d2b 100644 --- a/dbgpt/model/proxy/llms/zhipu.py +++ b/dbgpt/model/proxy/llms/zhipu.py @@ -6,7 +6,7 @@ CHATGLM_DEFAULT_MODEL = "chatglm_pro" -def __convert_2_wenxin_messages(messages: List[ModelMessage]): +def __convert_2_zhipu_messages(messages: List[ModelMessage]): chat_round = 0 wenxin_messages = [] @@ -57,38 +57,7 @@ def zhipu_generate_stream( zhipuai.api_key = proxy_api_key messages: List[ModelMessage] = params["messages"] - # Add history conversation - # system = "" - # if len(messages) > 1 and messages[0].role == ModelMessageRoleType.SYSTEM: - # role_define = messages.pop(0) - # system = role_define.content - # else: - # message = messages.pop(0) - # if message.role == ModelMessageRoleType.HUMAN: - # history.append({"role": "user", "content": message.content}) - # for message in messages: - # if message.role == ModelMessageRoleType.SYSTEM: - # history.append({"role": "user", "content": message.content}) - # # elif message.role == ModelMessageRoleType.HUMAN: - # # history.append({"role": "user", "content": message.content}) - # elif message.role == ModelMessageRoleType.AI: - # history.append({"role": "assistant", "content": message.content}) - # else: - # pass - # - # # temp_his = history[::-1] - # temp_his = history - # last_user_input = None - # for m in temp_his: - # if m["role"] == "user": - # last_user_input = m - # break - # - # if last_user_input: - # history.remove(last_user_input) - # history.append(last_user_input) - - history, systems = __convert_2_wenxin_messages(messages) + history, systems = __convert_2_zhipu_messages(messages) res = zhipuai.model_api.sse_invoke( model=proxyllm_backend, prompt=history,