diff --git a/.github/workflows/pylint.yml b/.github/workflows/pylint.yml index 383e65cd0..e82f1b9d4 100644 --- a/.github/workflows/pylint.yml +++ b/.github/workflows/pylint.yml @@ -1,6 +1,15 @@ name: Pylint -on: [push] +on: + push: + branches: [ make_ci_happy ] + pull_request: + branches: [ main ] + workflow_dispatch: + +concurrency: + group: ${{ github.event.number || github.run_id }} + cancel-in-progress: true jobs: build: @@ -17,7 +26,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install pylint - - name: Analysing the code with pylint + pip install -U black isort + - name: check the code lint run: | - pylint $(git ls-files '*.py') + black . --check diff --git a/examples/app.py b/examples/app.py index 7cb3aad7f..31b6e0f26 100644 --- a/examples/app.py +++ b/examples/app.py @@ -2,24 +2,28 @@ # -*- coding:utf-8 -*- import gradio as gr -from langchain.agents import ( - load_tools, - initialize_agent, - AgentType -) -from pilot.model.vicuna_llm import VicunaRequestLLM, VicunaEmbeddingLLM -from llama_index import LLMPredictor, LangchainEmbedding, ServiceContext +from langchain.agents import AgentType, initialize_agent, load_tools from langchain.embeddings.huggingface import HuggingFaceEmbeddings -from llama_index import Document, GPTSimpleVectorIndex +from llama_index import ( + Document, + GPTSimpleVectorIndex, + LangchainEmbedding, + LLMPredictor, + ServiceContext, +) + +from pilot.model.vicuna_llm import VicunaEmbeddingLLM, VicunaRequestLLM + def agent_demo(): llm = VicunaRequestLLM() - tools = load_tools(['python_repl'], llm=llm) - agent = initialize_agent(tools, llm, agent=AgentType.CHAT_ZERO_SHOT_REACT_DESCRIPTION, verbose=True) - agent.run( - "Write a SQL script that Query 'select count(1)!'" + tools = load_tools(["python_repl"], llm=llm) + agent = initialize_agent( + tools, llm, agent=AgentType.CHAT_ZERO_SHOT_REACT_DESCRIPTION, verbose=True ) + agent.run("Write a SQL script that Query 'select count(1)!'") + def knowledged_qa_demo(text_list): llm_predictor = LLMPredictor(llm=VicunaRequestLLM()) @@ -27,27 +31,34 @@ def knowledged_qa_demo(text_list): embed_model = LangchainEmbedding(hfemb) documents = [Document(t) for t in text_list] - service_context = ServiceContext.from_defaults(llm_predictor=llm_predictor, embed_model=embed_model) - index = GPTSimpleVectorIndex.from_documents(documents, service_context=service_context) + service_context = ServiceContext.from_defaults( + llm_predictor=llm_predictor, embed_model=embed_model + ) + index = GPTSimpleVectorIndex.from_documents( + documents, service_context=service_context + ) return index def get_answer(q): - base_knowledge = """ """ + base_knowledge = """ """ text_list = [base_knowledge] index = knowledged_qa_demo(text_list) response = index.query(q) return response.response + def get_similar(q): from pilot.vector_store.extract_tovec import knownledge_tovec, knownledge_tovec_st + docsearch = knownledge_tovec_st("./datasets/plan.md") docs = docsearch.similarity_search_with_score(q, k=1) for doc in docs: - dc, s = doc + dc, s = doc print(s) - yield dc.page_content + yield dc.page_content + if __name__ == "__main__": # agent_demo() @@ -58,8 +69,7 @@ def get_similar(q): text_input = gr.TextArea() text_output = gr.TextArea() text_button = gr.Button() - + text_button.click(get_similar, inputs=text_input, outputs=text_output) demo.queue(concurrency_count=3).launch(server_name="0.0.0.0") - diff --git a/examples/embdserver.py b/examples/embdserver.py index 32eca1291..ae0dfcae8 100644 --- a/examples/embdserver.py +++ b/examples/embdserver.py @@ -1,30 +1,29 @@ #!/usr/bin/env python3 # -*- coding:utf-8 -*- -import requests import json -import time -import uuid import os import sys from urllib.parse import urljoin + import gradio as gr +import requests ROOT_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) sys.path.append(ROOT_PATH) -from pilot.configs.config import Config -from pilot.conversation import conv_qa_prompt_template, conv_templates from langchain.prompts import PromptTemplate +from pilot.configs.config import Config +from pilot.conversation import conv_qa_prompt_template, conv_templates llmstream_stream_path = "generate_stream" CFG = Config() -def generate(query): +def generate(query): template_name = "conv_one_shot" state = conv_templates[template_name].copy() @@ -47,7 +46,7 @@ def generate(query): "prompt": prompt, "temperature": 1.0, "max_new_tokens": 1024, - "stop": "###" + "stop": "###", } response = requests.post( @@ -57,19 +56,18 @@ def generate(query): skip_echo_len = len(params["prompt"]) + 1 - params["prompt"].count("") * 3 for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): - if chunk: data = json.loads(chunk.decode()) if data["error_code"] == 0: - if "vicuna" in CFG.LLM_MODEL: output = data["text"][skip_echo_len:].strip() else: output = data["text"].strip() state.messages[-1][-1] = output + "▌" - yield(output) - + yield (output) + + if __name__ == "__main__": print(CFG.LLM_MODEL) with gr.Blocks() as demo: @@ -78,10 +76,7 @@ def generate(query): text_input = gr.TextArea() text_output = gr.TextArea() text_button = gr.Button("提交") - text_button.click(generate, inputs=text_input, outputs=text_output) - demo.queue(concurrency_count=3).launch(server_name="0.0.0.0") - - \ No newline at end of file + demo.queue(concurrency_count=3).launch(server_name="0.0.0.0") diff --git a/examples/gpt_index.py b/examples/gpt_index.py index 29c0a3fe0..2a0841a24 100644 --- a/examples/gpt_index.py +++ b/examples/gpt_index.py @@ -1,19 +1,19 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -import os import logging import sys -from llama_index import SimpleDirectoryReader, GPTSimpleVectorIndex +from llama_index import GPTSimpleVectorIndex, SimpleDirectoryReader + logging.basicConfig(stream=sys.stdout, level=logging.INFO) logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout)) # read the document of data dir documents = SimpleDirectoryReader("data").load_data() -# split the document to chunk, max token size=500, convert chunk to vector +# split the document to chunk, max token size=500, convert chunk to vector index = GPTSimpleVectorIndex(documents) # save index -index.save_to_disk("index.json") \ No newline at end of file +index.save_to_disk("index.json") diff --git a/examples/gradio_test.py b/examples/gradio_test.py index f39a1ca9e..593c6c1f4 100644 --- a/examples/gradio_test.py +++ b/examples/gradio_test.py @@ -3,17 +3,19 @@ import gradio as gr + def change_tab(): return gr.Tabs.update(selected=1) + with gr.Blocks() as demo: with gr.Tabs() as tabs: with gr.TabItem("Train", id=0): t = gr.Textbox() with gr.TabItem("Inference", id=1): i = gr.Image() - + btn = gr.Button() btn.click(change_tab, None, tabs) -demo.launch() \ No newline at end of file +demo.launch() diff --git a/examples/knowledge_embedding/csv_embedding_test.py b/examples/knowledge_embedding/csv_embedding_test.py index d796596c6..3f08422f7 100644 --- a/examples/knowledge_embedding/csv_embedding_test.py +++ b/examples/knowledge_embedding/csv_embedding_test.py @@ -1,5 +1,3 @@ - - from pilot.source_embedding.csv_embedding import CSVEmbedding # path = "/Users/chenketing/Downloads/share_ireserve双写数据异常2.xlsx" @@ -8,6 +6,13 @@ vector_store_path = "your_path/" -pdf_embedding = CSVEmbedding(file_path=path, model_name=model_name, vector_store_config={"vector_store_name": "url", "vector_store_path": "vector_store_path"}) +pdf_embedding = CSVEmbedding( + file_path=path, + model_name=model_name, + vector_store_config={ + "vector_store_name": "url", + "vector_store_path": "vector_store_path", + }, +) pdf_embedding.source_embedding() -print("success") \ No newline at end of file +print("success") diff --git a/examples/knowledge_embedding/pdf_embedding_test.py b/examples/knowledge_embedding/pdf_embedding_test.py index 6c3f3588e..660b811ee 100644 --- a/examples/knowledge_embedding/pdf_embedding_test.py +++ b/examples/knowledge_embedding/pdf_embedding_test.py @@ -6,6 +6,13 @@ vector_store_path = "your_path/" -pdf_embedding = PDFEmbedding(file_path=path, model_name=model_name, vector_store_config={"vector_store_name": "ob-pdf", "vector_store_path": vector_store_path}) +pdf_embedding = PDFEmbedding( + file_path=path, + model_name=model_name, + vector_store_config={ + "vector_store_name": "ob-pdf", + "vector_store_path": vector_store_path, + }, +) pdf_embedding.source_embedding() -print("success") \ No newline at end of file +print("success") diff --git a/examples/knowledge_embedding/url_embedding_test.py b/examples/knowledge_embedding/url_embedding_test.py index 5db7f998d..aeb353c89 100644 --- a/examples/knowledge_embedding/url_embedding_test.py +++ b/examples/knowledge_embedding/url_embedding_test.py @@ -5,6 +5,13 @@ vector_store_path = "your_path" -pdf_embedding = URLEmbedding(file_path=path, model_name=model_name, vector_store_config={"vector_store_name": "url", "vector_store_path": "vector_store_path"}) +pdf_embedding = URLEmbedding( + file_path=path, + model_name=model_name, + vector_store_config={ + "vector_store_name": "url", + "vector_store_path": "vector_store_path", + }, +) pdf_embedding.source_embedding() -print("success") \ No newline at end of file +print("success") diff --git a/examples/t5_example.py b/examples/t5_example.py index a63c9f961..ab2b7f2e3 100644 --- a/examples/t5_example.py +++ b/examples/t5_example.py @@ -1,19 +1,28 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -from llama_index import SimpleDirectoryReader, LangchainEmbedding, GPTListIndex, GPTSimpleVectorIndex, PromptHelper -from langchain.embeddings.huggingface import HuggingFaceEmbeddings -from llama_index import LLMPredictor import torch +from langchain.embeddings.huggingface import HuggingFaceEmbeddings from langchain.llms.base import LLM +from llama_index import ( + GPTListIndex, + GPTSimpleVectorIndex, + LangchainEmbedding, + LLMPredictor, + PromptHelper, + SimpleDirectoryReader, +) from transformers import pipeline class FlanLLM(LLM): model_name = "google/flan-t5-large" - pipeline = pipeline("text2text-generation", model=model_name, device=0, model_kwargs={ - "torch_dtype": torch.bfloat16 - }) + pipeline = pipeline( + "text2text-generation", + model=model_name, + device=0, + model_kwargs={"torch_dtype": torch.bfloat16}, + ) def _call(self, prompt, stop=None): return self.pipeline(prompt, max_length=9999)[0]["generated_text"] @@ -24,6 +33,7 @@ def _identifying_params(self): def _llm_type(self): return "custome" + llm_predictor = LLMPredictor(llm=FlanLLM()) hfemb = HuggingFaceEmbeddings() embed_model = LangchainEmbedding(hfemb) @@ -214,9 +224,10 @@ def _llm_type(self): 回答: nlj也是左表的表是驱动表,这个要了解下计划执行方面的基本原理,取左表的一行数据,再遍历右表,一旦满足连接条件,就可以返回数据 anti/semi只是因为not exists/exist的语义只是返回左表数据,改成anti join是一种计划优化,连接的方式比子查询更优 -""" +""" from llama_index import Document + text_list = [text1] documents = [Document(t) for t in text_list] @@ -226,12 +237,18 @@ def _llm_type(self): max_chunk_overlap = 20 prompt_helper = PromptHelper(max_input_size, num_output, max_chunk_overlap) -index = GPTListIndex(documents, embed_model=embed_model, llm_predictor=llm_predictor, prompt_helper=prompt_helper) +index = GPTListIndex( + documents, + embed_model=embed_model, + llm_predictor=llm_predictor, + prompt_helper=prompt_helper, +) index.save_to_disk("index.json") if __name__ == "__main__": import logging + logging.getLogger().setLevel(logging.CRITICAL) for d in documents: print(d) diff --git a/pilot/__init__.py b/pilot/__init__.py index b1d1cd3d2..f44b2e809 100644 --- a/pilot/__init__.py +++ b/pilot/__init__.py @@ -1,6 +1,3 @@ -from pilot.source_embedding import (SourceEmbedding, register) +from pilot.source_embedding import SourceEmbedding, register -__all__ = [ - "SourceEmbedding", - "register" -] +__all__ = ["SourceEmbedding", "register"] diff --git a/pilot/agent/agent.py b/pilot/agent/agent.py index 3463790bc..8d8220b4a 100644 --- a/pilot/agent/agent.py +++ b/pilot/agent/agent.py @@ -3,10 +3,10 @@ class Agent: - """Agent class for interacting with DB-GPT - - Attributes: + """Agent class for interacting with DB-GPT + + Attributes: """ - + def __init__(self) -> None: - pass \ No newline at end of file + pass diff --git a/pilot/agent/agent_manager.py b/pilot/agent/agent_manager.py index 89754bd1c..31b55eb65 100644 --- a/pilot/agent/agent_manager.py +++ b/pilot/agent/agent_manager.py @@ -4,10 +4,8 @@ from __future__ import annotations from pilot.configs.config import Config -from pilot.singleton import Singleton -from pilot.configs.config import Config -from typing import List from pilot.model.base import Message +from pilot.singleton import Singleton class AgentManager(metaclass=Singleton): @@ -17,6 +15,7 @@ def __init__(self): self.next_key = 0 self.agents = {} # key, (task, full_message_history, model) self.cfg = Config() + """Agent manager for managing DB-GPT agents In order to compatible auto gpt plugins, we use the same template with it. @@ -28,7 +27,7 @@ def __init__(self): def __init__(self) -> None: self.next_key = 0 - self.agents = {} #TODO need to define + self.agents = {} # TODO need to define self.cfg = Config() # Create new GPT agent @@ -46,7 +45,6 @@ def create_agent(self, task: str, prompt: str, model: str) -> tuple[int, str]: The key of the new agent """ - def message_agent(self, key: str | int, message: str) -> str: """Send a message to an agent and return its response @@ -58,7 +56,6 @@ def message_agent(self, key: str | int, message: str) -> str: The agent's response """ - def list_agents(self) -> list[tuple[str | int, str]]: """Return a list of all agents diff --git a/pilot/agent/json_fix_llm.py b/pilot/agent/json_fix_llm.py index 3ca8f85b0..327881a78 100644 --- a/pilot/agent/json_fix_llm.py +++ b/pilot/agent/json_fix_llm.py @@ -1,18 +1,22 @@ - +import contextlib import json from typing import Any, Dict -import contextlib + from colorama import Fore from regex import regex from pilot.configs.config import Config +from pilot.json_utils.json_fix_general import ( + add_quotes_to_property_names, + balance_braces, + fix_invalid_escape, +) from pilot.logs import logger from pilot.speech import say_text -from pilot.json_utils.json_fix_general import fix_invalid_escape,add_quotes_to_property_names,balance_braces - CFG = Config() + def fix_and_parse_json( json_to_load: str, try_to_fix_with_gpt: bool = True ) -> Dict[Any, Any]: @@ -48,7 +52,7 @@ def fix_and_parse_json( maybe_fixed_json = maybe_fixed_json[: last_brace_index + 1] return json.loads(maybe_fixed_json) except (json.JSONDecodeError, ValueError) as e: - logger.error("参数解析错误", e) + logger.error("参数解析错误", e) def fix_json_using_multiple_techniques(assistant_reply: str) -> Dict[Any, Any]: diff --git a/pilot/chain/audio.py b/pilot/chain/audio.py index 8b197119c..c53f601b3 100644 --- a/pilot/chain/audio.py +++ b/pilot/chain/audio.py @@ -1,2 +1,2 @@ #!/usr/bin/env python3 -# -*- coding:utf-8 -*- \ No newline at end of file +# -*- coding:utf-8 -*- diff --git a/pilot/chain/visual.py b/pilot/chain/visual.py index 1f776fc63..56fafa58b 100644 --- a/pilot/chain/visual.py +++ b/pilot/chain/visual.py @@ -1,2 +1,2 @@ #!/usr/bin/env python3 -# -*- coding: utf-8 -*- \ No newline at end of file +# -*- coding: utf-8 -*- diff --git a/pilot/commands/command.py b/pilot/commands/command.py index 134e93e1d..0200ef6cd 100644 --- a/pilot/commands/command.py +++ b/pilot/commands/command.py @@ -1,15 +1,14 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -from pilot.prompts.generator import PromptGenerator -from typing import Dict, List, NoReturn, Union -from pilot.configs.config import Config - -from pilot.speech import say_text +import json +from typing import Dict from pilot.agent.json_fix_llm import fix_json_using_multiple_techniques from pilot.commands.exception_not_commands import NotCommands -import json +from pilot.configs.config import Config +from pilot.prompts.generator import PromptGenerator +from pilot.speech import say_text def _resolve_pathlike_command_args(command_args): @@ -25,9 +24,9 @@ def _resolve_pathlike_command_args(command_args): def execute_ai_response_json( - prompt: PromptGenerator, - ai_response: str, - user_input: str = None, + prompt: PromptGenerator, + ai_response: str, + user_input: str = None, ) -> str: """ @@ -52,18 +51,14 @@ def execute_ai_response_json( arguments = _resolve_pathlike_command_args(arguments) # Execute command if command_name is not None and command_name.lower().startswith("error"): - result = ( - f"Command {command_name} threw the following error: {arguments}" - ) + result = f"Command {command_name} threw the following error: {arguments}" elif command_name == "human_feedback": result = f"Human feedback: {user_input}" else: for plugin in cfg.plugins: if not plugin.can_handle_pre_command(): continue - command_name, arguments = plugin.pre_command( - command_name, arguments - ) + command_name, arguments = plugin.pre_command(command_name, arguments) command_result = execute_command( command_name, arguments, @@ -74,9 +69,9 @@ def execute_ai_response_json( def execute_command( - command_name: str, - arguments, - prompt: PromptGenerator, + command_name: str, + arguments, + prompt: PromptGenerator, ): """Execute the command and return the result @@ -102,13 +97,15 @@ def execute_command( else: for command in prompt.commands: if ( - command_name == command["label"].lower() - or command_name == command["name"].lower() + command_name == command["label"].lower() + or command_name == command["name"].lower() ): try: # 删除非定义参数 - diff_ags = list(set(arguments.keys()).difference(set(command['args'].keys()))) - for arg_name in diff_ags: + diff_ags = list( + set(arguments.keys()).difference(set(command["args"].keys())) + ) + for arg_name in diff_ags: del arguments[arg_name] print(str(arguments)) return command["function"](**arguments) diff --git a/pilot/commands/commands_load.py b/pilot/commands/commands_load.py index b173eb03a..a6fad3db2 100644 --- a/pilot/commands/commands_load.py +++ b/pilot/commands/commands_load.py @@ -1,19 +1,21 @@ +from typing import Optional + from pilot.configs.config import Config from pilot.prompts.generator import PromptGenerator -from typing import Any, Optional, Type from pilot.prompts.prompt import build_default_prompt_generator class CommandsLoad: """ - Load Plugins Commands Info , help build system prompt! + Load Plugins Commands Info , help build system prompt! """ - def __init__(self)->None: + def __init__(self) -> None: self.command_registry = None - - def getCommandInfos(self, prompt_generator: Optional[PromptGenerator] = None)-> str: + def getCommandInfos( + self, prompt_generator: Optional[PromptGenerator] = None + ) -> str: cfg = Config() if prompt_generator is None: prompt_generator = build_default_prompt_generator() @@ -24,4 +26,4 @@ def getCommandInfos(self, prompt_generator: Optional[PromptGenerator] = None)-> self.prompt_generator = prompt_generator command_infos = "" command_infos += f"\n\n{prompt_generator.commands()}" - return command_infos \ No newline at end of file + return command_infos diff --git a/pilot/commands/exception_not_commands.py b/pilot/commands/exception_not_commands.py index 88c4d8f1d..7d92f05c0 100644 --- a/pilot/commands/exception_not_commands.py +++ b/pilot/commands/exception_not_commands.py @@ -1,5 +1,4 @@ - class NotCommands(Exception): def __init__(self, message): super().__init__(message) - self.message = message \ No newline at end of file + self.message = message diff --git a/pilot/commands/image_gen.py b/pilot/commands/image_gen.py index 25a6c80fd..d6492e2d9 100644 --- a/pilot/commands/image_gen.py +++ b/pilot/commands/image_gen.py @@ -25,7 +25,7 @@ def generate_image(prompt: str, size: int = 256) -> str: str: The filename of the image """ filename = f"{CFG.workspace_path}/{str(uuid.uuid4())}.jpg" - + # HuggingFace if CFG.image_provider == "huggingface": return generate_image_with_hf(prompt, filename) @@ -72,6 +72,7 @@ def generate_image_with_hf(prompt: str, filename: str) -> str: return f"Saved to disk:{filename}" + def generate_image_with_sd_webui( prompt: str, filename: str, diff --git a/pilot/configs/ai_config.py b/pilot/configs/ai_config.py index d774454e4..ed9b4e2f8 100644 --- a/pilot/configs/ai_config.py +++ b/pilot/configs/ai_config.py @@ -7,13 +7,13 @@ import os import platform from pathlib import Path -from typing import Any, Optional, Type +from typing import Optional import distro import yaml -from pilot.prompts.generator import PromptGenerator from pilot.configs.config import Config +from pilot.prompts.generator import PromptGenerator from pilot.prompts.prompt import build_default_prompt_generator # Soon this will go in a folder where it remembers more stuff about the run(s) @@ -88,7 +88,7 @@ def load(config_file: str = SAVE_FILE) -> "AIConfig": for goal in config_params.get("ai_goals", []) ] api_budget = config_params.get("api_budget", 0.0) - # type: Type[AIConfig] + # type is Type[AIConfig] return AIConfig(ai_name, ai_role, ai_goals, api_budget) def save(self, config_file: str = SAVE_FILE) -> None: @@ -133,8 +133,6 @@ def construct_full_prompt( "" ) - - cfg = Config() if prompt_generator is None: prompt_generator = build_default_prompt_generator() diff --git a/pilot/configs/config.py b/pilot/configs/config.py index e9ec2bd48..d5e403598 100644 --- a/pilot/configs/config.py +++ b/pilot/configs/config.py @@ -2,15 +2,17 @@ # -*- coding: utf-8 -*- import os -import nltk from typing import List +import nltk from auto_gpt_plugin_template import AutoGPTPluginTemplate + from pilot.singleton import Singleton class Config(metaclass=Singleton): """Configuration class to store the state of bools for different scripts access""" + def __init__(self) -> None: """Initialize the Config class""" @@ -18,7 +20,6 @@ def __init__(self) -> None: self.skip_reprompt = False self.temperature = float(os.getenv("TEMPERATURE", 0.7)) - self.execute_local_commands = ( os.getenv("EXECUTE_LOCAL_COMMANDS", "False") == "True" ) @@ -45,7 +46,6 @@ def __init__(self) -> None: self.milvus_collection = os.getenv("MILVUS_COLLECTION", "dbgpt") self.milvus_secure = os.getenv("MILVUS_SECURE") == "True" - self.authorise_key = os.getenv("AUTHORISE_COMMAND_KEY", "y") self.exit_key = os.getenv("EXIT_KEY", "n") self.image_provider = os.getenv("IMAGE_PROVIDER", True) @@ -62,7 +62,6 @@ def __init__(self) -> None: ) self.speak_mode = False - ### Related configuration of built-in commands self.command_registry = [] @@ -76,7 +75,6 @@ def __init__(self) -> None: os.getenv("EXECUTE_LOCAL_COMMANDS", "False") == "True" ) - ### The associated configuration parameters of the plug-in control the loading and use of the plug-in self.plugins_dir = os.getenv("PLUGINS_DIR", "../../plugins") self.plugins: List[AutoGPTPluginTemplate] = [] @@ -94,35 +92,35 @@ def __init__(self) -> None: else: self.plugins_denylist = [] - ### Local database connection configuration - self.LOCAL_DB_HOST = os.getenv("LOCAL_DB_HOST", "127.0.0.1") - self.LOCAL_DB_PORT = int(os.getenv("LOCAL_DB_PORT", 3306)) - self.LOCAL_DB_USER = os.getenv("LOCAL_DB_USER", "root") - self.LOCAL_DB_PASSWORD = os.getenv("LOCAL_DB_PASSWORD", "aa123456") + self.LOCAL_DB_HOST = os.getenv("LOCAL_DB_HOST", "127.0.0.1") + self.LOCAL_DB_PORT = int(os.getenv("LOCAL_DB_PORT", 3306)) + self.LOCAL_DB_USER = os.getenv("LOCAL_DB_USER", "root") + self.LOCAL_DB_PASSWORD = os.getenv("LOCAL_DB_PASSWORD", "aa123456") ### LLM Model Service Configuration - self.LLM_MODEL = os.getenv("LLM_MODEL", "vicuna-13b") - self.LIMIT_MODEL_CONCURRENCY = int(os.getenv("LIMIT_MODEL_CONCURRENCY", 5)) - self.MAX_POSITION_EMBEDDINGS = int(os.getenv("MAX_POSITION_EMBEDDINGS", 4096)) - self.MODEL_PORT = os.getenv("MODEL_PORT", 8000) - self.MODEL_SERVER = os.getenv("MODEL_SERVER", "http://127.0.0.1" + ":" + str(self.MODEL_PORT)) + self.LLM_MODEL = os.getenv("LLM_MODEL", "vicuna-13b") + self.LIMIT_MODEL_CONCURRENCY = int(os.getenv("LIMIT_MODEL_CONCURRENCY", 5)) + self.MAX_POSITION_EMBEDDINGS = int(os.getenv("MAX_POSITION_EMBEDDINGS", 4096)) + self.MODEL_PORT = os.getenv("MODEL_PORT", 8000) + self.MODEL_SERVER = os.getenv( + "MODEL_SERVER", "http://127.0.0.1" + ":" + str(self.MODEL_PORT) + ) self.ISLOAD_8BIT = os.getenv("ISLOAD_8BIT", "True") == "True" ### Vector Store Configuration - self.VECTOR_STORE_TYPE = os.getenv("VECTOR_STORE_TYPE", "Chroma") - self.MILVUS_URL = os.getenv("MILVUS_URL", "127.0.0.1") - self.MILVUS_PORT = os.getenv("MILVUS_PORT", "19530") - self.MILVUS_USERNAME = os.getenv("MILVUS_USERNAME", None) - self.MILVUS_PASSWORD = os.getenv("MILVUS_PASSWORD", None) - + self.VECTOR_STORE_TYPE = os.getenv("VECTOR_STORE_TYPE", "Chroma") + self.MILVUS_URL = os.getenv("MILVUS_URL", "127.0.0.1") + self.MILVUS_PORT = os.getenv("MILVUS_PORT", "19530") + self.MILVUS_USERNAME = os.getenv("MILVUS_USERNAME", None) + self.MILVUS_PASSWORD = os.getenv("MILVUS_PASSWORD", None) def set_debug_mode(self, value: bool) -> None: """Set the debug mode value""" self.debug_mode = value def set_plugins(self, value: list) -> None: - """Set the plugins value. """ + """Set the plugins value.""" self.plugins = value def set_templature(self, value: int) -> None: @@ -135,4 +133,4 @@ def set_speak_mode(self, value: bool) -> None: def set_last_plugin_return(self, value: bool) -> None: """Set the speak mode value.""" - self.last_plugin_return = value \ No newline at end of file + self.last_plugin_return = value diff --git a/pilot/configs/model_config.py b/pilot/configs/model_config.py index ebd8513e4..9d8c930fd 100644 --- a/pilot/configs/model_config.py +++ b/pilot/configs/model_config.py @@ -1,10 +1,10 @@ #!/usr/bin/env python3 # -*- coding:utf-8 -*- -import torch import os -import nltk +import nltk +import torch ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) MODEL_PATH = os.path.join(ROOT_PATH, "models") @@ -16,7 +16,13 @@ nltk.data.path = [os.path.join(PILOT_PATH, "nltk_data")] + nltk.data.path -DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" +DEVICE = ( + "cuda" + if torch.cuda.is_available() + else "mps" + if torch.backends.mps.is_available() + else "cpu" +) LLM_MODEL_CONFIG = { "flan-t5-base": os.path.join(MODEL_PATH, "flan-t5-base"), "vicuna-13b": os.path.join(MODEL_PATH, "vicuna-13b"), @@ -28,7 +34,7 @@ "chatglm-6b-int4": os.path.join(MODEL_PATH, "chatglm-6b-int4"), "chatglm-6b": os.path.join(MODEL_PATH, "chatglm-6b"), "text2vec-base": os.path.join(MODEL_PATH, "text2vec-base-chinese"), - "sentence-transforms": os.path.join(MODEL_PATH, "all-MiniLM-L6-v2") + "sentence-transforms": os.path.join(MODEL_PATH, "all-MiniLM-L6-v2"), } @@ -46,5 +52,7 @@ VECTOR_SEARCH_TOP_K = 10 VS_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "vs_store") -KNOWLEDGE_UPLOAD_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "data") -KNOWLEDGE_CHUNK_SPLIT_SIZE = 100 \ No newline at end of file +KNOWLEDGE_UPLOAD_ROOT_PATH = os.path.join( + os.path.dirname(os.path.dirname(__file__)), "data" +) +KNOWLEDGE_CHUNK_SPLIT_SIZE = 100 diff --git a/pilot/connections/base.py b/pilot/connections/base.py index 318ce17a2..ec41f9273 100644 --- a/pilot/connections/base.py +++ b/pilot/connections/base.py @@ -3,6 +3,6 @@ """We need to design a base class. That other connector can Write with this""" + class BaseConnection: pass - diff --git a/pilot/connections/clickhouse.py b/pilot/connections/clickhouse.py index 23f2660f9..7ea244276 100644 --- a/pilot/connections/clickhouse.py +++ b/pilot/connections/clickhouse.py @@ -4,4 +4,5 @@ class ClickHouseConnector: """ClickHouseConnector""" - pass \ No newline at end of file + + pass diff --git a/pilot/connections/es.py b/pilot/connections/es.py index 819d85ecf..3810c7619 100644 --- a/pilot/connections/es.py +++ b/pilot/connections/es.py @@ -4,4 +4,5 @@ class ElasticSearchConnector: """ElasticSearchConnector""" - pass \ No newline at end of file + + pass diff --git a/pilot/connections/mongo.py b/pilot/connections/mongo.py index b66aefdb3..27f7a610c 100644 --- a/pilot/connections/mongo.py +++ b/pilot/connections/mongo.py @@ -1,6 +1,8 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- + class MongoConnector: """MongoConnector is a class which connect to mongo and chat with LLM""" - pass \ No newline at end of file + + pass diff --git a/pilot/connections/mysql.py b/pilot/connections/mysql.py index 83da27ec3..a4595c603 100644 --- a/pilot/connections/mysql.py +++ b/pilot/connections/mysql.py @@ -3,27 +3,27 @@ import pymysql + class MySQLOperator: - """Connect MySQL Database fetch MetaData For LLM Prompt - Args: + """Connect MySQL Database fetch MetaData For LLM Prompt + Args: - Usage: + Usage: """ default_db = ["information_schema", "performance_schema", "sys", "mysql"] + def __init__(self, user, password, host="localhost", port=3306) -> None: - self.conn = pymysql.connect( host=host, user=user, port=port, passwd=password, charset="utf8mb4", - cursorclass=pymysql.cursors.DictCursor + cursorclass=pymysql.cursors.DictCursor, ) def get_schema(self, schema_name): - with self.conn.cursor() as cursor: _sql = f""" select concat(table_name, "(" , group_concat(column_name), ")") as schema_info from information_schema.COLUMNS where table_schema="{schema_name}" group by TABLE_NAME; @@ -31,7 +31,7 @@ def get_schema(self, schema_name): cursor.execute(_sql) results = cursor.fetchall() return results - + def get_index(self, schema_name): pass @@ -43,10 +43,10 @@ def get_db_list(self): cursor.execute(_sql) results = cursor.fetchall() - dbs = [d["Database"] for d in results if d["Database"] not in self.default_db] + dbs = [ + d["Database"] for d in results if d["Database"] not in self.default_db + ] return dbs def get_meta(self, schema_name): pass - - diff --git a/pilot/connections/oracle.py b/pilot/connections/oracle.py index 4ce4e742a..6af8aa0a8 100644 --- a/pilot/connections/oracle.py +++ b/pilot/connections/oracle.py @@ -1,6 +1,8 @@ #!/usr/bin/env python3 # -*- coding:utf-8 -*- + class OracleConnector: """OracleConnector""" - pass \ No newline at end of file + + pass diff --git a/pilot/connections/postgres.py b/pilot/connections/postgres.py index 3e1df00ab..48a225293 100644 --- a/pilot/connections/postgres.py +++ b/pilot/connections/postgres.py @@ -2,7 +2,7 @@ # -*- coding: utf-8 -*- - class PostgresConnector: """PostgresConnector is a class which Connector to chat with LLM""" - pass \ No newline at end of file + + pass diff --git a/pilot/connections/redis.py b/pilot/connections/redis.py index ac00ade63..2502562a7 100644 --- a/pilot/connections/redis.py +++ b/pilot/connections/redis.py @@ -4,4 +4,5 @@ class RedisConnector: """RedisConnector""" - pass \ No newline at end of file + + pass diff --git a/pilot/conversation.py b/pilot/conversation.py index 0470bc720..d674e901a 100644 --- a/pilot/conversation.py +++ b/pilot/conversation.py @@ -2,31 +2,34 @@ # -*- coding:utf-8 -*- import dataclasses -from enum import auto, Enum -from typing import List, Any +from enum import Enum, auto +from typing import Any, List + from pilot.configs.config import Config CFG = Config() DB_SETTINGS = { "user": CFG.LOCAL_DB_USER, - "password": CFG.LOCAL_DB_PASSWORD, + "password": CFG.LOCAL_DB_PASSWORD, "host": CFG.LOCAL_DB_HOST, - "port": CFG.LOCAL_DB_PORT + "port": CFG.LOCAL_DB_PORT, } ROLE_USER = "USER" ROLE_ASSISTANT = "Assistant" + class SeparatorStyle(Enum): SINGLE = auto() TWO = auto() THREE = auto() FOUR = auto() -@ dataclasses.dataclass + +@dataclasses.dataclass class Conversation: - """This class keeps all conversation history. """ + """This class keeps all conversation history.""" system: str roles: List[str] @@ -67,7 +70,7 @@ def append_message(self, role, message): def to_gradio_chatbot(self): ret = [] - for i, (role, msg) in enumerate(self.messages[self.offset:]): + for i, (role, msg) in enumerate(self.messages[self.offset :]): if i % 2 == 0: ret.append([msg, None]) else: @@ -95,15 +98,14 @@ def dict(self): "offset": self.offset, "sep": self.sep, "sep2": self.sep2, - "conv_id": self.conv_id + "conv_id": self.conv_id, } def gen_sqlgen_conversation(dbname): from pilot.connections.mysql import MySQLOperator - mo = MySQLOperator( - **(DB_SETTINGS) - ) + + mo = MySQLOperator(**(DB_SETTINGS)) message = "" @@ -115,7 +117,7 @@ def gen_sqlgen_conversation(dbname): conv_one_shot = Conversation( system="A chat between a curious user and an artificial intelligence assistant, who very familiar with database related knowledge. " - "The assistant gives helpful, detailed, professional and polite answers to the user's questions. ", + "The assistant gives helpful, detailed, professional and polite answers to the user's questions. ", roles=("USER", "Assistant"), messages=( ( @@ -136,20 +138,19 @@ def gen_sqlgen_conversation(dbname): "whereas PostgreSQL is known for its robustness and reliability.\n" "5. Licensing: MySQL is licensed under the GPL (General Public License), which means that it is free and open-source software, " "whereas PostgreSQL is licensed under the PostgreSQL License, which is also free and open-source but with different terms.\n" - "Ultimately, the choice between MySQL and PostgreSQL depends on the specific needs and requirements of your application. " "Both are excellent database management systems, and choosing the right one " - "for your project requires careful consideration of your application's requirements, performance needs, and scalability." + "for your project requires careful consideration of your application's requirements, performance needs, and scalability.", ), ), offset=2, sep_style=SeparatorStyle.SINGLE, - sep="###" + sep="###", ) conv_vicuna_v1 = Conversation( system="A chat between a curious user and an artificial intelligence assistant. who very familiar with database related knowledge. " - "The assistant gives helpful, detailed, professional and polite answers to the user's questions. ", + "The assistant gives helpful, detailed, professional and polite answers to the user's questions. ", roles=("USER", "ASSISTANT"), messages=(), offset=0, @@ -160,7 +161,7 @@ def gen_sqlgen_conversation(dbname): auto_dbgpt_one_shot = Conversation( system="You are DB-GPT, an AI designed to answer questions about HackerNews by query `hackerbews` database in MySQL. " - "Your decisions must always be made independently without seeking user assistance. Play to your strengths as an LLM and pursue simple strategies with no legal complications.", + "Your decisions must always be made independently without seeking user assistance. Play to your strengths as an LLM and pursue simple strategies with no legal complications.", roles=("USER", "ASSISTANT"), messages=( ( @@ -203,7 +204,7 @@ def gen_sqlgen_conversation(dbname): } } } - """ + """, ), ( "ASSISTANT", @@ -223,8 +224,8 @@ def gen_sqlgen_conversation(dbname): } } } - """ - ) + """, + ), ), offset=0, sep_style=SeparatorStyle.SINGLE, @@ -233,7 +234,7 @@ def gen_sqlgen_conversation(dbname): auto_dbgpt_without_shot = Conversation( system="You are DB-GPT, an AI designed to answer questions about users by query `users` database in MySQL. " - "Your decisions must always be made independently without seeking user assistance. Play to your strengths as an LLM and pursue simple strategies with no legal complications.", + "Your decisions must always be made independently without seeking user assistance. Play to your strengths as an LLM and pursue simple strategies with no legal complications.", roles=("USER", "ASSISTANT"), messages=(), offset=0, @@ -259,9 +260,9 @@ def gen_sqlgen_conversation(dbname): # """ default_conversation = conv_one_shot -conversation_sql_mode ={ +conversation_sql_mode = { "auto_execute_ai_response": "直接执行结果", - "dont_execute_ai_response": "不直接执行结果" + "dont_execute_ai_response": "不直接执行结果", } conversation_types = { @@ -273,7 +274,7 @@ def gen_sqlgen_conversation(dbname): conv_templates = { "conv_one_shot": conv_one_shot, "vicuna_v1": conv_vicuna_v1, - "auto_dbgpt_one_shot": auto_dbgpt_one_shot + "auto_dbgpt_one_shot": auto_dbgpt_one_shot, } if __name__ == "__main__": diff --git a/pilot/json_utils/json_fix_general.py b/pilot/json_utils/json_fix_general.py index eecf83568..e24d02bbf 100644 --- a/pilot/json_utils/json_fix_general.py +++ b/pilot/json_utils/json_fix_general.py @@ -8,8 +8,8 @@ from typing import Optional from pilot.configs.config import Config -from pilot.logs import logger from pilot.json_utils.utilities import extract_char_position +from pilot.logs import logger CFG = Config() diff --git a/pilot/logs.py b/pilot/logs.py index b5a1fad82..52d25b5fd 100644 --- a/pilot/logs.py +++ b/pilot/logs.py @@ -84,7 +84,7 @@ def __init__(self): self.chat_plugins = [] def typewriter_log( - self, title="", title_color="", content="", speak_text=False, level=logging.INFO + self, title="", title_color="", content="", speak_text=False, level=logging.INFO ): if speak_text and self.speak_mode: say_text(f"{title}. {content}") @@ -103,26 +103,26 @@ def typewriter_log( ) def debug( - self, - message, - title="", - title_color="", + self, + message, + title="", + title_color="", ): self._log(title, title_color, message, logging.DEBUG) def info( - self, - message, - title="", - title_color="", + self, + message, + title="", + title_color="", ): self._log(title, title_color, message, logging.INFO) def warn( - self, - message, - title="", - title_color="", + self, + message, + title="", + title_color="", ): self._log(title, title_color, message, logging.WARN) @@ -130,11 +130,11 @@ def error(self, title, message=""): self._log(title, Fore.RED, message, logging.ERROR) def _log( - self, - title: str = "", - title_color: str = "", - message: str = "", - level=logging.INFO, + self, + title: str = "", + title_color: str = "", + message: str = "", + level=logging.INFO, ): if message: if isinstance(message, list): @@ -178,10 +178,12 @@ def get_log_directory(self): log_dir = os.path.join(this_files_dir_path, "../logs") return os.path.abspath(log_dir) + """ Output stream to console using simulated typing """ + class TypingConsoleHandler(logging.StreamHandler): def emit(self, record): min_typing_speed = 0.05 @@ -203,6 +205,7 @@ def emit(self, record): except Exception: self.handleError(record) + class ConsoleHandler(logging.StreamHandler): def emit(self, record) -> None: msg = self.format(record) @@ -221,10 +224,10 @@ class DbGptFormatter(logging.Formatter): def format(self, record: LogRecord) -> str: if hasattr(record, "color"): record.title_color = ( - getattr(record, "color") - + getattr(record, "title", "") - + " " - + Style.RESET_ALL + getattr(record, "color") + + getattr(record, "title", "") + + " " + + Style.RESET_ALL ) else: record.title_color = getattr(record, "title", "") @@ -248,9 +251,9 @@ def remove_color_codes(s: str) -> str: def print_assistant_thoughts( - ai_name: object, - assistant_reply_json_valid: object, - speak_mode: bool = False, + ai_name: object, + assistant_reply_json_valid: object, + speak_mode: bool = False, ) -> None: assistant_thoughts_reasoning = None assistant_thoughts_plan = None diff --git a/pilot/model/adapter.py b/pilot/model/adapter.py index be8980726..83fad3d5f 100644 --- a/pilot/model/adapter.py +++ b/pilot/model/adapter.py @@ -1,19 +1,16 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -from typing import List from functools import cache +from typing import List -from transformers import ( - AutoTokenizer, - AutoModelForCausalLM, - AutoModel -) +from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer from pilot.configs.model_config import DEVICE + class BaseLLMAdaper: """The Base class for multi model, in our project. - We will support those model, which performance resemble ChatGPT """ + We will support those model, which performance resemble ChatGPT""" def match(self, model_path: str): return True @@ -28,6 +25,7 @@ def loader(self, model_path: str, from_pretrained_kwargs: dict): llm_model_adapters: List[BaseLLMAdaper] = [] + # Register llm models to adapters, by this we can use multi models. def register_llm_model_adapters(cls): """Register a llm model adapter.""" @@ -39,28 +37,30 @@ def get_llm_model_adapter(model_path: str) -> BaseLLMAdaper: for adapter in llm_model_adapters: if adapter.match(model_path): return adapter - + raise ValueError(f"Invalid model adapter for {model_path}") # TODO support cpu? for practise we support gpt4all or chatglm-6b-int4? + class VicunaLLMAdapater(BaseLLMAdaper): - """Vicuna Adapter """ + """Vicuna Adapter""" + def match(self, model_path: str): - return "vicuna" in model_path + return "vicuna" in model_path def loader(self, model_path: str, from_pretrained_kwagrs: dict): tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) model = AutoModelForCausalLM.from_pretrained( - model_path, - low_cpu_mem_usage=True, - **from_pretrained_kwagrs + model_path, low_cpu_mem_usage=True, **from_pretrained_kwagrs ) return model, tokenizer + class ChatGLMAdapater(BaseLLMAdaper): """LLM Adatpter for THUDM/chatglm-6b""" + def match(self, model_path: str): return "chatglm" in model_path @@ -73,37 +73,49 @@ def loader(self, model_path: str, from_pretrained_kwargs: dict): ).float() return model, tokenizer else: - model = AutoModel.from_pretrained( - model_path, trust_remote_code=True, **from_pretrained_kwargs - ).half().cuda() + model = ( + AutoModel.from_pretrained( + model_path, trust_remote_code=True, **from_pretrained_kwargs + ) + .half() + .cuda() + ) return model, tokenizer - + + class CodeGenAdapter(BaseLLMAdaper): pass + class StarCoderAdapter(BaseLLMAdaper): pass + class T5CodeAdapter(BaseLLMAdaper): pass + class KoalaLLMAdapter(BaseLLMAdaper): - """Koala LLM Adapter which Based LLaMA """ + """Koala LLM Adapter which Based LLaMA""" + def match(self, model_path: str): return "koala" in model_path - + class RWKV4LLMAdapter(BaseLLMAdaper): - """LLM Adapter for RwKv4 """ + """LLM Adapter for RwKv4""" + def match(self, model_path: str): return "RWKV-4" in model_path - + def loader(self, model_path: str, from_pretrained_kwargs: dict): # TODO pass + class GPT4AllAdapter(BaseLLMAdaper): """A light version for someone who want practise LLM use laptop.""" + def match(self, model_path: str): return "gpt4all" in model_path @@ -112,4 +124,4 @@ def match(self, model_path: str): register_llm_model_adapters(ChatGLMAdapater) # TODO Default support vicuna, other model need to tests and Evaluate -register_llm_model_adapters(BaseLLMAdaper) \ No newline at end of file +register_llm_model_adapters(BaseLLMAdaper) diff --git a/pilot/model/base.py b/pilot/model/base.py index 8199198eb..ba8190ea3 100644 --- a/pilot/model/base.py +++ b/pilot/model/base.py @@ -1,11 +1,11 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -from typing import List, TypedDict +from typing import TypedDict + class Message(TypedDict): - """LLM Message object containing usually like (role: content) """ + """LLM Message object containing usually like (role: content)""" role: str content: str - diff --git a/pilot/model/chatglm_llm.py b/pilot/model/chatglm_llm.py index 0f8b74efa..4d72af072 100644 --- a/pilot/model/chatglm_llm.py +++ b/pilot/model/chatglm_llm.py @@ -1,14 +1,16 @@ #!/usr/bin/env python3 # -*- coding:utf-8 -*- -import torch +import torch + +from pilot.conversation import ROLE_ASSISTANT, ROLE_USER -from pilot.conversation import ROLE_USER, ROLE_ASSISTANT @torch.inference_mode() -def chatglm_generate_stream(model, tokenizer, params, device, context_len=2048, stream_interval=2): - - """Generate text using chatglm model's chat api """ +def chatglm_generate_stream( + model, tokenizer, params, device, context_len=2048, stream_interval=2 +): + """Generate text using chatglm model's chat api""" prompt = params["prompt"] temperature = float(params.get("temperature", 1.0)) top_p = float(params.get("top_p", 1.0)) @@ -19,31 +21,38 @@ def chatglm_generate_stream(model, tokenizer, params, device, context_len=2048, "do_sample": True if temperature > 1e-5 else False, "top_p": top_p, "repetition_penalty": 1.0, - "logits_processor": None + "logits_processor": None, } if temperature > 1e-5: generate_kwargs["temperature"] = temperature # TODO, Fix this - hist = [] + hist = [] messages = prompt.split(stop) - # Add history chat to hist for model. + # Add history chat to hist for model. for i in range(1, len(messages) - 2, 2): - hist.append((messages[i].split(ROLE_USER + ":")[1], messages[i+1].split(ROLE_ASSISTANT + ":")[1])) + hist.append( + ( + messages[i].split(ROLE_USER + ":")[1], + messages[i + 1].split(ROLE_ASSISTANT + ":")[1], + ) + ) query = messages[-2].split(ROLE_USER + ":")[1] print("Query Message: ", query) output = "" i = 0 - for i, (response, new_hist) in enumerate(model.stream_chat(tokenizer, query, hist, **generate_kwargs)): + for i, (response, new_hist) in enumerate( + model.stream_chat(tokenizer, query, hist, **generate_kwargs) + ): if echo: output = query + " " + response else: output = response - + yield output - yield output \ No newline at end of file + yield output diff --git a/pilot/model/compression.py b/pilot/model/compression.py index 9c8c25d08..83b681c6f 100644 --- a/pilot/model/compression.py +++ b/pilot/model/compression.py @@ -3,14 +3,15 @@ import dataclasses import torch -from torch import Tensor import torch.nn as nn +from torch import Tensor from torch.nn import functional as F @dataclasses.dataclass class CompressionConfig: """Group-wise quantization.""" + num_bits: int group_size: int group_dim: int @@ -19,7 +20,8 @@ class CompressionConfig: default_compression_config = CompressionConfig( - num_bits=8, group_size=256, group_dim=1, symmetric=True, enabled=True) + num_bits=8, group_size=256, group_dim=1, symmetric=True, enabled=True +) class CLinear(nn.Module): @@ -40,8 +42,11 @@ def compress_module(module, target_device): for attr_str in dir(module): target_attr = getattr(module, attr_str) if type(target_attr) == torch.nn.Linear: - setattr(module, attr_str, - CLinear(target_attr.weight, target_attr.bias, target_device)) + setattr( + module, + attr_str, + CLinear(target_attr.weight, target_attr.bias, target_device), + ) for name, child in module.named_children(): compress_module(child, target_device) @@ -52,22 +57,31 @@ def compress(tensor, config): return tensor group_size, num_bits, group_dim, symmetric = ( - config.group_size, config.num_bits, config.group_dim, config.symmetric) + config.group_size, + config.num_bits, + config.group_dim, + config.symmetric, + ) assert num_bits <= 8 original_shape = tensor.shape num_groups = (original_shape[group_dim] + group_size - 1) // group_size - new_shape = (original_shape[:group_dim] + (num_groups, group_size) + - original_shape[group_dim+1:]) + new_shape = ( + original_shape[:group_dim] + + (num_groups, group_size) + + original_shape[group_dim + 1 :] + ) # Pad pad_len = (group_size - original_shape[group_dim] % group_size) % group_size if pad_len != 0: - pad_shape = original_shape[:group_dim] + (pad_len,) + original_shape[group_dim+1:] - tensor = torch.cat([ - tensor, - torch.zeros(pad_shape, dtype=tensor.dtype, device=tensor.device)], - dim=group_dim) + pad_shape = ( + original_shape[:group_dim] + (pad_len,) + original_shape[group_dim + 1 :] + ) + tensor = torch.cat( + [tensor, torch.zeros(pad_shape, dtype=tensor.dtype, device=tensor.device)], + dim=group_dim, + ) data = tensor.view(new_shape) # Quantize @@ -78,7 +92,7 @@ def compress(tensor, config): data = data.clamp_(-B, B).round_().to(torch.int8) return data, scale, original_shape else: - B = 2 ** num_bits - 1 + B = 2**num_bits - 1 mn = torch.min(data, dim=group_dim + 1, keepdim=True)[0] mx = torch.max(data, dim=group_dim + 1, keepdim=True)[0] @@ -96,7 +110,11 @@ def decompress(packed_data, config): return packed_data group_size, num_bits, group_dim, symmetric = ( - config.group_size, config.num_bits, config.group_dim, config.symmetric) + config.group_size, + config.num_bits, + config.group_dim, + config.symmetric, + ) # Dequantize if symmetric: @@ -111,9 +129,10 @@ def decompress(packed_data, config): pad_len = (group_size - original_shape[group_dim] % group_size) % group_size if pad_len: padded_original_shape = ( - original_shape[:group_dim] + - (original_shape[group_dim] + pad_len,) + - original_shape[group_dim+1:]) + original_shape[:group_dim] + + (original_shape[group_dim] + pad_len,) + + original_shape[group_dim + 1 :] + ) data = data.reshape(padded_original_shape) indices = [slice(0, x) for x in original_shape] return data[indices].contiguous() diff --git a/pilot/model/inference.py b/pilot/model/inference.py index a677c0339..042f9954b 100644 --- a/pilot/model/inference.py +++ b/pilot/model/inference.py @@ -3,11 +3,12 @@ import torch -@torch.inference_mode() -def generate_stream(model, tokenizer, params, device, - context_len=4096, stream_interval=2): - """Fork from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/inference.py """ +@torch.inference_mode() +def generate_stream( + model, tokenizer, params, device, context_len=4096, stream_interval=2 +): + """Fork from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/inference.py""" prompt = params["prompt"] l_prompt = len(prompt) temperature = float(params.get("temperature", 1.0)) @@ -22,17 +23,19 @@ def generate_stream(model, tokenizer, params, device, for i in range(max_new_tokens): if i == 0: - out = model( - torch.as_tensor([input_ids], device=device), use_cache=True) + out = model(torch.as_tensor([input_ids], device=device), use_cache=True) logits = out.logits past_key_values = out.past_key_values else: attention_mask = torch.ones( - 1, past_key_values[0][0].shape[-2] + 1, device=device) - out = model(input_ids=torch.as_tensor([[token]], device=device), - use_cache=True, - attention_mask=attention_mask, - past_key_values=past_key_values) + 1, past_key_values[0][0].shape[-2] + 1, device=device + ) + out = model( + input_ids=torch.as_tensor([[token]], device=device), + use_cache=True, + attention_mask=attention_mask, + past_key_values=past_key_values, + ) logits = out.logits past_key_values = out.past_key_values @@ -68,9 +71,12 @@ def generate_stream(model, tokenizer, params, device, del past_key_values + @torch.inference_mode() -def generate_output(model, tokenizer, params, device, context_len=4096, stream_interval=2): - """Fork from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/inference.py """ +def generate_output( + model, tokenizer, params, device, context_len=4096, stream_interval=2 +): + """Fork from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/inference.py""" prompt = params["prompt"] l_prompt = len(prompt) @@ -78,7 +84,6 @@ def generate_output(model, tokenizer, params, device, context_len=4096, stream_i max_new_tokens = int(params.get("max_new_tokens", 2048)) stop_str = params.get("stop", None) - input_ids = tokenizer(prompt).input_ids output_ids = list(input_ids) @@ -87,17 +92,19 @@ def generate_output(model, tokenizer, params, device, context_len=4096, stream_i for i in range(max_new_tokens): if i == 0: - out = model( - torch.as_tensor([input_ids], device=device), use_cache=True) + out = model(torch.as_tensor([input_ids], device=device), use_cache=True) logits = out.logits past_key_values = out.past_key_values else: attention_mask = torch.ones( - 1, past_key_values[0][0].shape[-2] + 1, device=device) - out = model(input_ids=torch.as_tensor([[token]], device=device), - use_cache=True, - attention_mask=attention_mask, - past_key_values=past_key_values) + 1, past_key_values[0][0].shape[-2] + 1, device=device + ) + out = model( + input_ids=torch.as_tensor([[token]], device=device), + use_cache=True, + attention_mask=attention_mask, + past_key_values=past_key_values, + ) logits = out.logits past_key_values = out.past_key_values @@ -120,7 +127,6 @@ def generate_output(model, tokenizer, params, device, context_len=4096, stream_i else: stopped = False - if i % stream_interval == 0 or i == max_new_tokens - 1 or stopped: output = tokenizer.decode(output_ids, skip_special_tokens=True) pos = output.rfind(stop_str, l_prompt) @@ -133,8 +139,11 @@ def generate_output(model, tokenizer, params, device, context_len=4096, stream_i break del past_key_values + @torch.inference_mode() -def generate_output_ex(model, tokenizer, params, device, context_len=2048, stream_interval=2): +def generate_output_ex( + model, tokenizer, params, device, context_len=2048, stream_interval=2 +): prompt = params["prompt"] temperature = float(params.get("temperature", 1.0)) max_new_tokens = int(params.get("max_new_tokens", 2048)) @@ -161,20 +170,20 @@ def generate_output_ex(model, tokenizer, params, device, context_len=2048, strea for i in range(max_new_tokens): if i == 0: - out = model( - torch.as_tensor([input_ids], device=device), use_cache=True) + out = model(torch.as_tensor([input_ids], device=device), use_cache=True) logits = out.logits past_key_values = out.past_key_values else: - out = model(input_ids=torch.as_tensor([[token]], device=device), - use_cache=True, - past_key_values=past_key_values) + out = model( + input_ids=torch.as_tensor([[token]], device=device), + use_cache=True, + past_key_values=past_key_values, + ) logits = out.logits past_key_values = out.past_key_values last_token_logits = logits[0][-1] - if temperature < 1e-4: token = int(torch.argmax(last_token_logits)) else: @@ -188,7 +197,6 @@ def generate_output_ex(model, tokenizer, params, device, context_len=2048, strea else: stopped = False - output = tokenizer.decode(output_ids, skip_special_tokens=True) # print("Partial output:", output) for stop_str in stop_strings: @@ -211,7 +219,7 @@ def generate_output_ex(model, tokenizer, params, device, context_len=2048, strea del past_key_values if pos != -1: return output[:pos] - return output + return output @torch.inference_mode() diff --git a/pilot/model/llm/base.py b/pilot/model/llm/base.py index 435cc0d5f..581837347 100644 --- a/pilot/model/llm/base.py +++ b/pilot/model/llm/base.py @@ -1,12 +1,12 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -from dataclasses import dataclass, field -from typing import List, TypedDict +from dataclasses import dataclass +from typing import TypedDict class Message(TypedDict): - """Vicuna Message object containing a role and the message content """ + """Vicuna Message object containing a role and the message content""" role: str content: str @@ -18,12 +18,15 @@ class ModelInfo: Would be lovely to eventually get this directly from APIs """ + name: str max_tokens: int + @dataclass class LLMResponse: """Standard response struct for a response from a LLM model.""" + model_info = ModelInfo @@ -31,4 +34,4 @@ class LLMResponse: class ChatModelResponse(LLMResponse): """Standard response struct for a response from an LLM model.""" - content: str = None \ No newline at end of file + content: str = None diff --git a/pilot/model/llm/llm_utils.py b/pilot/model/llm/llm_utils.py index a68860ee6..e2bf631cc 100644 --- a/pilot/model/llm/llm_utils.py +++ b/pilot/model/llm/llm_utils.py @@ -2,35 +2,39 @@ # -*- coding: utf-8 -*- import abc -import time import functools -from typing import List, Optional -from pilot.model.llm.base import Message -from pilot.conversation import conv_templates, Conversation, conv_one_shot, auto_dbgpt_one_shot +import time +from typing import Optional + from pilot.configs.config import Config +from pilot.conversation import ( + Conversation, + auto_dbgpt_one_shot, + conv_one_shot, + conv_templates, +) +from pilot.model.llm.base import Message # TODO Rewrite this def retry_stream_api( - num_retries: int = 10, - backoff_base: float = 2.0, - warn_user: bool = True -): + num_retries: int = 10, backoff_base: float = 2.0, warn_user: bool = True +): """Retry an Vicuna Server call. - Args: - num_retries int: Number of retries. Defaults to 10. - backoff_base float: Base for exponential backoff. Defaults to 2. - warn_user bool: Whether to warn the user. Defaults to True. + Args: + num_retries int: Number of retries. Defaults to 10. + backoff_base float: Base for exponential backoff. Defaults to 2. + warn_user bool: Whether to warn the user. Defaults to True. """ retry_limit_msg = f"Error: Reached rate limit, passing..." - backoff_msg = (f"Error: API Bad gateway. Waiting {{backoff}} seconds...") + backoff_msg = f"Error: API Bad gateway. Waiting {{backoff}} seconds..." def _wrapper(func): @functools.wraps(func) def _wrapped(*args, **kwargs): user_warned = not warn_user - num_attempts = num_retries + 1 # +1 for the first attempt + num_attempts = num_retries + 1 # +1 for the first attempt for attempt in range(1, num_attempts + 1): try: return func(*args, **kwargs) @@ -39,10 +43,13 @@ def _wrapped(*args, **kwargs): raise backoff = backoff_base ** (attempt + 2) - time.sleep(backoff) + time.sleep(backoff) + return _wrapped + return _wrapper + # Overly simple abstraction util we create something better # simple retry mechanism when getting a rate error or a bad gateway def create_chat_competion( @@ -52,15 +59,15 @@ def create_chat_competion( max_new_tokens: Optional[int] = None, ) -> str: """Create a chat completion using the Vicuna-13b - - Args: - messages(List[Message]): The messages to send to the chat completion - model (str, optional): The model to use. Default to None. - temperature (float, optional): The temperature to use. Defaults to 0.7. - max_tokens (int, optional): The max tokens to use. Defaults to None. - - Returns: - str: The response from the chat completion + + Args: + messages(List[Message]): The messages to send to the chat completion + model (str, optional): The model to use. Default to None. + temperature (float, optional): The temperature to use. Defaults to 0.7. + max_tokens (int, optional): The max tokens to use. Defaults to None. + + Returns: + str: The response from the chat completion """ cfg = Config() if temperature is None: @@ -77,7 +84,7 @@ class ChatIO(abc.ABC): @abc.abstractmethod def prompt_for_input(self, role: str) -> str: """Prompt for input from a role.""" - + @abc.abstractmethod def prompt_for_output(self, role: str) -> str: """Prompt for output from a role.""" @@ -105,4 +112,3 @@ def stream_output(self, output_stream, skip_echo_len: int): print(" ".join(outputs[pre:]), flush=True) return " ".join(outputs) - diff --git a/pilot/model/llm/monkey_patch.py b/pilot/model/llm/monkey_patch.py index a50481281..f3656a159 100644 --- a/pilot/model/llm/monkey_patch.py +++ b/pilot/model/llm/monkey_patch.py @@ -5,8 +5,8 @@ from typing import Optional, Tuple import torch -from torch import nn import transformers +from torch import nn def rotate_half(x): @@ -116,8 +116,8 @@ def replace_llama_attn_with_non_inplace_operations(): """Avoid bugs in mps backend by not using in-place operations.""" transformers.models.llama.modeling_llama.LlamaAttention.forward = forward -import transformers +import transformers def replace_llama_attn_with_non_inplace_operations(): diff --git a/pilot/model/llm_utils.py b/pilot/model/llm_utils.py index 196246118..118d45f97 100644 --- a/pilot/model/llm_utils.py +++ b/pilot/model/llm_utils.py @@ -2,31 +2,33 @@ # -*- coding:utf-8 -*- from typing import List, Optional -from pilot.model.base import Message + from pilot.configs.config import Config +from pilot.model.base import Message from pilot.server.llmserver import generate_output + def create_chat_completion( - messages: List[Message], # type: ignore + messages: List[Message], # type: ignore model: Optional[str] = None, temperature: float = None, max_tokens: Optional[int] = None, ) -> str: - """Create a chat completion using the vicuna local model - - Args: - messages(List[Message]): The messages to send to the chat completion - model (str, optional): The model to use. Defaults to None. - temperature (float, optional): The temperature to use. Defaults to 0.7. - max_tokens (int, optional): The max tokens to use. Defaults to None - - Returns: - str: The response from chat completion + """Create a chat completion using the vicuna local model + + Args: + messages(List[Message]): The messages to send to the chat completion + model (str, optional): The model to use. Defaults to None. + temperature (float, optional): The temperature to use. Defaults to 0.7. + max_tokens (int, optional): The max tokens to use. Defaults to None + + Returns: + str: The response from chat completion """ cfg = Config() if temperature is None: temperature = cfg.temperature - + for plugin in cfg.plugins: if plugin.can_handle_chat_completion( messages=messages, diff --git a/pilot/model/loader.py b/pilot/model/loader.py index bd31bae0a..a228acbf7 100644 --- a/pilot/model/loader.py +++ b/pilot/model/loader.py @@ -1,16 +1,19 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -import torch import sys import warnings -from pilot.singleton import Singleton from typing import Optional -from pilot.model.compression import compress_module -from pilot.model.adapter import get_llm_model_adapter -from pilot.utils import get_gpu_memory + +import torch + from pilot.configs.model_config import DEVICE +from pilot.model.adapter import get_llm_model_adapter +from pilot.model.compression import compress_module from pilot.model.llm.monkey_patch import replace_llama_attn_with_non_inplace_operations +from pilot.singleton import Singleton +from pilot.utils import get_gpu_memory + def raise_warning_for_incompatible_cpu_offloading_configuration( device: str, load_8bit: bool, cpu_offloading: bool @@ -40,27 +43,31 @@ def raise_warning_for_incompatible_cpu_offloading_configuration( class ModelLoader(metaclass=Singleton): """Model loader is a class for model load - + Args: model_path - TODO: multi model support. + TODO: multi model support. """ kwargs = {} - def __init__(self, - model_path) -> None: - + def __init__(self, model_path) -> None: self.device = DEVICE - self.model_path = model_path + self.model_path = model_path self.kwargs = { "torch_dtype": torch.float16, "device_map": "auto", } # TODO multi gpu support - def loader(self, num_gpus, load_8bit=False, debug=False, cpu_offloading=False, max_gpu_memory: Optional[str]=None): - + def loader( + self, + num_gpus, + load_8bit=False, + debug=False, + cpu_offloading=False, + max_gpu_memory: Optional[str] = None, + ): if self.device == "cpu": kwargs = {"torch_dtype": torch.float32} @@ -72,7 +79,7 @@ def loader(self, num_gpus, load_8bit=False, debug=False, cpu_offloading=False, m kwargs["device_map"] = "auto" if max_gpu_memory is None: kwargs["device_map"] = "sequential" - + available_gpu_memory = get_gpu_memory(num_gpus) kwargs["max_memory"] = { i: str(int(available_gpu_memory[i] * 0.85)) + "GiB" @@ -99,13 +106,14 @@ def loader(self, num_gpus, load_8bit=False, debug=False, cpu_offloading=False, m "8-bit quantization is not supported for multi-gpu inference" ) else: - compress_module(model, self.device) + compress_module(model, self.device) - if (self.device == "cuda" and num_gpus == 1 and not cpu_offloading) or self.device == "mps": + if ( + self.device == "cuda" and num_gpus == 1 and not cpu_offloading + ) or self.device == "mps": model.to(self.device) if debug: print(model) return model, tokenizer - diff --git a/pilot/model/vicuna_llm.py b/pilot/model/vicuna_llm.py index 63788a619..b38249a98 100644 --- a/pilot/model/vicuna_llm.py +++ b/pilot/model/vicuna_llm.py @@ -2,25 +2,34 @@ # -*- coding:utf-8 -*- import json -import requests +from typing import Any, List, Mapping, Optional from urllib.parse import urljoin + +import requests from langchain.embeddings.base import Embeddings -from pydantic import BaseModel -from typing import Any, Mapping, Optional, List from langchain.llms.base import LLM +from pydantic import BaseModel + from pilot.configs.config import Config CFG = Config() -class VicunaLLM(LLM): + +class VicunaLLM(LLM): vicuna_generate_path = "generate_stream" - def _call(self, prompt: str, temperature: float, max_new_tokens: int, stop: Optional[List[str]] = None) -> str: + def _call( + self, + prompt: str, + temperature: float, + max_new_tokens: int, + stop: Optional[List[str]] = None, + ) -> str: params = { "prompt": prompt, "temperature": temperature, "max_new_tokens": max_new_tokens, - "stop": stop + "stop": stop, } response = requests.post( url=urljoin(CFG.MODEL_SERVER, self.vicuna_generate_path), @@ -41,10 +50,9 @@ def _llm_type(self) -> str: def _identifying_params(self) -> Mapping[str, Any]: return {} - + class VicunaEmbeddingLLM(BaseModel, Embeddings): - vicuna_embedding_path = "embedding" def _call(self, prompt: str) -> str: @@ -53,15 +61,13 @@ def _call(self, prompt: str) -> str: response = requests.post( url=urljoin(CFG.MODEL_SERVER, self.vicuna_embedding_path), - json={ - "prompt": p - } + json={"prompt": p}, ) response.raise_for_status() return response.json()["response"] def embed_documents(self, texts: List[str]) -> List[List[float]]: - """ Call out to Vicuna's server embedding endpoint for embedding search docs. + """Call out to Vicuna's server embedding endpoint for embedding search docs. Args: texts: The list of text to embed @@ -73,17 +79,15 @@ def embed_documents(self, texts: List[str]) -> List[List[float]]: for text in texts: response = self.embed_query(text) results.append(response) - return results - + return results def embed_query(self, text: str) -> List[float]: - """ Call out to Vicuna's server embedding endpoint for embedding query text. - - Args: + """Call out to Vicuna's server embedding endpoint for embedding query text. + + Args: text: The text to embed. Returns: Embedding for the text """ embedding = self._call(text) return embedding - diff --git a/pilot/plugins.py b/pilot/plugins.py index 28f33a5a4..f1cd1c962 100644 --- a/pilot/plugins.py +++ b/pilot/plugins.py @@ -1,11 +1,10 @@ """加载组件""" -import importlib import json import os import zipfile from pathlib import Path -from typing import List, Optional, Tuple +from typing import List from urllib.parse import urlparse from zipimport import zipimporter @@ -15,6 +14,7 @@ from pilot.configs.config import Config from pilot.logs import logger + def inspect_zip_for_modules(zip_path: str, debug: bool = False) -> list[str]: """ Loader zip plugin file. Native support Auto_gpt_plugin @@ -36,6 +36,7 @@ def inspect_zip_for_modules(zip_path: str, debug: bool = False) -> list[str]: logger.debug(f"Module '__init__.py' not found in the zipfile @ {zip_path}.") return result + def write_dict_to_json_file(data: dict, file_path: str) -> None: """ Write a dictionary to a JSON file. @@ -46,6 +47,7 @@ def write_dict_to_json_file(data: dict, file_path: str) -> None: with open(file_path, "w") as file: json.dump(data, file, indent=4) + def create_directory_if_not_exists(directory_path: str) -> bool: """ Create a directory if it does not exist. @@ -66,6 +68,7 @@ def create_directory_if_not_exists(directory_path: str) -> bool: logger.info(f"Directory {directory_path} already exists") return True + def scan_plugins(cfg: Config, debug: bool = False) -> List[AutoGPTPluginTemplate]: """Scan the plugins directory for plugins and loads them. diff --git a/pilot/prompts/auto_mode_prompt.py b/pilot/prompts/auto_mode_prompt.py index 86f707783..b47d24a76 100644 --- a/pilot/prompts/auto_mode_prompt.py +++ b/pilot/prompts/auto_mode_prompt.py @@ -1,19 +1,21 @@ -from pilot.prompts.generator import PromptGenerator -from typing import Any, Optional, Type -import os import platform -from pathlib import Path +from typing import Optional import distro import yaml + from pilot.configs.config import Config -from pilot.prompts.prompt import build_default_prompt_generator, DEFAULT_PROMPT_OHTER, DEFAULT_TRIGGERING_PROMPT +from pilot.prompts.generator import PromptGenerator +from pilot.prompts.prompt import ( + DEFAULT_PROMPT_OHTER, + DEFAULT_TRIGGERING_PROMPT, + build_default_prompt_generator, +) class AutoModePrompt: - """ + """ """ - """ def __init__( self, ai_goals: list | None = None, @@ -36,23 +38,21 @@ def __init__( self.command_registry = None def construct_follow_up_prompt( - self, - user_input:[str], - last_auto_return: str = None, - prompt_generator: Optional[PromptGenerator] = None - )-> str: + self, + user_input: [str], + last_auto_return: str = None, + prompt_generator: Optional[PromptGenerator] = None, + ) -> str: """ - Build complete prompt information based on subsequent dialogue information entered by the user - Args: - self: - prompt_generator: + Build complete prompt information based on subsequent dialogue information entered by the user + Args: + self: + prompt_generator: - Returns: + Returns: - """ - prompt_start = ( - DEFAULT_PROMPT_OHTER - ) + """ + prompt_start = DEFAULT_PROMPT_OHTER if prompt_generator is None: prompt_generator = build_default_prompt_generator() prompt_generator.goals = user_input @@ -64,12 +64,13 @@ def construct_follow_up_prompt( continue prompt_generator = plugin.post_prompt(prompt_generator) - full_prompt = f"{prompt_start}\n\nGOALS:\n\n" - if not self.ai_goals : + if not self.ai_goals: self.ai_goals = user_input for i, goal in enumerate(self.ai_goals): - full_prompt += f"{i+1}.According to the provided Schema information, {goal}\n" + full_prompt += ( + f"{i+1}.According to the provided Schema information, {goal}\n" + ) # if last_auto_return == None: # full_prompt += f"{cfg.last_plugin_return}\n\n" # else: @@ -82,10 +83,10 @@ def construct_follow_up_prompt( return full_prompt def construct_first_prompt( - self, - fisrt_message: [str]=[], - db_schemes: str=None, - prompt_generator: Optional[PromptGenerator] = None + self, + fisrt_message: [str] = [], + db_schemes: str = None, + prompt_generator: Optional[PromptGenerator] = None, ) -> str: """ Build complete prompt information based on the initial dialogue information entered by the user @@ -125,16 +126,18 @@ def construct_first_prompt( # Construct full prompt full_prompt = f"{prompt_start}\n\nGOALS:\n\n" - if not self.ai_goals : + if not self.ai_goals: self.ai_goals = fisrt_message for i, goal in enumerate(self.ai_goals): - full_prompt += f"{i+1}.According to the provided Schema information,{goal}\n" - if db_schemes: - full_prompt += f"\nSchema:\n\n" + full_prompt += ( + f"{i+1}.According to the provided Schema information,{goal}\n" + ) + if db_schemes: + full_prompt += f"\nSchema:\n\n" full_prompt += f"{db_schemes}" # if self.api_budget > 0.0: # full_prompt += f"\nIt takes money to let you run. Your API budget is ${self.api_budget:.3f}" self.prompt_generator = prompt_generator full_prompt += f"\n\n{prompt_generator.generate_prompt_string()}" - return full_prompt \ No newline at end of file + return full_prompt diff --git a/pilot/prompts/generator.py b/pilot/prompts/generator.py index fc45f9512..c470ff5a5 100644 --- a/pilot/prompts/generator.py +++ b/pilot/prompts/generator.py @@ -149,7 +149,7 @@ def generate_prompt_string(self) -> str: f"Resources:\n{self._generate_numbered_list(self.resources)}\n\n" "Performance Evaluation:\n" f"{self._generate_numbered_list(self.performance_evaluation)}\n\n" - "You should only respond in JSON format as described below and ensure the" + "You should only respond in JSON format as described below and ensure the" "response can be parsed by Python json.loads \nResponse" f" Format: \n{formatted_response_format}" ) diff --git a/pilot/prompts/prompt.py b/pilot/prompts/prompt.py index 8d050adf4..d46b69ad5 100644 --- a/pilot/prompts/prompt.py +++ b/pilot/prompts/prompt.py @@ -1,17 +1,14 @@ - from pilot.configs.config import Config from pilot.prompts.generator import PromptGenerator - CFG = Config() DEFAULT_TRIGGERING_PROMPT = ( "Determine which next command to use, and respond using the format specified above" ) -DEFAULT_PROMPT_OHTER = ( - "Previous response was excellent. Please response according to the requirements based on the new goal" -) +DEFAULT_PROMPT_OHTER = "Previous response was excellent. Please response according to the requirements based on the new goal" + def build_default_prompt_generator() -> PromptGenerator: """ @@ -36,17 +33,15 @@ def build_default_prompt_generator() -> PromptGenerator: ) # prompt_generator.add_constraint("No user assistance") - prompt_generator.add_constraint( - 'Only output one correct JSON response at a time' - ) + prompt_generator.add_constraint("Only output one correct JSON response at a time") prompt_generator.add_constraint( 'Exclusively use the commands listed in double quotes e.g. "command name"' ) prompt_generator.add_constraint( - 'If there is SQL in the args parameter, ensure to use the database and table definitions in Schema, and ensure that the fields and table names are in the definition' + "If there is SQL in the args parameter, ensure to use the database and table definitions in Schema, and ensure that the fields and table names are in the definition" ) prompt_generator.add_constraint( - 'The generated command args need to comply with the definition of the command' + "The generated command args need to comply with the definition of the command" ) # Add resources to the PromptGenerator object diff --git a/pilot/pturning/lora/finetune.py b/pilot/pturning/lora/finetune.py index 91ec07d0a..c661e4405 100644 --- a/pilot/pturning/lora/finetune.py +++ b/pilot/pturning/lora/finetune.py @@ -1,26 +1,24 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -import os import json -import transformers -from transformers import LlamaTokenizer, LlamaForCausalLM +import os -from typing import List +import pandas as pd +import torch +import transformers +from datasets import load_dataset from peft import ( LoraConfig, get_peft_model, get_peft_model_state_dict, prepare_model_for_int8_training, ) +from transformers import LlamaForCausalLM, LlamaTokenizer -import torch -from datasets import load_dataset -import pandas as pd from pilot.configs.config import Config - - from pilot.configs.model_config import DATA_DIR, LLM_MODEL_CONFIG + device = "cuda" if torch.cuda.is_available() else "cpu" CUTOFF_LEN = 50 @@ -28,6 +26,7 @@ CFG = Config() + def sentiment_score_to_name(score: float): if score > 0: return "Positive" @@ -40,16 +39,18 @@ def sentiment_score_to_name(score: float): { "instruction": "Detect the sentiment of the tweet.", "input": row_dict["Tweet"], - "output": sentiment_score_to_name(row_dict["New_Sentiment_State"]) - } + "output": sentiment_score_to_name(row_dict["New_Sentiment_State"]), + } for row_dict in df.to_dict(orient="records") ] with open(os.path.join(DATA_DIR, "alpaca-bitcoin-sentiment-dataset.json"), "w") as f: - json.dump(dataset_data, f) + json.dump(dataset_data, f) -data = load_dataset("json", data_files=os.path.join(DATA_DIR, "alpaca-bitcoin-sentiment-dataset.json")) +data = load_dataset( + "json", data_files=os.path.join(DATA_DIR, "alpaca-bitcoin-sentiment-dataset.json") +) print(data["train"]) BASE_MODEL = LLM_MODEL_CONFIG[CFG.LLM_MODEL] @@ -57,13 +58,14 @@ def sentiment_score_to_name(score: float): BASE_MODEL, torch_dtype=torch.float16, device_map="auto", - offload_folder=os.path.join(DATA_DIR, "vicuna-lora") -) + offload_folder=os.path.join(DATA_DIR, "vicuna-lora"), +) tokenizer = LlamaTokenizer.from_pretrained(BASE_MODEL) -tokenizer.pad_token_id = (0) +tokenizer.pad_token_id = 0 tokenizer.padding_side = "left" + def generate_prompt(data_point): return f"""Blow is an instruction that describes a task, paired with an input that provide future context. Write a response that appropriately completes the request. #noqa: @@ -76,6 +78,7 @@ def generate_prompt(data_point): {data_point["output"]} """ + def tokenize(prompt, add_eos_token=True): result = tokenizer( prompt, @@ -85,30 +88,29 @@ def tokenize(prompt, add_eos_token=True): return_tensors=None, ) - if (result["input_ids"][-1] != tokenizer.eos_token_id and len(result["input_ids"]) < CUTOFF_LEN and add_eos_token): + if ( + result["input_ids"][-1] != tokenizer.eos_token_id + and len(result["input_ids"]) < CUTOFF_LEN + and add_eos_token + ): result["input_ids"].append(tokenizer.eos_token_id) result["attention_mask"].append(1) result["labels"] = result["input_ids"].copy() return result + def generate_and_tokenize_prompt(data_point): full_prompt = generate_prompt(data_point) tokenized_full_prompt = tokenize(full_prompt) return tokenized_full_prompt -train_val = data["train"].train_test_split( - test_size=200, shuffle=True, seed=42 -) +train_val = data["train"].train_test_split(test_size=200, shuffle=True, seed=42) -train_data = ( - train_val["train"].map(generate_and_tokenize_prompt) -) +train_data = train_val["train"].map(generate_and_tokenize_prompt) -val_data = ( - train_val["test"].map(generate_and_tokenize_prompt) -) +val_data = train_val["test"].map(generate_and_tokenize_prompt) # Training LORA_R = 8 @@ -129,7 +131,7 @@ def generate_and_tokenize_prompt(data_point): # We can now prepare model for training model = prepare_model_for_int8_training(model) config = LoraConfig( - r = LORA_R, + r=LORA_R, lora_alpha=LORA_ALPHA, target_modules=LORA_TARGET_MODULES, lora_dropout=LORA_DROPOUT, @@ -156,7 +158,7 @@ def generate_and_tokenize_prompt(data_point): output_dir=OUTPUT_DIR, save_total_limit=3, load_best_model_at_end=True, - report_to="tensorboard" + report_to="tensorboard", ) data_collector = transformers.DataCollatorForSeq2Seq( @@ -168,15 +170,13 @@ def generate_and_tokenize_prompt(data_point): train_dataset=train_data, eval_dataset=val_data, args=training_arguments, - data_collector=data_collector + data_collector=data_collector, ) model.config.use_cache = False old_state_dict = model.state_dict model.state_dict = ( - lambda self, *_, **__: get_peft_model_state_dict( - self, old_state_dict() - ) + lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict()) ).__get__(model, type(model)) trainer.train() diff --git a/pilot/server/chat_adapter.py b/pilot/server/chat_adapter.py index 805cacb3d..b7e102be3 100644 --- a/pilot/server/chat_adapter.py +++ b/pilot/server/chat_adapter.py @@ -1,10 +1,12 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -from typing import List from functools import cache +from typing import List + from pilot.model.inference import generate_stream + class BaseChatAdpter: """The Base class for chat with llm models. it will match the model, and fetch output from model""" @@ -15,7 +17,7 @@ def match(self, model_path: str): def get_generate_stream_func(self): """Return the generate stream handler func""" pass - + llm_model_chat_adapters: List[BaseChatAdpter] = [] @@ -31,13 +33,14 @@ def get_llm_chat_adapter(model_path: str) -> BaseChatAdpter: for adapter in llm_model_chat_adapters: if adapter.match(model_path): return adapter - + raise ValueError(f"Invalid model for chat adapter {model_path}") class VicunaChatAdapter(BaseChatAdpter): - """ Model chat Adapter for vicuna""" + """Model chat Adapter for vicuna""" + def match(self, model_path: str): return "vicuna" in model_path @@ -46,37 +49,42 @@ def get_generate_stream_func(self): class ChatGLMChatAdapter(BaseChatAdpter): - """ Model chat Adapter for ChatGLM""" + """Model chat Adapter for ChatGLM""" + def match(self, model_path: str): return "chatglm" in model_path def get_generate_stream_func(self): from pilot.model.chatglm_llm import chatglm_generate_stream + return chatglm_generate_stream class CodeT5ChatAdapter(BaseChatAdpter): - """ Model chat adapter for CodeT5 """ + """Model chat adapter for CodeT5""" + def match(self, model_path: str): return "codet5" in model_path - + def get_generate_stream_func(self): # TODO pass + class CodeGenChatAdapter(BaseChatAdpter): - - """ Model chat adapter for CodeGen """ + + """Model chat adapter for CodeGen""" + def match(self, model_path: str): return "codegen" in model_path - + def get_generate_stream_func(self): - # TODO + # TODO pass register_llm_model_chat_adapter(VicunaChatAdapter) register_llm_model_chat_adapter(ChatGLMChatAdapter) -register_llm_model_chat_adapter(BaseChatAdpter) \ No newline at end of file +register_llm_model_chat_adapter(BaseChatAdpter) diff --git a/pilot/server/gradio_css.py b/pilot/server/gradio_css.py index 97706df3f..b0a3892ae 100644 --- a/pilot/server/gradio_css.py +++ b/pilot/server/gradio_css.py @@ -1,8 +1,7 @@ #!/usr/bin/env python3 # -*- coding:utf-8 -*- -code_highlight_css = ( -""" +code_highlight_css = """ #chatbot .hll { background-color: #ffffcc } #chatbot .c { color: #408080; font-style: italic } #chatbot .err { border: 1px solid #FF0000 } @@ -71,6 +70,5 @@ #chatbot .vi { color: #19177C } #chatbot .vm { color: #19177C } #chatbot .il { color: #666666 } -""") -#.highlight { background: #f8f8f8; } - +""" +# .highlight { background: #f8f8f8; } diff --git a/pilot/server/gradio_patch.py b/pilot/server/gradio_patch.py index a915760eb..ca3974cbd 100644 --- a/pilot/server/gradio_patch.py +++ b/pilot/server/gradio_patch.py @@ -49,7 +49,7 @@ def __init__( warnings.warn( "The 'color_map' parameter has been deprecated.", ) - #self.md = utils.get_markdown_parser() + # self.md = utils.get_markdown_parser() self.md = Markdown(extras=["fenced-code-blocks", "tables", "break-on-newline"]) self.select: EventListenerMethod """ @@ -112,7 +112,7 @@ def _process_chat_messages( ): # This happens for previously processed messages return chat_message elif isinstance(chat_message, str): - #return self.md.render(chat_message) + # return self.md.render(chat_message) return str(self.md.convert(chat_message)) else: raise ValueError(f"Invalid message for Chatbot component: {chat_message}") @@ -141,9 +141,10 @@ def postprocess( ), f"Expected a list of lists of length 2 or list of tuples of length 2. Received: {message_pair}" processed_messages.append( ( - #self._process_chat_messages(message_pair[0]), - '
' + - message_pair[0] + "", + # self._process_chat_messages(message_pair[0]), + '
' + + message_pair[0] + + "", self._process_chat_messages(message_pair[1]), ) ) @@ -163,5 +164,3 @@ def style(self, height: int | None = None, **kwargs): **kwargs, ) return self - - diff --git a/pilot/server/llmserver.py b/pilot/server/llmserver.py index bc227d518..ac6b7cac9 100644 --- a/pilot/server/llmserver.py +++ b/pilot/server/llmserver.py @@ -1,13 +1,13 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -import os -import uvicorn import asyncio import json +import os import sys -from typing import Optional, List -from fastapi import FastAPI, Request, BackgroundTasks + +import uvicorn +from fastapi import BackgroundTasks, FastAPI, Request from fastapi.responses import StreamingResponse from pydantic import BaseModel @@ -17,28 +17,26 @@ ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.append(ROOT_PATH) -from pilot.model.inference import generate_stream -from pilot.model.inference import generate_output, get_embeddings - -from pilot.model.loader import ModelLoader +from pilot.configs.config import Config from pilot.configs.model_config import * -from pilot.configs.config import Config +from pilot.model.inference import generate_output, generate_stream, get_embeddings +from pilot.model.loader import ModelLoader from pilot.server.chat_adapter import get_llm_chat_adapter - CFG = Config() -class ModelWorker: +class ModelWorker: def __init__(self, model_path, model_name, device, num_gpus=1): - if model_path.endswith("/"): model_path = model_path[:-1] self.model_name = model_name or model_path.split("/")[-1] self.device = device self.ml = ModelLoader(model_path=model_path) - self.model, self.tokenizer = self.ml.loader(num_gpus, load_8bit=ISLOAD_8BIT, debug=ISDEBUG) + self.model, self.tokenizer = self.ml.loader( + num_gpus, load_8bit=ISLOAD_8BIT, debug=ISDEBUG + ) if hasattr(self.model.config, "max_sequence_length"): self.context_len = self.model.config.max_sequence_length @@ -47,24 +45,28 @@ def __init__(self, model_path, model_name, device, num_gpus=1): else: self.context_len = 2048 - + self.llm_chat_adapter = get_llm_chat_adapter(model_path) - self.generate_stream_func = self.llm_chat_adapter.get_generate_stream_func() + self.generate_stream_func = self.llm_chat_adapter.get_generate_stream_func() def get_queue_length(self): - if model_semaphore is None or model_semaphore._value is None or model_semaphore._waiters is None: + if ( + model_semaphore is None + or model_semaphore._value is None + or model_semaphore._waiters is None + ): return 0 else: - CFG.LIMIT_MODEL_CONCURRENCY - model_semaphore._value + len(model_semaphore._waiters) + ( + CFG.LIMIT_MODEL_CONCURRENCY + - model_semaphore._value + + len(model_semaphore._waiters) + ) def generate_stream_gate(self, params): try: for output in self.generate_stream_func( - self.model, - self.tokenizer, - params, - DEVICE, - CFG.MAX_POSITION_EMBEDDINGS + self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS ): print("output: ", output) ret = { @@ -74,17 +76,16 @@ def generate_stream_gate(self, params): yield json.dumps(ret).encode() + b"\0" except torch.cuda.CudaError: - ret = { - "text": "**GPU OutOfMemory, Please Refresh.**", - "error_code": 0 - } + ret = {"text": "**GPU OutOfMemory, Please Refresh.**", "error_code": 0} yield json.dumps(ret).encode() + b"\0" def get_embeddings(self, prompt): return get_embeddings(self.model, self.tokenizer, prompt) + app = FastAPI() + class PromptRequest(BaseModel): prompt: str temperature: float @@ -92,6 +93,7 @@ class PromptRequest(BaseModel): model: str stop: str = None + class StreamRequest(BaseModel): model: str prompt: str @@ -99,9 +101,11 @@ class StreamRequest(BaseModel): max_new_tokens: int stop: str + class EmbeddingRequest(BaseModel): prompt: str + def release_model_semaphore(): model_semaphore.release() @@ -114,23 +118,24 @@ async def api_generate_stream(request: Request): if model_semaphore is None: model_semaphore = asyncio.Semaphore(CFG.LIMIT_MODEL_CONCURRENCY) - await model_semaphore.acquire() + await model_semaphore.acquire() generator = worker.generate_stream_gate(params) background_tasks = BackgroundTasks() background_tasks.add_task(release_model_semaphore) return StreamingResponse(generator, background=background_tasks) + @app.post("/generate") def generate(prompt_request: PromptRequest): params = { "prompt": prompt_request.prompt, "temperature": prompt_request.temperature, "max_new_tokens": prompt_request.max_new_tokens, - "stop": prompt_request.stop + "stop": prompt_request.stop, } - response = [] + response = [] rsp_str = "" output = worker.generate_stream_gate(params) for rsp in output: @@ -140,7 +145,7 @@ def generate(prompt_request: PromptRequest): response.append(rsp_str) return {"response": rsp_str} - + @app.post("/embedding") def embeddings(prompt_request: EmbeddingRequest): @@ -151,16 +156,11 @@ def embeddings(prompt_request: EmbeddingRequest): if __name__ == "__main__": - model_path = LLM_MODEL_CONFIG[CFG.LLM_MODEL] print(model_path, DEVICE) - - + worker = ModelWorker( - model_path=model_path, - model_name=CFG.LLM_MODEL, - device=DEVICE, - num_gpus=1 + model_path=model_path, model_name=CFG.LLM_MODEL, device=DEVICE, num_gpus=1 ) - uvicorn.run(app, host="0.0.0.0", port=CFG.MODEL_PORT, log_level="info") \ No newline at end of file + uvicorn.run(app, host="0.0.0.0", port=CFG.MODEL_PORT, log_level="info") diff --git a/pilot/server/vectordb_qa.py b/pilot/server/vectordb_qa.py index 71a9b881d..6bf0b4688 100644 --- a/pilot/server/vectordb_qa.py +++ b/pilot/server/vectordb_qa.py @@ -1,29 +1,30 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -from pilot.vector_store.file_loader import KnownLedge2Vector from langchain.prompts import PromptTemplate -from pilot.conversation import conv_qa_prompt_template + from pilot.configs.model_config import VECTOR_SEARCH_TOP_K +from pilot.conversation import conv_qa_prompt_template from pilot.model.vicuna_llm import VicunaLLM +from pilot.vector_store.file_loader import KnownLedge2Vector -class KnownLedgeBaseQA: +class KnownLedgeBaseQA: def __init__(self) -> None: k2v = KnownLedge2Vector() self.vector_store = k2v.init_vector_store() self.llm = VicunaLLM() - + def get_similar_answer(self, query): - prompt = PromptTemplate( - template=conv_qa_prompt_template, - input_variables=["context", "question"] + template=conv_qa_prompt_template, input_variables=["context", "question"] ) - retriever = self.vector_store.as_retriever(search_kwargs={"k": VECTOR_SEARCH_TOP_K}) + retriever = self.vector_store.as_retriever( + search_kwargs={"k": VECTOR_SEARCH_TOP_K} + ) docs = retriever.get_relevant_documents(query=query) - context = [d.page_content for d in docs] + context = [d.page_content for d in docs] result = prompt.format(context="\n".join(context), question=query) return result diff --git a/pilot/server/webserver.py b/pilot/server/webserver.py index 1ac32ab26..15d360ec7 100644 --- a/pilot/server/webserver.py +++ b/pilot/server/webserver.py @@ -2,58 +2,55 @@ # -*- coding: utf-8 -*- import argparse +import datetime +import json import os import shutil -import uuid -import json import sys import time -import gradio as gr -import datetime -import requests +import uuid from urllib.parse import urljoin +import gradio as gr +import requests from langchain import PromptTemplate - ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.append(ROOT_PATH) -from pilot.configs.model_config import KNOWLEDGE_UPLOAD_ROOT_PATH, LLM_MODEL_CONFIG, VECTOR_SEARCH_TOP_K -from pilot.server.vectordb_qa import KnownLedgeBaseQA -from pilot.connections.mysql import MySQLOperator -from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding -from pilot.vector_store.extract_tovec import get_vector_storelist, load_knownledge_from_doc, knownledge_tovec_st - -from pilot.configs.model_config import LOGDIR, DATASETS_DIR - -from pilot.plugins import scan_plugins -from pilot.configs.config import Config +from pilot.commands.command import execute_ai_response_json from pilot.commands.command_mange import CommandRegistry -from pilot.prompts.auto_mode_prompt import AutoModePrompt -from pilot.prompts.generator import PromptGenerator - from pilot.commands.exception_not_commands import NotCommands - - - +from pilot.configs.config import Config +from pilot.configs.model_config import ( + DATASETS_DIR, + KNOWLEDGE_UPLOAD_ROOT_PATH, + LLM_MODEL_CONFIG, + LOGDIR, + VECTOR_SEARCH_TOP_K, +) +from pilot.connections.mysql import MySQLOperator from pilot.conversation import ( - default_conversation, + SeparatorStyle, + conv_qa_prompt_template, conv_templates, - conversation_types, conversation_sql_mode, - SeparatorStyle, conv_qa_prompt_template -) - -from pilot.utils import ( - build_logger, - server_error_msg, + conversation_types, + default_conversation, ) - +from pilot.plugins import scan_plugins +from pilot.prompts.auto_mode_prompt import AutoModePrompt +from pilot.prompts.generator import PromptGenerator from pilot.server.gradio_css import code_highlight_css from pilot.server.gradio_patch import Chatbot as grChatbot - -from pilot.commands.command import execute_ai_response_json +from pilot.server.vectordb_qa import KnownLedgeBaseQA +from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding +from pilot.utils import build_logger, server_error_msg +from pilot.vector_store.extract_tovec import ( + get_vector_storelist, + knownledge_tovec_st, + load_knownledge_from_doc, +) logger = build_logger("webserver", LOGDIR + "webserver.log") headers = {"User-Agent": "dbgpt Client"} @@ -70,19 +67,19 @@ vector_store_client = None vector_store_name = {"vs_name": ""} -priority = { - "vicuna-13b": "aaa" -} +priority = {"vicuna-13b": "aaa"} # 加载插件 -CFG= Config() +CFG = Config() DB_SETTINGS = { "user": CFG.LOCAL_DB_USER, - "password": CFG.LOCAL_DB_PASSWORD, + "password": CFG.LOCAL_DB_PASSWORD, "host": CFG.LOCAL_DB_HOST, - "port": CFG.LOCAL_DB_PORT + "port": CFG.LOCAL_DB_PORT, } + + def get_simlar(q): docsearch = knownledge_tovec_st(os.path.join(DATASETS_DIR, "plan.md")) docs = docsearch.similarity_search_with_score(q, k=1) @@ -92,9 +89,7 @@ def get_simlar(q): def gen_sqlgen_conversation(dbname): - mo = MySQLOperator( - **DB_SETTINGS - ) + mo = MySQLOperator(**DB_SETTINGS) message = "" @@ -132,13 +127,15 @@ def load_demo(url_params, request: gr.Request): gr.Dropdown.update(choices=dbs) state = default_conversation.copy() - return (state, - dropdown_update, - gr.Chatbot.update(visible=True), - gr.Textbox.update(visible=True), - gr.Button.update(visible=True), - gr.Row.update(visible=True), - gr.Accordion.update(visible=True)) + return ( + state, + dropdown_update, + gr.Chatbot.update(visible=True), + gr.Textbox.update(visible=True), + gr.Button.update(visible=True), + gr.Row.update(visible=True), + gr.Accordion.update(visible=True), + ) def get_conv_log_filename(): @@ -185,7 +182,9 @@ def post_process_code(code): return code -def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, request: gr.Request): +def http_bot( + state, mode, sql_mode, db_selector, temperature, max_new_tokens, request: gr.Request +): if sql_mode == conversation_sql_mode["auto_execute_ai_response"]: print("AUTO DB-GPT模式.") if sql_mode == conversation_sql_mode["dont_execute_ai_response"]: @@ -212,12 +211,13 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re # 第一轮对话需要加入提示Prompt if sql_mode == conversation_sql_mode["auto_execute_ai_response"]: # autogpt模式的第一轮对话需要 构建专属prompt - system_prompt = auto_prompt.construct_first_prompt(fisrt_message=[query], - db_schemes=gen_sqlgen_conversation(dbname)) + system_prompt = auto_prompt.construct_first_prompt( + fisrt_message=[query], db_schemes=gen_sqlgen_conversation(dbname) + ) logger.info("[TEST]:" + system_prompt) template_name = "auto_dbgpt_one_shot" new_state = conv_templates[template_name].copy() - new_state.append_message(role='USER', message=system_prompt) + new_state.append_message(role="USER", message=system_prompt) # new_state.append_message(new_state.roles[0], query) new_state.append_message(new_state.roles[1], None) else: @@ -226,7 +226,9 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re # prompt 中添加上下文提示, 根据已有知识对话, 上下文提示是否也应该放在第一轮, 还是每一轮都添加上下文? # 如果用户侧的问题跨度很大, 应该每一轮都加提示。 if db_selector: - new_state.append_message(new_state.roles[0], gen_sqlgen_conversation(dbname) + query) + new_state.append_message( + new_state.roles[0], gen_sqlgen_conversation(dbname) + query + ) new_state.append_message(new_state.roles[1], None) else: new_state.append_message(new_state.roles[0], query) @@ -244,7 +246,9 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re # prompt 中添加上下文提示, 根据已有知识对话, 上下文提示是否也应该放在第一轮, 还是每一轮都添加上下文? # 如果用户侧的问题跨度很大, 应该每一轮都加提示。 if db_selector: - new_state.append_message(new_state.roles[0], gen_sqlgen_conversation(dbname) + query) + new_state.append_message( + new_state.roles[0], gen_sqlgen_conversation(dbname) + query + ) new_state.append_message(new_state.roles[1], None) else: new_state.append_message(new_state.roles[0], query) @@ -268,17 +272,22 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re if mode == conversation_types["custome"] and not db_selector: print("vector store name: ", vector_store_name["vs_name"]) - vector_store_config = {"vector_store_name": vector_store_name["vs_name"], "text_field": "content", - "vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH} - knowledge_embedding_client = KnowledgeEmbedding(file_path="", model_name=LLM_MODEL_CONFIG["text2vec"], - local_persist=False, - vector_store_config=vector_store_config) + vector_store_config = { + "vector_store_name": vector_store_name["vs_name"], + "text_field": "content", + "vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH, + } + knowledge_embedding_client = KnowledgeEmbedding( + file_path="", + model_name=LLM_MODEL_CONFIG["text2vec"], + local_persist=False, + vector_store_config=vector_store_config, + ) query = state.messages[-2][1] docs = knowledge_embedding_client.similar_search(query, VECTOR_SEARCH_TOP_K) context = [d.page_content for d in docs] prompt_template = PromptTemplate( - template=conv_qa_prompt_template, - input_variables=["context", "question"] + template=conv_qa_prompt_template, input_variables=["context", "question"] ) result = prompt_template.format(context="\n".join(context), question=query) state.messages[-2][1] = result @@ -290,7 +299,7 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re context = context[:2000] prompt_template = PromptTemplate( template=conv_qa_prompt_template, - input_variables=["context", "question"] + input_variables=["context", "question"], ) result = prompt_template.format(context="\n".join(context), question=query) state.messages[-2][1] = result @@ -311,8 +320,12 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re logger.info(f"Requert: \n{payload}") if sql_mode == conversation_sql_mode["auto_execute_ai_response"]: - response = requests.post(urljoin(CFG.MODEL_SERVER, "generate"), - headers=headers, json=payload, timeout=120) + response = requests.post( + urljoin(CFG.MODEL_SERVER, "generate"), + headers=headers, + json=payload, + timeout=120, + ) print(response.json()) print(str(response)) @@ -321,17 +334,17 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re text = text.rstrip() respObj = json.loads(text) - xx = respObj['response'] - xx = xx.strip(b'\x00'.decode()) + xx = respObj["response"] + xx = xx.strip(b"\x00".decode()) respObj_ex = json.loads(xx) - if respObj_ex['error_code'] == 0: + if respObj_ex["error_code"] == 0: ai_response = None - all_text = respObj_ex['text'] + all_text = respObj_ex["text"] ### 解析返回文本,获取AI回复部分 tmpResp = all_text.split(state.sep) last_index = -1 for i in range(len(tmpResp)): - if tmpResp[i].find('ASSISTANT:') != -1: + if tmpResp[i].find("ASSISTANT:") != -1: last_index = i ai_response = tmpResp[last_index] ai_response = ai_response.replace("ASSISTANT:", "") @@ -343,14 +356,20 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re state.messages[-1][-1] = "ASSISTANT未能正确回复,回复结果为:\n" + all_text yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5 else: - plugin_resp = execute_ai_response_json(auto_prompt.prompt_generator, ai_response) + plugin_resp = execute_ai_response_json( + auto_prompt.prompt_generator, ai_response + ) cfg.set_last_plugin_return(plugin_resp) print(plugin_resp) - state.messages[-1][-1] = "Model推理信息:\n" + ai_response + "\n\nDB-GPT执行结果:\n" + plugin_resp + state.messages[-1][-1] = ( + "Model推理信息:\n" + ai_response + "\n\nDB-GPT执行结果:\n" + plugin_resp + ) yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5 except NotCommands as e: print("命令执行:" + e.message) - state.messages[-1][-1] = "命令执行:" + e.message + "\n模型输出:\n" + str(ai_response) + state.messages[-1][-1] = ( + "命令执行:" + e.message + "\n模型输出:\n" + str(ai_response) + ) yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5 else: # 流式输出 @@ -359,8 +378,13 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re try: # Stream output - response = requests.post(urljoin(CFG.MODEL_SERVER, "generate_stream"), - headers=headers, json=payload, stream=True, timeout=20) + response = requests.post( + urljoin(CFG.MODEL_SERVER, "generate_stream"), + headers=headers, + json=payload, + stream=True, + timeout=20, + ) for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): if chunk: data = json.loads(chunk.decode()) @@ -368,7 +392,6 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re """ TODO Multi mode output handler, rewrite this for multi model, use adapter mode. """ if data["error_code"] == 0: - if "vicuna" in CFG.LLM_MODEL: output = data["text"][skip_echo_len:].strip() else: @@ -381,12 +404,23 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re output = data["text"] + f" (error_code: {data['error_code']})" state.messages[-1][-1] = output yield (state, state.to_gradio_chatbot()) + ( - disable_btn, disable_btn, disable_btn, enable_btn, enable_btn) + disable_btn, + disable_btn, + disable_btn, + enable_btn, + enable_btn, + ) return except requests.exceptions.RequestException as e: state.messages[-1][-1] = server_error_msg + f" (error_code: 4)" - yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn) + yield (state, state.to_gradio_chatbot()) + ( + disable_btn, + disable_btn, + disable_btn, + enable_btn, + enable_btn, + ) return state.messages[-1][-1] = state.messages[-1][-1][:-1] @@ -410,8 +444,8 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re block_css = ( - code_highlight_css - + """ + code_highlight_css + + """ pre { white-space: pre-wrap; /* Since CSS 2.1 */ white-space: -moz-pre-wrap; /* Mozilla, since 1999 */ @@ -487,7 +521,8 @@ def build_single_model_ui(): choices=dbs, value=dbs[0] if len(models) > 0 else "", interactive=True, - show_label=True).style(container=False) + show_label=True, + ).style(container=False) sql_mode = gr.Radio(["直接执行结果", "不执行结果"], show_label=False, value="不执行结果") sql_vs_setting = gr.Markdown("自动执行模式下, DB-GPT可以具备执行SQL、从网络读取知识自动化存储学习的能力") @@ -495,7 +530,9 @@ def build_single_model_ui(): tab_qa = gr.TabItem("知识问答", elem_id="QA") with tab_qa: - mode = gr.Radio(["LLM原生对话", "默认知识库对话", "新增知识库对话"], show_label=False, value="LLM原生对话") + mode = gr.Radio( + ["LLM原生对话", "默认知识库对话", "新增知识库对话"], show_label=False, value="LLM原生对话" + ) vs_setting = gr.Accordion("配置知识库", open=False) mode.change(fn=change_mode, inputs=mode, outputs=vs_setting) with vs_setting: @@ -504,19 +541,22 @@ def build_single_model_ui(): with gr.Column() as doc2vec: gr.Markdown("向知识库中添加文件") with gr.Tab("上传文件"): - files = gr.File(label="添加文件", - file_types=[".txt", ".md", ".docx", ".pdf"], - file_count="multiple", - allow_flagged_uploads=True, - show_label=False - ) + files = gr.File( + label="添加文件", + file_types=[".txt", ".md", ".docx", ".pdf"], + file_count="multiple", + allow_flagged_uploads=True, + show_label=False, + ) load_file_button = gr.Button("上传并加载到知识库") with gr.Tab("上传文件夹"): - folder_files = gr.File(label="添加文件夹", - accept_multiple_files=True, - file_count="directory", - show_label=False) + folder_files = gr.File( + label="添加文件夹", + accept_multiple_files=True, + file_count="directory", + show_label=False, + ) load_folder_button = gr.Button("上传并加载到知识库") with gr.Blocks(): @@ -557,28 +597,32 @@ def build_single_model_ui(): ).then( http_bot, [state, mode, sql_mode, db_selector, temperature, max_output_tokens], - [state, chatbot] + btn_list + [state, chatbot] + btn_list, + ) + vs_add.click( + fn=save_vs_name, show_progress=True, inputs=[vs_name], outputs=[vs_name] + ) + load_file_button.click( + fn=knowledge_embedding_store, + show_progress=True, + inputs=[vs_name, files], + outputs=[vs_name], + ) + load_folder_button.click( + fn=knowledge_embedding_store, + show_progress=True, + inputs=[vs_name, folder_files], + outputs=[vs_name], ) - vs_add.click(fn=save_vs_name, show_progress=True, - inputs=[vs_name], - outputs=[vs_name]) - load_file_button.click(fn=knowledge_embedding_store, - show_progress=True, - inputs=[vs_name, files], - outputs=[vs_name]) - load_folder_button.click(fn=knowledge_embedding_store, - show_progress=True, - inputs=[vs_name, folder_files], - outputs=[vs_name]) return state, chatbot, textbox, send_btn, button_row, parameter_row def build_webdemo(): with gr.Blocks( - title="数据库智能助手", - # theme=gr.themes.Base(), - theme=gr.themes.Default(), - css=block_css, + title="数据库智能助手", + # theme=gr.themes.Base(), + theme=gr.themes.Default(), + css=block_css, ) as demo: url_params = gr.JSON(visible=False) ( @@ -613,26 +657,31 @@ def save_vs_name(vs_name): vector_store_name["vs_name"] = vs_name return vs_name + def knowledge_embedding_store(vs_id, files): # vs_path = os.path.join(VS_ROOT_PATH, vs_id) if not os.path.exists(os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id)): os.makedirs(os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id)) for file in files: filename = os.path.split(file.name)[-1] - shutil.move(file.name, os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, filename)) + shutil.move( + file.name, os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, filename) + ) knowledge_embedding_client = KnowledgeEmbedding( file_path=os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, filename), model_name=LLM_MODEL_CONFIG["text2vec"], local_persist=False, vector_store_config={ "vector_store_name": vector_store_name["vs_name"], - "vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH}) + "vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH, + }, + ) knowledge_embedding_client.knowledge_embedding() - logger.info("knowledge embedding success") return os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, vs_id + ".vectordb") + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--host", type=str, default="0.0.0.0") @@ -671,5 +720,8 @@ def knowledge_embedding_store(vs_id, files): demo.queue( concurrency_count=args.concurrency_count, status_update_rate=10, api_open=False ).launch( - server_name=args.host, server_port=args.port, share=args.share, max_threads=200, + server_name=args.host, + server_port=args.port, + share=args.share, + max_threads=200, ) diff --git a/pilot/singleton.py b/pilot/singleton.py index 8a9d6e2fa..6fd65132e 100644 --- a/pilot/singleton.py +++ b/pilot/singleton.py @@ -5,10 +5,12 @@ import abc from typing import Any + class Singleton(abc.ABCMeta, type): - """ Singleton metaclass for ensuring only one instance of a class""" + """Singleton metaclass for ensuring only one instance of a class""" _instances = {} + def __call__(cls, *args: Any, **kwargs: Any) -> Any: """Call method for the singleton metaclass""" if cls not in cls._instances: @@ -18,4 +20,5 @@ def __call__(cls, *args: Any, **kwargs: Any) -> Any: class AbstractSingleton(abc.ABC, metaclass=Singleton): """Abstract singleton class for ensuring only one instance of a class""" - pass \ No newline at end of file + + pass diff --git a/pilot/source_embedding/__init__.py b/pilot/source_embedding/__init__.py index 9d1e74a31..464ff11b1 100644 --- a/pilot/source_embedding/__init__.py +++ b/pilot/source_embedding/__init__.py @@ -1,8 +1,3 @@ -from pilot.source_embedding.source_embedding import SourceEmbedding -from pilot.source_embedding.source_embedding import register +from pilot.source_embedding.source_embedding import SourceEmbedding, register - -__all__ = [ - "SourceEmbedding", - "register" -] \ No newline at end of file +__all__ = ["SourceEmbedding", "register"] diff --git a/pilot/source_embedding/chn_document_splitter.py b/pilot/source_embedding/chn_document_splitter.py index 10a77aeca..5bf06ea8c 100644 --- a/pilot/source_embedding/chn_document_splitter.py +++ b/pilot/source_embedding/chn_document_splitter.py @@ -1,5 +1,6 @@ import re from typing import List + from langchain.text_splitter import CharacterTextSplitter @@ -12,32 +13,43 @@ def __init__(self, pdf: bool = False, sentence_size: int = None, **kwargs): def split_text(self, text: str) -> List[str]: if self.pdf: text = re.sub(r"\n{3,}", r"\n", text) - text = re.sub('\s', " ", text) + text = re.sub("\s", " ", text) text = re.sub("\n\n", "", text) - text = re.sub(r'([;;.!?。!?\?])([^”’])', r"\1\n\2", text) + text = re.sub(r"([;;.!?。!?\?])([^”’])", r"\1\n\2", text) text = re.sub(r'(\.{6})([^"’”」』])', r"\1\n\2", text) text = re.sub(r'(\…{2})([^"’”」』])', r"\1\n\2", text) - text = re.sub(r'([;;!?。!?\?]["’”」』]{0,2})([^;;!?,。!?\?])', r'\1\n\2', text) + text = re.sub(r'([;;!?。!?\?]["’”」』]{0,2})([^;;!?,。!?\?])', r"\1\n\2", text) text = text.rstrip() ls = [i for i in text.split("\n") if i] for ele in ls: if len(ele) > self.sentence_size: - ele1 = re.sub(r'([,,.]["’”」』]{0,2})([^,,.])', r'\1\n\2', ele) + ele1 = re.sub(r'([,,.]["’”」』]{0,2})([^,,.])', r"\1\n\2", ele) ele1_ls = ele1.split("\n") for ele_ele1 in ele1_ls: if len(ele_ele1) > self.sentence_size: - ele_ele2 = re.sub(r'([\n]{1,}| {2,}["’”」』]{0,2})([^\s])', r'\1\n\2', ele_ele1) + ele_ele2 = re.sub( + r'([\n]{1,}| {2,}["’”」』]{0,2})([^\s])', r"\1\n\2", ele_ele1 + ) ele2_ls = ele_ele2.split("\n") for ele_ele2 in ele2_ls: if len(ele_ele2) > self.sentence_size: - ele_ele3 = re.sub('( ["’”」』]{0,2})([^ ])', r'\1\n\2', ele_ele2) + ele_ele3 = re.sub( + '( ["’”」』]{0,2})([^ ])', r"\1\n\2", ele_ele2 + ) ele2_id = ele2_ls.index(ele_ele2) - ele2_ls = ele2_ls[:ele2_id] + [i for i in ele_ele3.split("\n") if i] + ele2_ls[ - ele2_id + 1:] + ele2_ls = ( + ele2_ls[:ele2_id] + + [i for i in ele_ele3.split("\n") if i] + + ele2_ls[ele2_id + 1 :] + ) ele_id = ele1_ls.index(ele_ele1) - ele1_ls = ele1_ls[:ele_id] + [i for i in ele2_ls if i] + ele1_ls[ele_id + 1:] + ele1_ls = ( + ele1_ls[:ele_id] + + [i for i in ele2_ls if i] + + ele1_ls[ele_id + 1 :] + ) id = ls.index(ele) - ls = ls[:id] + [i for i in ele1_ls if i] + ls[id + 1:] + ls = ls[:id] + [i for i in ele1_ls if i] + ls[id + 1 :] return ls diff --git a/pilot/source_embedding/csv_embedding.py b/pilot/source_embedding/csv_embedding.py index 2f3b7ed06..8b2e25ff3 100644 --- a/pilot/source_embedding/csv_embedding.py +++ b/pilot/source_embedding/csv_embedding.py @@ -1,14 +1,21 @@ -from typing import List, Optional, Dict -from pilot.source_embedding import SourceEmbedding, register +from typing import Dict, List, Optional from langchain.document_loaders import CSVLoader from langchain.schema import Document +from pilot.source_embedding import SourceEmbedding, register + class CSVEmbedding(SourceEmbedding): """csv embedding for read csv document.""" - def __init__(self, file_path, model_name, vector_store_config, embedding_args: Optional[Dict] = None): + def __init__( + self, + file_path, + model_name, + vector_store_config, + embedding_args: Optional[Dict] = None, + ): """Initialize with csv path.""" super().__init__(file_path, model_name, vector_store_config) self.file_path = file_path @@ -29,6 +36,3 @@ def data_process(self, documents: List[Document]): documents[i].page_content = d.page_content.replace("\n", "") i += 1 return documents - - - diff --git a/pilot/source_embedding/knowledge_embedding.py b/pilot/source_embedding/knowledge_embedding.py index 2f313a35a..33f35f826 100644 --- a/pilot/source_embedding/knowledge_embedding.py +++ b/pilot/source_embedding/knowledge_embedding.py @@ -1,7 +1,8 @@ import os +import markdown from bs4 import BeautifulSoup -from langchain.document_loaders import TextLoader, markdown, PyPDFLoader +from langchain.document_loaders import PyPDFLoader, TextLoader, markdown from langchain.embeddings import HuggingFaceEmbeddings from pilot.configs.config import Config @@ -10,12 +11,11 @@ from pilot.source_embedding.csv_embedding import CSVEmbedding from pilot.source_embedding.markdown_embedding import MarkdownEmbedding from pilot.source_embedding.pdf_embedding import PDFEmbedding -import markdown - from pilot.vector_store.connector import VectorStoreConnector CFG = Config() + class KnowledgeEmbedding: def __init__(self, file_path, model_name, vector_store_config, local_persist=True): """Initialize with Loader url, model_name, vector_store_config""" @@ -37,16 +37,30 @@ def knowledge_embedding_batch(self): def init_knowledge_embedding(self): if self.file_path.endswith(".pdf"): - embedding = PDFEmbedding(file_path=self.file_path, model_name=self.model_name, - vector_store_config=self.vector_store_config) + embedding = PDFEmbedding( + file_path=self.file_path, + model_name=self.model_name, + vector_store_config=self.vector_store_config, + ) elif self.file_path.endswith(".md"): - embedding = MarkdownEmbedding(file_path=self.file_path, model_name=self.model_name, vector_store_config=self.vector_store_config) + embedding = MarkdownEmbedding( + file_path=self.file_path, + model_name=self.model_name, + vector_store_config=self.vector_store_config, + ) elif self.file_path.endswith(".csv"): - embedding = CSVEmbedding(file_path=self.file_path, model_name=self.model_name, - vector_store_config=self.vector_store_config) + embedding = CSVEmbedding( + file_path=self.file_path, + model_name=self.model_name, + vector_store_config=self.vector_store_config, + ) elif self.file_type == "default": - embedding = MarkdownEmbedding(file_path=self.file_path, model_name=self.model_name, vector_store_config=self.vector_store_config) + embedding = MarkdownEmbedding( + file_path=self.file_path, + model_name=self.model_name, + vector_store_config=self.vector_store_config, + ) return embedding @@ -55,7 +69,9 @@ def similar_search(self, text, topk): def knowledge_persist_initialization(self, append_mode): documents = self._load_knownlege(self.file_path) - self.vector_client = VectorStoreConnector(CFG.VECTOR_STORE_TYPE, self.vector_store_config) + self.vector_client = VectorStoreConnector( + CFG.VECTOR_STORE_TYPE, self.vector_store_config + ) self.vector_client.load_document(documents) return self.vector_client @@ -67,7 +83,9 @@ def _load_knownlege(self, path): docs = self._load_file(filename) new_docs = [] for doc in docs: - doc.metadata = {"source": doc.metadata["source"].replace(DATASETS_DIR, "")} + doc.metadata = { + "source": doc.metadata["source"].replace(DATASETS_DIR, "") + } print("doc is embedding...", doc.metadata) new_docs.append(doc) docments += new_docs @@ -76,27 +94,33 @@ def _load_knownlege(self, path): def _load_file(self, filename): if filename.lower().endswith(".md"): loader = TextLoader(filename) - text_splitter = CHNDocumentSplitter(pdf=True, sentence_size=KNOWLEDGE_CHUNK_SPLIT_SIZE) + text_splitter = CHNDocumentSplitter( + pdf=True, sentence_size=KNOWLEDGE_CHUNK_SPLIT_SIZE + ) docs = loader.load_and_split(text_splitter) i = 0 for d in docs: content = markdown.markdown(d.page_content) - soup = BeautifulSoup(content, 'html.parser') - for tag in soup(['!doctype', 'meta', 'i.fa']): + soup = BeautifulSoup(content, "html.parser") + for tag in soup(["!doctype", "meta", "i.fa"]): tag.extract() docs[i].page_content = soup.get_text() docs[i].page_content = docs[i].page_content.replace("\n", " ") i += 1 elif filename.lower().endswith(".pdf"): loader = PyPDFLoader(filename) - textsplitter = CHNDocumentSplitter(pdf=True, sentence_size=KNOWLEDGE_CHUNK_SPLIT_SIZE) + textsplitter = CHNDocumentSplitter( + pdf=True, sentence_size=KNOWLEDGE_CHUNK_SPLIT_SIZE + ) docs = loader.load_and_split(textsplitter) i = 0 for d in docs: - docs[i].page_content = d.page_content.replace("\n", " ").replace("�", "") + docs[i].page_content = d.page_content.replace("\n", " ").replace( + "�", "" + ) i += 1 else: loader = TextLoader(filename) text_splitor = CHNDocumentSplitter(sentence_size=KNOWLEDGE_CHUNK_SPLIT_SIZE) docs = loader.load_and_split(text_splitor) - return docs \ No newline at end of file + return docs diff --git a/pilot/source_embedding/markdown_embedding.py b/pilot/source_embedding/markdown_embedding.py index 834226f75..3db6cdbf5 100644 --- a/pilot/source_embedding/markdown_embedding.py +++ b/pilot/source_embedding/markdown_embedding.py @@ -3,12 +3,12 @@ import os from typing import List +import markdown from bs4 import BeautifulSoup from langchain.document_loaders import TextLoader from langchain.schema import Document -import markdown -from pilot.configs.model_config import KNOWLEDGE_CHUNK_SPLIT_SIZE +from pilot.configs.model_config import KNOWLEDGE_CHUNK_SPLIT_SIZE from pilot.source_embedding import SourceEmbedding, register from pilot.source_embedding.chn_document_splitter import CHNDocumentSplitter @@ -27,7 +27,9 @@ def __init__(self, file_path, model_name, vector_store_config): def read(self): """Load from markdown path.""" loader = TextLoader(self.file_path) - text_splitter = CHNDocumentSplitter(pdf=True, sentence_size=KNOWLEDGE_CHUNK_SPLIT_SIZE) + text_splitter = CHNDocumentSplitter( + pdf=True, sentence_size=KNOWLEDGE_CHUNK_SPLIT_SIZE + ) return loader.load_and_split(text_splitter) @register @@ -44,7 +46,9 @@ def read_batch(self): # 更新metadata数据 new_docs = [] for doc in docs: - doc.metadata = {"source": doc.metadata["source"].replace(self.file_path, "")} + doc.metadata = { + "source": doc.metadata["source"].replace(self.file_path, "") + } print("doc is embedding ... ", doc.metadata) new_docs.append(doc) docments += new_docs @@ -55,13 +59,10 @@ def data_process(self, documents: List[Document]): i = 0 for d in documents: content = markdown.markdown(d.page_content) - soup = BeautifulSoup(content, 'html.parser') - for tag in soup(['!doctype', 'meta', 'i.fa']): + soup = BeautifulSoup(content, "html.parser") + for tag in soup(["!doctype", "meta", "i.fa"]): tag.extract() documents[i].page_content = soup.get_text() documents[i].page_content = documents[i].page_content.replace("\n", " ") i += 1 return documents - - - diff --git a/pilot/source_embedding/pdf_embedding.py b/pilot/source_embedding/pdf_embedding.py index 75d17c4c6..c76cf65d2 100644 --- a/pilot/source_embedding/pdf_embedding.py +++ b/pilot/source_embedding/pdf_embedding.py @@ -4,8 +4,8 @@ from langchain.document_loaders import PyPDFLoader from langchain.schema import Document -from pilot.configs.model_config import KNOWLEDGE_CHUNK_SPLIT_SIZE +from pilot.configs.model_config import KNOWLEDGE_CHUNK_SPLIT_SIZE from pilot.source_embedding import SourceEmbedding, register from pilot.source_embedding.chn_document_splitter import CHNDocumentSplitter @@ -25,7 +25,9 @@ def read(self): """Load from pdf path.""" # loader = UnstructuredPaddlePDFLoader(self.file_path) loader = PyPDFLoader(self.file_path) - textsplitter = CHNDocumentSplitter(pdf=True, sentence_size=KNOWLEDGE_CHUNK_SPLIT_SIZE) + textsplitter = CHNDocumentSplitter( + pdf=True, sentence_size=KNOWLEDGE_CHUNK_SPLIT_SIZE + ) return loader.load_and_split(textsplitter) @register @@ -35,6 +37,3 @@ def data_process(self, documents: List[Document]): documents[i].page_content = d.page_content.replace("\n", "") i += 1 return documents - - - diff --git a/pilot/source_embedding/pdf_loader.py b/pilot/source_embedding/pdf_loader.py index aa7cf4da5..80888631f 100644 --- a/pilot/source_embedding/pdf_loader.py +++ b/pilot/source_embedding/pdf_loader.py @@ -1,10 +1,10 @@ """Loader that loads image files.""" +import os from typing import List +import fitz from langchain.document_loaders.unstructured import UnstructuredFileLoader from paddleocr import PaddleOCR -import os -import fitz class UnstructuredPaddlePDFLoader(UnstructuredFileLoader): @@ -19,9 +19,8 @@ def pdf_ocr_txt(filepath, dir_path="tmp_files"): ocr = PaddleOCR(lang="ch", use_gpu=False, show_log=False) doc = fitz.open(filepath) txt_file_path = os.path.join(full_dir_path, "%s.txt" % (filename)) - img_name = os.path.join(full_dir_path, '.tmp.png') - with open(txt_file_path, 'w', encoding='utf-8') as fout: - + img_name = os.path.join(full_dir_path, ".tmp.png") + with open(txt_file_path, "w", encoding="utf-8") as fout: for i in range(doc.page_count): page = doc[i] text = page.get_text("") @@ -42,11 +41,14 @@ def pdf_ocr_txt(filepath, dir_path="tmp_files"): txt_file_path = pdf_ocr_txt(self.file_path) from unstructured.partition.text import partition_text + return partition_text(filename=txt_file_path, **self.unstructured_kwargs) if __name__ == "__main__": - filepath = os.path.join(os.path.dirname(os.path.dirname(__file__)), "content", "samples", "test.pdf") + filepath = os.path.join( + os.path.dirname(os.path.dirname(__file__)), "content", "samples", "test.pdf" + ) loader = UnstructuredPaddlePDFLoader(filepath, mode="elements") docs = loader.load() for doc in docs: diff --git a/pilot/source_embedding/search_milvus.py b/pilot/source_embedding/search_milvus.py index ec0aa6813..aa02c1f61 100644 --- a/pilot/source_embedding/search_milvus.py +++ b/pilot/source_embedding/search_milvus.py @@ -58,4 +58,4 @@ # # docs, # # embedding=embeddings, # # connection_args={"host": "127.0.0.1", "port": "19530", "alias": "default"} -# # ) \ No newline at end of file +# # ) diff --git a/pilot/source_embedding/source_embedding.py b/pilot/source_embedding/source_embedding.py index a84282009..acbf82a73 100644 --- a/pilot/source_embedding/source_embedding.py +++ b/pilot/source_embedding/source_embedding.py @@ -1,9 +1,9 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- from abc import ABC, abstractmethod +from typing import Dict, List, Optional from langchain.embeddings import HuggingFaceEmbeddings -from typing import List, Optional, Dict from pilot.configs.config import Config from pilot.vector_store.connector import VectorStoreConnector @@ -23,7 +23,13 @@ class SourceEmbedding(ABC): Implementations should implement the method """ - def __init__(self, file_path, model_name, vector_store_config, embedding_args: Optional[Dict] = None): + def __init__( + self, + file_path, + model_name, + vector_store_config, + embedding_args: Optional[Dict] = None, + ): """Initialize with Loader url, model_name, vector_store_config""" self.file_path = file_path self.model_name = model_name @@ -32,12 +38,15 @@ def __init__(self, file_path, model_name, vector_store_config, embedding_args: O self.embeddings = HuggingFaceEmbeddings(model_name=self.model_name) vector_store_config["embeddings"] = self.embeddings - self.vector_client = VectorStoreConnector(CFG.VECTOR_STORE_TYPE, vector_store_config) + self.vector_client = VectorStoreConnector( + CFG.VECTOR_STORE_TYPE, vector_store_config + ) @abstractmethod @register def read(self) -> List[ABC]: """read datasource into document objects.""" + @register def data_process(self, text): """pre process data.""" @@ -63,25 +72,25 @@ def similar_search(self, doc, topk): return self.vector_client.similar_search(doc, topk) def source_embedding(self): - if 'read' in registered_methods: + if "read" in registered_methods: text = self.read() - if 'data_process' in registered_methods: + if "data_process" in registered_methods: text = self.data_process(text) - if 'text_split' in registered_methods: + if "text_split" in registered_methods: self.text_split(text) - if 'text_to_vector' in registered_methods: + if "text_to_vector" in registered_methods: self.text_to_vector(text) - if 'index_to_store' in registered_methods: + if "index_to_store" in registered_methods: self.index_to_store(text) def batch_embedding(self): - if 'read_batch' in registered_methods: + if "read_batch" in registered_methods: text = self.read_batch() - if 'data_process' in registered_methods: + if "data_process" in registered_methods: text = self.data_process(text) - if 'text_split' in registered_methods: + if "text_split" in registered_methods: self.text_split(text) - if 'text_to_vector' in registered_methods: + if "text_to_vector" in registered_methods: self.text_to_vector(text) - if 'index_to_store' in registered_methods: + if "index_to_store" in registered_methods: self.index_to_store(text) diff --git a/pilot/source_embedding/url_embedding.py b/pilot/source_embedding/url_embedding.py index 68fbdd5e4..59eef19e7 100644 --- a/pilot/source_embedding/url_embedding.py +++ b/pilot/source_embedding/url_embedding.py @@ -1,13 +1,11 @@ from typing import List -from langchain.text_splitter import CharacterTextSplitter - -from pilot.source_embedding import SourceEmbedding, register - from bs4 import BeautifulSoup from langchain.document_loaders import WebBaseLoader from langchain.schema import Document +from langchain.text_splitter import CharacterTextSplitter +from pilot.source_embedding import SourceEmbedding, register class URLEmbedding(SourceEmbedding): @@ -23,7 +21,9 @@ def __init__(self, file_path, model_name, vector_store_config): def read(self): """Load from url path.""" loader = WebBaseLoader(web_path=self.file_path) - text_splitor = CharacterTextSplitter(chunk_size=1000, chunk_overlap=20, length_function=len) + text_splitor = CharacterTextSplitter( + chunk_size=1000, chunk_overlap=20, length_function=len + ) return loader.load_and_split(text_splitor) @register @@ -31,12 +31,9 @@ def data_process(self, documents: List[Document]): i = 0 for d in documents: content = d.page_content.replace("\n", "") - soup = BeautifulSoup(content, 'html.parser') - for tag in soup(['!doctype', 'meta']): + soup = BeautifulSoup(content, "html.parser") + for tag in soup(["!doctype", "meta"]): tag.extract() documents[i].page_content = soup.get_text() i += 1 return documents - - - diff --git a/pilot/utils.py b/pilot/utils.py index 607b83251..41e42fd55 100644 --- a/pilot/utils.py +++ b/pilot/utils.py @@ -1,27 +1,28 @@ #!/usr/bin/env python3 # -*- coding:utf-8 -*- -import torch - -import datetime import logging import logging.handlers import os import sys import requests +import torch from pilot.configs.model_config import LOGDIR -server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" +server_error_msg = ( + "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" +) handler = None + def get_gpu_memory(max_gpus=None): gpu_memory = [] num_gpus = ( torch.cuda.device_count() - if max_gpus is None + if max_gpus is None else min(max_gpus, torch.cuda.device_count()) ) @@ -29,14 +30,13 @@ def get_gpu_memory(max_gpus=None): with torch.cuda.device(gpu_id): device = torch.cuda.current_device() gpu_properties = torch.cuda.get_device_properties(device) - total_memory = gpu_properties.total_memory / (1024 ** 3) - allocated_memory = torch.cuda.memory_allocated() / (1024 ** 3) + total_memory = gpu_properties.total_memory / (1024**3) + allocated_memory = torch.cuda.memory_allocated() / (1024**3) available_memory = total_memory - allocated_memory gpu_memory.append(available_memory) return gpu_memory - def build_logger(logger_name, logger_filename): global handler @@ -47,7 +47,7 @@ def build_logger(logger_name, logger_filename): # Set the format of root handlers if not logging.getLogger().handlers: - logging.basicConfig(level=logging.INFO, encoding='utf-8') + logging.basicConfig(level=logging.INFO, encoding="utf-8") logging.getLogger().handlers[0].setFormatter(formatter) # Redirect stdout and stderr to loggers @@ -70,7 +70,8 @@ def build_logger(logger_name, logger_filename): os.makedirs(LOGDIR, exist_ok=True) filename = os.path.join(LOGDIR, logger_filename) handler = logging.handlers.TimedRotatingFileHandler( - filename, when='D', utc=True) + filename, when="D", utc=True + ) handler.setFormatter(formatter) for name, item in logging.root.manager.loggerDict.items(): @@ -84,35 +85,36 @@ class StreamToLogger(object): """ Fake file-like stream object that redirects writes to a logger instance. """ + def __init__(self, logger, log_level=logging.INFO): self.terminal = sys.stdout self.logger = logger self.log_level = log_level - self.linebuf = '' + self.linebuf = "" def __getattr__(self, attr): return getattr(self.terminal, attr) def write(self, buf): temp_linebuf = self.linebuf + buf - self.linebuf = '' + self.linebuf = "" for line in temp_linebuf.splitlines(True): # From the io.TextIOWrapper docs: # On output, if newline is None, any '\n' characters written # are translated to the system default line separator. # By default sys.stdout.write() expects '\n' newlines and then # translates them so this is still cross platform. - if line[-1] == '\n': - encoded_message = line.encode('utf-8', 'ignore').decode('utf-8') + if line[-1] == "\n": + encoded_message = line.encode("utf-8", "ignore").decode("utf-8") self.logger.log(self.log_level, encoded_message.rstrip()) else: self.linebuf += line def flush(self): - if self.linebuf != '': - encoded_message = self.linebuf.encode('utf-8', 'ignore').decode('utf-8') + if self.linebuf != "": + encoded_message = self.linebuf.encode("utf-8", "ignore").decode("utf-8") self.logger.log(self.log_level, encoded_message.rstrip()) - self.linebuf = '' + self.linebuf = "" def disable_torch_init(): @@ -120,6 +122,7 @@ def disable_torch_init(): Disable the redundant torch default initialization to accelerate model creation. """ import torch + setattr(torch.nn.Linear, "reset_parameters", lambda self: None) setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) @@ -128,4 +131,3 @@ def pretty_print_semaphore(semaphore): if semaphore is None: return "None" return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})" - diff --git a/pilot/vector_store/chroma_store.py b/pilot/vector_store/chroma_store.py index 9a91659f1..1ec9e8b04 100644 --- a/pilot/vector_store/chroma_store.py +++ b/pilot/vector_store/chroma_store.py @@ -13,9 +13,12 @@ class ChromaStore(VectorStoreBase): def __init__(self, ctx: {}) -> None: self.ctx = ctx self.embeddings = ctx["embeddings"] - self.persist_dir = os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, - ctx["vector_store_name"] + ".vectordb") - self.vector_store_client = Chroma(persist_directory=self.persist_dir, embedding_function=self.embeddings) + self.persist_dir = os.path.join( + KNOWLEDGE_UPLOAD_ROOT_PATH, ctx["vector_store_name"] + ".vectordb" + ) + self.vector_store_client = Chroma( + persist_directory=self.persist_dir, embedding_function=self.embeddings + ) def similar_search(self, text, topk) -> None: logger.info("ChromaStore similar search") @@ -27,4 +30,3 @@ def load_document(self, documents): metadatas = [doc.metadata for doc in documents] self.vector_store_client.add_texts(texts=texts, metadatas=metadatas) self.vector_store_client.persist() - diff --git a/pilot/vector_store/connector.py b/pilot/vector_store/connector.py index 003415712..06fad00f2 100644 --- a/pilot/vector_store/connector.py +++ b/pilot/vector_store/connector.py @@ -1,15 +1,12 @@ from pilot.vector_store.chroma_store import ChromaStore from pilot.vector_store.milvus_store import MilvusStore -connector = { - "Chroma": ChromaStore, - "Milvus": MilvusStore - } +connector = {"Chroma": ChromaStore, "Milvus": MilvusStore} class VectorStoreConnector: - """ vector store connector, can connect different vector db provided load document api and similar search api - """ + """vector store connector, can connect different vector db provided load document api and similar search api""" + def __init__(self, vector_store_type, ctx: {}) -> None: self.ctx = ctx self.connector_class = connector[vector_store_type] diff --git a/pilot/vector_store/extract_tovec.py b/pilot/vector_store/extract_tovec.py index c6b83d467..1032876cf 100644 --- a/pilot/vector_store/extract_tovec.py +++ b/pilot/vector_store/extract_tovec.py @@ -3,14 +3,16 @@ import os +from langchain.embeddings import HuggingFaceEmbeddings from langchain.text_splitter import CharacterTextSplitter from langchain.vectorstores import Chroma + +from pilot.configs.model_config import DATASETS_DIR, VECTORE_PATH from pilot.model.vicuna_llm import VicunaEmbeddingLLM -from pilot.configs.model_config import VECTORE_PATH, DATASETS_DIR -from langchain.embeddings import HuggingFaceEmbeddings embeddings = VicunaEmbeddingLLM() + def knownledge_tovec(filename): with open(filename, "r") as f: knownledge = f.read() @@ -22,48 +24,64 @@ def knownledge_tovec(filename): ) return docsearch + def knownledge_tovec_st(filename): - """ Use sentence transformers to embedding the document. - https://github.com/UKPLab/sentence-transformers + """Use sentence transformers to embedding the document. + https://github.com/UKPLab/sentence-transformers """ from pilot.configs.model_config import LLM_MODEL_CONFIG - embeddings = HuggingFaceEmbeddings(model_name=LLM_MODEL_CONFIG["sentence-transforms"]) + + embeddings = HuggingFaceEmbeddings( + model_name=LLM_MODEL_CONFIG["sentence-transforms"] + ) with open(filename, "r") as f: knownledge = f.read() - + text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0) texts = text_splitter.split_text(knownledge) - docsearch = Chroma.from_texts(texts, embeddings, metadatas=[{"source": str(i)} for i in range(len(texts))]) + docsearch = Chroma.from_texts( + texts, embeddings, metadatas=[{"source": str(i)} for i in range(len(texts))] + ) return docsearch def load_knownledge_from_doc(): """Loader Knownledge from current datasets - # TODO if the vector store is exists, just use it. + # TODO if the vector store is exists, just use it. """ if not os.path.exists(DATASETS_DIR): - print("Not Exists Local DataSets, We will answers the Question use model default.") + print( + "Not Exists Local DataSets, We will answers the Question use model default." + ) from pilot.configs.model_config import LLM_MODEL_CONFIG - embeddings = HuggingFaceEmbeddings(model_name=LLM_MODEL_CONFIG["sentence-transforms"]) + + embeddings = HuggingFaceEmbeddings( + model_name=LLM_MODEL_CONFIG["sentence-transforms"] + ) files = os.listdir(DATASETS_DIR) for file in files: - if not os.path.isdir(file): + if not os.path.isdir(file): filename = os.path.join(DATASETS_DIR, file) with open(filename, "r") as f: - knownledge = f.read() + knownledge = f.read() text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_owerlap=0) texts = text_splitter.split_text(knownledge) - docsearch = Chroma.from_texts(texts, embeddings, metadatas=[{"source": str(i)} for i in range(len(texts))], - persist_directory=os.path.join(VECTORE_PATH, ".vectore")) + docsearch = Chroma.from_texts( + texts, + embeddings, + metadatas=[{"source": str(i)} for i in range(len(texts))], + persist_directory=os.path.join(VECTORE_PATH, ".vectore"), + ) return docsearch + def get_vector_storelist(): if not os.path.exists(VECTORE_PATH): return [] - return os.listdir(VECTORE_PATH) \ No newline at end of file + return os.listdir(VECTORE_PATH) diff --git a/pilot/vector_store/file_loader.py b/pilot/vector_store/file_loader.py index 279d5343c..c42eda7a6 100644 --- a/pilot/vector_store/file_loader.py +++ b/pilot/vector_store/file_loader.py @@ -2,56 +2,70 @@ # -*- coding: utf-8 -*- import os -import copy -from typing import Optional, List, Dict -from langchain.prompts import PromptTemplate -from langchain.vectorstores import Chroma -from langchain.text_splitter import CharacterTextSplitter -from langchain.document_loaders import UnstructuredFileLoader, UnstructuredPDFLoader, TextLoader + from langchain.chains import VectorDBQA +from langchain.document_loaders import ( + TextLoader, + UnstructuredFileLoader, + UnstructuredPDFLoader, +) from langchain.embeddings import HuggingFaceEmbeddings -from pilot.configs.model_config import VECTORE_PATH, DATASETS_DIR, LLM_MODEL_CONFIG, VECTOR_SEARCH_TOP_K +from langchain.prompts import PromptTemplate +from langchain.text_splitter import CharacterTextSplitter +from langchain.vectorstores import Chroma + +from pilot.configs.model_config import ( + DATASETS_DIR, + LLM_MODEL_CONFIG, + VECTOR_SEARCH_TOP_K, + VECTORE_PATH, +) class KnownLedge2Vector: - """KnownLedge2Vector class is order to load document to vector + """KnownLedge2Vector class is order to load document to vector and persist to vector store. - - Args: + + Args: - model_name Usage: k2v = KnownLedge2Vector() - persist_dir = os.path.join(VECTORE_PATH, ".vectordb") + persist_dir = os.path.join(VECTORE_PATH, ".vectordb") print(persist_dir) for s, dc in k2v.query("what is oceanbase?"): print(s, dc.page_content, dc.metadata) """ - embeddings: object = None + + embeddings: object = None model_name = LLM_MODEL_CONFIG["sentence-transforms"] top_k: int = VECTOR_SEARCH_TOP_K def __init__(self, model_name=None) -> None: if not model_name: # use default embedding model - self.embeddings = HuggingFaceEmbeddings(model_name=self.model_name) - + self.embeddings = HuggingFaceEmbeddings(model_name=self.model_name) + def init_vector_store(self): persist_dir = os.path.join(VECTORE_PATH, ".vectordb") print("Vector store Persist address is: ", persist_dir) if os.path.exists(persist_dir): # Loader from local file. print("Loader data from local persist vector file...") - vector_store = Chroma(persist_directory=persist_dir, embedding_function=self.embeddings) + vector_store = Chroma( + persist_directory=persist_dir, embedding_function=self.embeddings + ) # vector_store.add_documents(documents=documents) else: documents = self.load_knownlege() - # reinit - vector_store = Chroma.from_documents(documents=documents, - embedding=self.embeddings, - persist_directory=persist_dir) + # reinit + vector_store = Chroma.from_documents( + documents=documents, + embedding=self.embeddings, + persist_directory=persist_dir, + ) vector_store.persist() return vector_store @@ -62,9 +76,11 @@ def load_knownlege(self): filename = os.path.join(root, file) docs = self._load_file(filename) # update metadata. - new_docs = [] + new_docs = [] for doc in docs: - doc.metadata = {"source": doc.metadata["source"].replace(DATASETS_DIR, "")} + doc.metadata = { + "source": doc.metadata["source"].replace(DATASETS_DIR, "") + } print("Documents to vector running, please wait...", doc.metadata) new_docs.append(doc) docments += new_docs @@ -73,7 +89,7 @@ def load_knownlege(self): def _load_file(self, filename): # Loader file if filename.lower().endswith(".pdf"): - loader = UnstructuredFileLoader(filename) + loader = UnstructuredFileLoader(filename) text_splitor = CharacterTextSplitter() docs = loader.load_and_split(text_splitor) else: @@ -86,13 +102,10 @@ def _load_from_url(self, url): """Load data from url address""" pass - def query(self, q): - """Query similar doc from Vector """ + """Query similar doc from Vector""" vector_store = self.init_vector_store() docs = vector_store.similarity_search_with_score(q, k=self.top_k) for doc in docs: dc, s = doc yield s, dc - - \ No newline at end of file diff --git a/pilot/vector_store/milvus_store.py b/pilot/vector_store/milvus_store.py index a61027850..3ae265f7b 100644 --- a/pilot/vector_store/milvus_store.py +++ b/pilot/vector_store/milvus_store.py @@ -1,15 +1,17 @@ -from typing import List, Optional, Iterable, Tuple, Any - -from pymilvus import connections, Collection, DataType +from typing import Any, Iterable, List, Optional, Tuple from langchain.docstore.document import Document +from pymilvus import Collection, DataType, connections from pilot.configs.config import Config from pilot.vector_store.vector_store_base import VectorStoreBase CFG = Config() + + class MilvusStore(VectorStoreBase): """Milvus database""" + def __init__(self, ctx: {}) -> None: """init a milvus storage connection. @@ -66,12 +68,12 @@ def __init__(self, ctx: {}) -> None: def init_schema_and_load(self, vector_name, documents): """Create a Milvus collection, indexes it with HNSW, load document. - Args: - vector_name (Embeddings): your collection name. - documents (List[str]): Text to insert. - Returns: - VectorStore: The MilvusStore vector store. - """ + Args: + vector_name (Embeddings): your collection name. + documents (List[str]): Text to insert. + Returns: + VectorStore: The MilvusStore vector store. + """ try: from pymilvus import ( Collection, @@ -237,13 +239,10 @@ def _add_texts( partition_name: Optional[str] = None, timeout: Optional[int] = None, ) -> List[str]: - """add text data into Milvus. - """ + """add text data into Milvus.""" insert_dict: Any = {self.text_field: list(texts)} try: - insert_dict[self.vector_field] = self.embedding.embed_documents( - list(texts) - ) + insert_dict[self.vector_field] = self.embedding.embed_documents(list(texts)) except NotImplementedError: insert_dict[self.vector_field] = [ self.embedding.embed_query(x) for x in texts diff --git a/pilot/vector_store/vector_store_base.py b/pilot/vector_store/vector_store_base.py index b483b3116..70888f5aa 100644 --- a/pilot/vector_store/vector_store_base.py +++ b/pilot/vector_store/vector_store_base.py @@ -12,4 +12,4 @@ def load_document(self, documents) -> None: @abstractmethod def similar_search(self, text, topk) -> None: """Initialize schema in vector database.""" - pass \ No newline at end of file + pass diff --git a/tests/unit/test_plugins.py b/tests/unit/test_plugins.py index 21dbaaf27..a2a3d2506 100644 --- a/tests/unit/test_plugins.py +++ b/tests/unit/test_plugins.py @@ -1,6 +1,6 @@ -import pytest import os +import pytest from pilot.configs.config import Config from pilot.plugins import ( @@ -15,10 +15,13 @@ PLUGIN_TEST_INIT_PY = "Auto-GPT-Plugin-Test-master/src/auto_gpt_vicuna/__init__.py" PLUGIN_TEST_OPENAI = "https://weathergpt.vercel.app/" + def test_inspect_zip_for_modules(): current_dir = os.getcwd() print(current_dir) - result = inspect_zip_for_modules(str(f"{current_dir}/{PLUGINS_TEST_DIR_TEMP}/{PLUGIN_TEST_ZIP_FILE}")) + result = inspect_zip_for_modules( + str(f"{current_dir}/{PLUGINS_TEST_DIR_TEMP}/{PLUGIN_TEST_ZIP_FILE}") + ) assert result == [PLUGIN_TEST_INIT_PY] @@ -99,6 +102,7 @@ def mock_config_openai_plugin(): class MockConfig: """Mock config object for testing the scan_plugins function""" + current_dir = os.getcwd() plugins_dir = f"{current_dir}/{PLUGINS_TEST_DIR_TEMP}/" plugins_openai = [PLUGIN_TEST_OPENAI] diff --git a/tools/knowlege_init.py b/tools/knowlege_init.py index 23ca33a80..df8697273 100644 --- a/tools/knowlege_init.py +++ b/tools/knowlege_init.py @@ -2,8 +2,13 @@ # -*- coding: utf-8 -*- import argparse -from pilot.configs.model_config import DATASETS_DIR, LLM_MODEL_CONFIG, VECTOR_SEARCH_TOP_K, VECTOR_STORE_CONFIG, \ - VECTOR_STORE_TYPE +from pilot.configs.model_config import ( + DATASETS_DIR, + LLM_MODEL_CONFIG, + VECTOR_SEARCH_TOP_K, + VECTOR_STORE_CONFIG, + VECTOR_STORE_TYPE, +) from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding @@ -16,22 +21,24 @@ def __init__(self, vector_store_config) -> None: self.vector_store_config = vector_store_config def knowledge_persist(self, file_path, append_mode): - """ knowledge persist """ + """knowledge persist""" kv = KnowledgeEmbedding( file_path=file_path, model_name=LLM_MODEL_CONFIG["text2vec"], - vector_store_config= self.vector_store_config) + vector_store_config=self.vector_store_config, + ) vector_store = kv.knowledge_persist_initialization(append_mode) return vector_store def query(self, q): - """Query similar doc from Vector """ + """Query similar doc from Vector""" vector_store = self.init_vector_store() docs = vector_store.similarity_search_with_score(q, k=self.top_k) for doc in docs: dc, s = doc yield s, dc + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--vector_name", type=str, default="default") @@ -41,8 +48,12 @@ def query(self, q): vector_name = args.vector_name append_mode = args.append store_type = VECTOR_STORE_TYPE - vector_store_config = {"url": VECTOR_STORE_CONFIG["url"], "port": VECTOR_STORE_CONFIG["port"], "vector_store_name":vector_name} + vector_store_config = { + "url": VECTOR_STORE_CONFIG["url"], + "port": VECTOR_STORE_CONFIG["port"], + "vector_store_name": vector_name, + } print(vector_store_config) - kv = LocalKnowledgeInit(vector_store_config=vector_store_config) + kv = LocalKnowledgeInit(vector_store_config=vector_store_config) vector_store = kv.knowledge_persist(file_path=DATASETS_DIR, append_mode=append_mode) - print("your knowledge embedding success...") \ No newline at end of file + print("your knowledge embedding success...")