Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(core): Support RAG chat flow #1185

Merged
merged 2 commits into from
Feb 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion dbgpt/core/awel/runner/local_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
This runner will run the workflow in the current process.
"""
import logging
import traceback
from typing import Any, Dict, List, Optional, Set, cast

from dbgpt.component import SystemApp
Expand Down Expand Up @@ -143,7 +144,11 @@ async def _execute_node(
)
_skip_current_downstream_by_node_name(node, skip_nodes, skip_node_ids)
except Exception as e:
logger.info(f"Run operator {node.node_id} error, error message: {str(e)}")
msg = traceback.format_exc()
logger.info(
f"Run operator {type(node)}({node.node_id}) error, error message: "
f"{msg}"
)
task_ctx.set_current_state(TaskState.FAILED)
raise e

Expand Down
59 changes: 37 additions & 22 deletions dbgpt/serve/flow/service/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,22 +370,16 @@ def _parse_flow_category(self, dag: DAG) -> FlowCategory:
return FlowCategory.COMMON


def _is_chat_flow_type(obj: Any, is_class: bool = False) -> bool:
try:
from dbgpt.model.utils.chatgpt_utils import OpenAIStreamingOutputOperator
except ImportError:
OpenAIStreamingOutputOperator = None
def _is_chat_flow_type(output_obj: Any, is_class: bool = False) -> bool:
if is_class:
return (
obj == str
or obj == CommonLLMHttpResponseBody
or (OpenAIStreamingOutputOperator and obj == OpenAIStreamingOutputOperator)
output_obj == str
or output_obj == CommonLLMHttpResponseBody
or output_obj == ModelOutput
)
else:
chat_types = (str, CommonLLMHttpResponseBody)
if OpenAIStreamingOutputOperator:
chat_types += (OpenAIStreamingOutputOperator,)
return isinstance(obj, chat_types)
return isinstance(output_obj, chat_types)


async def _chat_with_dag_task(
Expand Down Expand Up @@ -439,29 +433,50 @@ async def _chat_with_dag_task(
yield f"data:{full_text}\n\n"
else:
async for output in await task.call_stream(request):
str_msg = ""
should_return = False
if isinstance(output, str):
if output.strip():
yield output
str_msg = output
elif isinstance(output, ModelOutput):
if output.error_code != 0:
str_msg = f"[SERVER_ERROR]{output.text}"
should_return = True
else:
str_msg = output.text
else:
yield "data:[SERVER_ERROR]The output is not a stream format\n\n"
return
str_msg = (
f"[SERVER_ERROR]The output is not a valid format"
f"({type(output)})"
)
should_return = True
if str_msg:
str_msg = str_msg.replace("\n", "\\n")
yield f"data:{str_msg}\n\n"
if should_return:
return
else:
result = await task.call(request)
str_msg = ""
if result is None:
yield "data:[SERVER_ERROR]The result is None\n\n"
str_msg = "[SERVER_ERROR]The result is None!"
elif isinstance(result, str):
yield f"data:{result}\n\n"
str_msg = result
elif isinstance(result, ModelOutput):
if result.error_code != 0:
yield f"data:[SERVER_ERROR]{result.text}\n\n"
str_msg = f"[SERVER_ERROR]{result.text}"
else:
yield f"data:{result.text}\n\n"
str_msg = result.text
elif isinstance(result, CommonLLMHttpResponseBody):
if result.error_code != 0:
yield f"data:[SERVER_ERROR]{result.text}\n\n"
str_msg = f"[SERVER_ERROR]{result.text}"
else:
yield f"data:{result.text}\n\n"
str_msg = result.text
elif isinstance(result, dict):
yield f"data:{json.dumps(result, ensure_ascii=False)}\n\n"
str_msg = json.dumps(result, ensure_ascii=False)
else:
yield f"data:[SERVER_ERROR]The result is not a valid format({type(result)})\n\n"
str_msg = f"[SERVER_ERROR]The result is not a valid format({type(result)})"

if str_msg:
str_msg = str_msg.replace("\n", "\\n")
yield f"data:{str_msg}\n\n"
6 changes: 2 additions & 4 deletions dbgpt/storage/vector_store/milvus_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,9 +339,7 @@ def similar_search_with_scores(self, text, topk, score_threshold) -> List[Chunk]
self.vector_field = x.name
_, docs_and_scores = self._search(text, topk)
if any(score < 0.0 or score > 1.0 for _, score, id in docs_and_scores):
import warnings

warnings.warn(
logger.warning(
"similarity score need between" f" 0 and 1, got {docs_and_scores}"
)

Expand All @@ -357,7 +355,7 @@ def similar_search_with_scores(self, text, topk, score_threshold) -> List[Chunk]
if score >= score_threshold
]
if len(docs_and_scores) == 0:
warnings.warn(
logger.warning(
"No relevant docs were retrieved using the relevance score"
f" threshold {score_threshold}"
)
Expand Down
12 changes: 10 additions & 2 deletions dbgpt/util/dbgpts/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,17 +56,25 @@ def list_repos():

@click.command(name="add")
@add_tap_options
@click.option(
"-b",
"--branch",
type=str,
default=None,
required=False,
help="The branch of the repository(Just for git repo)",
)
@click.option(
"--url",
type=str,
required=True,
help="The URL of the repo",
)
def add_repo(repo: str, url: str):
def add_repo(repo: str, branch: str | None, url: str):
"""Add a new repo"""
from .repo import add_repo

add_repo(repo, url)
add_repo(repo, url, branch)


@click.command(name="remove")
Expand Down
42 changes: 31 additions & 11 deletions dbgpt/util/dbgpts/repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,13 @@ def _list_repos_details() -> List[Tuple[str, str]]:
return results


def add_repo(repo: str, repo_url: str):
def add_repo(repo: str, repo_url: str, branch: str | None = None):
"""Add a new repo

