From 8037c811cf6c969bd16512e18317b5bc1ac8553e Mon Sep 17 00:00:00 2001 From: Fangyin Cheng Date: Fri, 23 Feb 2024 11:44:44 +0800 Subject: [PATCH] feat(core): Support RAG chat flow (#1185) --- dbgpt/core/awel/runner/local_runner.py | 7 ++- dbgpt/serve/flow/service/service.py | 59 ++++++++++++------- dbgpt/storage/vector_store/milvus_store.py | 6 +- dbgpt/util/dbgpts/cli.py | 12 +++- dbgpt/util/dbgpts/repo.py | 42 +++++++++---- dbgpt/util/function_utils.py | 50 +++++++++++++--- .../quickstart_basic_awel_workflow.md | 4 +- scripts/setup_autodl_env.sh | 6 +- setup.py | 2 + 9 files changed, 134 insertions(+), 54 deletions(-) diff --git a/dbgpt/core/awel/runner/local_runner.py b/dbgpt/core/awel/runner/local_runner.py index 480f3b89a..8ded92ffc 100644 --- a/dbgpt/core/awel/runner/local_runner.py +++ b/dbgpt/core/awel/runner/local_runner.py @@ -3,6 +3,7 @@ This runner will run the workflow in the current process. """ import logging +import traceback from typing import Any, Dict, List, Optional, Set, cast from dbgpt.component import SystemApp @@ -143,7 +144,11 @@ async def _execute_node( ) _skip_current_downstream_by_node_name(node, skip_nodes, skip_node_ids) except Exception as e: - logger.info(f"Run operator {node.node_id} error, error message: {str(e)}") + msg = traceback.format_exc() + logger.info( + f"Run operator {type(node)}({node.node_id}) error, error message: " + f"{msg}" + ) task_ctx.set_current_state(TaskState.FAILED) raise e diff --git a/dbgpt/serve/flow/service/service.py b/dbgpt/serve/flow/service/service.py index 0ed5e462c..025d0a832 100644 --- a/dbgpt/serve/flow/service/service.py +++ b/dbgpt/serve/flow/service/service.py @@ -370,22 +370,16 @@ def _parse_flow_category(self, dag: DAG) -> FlowCategory: return FlowCategory.COMMON -def _is_chat_flow_type(obj: Any, is_class: bool = False) -> bool: - try: - from dbgpt.model.utils.chatgpt_utils import OpenAIStreamingOutputOperator - except ImportError: - OpenAIStreamingOutputOperator = None +def _is_chat_flow_type(output_obj: Any, is_class: bool = False) -> bool: if is_class: return ( - obj == str - or obj == CommonLLMHttpResponseBody - or (OpenAIStreamingOutputOperator and obj == OpenAIStreamingOutputOperator) + output_obj == str + or output_obj == CommonLLMHttpResponseBody + or output_obj == ModelOutput ) else: chat_types = (str, CommonLLMHttpResponseBody) - if OpenAIStreamingOutputOperator: - chat_types += (OpenAIStreamingOutputOperator,) - return isinstance(obj, chat_types) + return isinstance(output_obj, chat_types) async def _chat_with_dag_task( @@ -439,29 +433,50 @@ async def _chat_with_dag_task( yield f"data:{full_text}\n\n" else: async for output in await task.call_stream(request): + str_msg = "" + should_return = False if isinstance(output, str): if output.strip(): - yield output + str_msg = output + elif isinstance(output, ModelOutput): + if output.error_code != 0: + str_msg = f"[SERVER_ERROR]{output.text}" + should_return = True + else: + str_msg = output.text else: - yield "data:[SERVER_ERROR]The output is not a stream format\n\n" - return + str_msg = ( + f"[SERVER_ERROR]The output is not a valid format" + f"({type(output)})" + ) + should_return = True + if str_msg: + str_msg = str_msg.replace("\n", "\\n") + yield f"data:{str_msg}\n\n" + if should_return: + return else: result = await task.call(request) + str_msg = "" if result is None: - yield "data:[SERVER_ERROR]The result is None\n\n" + str_msg = "[SERVER_ERROR]The result is None!" elif isinstance(result, str): - yield f"data:{result}\n\n" + str_msg = result elif isinstance(result, ModelOutput): if result.error_code != 0: - yield f"data:[SERVER_ERROR]{result.text}\n\n" + str_msg = f"[SERVER_ERROR]{result.text}" else: - yield f"data:{result.text}\n\n" + str_msg = result.text elif isinstance(result, CommonLLMHttpResponseBody): if result.error_code != 0: - yield f"data:[SERVER_ERROR]{result.text}\n\n" + str_msg = f"[SERVER_ERROR]{result.text}" else: - yield f"data:{result.text}\n\n" + str_msg = result.text elif isinstance(result, dict): - yield f"data:{json.dumps(result, ensure_ascii=False)}\n\n" + str_msg = json.dumps(result, ensure_ascii=False) else: - yield f"data:[SERVER_ERROR]The result is not a valid format({type(result)})\n\n" + str_msg = f"[SERVER_ERROR]The result is not a valid format({type(result)})" + + if str_msg: + str_msg = str_msg.replace("\n", "\\n") + yield f"data:{str_msg}\n\n" diff --git a/dbgpt/storage/vector_store/milvus_store.py b/dbgpt/storage/vector_store/milvus_store.py index a29e137d7..1529e967a 100644 --- a/dbgpt/storage/vector_store/milvus_store.py +++ b/dbgpt/storage/vector_store/milvus_store.py @@ -339,9 +339,7 @@ def similar_search_with_scores(self, text, topk, score_threshold) -> List[Chunk] self.vector_field = x.name _, docs_and_scores = self._search(text, topk) if any(score < 0.0 or score > 1.0 for _, score, id in docs_and_scores): - import warnings - - warnings.warn( + logger.warning( "similarity score need between" f" 0 and 1, got {docs_and_scores}" ) @@ -357,7 +355,7 @@ def similar_search_with_scores(self, text, topk, score_threshold) -> List[Chunk] if score >= score_threshold ] if len(docs_and_scores) == 0: - warnings.warn( + logger.warning( "No relevant docs were retrieved using the relevance score" f" threshold {score_threshold}" ) diff --git a/dbgpt/util/dbgpts/cli.py b/dbgpt/util/dbgpts/cli.py index 53b7980a9..c3bb00acc 100644 --- a/dbgpt/util/dbgpts/cli.py +++ b/dbgpt/util/dbgpts/cli.py @@ -56,17 +56,25 @@ def list_repos(): @click.command(name="add") @add_tap_options +@click.option( + "-b", + "--branch", + type=str, + default=None, + required=False, + help="The branch of the repository(Just for git repo)", +) @click.option( "--url", type=str, required=True, help="The URL of the repo", ) -def add_repo(repo: str, url: str): +def add_repo(repo: str, branch: str | None, url: str): """Add a new repo""" from .repo import add_repo - add_repo(repo, url) + add_repo(repo, url, branch) @click.command(name="remove") diff --git a/dbgpt/util/dbgpts/repo.py b/dbgpt/util/dbgpts/repo.py index e86c836a7..946107963 100644 --- a/dbgpt/util/dbgpts/repo.py +++ b/dbgpt/util/dbgpts/repo.py @@ -63,12 +63,13 @@ def _list_repos_details() -> List[Tuple[str, str]]: return results -def add_repo(repo: str, repo_url: str): +def add_repo(repo: str, repo_url: str, branch: str | None = None): """Add a new repo Args: repo (str): The name of the repo repo_url (str): The URL of the repo + branch (str): The branch of the repo """ exist_repos = list_repos() if repo in exist_repos and repo_url not in DEFAULT_REPO_MAP.values(): @@ -84,7 +85,7 @@ def add_repo(repo: str, repo_url: str): repo_group_dir = os.path.join(DBGPTS_REPO_HOME, repo_arr[0]) os.makedirs(repo_group_dir, exist_ok=True) if repo_url.startswith("http") or repo_url.startswith("git"): - clone_repo(repo, repo_group_dir, repo_name, repo_url) + clone_repo(repo, repo_group_dir, repo_name, repo_url, branch) elif os.path.isdir(repo_url): # Create soft link os.symlink(repo_url, os.path.join(repo_group_dir, repo_name)) @@ -106,7 +107,13 @@ def remove_repo(repo: str): logger.info(f"Repo '{repo}' removed successfully.") -def clone_repo(repo: str, repo_group_dir: str, repo_name: str, repo_url: str): +def clone_repo( + repo: str, + repo_group_dir: str, + repo_name: str, + repo_url: str, + branch: str | None = None, +): """Clone the specified repo Args: @@ -114,10 +121,22 @@ def clone_repo(repo: str, repo_group_dir: str, repo_name: str, repo_url: str): repo_group_dir (str): The directory of the repo group repo_name (str): The name of the repo repo_url (str): The URL of the repo + branch (str): The branch of the repo """ os.chdir(repo_group_dir) - subprocess.run(["git", "clone", repo_url, repo_name], check=True) - logger.info(f"Repo '{repo}' cloned from {repo_url} successfully.") + clone_command = ["git", "clone", repo_url, repo_name] + + # If the branch is specified, add it to the clone command + if branch: + clone_command += ["-b", branch] + + subprocess.run(clone_command, check=True) + if branch: + click.echo( + f"Repo '{repo}' cloned from {repo_url} with branch '{branch}' successfully." + ) + else: + click.echo(f"Repo '{repo}' cloned from {repo_url} successfully.") def update_repo(repo: str): @@ -217,7 +236,7 @@ def _write_install_metadata(name: str, repo: str, install_path: Path): def check_with_retry( name: str, - repo: str | None = None, + spec_repo: str | None = None, with_update: bool = False, is_first: bool = False, ) -> Tuple[str, Path] | None: @@ -225,18 +244,17 @@ def check_with_retry( Args: name (str): The name of the dbgpt - repo (str): The name of the repo + spec_repo (str): The name of the repo with_update (bool): Whether to update the repo before installing is_first (bool): Whether it's the first time to check the dbgpt - Returns: Tuple[str, Path] | None: The repo and the path of the dbgpt """ repos = _list_repos_details() - if repo: + if spec_repo: repos = list(filter(lambda x: x[0] == repo, repos)) if not repos: - logger.error(f"The specified repo '{repo}' does not exist.") + logger.error(f"The specified repo '{spec_repo}' does not exist.") return if is_first and with_update: for repo in repos: @@ -253,7 +271,9 @@ def check_with_retry( ): return repo[0], dbgpt_path if is_first: - return check_with_retry(name, repo, with_update=with_update, is_first=False) + return check_with_retry( + name, spec_repo, with_update=with_update, is_first=False + ) return None diff --git a/dbgpt/util/function_utils.py b/dbgpt/util/function_utils.py index ccce14ddd..c28a00f7a 100644 --- a/dbgpt/util/function_utils.py +++ b/dbgpt/util/function_utils.py @@ -3,6 +3,14 @@ from functools import wraps from typing import Any, get_args, get_origin, get_type_hints +from typeguard import check_type + + +def _is_typing(obj): + from typing import _Final # type: ignore + + return isinstance(obj, _Final) + def _is_instance_of_generic_type(obj, generic_type): """Check if an object is an instance of a generic type.""" @@ -18,18 +26,44 @@ def _is_instance_of_generic_type(obj, generic_type): return isinstance(obj, origin) # Check if object matches the generic origin (like list, dict) - if not isinstance(obj, origin): - return False + if not _is_typing(origin): + return isinstance(obj, origin) + + objs = [obj for _ in range(len(args))] # For each item in the object, check if it matches the corresponding type argument - for sub_obj, arg in zip(obj, args): + for sub_obj, arg in zip(objs, args): # Skip check if the type argument is Any - if arg is not Any and not isinstance(sub_obj, arg): - return False - + if arg is not Any: + if _is_typing(arg): + sub_args = get_args(arg) + if ( + sub_args + and not _is_typing(sub_args[0]) + and not isinstance(sub_obj, sub_args[0]) + ): + return False + elif not isinstance(sub_obj, arg): + return False return True +def _check_type(obj, t) -> bool: + try: + check_type(obj, t) + return True + except Exception: + return False + + +def _get_orders(obj, arg_types): + try: + orders = [i for i, t in enumerate(arg_types) if _check_type(obj, t)] + return orders[0] if orders else int(1e8) + except Exception: + return int(1e8) + + def _sort_args(func, args, kwargs): sig = inspect.signature(func) type_hints = get_type_hints(func) @@ -49,9 +83,7 @@ def _sort_args(func, args, kwargs): sorted_args = sorted( other_args, - key=lambda x: next( - i for i, t in enumerate(arg_types) if _is_instance_of_generic_type(x, t) - ), + key=lambda x: _get_orders(x, arg_types), ) return (*self_arg, *sorted_args), kwargs diff --git a/docs/docs/awel/cookbook/quickstart_basic_awel_workflow.md b/docs/docs/awel/cookbook/quickstart_basic_awel_workflow.md index 7fddd8ff0..6719044d1 100644 --- a/docs/docs/awel/cookbook/quickstart_basic_awel_workflow.md +++ b/docs/docs/awel/cookbook/quickstart_basic_awel_workflow.md @@ -4,7 +4,7 @@ At first, install dbgpt, and necessary dependencies: -```python +```shell pip install dbgpt --upgrade pip install openai ``` @@ -14,7 +14,7 @@ Create a python file `simple_sdk_llm_example_dag.py` and write the following con ```python from dbgpt.core import BaseOutputParser from dbgpt.core.awel import DAG -from dbgpt.core.operator import ( +from dbgpt.core.operators import ( PromptBuilderOperator, RequestBuilderOperator, ) diff --git a/scripts/setup_autodl_env.sh b/scripts/setup_autodl_env.sh index 13d08736d..236c1d35c 100644 --- a/scripts/setup_autodl_env.sh +++ b/scripts/setup_autodl_env.sh @@ -35,14 +35,14 @@ clone_repositories() { cd /root && git clone https://github.com/eosphoros-ai/DB-GPT.git mkdir -p /root/DB-GPT/models && cd /root/DB-GPT/models git clone https://huggingface.co/GanymedeNil/text2vec-large-chinese - git clone https://huggingface.co/THUDM/chatglm2-6b + git clone https://huggingface.co/Qwen/Qwen-1_8B-Chat rm -rf /root/DB-GPT/models/text2vec-large-chinese/.git - rm -rf /root/DB-GPT/models/chatglm2-6b/.git + rm -rf /root/DB-GPT/models/Qwen-1_8B-Chat/.git } install_dbgpt_packages() { conda activate dbgpt && cd /root/DB-GPT && pip install -e ".[default]" - cp .env.template .env && sed -i 's/LLM_MODEL=vicuna-13b-v1.5/LLM_MODEL=chatglm2-6b/' .env + cp .env.template .env && sed -i 's/LLM_MODEL=vicuna-13b-v1.5/LLM_MODEL=qwen-1.8b-chat/' .env } clean_up() { diff --git a/setup.py b/setup.py index a8074296f..17c0e2f23 100644 --- a/setup.py +++ b/setup.py @@ -367,6 +367,8 @@ def core_requires(): "python-dotenv==1.0.0", "cachetools", "pydantic<2,>=1", + # For AWEL type checking + "typeguard", ] # Simple command line dependencies setup_spec.extras["cli"] = setup_spec.extras["core"] + [