Args:
repo (str): The name of the repo
repo_url (str): The URL of the repo
branch (str): The branch of the repo
"""
exist_repos = list_repos()
if repo in exist_repos and repo_url not in DEFAULT_REPO_MAP.values():
Expand All @@ -84,7 +85,7 @@ def add_repo(repo: str, repo_url: str):
repo_group_dir = os.path.join(DBGPTS_REPO_HOME, repo_arr[0])
os.makedirs(repo_group_dir, exist_ok=True)
if repo_url.startswith("http") or repo_url.startswith("git"):
clone_repo(repo, repo_group_dir, repo_name, repo_url)
clone_repo(repo, repo_group_dir, repo_name, repo_url, branch)
elif os.path.isdir(repo_url):
# Create soft link
os.symlink(repo_url, os.path.join(repo_group_dir, repo_name))
Expand All @@ -106,18 +107,36 @@ def remove_repo(repo: str):
logger.info(f"Repo '{repo}' removed successfully.")


def clone_repo(repo: str, repo_group_dir: str, repo_name: str, repo_url: str):
def clone_repo(
repo: str,
repo_group_dir: str,
repo_name: str,
repo_url: str,
branch: str | None = None,
):
"""Clone the specified repo

Args:
repo (str): The name of the repo
repo_group_dir (str): The directory of the repo group
repo_name (str): The name of the repo
repo_url (str): The URL of the repo
branch (str): The branch of the repo
"""
os.chdir(repo_group_dir)
subprocess.run(["git", "clone", repo_url, repo_name], check=True)
logger.info(f"Repo '{repo}' cloned from {repo_url} successfully.")
clone_command = ["git", "clone", repo_url, repo_name]

# If the branch is specified, add it to the clone command
if branch:
clone_command += ["-b", branch]

subprocess.run(clone_command, check=True)
if branch:
click.echo(
f"Repo '{repo}' cloned from {repo_url} with branch '{branch}' successfully."
)
else:
click.echo(f"Repo '{repo}' cloned from {repo_url} successfully.")


def update_repo(repo: str):
Expand Down Expand Up @@ -217,26 +236,25 @@ def _write_install_metadata(name: str, repo: str, install_path: Path):

def check_with_retry(
name: str,
repo: str | None = None,
spec_repo: str | None = None,
with_update: bool = False,
is_first: bool = False,
) -> Tuple[str, Path] | None:
"""Check the specified dbgpt with retry.

Args:
name (str): The name of the dbgpt
repo (str): The name of the repo
spec_repo (str): The name of the repo
with_update (bool): Whether to update the repo before installing
is_first (bool): Whether it's the first time to check the dbgpt

Returns:
Tuple[str, Path] | None: The repo and the path of the dbgpt
"""
repos = _list_repos_details()
if repo:
if spec_repo:
repos = list(filter(lambda x: x[0] == repo, repos))
if not repos:
logger.error(f"The specified repo '{repo}' does not exist.")
logger.error(f"The specified repo '{spec_repo}' does not exist.")
return
if is_first and with_update:
for repo in repos:
Expand All @@ -253,7 +271,9 @@ def check_with_retry(
):
return repo[0], dbgpt_path
if is_first:
return check_with_retry(name, repo, with_update=with_update, is_first=False)
return check_with_retry(
name, spec_repo, with_update=with_update, is_first=False
)
return None


Expand Down
50 changes: 41 additions & 9 deletions dbgpt/util/function_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,14 @@
from functools import wraps
from typing import Any, get_args, get_origin, get_type_hints

from typeguard import check_type


def _is_typing(obj):
from typing import _Final # type: ignore

return isinstance(obj, _Final)


def _is_instance_of_generic_type(obj, generic_type):
"""Check if an object is an instance of a generic type."""
Expand All @@ -18,18 +26,44 @@ def _is_instance_of_generic_type(obj, generic_type):
return isinstance(obj, origin)

# Check if object matches the generic origin (like list, dict)
if not isinstance(obj, origin):
return False
if not _is_typing(origin):
return isinstance(obj, origin)

objs = [obj for _ in range(len(args))]

# For each item in the object, check if it matches the corresponding type argument
for sub_obj, arg in zip(obj, args):
for sub_obj, arg in zip(objs, args):
# Skip check if the type argument is Any
if arg is not Any and not isinstance(sub_obj, arg):
return False

if arg is not Any:
if _is_typing(arg):
sub_args = get_args(arg)
if (
sub_args
and not _is_typing(sub_args[0])
and not isinstance(sub_obj, sub_args[0])
):
return False
elif not isinstance(sub_obj, arg):
return False
return True


def _check_type(obj, t) -> bool:
try:
check_type(obj, t)
return True
except Exception:
return False


def _get_orders(obj, arg_types):
try:
orders = [i for i, t in enumerate(arg_types) if _check_type(obj, t)]
return orders[0] if orders else int(1e8)
except Exception:
return int(1e8)


def _sort_args(func, args, kwargs):
sig = inspect.signature(func)
type_hints = get_type_hints(func)
Expand All @@ -49,9 +83,7 @@ def _sort_args(func, args, kwargs):

sorted_args = sorted(
other_args,
key=lambda x: next(
i for i, t in enumerate(arg_types) if _is_instance_of_generic_type(x, t)
),
key=lambda x: _get_orders(x, arg_types),
)
return (*self_arg, *sorted_args), kwargs

Expand Down
4 changes: 2 additions & 2 deletions docs/docs/awel/cookbook/quickstart_basic_awel_workflow.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

At first, install dbgpt, and necessary dependencies:

```python
```shell
pip install dbgpt --upgrade
pip install openai
```
Expand All @@ -14,7 +14,7 @@ Create a python file `simple_sdk_llm_example_dag.py` and write the following con
```python
from dbgpt.core import BaseOutputParser
from dbgpt.core.awel import DAG
from dbgpt.core.operator import (
from dbgpt.core.operators import (
PromptBuilderOperator,
RequestBuilderOperator,
)
Expand Down
6 changes: 3 additions & 3 deletions scripts/setup_autodl_env.sh
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,14 @@ clone_repositories() {
cd /root && git clone https://github.com/eosphoros-ai/DB-GPT.git
mkdir -p /root/DB-GPT/models && cd /root/DB-GPT/models
git clone https://huggingface.co/GanymedeNil/text2vec-large-chinese
git clone https://huggingface.co/THUDM/chatglm2-6b
git clone https://huggingface.co/Qwen/Qwen-1_8B-Chat
rm -rf /root/DB-GPT/models/text2vec-large-chinese/.git
rm -rf /root/DB-GPT/models/chatglm2-6b/.git
rm -rf /root/DB-GPT/models/Qwen-1_8B-Chat/.git
}

install_dbgpt_packages() {
conda activate dbgpt && cd /root/DB-GPT && pip install -e ".[default]"
cp .env.template .env && sed -i 's/LLM_MODEL=vicuna-13b-v1.5/LLM_MODEL=chatglm2-6b/' .env
cp .env.template .env && sed -i 's/LLM_MODEL=vicuna-13b-v1.5/LLM_MODEL=qwen-1.8b-chat/' .env
}

clean_up() {
Expand Down
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,8 @@ def core_requires():
"python-dotenv==1.0.0",
"cachetools",
"pydantic<2,>=1",
# For AWEL type checking
"typeguard",
]
# Simple command line dependencies
setup_spec.extras["cli"] = setup_spec.extras["core"] + [
Expand Down
Loading