diff --git a/.flake8 b/.flake8 new file mode 100644 index 000000000..61608bcaf --- /dev/null +++ b/.flake8 @@ -0,0 +1,40 @@ +[flake8] +exclude = + .eggs/ + build/ + */tests/* + *_private +max-line-length = 88 +inline-quotes = " +ignore = + C408 + C417 + E121 + E123 + E126 + E203 + E226 + E24 + E704 + W503 + W504 + W605 + I + N + B001 + B002 + B003 + B004 + B005 + B007 + B008 + B009 + B010 + B011 + B012 + B013 + B014 + B015 + B016 + B017 +avoid-escape = no diff --git a/.isort.cfg b/.isort.cfg new file mode 100644 index 000000000..f2dfd3350 --- /dev/null +++ b/.isort.cfg @@ -0,0 +1,13 @@ +[settings] +# This is to make isort compatible with Black. See +# https://black.readthedocs.io/en/stable/the_black_code_style.html#how-black-wraps-lines. +line_length=88 +profile=black +multi_line_output=3 +include_trailing_comma=True +use_parentheses=True +float_to_top=True +filter_files=True + +skip_glob=examples/notebook/* +sections=FUTURE,STDLIB,THIRDPARTY,FIRSTPARTY,LOCALFOLDER,AFTERRAY diff --git a/.mypy.ini b/.mypy.ini new file mode 100644 index 000000000..70bae183a --- /dev/null +++ b/.mypy.ini @@ -0,0 +1,20 @@ +[mypy] +exclude = /tests/ +# plugins = pydantic.mypy + +[mypy-graphviz.*] +ignore_missing_imports = True + +[mypy-cachetools.*] +ignore_missing_imports = True + +[mypy-coloredlogs.*] +ignore_missing_imports = True + +[mypy-termcolor.*] +ignore_missing_imports = True + +[mypy-pydantic.*] +strict_optional = False +ignore_missing_imports = True +follow_imports = skip \ No newline at end of file diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 0a4e1ad34..2dd7e535d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -20,4 +20,22 @@ repos: stages: [commit] pass_filenames: false args: [] + - id: python-test-doc + name: Python Doc Test + entry: make test-doc + language: system + exclude: '^dbgpt/app/static/|^web/' + types: [python] + stages: [commit] + pass_filenames: false + args: [] + - id: python-lint-mypy + name: Python Lint mypy + entry: make mypy + language: system + exclude: '^dbgpt/app/static/|^web/' + types: [python] + stages: [commit] + pass_filenames: false + args: [] diff --git a/Makefile b/Makefile index 57e83f42e..c8b06e756 100644 --- a/Makefile +++ b/Makefile @@ -14,7 +14,7 @@ setup: $(VENV)/bin/activate $(VENV)/bin/activate: $(VENV)/.venv-timestamp -$(VENV)/.venv-timestamp: setup.py +$(VENV)/.venv-timestamp: setup.py requirements # Create new virtual environment if setup.py has changed python3 -m venv $(VENV) $(VENV_BIN)/pip install --upgrade pip @@ -46,15 +46,14 @@ fmt: setup ## Format Python code # $(VENV_BIN)/blackdoc . $(VENV_BIN)/blackdoc dbgpt $(VENV_BIN)/blackdoc examples - # TODO: Type checking of Python code. - # https://github.com/python/mypy - # $(VENV_BIN)/mypy dbgpt - # TODO: uUse flake8 to enforce Python style guide. + # TODO: Use flake8 to enforce Python style guide. # https://flake8.pycqa.org/en/latest/ - # $(VENV_BIN)/flake8 dbgpt + $(VENV_BIN)/flake8 dbgpt/core/ + # TODO: More package checks with flake8. + .PHONY: pre-commit -pre-commit: fmt test ## Run formatting and unit tests before committing +pre-commit: fmt test test-doc mypy ## Run formatting and unit tests before committing test: $(VENV)/.testenv ## Run unit tests $(VENV_BIN)/pytest dbgpt @@ -64,6 +63,12 @@ test-doc: $(VENV)/.testenv ## Run doctests # -k "not test_" skips tests that are not doctests. $(VENV_BIN)/pytest --doctest-modules -k "not test_" dbgpt/core +.PHONY: mypy +mypy: $(VENV)/.testenv ## Run mypy checks + # https://github.com/python/mypy + $(VENV_BIN)/mypy --config-file .mypy.ini dbgpt/core/ + # TODO: More package checks with mypy. + .PHONY: coverage coverage: setup ## Run tests and report coverage $(VENV_BIN)/pytest dbgpt --cov=dbgpt diff --git a/dbgpt/agent/plugin/commands/built_in/disply_type/show_chart_gen.py b/dbgpt/agent/plugin/commands/built_in/disply_type/show_chart_gen.py index bfd62fe07..33d4706f9 100644 --- a/dbgpt/agent/plugin/commands/built_in/disply_type/show_chart_gen.py +++ b/dbgpt/agent/plugin/commands/built_in/disply_type/show_chart_gen.py @@ -1,22 +1,22 @@ +import logging import os import uuid import matplotlib +import matplotlib.pyplot as plt +import matplotlib.ticker as mtick import pandas as pd import seaborn as sns +from matplotlib.font_manager import FontManager from pandas import DataFrame +from dbgpt.configs.model_config import PILOT_PATH +from dbgpt.util.string_utils import is_scientific_notation + from ...command_mange import command matplotlib.use("Agg") -import logging -import matplotlib.pyplot as plt -import matplotlib.ticker as mtick -from matplotlib.font_manager import FontManager - -from dbgpt.configs.model_config import PILOT_PATH -from dbgpt.util.string_utils import is_scientific_notation logger = logging.getLogger(__name__) diff --git a/dbgpt/app/base.py b/dbgpt/app/base.py index 54c6f5122..622440666 100644 --- a/dbgpt/app/base.py +++ b/dbgpt/app/base.py @@ -192,7 +192,8 @@ def _create_mysql_database(db_name: str, db_url: str, try_to_create_db: bool = F with engine_no_db.connect() as conn: conn.execute( DDL( - f"CREATE DATABASE {db_name} CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci" + f"CREATE DATABASE {db_name} CHARACTER SET utf8mb4 COLLATE " + f"utf8mb4_unicode_ci" ) ) logger.info(f"Database {db_name} successfully created") @@ -218,26 +219,31 @@ class WebServerParameters(BaseParameters): controller_addr: Optional[str] = field( default=None, metadata={ - "help": "The Model controller address to connect. If None, read model controller address from environment key `MODEL_SERVER`." + "help": "The Model controller address to connect. If None, read model " + "controller address from environment key `MODEL_SERVER`." }, ) model_name: str = field( default=None, metadata={ - "help": "The default model name to use. If None, read model name from environment key `LLM_MODEL`.", + "help": "The default model name to use. If None, read model name from " + "environment key `LLM_MODEL`.", "tags": "fixed", }, ) share: Optional[bool] = field( default=False, metadata={ - "help": "Whether to create a publicly shareable link for the interface. Creates an SSH tunnel to make your UI accessible from anywhere. " + "help": "Whether to create a publicly shareable link for the interface. " + "Creates an SSH tunnel to make your UI accessible from anywhere. " }, ) remote_embedding: Optional[bool] = field( default=False, metadata={ - "help": "Whether to enable remote embedding models. If it is True, you need to start a embedding model through `dbgpt start worker --worker_type text2vec --model_name xxx --model_path xxx`" + "help": "Whether to enable remote embedding models. If it is True, you need" + " to start a embedding model through `dbgpt start worker --worker_type " + "text2vec --model_name xxx --model_path xxx`" }, ) log_level: Optional[str] = field( @@ -286,3 +292,10 @@ class WebServerParameters(BaseParameters): "help": "The directories to search awel files, split by `,`", }, ) + default_thread_pool_size: Optional[int] = field( + default=None, + metadata={ + "help": "The default thread pool size, If None, " + "use default config of python thread pool", + }, + ) diff --git a/dbgpt/app/component_configs.py b/dbgpt/app/component_configs.py index fefe57a16..d68c6de2f 100644 --- a/dbgpt/app/component_configs.py +++ b/dbgpt/app/component_configs.py @@ -25,7 +25,9 @@ def initialize_components( from dbgpt.model.cluster.controller.controller import controller # Register global default executor factory first - system_app.register(DefaultExecutorFactory) + system_app.register( + DefaultExecutorFactory, max_workers=param.default_thread_pool_size + ) system_app.register_instance(controller) from dbgpt.serve.agent.hub.controller import module_agent diff --git a/dbgpt/app/dbgpt_server.py b/dbgpt/app/dbgpt_server.py index b4328d257..12e79dab2 100644 --- a/dbgpt/app/dbgpt_server.py +++ b/dbgpt/app/dbgpt_server.py @@ -3,8 +3,6 @@ import sys from typing import List -ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -sys.path.append(ROOT_PATH) from fastapi import FastAPI from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware @@ -41,6 +39,10 @@ setup_logging, ) +ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +sys.path.append(ROOT_PATH) + + static_file_path = os.path.join(ROOT_PATH, "dbgpt", "app/static") CFG = Config() diff --git a/dbgpt/app/knowledge/_cli/knowledge_client.py b/dbgpt/app/knowledge/_cli/knowledge_client.py index 7f2b5ea1b..ac87a11e6 100644 --- a/dbgpt/app/knowledge/_cli/knowledge_client.py +++ b/dbgpt/app/knowledge/_cli/knowledge_client.py @@ -5,6 +5,7 @@ from urllib.parse import urljoin import requests +from prettytable import PrettyTable from dbgpt.app.knowledge.request.request import ( ChunkQueryRequest, @@ -193,9 +194,6 @@ def upload(filename: str): return -from prettytable import PrettyTable - - class _KnowledgeVisualizer: def __init__(self, api_address: str, out_format: str): self.client = KnowledgeApiClient(api_address) diff --git a/dbgpt/app/llmserver.py b/dbgpt/app/llmserver.py index be07c1a02..8fedc5a6e 100644 --- a/dbgpt/app/llmserver.py +++ b/dbgpt/app/llmserver.py @@ -4,13 +4,14 @@ import os import sys -ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -sys.path.append(ROOT_PATH) - from dbgpt._private.config import Config from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG, LLM_MODEL_CONFIG from dbgpt.model.cluster import run_worker_manager +ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +sys.path.append(ROOT_PATH) + + CFG = Config() model_path = LLM_MODEL_CONFIG.get(CFG.LLM_MODEL) diff --git a/dbgpt/app/scene/base_chat.py b/dbgpt/app/scene/base_chat.py index e2647fb4c..f48feceb0 100644 --- a/dbgpt/app/scene/base_chat.py +++ b/dbgpt/app/scene/base_chat.py @@ -313,8 +313,9 @@ async def stream_call(self): ) ### store current conversation span.end(metadata={"error": str(e)}) - # self.memory.append(self.current_message) - self.current_message.end_current_round() + await blocking_func_to_async( + self._executor, self.current_message.end_current_round + ) async def nostream_call(self): payload = await self._build_model_request() @@ -381,8 +382,9 @@ async def nostream_call(self): ) span.end(metadata={"error": str(e)}) ### store dialogue - # self.memory.append(self.current_message) - self.current_message.end_current_round() + await blocking_func_to_async( + self._executor, self.current_message.end_current_round + ) return self.current_ai_response() async def get_llm_response(self): diff --git a/dbgpt/component.py b/dbgpt/component.py index f094213a1..2468906b9 100644 --- a/dbgpt/component.py +++ b/dbgpt/component.py @@ -104,28 +104,37 @@ def init_app(self, system_app: SystemApp): @classmethod def get_instance( - cls, + cls: Type[T], system_app: SystemApp, default_component=_EMPTY_DEFAULT_COMPONENT, - or_register_component: Type[BaseComponent] = None, + or_register_component: Optional[Type[T]] = None, *args, **kwargs, - ) -> BaseComponent: + ) -> T: """Get the current component instance. Args: system_app (SystemApp): The system app default_component : The default component instance if not retrieve by name - or_register_component (Type[BaseComponent]): The new component to register if not retrieve by name + or_register_component (Type[T]): The new component to register if not retrieve by name Returns: - BaseComponent: The component instance + T: The component instance """ + # Check for keyword argument conflicts + if "default_component" in kwargs: + raise ValueError( + "default_component argument given in both fixed and **kwargs" + ) + if "or_register_component" in kwargs: + raise ValueError( + "or_register_component argument given in both fixed and **kwargs" + ) + kwargs["default_component"] = default_component + kwargs["or_register_component"] = or_register_component return system_app.get_component( cls.name, cls, - default_component=default_component, - or_register_component=or_register_component, *args, **kwargs, ) @@ -159,11 +168,11 @@ def config(self) -> AppConfig: """Returns the internal AppConfig.""" return self._app_config - def register(self, component: Type[BaseComponent], *args, **kwargs) -> T: + def register(self, component: Type[T], *args, **kwargs) -> T: """Register a new component by its type. Args: - component (Type[BaseComponent]): The component class to register + component (Type[T]): The component class to register Returns: T: The instance of registered component @@ -198,7 +207,7 @@ def get_component( name: Union[str, ComponentType], component_type: Type[T], default_component=_EMPTY_DEFAULT_COMPONENT, - or_register_component: Type[BaseComponent] = None, + or_register_component: Optional[Type[T]] = None, *args, **kwargs, ) -> T: @@ -208,7 +217,7 @@ def get_component( name (Union[str, ComponentType]): Component name component_type (Type[T]): The type of current retrieve component default_component : The default component instance if not retrieve by name - or_register_component (Type[BaseComponent]): The new component to register if not retrieve by name + or_register_component (Type[T]): The new component to register if not retrieve by name Returns: T: The instance retrieved by component name diff --git a/dbgpt/core/__init__.py b/dbgpt/core/__init__.py index 6cee5890e..fe955bffb 100644 --- a/dbgpt/core/__init__.py +++ b/dbgpt/core/__init__.py @@ -1,11 +1,13 @@ -from dbgpt.core.interface.cache import ( +"""The core module contains the core interfaces and classes for dbgpt.""" + +from dbgpt.core.interface.cache import ( # noqa: F401 CacheClient, CacheConfig, CacheKey, CachePolicy, CacheValue, ) -from dbgpt.core.interface.llm import ( +from dbgpt.core.interface.llm import ( # noqa: F401 DefaultMessageConverter, LLMClient, MessageConverter, @@ -16,7 +18,7 @@ ModelRequest, ModelRequestContext, ) -from dbgpt.core.interface.message import ( +from dbgpt.core.interface.message import ( # noqa: F401 AIMessage, BaseMessage, ConversationIdentifier, @@ -29,8 +31,11 @@ StorageConversation, SystemMessage, ) -from dbgpt.core.interface.output_parser import BaseOutputParser, SQLOutputParser -from dbgpt.core.interface.prompt import ( +from dbgpt.core.interface.output_parser import ( # noqa: F401 + BaseOutputParser, + SQLOutputParser, +) +from dbgpt.core.interface.prompt import ( # noqa: F401 BasePromptTemplate, ChatPromptTemplate, HumanPromptTemplate, @@ -40,8 +45,8 @@ StoragePromptTemplate, SystemPromptTemplate, ) -from dbgpt.core.interface.serialization import Serializable, Serializer -from dbgpt.core.interface.storage import ( +from dbgpt.core.interface.serialization import Serializable, Serializer # noqa: F401 +from dbgpt.core.interface.storage import ( # noqa: F401 DefaultStorageItemAdapter, InMemoryStorage, QuerySpec, diff --git a/dbgpt/core/_private/example_base.py b/dbgpt/core/_private/example_base.py index 687059b67..f94dde4dc 100644 --- a/dbgpt/core/_private/example_base.py +++ b/dbgpt/core/_private/example_base.py @@ -1,27 +1,34 @@ +"""Example selector base class""" + from abc import ABC from enum import Enum -from typing import List +from typing import List, Optional from dbgpt._private.pydantic import BaseModel class ExampleType(Enum): + """Example type""" + ONE_SHOT = "one_shot" FEW_SHOT = "few_shot" class ExampleSelector(BaseModel, ABC): + """Example selector base class""" + examples_record: List[dict] use_example: bool = False type: str = ExampleType.ONE_SHOT.value def examples(self, count: int = 2): + """Return examples""" if ExampleType.ONE_SHOT.value == self.type: - return self.__one_show_context() + return self.__one_shot_context() else: return self.__few_shot_context(count) - def __few_shot_context(self, count: int = 2) -> List[dict]: + def __few_shot_context(self, count: int = 2) -> Optional[List[dict]]: """ Use 2 or more examples, default 2 Returns: example text @@ -31,14 +38,14 @@ def __few_shot_context(self, count: int = 2) -> List[dict]: return need_use return None - def __one_show_context(self) -> dict: + def __one_shot_context(self) -> Optional[dict]: """ Use one examples Returns: """ if self.use_example: - need_use = self.examples_record[:1] + need_use = self.examples_record[-1] return need_use return None diff --git a/dbgpt/core/_private/prompt_registry.py b/dbgpt/core/_private/prompt_registry.py index 5684b1f6a..bf0a58daa 100644 --- a/dbgpt/core/_private/prompt_registry.py +++ b/dbgpt/core/_private/prompt_registry.py @@ -1,8 +1,12 @@ +"""Prompt template registry. + +This module is deprecated. we will remove it in the future. +""" #!/usr/bin/env python3 # -*- coding: utf-8 -*- from collections import defaultdict -from typing import Dict, List +from typing import Dict, List, Optional _DEFAULT_MODEL_KEY = "___default_prompt_template_model_key__" _DEFUALT_LANGUAGE_KEY = "___default_prompt_template_language_key__" @@ -14,15 +18,15 @@ class PromptTemplateRegistry: """ def __init__(self) -> None: - self.registry = defaultdict(dict) + self.registry = defaultdict(dict) # type: ignore def register( self, prompt_template, language: str = "en", is_default: bool = False, - model_names: List[str] = None, - scene_name: str = None, + model_names: Optional[List[str]] = None, + scene_name: Optional[str] = None, ) -> None: """Register prompt template with scene name, language registry dict format: @@ -43,7 +47,7 @@ def register( if not scene_name: raise ValueError("Prompt template scene name cannot be empty") if not model_names: - model_names: List[str] = [_DEFAULT_MODEL_KEY] + model_names = [_DEFAULT_MODEL_KEY] scene_registry = self.registry[scene_name] _register_scene_prompt_template( scene_registry, prompt_template, language, model_names @@ -64,7 +68,7 @@ def get_prompt_template( scene_name: str, language: str, model_name: str, - proxyllm_backend: str = None, + proxyllm_backend: Optional[str] = None, ): """Get prompt template with scene name, language and model name proxyllm_backend: see CFG.PROXYLLM_BACKEND diff --git a/dbgpt/core/awel/__init__.py b/dbgpt/core/awel/__init__.py index 01c91153c..84e2fc775 100644 --- a/dbgpt/core/awel/__init__.py +++ b/dbgpt/core/awel/__init__.py @@ -1,9 +1,10 @@ -"""Agentic Workflow Expression Language (AWEL) +"""Agentic Workflow Expression Language (AWEL). -Note: - -AWEL is still an experimental feature and only opens the lowest level API. -The stability of this API cannot be guaranteed at present. +Agentic Workflow Expression Language(AWEL) is a set of intelligent agent workflow +expression language specially designed for large model application development. It +provides great functionality and flexibility. Through the AWEL API, you can focus on +the development of business logic for LLMs applications without paying attention to +cumbersome model and environment details. """ @@ -71,10 +72,12 @@ "TransformStreamAbsOperator", "HttpTrigger", "setup_dev_environment", + "_is_async_iterator", ] def initialize_awel(system_app: SystemApp, dag_dirs: List[str]): + """Initialize AWEL.""" from .dag.base import DAGVar from .dag.dag_manager import DAGManager from .operator.base import initialize_runner @@ -92,13 +95,13 @@ def initialize_awel(system_app: SystemApp, dag_dirs: List[str]): def setup_dev_environment( dags: List[DAG], - host: Optional[str] = "127.0.0.1", - port: Optional[int] = 5555, + host: str = "127.0.0.1", + port: int = 5555, logging_level: Optional[str] = None, logger_filename: Optional[str] = None, show_dag_graph: Optional[bool] = True, ) -> None: - """Setup a development environment for AWEL. + """Run AWEL in development environment. Just using in development environment, not production environment. @@ -107,9 +110,11 @@ def setup_dev_environment( host (Optional[str], optional): The host. Defaults to "127.0.0.1" port (Optional[int], optional): The port. Defaults to 5555. logging_level (Optional[str], optional): The logging level. Defaults to None. - logger_filename (Optional[str], optional): The logger filename. Defaults to None. - show_dag_graph (Optional[bool], optional): Whether show the DAG graph. Defaults to True. - If True, the DAG graph will be saved to a file and open it automatically. + logger_filename (Optional[str], optional): The logger filename. + Defaults to None. + show_dag_graph (Optional[bool], optional): Whether show the DAG graph. + Defaults to True. If True, the DAG graph will be saved to a file and open + it automatically. """ import uvicorn from fastapi import FastAPI @@ -138,7 +143,9 @@ def setup_dev_environment( logger.info(f"Visualize DAG {str(dag)} to {dag_graph_file}") except Exception as e: logger.warning( - f"Visualize DAG {str(dag)} failed: {e}, if your system has no graphviz, you can install it by `pip install graphviz` or `sudo apt install graphviz`" + f"Visualize DAG {str(dag)} failed: {e}, if your system has no " + f"graphviz, you can install it by `pip install graphviz` or " + f"`sudo apt install graphviz`" ) for trigger in dag.trigger_nodes: trigger_manager.register_trigger(trigger) diff --git a/dbgpt/core/awel/base.py b/dbgpt/core/awel/base.py index 97cb8ad05..a0cb26cb8 100644 --- a/dbgpt/core/awel/base.py +++ b/dbgpt/core/awel/base.py @@ -1,7 +1,10 @@ +"""Base classes for AWEL.""" from abc import ABC, abstractmethod class Trigger(ABC): + """Base class for trigger.""" + @abstractmethod async def trigger(self) -> None: """Trigger the workflow or a specific operation in the workflow.""" diff --git a/dbgpt/core/awel/dag/__init__.py b/dbgpt/core/awel/dag/__init__.py index e69de29bb..59d722c63 100644 --- a/dbgpt/core/awel/dag/__init__.py +++ b/dbgpt/core/awel/dag/__init__.py @@ -0,0 +1 @@ +"""The module of DAGs.""" diff --git a/dbgpt/core/awel/dag/base.py b/dbgpt/core/awel/dag/base.py index 7f1017c7a..5db95f244 100644 --- a/dbgpt/core/awel/dag/base.py +++ b/dbgpt/core/awel/dag/base.py @@ -1,3 +1,7 @@ +"""The base module of DAG. + +DAG is the core component of AWEL, it is used to define the relationship between tasks. +""" import asyncio import contextvars import logging @@ -6,7 +10,7 @@ from abc import ABC, abstractmethod from collections import deque from concurrent.futures import Executor -from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Union +from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Union, cast from dbgpt.component import SystemApp @@ -27,86 +31,108 @@ def _is_async_context(): class DependencyMixin(ABC): + """The mixin class for DAGNode. + + This class defines the interface for setting upstream and downstream nodes. + + And it also implements the operator << and >> for setting upstream + and downstream nodes. + """ + @abstractmethod - def set_upstream(self, nodes: DependencyType) -> "DependencyMixin": + def set_upstream(self, nodes: DependencyType) -> None: """Set one or more upstream nodes for this node. Args: nodes (DependencyType): Upstream nodes to be set to current node. - Returns: - DependencyMixin: Returns self to allow method chaining. - Raises: - ValueError: If no upstream nodes are provided or if an argument is not a DependencyMixin. + ValueError: If no upstream nodes are provided or if an argument is + not a DependencyMixin. """ @abstractmethod - def set_downstream(self, nodes: DependencyType) -> "DependencyMixin": + def set_downstream(self, nodes: DependencyType) -> None: """Set one or more downstream nodes for this node. Args: nodes (DependencyType): Downstream nodes to be set to current node. - Returns: - DependencyMixin: Returns self to allow method chaining. - Raises: - ValueError: If no downstream nodes are provided or if an argument is not a DependencyMixin. + ValueError: If no downstream nodes are provided or if an argument is + not a DependencyMixin. """ def __lshift__(self, nodes: DependencyType) -> DependencyType: - """Implements self << nodes + """Set upstream nodes for current node. - Example: + Implements: self << nodes. - .. code-block:: python + Example: + .. code-block:: python - # means node.set_upstream(input_node) - node << input_node + # means node.set_upstream(input_node) + node << input_node + # means node2.set_upstream([input_node]) + node2 << [input_node] - # means node2.set_upstream([input_node]) - node2 << [input_node] """ self.set_upstream(nodes) return nodes def __rshift__(self, nodes: DependencyType) -> DependencyType: - """Implements self >> nodes + """Set downstream nodes for current node. - Example: + Implements: self >> nodes. - .. code-block:: python + Examples: + .. code-block:: python - # means node.set_downstream(next_node) - node >> next_node + # means node.set_downstream(next_node) + node >> next_node - # means node2.set_downstream([next_node]) - node2 >> [next_node] + # means node2.set_downstream([next_node]) + node2 >> [next_node] """ self.set_downstream(nodes) return nodes def __rrshift__(self, nodes: DependencyType) -> "DependencyMixin": - """Implements [node] >> self""" + """Set upstream nodes for current node. + + Implements: [node] >> self + """ self.__lshift__(nodes) return self def __rlshift__(self, nodes: DependencyType) -> "DependencyMixin": - """Implements [node] << self""" + """Set downstream nodes for current node. + + Implements: [node] << self + """ self.__rshift__(nodes) return self class DAGVar: + """The DAGVar is used to store the current DAG context.""" + _thread_local = threading.local() - _async_local = contextvars.ContextVar("current_dag_stack", default=deque()) - _system_app: SystemApp = None - _executor: Executor = None + _async_local: contextvars.ContextVar = contextvars.ContextVar( + "current_dag_stack", default=deque() + ) + _system_app: Optional[SystemApp] = None + # The executor for current DAG, this is used run some sync tasks in async DAG + _executor: Optional[Executor] = None @classmethod def enter_dag(cls, dag) -> None: + """Enter a DAG context. + + Args: + dag (DAG): The DAG to enter + """ is_async = _is_async_context() if is_async: stack = cls._async_local.get() @@ -119,6 +145,7 @@ def enter_dag(cls, dag) -> None: @classmethod def exit_dag(cls) -> None: + """Exit a DAG context.""" is_async = _is_async_context() if is_async: stack = cls._async_local.get() @@ -134,6 +161,11 @@ def exit_dag(cls) -> None: @classmethod def get_current_dag(cls) -> Optional["DAG"]: + """Get the current DAG. + + Returns: + Optional[DAG]: The current DAG + """ is_async = _is_async_context() if is_async: stack = cls._async_local.get() @@ -147,36 +179,56 @@ def get_current_dag(cls) -> Optional["DAG"]: return None @classmethod - def get_current_system_app(cls) -> SystemApp: + def get_current_system_app(cls) -> Optional[SystemApp]: + """Get the current system app. + + Returns: + Optional[SystemApp]: The current system app + """ # if not cls._system_app: # raise RuntimeError("System APP not set for DAGVar") return cls._system_app @classmethod def set_current_system_app(cls, system_app: SystemApp) -> None: + """Set the current system app. + + Args: + system_app (SystemApp): The system app to set + """ if cls._system_app: - logger.warn("System APP has already set, nothing to do") + logger.warning("System APP has already set, nothing to do") else: cls._system_app = system_app @classmethod - def get_executor(cls) -> Executor: + def get_executor(cls) -> Optional[Executor]: + """Get the current executor. + + Returns: + Optional[Executor]: The current executor + """ return cls._executor @classmethod def set_executor(cls, executor: Executor) -> None: + """Set the current executor. + + Args: + executor (Executor): The executor to set + """ cls._executor = executor class DAGLifecycle: - """The lifecycle of DAG""" + """The lifecycle of DAG.""" async def before_dag_run(self): - """The callback before DAG run""" + """Execute before DAG run.""" pass async def after_dag_end(self): - """The callback after DAG end, + """Execute after DAG end. This method may be called multiple times, please make sure it is idempotent. """ @@ -184,6 +236,8 @@ async def after_dag_end(self): class DAGNode(DAGLifecycle, DependencyMixin, ABC): + """The base class of DAGNode.""" + resource_group: Optional[ResourceGroup] = None """The resource group of current DAGNode""" @@ -196,6 +250,17 @@ def __init__( executor: Optional[Executor] = None, **kwargs, ) -> None: + """Initialize a DAGNode. + + Args: + dag (Optional["DAG"], optional): The DAG to add this node to. + Defaults to None. + node_id (Optional[str], optional): The node id. Defaults to None. + node_name (Optional[str], optional): The node name. Defaults to None. + system_app (Optional[SystemApp], optional): The system app. + Defaults to None. + executor (Optional[Executor], optional): The executor. Defaults to None. + """ super().__init__() self._upstream: List["DAGNode"] = [] self._downstream: List["DAGNode"] = [] @@ -206,24 +271,28 @@ def __init__( self._executor: Optional[Executor] = executor or DAGVar.get_executor() if not node_id and self._dag: node_id = self._dag._new_node_id() - self._node_id: str = node_id - self._node_name: str = node_name + self._node_id: Optional[str] = node_id + self._node_name: Optional[str] = node_name @property def node_id(self) -> str: + """Return the node id of current DAGNode.""" + if not self._node_id: + raise ValueError("Node id not set for current DAGNode") return self._node_id @property @abstractmethod def dev_mode(self) -> bool: - """Whether current DAGNode is in dev mode""" + """Whether current DAGNode is in dev mode.""" @property - def system_app(self) -> SystemApp: + def system_app(self) -> Optional[SystemApp]: + """Return the system app of current DAGNode.""" return self._system_app def set_system_app(self, system_app: SystemApp) -> None: - """Set system app for current DAGNode + """Set system app for current DAGNode. Args: system_app (SystemApp): The system app @@ -231,50 +300,97 @@ def set_system_app(self, system_app: SystemApp) -> None: self._system_app = system_app def set_node_id(self, node_id: str) -> None: + """Set node id for current DAGNode. + + Args: + node_id (str): The node id + """ self._node_id = node_id def __hash__(self) -> int: + """Return the hash value of current DAGNode. + + If the node_id is not None, return the hash value of node_id. + """ if self.node_id: return hash(self.node_id) else: return super().__hash__() def __eq__(self, other: Any) -> bool: + """Return whether the current DAGNode is equal to other DAGNode.""" if not isinstance(other, DAGNode): return False return self.node_id == other.node_id @property - def node_name(self) -> str: + def node_name(self) -> Optional[str]: + """Return the node name of current DAGNode. + + Returns: + Optional[str]: The node name of current DAGNode + """ return self._node_name @property - def dag(self) -> "DAG": + def dag(self) -> Optional["DAG"]: + """Return the DAG of current DAGNode. + + Returns: + Optional["DAG"]: The DAG of current DAGNode + """ return self._dag - def set_upstream(self, nodes: DependencyType) -> "DAGNode": + def set_upstream(self, nodes: DependencyType) -> None: + """Set upstream nodes for current node. + + Args: + nodes (DependencyType): Upstream nodes to be set to current node. + """ self.set_dependency(nodes) - def set_downstream(self, nodes: DependencyType) -> "DAGNode": + def set_downstream(self, nodes: DependencyType) -> None: + """Set downstream nodes for current node. + + Args: + nodes (DependencyType): Downstream nodes to be set to current node. + """ self.set_dependency(nodes, is_upstream=False) @property def upstream(self) -> List["DAGNode"]: + """Return the upstream nodes of current DAGNode. + + Returns: + List["DAGNode"]: The upstream nodes of current DAGNode + """ return self._upstream @property def downstream(self) -> List["DAGNode"]: + """Return the downstream nodes of current DAGNode. + + Returns: + List["DAGNode"]: The downstream nodes of current DAGNode + """ return self._downstream def set_dependency(self, nodes: DependencyType, is_upstream: bool = True) -> None: + """Set dependency for current node. + + Args: + nodes (DependencyType): The nodes to set dependency to current node. + is_upstream (bool, optional): Whether set upstream nodes. Defaults to True. + """ if not isinstance(nodes, Sequence): nodes = [nodes] if not all(isinstance(node, DAGNode) for node in nodes): raise ValueError( - "all nodes to set dependency to current node must be instance of 'DAGNode'" + "all nodes to set dependency to current node must be instance " + "of 'DAGNode'" ) - nodes: Sequence[DAGNode] = nodes - dags = set([node.dag for node in nodes if node.dag]) + nodes = cast(Sequence[DAGNode], nodes) + dags = set([node.dag for node in nodes if node.dag]) # noqa: C403 if self.dag: dags.add(self.dag) if not dags: @@ -302,6 +418,7 @@ def set_dependency(self, nodes: DependencyType, is_upstream: bool = True) -> Non node._upstream.append(self) def __repr__(self): + """Return the representation of current DAGNode.""" cls_name = self.__class__.__name__ if self.node_name and self.node_name: return f"{cls_name}(node_id={self.node_id}, node_name={self.node_name})" @@ -313,6 +430,7 @@ def __repr__(self): return f"{cls_name}" def __str__(self): + """Return the string of current DAGNode.""" return self.__repr__() @@ -321,7 +439,7 @@ def _build_task_key(task_name: str, key: str) -> str: class DAGContext: - """The context of current DAG, created when the DAG is running + """The context of current DAG, created when the DAG is running. Every DAG has been triggered will create a new DAGContext. """ @@ -329,22 +447,32 @@ class DAGContext: def __init__( self, streaming_call: bool = False, - node_to_outputs: Dict[str, TaskContext] = None, - node_name_to_ids: Dict[str, str] = None, + node_to_outputs: Optional[Dict[str, TaskContext]] = None, + node_name_to_ids: Optional[Dict[str, str]] = None, ) -> None: + """Initialize a DAGContext. + + Args: + streaming_call (bool, optional): Whether the current DAG is streaming call. + Defaults to False. + node_to_outputs (Optional[Dict[str, TaskContext]], optional): + The task outputs of current DAG. Defaults to None. + node_name_to_ids (Optional[Dict[str, str]], optional): + The task name to task id mapping. Defaults to None. + """ if not node_to_outputs: node_to_outputs = {} if not node_name_to_ids: node_name_to_ids = {} self._streaming_call = streaming_call - self._curr_task_ctx = None + self._curr_task_ctx: Optional[TaskContext] = None self._share_data: Dict[str, Any] = {} - self._node_to_outputs = node_to_outputs - self._node_name_to_ids = node_name_to_ids + self._node_to_outputs: Dict[str, TaskContext] = node_to_outputs + self._node_name_to_ids: Dict[str, str] = node_name_to_ids @property def _task_outputs(self) -> Dict[str, TaskContext]: - """The task outputs of current DAG + """Return the task outputs of current DAG. Just use for internal for now. Returns: @@ -354,18 +482,28 @@ def _task_outputs(self) -> Dict[str, TaskContext]: @property def current_task_context(self) -> TaskContext: + """Return the current task context.""" + if not self._curr_task_ctx: + raise RuntimeError("Current task context not set") return self._curr_task_ctx @property def streaming_call(self) -> bool: - """Whether the current DAG is streaming call""" + """Whether the current DAG is streaming call.""" return self._streaming_call def set_current_task_context(self, _curr_task_ctx: TaskContext) -> None: + """Set the current task context. + + When the task is running, the current task context + will be set to the task context. + + TODO: We should support parallel task running in the future. + """ self._curr_task_ctx = _curr_task_ctx def get_task_output(self, task_name: str) -> TaskOutput: - """Get the task output by task name + """Get the task output by task name. Args: task_name (str): The task name @@ -376,22 +514,41 @@ def get_task_output(self, task_name: str) -> TaskOutput: if task_name is None: raise ValueError("task_name can't be None") node_id = self._node_name_to_ids.get(task_name) - if node_id: + if not node_id: raise ValueError(f"Task name {task_name} not exists in DAG") - return self._task_outputs.get(node_id).task_output + task_output = self._task_outputs.get(node_id) + if not task_output: + raise ValueError(f"Task output for task {task_name} not exists") + return task_output.task_output async def get_from_share_data(self, key: str) -> Any: + """Get share data by key. + + Args: + key (str): The share data key + + Returns: + Any: The share data, you can cast it to the real type + """ return self._share_data.get(key) async def save_to_share_data( self, key: str, data: Any, overwrite: bool = False ) -> None: + """Save share data by key. + + Args: + key (str): The share data key + data (Any): The share data + overwrite (bool): Whether overwrite the share data if the key + already exists. Defaults to None. + """ if key in self._share_data and not overwrite: raise ValueError(f"Share data key {key} already exists") self._share_data[key] = data async def get_task_share_data(self, task_name: str, key: str) -> Any: - """Get share data by task name and key + """Get share data by task name and key. Args: task_name (str): The task name @@ -409,14 +566,14 @@ async def get_task_share_data(self, task_name: str, key: str) -> Any: async def save_task_share_data( self, task_name: str, key: str, data: Any, overwrite: bool = False ) -> None: - """Save share data by task name and key + """Save share data by task name and key. Args: task_name (str): The task name key (str): The share data key data (Any): The share data - overwrite (bool): Whether overwrite the share data if the key already exists. - Defaults to None. + overwrite (bool): Whether overwrite the share data if the key + already exists. Defaults to None. Raises: ValueError: If the share data key already exists and overwrite is not True @@ -429,15 +586,22 @@ async def save_task_share_data( class DAG: + """The DAG class. + + Manage the DAG nodes and the relationship between them. + """ + def __init__( self, dag_id: str, resource_group: Optional[ResourceGroup] = None ) -> None: + """Initialize a DAG.""" self._dag_id = dag_id self.node_map: Dict[str, DAGNode] = {} self.node_name_to_node: Dict[str, DAGNode] = {} - self._root_nodes: List[DAGNode] = None - self._leaf_nodes: List[DAGNode] = None - self._trigger_nodes: List[DAGNode] = None + self._root_nodes: List[DAGNode] = [] + self._leaf_nodes: List[DAGNode] = [] + self._trigger_nodes: List[DAGNode] = [] + self._resource_group: Optional[ResourceGroup] = resource_group def _append_node(self, node: DAGNode) -> None: if node.node_id in self.node_map: @@ -448,22 +612,26 @@ def _append_node(self, node: DAGNode) -> None: f"Node name {node.node_name} already exists in DAG {self.dag_id}" ) self.node_name_to_node[node.node_name] = node - self.node_map[node.node_id] = node + node_id = node.node_id + if not node_id: + raise ValueError("Node id can't be None") + self.node_map[node_id] = node # clear cached nodes - self._root_nodes = None - self._leaf_nodes = None + self._root_nodes = [] + self._leaf_nodes = [] def _new_node_id(self) -> str: return str(uuid.uuid4()) @property def dag_id(self) -> str: + """Return the dag id of current DAG.""" return self._dag_id def _build(self) -> None: from ..operator.common_operator import TriggerOperator - nodes = set() + nodes: Set[DAGNode] = set() for _, node in self.node_map.items(): nodes = nodes.union(_get_nodes(node)) self._root_nodes = list(set(filter(lambda x: not x.upstream, nodes))) @@ -474,7 +642,7 @@ def _build(self) -> None: @property def root_nodes(self) -> List[DAGNode]: - """The root nodes of current DAG + """Return the root nodes of current DAG. Returns: List[DAGNode]: The root nodes of current DAG, no repeat @@ -485,7 +653,7 @@ def root_nodes(self) -> List[DAGNode]: @property def leaf_nodes(self) -> List[DAGNode]: - """The leaf nodes of current DAG + """Return the leaf nodes of current DAG. Returns: List[DAGNode]: The leaf nodes of current DAG, no repeat @@ -496,7 +664,7 @@ def leaf_nodes(self) -> List[DAGNode]: @property def trigger_nodes(self) -> List[DAGNode]: - """The trigger nodes of current DAG + """Return the trigger nodes of current DAG. Returns: List[DAGNode]: The trigger nodes of current DAG, no repeat @@ -506,34 +674,42 @@ def trigger_nodes(self) -> List[DAGNode]: return self._trigger_nodes async def _after_dag_end(self) -> None: - """The callback after DAG end""" + """Execute after DAG end.""" tasks = [] for node in self.node_map.values(): tasks.append(node.after_dag_end()) await asyncio.gather(*tasks) def print_tree(self) -> None: - """Print the DAG tree""" + """Print the DAG tree""" # noqa: D400 _print_format_dag_tree(self) def visualize_dag(self, view: bool = True, **kwargs) -> Optional[str]: - """Create the DAG graph""" + """Visualize the DAG. + + Args: + view (bool, optional): Whether view the DAG graph. Defaults to True, + if True, it will open the graph file with your default viewer. + """ self.print_tree() return _visualize_dag(self, view=view, **kwargs) def __enter__(self): + """Enter a DAG context.""" DAGVar.enter_dag(self) return self def __exit__(self, exc_type, exc_val, exc_tb): + """Exit a DAG context.""" DAGVar.exit_dag() def __repr__(self): + """Return the representation of current DAG.""" return f"DAG(dag_id={self.dag_id})" -def _get_nodes(node: DAGNode, is_upstream: Optional[bool] = True) -> set[DAGNode]: - nodes = set() +def _get_nodes(node: DAGNode, is_upstream: Optional[bool] = True) -> Set[DAGNode]: + nodes: Set[DAGNode] = set() if not node: return nodes nodes.add(node) @@ -553,7 +729,7 @@ def _print_dag( level: int = 0, prefix: str = "", last: bool = True, - level_dict: Dict[str, Any] = None, + level_dict: Optional[Dict[int, Any]] = None, ): if level_dict is None: level_dict = {} @@ -606,7 +782,7 @@ def _handle_dag_nodes( def _visualize_dag(dag: DAG, view: bool = True, **kwargs) -> Optional[str]: - """Visualize the DAG + """Visualize the DAG. Args: dag (DAG): The DAG to visualize @@ -641,7 +817,7 @@ def add_edges(node: DAGNode): filename = kwargs["filename"] del kwargs["filename"] - if not "directory" in kwargs: + if "directory" not in kwargs: from dbgpt.configs.model_config import LOGDIR kwargs["directory"] = LOGDIR diff --git a/dbgpt/core/awel/dag/dag_manager.py b/dbgpt/core/awel/dag/dag_manager.py index 90214e934..ffa09bdc3 100644 --- a/dbgpt/core/awel/dag/dag_manager.py +++ b/dbgpt/core/awel/dag/dag_manager.py @@ -1,3 +1,8 @@ +"""DAGManager is a component of AWEL, it is used to manage DAGs. + +DAGManager will load DAGs from dag_dirs, and register the trigger nodes +to TriggerManager. +""" import logging from typing import Dict, List @@ -10,24 +15,35 @@ class DAGManager(BaseComponent): + """The component of DAGManager.""" + name = ComponentType.AWEL_DAG_MANAGER def __init__(self, system_app: SystemApp, dag_dirs: List[str]): + """Initialize a DAGManager. + + Args: + system_app (SystemApp): The system app. + dag_dirs (List[str]): The directories to load DAGs. + """ super().__init__(system_app) self.dag_loader = LocalFileDAGLoader(dag_dirs) self.system_app = system_app self.dag_map: Dict[str, DAG] = {} def init_app(self, system_app: SystemApp): + """Initialize the DAGManager.""" self.system_app = system_app def load_dags(self): + """Load DAGs from dag_dirs.""" dags = self.dag_loader.load_dags() triggers = [] for dag in dags: dag_id = dag.dag_id if dag_id in self.dag_map: raise ValueError(f"Load DAG error, DAG ID {dag_id} has already exist") + self.dag_map[dag_id] = dag triggers += dag.trigger_nodes from ..trigger.trigger_manager import DefaultTriggerManager diff --git a/dbgpt/core/awel/dag/loader.py b/dbgpt/core/awel/dag/loader.py index 325d4733a..bd74f7b47 100644 --- a/dbgpt/core/awel/dag/loader.py +++ b/dbgpt/core/awel/dag/loader.py @@ -1,3 +1,8 @@ +"""DAG loader. + +DAGLoader will load DAGs from dag_dirs or other sources. +Now only support load DAGs from local files. +""" import hashlib import logging import os @@ -12,16 +17,26 @@ class DAGLoader(ABC): + """Abstract base class representing a loader for loading DAGs.""" + @abstractmethod def load_dags(self) -> List[DAG]: - """Load dags""" + """Load dags.""" class LocalFileDAGLoader(DAGLoader): + """DAG loader for loading DAGs from local files.""" + def __init__(self, dag_dirs: List[str]) -> None: + """Initialize a LocalFileDAGLoader. + + Args: + dag_dirs (List[str]): The directories to load DAGs. + """ self._dag_dirs = dag_dirs def load_dags(self) -> List[DAG]: + """Load dags from local files.""" dags = [] for filepath in self._dag_dirs: if not os.path.exists(filepath): @@ -70,7 +85,7 @@ def parse(mod_name, filepath): sys.modules[spec.name] = new_module loader.exec_module(new_module) return [new_module] - except Exception as e: + except Exception: msg = traceback.format_exc() logger.error(f"Failed to import: {filepath}, error message: {msg}") # TODO save error message diff --git a/dbgpt/core/awel/operator/__init__.py b/dbgpt/core/awel/operator/__init__.py index e69de29bb..da9f7d7d1 100644 --- a/dbgpt/core/awel/operator/__init__.py +++ b/dbgpt/core/awel/operator/__init__.py @@ -0,0 +1 @@ +"""The module of operator.""" diff --git a/dbgpt/core/awel/operator/base.py b/dbgpt/core/awel/operator/base.py index b56ed9b86..0af890e7c 100644 --- a/dbgpt/core/awel/operator/base.py +++ b/dbgpt/core/awel/operator/base.py @@ -1,7 +1,7 @@ +"""Base classes for operators that can be executed within a workflow.""" import asyncio import functools from abc import ABC, ABCMeta, abstractmethod -from inspect import signature from types import FunctionType from typing import ( Any, @@ -9,7 +9,6 @@ Dict, Generic, Iterator, - List, Optional, TypeVar, Union, @@ -21,7 +20,6 @@ AsyncToSyncIterator, BlockingFunction, DefaultExecutorFactory, - ExecutorFactory, blocking_func_to_async, ) @@ -54,13 +52,15 @@ async def execute_workflow( node (RunnableDAGNode): The starting node of the workflow to be executed. call_data (CALL_DATA): The data pass to root operator node. streaming_call (bool): Whether the call is a streaming call. - exist_dag_ctx (DAGContext): The context of the DAG when this node is run, Defaults to None. + exist_dag_ctx (DAGContext): The context of the DAG when this node is run, + Defaults to None. Returns: - DAGContext: The context after executing the workflow, containing the final state and data. + DAGContext: The context after executing the workflow, containing the final + state and data. """ -default_runner: WorkflowRunner = None +default_runner: Optional[WorkflowRunner] = None class BaseOperatorMeta(ABCMeta): @@ -68,8 +68,7 @@ class BaseOperatorMeta(ABCMeta): @classmethod def _apply_defaults(cls, func: F) -> F: - sig_cache = signature(func) - + # sig_cache = signature(func) @functools.wraps(func) def apply_defaults(self: "BaseOperator", *args: Any, **kwargs: Any) -> Any: dag: Optional[DAG] = kwargs.get("dag") or DAGVar.get_current_dag() @@ -81,7 +80,7 @@ def apply_defaults(self: "BaseOperator", *args: Any, **kwargs: Any) -> Any: if not executor: if system_app: executor = system_app.get_component( - ComponentType.EXECUTOR_DEFAULT, ExecutorFactory + ComponentType.EXECUTOR_DEFAULT, DefaultExecutorFactory ).create() else: executor = DefaultExecutorFactory().create() @@ -107,9 +106,10 @@ def apply_defaults(self: "BaseOperator", *args: Any, **kwargs: Any) -> Any: real_obj = func(self, *args, **kwargs) return real_obj - return cast(T, apply_defaults) + return cast(F, apply_defaults) def __new__(cls, name, bases, namespace, **kwargs): + """Create a new BaseOperator class with default arguments.""" new_cls = super().__new__(cls, name, bases, namespace, **kwargs) new_cls.__init__ = cls._apply_defaults(new_cls.__init__) return new_cls @@ -126,13 +126,14 @@ def __init__( task_id: Optional[str] = None, task_name: Optional[str] = None, dag: Optional[DAG] = None, - runner: WorkflowRunner = None, + runner: Optional[WorkflowRunner] = None, **kwargs, ) -> None: - """Initializes a BaseOperator with an optional workflow runner. + """Create a BaseOperator with an optional workflow runner. Args: - runner (WorkflowRunner, optional): The runner used to execute the workflow. Defaults to None. + runner (WorkflowRunner, optional): The runner used to execute the workflow. + Defaults to None. """ super().__init__(node_id=task_id, node_name=task_name, dag=dag, **kwargs) if not runner: @@ -141,19 +142,24 @@ def __init__( runner = DefaultWorkflowRunner() self._runner: WorkflowRunner = runner - self._dag_ctx: DAGContext = None + self._dag_ctx: Optional[DAGContext] = None @property def current_dag_context(self) -> DAGContext: + """Return the current DAG context.""" + if not self._dag_ctx: + raise ValueError("DAGContext is not set") return self._dag_ctx @property def dev_mode(self) -> bool: """Whether the operator is in dev mode. + In production mode, the default runner is not None. Returns: - bool: Whether the operator is in dev mode. True if the default runner is None. + bool: Whether the operator is in dev mode. True if the + default runner is None. """ return default_runner is None @@ -186,7 +192,8 @@ async def call( Args: call_data (CALL_DATA): The data pass to root operator node. - dag_ctx (DAGContext): The context of the DAG when this node is run, Defaults to None. + dag_ctx (DAGContext): The context of the DAG when this node is run, + Defaults to None. Returns: OUT: The output of the node after execution. """ @@ -196,7 +203,9 @@ async def call( return out_ctx.current_task_context.task_output.output def _blocking_call( - self, call_data: Optional[CALL_DATA] = None, loop: asyncio.BaseEventLoop = None + self, + call_data: Optional[CALL_DATA] = None, + loop: Optional[asyncio.BaseEventLoop] = None, ) -> OUT: """Execute the node and return the output. @@ -213,6 +222,7 @@ def _blocking_call( if not loop: loop = get_or_create_event_loop() + loop = cast(asyncio.BaseEventLoop, loop) return loop.run_until_complete(self.call(call_data)) async def call_stream( @@ -226,7 +236,8 @@ async def call_stream( Args: call_data (CALL_DATA): The data pass to root operator node. - dag_ctx (DAGContext): The context of the DAG when this node is run, Defaults to None. + dag_ctx (DAGContext): The context of the DAG when this node is run, + Defaults to None. Returns: AsyncIterator[OUT]: An asynchronous iterator over the output stream. @@ -237,7 +248,9 @@ async def call_stream( return out_ctx.current_task_context.task_output.output_stream def _blocking_call_stream( - self, call_data: Optional[CALL_DATA] = None, loop: asyncio.BaseEventLoop = None + self, + call_data: Optional[CALL_DATA] = None, + loop: Optional[asyncio.BaseEventLoop] = None, ) -> Iterator[OUT]: """Execute the node and return the output as a stream. @@ -259,9 +272,22 @@ def _blocking_call_stream( async def blocking_func_to_async( self, func: BlockingFunction, *args, **kwargs ) -> Any: + """Execute a blocking function asynchronously. + + In AWEL, the operators are executed asynchronously. However, + some functions are blocking, we run them in a separate thread. + + Args: + func (BlockingFunction): The blocking function to be executed. + *args: Positional arguments for the function. + **kwargs: Keyword arguments for the function. + """ + if not self._executor: + raise ValueError("Executor is not set") return await blocking_func_to_async(self._executor, func, *args, **kwargs) def initialize_runner(runner: WorkflowRunner): + """Initialize the default runner.""" global default_runner default_runner = runner diff --git a/dbgpt/core/awel/operator/common_operator.py b/dbgpt/core/awel/operator/common_operator.py index 4fbb266f9..616c1c2a1 100644 --- a/dbgpt/core/awel/operator/common_operator.py +++ b/dbgpt/core/awel/operator/common_operator.py @@ -1,7 +1,7 @@ +"""Common operators of AWEL.""" import asyncio import logging from typing import ( - Any, AsyncIterator, Awaitable, Callable, @@ -13,7 +13,17 @@ ) from ..dag.base import DAGContext -from ..task.base import IN, OUT, InputContext, InputSource, TaskContext, TaskOutput +from ..task.base import ( + IN, + OUT, + InputContext, + InputSource, + JoinFunc, + MapFunc, + ReduceFunc, + TaskContext, + TaskOutput, +) from .base import BaseOperator logger = logging.getLogger(__name__) @@ -25,7 +35,12 @@ class JoinOperator(BaseOperator, Generic[OUT]): This node type is useful for combining the outputs of upstream nodes. """ - def __init__(self, combine_function, **kwargs): + def __init__(self, combine_function: JoinFunc, **kwargs): + """Create a JoinDAGNode with a combine function. + + Args: + combine_function: A function that defines how to combine inputs. + """ super().__init__(**kwargs) if not callable(combine_function): raise ValueError("combine_function must be callable") @@ -33,6 +48,7 @@ def __init__(self, combine_function, **kwargs): async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]: """Run the join operation on the DAG context's inputs. + Args: dag_ctx (DAGContext): The current context of the DAG. @@ -50,8 +66,10 @@ async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]: class ReduceStreamOperator(BaseOperator, Generic[IN, OUT]): - def __init__(self, reduce_function=None, **kwargs): - """Initializes a ReduceStreamOperator with a combine function. + """Operator that reduces inputs using a custom reduce function.""" + + def __init__(self, reduce_function: Optional[ReduceFunc] = None, **kwargs): + """Create a ReduceStreamOperator with a combine function. Args: combine_function: A function that defines how to combine inputs. @@ -89,6 +107,7 @@ async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]: return reduce_output async def reduce(self, input_value: AsyncIterator[IN]) -> OUT: + """Reduce the input stream to a single value.""" raise NotImplementedError @@ -99,8 +118,8 @@ class MapOperator(BaseOperator, Generic[IN, OUT]): passes the transformed data downstream. """ - def __init__(self, map_function=None, **kwargs): - """Initializes a MapDAGNode with a mapping function. + def __init__(self, map_function: Optional[MapFunc] = None, **kwargs): + """Create a MapDAGNode with a mapping function. Args: map_function: A function that defines how to map the input data. @@ -133,13 +152,18 @@ async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]: if not call_data and not curr_task_ctx.task_input.check_single_parent(): num_parents = len(curr_task_ctx.task_input.parent_outputs) raise ValueError( - f"task {curr_task_ctx.task_id} MapDAGNode expects single parent, now number of parents: {num_parents}" + f"task {curr_task_ctx.task_id} MapDAGNode expects single parent," + f"now number of parents: {num_parents}" ) map_function = self.map_function or self.map if call_data: - call_data = await curr_task_ctx._call_data_to_output() - output = await call_data.map(map_function) + wrapped_call_data = await curr_task_ctx._call_data_to_output() + if not wrapped_call_data: + raise ValueError( + f"task {curr_task_ctx.task_id} MapDAGNode expects wrapped_call_data" + ) + output: TaskOutput[OUT] = await wrapped_call_data.map(map_function) curr_task_ctx.set_task_output(output) return output @@ -150,6 +174,7 @@ async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]: return output async def map(self, input_value: IN) -> OUT: + """Map the input data to a new value.""" raise NotImplementedError @@ -161,6 +186,11 @@ class BranchOperator(BaseOperator, Generic[IN, OUT]): This node filters its input data using a branching function and allows for conditional paths in the workflow. + + If a branch function returns True, the corresponding task will be executed. + otherwise, the corresponding task will be skipped, and the output of + this skip node will be set to `SKIP_DATA` + """ def __init__( @@ -168,11 +198,11 @@ def __init__( branches: Optional[Dict[BranchFunc[IN], Union[BaseOperator, str]]] = None, **kwargs, ): - """ - Initializes a BranchDAGNode with a branching function. + """Create a BranchDAGNode with a branching function. Args: - branches (Dict[BranchFunc[IN], Union[BaseOperator, str]]): Dict of function that defines the branching condition. + branches (Dict[BranchFunc[IN], Union[BaseOperator, str]]): + Dict of function that defines the branching condition. Raises: ValueError: If the branch_function is not callable. @@ -183,7 +213,9 @@ def __init__( if not callable(branch_function): raise ValueError("branch_function must be callable") if isinstance(value, BaseOperator): - branches[branch_function] = value.node_name or value.node_name + if not value.node_name: + raise ValueError("branch node name must be set") + branches[branch_function] = value.node_name self._branches = branches async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]: @@ -210,7 +242,7 @@ async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]: branches = await self.branches() branch_func_tasks = [] - branch_nodes: List[str] = [] + branch_nodes: List[Union[BaseOperator, str]] = [] for func, node_name in branches.items(): branch_nodes.append(node_name) branch_func_tasks.append( @@ -225,20 +257,25 @@ async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]: node_name = branch_nodes[i] branch_out = ctx.parent_outputs[0].task_output logger.info( - f"branch_input_ctxs {i} result {branch_out.output}, is_empty: {branch_out.is_empty}" + f"branch_input_ctxs {i} result {branch_out.output}, " + f"is_empty: {branch_out.is_empty}" ) - if ctx.parent_outputs[0].task_output.is_empty: + if ctx.parent_outputs[0].task_output.is_none: logger.info(f"Skip node name {node_name}") skip_node_names.append(node_name) curr_task_ctx.update_metadata("skip_node_names", skip_node_names) return parent_output async def branches(self) -> Dict[BranchFunc[IN], Union[BaseOperator, str]]: + """Return branch logic based on input data.""" raise NotImplementedError class InputOperator(BaseOperator, Generic[OUT]): + """Operator node that reads data from an input source.""" + def __init__(self, input_source: InputSource[OUT], **kwargs) -> None: + """Create an InputDAGNode with an input source.""" super().__init__(**kwargs) self._input_source = input_source @@ -250,7 +287,10 @@ async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]: class TriggerOperator(InputOperator, Generic[OUT]): + """Operator node that triggers the DAG to run.""" + def __init__(self, **kwargs) -> None: + """Create a TriggerDAGNode.""" from ..task.task_impl import SimpleCallDataInputSource super().__init__(input_source=SimpleCallDataInputSource(), **kwargs) diff --git a/dbgpt/core/awel/operator/stream_operator.py b/dbgpt/core/awel/operator/stream_operator.py index 526249704..8893a51f9 100644 --- a/dbgpt/core/awel/operator/stream_operator.py +++ b/dbgpt/core/awel/operator/stream_operator.py @@ -1,3 +1,4 @@ +"""The module of stream operator.""" from abc import ABC, abstractmethod from typing import AsyncIterator, Generic @@ -7,12 +8,18 @@ class StreamifyAbsOperator(BaseOperator[OUT], ABC, Generic[IN, OUT]): + """An abstract operator that converts a value of IN to an AsyncIterator[OUT].""" + async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]: curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context call_data = curr_task_ctx.call_data if call_data: - call_data = await curr_task_ctx._call_data_to_output() - output = await call_data.streamify(self.streamify) + wrapped_call_data = await curr_task_ctx._call_data_to_output() + if not wrapped_call_data: + raise ValueError( + f"task {curr_task_ctx.task_id} MapDAGNode expects wrapped_call_data" + ) + output = await wrapped_call_data.streamify(self.streamify) curr_task_ctx.set_task_output(output) return output output = await curr_task_ctx.task_input.parent_outputs[0].task_output.streamify( @@ -23,26 +30,28 @@ async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]: @abstractmethod async def streamify(self, input_value: IN) -> AsyncIterator[OUT]: - """Convert a value of IN to an AsyncIterator[OUT] + """Convert a value of IN to an AsyncIterator[OUT]. Args: input_value (IN): The data of parent operator's output - Example: + Examples: + .. code-block:: python - .. code-block:: python + class MyStreamOperator(StreamifyAbsOperator[int, int]): + async def streamify(self, input_value: int) -> AsyncIterator[int]: + for i in range(input_value): + yield i - class MyStreamOperator(StreamifyAbsOperator[int, int]): - async def streamify(self, input_value: int) -> AsyncIterator[int]: - for i in range(input_value): - yield i """ class UnstreamifyAbsOperator(BaseOperator[OUT], Generic[IN, OUT]): + """An abstract operator that converts a value of AsyncIterator[IN] to an OUT.""" + async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]: curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context - output = await curr_task_ctx.task_input.parent_outputs[ + output: TaskOutput[OUT] = await curr_task_ctx.task_input.parent_outputs[ 0 ].task_output.unstreamify(self.unstreamify) curr_task_ctx.set_task_output(output) @@ -56,24 +65,30 @@ async def unstreamify(self, input_value: AsyncIterator[IN]) -> OUT: input_value (AsyncIterator[IN])): The data of parent operator's output Example: - - .. code-block:: python - - class MyUnstreamOperator(UnstreamifyAbsOperator[int, int]): - async def unstreamify(self, input_value: AsyncIterator[int]) -> int: - value_cnt = 0 - async for v in input_value: - value_cnt += 1 - return value_cnt + .. code-block:: python + + class MyUnstreamOperator(UnstreamifyAbsOperator[int, int]): + async def unstreamify(self, input_value: AsyncIterator[int]) -> int: + value_cnt = 0 + async for v in input_value: + value_cnt += 1 + return value_cnt """ class TransformStreamAbsOperator(BaseOperator[OUT], Generic[IN, OUT]): + """Streaming to other streaming data. + + An abstract operator that transforms a value of + AsyncIterator[IN] to another AsyncIterator[OUT]. + """ + async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]: curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context - output = await curr_task_ctx.task_input.parent_outputs[ + output: TaskOutput[OUT] = await curr_task_ctx.task_input.parent_outputs[ 0 ].task_output.transform_stream(self.transform_stream) + curr_task_ctx.set_task_output(output) return output @@ -81,19 +96,18 @@ async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]: async def transform_stream( self, input_value: AsyncIterator[IN] ) -> AsyncIterator[OUT]: - """Transform an AsyncIterator[IN] to another AsyncIterator[OUT] using a given function. + """Transform an AsyncIterator[IN] to another AsyncIterator[OUT]. Args: input_value (AsyncIterator[IN])): The data of parent operator's output - Example: - - .. code-block:: python + Examples: + .. code-block:: python - class MyTransformStreamOperator(TransformStreamAbsOperator[int, int]): - async def unstreamify( - self, input_value: AsyncIterator[int] - ) -> AsyncIterator[int]: - async for v in input_value: - yield v + 1 + class MyTransformStreamOperator(TransformStreamAbsOperator[int, int]): + async def unstreamify( + self, input_value: AsyncIterator[int] + ) -> AsyncIterator[int]: + async for v in input_value: + yield v + 1 """ diff --git a/dbgpt/core/awel/resource/__init__.py b/dbgpt/core/awel/resource/__init__.py index e69de29bb..cd9d9c59d 100644 --- a/dbgpt/core/awel/resource/__init__.py +++ b/dbgpt/core/awel/resource/__init__.py @@ -0,0 +1,4 @@ +"""The module of AWEL resource. + +Not implemented yet. +""" diff --git a/dbgpt/core/awel/resource/base.py b/dbgpt/core/awel/resource/base.py index 97fefbbc3..70f6bc8a5 100644 --- a/dbgpt/core/awel/resource/base.py +++ b/dbgpt/core/awel/resource/base.py @@ -1,8 +1,15 @@ +"""Base class for resource group.""" from abc import ABC, abstractmethod class ResourceGroup(ABC): + """Base class for resource group. + + A resource group is a group of resources that are related to each other. + It contains the all resources that are needed to run a workflow. + """ + @property @abstractmethod def name(self) -> str: - """The name of current resource group""" + """Return the name of current resource group.""" diff --git a/dbgpt/core/awel/runner/__init__.py b/dbgpt/core/awel/runner/__init__.py index e69de29bb..54d50c954 100644 --- a/dbgpt/core/awel/runner/__init__.py +++ b/dbgpt/core/awel/runner/__init__.py @@ -0,0 +1,4 @@ +"""The module to run AWEL operators. + +You can implement your own runner by inheriting the `WorkflowRunner` class. +""" diff --git a/dbgpt/core/awel/runner/job_manager.py b/dbgpt/core/awel/runner/job_manager.py index b8b7d2ebd..f67705ffd 100644 --- a/dbgpt/core/awel/runner/job_manager.py +++ b/dbgpt/core/awel/runner/job_manager.py @@ -1,33 +1,38 @@ +"""Job manager for DAG.""" import asyncio import logging import uuid -from typing import Dict, List, Optional, Set +from typing import Dict, List, Optional, cast -from ..dag.base import DAG, DAGLifecycle +from ..dag.base import DAGLifecycle from ..operator.base import CALL_DATA, BaseOperator logger = logging.getLogger(__name__) -class DAGNodeInstance: - def __init__(self, node_instance: DAG) -> None: - pass - - -class DAGInstance: - def __init__(self, dag: DAG) -> None: - self._dag = dag +class JobManager(DAGLifecycle): + """Job manager for DAG. + This class is used to manage the DAG lifecycle. + """ -class JobManager(DAGLifecycle): def __init__( self, root_nodes: List[BaseOperator], all_nodes: List[BaseOperator], end_node: BaseOperator, - id2call_data: Dict[str, Dict], + id2call_data: Dict[str, Optional[Dict]], node_name_to_ids: Dict[str, str], ) -> None: + """Create a job manager. + + Args: + root_nodes (List[BaseOperator]): The root nodes of the DAG. + all_nodes (List[BaseOperator]): All nodes of the DAG. + end_node (BaseOperator): The end node of the DAG. + id2call_data (Dict[str, Optional[Dict]]): The call data of each node. + node_name_to_ids (Dict[str, str]): The node name to node id mapping. + """ self._root_nodes = root_nodes self._all_nodes = all_nodes self._end_node = end_node @@ -38,6 +43,15 @@ def __init__( def build_from_end_node( end_node: BaseOperator, call_data: Optional[CALL_DATA] = None ) -> "JobManager": + """Build a job manager from the end node. + + This will get all upstream nodes from the end node, and build a job manager. + + Args: + end_node (BaseOperator): The end node of the DAG. + call_data (Optional[CALL_DATA], optional): The call data of the end node. + Defaults to None. + """ nodes = _build_from_end_node(end_node) root_nodes = _get_root_nodes(nodes) id2call_data = _save_call_data(root_nodes, call_data) @@ -50,17 +64,22 @@ def build_from_end_node( return JobManager(root_nodes, nodes, end_node, id2call_data, node_name_to_ids) def get_call_data_by_id(self, node_id: str) -> Optional[Dict]: + """Get the call data by node id. + + Args: + node_id (str): The node id. + """ return self._id2node_data.get(node_id) async def before_dag_run(self): - """The callback before DAG run""" + """Execute the callback before DAG run.""" tasks = [] for node in self._all_nodes: tasks.append(node.before_dag_run()) await asyncio.gather(*tasks) async def after_dag_end(self): - """The callback after DAG end""" + """Execute the callback after DAG end.""" tasks = [] for node in self._all_nodes: tasks.append(node.after_dag_end()) @@ -68,9 +87,9 @@ async def after_dag_end(self): def _save_call_data( - root_nodes: List[BaseOperator], call_data: CALL_DATA -) -> Dict[str, Dict]: - id2call_data = {} + root_nodes: List[BaseOperator], call_data: Optional[CALL_DATA] +) -> Dict[str, Optional[Dict]]: + id2call_data: Dict[str, Optional[Dict]] = {} logger.debug(f"_save_call_data: {call_data}, root_nodes: {root_nodes}") if not call_data: return id2call_data @@ -82,7 +101,8 @@ def _save_call_data( for node in root_nodes: node_id = node.node_id logger.debug( - f"Save call data to node {node.node_id}, call_data: {call_data.get(node_id)}" + f"Save call data to node {node.node_id}, call_data: " + f"{call_data.get(node_id)}" ) id2call_data[node_id] = call_data.get(node_id) return id2call_data @@ -91,13 +111,11 @@ def _save_call_data( def _build_from_end_node(end_node: BaseOperator) -> List[BaseOperator]: """Build all nodes from the end node.""" nodes = [] - if isinstance(end_node, BaseOperator): - task_id = end_node.node_id - if not task_id: - task_id = str(uuid.uuid4()) - end_node.set_node_id(task_id) + if isinstance(end_node, BaseOperator) and not end_node._node_id: + end_node.set_node_id(str(uuid.uuid4())) nodes.append(end_node) for node in end_node.upstream: + node = cast(BaseOperator, node) nodes += _build_from_end_node(node) return nodes diff --git a/dbgpt/core/awel/runner/local_runner.py b/dbgpt/core/awel/runner/local_runner.py index 8eb16f574..b0882059b 100644 --- a/dbgpt/core/awel/runner/local_runner.py +++ b/dbgpt/core/awel/runner/local_runner.py @@ -1,12 +1,16 @@ +"""Local runner for workflow. + +This runner will run the workflow in the current process. +""" import logging -from typing import Dict, List, Optional, Set +from typing import Dict, List, Optional, Set, cast from dbgpt.component import SystemApp from ..dag.base import DAGContext, DAGVar from ..operator.base import CALL_DATA, BaseOperator, WorkflowRunner -from ..operator.common_operator import BranchOperator, JoinOperator, TriggerOperator -from ..task.base import TaskContext, TaskState +from ..operator.common_operator import BranchOperator, JoinOperator +from ..task.base import SKIP_DATA, TaskContext, TaskState from ..task.task_impl import DefaultInputContext, DefaultTaskContext, SimpleTaskOutput from .job_manager import JobManager @@ -14,6 +18,8 @@ class DefaultWorkflowRunner(WorkflowRunner): + """The default workflow runner.""" + async def execute_workflow( self, node: BaseOperator, @@ -21,6 +27,17 @@ async def execute_workflow( streaming_call: bool = False, exist_dag_ctx: Optional[DAGContext] = None, ) -> DAGContext: + """Execute the workflow. + + Args: + node (BaseOperator): The end node of the workflow. + call_data (Optional[CALL_DATA], optional): The call data of the end node. + Defaults to None. + streaming_call (bool, optional): Whether the call is streaming call. + Defaults to False. + exist_dag_ctx (Optional[DAGContext], optional): The exist DAG context. + Defaults to None. + """ # Save node output # dag = node.dag job_manager = JobManager.build_from_end_node(node, call_data) @@ -37,8 +54,8 @@ async def execute_workflow( ) logger.info(f"Begin run workflow from end operator, id: {node.node_id}") logger.debug(f"Node id {node.node_id}, call_data: {call_data}") - skip_node_ids = set() - system_app: SystemApp = DAGVar.get_current_system_app() + skip_node_ids: Set[str] = set() + system_app: Optional[SystemApp] = DAGVar.get_current_system_app() await job_manager.before_dag_run() await self._execute_node( @@ -57,7 +74,7 @@ async def _execute_node( dag_ctx: DAGContext, node_outputs: Dict[str, TaskContext], skip_node_ids: Set[str], - system_app: SystemApp, + system_app: Optional[SystemApp], ): # Skip run node if node.node_id in node_outputs: @@ -79,8 +96,12 @@ async def _execute_node( node_outputs[upstream_node.node_id] for upstream_node in node.upstream ] input_ctx = DefaultInputContext(inputs) - task_ctx = DefaultTaskContext(node.node_id, TaskState.INIT, task_output=None) - task_ctx.set_call_data(job_manager.get_call_data_by_id(node.node_id)) + task_ctx: DefaultTaskContext = DefaultTaskContext( + node.node_id, TaskState.INIT, task_output=None + ) + current_call_data = job_manager.get_call_data_by_id(node.node_id) + if current_call_data: + task_ctx.set_call_data(current_call_data) task_ctx.set_task_input(input_ctx) dag_ctx.set_current_task_context(task_ctx) @@ -88,12 +109,13 @@ async def _execute_node( if node.node_id in skip_node_ids: task_ctx.set_current_state(TaskState.SKIP) - task_ctx.set_task_output(SimpleTaskOutput(None)) + task_ctx.set_task_output(SimpleTaskOutput(SKIP_DATA)) node_outputs[node.node_id] = task_ctx return try: logger.debug( - f"Begin run operator, node id: {node.node_id}, node name: {node.node_name}, cls: {node}" + f"Begin run operator, node id: {node.node_id}, node name: " + f"{node.node_name}, cls: {node}" ) if system_app is not None and node.system_app is None: node.set_system_app(system_app) @@ -120,6 +142,7 @@ def _skip_current_downstream_by_node_name( if not skip_nodes: return for child in branch_node.downstream: + child = cast(BaseOperator, child) if child.node_name in skip_nodes: logger.info(f"Skip node name {child.node_name}, node id {child.node_id}") _skip_downstream_by_id(child, skip_node_ids) @@ -131,4 +154,5 @@ def _skip_downstream_by_id(node: BaseOperator, skip_node_ids: Set[str]): return skip_node_ids.add(node.node_id) for child in node.downstream: + child = cast(BaseOperator, child) _skip_downstream_by_id(child, skip_node_ids) diff --git a/dbgpt/core/awel/task/__init__.py b/dbgpt/core/awel/task/__init__.py index e69de29bb..9a41a1476 100644 --- a/dbgpt/core/awel/task/__init__.py +++ b/dbgpt/core/awel/task/__init__.py @@ -0,0 +1 @@ +"""The module of Task.""" diff --git a/dbgpt/core/awel/task/base.py b/dbgpt/core/awel/task/base.py index 603bc8609..f0c5712bc 100644 --- a/dbgpt/core/awel/task/base.py +++ b/dbgpt/core/awel/task/base.py @@ -1,8 +1,10 @@ +"""Base classes for task-related objects.""" from abc import ABC, abstractmethod from enum import Enum from typing import ( Any, AsyncIterator, + Awaitable, Callable, Dict, Generic, @@ -17,6 +19,24 @@ T = TypeVar("T") +class _EMPTY_DATA_TYPE: + def __bool__(self): + return False + + +EMPTY_DATA = _EMPTY_DATA_TYPE() +SKIP_DATA = _EMPTY_DATA_TYPE() +PLACEHOLDER_DATA = _EMPTY_DATA_TYPE() + +MapFunc = Union[Callable[[IN], OUT], Callable[[IN], Awaitable[OUT]]] +ReduceFunc = Union[Callable[[IN], OUT], Callable[[IN], Awaitable[OUT]]] +StreamFunc = Callable[[IN], Awaitable[AsyncIterator[OUT]]] +UnStreamFunc = Callable[[AsyncIterator[IN]], OUT] +TransformFunc = Callable[[AsyncIterator[IN]], Awaitable[AsyncIterator[OUT]]] +PredicateFunc = Union[Callable[[IN], bool], Callable[[IN], Awaitable[bool]]] +JoinFunc = Union[Callable[..., OUT], Callable[..., Awaitable[OUT]]] + + class TaskState(str, Enum): """Enumeration representing the state of a task in the workflow. @@ -33,8 +53,8 @@ class TaskState(str, Enum): class TaskOutput(ABC, Generic[T]): """Abstract base class representing the output of a task. - This class encapsulates the output of a task and provides methods to access the output data. - It can be subclassed to implement specific output behaviors. + This class encapsulates the output of a task and provides methods to access the + output data.It can be subclassed to implement specific output behaviors. """ @property @@ -56,20 +76,30 @@ def is_empty(self) -> bool: return False @property - def output(self) -> Optional[T]: + def is_none(self) -> bool: + """Check if the output is None. + + Returns: + bool: True if the output is None, False otherwise. + """ + return False + + @property + def output(self) -> T: """Return the output of the task. Returns: - T: The output of the task. None if the output is empty. + T: The output of the task. """ raise NotImplementedError @property - def output_stream(self) -> Optional[AsyncIterator[T]]: + def output_stream(self) -> AsyncIterator[T]: """Return the output of the task as an asynchronous stream. Returns: - AsyncIterator[T]: An asynchronous iterator over the output. None if the output is empty. + AsyncIterator[T]: An asynchronous iterator over the output. None if the + output is empty. """ raise NotImplementedError @@ -83,39 +113,38 @@ def set_output(self, output_data: Union[T, AsyncIterator[T]]) -> None: @abstractmethod def new_output(self) -> "TaskOutput[T]": - """Create new output object""" + """Create new output object.""" - async def map(self, map_func) -> "TaskOutput[T]": + async def map(self, map_func: MapFunc) -> "TaskOutput[OUT]": """Apply a mapping function to the task's output. Args: - map_func: A function to apply to the task's output. + map_func (MapFunc): A function to apply to the task's output. Returns: - TaskOutput[T]: The result of applying the mapping function. + TaskOutput[OUT]: The result of applying the mapping function. """ raise NotImplementedError - async def reduce(self, reduce_func) -> "TaskOutput[T]": + async def reduce(self, reduce_func: ReduceFunc) -> "TaskOutput[OUT]": """Apply a reducing function to the task's output. - Stream TaskOutput to Nonstream TaskOutput. + Stream TaskOutput to no stream TaskOutput. Args: reduce_func: A reducing function to apply to the task's output. Returns: - TaskOutput[T]: The result of applying the reducing function. + TaskOutput[OUT]: The result of applying the reducing function. """ raise NotImplementedError - async def streamify( - self, transform_func: Callable[[T], AsyncIterator[T]] - ) -> "TaskOutput[T]": + async def streamify(self, transform_func: StreamFunc) -> "TaskOutput[T]": """Convert a value of type T to an AsyncIterator[T] using a transform function. Args: - transform_func (Callable[[T], AsyncIterator[T]]): Function to transform a T value into an AsyncIterator[T]. + transform_func (StreamFunc): Function to transform a T value into an + AsyncIterator[OUT]. Returns: TaskOutput[T]: The result of applying the reducing function. @@ -123,38 +152,39 @@ async def streamify( raise NotImplementedError async def transform_stream( - self, transform_func: Callable[[AsyncIterator[T]], AsyncIterator[T]] - ) -> "TaskOutput[T]": - """Transform an AsyncIterator[T] to another AsyncIterator[T] using a given function. + self, transform_func: TransformFunc + ) -> "TaskOutput[OUT]": + """Transform an AsyncIterator[T] to another AsyncIterator[T]. Args: - transform_func (Callable[[AsyncIterator[T]], AsyncIterator[T]]): Function to apply to the AsyncIterator[T]. + transform_func (Callable[[AsyncIterator[T]], AsyncIterator[T]]): Function to + apply to the AsyncIterator[T]. Returns: TaskOutput[T]: The result of applying the reducing function. """ raise NotImplementedError - async def unstreamify( - self, transform_func: Callable[[AsyncIterator[T]], T] - ) -> "TaskOutput[T]": + async def unstreamify(self, transform_func: UnStreamFunc) -> "TaskOutput[OUT]": """Convert an AsyncIterator[T] to a value of type T using a transform function. Args: - transform_func (Callable[[AsyncIterator[T]], T]): Function to transform an AsyncIterator[T] into a T value. + transform_func (UnStreamFunc): Function to transform an AsyncIterator[T] + into a T value. Returns: TaskOutput[T]: The result of applying the reducing function. """ raise NotImplementedError - async def check_condition(self, condition_func) -> bool: + async def check_condition(self, condition_func) -> "TaskOutput[OUT]": """Check if current output meets a given condition. Args: condition_func: A function to determine if the condition is met. Returns: - bool: True if current output meet the condition, False otherwise. + TaskOutput[T]: The result of applying the reducing function. + If the condition is not met, return empty output. """ raise NotImplementedError @@ -182,6 +212,9 @@ def task_input(self) -> "InputContext": Returns: InputContext: The InputContext of current task. + + Raises: + Exception: If the InputContext is not set. """ @abstractmethod @@ -216,7 +249,7 @@ def current_state(self) -> TaskState: @abstractmethod def set_current_state(self, task_state: TaskState) -> None: - """Set current task state + """Set current task state. Args: task_state (TaskState): The task state to be set. @@ -224,7 +257,7 @@ def set_current_state(self, task_state: TaskState) -> None: @abstractmethod def new_ctx(self) -> "TaskContext": - """Create new task context + """Create new task context. Returns: TaskContext: A new instance of a TaskContext. @@ -233,14 +266,14 @@ def new_ctx(self) -> "TaskContext": @property @abstractmethod def metadata(self) -> Dict[str, Any]: - """Get the metadata of current task + """Return the metadata of current task. Returns: Dict[str, Any]: The metadata """ def update_metadata(self, key: str, value: Any) -> None: - """Update metadata with key and value + """Update metadata with key and value. Args: key (str): The key of metadata @@ -250,15 +283,15 @@ def update_metadata(self, key: str, value: Any) -> None: @property def call_data(self) -> Optional[Dict]: - """Get the call data for current data""" + """Return the call data for current data.""" return self.metadata.get("call_data") @abstractmethod async def _call_data_to_output(self) -> Optional[TaskOutput[T]]: - """Get the call data for current data""" + """Get the call data for current data.""" def set_call_data(self, call_data: Dict) -> None: - """Set call data for current task""" + """Save the call data for current task.""" self.update_metadata("call_data", call_data) @@ -315,7 +348,8 @@ async def filter(self, filter_func: Callable[[Any], bool]) -> "InputContext": """Filter the inputs based on a provided function. Args: - filter_func (Callable[[Any], bool]): A function that returns True for inputs to keep. + filter_func (Callable[[Any], bool]): A function that returns True for + inputs to keep. Returns: InputContext: A new InputContext instance with the filtered inputs. @@ -323,13 +357,15 @@ async def filter(self, filter_func: Callable[[Any], bool]) -> "InputContext": @abstractmethod async def predicate_map( - self, predicate_func: Callable[[Any], bool], failed_value: Any = None + self, predicate_func: PredicateFunc, failed_value: Any = None ) -> "InputContext": """Predicate the inputs based on a provided function. Args: - predicate_func (Callable[[Any], bool]): A function that returns True for inputs is predicate True. - failed_value (Any): The value to be set if the return value of predicate function is False + predicate_func (Callable[[Any], bool]): A function that returns True for + inputs is predicate True. + failed_value (Any): The value to be set if the return value of predicate + function is False Returns: InputContext: A new InputContext instance with the predicate inputs. """ diff --git a/dbgpt/core/awel/task/task_impl.py b/dbgpt/core/awel/task/task_impl.py index 5f113aeec..8877c5cfe 100644 --- a/dbgpt/core/awel/task/task_impl.py +++ b/dbgpt/core/awel/task/task_impl.py @@ -1,3 +1,7 @@ +"""The default implementation of Task. + +This implementation can run workflow in local machine. +""" import asyncio import logging from abc import ABC, abstractmethod @@ -8,15 +12,32 @@ Coroutine, Dict, Generic, - Iterator, List, Optional, Tuple, - TypeVar, Union, + cast, ) -from .base import InputContext, InputSource, T, TaskContext, TaskOutput, TaskState +from .base import ( + _EMPTY_DATA_TYPE, + EMPTY_DATA, + OUT, + PLACEHOLDER_DATA, + SKIP_DATA, + InputContext, + InputSource, + MapFunc, + PredicateFunc, + ReduceFunc, + StreamFunc, + T, + TaskContext, + TaskOutput, + TaskState, + TransformFunc, + UnStreamFunc, +) logger = logging.getLogger(__name__) @@ -37,101 +58,197 @@ async def _reduce_stream(stream: AsyncIterator, reduce_function) -> Any: class SimpleTaskOutput(TaskOutput[T], Generic[T]): - def __init__(self, data: T) -> None: + """The default implementation of TaskOutput. + + It wraps the no stream data and provide some basic data operations. + """ + + def __init__(self, data: Union[T, _EMPTY_DATA_TYPE] = EMPTY_DATA) -> None: + """Create a SimpleTaskOutput. + + Args: + data (Union[T, _EMPTY_DATA_TYPE], optional): The output data. Defaults to + EMPTY_DATA. + """ super().__init__() self._data = data @property def output(self) -> T: - return self._data + """Return the output data.""" + if self._data == EMPTY_DATA: + raise ValueError("No output data for current task output") + return cast(T, self._data) def set_output(self, output_data: T | AsyncIterator[T]) -> None: - self._data = output_data + """Save the output data to current object. + + Args: + output_data (T | AsyncIterator[T]): The output data. + """ + if _is_async_iterator(output_data): + raise ValueError( + f"Can not set stream data {output_data} to SimpleTaskOutput" + ) + self._data = cast(T, output_data) def new_output(self) -> TaskOutput[T]: - return SimpleTaskOutput(None) + """Create new output object with empty data.""" + return SimpleTaskOutput() @property def is_empty(self) -> bool: + """Return True if the output data is empty.""" + return self._data == EMPTY_DATA or self._data == SKIP_DATA + + @property + def is_none(self) -> bool: + """Return True if the output data is None.""" return self._data is None async def _apply_func(self, func) -> Any: + """Apply the function to current output data.""" if asyncio.iscoroutinefunction(func): out = await func(self._data) else: out = func(self._data) return out - async def map(self, map_func) -> TaskOutput[T]: + async def map(self, map_func: MapFunc) -> TaskOutput[OUT]: + """Apply a mapping function to the task's output. + + Args: + map_func (MapFunc): A function to apply to the task's output. + + Returns: + TaskOutput[OUT]: The result of applying the mapping function. + """ out = await self._apply_func(map_func) return SimpleTaskOutput(out) - async def check_condition(self, condition_func) -> bool: - return await self._apply_func(condition_func) + async def check_condition(self, condition_func) -> TaskOutput[OUT]: + """Check the condition function.""" + out = await self._apply_func(condition_func) + if out: + return SimpleTaskOutput(PLACEHOLDER_DATA) + return SimpleTaskOutput(EMPTY_DATA) + + async def streamify(self, transform_func: StreamFunc) -> TaskOutput[OUT]: + """Transform the task's output to a stream output. + + Args: + transform_func (StreamFunc): A function to transform the task's output to a + stream output. - async def streamify( - self, transform_func: Callable[[T], AsyncIterator[T]] - ) -> TaskOutput[T]: + Returns: + TaskOutput[OUT]: The result of transforming the task's output to a stream + output. + """ out = await self._apply_func(transform_func) return SimpleStreamTaskOutput(out) class SimpleStreamTaskOutput(TaskOutput[T], Generic[T]): - def __init__(self, data: AsyncIterator[T]) -> None: + """The default stream implementation of TaskOutput.""" + + def __init__( + self, data: Union[AsyncIterator[T], _EMPTY_DATA_TYPE] = EMPTY_DATA + ) -> None: + """Create a SimpleStreamTaskOutput. + + Args: + data (Union[AsyncIterator[T], _EMPTY_DATA_TYPE], optional): The output data. + Defaults to EMPTY_DATA. + """ super().__init__() self._data = data @property def is_stream(self) -> bool: + """Return True if the output data is a stream.""" return True @property def is_empty(self) -> bool: - return not self._data + """Return True if the output data is empty.""" + return self._data == EMPTY_DATA or self._data == SKIP_DATA + + @property + def is_none(self) -> bool: + """Return True if the output data is None.""" + return self._data is None @property def output_stream(self) -> AsyncIterator[T]: - return self._data + """Return the output data. + + Returns: + AsyncIterator[T]: The output data. + + Raises: + ValueError: If the output data is empty. + """ + if self._data == EMPTY_DATA: + raise ValueError("No output data for current task output") + return cast(AsyncIterator[T], self._data) def set_output(self, output_data: T | AsyncIterator[T]) -> None: - self._data = output_data + """Save the output data to current object. + + Raises: + ValueError: If the output data is not a stream. + """ + if not _is_async_iterator(output_data): + raise ValueError( + f"Can not set non-stream data {output_data} to SimpleStreamTaskOutput" + ) + self._data = cast(AsyncIterator[T], output_data) def new_output(self) -> TaskOutput[T]: - return SimpleStreamTaskOutput(None) + """Create new output object with empty data.""" + return SimpleStreamTaskOutput() - async def map(self, map_func) -> TaskOutput[T]: + async def map(self, map_func: MapFunc) -> TaskOutput[OUT]: + """Apply a mapping function to the task's output.""" is_async = asyncio.iscoroutinefunction(map_func) - async def new_iter() -> AsyncIterator[T]: - async for out in self._data: + async def new_iter() -> AsyncIterator[OUT]: + async for out in self.output_stream: if is_async: - out = await map_func(out) + new_out: OUT = await map_func(out) else: - out = map_func(out) - yield out + new_out = cast(OUT, map_func(out)) + yield new_out return SimpleStreamTaskOutput(new_iter()) - async def reduce(self, reduce_func) -> TaskOutput[T]: - out = await _reduce_stream(self._data, reduce_func) + async def reduce(self, reduce_func: ReduceFunc) -> TaskOutput[OUT]: + """Apply a reduce function to the task's output.""" + out = await _reduce_stream(self.output_stream, reduce_func) return SimpleTaskOutput(out) - async def unstreamify( - self, transform_func: Callable[[AsyncIterator[T]], T] - ) -> TaskOutput[T]: + async def unstreamify(self, transform_func: UnStreamFunc) -> TaskOutput[OUT]: + """Transform the task's output to a non-stream output.""" if asyncio.iscoroutinefunction(transform_func): - out = await transform_func(self._data) + out = await transform_func(self.output_stream) else: - out = transform_func(self._data) + out = transform_func(self.output_stream) return SimpleTaskOutput(out) - async def transform_stream( - self, transform_func: Callable[[AsyncIterator[T]], AsyncIterator[T]] - ) -> TaskOutput[T]: + async def transform_stream(self, transform_func: TransformFunc) -> TaskOutput[OUT]: + """Transform an AsyncIterator[T] to another AsyncIterator[T]. + + Args: + transform_func (Callable[[AsyncIterator[T]], AsyncIterator[T]]): Function to + apply to the AsyncIterator[T]. + + Returns: + TaskOutput[T]: The result of applying the reducing function. + """ if asyncio.iscoroutinefunction(transform_func): - out = await transform_func(self._data) + out: AsyncIterator[OUT] = await transform_func(self.output_stream) else: - out = transform_func(self._data) + out = cast(AsyncIterator[OUT], transform_func(self.output_stream)) return SimpleStreamTaskOutput(out) @@ -145,20 +262,34 @@ def _is_async_iterator(obj): class BaseInputSource(InputSource, ABC): + """The base class of InputSource.""" + def __init__(self) -> None: + """Create a BaseInputSource.""" super().__init__() self._is_read = False @abstractmethod def _read_data(self, task_ctx: TaskContext) -> Any: - """Read data with task context""" + """Return data with task context.""" async def read(self, task_ctx: TaskContext) -> TaskOutput: + """Read data with task context. + + Args: + task_ctx (TaskContext): The task context. + + Returns: + TaskOutput: The task output. + + Raises: + ValueError: If the input source is a stream and has been read. + """ data = self._read_data(task_ctx) if _is_async_iterator(data): if self._is_read: raise ValueError(f"Input iterator {data} has been read!") - output = SimpleStreamTaskOutput(data) + output: TaskOutput = SimpleStreamTaskOutput(data) else: output = SimpleTaskOutput(data) self._is_read = True @@ -166,7 +297,14 @@ async def read(self, task_ctx: TaskContext) -> TaskOutput: class SimpleInputSource(BaseInputSource): + """The default implementation of InputSource.""" + def __init__(self, data: Any) -> None: + """Create a SimpleInputSource. + + Args: + data (Any): The input data. + """ super().__init__() self._data = data @@ -175,63 +313,121 @@ def _read_data(self, task_ctx: TaskContext) -> Any: class SimpleCallDataInputSource(BaseInputSource): + """The implementation of InputSource for call data.""" + def __init__(self) -> None: + """Create a SimpleCallDataInputSource.""" super().__init__() def _read_data(self, task_ctx: TaskContext) -> Any: + """Read data from task context. + + Returns: + Any: The data. + + Raises: + ValueError: If the call data is empty. + """ call_data = task_ctx.call_data - data = call_data.get("data") if call_data else None - if not (call_data and data): + data = call_data.get("data", EMPTY_DATA) if call_data else EMPTY_DATA + if data == EMPTY_DATA: raise ValueError("No call data for current SimpleCallDataInputSource") return data class DefaultTaskContext(TaskContext, Generic[T]): + """The default implementation of TaskContext.""" + def __init__( - self, task_id: str, task_state: TaskState, task_output: TaskOutput[T] + self, + task_id: str, + task_state: TaskState, + task_output: Optional[TaskOutput[T]] = None, ) -> None: + """Create a DefaultTaskContext. + + Args: + task_id (str): The task id. + task_state (TaskState): The task state. + task_output (Optional[TaskOutput[T]], optional): The task output. Defaults + to None. + """ super().__init__() self._task_id = task_id self._task_state = task_state - self._output = task_output - self._task_input = None - self._metadata = {} + self._output: Optional[TaskOutput[T]] = task_output + self._task_input: Optional[InputContext] = None + self._metadata: Dict[str, Any] = {} @property def task_id(self) -> str: + """Return the task id.""" return self._task_id @property def task_input(self) -> InputContext: + """Return the task input.""" + if not self._task_input: + raise ValueError("No input for current task context") return self._task_input - def set_task_input(self, input_ctx: "InputContext") -> None: + def set_task_input(self, input_ctx: InputContext) -> None: + """Save the task input to current task.""" self._task_input = input_ctx @property def task_output(self) -> TaskOutput: + """Return the task output. + + Returns: + TaskOutput: The task output. + + Raises: + ValueError: If the task output is empty. + """ + if not self._output: + raise ValueError("No output for current task context") return self._output def set_task_output(self, task_output: TaskOutput) -> None: + """Save the task output to current task. + + Args: + task_output (TaskOutput): The task output. + """ self._output = task_output @property def current_state(self) -> TaskState: + """Return the current task state.""" return self._task_state def set_current_state(self, task_state: TaskState) -> None: + """Save the current task state to current task.""" self._task_state = task_state def new_ctx(self) -> TaskContext: + """Create new task context with empty output.""" + if not self._output: + raise ValueError("No output for current task context") new_output = self._output.new_output() return DefaultTaskContext(self._task_id, self._task_state, new_output) @property def metadata(self) -> Dict[str, Any]: + """Return the metadata of current task. + + Returns: + Dict[str, Any]: The metadata. + """ return self._metadata async def _call_data_to_output(self) -> Optional[TaskOutput[T]]: - """Get the call data for current data""" + """Return the call data of current task. + + Returns: + Optional[TaskOutput[T]]: The call data. + """ call_data = self.call_data if not call_data: return None @@ -240,24 +436,48 @@ async def _call_data_to_output(self) -> Optional[TaskOutput[T]]: class DefaultInputContext(InputContext): + """The default implementation of InputContext. + + It wraps the all inputs from parent tasks and provide some basic data operations. + """ + def __init__(self, outputs: List[TaskContext]) -> None: + """Create a DefaultInputContext. + + Args: + outputs (List[TaskContext]): The outputs from parent tasks. + """ super().__init__() self._outputs = outputs @property def parent_outputs(self) -> List[TaskContext]: + """Return the outputs from parent tasks. + + Returns: + List[TaskContext]: The outputs from parent tasks. + """ return self._outputs async def _apply_func( self, func: Callable[[Any], Any], apply_type: str = "map" ) -> Tuple[List[TaskContext], List[TaskOutput]]: + """Apply the function to all parent outputs. + + Args: + func (Callable[[Any], Any]): The function to apply. + apply_type (str, optional): The apply type. Defaults to "map". + + Returns: + Tuple[List[TaskContext], List[TaskOutput]]: The new parent outputs and the + results of applying the function. + """ new_outputs: List[TaskContext] = [] map_tasks = [] for out in self._outputs: new_outputs.append(out.new_ctx()) - result = None if apply_type == "map": - result = out.task_output.map(func) + result: Coroutine[Any, Any, TaskOutput[Any]] = out.task_output.map(func) elif apply_type == "reduce": result = out.task_output.reduce(func) elif apply_type == "check_condition": @@ -269,29 +489,40 @@ async def _apply_func( return new_outputs, results async def map(self, map_func: Callable[[Any], Any]) -> InputContext: + """Apply a mapping function to all parent outputs.""" new_outputs, results = await self._apply_func(map_func) for i, task_ctx in enumerate(new_outputs): - task_ctx: TaskContext = task_ctx + task_ctx = cast(TaskContext, task_ctx) task_ctx.set_task_output(results[i]) return DefaultInputContext(new_outputs) async def map_all(self, map_func: Callable[..., Any]) -> InputContext: + """Apply a mapping function to all parent outputs. + + The parent outputs will be unpacked and passed to the mapping function. + + Args: + map_func (Callable[..., Any]): The mapping function. + + Returns: + InputContext: The new input context. + """ if not self._outputs: return DefaultInputContext([]) # Some parent may be empty not_empty_idx = 0 for i, p in enumerate(self._outputs): if p.task_output.is_empty: + # Skip empty parent continue not_empty_idx = i break # All output is empty? is_steam = self._outputs[not_empty_idx].task_output.is_stream - if is_steam: - if not self.check_stream(skip_empty=True): - raise ValueError( - "The output in all tasks must has same output format to map_all" - ) + if is_steam and not self.check_stream(skip_empty=True): + raise ValueError( + "The output in all tasks must has same output format to map_all" + ) outputs = [] for out in self._outputs: if out.task_output.is_stream: @@ -305,22 +536,26 @@ async def map_all(self, map_func: Callable[..., Any]) -> InputContext: single_output: TaskContext = self._outputs[not_empty_idx].new_ctx() single_output.task_output.set_output(map_res) logger.debug( - f"Current map_all map_res: {map_res}, is steam: {single_output.task_output.is_stream}" + f"Current map_all map_res: {map_res}, is steam: " + f"{single_output.task_output.is_stream}" ) return DefaultInputContext([single_output]) async def reduce(self, reduce_func: Callable[[Any], Any]) -> InputContext: + """Apply a reduce function to all parent outputs.""" if not self.check_stream(): raise ValueError( - "The output in all tasks must has same output format of stream to apply reduce function" + "The output in all tasks must has same output format of stream to apply" + " reduce function" ) new_outputs, results = await self._apply_func(reduce_func, apply_type="reduce") for i, task_ctx in enumerate(new_outputs): - task_ctx: TaskContext = task_ctx + task_ctx = cast(TaskContext, task_ctx) task_ctx.set_task_output(results[i]) return DefaultInputContext(new_outputs) async def filter(self, filter_func: Callable[[Any], bool]) -> InputContext: + """Filter all parent outputs.""" new_outputs, results = await self._apply_func( filter_func, apply_type="check_condition" ) @@ -331,15 +566,16 @@ async def filter(self, filter_func: Callable[[Any], bool]) -> InputContext: return DefaultInputContext(result_outputs) async def predicate_map( - self, predicate_func: Callable[[Any], bool], failed_value: Any = None + self, predicate_func: PredicateFunc, failed_value: Any = None ) -> "InputContext": + """Apply a predicate function to all parent outputs.""" new_outputs, results = await self._apply_func( predicate_func, apply_type="check_condition" ) result_outputs = [] for i, task_ctx in enumerate(new_outputs): - task_ctx: TaskContext = task_ctx - if results[i]: + task_ctx = cast(TaskContext, task_ctx) + if not results[i].is_empty: task_ctx.task_output.set_output(True) result_outputs.append(task_ctx) else: diff --git a/dbgpt/core/awel/tests/conftest.py b/dbgpt/core/awel/tests/conftest.py index a6fbb1c76..d68ddcfc8 100644 --- a/dbgpt/core/awel/tests/conftest.py +++ b/dbgpt/core/awel/tests/conftest.py @@ -66,10 +66,10 @@ async def _create_input_node(**kwargs): else: outputs = kwargs.get("outputs", ["Hello."]) nodes = [] - for output in outputs: + for i, output in enumerate(outputs): print(f"output: {output}") input_source = SimpleInputSource(output) - input_node = InputOperator(input_source) + input_node = InputOperator(input_source, task_id="input_node_" + str(i)) nodes.append(input_node) yield nodes diff --git a/dbgpt/core/awel/tests/test_run_dag.py b/dbgpt/core/awel/tests/test_run_dag.py index f797c6ccc..ca9f1c4d2 100644 --- a/dbgpt/core/awel/tests/test_run_dag.py +++ b/dbgpt/core/awel/tests/test_run_dag.py @@ -26,7 +26,7 @@ @pytest.mark.asyncio async def test_input_node(runner: WorkflowRunner): - input_node = InputOperator(SimpleInputSource("hello")) + input_node = InputOperator(SimpleInputSource("hello"), task_id="112232") res: DAGContext[str] = await runner.execute_workflow(input_node) assert res.current_task_context.current_state == TaskState.SUCCESS assert res.current_task_context.task_output.output == "hello" @@ -36,7 +36,9 @@ async def new_steam_iter(n: int): yield i num_iter = 10 - steam_input_node = InputOperator(SimpleInputSource(new_steam_iter(num_iter))) + steam_input_node = InputOperator( + SimpleInputSource(new_steam_iter(num_iter)), task_id="112232" + ) res: DAGContext[str] = await runner.execute_workflow(steam_input_node) assert res.current_task_context.current_state == TaskState.SUCCESS output_steam = res.current_task_context.task_output.output_stream diff --git a/dbgpt/core/awel/trigger/__init__.py b/dbgpt/core/awel/trigger/__init__.py index e69de29bb..825dbe29e 100644 --- a/dbgpt/core/awel/trigger/__init__.py +++ b/dbgpt/core/awel/trigger/__init__.py @@ -0,0 +1 @@ +"""The trigger module of AWEL.""" diff --git a/dbgpt/core/awel/trigger/base.py b/dbgpt/core/awel/trigger/base.py index 28662498f..c2410199f 100644 --- a/dbgpt/core/awel/trigger/base.py +++ b/dbgpt/core/awel/trigger/base.py @@ -1,3 +1,4 @@ +"""Base class for all trigger classes.""" from __future__ import annotations from abc import ABC, abstractmethod @@ -6,6 +7,11 @@ class Trigger(TriggerOperator, ABC): + """Base class for all trigger classes. + + Now only support http trigger. + """ + @abstractmethod async def trigger(self) -> None: """Trigger the workflow or a specific operation in the workflow.""" diff --git a/dbgpt/core/awel/trigger/http_trigger.py b/dbgpt/core/awel/trigger/http_trigger.py index 33a6e3ad9..a48dfe385 100644 --- a/dbgpt/core/awel/trigger/http_trigger.py +++ b/dbgpt/core/awel/trigger/http_trigger.py @@ -1,10 +1,11 @@ +"""Http trigger for AWEL.""" from __future__ import annotations import logging -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, Union +from enum import Enum +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, Union, cast from starlette.requests import Request -from starlette.responses import Response from dbgpt._private.pydantic import BaseModel @@ -13,29 +14,35 @@ from .base import Trigger if TYPE_CHECKING: - from fastapi import APIRouter, FastAPI + from fastapi import APIRouter -RequestBody = Union[Type[Request], Type[BaseModel], str] +RequestBody = Union[Type[Request], Type[BaseModel], Type[str]] StreamingPredictFunc = Callable[[Union[Request, BaseModel]], bool] logger = logging.getLogger(__name__) class HttpTrigger(Trigger): + """Http trigger for AWEL. + + Http trigger is used to trigger a DAG by http request. + """ + def __init__( self, endpoint: str, methods: Optional[Union[str, List[str]]] = "GET", request_body: Optional[RequestBody] = None, - streaming_response: Optional[bool] = False, + streaming_response: bool = False, streaming_predict_func: Optional[StreamingPredictFunc] = None, response_model: Optional[Type] = None, response_headers: Optional[Dict[str, str]] = None, response_media_type: Optional[str] = None, status_code: Optional[int] = 200, - router_tags: Optional[List[str]] = None, + router_tags: Optional[List[str | Enum]] = None, **kwargs, ) -> None: + """Initialize a HttpTrigger.""" super().__init__(**kwargs) if not endpoint.startswith("/"): endpoint = "/" + endpoint @@ -49,15 +56,21 @@ def __init__( self._router_tags = router_tags self._response_headers = response_headers self._response_media_type = response_media_type - self._end_node: BaseOperator = None + self._end_node: Optional[BaseOperator] = None async def trigger(self) -> None: + """Trigger the DAG. Not used in HttpTrigger.""" pass def mount_to_router(self, router: "APIRouter") -> None: + """Mount the trigger to a router. + + Args: + router (APIRouter): The router to mount the trigger. + """ from fastapi import Depends - methods = self._methods if isinstance(self._methods, list) else [self._methods] + methods = [self._methods] if isinstance(self._methods, str) else self._methods def create_route_function(name, req_body_cls: Optional[Type[BaseModel]]): async def _request_body_dependency(request: Request): @@ -87,7 +100,8 @@ async def route_function(body=Depends(_request_body_dependency)): ) dynamic_route_function = create_route_function(function_name, request_model) logger.info( - f"mount router function {dynamic_route_function}({function_name}), endpoint: {self._endpoint}, methods: {methods}" + f"mount router function {dynamic_route_function}({function_name}), " + f"endpoint: {self._endpoint}, methods: {methods}" ) router.api_route( @@ -100,17 +114,27 @@ async def route_function(body=Depends(_request_body_dependency)): async def _parse_request_body( - request: Request, request_body_cls: Optional[Type[BaseModel]] + request: Request, request_body_cls: Optional[RequestBody] ): if not request_body_cls: return None + if request_body_cls == Request: + return request if request.method == "POST": - json_data = await request.json() - return request_body_cls(**json_data) + if request_body_cls == str: + bytes_body = await request.body() + str_body = bytes_body.decode("utf-8") + return str_body + elif issubclass(request_body_cls, BaseModel): + json_data = await request.json() + return request_body_cls(**json_data) + else: + raise ValueError(f"Invalid request body cls: {request_body_cls}") elif request.method == "GET": - return request_body_cls(**request.query_params) - else: - return request + if issubclass(request_body_cls, BaseModel): + return request_body_cls(**request.query_params) + else: + raise ValueError(f"Invalid request body cls: {request_body_cls}") async def _trigger_dag( @@ -123,10 +147,10 @@ async def _trigger_dag( from fastapi import BackgroundTasks from fastapi.responses import StreamingResponse - end_node = dag.leaf_nodes - if len(end_node) != 1: + leaf_nodes = dag.leaf_nodes + if len(leaf_nodes) != 1: raise ValueError("HttpTrigger just support one leaf node in dag") - end_node = end_node[0] + end_node = cast(BaseOperator, leaf_nodes[0]) if not streaming_response: return await end_node.call(call_data={"data": body}) else: @@ -141,7 +165,7 @@ async def _trigger_dag( } generator = await end_node.call_stream(call_data={"data": body}) background_tasks = BackgroundTasks() - background_tasks.add_task(end_node.dag._after_dag_end) + background_tasks.add_task(dag._after_dag_end) return StreamingResponse( generator, headers=headers, diff --git a/dbgpt/core/awel/trigger/trigger_manager.py b/dbgpt/core/awel/trigger/trigger_manager.py index 95b4b89ab..c9baed58d 100644 --- a/dbgpt/core/awel/trigger/trigger_manager.py +++ b/dbgpt/core/awel/trigger/trigger_manager.py @@ -1,41 +1,63 @@ +"""Trigger manager for AWEL. + +After DB-GPT started, the trigger manager will be initialized and register all triggers +""" import logging from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any, Dict, Optional + +from dbgpt.component import BaseComponent, ComponentType, SystemApp + +from .base import Trigger if TYPE_CHECKING: from fastapi import APIRouter -from dbgpt.component import BaseComponent, ComponentType, SystemApp logger = logging.getLogger(__name__) class TriggerManager(ABC): + """Base class for trigger manager.""" + @abstractmethod def register_trigger(self, trigger: Any) -> None: - """ "Register a trigger to current manager""" + """Register a trigger to current manager.""" class HttpTriggerManager(TriggerManager): + """Http trigger manager. + + Register all http triggers to a router. + """ + def __init__( self, router: Optional["APIRouter"] = None, - router_prefix: Optional[str] = "/api/v1/awel/trigger", + router_prefix: str = "/api/v1/awel/trigger", ) -> None: + """Initialize a HttpTriggerManager. + + Args: + router (Optional["APIRouter"], optional): The router. Defaults to None. + If None, will create a new FastAPI router. + router_prefix (str, optional): The router prefix. Defaults + to "/api/v1/awel/trigger". + """ if not router: from fastapi import APIRouter router = APIRouter() self._router_prefix = router_prefix self._router = router - self._trigger_map = {} + self._trigger_map: Dict[str, Trigger] = {} def register_trigger(self, trigger: Any) -> None: + """Register a trigger to current manager.""" from .http_trigger import HttpTrigger if not isinstance(trigger, HttpTrigger): raise ValueError(f"Current trigger {trigger} not an object of HttpTrigger") - trigger: HttpTrigger = trigger trigger_id = trigger.node_id if trigger_id not in self._trigger_map: trigger.mount_to_router(self._router) @@ -45,23 +67,32 @@ def _init_app(self, system_app: SystemApp): logger.info( f"Include router {self._router} to prefix path {self._router_prefix}" ) - system_app.app.include_router( - self._router, prefix=self._router_prefix, tags=["AWEL"] - ) + app = system_app.app + if not app: + raise RuntimeError("System app not initialized") + app.include_router(self._router, prefix=self._router_prefix, tags=["AWEL"]) class DefaultTriggerManager(TriggerManager, BaseComponent): + """Default trigger manager for AWEL. + + Manage all trigger managers. Just support http trigger now. + """ + name = ComponentType.AWEL_TRIGGER_MANAGER def __init__(self, system_app: SystemApp | None = None): + """Initialize a DefaultTriggerManager.""" self.system_app = system_app self.http_trigger = HttpTriggerManager() super().__init__(None) def init_app(self, system_app: SystemApp): + """Initialize the trigger manager.""" self.system_app = system_app def register_trigger(self, trigger: Any) -> None: + """Register a trigger to current manager.""" from .http_trigger import HttpTrigger if isinstance(trigger, HttpTrigger): @@ -71,4 +102,6 @@ def register_trigger(self, trigger: Any) -> None: raise ValueError(f"Unsupport trigger: {trigger}") def after_register(self) -> None: - self.http_trigger._init_app(self.system_app) + """After register, init the trigger manager.""" + if self.system_app: + self.http_trigger._init_app(self.system_app) diff --git a/dbgpt/core/interface/__init__.py b/dbgpt/core/interface/__init__.py index e69de29bb..2c49ce509 100644 --- a/dbgpt/core/interface/__init__.py +++ b/dbgpt/core/interface/__init__.py @@ -0,0 +1,4 @@ +"""The core interface of DB-GPT. + +Just include the core interface to keep our dependencies clean. +""" diff --git a/dbgpt/core/interface/cache.py b/dbgpt/core/interface/cache.py index 63babb7a2..f7ae4aafb 100644 --- a/dbgpt/core/interface/cache.py +++ b/dbgpt/core/interface/cache.py @@ -1,3 +1,10 @@ +"""The cache interface. + +The cache interface is used to cache LLM results and embedding results. + +Maybe we can cache more server results in the future. +""" + from abc import ABC, abstractmethod from dataclasses import dataclass from enum import Enum @@ -10,17 +17,23 @@ class RetrievalPolicy(str, Enum): + """The retrieval policy of the cache.""" + EXACT_MATCH = "exact_match" SIMILARITY_MATCH = "similarity_match" class CachePolicy(str, Enum): + """The cache policy of the cache.""" + LRU = "lru" FIFO = "fifo" @dataclass class CacheConfig: + """The cache config.""" + retrieval_policy: Optional[RetrievalPolicy] = RetrievalPolicy.EXACT_MATCH cache_policy: Optional[CachePolicy] = CachePolicy.LRU @@ -30,7 +43,8 @@ class CacheKey(Serializable, ABC, Generic[K]): Supported cache keys: - The LLM cache key: Include user prompt and the parameters to LLM. - - The embedding model cache key: Include the texts to embedding and the parameters to embedding model. + - The embedding model cache key: Include the texts to embedding and the parameters + to embedding model. """ @abstractmethod @@ -76,7 +90,8 @@ async def get( cache_config (Optional[CacheConfig]): Cache config Returns: - Optional[CacheValue[V]]: The value retrieved according to key. If cache key not exist, return None. + Optional[CacheValue[V]]: The value retrieved according to key. If cache key + not exist, return None. """ @abstractmethod @@ -110,8 +125,8 @@ async def exists( @abstractmethod def new_key(self, **kwargs) -> CacheKey[K]: - """Create a cache key with params""" + """Create a cache key with params.""" @abstractmethod def new_value(self, **kwargs) -> CacheValue[K]: - """Create a cache key with params""" + """Create a cache key with params.""" diff --git a/dbgpt/core/interface/llm.py b/dbgpt/core/interface/llm.py index 50a668a61..6d0bd98c6 100644 --- a/dbgpt/core/interface/llm.py +++ b/dbgpt/core/interface/llm.py @@ -1,3 +1,5 @@ +"""The interface for LLM.""" + import collections import copy import logging @@ -31,7 +33,8 @@ class ModelInferenceMetrics: """The timestamp (in milliseconds) when the model inference ends.""" current_time_ms: Optional[int] = None - """The current timestamp (in milliseconds) when the model inference return partially output(stream).""" + """The current timestamp (in milliseconds) when the model inference return + partially output(stream).""" first_token_time_ms: Optional[int] = None """The timestamp (in milliseconds) when the first token is generated.""" @@ -64,6 +67,14 @@ class ModelInferenceMetrics: def create_metrics( last_metrics: Optional["ModelInferenceMetrics"] = None, ) -> "ModelInferenceMetrics": + """Create metrics for model inference. + + Args: + last_metrics(ModelInferenceMetrics): The last metrics. + + Returns: + ModelInferenceMetrics: The metrics for model inference. + """ start_time_ms = last_metrics.start_time_ms if last_metrics else None first_token_time_ms = last_metrics.first_token_time_ms if last_metrics else None first_completion_time_ms = ( @@ -100,15 +111,21 @@ def create_metrics( ) def to_dict(self) -> Dict: + """Convert the model inference metrics to dict.""" return asdict(self) @dataclass @PublicAPI(stability="beta") class ModelRequestContext: - stream: Optional[bool] = False + """A class to represent the context of a LLM model request.""" + + stream: bool = False """Whether to return a stream of responses.""" + cache_enable: bool = False + """Whether to enable the cache for the model inference""" + user_name: Optional[str] = None """The user name of the model request.""" @@ -129,8 +146,6 @@ class ModelRequestContext: request_id: Optional[str] = None """The request id of the model inference.""" - cache_enable: Optional[bool] = False - """Whether to enable the cache for the model inference""" @dataclass @@ -141,27 +156,31 @@ class ModelOutput: text: str """The generated text.""" error_code: int - """The error code of the model inference. If the model inference is successful, the error code is 0.""" - model_context: Dict = None - finish_reason: str = None - usage: Dict[str, Any] = None + """The error code of the model inference. If the model inference is successful, + the error code is 0.""" + model_context: Optional[Dict] = None + finish_reason: Optional[str] = None + usage: Optional[Dict[str, Any]] = None metrics: Optional[ModelInferenceMetrics] = None """Some metrics for model inference""" def to_dict(self) -> Dict: + """Convert the model output to dict.""" return asdict(self) -_ModelMessageType = Union[ModelMessage, Dict[str, Any]] +_ModelMessageType = Union[List[ModelMessage], List[Dict[str, Any]]] @dataclass @PublicAPI(stability="beta") class ModelRequest: + """The model request.""" + model: str """The name of the model.""" - messages: List[_ModelMessageType] + messages: _ModelMessageType """The input messages.""" temperature: Optional[float] = None @@ -189,28 +208,42 @@ class ModelRequest: @property def stream(self) -> bool: """Whether to return a stream of responses.""" - return self.context and self.context.stream + return bool(self.context and self.context.stream) + + def copy(self) -> "ModelRequest": + """Copy the model request. - def copy(self): + Returns: + ModelRequest: The copied model request. + """ new_request = copy.deepcopy(self) # Transform messages to List[ModelMessage] - new_request.messages = list( - map( - lambda m: m if isinstance(m, ModelMessage) else ModelMessage(**m), - new_request.messages, - ) - ) + new_request.messages = new_request.get_messages() return new_request def to_dict(self) -> Dict[str, Any]: + """Convert the model request to dict. + + Returns: + Dict[str, Any]: The model request in dict. + """ new_reqeust = copy.deepcopy(self) - new_reqeust.messages = list( - map(lambda m: m if isinstance(m, dict) else m.dict(), new_reqeust.messages) - ) + new_messages = [] + for message in new_reqeust.messages: + if isinstance(message, dict): + new_messages.append(message) + else: + new_messages.append(message.dict()) + new_reqeust.messages = new_messages # Skip None fields return {k: v for k, v in asdict(new_reqeust).items() if v is not None} - def to_trace_metadata(self): + def to_trace_metadata(self) -> Dict[str, Any]: + """Convert the model request to trace metadata. + + Returns: + Dict[str, Any]: The trace metadata. + """ metadata = self.to_dict() metadata["prompt"] = self.messages_to_string() return metadata @@ -218,16 +251,19 @@ def to_trace_metadata(self): def get_messages(self) -> List[ModelMessage]: """Get the messages. - If the messages is not a list of ModelMessage, it will be converted to a list of ModelMessage. + If the messages is not a list of ModelMessage, it will be converted to a list + of ModelMessage. + Returns: List[ModelMessage]: The messages. """ - return list( - map( - lambda m: m if isinstance(m, ModelMessage) else ModelMessage(**m), - self.messages, - ) - ) + messages = [] + for message in self.messages: + if isinstance(message, dict): + messages.append(ModelMessage(**message)) + else: + messages.append(message) + return messages def get_single_user_message(self) -> Optional[ModelMessage]: """Get the single user message. @@ -245,20 +281,35 @@ def build_request( model: str, messages: List[ModelMessage], context: Optional[Union[ModelRequestContext, Dict[str, Any], BaseModel]] = None, - stream: Optional[bool] = False, - echo: Optional[bool] = False, + stream: bool = False, + echo: bool = False, **kwargs, ): + """Build a model request. + + Args: + model(str): The model name. + messages(List[ModelMessage]): The messages. + context(Optional[Union[ModelRequestContext, Dict[str, Any], BaseModel]]): + The context. + stream(bool): Whether to return a stream of responses. Defaults to False. + echo(bool): Whether to echo the input messages. Defaults to False. + **kwargs: Other arguments. + """ if not context: context = ModelRequestContext(stream=stream) - context_dict = None - if isinstance(context, dict): - context_dict = context - elif isinstance(context, BaseModel): - context_dict = context.dict() - if context_dict and "stream" not in context_dict: - context_dict["stream"] = stream - context = ModelRequestContext(**context_dict) + elif not isinstance(context, ModelRequestContext): + context_dict = None + if isinstance(context, dict): + context_dict = context + elif isinstance(context, BaseModel): + context_dict = context.dict() + if context_dict and "stream" not in context_dict: + context_dict["stream"] = stream + if context_dict: + context = ModelRequestContext(**context_dict) + else: + context = ModelRequestContext(stream=stream) return ModelRequest( model=model, messages=messages, @@ -292,7 +343,6 @@ def to_common_messages( ValueError: If the message role is not supported Examples: - .. code-block:: python from dbgpt.core.interface.message import ( @@ -337,7 +387,7 @@ def messages_to_string(self) -> str: class ModelExtraMedata(BaseParameters): """A class to represent the extra metadata of a LLM.""" - prompt_roles: Optional[List[str]] = field( + prompt_roles: List[str] = field( default_factory=lambda: [ ModelMessageRoleType.SYSTEM, ModelMessageRoleType.HUMAN, @@ -356,7 +406,8 @@ class ModelExtraMedata(BaseParameters): prompt_chat_template: Optional[str] = field( default=None, metadata={ - "help": "The chat template, see: https://huggingface.co/docs/transformers/main/en/chat_templating" + "help": "The chat template, see: " + "https://huggingface.co/docs/transformers/main/en/chat_templating" }, ) @@ -403,19 +454,19 @@ class ModelMetadata(BaseParameters): def from_dict( cls, data: dict, ignore_extra_fields: bool = False ) -> "ModelMetadata": + """Create a new model metadata from a dict.""" if "ext_metadata" in data: data["ext_metadata"] = ModelExtraMedata(**data["ext_metadata"]) return cls(**data) class MessageConverter(ABC): - """An abstract class for message converter. + r"""An abstract class for message converter. - Different LLMs may have different message formats, this class is used to convert the messages - to the format of the LLM. + Different LLMs may have different message formats, this class is used to convert + the messages to the format of the LLM. Examples: - >>> from typing import List >>> from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType >>> from dbgpt.core.interface.llm import MessageConverter, ModelMetadata @@ -425,7 +476,8 @@ class MessageConverter(ABC): ... messages: List[ModelMessage], ... model_metadata: Optional[ModelMetadata] = None, ... ) -> List[ModelMessage]: - ... # Convert the messages, merge system messages to the last user message. + ... # Convert the messages, merge system messages to the last user + ... # message. ... system_message = None ... other_messages = [] ... sep = "\\n" @@ -478,6 +530,7 @@ class DefaultMessageConverter(MessageConverter): """The default message converter.""" def __init__(self, prompt_sep: Optional[str] = None): + """Create a new default message converter.""" self._prompt_sep = prompt_sep def convert( @@ -493,7 +546,8 @@ def convert( 2. Move the last user's message to the end of the list - 3. Convert the messages to no system message if the model does not support system message + 3. Convert the messages to no system message if the model does not support + system message Args: messages(List[ModelMessage]): The messages. @@ -520,10 +574,11 @@ def convert_to_no_system_message( messages: List[ModelMessage], model_metadata: Optional[ModelMetadata] = None, ) -> List[ModelMessage]: - """Convert the messages to no system message. + r"""Convert the messages to no system message. Examples: - >>> # Convert the messages to no system message, just merge system messages to the last user message + >>> # Convert the messages to no system message, just merge system messages + >>> # to the last user message >>> from typing import List >>> from dbgpt.core.interface.message import ( ... ModelMessage, @@ -550,7 +605,7 @@ def convert_to_no_system_message( >>> assert converted_messages == [ ... ModelMessage( ... role=ModelMessageRoleType.HUMAN, - ... content="You are a helpful assistant\\nWho are you", + ... content="You are a helpful assistant\nWho are you", ... ), ... ] """ @@ -562,7 +617,8 @@ def convert_to_no_system_message( result_messages = [] for message in messages: if message.role == ModelMessageRoleType.SYSTEM: - # Not support system message, append system message to the last user message + # Not support system message, append system message to the last user + # message system_messages.append(message) elif message.role in [ ModelMessageRoleType.HUMAN, @@ -578,7 +634,8 @@ def convert_to_no_system_message( system_message_str = system_messages[0].content if system_message_str and result_messages: - # Not support system messages, merge system messages to the last user message + # Not support system messages, merge system messages to the last user + # message result_messages[-1].content = ( system_message_str + prompt_sep + result_messages[-1].content ) @@ -587,10 +644,9 @@ def convert_to_no_system_message( def move_last_user_message_to_end( self, messages: List[ModelMessage] ) -> List[ModelMessage]: - """Move the last user message to the end of the list. + """Try to move the last user message to the end of the list. Examples: - >>> from typing import List >>> from dbgpt.core.interface.message import ( ... ModelMessage, @@ -660,7 +716,7 @@ class LLMClient(ABC): @property def cache(self) -> collections.abc.MutableMapping: - """The cache object to cache the model metadata. + """Return the cache object to cache the model metadata. You can override this property to use your own cache object. Returns: @@ -677,7 +733,8 @@ async def generate( """Generate a response for a given model request. Sometimes, different LLMs may have different message formats, - you can use the message converter to convert the messages to the format of the LLM. + you can use the message converter to convert the messages to the format of the + LLM. Args: request(ModelRequest): The model request. @@ -697,7 +754,8 @@ async def generate_stream( """Generate a stream of responses for a given model request. Sometimes, different LLMs may have different message formats, - you can use the message converter to convert the messages to the format of the LLM. + you can use the message converter to convert the messages to the format of the + LLM. Args: request(ModelRequest): The model request. @@ -733,6 +791,7 @@ async def covert_message( message_converter: Optional[MessageConverter] = None, ) -> ModelRequest: """Covert the message. + If no message converter is provided, the original request will be returned. Args: @@ -746,14 +805,15 @@ async def covert_message( return request new_request = request.copy() model_metadata = await self.get_model_metadata(request.model) - new_messages = message_converter.convert(request.messages, model_metadata) + new_messages = message_converter.convert(request.get_messages(), model_metadata) new_request.messages = new_messages return new_request async def cached_models(self) -> List[ModelMetadata]: """Get all the models from the cache or the llm server. - If the model metadata is not in the cache, it will be fetched from the llm server. + If the model metadata is not in the cache, it will be fetched from the + llm server. Returns: List[ModelMetadata]: A list of model metadata. diff --git a/dbgpt/core/interface/message.py b/dbgpt/core/interface/message.py index d6b548f55..d78308600 100755 --- a/dbgpt/core/interface/message.py +++ b/dbgpt/core/interface/message.py @@ -1,8 +1,10 @@ +"""The conversation and message module.""" + from __future__ import annotations from abc import ABC, abstractmethod from datetime import datetime -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Callable, Dict, List, Optional, Tuple, Union, cast from dbgpt._private.pydantic import BaseModel, Field from dbgpt.core.interface.storage import ( @@ -29,11 +31,11 @@ def type(self) -> str: @property def pass_to_model(self) -> bool: - """Whether the message will be passed to the model""" + """Whether the message will be passed to the model.""" return True def to_dict(self) -> Dict: - """Convert to dict + """Convert to dict. Returns: Dict: The dict object @@ -47,7 +49,7 @@ def to_dict(self) -> Dict: @staticmethod def messages_to_string(messages: List["BaseMessage"]) -> str: - """Convert messages to str + """Convert messages to str. Args: messages (List[BaseMessage]): The messages @@ -92,7 +94,7 @@ def type(self) -> str: @property def pass_to_model(self) -> bool: - """Whether the message will be passed to the model + """Whether the message will be passed to the model. The view message will not be passed to the model """ @@ -109,7 +111,7 @@ def type(self) -> str: class ModelMessageRoleType: - """ "Type of ModelMessage role""" + """Type of ModelMessage role.""" SYSTEM = "system" HUMAN = "human" @@ -118,7 +120,7 @@ class ModelMessageRoleType: class ModelMessage(BaseModel): - """Type of message that interaction between dbgpt-server and llm-server""" + """Type of message that interaction between dbgpt-server and llm-server.""" """Similar to openai's message format""" role: str @@ -127,7 +129,7 @@ class ModelMessage(BaseModel): @property def pass_to_model(self) -> bool: - """Whether the message will be passed to the model + """Whether the message will be passed to the model. The view message will not be passed to the model @@ -142,6 +144,14 @@ def pass_to_model(self) -> bool: @staticmethod def from_base_messages(messages: List[BaseMessage]) -> List["ModelMessage"]: + """Covert BaseMessage format to current ModelMessage format. + + Args: + messages (List[BaseMessage]): The base messages + + Returns: + List[ModelMessage]: The model messages + """ result = [] for message in messages: content, round_index = message.content, message.round_index @@ -173,7 +183,7 @@ def from_base_messages(messages: List[BaseMessage]) -> List["ModelMessage"]: def from_openai_messages( messages: Union[str, List[Dict[str, str]]] ) -> List["ModelMessage"]: - """Openai message format to current ModelMessage format""" + """Openai message format to current ModelMessage format.""" if isinstance(messages, str): return [ModelMessage(role=ModelMessageRoleType.HUMAN, content=messages)] result = [] @@ -202,8 +212,11 @@ def to_common_messages( convert_to_compatible_format: bool = False, support_system_role: bool = True, ) -> List[Dict[str, str]]: - """Convert to common message format(e.g. OpenAI message format) and - huggingface [Templates of Chat Models](https://huggingface.co/docs/transformers/v4.34.1/en/chat_templating) + """Cover to common message format. + + Convert to common message format(e.g. OpenAI message format) and + huggingface [Templates of Chat Models] + (https://huggingface.co/docs/transformers/v4.34.1/en/chat_templating) Args: messages (List["ModelMessage"]): The model messages @@ -243,15 +256,38 @@ def to_common_messages( @staticmethod def to_dict_list(messages: List["ModelMessage"]) -> List[Dict[str, str]]: + """Convert to dict list. + + Args: + messages (List["ModelMessage"]): The model messages + + Returns: + List[Dict[str, str]]: The dict list + """ return list(map(lambda m: m.dict(), messages)) @staticmethod def build_human_message(content: str) -> "ModelMessage": + """Build human message. + + Args: + content (str): The content + + Returns: + ModelMessage: The model message + """ return ModelMessage(role=ModelMessageRoleType.HUMAN, content=content) @staticmethod def get_printable_message(messages: List["ModelMessage"]) -> str: - """Get the printable message""" + """Get the printable message. + + Args: + messages (List["ModelMessage"]): The model messages + + Returns: + str: The printable message + """ str_msg = "" for message in messages: curr_message = ( @@ -263,7 +299,7 @@ def get_printable_message(messages: List["ModelMessage"]) -> str: @staticmethod def messages_to_string(messages: List["ModelMessage"]) -> str: - """Convert messages to str + """Convert messages to str. Args: messages (List[ModelMessage]): The messages @@ -287,12 +323,12 @@ def _messages_to_dict(messages: List[BaseMessage]) -> List[Dict]: def _messages_to_str( - messages: List[Union[BaseMessage, ModelMessage]], + messages: Union[List[BaseMessage], List[ModelMessage]], human_prefix: str = "Human", ai_prefix: str = "AI", system_prefix: str = "System", ) -> str: - """Convert messages to str + """Convert messages to str. Args: messages (List[Union[BaseMessage, ModelMessage]]): The messages @@ -343,21 +379,27 @@ def _messages_from_dict(messages: List[Dict]) -> List[BaseMessage]: def parse_model_messages( messages: List[ModelMessage], -) -> Tuple[str, List[str], List[List[str, str]]]: - """ - Parse model messages to extract the user prompt, system messages, and a history of conversation. +) -> Tuple[str, List[str], List[List[str]]]: + """Parse model messages. + + Parse model messages to extract the user prompt, system messages, and a history of + conversation. - This function analyzes a list of ModelMessage objects, identifying the role of each message (e.g., human, system, ai) - and categorizes them accordingly. The last message is expected to be from the user (human), and it's treated as - the current user prompt. System messages are extracted separately, and the conversation history is compiled into - pairs of human and AI messages. + This function analyzes a list of ModelMessage objects, identifying the role of each + message (e.g., human, system, ai) + and categorizes them accordingly. The last message is expected to be from the user + (human), and it's treated as + the current user prompt. System messages are extracted separately, and the + conversation history is compiled into pairs of human and AI messages. Args: messages (List[ModelMessage]): List of messages from a chat conversation. Returns: - tuple: A tuple containing the user prompt, list of system messages, and the conversation history. - The conversation history is a list of message pairs, each containing a user message and the corresponding AI response. + tuple: A tuple containing the user prompt, list of system messages, and the + conversation history. + The conversation history is a list of message pairs, each containing a + user message and the corresponding AI response. Examples: .. code-block:: python @@ -399,7 +441,6 @@ def parse_model_messages( # system_messages: ["Error 404"] # history: [["Hi", "Hello!"], ["What's the error?", "Just a joke."]] """ - system_messages: List[str] = [] history_messages: List[List[str]] = [[]] @@ -420,27 +461,30 @@ def parse_model_messages( class OnceConversation: - """All the information of a conversation, the current single service in memory, - can expand cache and database support distributed services. + """Once conversation. + All the information of a conversation, the current single service in memory, + can expand cache and database support distributed services. """ def __init__( self, chat_mode: str, - user_name: str = None, - sys_code: str = None, - summary: str = None, + user_name: Optional[str] = None, + sys_code: Optional[str] = None, + summary: Optional[str] = None, **kwargs, ): + """Create a new conversation.""" self.chat_mode: str = chat_mode - self.user_name: str = user_name - self.sys_code: str = sys_code - self.summary: str = summary + self.user_name: Optional[str] = user_name + self.sys_code: Optional[str] = sys_code + self.summary: Optional[str] = summary self.messages: List[BaseMessage] = kwargs.get("messages", []) self.start_date: str = kwargs.get("start_date", "") - # After each complete round of dialogue, the current value will be increased by 1 + # After each complete round of dialogue, the current value will be + # increased by 1 self.chat_order: int = int(kwargs.get("chat_order", 0)) self.model_name: str = kwargs.get("model_name", "") self.param_type: str = kwargs.get("param_type", "") @@ -460,10 +504,9 @@ def _append_message(self, message: BaseMessage) -> None: self.messages.append(message) def start_new_round(self) -> None: - """Start a new round of conversation + """Start a new round of conversation. Example: - >>> conversation = OnceConversation("chat_normal") >>> # The chat order will be 0, then we start a new round of conversation >>> assert conversation.chat_order == 0 @@ -473,7 +516,8 @@ def start_new_round(self) -> None: >>> conversation.add_user_message("hello") >>> conversation.add_ai_message("hi") >>> conversation.end_current_round() - >>> # Now the chat order will be 1, then we start a new round of conversation + >>> # Now the chat order will be 1, then we start a new round of + >>> # conversation >>> conversation.start_new_round() >>> # Now the chat order will be 2 >>> assert conversation.chat_order == 2 @@ -485,7 +529,7 @@ def start_new_round(self) -> None: self.chat_order += 1 def end_current_round(self) -> None: - """End the current round of conversation + """Execute the end of the current round of conversation. We do noting here, just for the interface """ @@ -494,7 +538,7 @@ def end_current_round(self) -> None: def add_user_message( self, message: str, check_duplicate_type: Optional[bool] = False ) -> None: - """Add a user message to the conversation + """Save a user message to the conversation. Args: message (str): The message content @@ -514,11 +558,12 @@ def add_user_message( def add_ai_message( self, message: str, update_if_exist: Optional[bool] = False ) -> None: - """Add an AI message to the conversation + """Save an AI message to current conversation. Args: message (str): The message content - update_if_exist (bool): Whether to update the message if the message type is duplicate + update_if_exist (bool): Whether to update the message if the message type + is duplicate """ if not update_if_exist: self._append_message(AIMessage(content=message)) @@ -530,51 +575,57 @@ def add_ai_message( self._append_message(AIMessage(content=message)) def _update_ai_message(self, new_message: str) -> None: - """ - stream out message update - Args: - new_message: + """Update the all AI message to new message. - Returns: + stream out message update + Args: + new_message (str): The new message """ - for item in self.messages: if item.type == "ai": item.content = new_message def add_view_message(self, message: str) -> None: - """Add an AI message to the store""" + """Save a view message to current conversation.""" self._append_message(ViewMessage(content=message)) def add_system_message(self, message: str) -> None: - """Add a system message to the store""" + """Save a system message to current conversation.""" self._append_message(SystemMessage(content=message)) def set_start_time(self, datatime: datetime): + """Set the start time of the conversation.""" dt_str = datatime.strftime("%Y-%m-%d %H:%M:%S") self.start_date = dt_str def clear(self) -> None: - """Remove all messages from the store""" + """Remove all messages from the store.""" self.messages.clear() def get_latest_user_message(self) -> Optional[HumanMessage]: - """Get the latest user message""" + """Get the latest user message.""" for message in self.messages[::-1]: if isinstance(message, HumanMessage): return message return None def get_system_messages(self) -> List[SystemMessage]: - """Get the latest user message""" - return list(filter(lambda x: isinstance(x, SystemMessage), self.messages)) + """Get the latest user message. + + Returns: + List[SystemMessage]: The system messages + """ + return cast( + List[SystemMessage], + list(filter(lambda x: isinstance(x, SystemMessage), self.messages)), + ) def _to_dict(self) -> Dict: return _conversation_to_dict(self) def from_conversation(self, conversation: OnceConversation) -> None: - """Load the conversation from the storage""" + """Load the conversation from the storage.""" self.chat_mode = conversation.chat_mode self.messages = conversation.messages self.start_date = conversation.start_date @@ -592,7 +643,7 @@ def from_conversation(self, conversation: OnceConversation) -> None: self._message_index = conversation._message_index def get_messages_by_round(self, round_index: int) -> List[BaseMessage]: - """Get the messages by round index + """Get the messages by round index. Args: round_index (int): The round index @@ -603,7 +654,7 @@ def get_messages_by_round(self, round_index: int) -> List[BaseMessage]: return list(filter(lambda x: x.round_index == round_index, self.messages)) def get_latest_round(self) -> List[BaseMessage]: - """Get the latest round messages + """Get the latest round messages. Returns: List[BaseMessage]: The messages @@ -611,7 +662,7 @@ def get_latest_round(self) -> List[BaseMessage]: return self.get_messages_by_round(self.chat_order) def get_messages_with_round(self, round_count: int) -> List[BaseMessage]: - """Get the messages with round count + """Get the messages with round count. If the round count is 1, the history messages will not be included. @@ -660,16 +711,19 @@ def get_messages_with_round(self, round_count: int) -> List[BaseMessage]: return messages def get_model_messages(self) -> List[ModelMessage]: - """Get the model messages + """Get the model messages. Model messages just include human, ai and system messages. - Model messages maybe include the history messages, The order of the messages is the same as the order of + Model messages maybe include the history messages, The order of the messages is + the same as the order of the messages in the conversation, the last message is the latest message. - If you want to hand the message with your own logic, you can override this method. + If you want to hand the message with your own logic, you can override this + method. Examples: - If you not need the history messages, you can override this method like this: + If you not need the history messages, you can override this method + like this: .. code-block:: python def get_model_messages(self) -> List[ModelMessage]: @@ -681,7 +735,8 @@ def get_model_messages(self) -> List[ModelMessage]: ) return messages - If you want to add the one round history messages, you can override this method like this: + If you want to add the one round history messages, you can override this + method like this: .. code-block:: python def get_model_messages(self) -> List[ModelMessage]: @@ -717,7 +772,7 @@ def get_model_messages(self) -> List[ModelMessage]: def get_history_message( self, include_system_message: bool = False ) -> List[BaseMessage]: - """Get the history message + """Get the history message. Not include the system messages. @@ -729,46 +784,60 @@ def get_history_message( """ messages = [] for message in self.messages: - if message.pass_to_model: - if include_system_message: - messages.append(message) - elif message.type != "system": - messages.append(message) + if ( + message.pass_to_model + and include_system_message + or message.type != "system" + ): + messages.append(message) return messages class ConversationIdentifier(ResourceIdentifier): - """Conversation identifier""" + """Conversation identifier.""" def __init__(self, conv_uid: str, identifier_type: str = "conversation"): + """Create a conversation identifier. + + Args: + conv_uid (str): The conversation uid + identifier_type (str): The identifier type + """ self.conv_uid = conv_uid self.identifier_type = identifier_type @property def str_identifier(self) -> str: + """Return the str identifier.""" return f"{self.identifier_type}:{self.conv_uid}" def to_dict(self) -> Dict: + """Convert to dict.""" return {"conv_uid": self.conv_uid, "identifier_type": self.identifier_type} class MessageIdentifier(ResourceIdentifier): - """Message identifier""" + """Message identifier.""" identifier_split = "___" def __init__(self, conv_uid: str, index: int, identifier_type: str = "message"): + """Create a message identifier.""" self.conv_uid = conv_uid self.index = index self.identifier_type = identifier_type @property def str_identifier(self) -> str: - return f"{self.identifier_type}{self.identifier_split}{self.conv_uid}{self.identifier_split}{self.index}" + """Return the str identifier.""" + return ( + f"{self.identifier_type}{self.identifier_split}{self.conv_uid}" + f"{self.identifier_split}{self.index}" + ) @staticmethod def from_str_identifier(str_identifier: str) -> MessageIdentifier: - """Convert from str identifier + """Convert from str identifier. Args: str_identifier (str): The str identifier @@ -782,6 +851,7 @@ def from_str_identifier(str_identifier: str) -> MessageIdentifier: return MessageIdentifier(parts[1], int(parts[2])) def to_dict(self) -> Dict: + """Convert to dict.""" return { "conv_uid": self.conv_uid, "index": self.index, @@ -790,17 +860,31 @@ def to_dict(self) -> Dict: class MessageStorageItem(StorageItem): + """The message storage item. + + Keep the message detail and the message index. + """ + @property def identifier(self) -> MessageIdentifier: + """Return the identifier.""" return self._id def __init__(self, conv_uid: str, index: int, message_detail: Dict): + """Create a message storage item. + + Args: + conv_uid (str): The conversation uid + index (int): The message index + message_detail (Dict): The message detail + """ self.conv_uid = conv_uid self.index = index self.message_detail = message_detail self._id = MessageIdentifier(conv_uid, index) def to_dict(self) -> Dict: + """Convert to dict.""" return { "conv_uid": self.conv_uid, "index": self.index, @@ -808,7 +892,8 @@ def to_dict(self) -> Dict: } def to_message(self) -> BaseMessage: - """Convert to message object + """Convert to message object. + Returns: BaseMessage: The message object @@ -818,7 +903,7 @@ def to_message(self) -> BaseMessage: return _message_from_dict(self.message_detail) def merge(self, other: "StorageItem") -> None: - """Merge the other message to self + """Merge the other message to self. Args: other (StorageItem): The other message @@ -829,16 +914,20 @@ def merge(self, other: "StorageItem") -> None: class StorageConversation(OnceConversation, StorageItem): - """All the information of a conversation, the current single service in memory, + """The storage conversation. + + All the information of a conversation, the current single service in memory, can expand cache and database support distributed services. """ @property def identifier(self) -> ConversationIdentifier: + """Return the identifier.""" return self._id def to_dict(self) -> Dict: + """Convert to dict.""" dict_data = self._to_dict() messages: Dict = dict_data.pop("messages") message_ids = [] @@ -859,7 +948,7 @@ def to_dict(self) -> Dict: return dict_data def merge(self, other: "StorageItem") -> None: - """Merge the other conversation to self + """Merge the other conversation to self. Args: other (StorageItem): The other conversation @@ -871,17 +960,18 @@ def merge(self, other: "StorageItem") -> None: def __init__( self, conv_uid: str, - chat_mode: str = None, - user_name: str = None, - sys_code: str = None, - message_ids: List[str] = None, - summary: str = None, - save_message_independent: Optional[bool] = True, - conv_storage: StorageInterface = None, - message_storage: StorageInterface = None, + chat_mode: str = "chat_normal", + user_name: Optional[str] = None, + sys_code: Optional[str] = None, + message_ids: Optional[List[str]] = None, + summary: Optional[str] = None, + save_message_independent: bool = True, + conv_storage: Optional[StorageInterface] = None, + message_storage: Optional[StorageInterface] = None, load_message: bool = True, **kwargs, ): + """Create a conversation.""" super().__init__(chat_mode, user_name, sys_code, summary, **kwargs) self.conv_uid = conv_uid self._message_ids = message_ids @@ -905,7 +995,7 @@ def __init__( @property def message_ids(self) -> List[str]: - """Get the message ids + """Return the message ids. Returns: List[str]: The message ids @@ -913,7 +1003,7 @@ def message_ids(self) -> List[str]: return self._message_ids if self._message_ids else [] def end_current_round(self) -> None: - """End the current round of conversation + """End the current round of conversation. Save the conversation to the storage after a round of conversation """ @@ -926,7 +1016,7 @@ def _get_message_items(self) -> List[MessageStorageItem]: ] def save_to_storage(self) -> None: - """Save the conversation to the storage""" + """Save the conversation to the storage.""" # Save messages first message_list = self._get_message_items() self._message_ids = [ @@ -943,7 +1033,7 @@ def save_to_storage(self) -> None: def load_from_storage( self, conv_storage: StorageInterface, message_storage: StorageInterface ) -> None: - """Load the conversation from the storage + """Load the conversation from the storage. Warning: This will overwrite the current conversation. @@ -952,7 +1042,7 @@ def load_from_storage( message_storage (StorageInterface): The storage interface """ # Load conversation first - conversation: StorageConversation = conv_storage.load( + conversation: Optional[StorageConversation] = conv_storage.load( self._id, StorageConversation ) if conversation is None: @@ -988,18 +1078,18 @@ def load_from_storage( def _append_additional_kwargs( self, conversation: StorageConversation, messages: List[BaseMessage] ) -> None: - """Parse the additional kwargs and append to the conversation + """Parse the additional kwargs and append to the conversation. Args: conversation (StorageConversation): The conversation messages (List[BaseMessage]): The messages """ - param_type = None - param_value = None + param_type = "" + param_value = "" for message in messages[::-1]: if message.additional_kwargs: - param_type = message.additional_kwargs.get("param_type") - param_value = message.additional_kwargs.get("param_value") + param_type = message.additional_kwargs.get("param_type", "") + param_value = message.additional_kwargs.get("param_value", "") break if not conversation.param_type: conversation.param_type = param_type @@ -1007,7 +1097,7 @@ def _append_additional_kwargs( conversation.param_value = param_value def delete(self) -> None: - """Delete all the messages and conversation from the storage""" + """Delete all the messages and conversation.""" # Delete messages first message_list = self._get_message_items() message_ids = [message.identifier for message in message_list] @@ -1055,13 +1145,13 @@ def _conversations_to_dict(conversations: List[OnceConversation]) -> List[dict]: def _conversation_from_dict(once: dict) -> OnceConversation: conversation = OnceConversation( - once.get("chat_mode"), once.get("user_name"), once.get("sys_code") + once.get("chat_mode", ""), once.get("user_name"), once.get("sys_code") ) conversation.cost = once.get("cost", 0) conversation.chat_mode = once.get("chat_mode", "chat_normal") conversation.tokens = once.get("tokens", 0) conversation.start_date = once.get("start_date", "") - conversation.chat_order = int(once.get("chat_order")) + conversation.chat_order = int(once.get("chat_order", 0)) conversation.param_type = once.get("param_type", "") conversation.param_value = once.get("param_value", "") conversation.model_name = once.get("model_name", "proxyllm") @@ -1093,7 +1183,8 @@ def _split_messages_by_round(messages: List[BaseMessage]) -> List[List[BaseMessa def _append_view_messages(messages: List[BaseMessage]) -> List[BaseMessage]: - """Append the view message to the messages + """Append the view message to the messages. + Just for show in DB-GPT-Web. If already have view message, do nothing. diff --git a/dbgpt/core/interface/operator/__init__.py b/dbgpt/core/interface/operator/__init__.py index e69de29bb..d1b2dec96 100644 --- a/dbgpt/core/interface/operator/__init__.py +++ b/dbgpt/core/interface/operator/__init__.py @@ -0,0 +1 @@ +"""The module include all core operators of DB-GPT.""" diff --git a/dbgpt/core/interface/operator/composer_operator.py b/dbgpt/core/interface/operator/composer_operator.py index 36eb97fbc..a8896702a 100644 --- a/dbgpt/core/interface/operator/composer_operator.py +++ b/dbgpt/core/interface/operator/composer_operator.py @@ -1,5 +1,9 @@ +"""The chat history prompt composer operator. + +We can wrap some atomic operators to a complex operator. +""" import dataclasses -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, cast from dbgpt.core import ( ChatPromptTemplate, @@ -51,6 +55,7 @@ def __init__( message_storage: Optional[StorageInterface[MessageStorageItem, Any]] = None, **kwargs, ): + """Create a new chat history prompt composer operator.""" super().__init__(**kwargs) self._prompt_template = prompt_template self._history_key = history_key @@ -61,7 +66,8 @@ def __init__( self._sub_compose_dag = self._build_composer_dag() async def map(self, input_value: ChatComposerInput) -> ModelRequest: - end_node: BaseOperator = self._sub_compose_dag.leaf_nodes[0] + """Compose the chat history prompt.""" + end_node: BaseOperator = cast(BaseOperator, self._sub_compose_dag.leaf_nodes[0]) # Sub dag, use the same dag context in the parent dag return await end_node.call( call_data={"data": input_value}, dag_ctx=self.current_dag_context @@ -82,7 +88,9 @@ def _build_composer_dag(self) -> DAG: history_prompt_build_task = HistoryPromptBuilderOperator( prompt=self._prompt_template, history_key=self._history_key ) - model_request_build_task = JoinOperator(self._build_model_request) + model_request_build_task: JoinOperator[ModelRequest] = JoinOperator( + combine_function=self._build_model_request + ) # Build composer dag ( @@ -113,5 +121,6 @@ def _build_model_request( return ModelRequest.build_request(messages=messages, **model_dict) async def after_dag_end(self): + """Execute after dag end.""" # Should call after_dag_end() of sub dag await self._sub_compose_dag._after_dag_end() diff --git a/dbgpt/core/interface/operator/llm_operator.py b/dbgpt/core/interface/operator/llm_operator.py index 5570aa1d6..b7e803436 100644 --- a/dbgpt/core/interface/operator/llm_operator.py +++ b/dbgpt/core/interface/operator/llm_operator.py @@ -1,9 +1,12 @@ +"""The LLM operator.""" + import dataclasses from abc import ABC from typing import Any, AsyncIterator, Dict, List, Optional, Union from dbgpt._private.pydantic import BaseModel from dbgpt.core.awel import ( + BaseOperator, BranchFunc, BranchOperator, DAGContext, @@ -32,11 +35,13 @@ class RequestBuilderOperator(MapOperator[RequestInput, ModelRequest], ABC): """Build the model request from the input value.""" def __init__(self, model: Optional[str] = None, **kwargs): + """Create a new request builder operator.""" self._model = model super().__init__(**kwargs) async def map(self, input_value: RequestInput) -> ModelRequest: - req_dict = {} + """Transform the input value to a model request.""" + req_dict: Dict[str, Any] = {} if not input_value: raise ValueError("input_value is not set") if isinstance(input_value, str): @@ -47,7 +52,9 @@ async def map(self, input_value: RequestInput) -> ModelRequest: req_dict = {"messages": [input_value]} elif isinstance(input_value, list) and isinstance(input_value[0], ModelMessage): req_dict = {"messages": input_value} - elif dataclasses.is_dataclass(input_value): + elif dataclasses.is_dataclass(input_value) and not isinstance( + input_value, type + ): req_dict = dataclasses.asdict(input_value) elif isinstance(input_value, BaseModel): req_dict = input_value.dict() @@ -90,6 +97,7 @@ class BaseLLM: SHARE_DATA_KEY_MODEL_OUTPUT = "share_data_key_model_output" def __init__(self, llm_client: Optional[LLMClient] = None): + """Create a new LLM operator.""" self._llm_client = llm_client @property @@ -118,10 +126,19 @@ class BaseLLMOperator(BaseLLM, MapOperator[ModelRequest, ModelOutput], ABC): """ def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs): + """Create a new LLM operator.""" super().__init__(llm_client=llm_client) MapOperator.__init__(self, **kwargs) async def map(self, request: ModelRequest) -> ModelOutput: + """Generate the model output. + + Args: + request (ModelRequest): The model request. + + Returns: + ModelOutput: The model output. + """ await self.current_dag_context.save_to_share_data( self.SHARE_DATA_KEY_MODEL_NAME, request.model ) @@ -142,15 +159,23 @@ class BaseStreamingLLMOperator( """ def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs): + """Create a streaming operator for a LLM. + + Args: + llm_client (LLMClient, optional): The LLM client. Defaults to None. + """ super().__init__(llm_client=llm_client) - StreamifyAbsOperator.__init__(self, **kwargs) + BaseOperator.__init__(self, **kwargs) - async def streamify(self, request: ModelRequest) -> AsyncIterator[ModelOutput]: + async def streamify( # type: ignore + self, request: ModelRequest # type: ignore + ) -> AsyncIterator[ModelOutput]: # type: ignore + """Streamify the request.""" await self.current_dag_context.save_to_share_data( self.SHARE_DATA_KEY_MODEL_NAME, request.model ) model_output = None - async for output in self.llm_client.generate_stream(request): + async for output in self.llm_client.generate_stream(request): # type: ignore model_output = output yield output if model_output: @@ -160,10 +185,17 @@ async def streamify(self, request: ModelRequest) -> AsyncIterator[ModelOutput]: class LLMBranchOperator(BranchOperator[ModelRequest, ModelRequest]): """Branch operator for LLM. - This operator will branch the workflow based on the stream flag of the request. + This operator will branch the workflow based on + the stream flag of the request. """ def __init__(self, stream_task_name: str, no_stream_task_name: str, **kwargs): + """Create a new LLM branch operator. + + Args: + stream_task_name (str): The name of the streaming task. + no_stream_task_name (str): The name of the non-streaming task. + """ super().__init__(**kwargs) if not stream_task_name: raise ValueError("stream_task_name is not set") @@ -172,18 +204,22 @@ def __init__(self, stream_task_name: str, no_stream_task_name: str, **kwargs): self._stream_task_name = stream_task_name self._no_stream_task_name = no_stream_task_name - async def branches(self) -> Dict[BranchFunc[ModelRequest], str]: + async def branches( + self, + ) -> Dict[BranchFunc[ModelRequest], Union[BaseOperator, str]]: """ Return a dict of branch function and task name. Returns: - Dict[BranchFunc[ModelRequest], str]: A dict of branch function and task name. - the key is a predicate function, the value is the task name. If the predicate function returns True, - we will run the corresponding task. + Dict[BranchFunc[ModelRequest], str]: A dict of branch function and task + name. the key is a predicate function, the value is the task name. + If the predicate function returns True, we will run the corresponding + task. """ async def check_stream_true(r: ModelRequest) -> bool: - # If stream is true, we will run the streaming task. otherwise, we will run the non-streaming task. + # If stream is true, we will run the streaming task. otherwise, we will run + # the non-streaming task. return r.stream return { diff --git a/dbgpt/core/interface/operator/message_operator.py b/dbgpt/core/interface/operator/message_operator.py index fea21d0f9..a692d8d22 100644 --- a/dbgpt/core/interface/operator/message_operator.py +++ b/dbgpt/core/interface/operator/message_operator.py @@ -1,3 +1,4 @@ +"""The message operator.""" import uuid from abc import ABC from typing import Any, Callable, Dict, List, Optional, Union @@ -36,6 +37,7 @@ def __init__( check_storage: bool = True, **kwargs, ): + """Create a new BaseConversationOperator.""" self._check_storage = check_storage self._storage = storage self._message_storage = message_storage @@ -102,7 +104,8 @@ class PreChatHistoryLoadOperator( ): """The operator to prepare the storage conversation. - In DB-GPT, conversation record and the messages in the conversation are stored in the storage, + In DB-GPT, conversation record and the messages in the conversation are stored in + the storage, and they can store in different storage(for high performance). This operator just load the conversation and messages from storage. @@ -115,6 +118,7 @@ def __init__( include_system_message: bool = False, **kwargs, ): + """Create a new PreChatHistoryLoadOperator.""" super().__init__(storage=storage, message_storage=message_storage) MapOperator.__init__(self, **kwargs) self._include_system_message = include_system_message @@ -139,7 +143,8 @@ async def map(self, input_value: ChatHistoryLoadType) -> List[BaseMessage]: chat_mode = input_value.chat_mode - # Create a new storage conversation, this will load the conversation from storage, so we must do this async + # Create a new storage conversation, this will load the conversation from + # storage, so we must do this async storage_conv: StorageConversation = await self.blocking_func_to_async( StorageConversation, conv_uid=input_value.conv_uid, @@ -167,14 +172,21 @@ async def map(self, input_value: ChatHistoryLoadType) -> List[BaseMessage]: class ConversationMapperOperator( BaseConversationOperator, MapOperator[List[BaseMessage], List[BaseMessage]] ): - def __init__(self, message_mapper: _MultiRoundMessageMapper = None, **kwargs): + """The base conversation mapper operator.""" + + def __init__( + self, message_mapper: Optional[_MultiRoundMessageMapper] = None, **kwargs + ): + """Create a new ConversationMapperOperator.""" MapOperator.__init__(self, **kwargs) self._message_mapper = message_mapper async def map(self, input_value: List[BaseMessage]) -> List[BaseMessage]: + """Map the input value to a ModelRequest.""" return await self.map_messages(input_value) async def map_messages(self, messages: List[BaseMessage]) -> List[BaseMessage]: + """Map multi round messages to a list of BaseMessage.""" messages_by_round: List[List[BaseMessage]] = _split_messages_by_round(messages) message_mapper = self._message_mapper or self.map_multi_round_messages return message_mapper(messages_by_round) @@ -184,11 +196,11 @@ def map_multi_round_messages( ) -> List[BaseMessage]: """Map multi round messages to a list of BaseMessage. - By default, just merge all multi round messages to a list of BaseMessage according origin order. + By default, just merge all multi round messages to a list of BaseMessage + according origin order. And you can overwrite this method to implement your own logic. Examples: - Merge multi round messages to a list of BaseMessage according origin order. >>> from dbgpt.core.interface.message import ( @@ -215,7 +227,8 @@ def map_multi_round_messages( ... AIMessage(content="Just a joke.", round_index=2), ... ] - Map multi round messages to a list of BaseMessage just keep the last one round. + Map multi round messages to a list of BaseMessage just keep the last one + round. >>> class MyMapper(ConversationMapperOperator): ... def __init__(self, **kwargs): @@ -234,13 +247,16 @@ def map_multi_round_messages( ... ] Args: + messages_by_round (List[List[BaseMessage]]): + The messages grouped by round. """ # Just merge and return return _merge_multi_round_messages(messages_by_round) class BufferedConversationMapperOperator(ConversationMapperOperator): - """ + """Buffered conversation mapper operator. + The buffered conversation mapper operator which can be configured to keep a certain number of starting and/or ending rounds of a conversation. @@ -249,39 +265,44 @@ class BufferedConversationMapperOperator(ConversationMapperOperator): keep_end_rounds (Optional[int]): Number of final rounds to keep. Examples: - # Keeping the first 2 and the last 1 rounds of a conversation - import asyncio - from dbgpt.core.interface.message import AIMessage, HumanMessage - from dbgpt.core.operator import BufferedConversationMapperOperator - - operator = BufferedConversationMapperOperator(keep_start_rounds=2, keep_end_rounds=1) - messages = [ - # Assume each HumanMessage and AIMessage belongs to separate rounds - HumanMessage(content="Hi", round_index=1), - AIMessage(content="Hello!", round_index=1), - HumanMessage(content="How are you?", round_index=2), - AIMessage(content="I'm good, thanks!", round_index=2), - HumanMessage(content="What's new today?", round_index=3), - AIMessage(content="Lots of things!", round_index=3), - ] - # This will keep rounds 1, 2, and 3 - assert asyncio.run(operator.map_messages(messages)) == [ - HumanMessage(content="Hi", round_index=1), - AIMessage(content="Hello!", round_index=1), - HumanMessage(content="How are you?", round_index=2), - AIMessage(content="I'm good, thanks!", round_index=2), - HumanMessage(content="What's new today?", round_index=3), - AIMessage(content="Lots of things!", round_index=3), - ] + .. code-block:: python + + # Keeping the first 2 and the last 1 rounds of a conversation + import asyncio + from dbgpt.core.interface.message import AIMessage, HumanMessage + from dbgpt.core.operator import BufferedConversationMapperOperator + + operator = BufferedConversationMapperOperator( + keep_start_rounds=2, keep_end_rounds=1 + ) + messages = [ + # Assume each HumanMessage and AIMessage belongs to separate rounds + HumanMessage(content="Hi", round_index=1), + AIMessage(content="Hello!", round_index=1), + HumanMessage(content="How are you?", round_index=2), + AIMessage(content="I'm good, thanks!", round_index=2), + HumanMessage(content="What's new today?", round_index=3), + AIMessage(content="Lots of things!", round_index=3), + ] + # This will keep rounds 1, 2, and 3 + assert asyncio.run(operator.map_messages(messages)) == [ + HumanMessage(content="Hi", round_index=1), + AIMessage(content="Hello!", round_index=1), + HumanMessage(content="How are you?", round_index=2), + AIMessage(content="I'm good, thanks!", round_index=2), + HumanMessage(content="What's new today?", round_index=3), + AIMessage(content="Lots of things!", round_index=3), + ] """ def __init__( self, keep_start_rounds: Optional[int] = None, keep_end_rounds: Optional[int] = None, - message_mapper: _MultiRoundMessageMapper = None, + message_mapper: Optional[_MultiRoundMessageMapper] = None, **kwargs, ): + """Create a new BufferedConversationMapperOperator.""" # Validate the input parameters if keep_start_rounds is not None and keep_start_rounds < 0: raise ValueError("keep_start_rounds must be non-negative") @@ -311,10 +332,11 @@ def new_message_mapper( def _filter_round_messages( self, messages_by_round: List[List[BaseMessage]] ) -> List[List[BaseMessage]]: - """Filters the messages to keep only the specified starting and/or ending rounds. + """Return a filtered list of messages. - Examples: + Filters the messages to keep only the specified starting and/or ending rounds. + Examples: >>> from dbgpt.core import AIMessage, HumanMessage >>> from dbgpt.core.operator import BufferedConversationMapperOperator >>> messages = [ @@ -395,15 +417,18 @@ def _filter_round_messages( ... ] Args: - messages_by_round (List[List[BaseMessage]]): The messages grouped by round. + messages_by_round (List[List[BaseMessage]]): + The messages grouped by round. + + Returns: + List[List[BaseMessage]]: Filtered list of messages. - Returns: - List[List[BaseMessage]]: Filtered list of messages. """ total_rounds = len(messages_by_round) if self._keep_start_rounds is not None and self._keep_end_rounds is not None: if self._keep_start_rounds + self._keep_end_rounds > total_rounds: - # Avoid overlapping when the sum of start and end rounds exceeds total rounds + # Avoid overlapping when the sum of start and end rounds exceeds total + # rounds return messages_by_round return ( messages_by_round[: self._keep_start_rounds] @@ -423,14 +448,16 @@ def _filter_round_messages( class TokenBufferedConversationMapperOperator(ConversationMapperOperator): """The token buffered conversation mapper operator. - If the token count of the messages is greater than the max token limit, we will evict the messages by round. + If the token count of the messages is greater than the max token limit, we will + evict the messages by round. Args: model (str): The model name. llm_client (LLMClient): The LLM client. max_token_limit (int): The max token limit. eviction_policy (EvictionPolicyType): The eviction policy. - message_mapper (_MultiRoundMessageMapper): The message mapper, it applies after all messages are handled. + message_mapper (_MultiRoundMessageMapper): The message mapper, it applies after + all messages are handled. """ def __init__( @@ -438,10 +465,11 @@ def __init__( model: str, llm_client: LLMClient, max_token_limit: int = 2000, - eviction_policy: EvictionPolicyType = None, - message_mapper: _MultiRoundMessageMapper = None, + eviction_policy: Optional[EvictionPolicyType] = None, + message_mapper: Optional[_MultiRoundMessageMapper] = None, **kwargs, ): + """Create a new TokenBufferedConversationMapperOperator.""" if max_token_limit < 0: raise ValueError("Max token limit can't be negative") self._model = model @@ -452,6 +480,7 @@ def __init__( super().__init__(**kwargs) async def map_messages(self, messages: List[BaseMessage]) -> List[BaseMessage]: + """Map multi round messages to a list of BaseMessage.""" eviction_policy = self._eviction_policy or self.eviction_policy messages_by_round: List[List[BaseMessage]] = _split_messages_by_round(messages) messages_str = _messages_to_str(_merge_multi_round_messages(messages_by_round)) @@ -459,7 +488,8 @@ async def map_messages(self, messages: List[BaseMessage]) -> List[BaseMessage]: current_tokens = await self._llm_client.count_token(self._model, messages_str) while current_tokens > self._max_token_limit: - # Evict the messages by round after all tokens are not greater than the max token limit + # Evict the messages by round after all tokens are not greater than the max + # token limit # TODO: We should find a high performance way to do this messages_by_round = eviction_policy(messages_by_round) messages_str = _messages_to_str( diff --git a/dbgpt/core/interface/operator/prompt_operator.py b/dbgpt/core/interface/operator/prompt_operator.py index 7cdde4349..5049c7026 100644 --- a/dbgpt/core/interface/operator/prompt_operator.py +++ b/dbgpt/core/interface/operator/prompt_operator.py @@ -1,8 +1,8 @@ +"""The prompt operator.""" from abc import ABC from typing import Any, Dict, List, Optional, Union from dbgpt.core import ( - BasePromptTemplate, ChatPromptTemplate, ModelMessage, ModelMessageRoleType, @@ -13,7 +13,13 @@ from dbgpt.core.interface.message import BaseMessage from dbgpt.core.interface.operator.llm_operator import BaseLLM from dbgpt.core.interface.operator.message_operator import BaseConversationOperator -from dbgpt.core.interface.prompt import HumanPromptTemplate, MessageType +from dbgpt.core.interface.prompt import ( + BaseChatPromptTemplate, + HumanPromptTemplate, + MessagesPlaceholder, + MessageType, + PromptTemplate, +) from dbgpt.util.function_utils import rearrange_args_by_type @@ -21,6 +27,7 @@ class BasePromptBuilderOperator(BaseConversationOperator, ABC): """The base prompt builder operator.""" def __init__(self, check_storage: bool, **kwargs): + """Create a new prompt builder operator.""" super().__init__(check_storage=check_storage, **kwargs) async def format_prompt( @@ -39,10 +46,10 @@ async def format_prompt( kwargs.update(prompt_dict) pass_kwargs = {k: v for k, v in kwargs.items() if k in prompt.input_variables} messages = prompt.format_messages(**pass_kwargs) - messages = ModelMessage.from_base_messages(messages) + model_messages = ModelMessage.from_base_messages(messages) # Start new round conversation, and save user message to storage - await self.start_new_round_conv(messages) - return messages + await self.start_new_round_conv(model_messages) + return model_messages async def start_new_round_conv(self, messages: List[ModelMessage]) -> None: """Start a new round conversation. @@ -50,7 +57,6 @@ async def start_new_round_conv(self, messages: List[ModelMessage]) -> None: Args: messages (List[ModelMessage]): The messages. """ - lass_user_message = None for message in messages[::-1]: if message.role == ModelMessageRoleType.HUMAN: @@ -58,7 +64,9 @@ async def start_new_round_conv(self, messages: List[ModelMessage]) -> None: break if not lass_user_message: raise ValueError("No user message") - storage_conv: StorageConversation = await self.get_storage_conversation() + storage_conv: Optional[ + StorageConversation + ] = await self.get_storage_conversation() if not storage_conv: return # Start new round @@ -66,13 +74,17 @@ async def start_new_round_conv(self, messages: List[ModelMessage]) -> None: storage_conv.add_user_message(lass_user_message) async def after_dag_end(self): - """The callback after DAG end""" - # TODO remove this to start_new_round() + """Execute after the DAG finished.""" # Save the storage conversation to storage after the whole DAG finished - storage_conv: StorageConversation = await self.get_storage_conversation() + storage_conv: Optional[ + StorageConversation + ] = await self.get_storage_conversation() + if not storage_conv: return - model_output: ModelOutput = await self.current_dag_context.get_from_share_data( + model_output: Optional[ + ModelOutput + ] = await self.current_dag_context.get_from_share_data( BaseLLM.SHARE_DATA_KEY_MODEL_OUTPUT ) if model_output: @@ -82,7 +94,7 @@ async def after_dag_end(self): storage_conv.end_current_round() -PromptTemplateType = Union[ChatPromptTemplate, BasePromptTemplate, MessageType, str] +PromptTemplateType = Union[ChatPromptTemplate, PromptTemplate, MessageType, str] class PromptBuilderOperator( @@ -91,7 +103,6 @@ class PromptBuilderOperator( """The operator to build the prompt with static prompt. Examples: - .. code-block:: python import asyncio @@ -119,7 +130,8 @@ class PromptBuilderOperator( ChatPromptTemplate( messages=[ HumanPromptTemplate.from_template( - "Please write a {dialect} SQL count the length of a field" + "Please write a {dialect} SQL count the length of a" + " field" ) ] ) @@ -131,7 +143,8 @@ class PromptBuilderOperator( "You are a {dialect} SQL expert" ), HumanPromptTemplate.from_template( - "Please write a {dialect} SQL count the length of a field" + "Please write a {dialect} SQL count the length of a" + " field" ), ], ) @@ -171,17 +184,18 @@ class PromptBuilderOperator( """ def __init__(self, prompt: PromptTemplateType, **kwargs): + """Create a new prompt builder operator.""" if isinstance(prompt, str): prompt = ChatPromptTemplate( messages=[HumanPromptTemplate.from_template(prompt)] ) - elif isinstance(prompt, BasePromptTemplate) and not isinstance( - prompt, ChatPromptTemplate - ): + elif isinstance(prompt, PromptTemplate): prompt = ChatPromptTemplate( messages=[HumanPromptTemplate.from_template(prompt.template)] ) - elif isinstance(prompt, MessageType): + elif isinstance( + prompt, (BaseChatPromptTemplate, MessagesPlaceholder, BaseMessage) + ): prompt = ChatPromptTemplate(messages=[prompt]) self._prompt = prompt @@ -190,6 +204,7 @@ def __init__(self, prompt: PromptTemplateType, **kwargs): @rearrange_args_by_type async def merge_prompt(self, prompt_dict: Dict[str, Any]) -> List[ModelMessage]: + """Format the prompt.""" return await self.format_prompt(self._prompt, prompt_dict) @@ -202,6 +217,7 @@ class DynamicPromptBuilderOperator( """ def __init__(self, **kwargs): + """Create a new dynamic prompt builder operator.""" super().__init__(check_storage=False, **kwargs) JoinOperator.__init__(self, combine_function=self.merge_prompt, **kwargs) @@ -209,20 +225,37 @@ def __init__(self, **kwargs): async def merge_prompt( self, prompt: ChatPromptTemplate, prompt_dict: Dict[str, Any] ) -> List[ModelMessage]: + """Merge the prompt and history.""" return await self.format_prompt(prompt, prompt_dict) class HistoryPromptBuilderOperator( BasePromptBuilderOperator, JoinOperator[List[ModelMessage]] ): + """The operator to build the prompt with static prompt. + + The prompt will pass to this operator. + """ + def __init__( self, prompt: ChatPromptTemplate, - history_key: Optional[str] = None, + history_key: str = "chat_history", check_storage: bool = True, str_history: bool = False, **kwargs, ): + """Create a new history prompt builder operator. + + Args: + prompt (ChatPromptTemplate): The prompt. + history_key (str, optional): The key of history in prompt dict. Defaults + to "chat_history". + check_storage (bool, optional): Whether to check the storage. + Defaults to True. + str_history (bool, optional): Whether to convert the history to string. + Defaults to False. + """ self._prompt = prompt self._history_key = history_key self._str_history = str_history @@ -233,6 +266,7 @@ def __init__( async def merge_history( self, history: List[BaseMessage], prompt_dict: Dict[str, Any] ) -> List[ModelMessage]: + """Merge the prompt and history.""" if self._str_history: prompt_dict[self._history_key] = BaseMessage.messages_to_string(history) else: @@ -250,11 +284,12 @@ class HistoryDynamicPromptBuilderOperator( def __init__( self, - history_key: Optional[str] = None, + history_key: str = "chat_history", check_storage: bool = True, str_history: bool = False, **kwargs, ): + """Create a new history dynamic prompt builder operator.""" self._history_key = history_key self._str_history = str_history BasePromptBuilderOperator.__init__(self, check_storage=check_storage) @@ -267,6 +302,7 @@ async def merge_history( history: List[BaseMessage], prompt_dict: Dict[str, Any], ) -> List[ModelMessage]: + """Merge the prompt and history.""" if self._str_history: prompt_dict[self._history_key] = BaseMessage.messages_to_string(history) else: diff --git a/dbgpt/core/interface/retriever.py b/dbgpt/core/interface/operator/retriever.py similarity index 89% rename from dbgpt/core/interface/retriever.py rename to dbgpt/core/interface/operator/retriever.py index 385295534..0ea4de8c7 100644 --- a/dbgpt/core/interface/retriever.py +++ b/dbgpt/core/interface/operator/retriever.py @@ -1,3 +1,4 @@ +"""The Abstract Retriever Operator.""" from abc import abstractmethod from dbgpt.core.awel import MapOperator @@ -16,7 +17,8 @@ async def map(self, input_value: IN) -> OUT: Returns: OUT: The output value. """ - # The retrieve function is blocking, so we need to wrap it in a blocking_func_to_async. + # The retrieve function is blocking, so we need to wrap it in a + # blocking_func_to_async. return await self.blocking_func_to_async(self.retrieve, input_value) @abstractmethod diff --git a/dbgpt/core/interface/output_parser.py b/dbgpt/core/interface/output_parser.py index b06c45dc6..5a9c5f447 100644 --- a/dbgpt/core/interface/output_parser.py +++ b/dbgpt/core/interface/output_parser.py @@ -1,10 +1,15 @@ +"""The output parser is used to parse the output of an LLM call. + +TODO: Make this more general and clear. +""" + from __future__ import annotations import json import logging from abc import ABC from dataclasses import asdict -from typing import Any, Dict, TypeVar, Union +from typing import Any, TypeVar, Union from dbgpt.core import ModelOutput from dbgpt.core.awel import MapOperator @@ -22,11 +27,16 @@ class BaseOutputParser(MapOperator[ModelOutput, Any], ABC): """ def __init__(self, is_stream_out: bool = True, **kwargs): + """Create a new output parser.""" super().__init__(**kwargs) self.is_stream_out = is_stream_out self.data_schema = None def update(self, data_schema): + """Update the data schema. + + TODO: Remove this method. + """ self.data_schema = data_schema def __post_process_code(self, code): @@ -40,9 +50,16 @@ def __post_process_code(self, code): return code def parse_model_stream_resp_ex(self, chunk: ResponseTye, skip_echo_len): - data = _parse_model_response(chunk) - """ TODO Multi mode output handler, rewrite this for multi model, use adapter mode. + """Parse the output of an LLM call. + + Args: + chunk (ResponseTye): The output of an LLM call. + skip_echo_len (int): The length of the prompt to skip. """ + data = _parse_model_response(chunk) + # TODO: Multi mode output handler, rewrite this for multi model, use adapter + # mode. + model_context = data.get("model_context") has_echo = False if model_context and "prompt_echo_len_char" in model_context: @@ -65,6 +82,7 @@ def parse_model_stream_resp_ex(self, chunk: ResponseTye, skip_echo_len): return output def parse_model_nostream_resp(self, response: ResponseTye, sep: str): + """Parse the output of an LLM call.""" resp_obj_ex = _parse_model_response(response) if isinstance(resp_obj_ex, str): resp_obj_ex = json.loads(resp_obj_ex) @@ -89,7 +107,8 @@ def parse_model_nostream_resp(self, response: ResponseTye, sep: str): return ai_response else: raise ValueError( - f"""Model server error!code={resp_obj_ex["error_code"]}, errmsg is {resp_obj_ex["text"]}""" + f"Model server error!code={resp_obj_ex['error_code']}, error msg is " + f"{resp_obj_ex['text']}" ) def _illegal_json_ends(self, s): @@ -117,7 +136,7 @@ def _extract_json(self, s): temp_json = self._illegal_json_ends(temp_json) return temp_json - except Exception as e: + except Exception: raise ValueError("Failed to find a valid json in LLM response!" + temp_json) def _json_interception(self, s, is_json_array: bool = False): @@ -150,17 +169,17 @@ def _json_interception(self, s, is_json_array: bool = False): break assert count == 0 return s[i : j + 1] - except Exception as e: + except Exception: return "" - def parse_prompt_response(self, model_out_text) -> T: - """ - parse model out text to prompt define response + def parse_prompt_response(self, model_out_text) -> Any: + """Parse model out text to prompt define response. + Args: - model_out_text: + model_out_text: The output of an LLM call. Returns: - + Any: The parsed output of an LLM call. """ cleaned_output = model_out_text.rstrip() if "```json" in cleaned_output: @@ -194,12 +213,15 @@ def parse_prompt_response(self, model_out_text) -> T: def parse_view_response( self, ai_text, data, parse_prompt_response: Any = None ) -> str: - """ - parse the ai response info to user view + """Parse the AI response info to user view. + Args: - text: + ai_text (str): The output of an LLM call. + data (dict): The data has been handled by some scene. + parse_prompt_response (Any): The prompt response has been parsed. Returns: + str: The parsed output of an LLM call. """ return ai_text @@ -240,10 +262,14 @@ def _parse_model_response(response: ResponseTye): class SQLOutputParser(BaseOutputParser): + """Parse the SQL output of an LLM call.""" + def __init__(self, is_stream_out: bool = False, **kwargs): + """Create a new SQL output parser.""" super().__init__(is_stream_out=is_stream_out, **kwargs) def parse_model_nostream_resp(self, response: ResponseTye, sep: str): + """Parse the output of an LLM call.""" model_out_text = super().parse_model_nostream_resp(response, sep) clean_str = super().parse_prompt_response(model_out_text) return json.loads(clean_str, strict=True) diff --git a/dbgpt/core/interface/prompt.py b/dbgpt/core/interface/prompt.py index 1788357c4..bc536aab0 100644 --- a/dbgpt/core/interface/prompt.py +++ b/dbgpt/core/interface/prompt.py @@ -1,3 +1,5 @@ +"""The prompt template interface.""" + from __future__ import annotations import dataclasses @@ -7,9 +9,7 @@ from typing import Any, Callable, Dict, List, Optional, Set, Union from dbgpt._private.pydantic import BaseModel, root_validator -from dbgpt.core._private.example_base import ExampleSelector from dbgpt.core.interface.message import BaseMessage, HumanMessage, SystemMessage -from dbgpt.core.interface.output_parser import BaseOutputParser from dbgpt.core.interface.storage import ( InMemoryStorage, QuerySpec, @@ -42,36 +42,56 @@ def _jinja2_formatter(template: str, **kwargs: Any) -> str: class BasePromptTemplate(BaseModel): + """Base class for all prompt templates, returning a prompt.""" + input_variables: List[str] """A list of the names of the variables the prompt template expects.""" - template: Optional[str] + +class PromptTemplate(BasePromptTemplate): + """Prompt template.""" + + template: str """The prompt template.""" - template_format: Optional[str] = "f-string" + template_format: str = "f-string" """The format of the prompt template. Options are: 'f-string', 'jinja2'.""" + response_key: str = "response" + + template_is_strict: bool = True + """strict template will check template args""" + response_format: Optional[str] = None - response_key: Optional[str] = "response" + template_scene: Optional[str] = None - template_is_strict: Optional[bool] = True - """strict template will check template args""" + template_define: Optional[str] = None + """this template define""" + + class Config: + """Configuration for this pydantic object.""" + + arbitrary_types_allowed = True + + @property + def _prompt_type(self) -> str: + """Return the prompt type key.""" + return "prompt" def format(self, **kwargs: Any) -> str: """Format the prompt with the inputs.""" - if self.template: - if self.response_format: - kwargs[self.response_key] = json.dumps( - self.response_format, ensure_ascii=False, indent=4 - ) - return _DEFAULT_FORMATTER_MAPPING[self.template_format]( - self.template_is_strict - )(self.template, **kwargs) + if self.response_format: + kwargs[self.response_key] = json.dumps( + self.response_format, ensure_ascii=False, indent=4 + ) + return _DEFAULT_FORMATTER_MAPPING[self.template_format]( + self.template_is_strict + )(self.template, **kwargs) @classmethod def from_template( - cls, template: str, template_format: Optional[str] = "f-string", **kwargs: Any + cls, template: str, template_format: str = "f-string", **kwargs: Any ) -> BasePromptTemplate: """Create a prompt template from a template string.""" input_variables = get_template_vars(template, template_format) @@ -83,41 +103,14 @@ def from_template( ) -class PromptTemplate(BasePromptTemplate): - template_scene: Optional[str] - template_define: Optional[str] - """this template define""" - """default use stream out""" - stream_out: bool = True - """""" - output_parser: BaseOutputParser = None - """""" - sep: str = "###" - - example_selector: ExampleSelector = None - - need_historical_messages: bool = False - - temperature: float = 0.6 - max_new_tokens: int = 1024 - - class Config: - """Configuration for this pydantic object.""" - - arbitrary_types_allowed = True - - @property - def _prompt_type(self) -> str: - """Return the prompt type key.""" - return "prompt" - - class BaseChatPromptTemplate(BaseModel, ABC): + """The base chat prompt template.""" + prompt: BasePromptTemplate @property def input_variables(self) -> List[str]: - """A list of the names of the variables the prompt template expects.""" + """Return a list of the names of the variables the prompt template expects.""" return self.prompt.input_variables @abstractmethod @@ -128,14 +121,14 @@ def format_messages(self, **kwargs: Any) -> List[BaseMessage]: def from_template( cls, template: str, - template_format: Optional[str] = "f-string", + template_format: str = "f-string", response_format: Optional[str] = None, - response_key: Optional[str] = "response", + response_key: str = "response", template_is_strict: bool = True, **kwargs: Any, ) -> BaseChatPromptTemplate: """Create a prompt template from a template string.""" - prompt = BasePromptTemplate.from_template( + prompt = PromptTemplate.from_template( template, template_format, response_format=response_format, @@ -149,6 +142,11 @@ class SystemPromptTemplate(BaseChatPromptTemplate): """The system prompt template.""" def format_messages(self, **kwargs: Any) -> List[BaseMessage]: + """Format the prompt with the inputs. + + Returns: + List[BaseMessage]: The formatted messages. + """ content = self.prompt.format(**kwargs) return [SystemMessage(content=content)] @@ -157,20 +155,31 @@ class HumanPromptTemplate(BaseChatPromptTemplate): """The human prompt template.""" def format_messages(self, **kwargs: Any) -> List[BaseMessage]: + """Format the prompt with the inputs. + + Returns: + List[BaseMessage]: The formatted messages. + """ content = self.prompt.format(**kwargs) return [HumanMessage(content=content)] -class MessagesPlaceholder(BaseChatPromptTemplate): +class MessagesPlaceholder(BaseModel): """The messages placeholder template. Mostly used for the chat history. """ variable_name: str - prompt: BasePromptTemplate = None def format_messages(self, **kwargs: Any) -> List[BaseMessage]: + """Format the prompt with the inputs. + + Just return the messages from the kwargs with the variable name. + + Returns: + List[BaseMessage]: The messages. + """ messages = kwargs.get(self.variable_name, []) if not isinstance(messages, list): raise ValueError( @@ -185,7 +194,7 @@ def format_messages(self, **kwargs: Any) -> List[BaseMessage]: @property def input_variables(self) -> List[str]: - """A list of the names of the variables the prompt template expects. + """Return a list of the names of the variables the prompt template expects. Returns: List[str]: The input variables. @@ -193,10 +202,26 @@ def input_variables(self) -> List[str]: return [self.variable_name] -MessageType = Union[BaseChatPromptTemplate, BaseMessage] +MessageType = Union[BaseChatPromptTemplate, MessagesPlaceholder, BaseMessage] class ChatPromptTemplate(BasePromptTemplate): + """The chat prompt template. + + Examples: + .. code-block:: python + + prompt_template = ChatPromptTemplate( + messages=[ + SystemPromptTemplate.from_template( + "You are a helpful AI assistant." + ), + MessagesPlaceholder(variable_name="chat_history"), + HumanPromptTemplate.from_template("{question}"), + ] + ) + """ + messages: List[MessageType] def format_messages(self, **kwargs: Any) -> List[BaseMessage]: @@ -205,12 +230,7 @@ def format_messages(self, **kwargs: Any) -> List[BaseMessage]: for message in self.messages: if isinstance(message, BaseMessage): result_messages.append(message) - elif isinstance(message, BaseChatPromptTemplate): - pass_kwargs = { - k: v for k, v in kwargs.items() if k in message.input_variables - } - result_messages.extend(message.format_messages(**pass_kwargs)) - elif isinstance(message, MessagesPlaceholder): + elif isinstance(message, (BaseChatPromptTemplate, MessagesPlaceholder)): pass_kwargs = { k: v for k, v in kwargs.items() if k in message.input_variables } @@ -227,7 +247,7 @@ def pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]: if not input_variables: input_variables = set() for message in messages: - if isinstance(message, BaseChatPromptTemplate): + if isinstance(message, (BaseChatPromptTemplate, MessagesPlaceholder)): input_variables.update(message.input_variables) values["input_variables"] = sorted(input_variables) return values @@ -235,6 +255,8 @@ def pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]: @dataclasses.dataclass class PromptTemplateIdentifier(ResourceIdentifier): + """The identifier of a prompt template.""" + identifier_split: str = dataclasses.field(default="___$$$$___", init=False) prompt_name: str prompt_language: Optional[str] = None @@ -242,6 +264,7 @@ class PromptTemplateIdentifier(ResourceIdentifier): model: Optional[str] = None def __post_init__(self): + """Post init method.""" if self.prompt_name is None: raise ValueError("prompt_name cannot be None") @@ -256,11 +279,13 @@ def __post_init__(self): if key is not None ): raise ValueError( - f"identifier_split {self.identifier_split} is not allowed in prompt_name, prompt_language, sys_code, model" + f"identifier_split {self.identifier_split} is not allowed in " + f"prompt_name, prompt_language, sys_code, model" ) @property def str_identifier(self) -> str: + """Return the string identifier of the identifier.""" return self.identifier_split.join( key for key in [ @@ -273,6 +298,11 @@ def str_identifier(self) -> str: ) def to_dict(self) -> Dict: + """Convert the identifier to a dict. + + Returns: + Dict: The dict of the identifier. + """ return { "prompt_name": self.prompt_name, "prompt_language": self.prompt_language, @@ -283,6 +313,8 @@ def to_dict(self) -> Dict: @dataclasses.dataclass class StoragePromptTemplate(StorageItem): + """The storage prompt template.""" + prompt_name: str content: Optional[str] = None prompt_language: Optional[str] = None @@ -297,25 +329,28 @@ class StoragePromptTemplate(StorageItem): _identifier: PromptTemplateIdentifier = dataclasses.field(init=False) def __post_init__(self): + """Post init method.""" self._identifier = PromptTemplateIdentifier( prompt_name=self.prompt_name, prompt_language=self.prompt_language, sys_code=self.sys_code, model=self.model, ) - self._check() # Assuming _check() is a method you need to call after initialization + # Assuming _check() is a method you need to call after initialization + self._check() def to_prompt_template(self) -> PromptTemplate: """Convert the storage prompt template to a prompt template.""" input_variables = ( [] if not self.input_variables else self.input_variables.strip().split(",") ) + template_format = self.prompt_format or "f-string" return PromptTemplate( input_variables=input_variables, template=self.content, template_scene=self.chat_scene, - prompt_name=self.prompt_name, - template_format=self.prompt_format, + # prompt_name=self.prompt_name, + template_format=template_format, ) @staticmethod @@ -335,12 +370,18 @@ def from_prompt_template( Args: prompt_template (PromptTemplate): The prompt template to convert from. prompt_name (str): The name of the prompt. - prompt_language (Optional[str], optional): The language of the prompt. Defaults to None. e.g. zh-cn, en. - prompt_type (Optional[str], optional): The type of the prompt. Defaults to None. e.g. common, private. - sys_code (Optional[str], optional): The system code of the prompt. Defaults to None. - user_name (Optional[str], optional): The username of the prompt. Defaults to None. - sub_chat_scene (Optional[str], optional): The sub chat scene of the prompt. Defaults to None. - model (Optional[str], optional): The model name of the prompt. Defaults to None. + prompt_language (Optional[str], optional): The language of the prompt. + Defaults to None. e.g. zh-cn, en. + prompt_type (Optional[str], optional): The type of the prompt. + Defaults to None. e.g. common, private. + sys_code (Optional[str], optional): The system code of the prompt. + Defaults to None. + user_name (Optional[str], optional): The username of the prompt. + Defaults to None. + sub_chat_scene (Optional[str], optional): The sub chat scene of the prompt. + Defaults to None. + model (Optional[str], optional): The model name of the prompt. + Defaults to None. kwargs (Dict): Other params to build the storage prompt template. """ input_variables = prompt_template.input_variables or kwargs.get( @@ -365,6 +406,7 @@ def from_prompt_template( @property def identifier(self) -> PromptTemplateIdentifier: + """Return the identifier of the storage prompt template.""" return self._identifier def merge(self, other: "StorageItem") -> None: @@ -375,11 +417,17 @@ def merge(self, other: "StorageItem") -> None: """ if not isinstance(other, StoragePromptTemplate): raise ValueError( - f"Cannot merge {type(other)} into {type(self)} because they are not the same type." + f"Cannot merge {type(other)} into {type(self)} because they are not " + f"the same type." ) self.from_object(other) def to_dict(self) -> Dict: + """Convert the storage prompt template to a dict. + + Returns: + Dict: The dict of the storage prompt template. + """ return { "prompt_name": self.prompt_name, "content": self.content, @@ -422,7 +470,6 @@ class PromptManager: Simple wrapper for the storage interface. Examples: - .. code-block:: python # Default use InMemoryStorage @@ -458,13 +505,14 @@ class PromptManager: def __init__( self, storage: Optional[StorageInterface[StoragePromptTemplate, Any]] = None ): + """Create a new prompt manager.""" if storage is None: storage = InMemoryStorage() self._storage = storage @property def storage(self) -> StorageInterface[StoragePromptTemplate, Any]: - """The storage interface for prompt templates.""" + """Return the storage interface for prompt templates.""" return self._storage def prefer_query( @@ -477,11 +525,12 @@ def prefer_query( ) -> List[StoragePromptTemplate]: """Query prompt templates from storage with prefer params. - Sometimes, we want to query prompt templates with prefer params(e.g. some language or some model). - This method will query prompt templates with prefer params first, if not found, will query all prompt templates. + Sometimes, we want to query prompt templates with prefer params(e.g. some + language or some model). + This method will query prompt templates with prefer params first, if not found, + will query all prompt templates. Examples: - Query a prompt template. .. code-block:: python @@ -500,7 +549,8 @@ def prefer_query( .. code-block:: python # First query with prompt name "hello" exactly. - # Second filter with prompt language "zh-cn", if not found, will return all prompt templates. + # Second filter with prompt language "zh-cn", if not found, will return + # all prompt templates. prompt_template_list = prompt_manager.prefer_query( "hello", prefer_prompt_language="zh-cn" ) @@ -510,17 +560,22 @@ def prefer_query( .. code-block:: python # First query with prompt name "hello" exactly. - # Second filter with model "vicuna-13b-v1.5", if not found, will return all prompt templates. + # Second filter with model "vicuna-13b-v1.5", if not found, will return + # all prompt templates. prompt_template_list = prompt_manager.prefer_query( "hello", prefer_model="vicuna-13b-v1.5" ) Args: prompt_name (str): The name of the prompt template. - sys_code (Optional[str], optional): The system code of the prompt template. Defaults to None. - prefer_prompt_language (Optional[str], optional): The language of the prompt template. Defaults to None. - prefer_model (Optional[str], optional): The model of the prompt template. Defaults to None. - kwargs (Dict): Other query params(If some key and value not None, wo we query it exactly). + sys_code (Optional[str], optional): The system code of the prompt template. + Defaults to None. + prefer_prompt_language (Optional[str], optional): The language of the + prompt template. Defaults to None. + prefer_model (Optional[str], optional): The model of the prompt template. + Defaults to None. + kwargs (Dict): Other query params(If some key and value not None, wo we + query it exactly). """ query_spec = QuerySpec( conditions={ @@ -559,7 +614,6 @@ def save(self, prompt_template: PromptTemplate, prompt_name: str, **kwargs) -> N """Save a prompt template to storage. Examples: - .. code-block:: python prompt_template = PromptTemplate( @@ -618,15 +672,17 @@ def query_or_save( if exist_prompt_template: return exist_prompt_template self.save(prompt_template, prompt_name, **kwargs) - return self.storage.load( + prompt = self.storage.load( storage_prompt_template.identifier, StoragePromptTemplate ) + if not prompt: + raise ValueError("Can't read prompt from storage") + return prompt def list(self, **kwargs) -> List[StoragePromptTemplate]: """List prompt templates from storage. Examples: - List all prompt templates. .. code-block:: python @@ -656,7 +712,6 @@ def delete( """Delete a prompt template from storage. Examples: - Delete a prompt template. .. code-block:: python @@ -673,9 +728,12 @@ def delete( Args: prompt_name (str): The name of the prompt template. - prompt_language (Optional[str], optional): The language of the prompt template. Defaults to None. - sys_code (Optional[str], optional): The system code of the prompt template. Defaults to None. - model (Optional[str], optional): The model of the prompt template. Defaults to None. + prompt_language (Optional[str], optional): The language of the prompt + template. Defaults to None. + sys_code (Optional[str], optional): The system code of the prompt template. + Defaults to None. + model (Optional[str], optional): The model of the prompt template. + Defaults to None. """ identifier = PromptTemplateIdentifier( prompt_name=prompt_name, diff --git a/dbgpt/core/interface/serialization.py b/dbgpt/core/interface/serialization.py index b1d8d60eb..72f1fe3b0 100644 --- a/dbgpt/core/interface/serialization.py +++ b/dbgpt/core/interface/serialization.py @@ -1,11 +1,15 @@ +"""The interface for serializing.""" + from __future__ import annotations from abc import ABC, abstractmethod -from typing import Dict, Type +from typing import Dict, Optional, Type class Serializable(ABC): - serializer: "Serializer" = None + """The serializable abstract class.""" + + serializer: Optional["Serializer"] = None @abstractmethod def to_dict(self) -> Dict: diff --git a/dbgpt/core/interface/storage.py b/dbgpt/core/interface/storage.py index 6a486cab7..86c28f3fa 100644 --- a/dbgpt/core/interface/storage.py +++ b/dbgpt/core/interface/storage.py @@ -1,5 +1,7 @@ +"""The storage interface for storing and loading data.""" + from abc import ABC, abstractmethod -from typing import Any, Dict, Generic, List, Optional, Type, TypeVar +from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, cast from dbgpt.core.interface.serialization import Serializable, Serializer from dbgpt.util.annotations import PublicAPI @@ -55,17 +57,22 @@ def merge(self, other: "StorageItem") -> None: """ +ID = TypeVar("ID", bound=ResourceIdentifier) T = TypeVar("T", bound=StorageItem) TDataRepresentation = TypeVar("TDataRepresentation") class StorageItemAdapter(Generic[T, TDataRepresentation]): - """The storage item adapter for converting storage items to and from the storage format. + """Storage item adapter. + + The storage item adapter for converting storage items to and from the storage + format. Sometimes, the storage item is not the same as the storage format, so we need to convert the storage item to the storage format and vice versa. - In database storage, the storage format is database model, but the StorageItem is the user-defined object. + In database storage, the storage format is database model, but the StorageItem is + the user-defined object. """ @abstractmethod @@ -110,20 +117,44 @@ def get_query_for_identifier( class DefaultStorageItemAdapter(StorageItemAdapter[T, T]): - """The default storage item adapter for converting storage items to and from the storage format. + """Default storage item adapter. + + The default storage item adapter for converting storage items to and from the + storage format. The storage item is the same as the storage format, so no conversion is required. """ def to_storage_format(self, item: T) -> T: + """Convert the storage item to the storage format. + + Returns the storage item itself. + + Args: + item (T): The storage item + + Returns: + T: The data in the storage format + """ return item def from_storage_format(self, data: T) -> T: + """Convert the storage format to the storage item. + + Returns the storage format itself. + + Args: + data (T): The data in the storage format + + Returns: + T: The storage item + """ return data def get_query_for_identifier( - self, storage_format: Type[T], resource_id: ResourceIdentifier, **kwargs + self, storage_format: Type[T], resource_id: ID, **kwargs ) -> bool: + """Return the query for the resource identifier.""" return True @@ -132,6 +163,7 @@ class StorageError(Exception): """The base exception class for storage errors.""" def __init__(self, message: str): + """Create a new StorageError.""" super().__init__(message) @@ -146,8 +178,9 @@ class QuerySpec: """ def __init__( - self, conditions: Dict[str, Any], limit: int = None, offset: int = 0 + self, conditions: Dict[str, Any], limit: Optional[int] = None, offset: int = 0 ) -> None: + """Create a new QuerySpec.""" self.conditions = conditions self.limit = limit self.offset = offset @@ -162,6 +195,7 @@ def __init__( serializer: Optional[Serializer] = None, adapter: Optional[StorageItemAdapter[T, TDataRepresentation]] = None, ): + """Create a new StorageInterface.""" self._serializer = serializer or JsonSerializer() self._storage_item_adapter = adapter or DefaultStorageItemAdapter() @@ -238,7 +272,7 @@ def save_or_update_list(self, data: List[T]) -> None: self.save_or_update(d) @abstractmethod - def load(self, resource_id: ResourceIdentifier, cls: Type[T]) -> Optional[T]: + def load(self, resource_id: ID, cls: Type[T]) -> Optional[T]: """Load the data from the storage. None will be returned if the data does not exist in the storage. @@ -247,14 +281,14 @@ def load(self, resource_id: ResourceIdentifier, cls: Type[T]) -> Optional[T]: so we suggest to use load if possible. Args: - resource_id (ResourceIdentifier): The resource identifier of the data + resource_id (ID): The resource identifier of the data cls (Type[T]): The type of the data Returns: Optional[T]: The loaded data """ - def load_list(self, resource_id: List[ResourceIdentifier], cls: Type[T]) -> List[T]: + def load_list(self, resource_id: List[ID], cls: Type[T]) -> List[T]: """Load the data from the storage. None will be returned if the data does not exist in the storage. @@ -263,7 +297,7 @@ def load_list(self, resource_id: List[ResourceIdentifier], cls: Type[T]) -> List so we suggest to use load if possible. Args: - resource_id (ResourceIdentifier): The resource identifier of the data + resource_id (ID): The resource identifier of the data cls (Type[T]): The type of the data Returns: @@ -277,18 +311,18 @@ def load_list(self, resource_id: List[ResourceIdentifier], cls: Type[T]) -> List return result @abstractmethod - def delete(self, resource_id: ResourceIdentifier) -> None: + def delete(self, resource_id: ID) -> None: """Delete the data from the storage. Args: - resource_id (ResourceIdentifier): The resource identifier of the data + resource_id (ID): The resource identifier of the data """ - def delete_list(self, resource_id: List[ResourceIdentifier]) -> None: + def delete_list(self, resource_id: List[ID]) -> None: """Delete the data from the storage. Args: - resource_id (ResourceIdentifier): The resource identifier of the data + resource_id (ID): The resource identifier of the data """ for r in resource_id: self.delete(r) @@ -297,7 +331,8 @@ def delete_list(self, resource_id: List[ResourceIdentifier]) -> None: def query(self, spec: QuerySpec, cls: Type[T]) -> List[T]: """Query data from the storage. - Query data with resource_id will be faster than query data with conditions, so please use load if possible. + Query data with resource_id will be faster than query data with conditions, + so please use load if possible. Args: spec (QuerySpec): The query specification @@ -328,7 +363,8 @@ def paginate_query( page (int): The page number page_size (int): The number of items per page cls (Type[T]): The type of the data - spec (Optional[QuerySpec], optional): The query specification. Defaults to None. + spec (Optional[QuerySpec], optional): The query specification. + Defaults to None. Returns: PaginationResult[T]: The pagination result @@ -356,10 +392,17 @@ def __init__( self, serializer: Optional[Serializer] = None, ): + """Create a new InMemoryStorage.""" super().__init__(serializer) - self._data = {} # Key: ResourceIdentifier, Value: Serialized data + # Key: ResourceIdentifier, Value: Serialized data + self._data: Dict[str, bytes] = {} def save(self, data: T) -> None: + """Save the data to the storage. + + Args: + data (T): The data to save + """ if not data: raise StorageError("Data cannot be None") if not data.serializer: @@ -372,6 +415,7 @@ def save(self, data: T) -> None: self._data[data.identifier.str_identifier] = data.serialize() def update(self, data: T) -> None: + """Update the data to the storage.""" if not data: raise StorageError("Data cannot be None") if not data.serializer: @@ -379,22 +423,34 @@ def update(self, data: T) -> None: self._data[data.identifier.str_identifier] = data.serialize() def save_or_update(self, data: T) -> None: + """Save or update the data to the storage.""" self.update(data) - def load(self, resource_id: ResourceIdentifier, cls: Type[T]) -> Optional[T]: + def load(self, resource_id: ID, cls: Type[T]) -> Optional[T]: + """Load the data from the storage.""" serialized_data = self._data.get(resource_id.str_identifier) if serialized_data is None: return None - return self.serializer.deserialize(serialized_data, cls) + return cast(T, self.serializer.deserialize(serialized_data, cls)) - def delete(self, resource_id: ResourceIdentifier) -> None: + def delete(self, resource_id: ID) -> None: + """Delete the data from the storage.""" if resource_id.str_identifier in self._data: del self._data[resource_id.str_identifier] def query(self, spec: QuerySpec, cls: Type[T]) -> List[T]: + """Query data from the storage. + + Args: + spec (QuerySpec): The query specification + cls (Type[T]): The type of the data + + Returns: + List[T]: The queried data + """ result = [] for serialized_data in self._data.values(): - data = self._serializer.deserialize(serialized_data, cls) + data = cast(T, self._serializer.deserialize(serialized_data, cls)) if all( getattr(data, key) == value for key, value in spec.conditions.items() ): @@ -408,6 +464,15 @@ def query(self, spec: QuerySpec, cls: Type[T]) -> List[T]: return result def count(self, spec: QuerySpec, cls: Type[T]) -> int: + """Count the number of data from the storage. + + Args: + spec (QuerySpec): The query specification + cls (Type[T]): The type of the data + + Returns: + int: The number of data + """ count = 0 for serialized_data in self._data.values(): data = self._serializer.deserialize(serialized_data, cls) diff --git a/dbgpt/core/operator/__init__.py b/dbgpt/core/operator/__init__.py index ffc60d2c7..0136adf8d 100644 --- a/dbgpt/core/operator/__init__.py +++ b/dbgpt/core/operator/__init__.py @@ -1,22 +1,24 @@ -from dbgpt.core.interface.operator.composer_operator import ( +"""All core operators.""" + +from dbgpt.core.interface.operator.composer_operator import ( # noqa: F401 ChatComposerInput, ChatHistoryPromptComposerOperator, ) -from dbgpt.core.interface.operator.llm_operator import ( +from dbgpt.core.interface.operator.llm_operator import ( # noqa: F401 BaseLLM, BaseLLMOperator, BaseStreamingLLMOperator, LLMBranchOperator, RequestBuilderOperator, ) -from dbgpt.core.interface.operator.message_operator import ( +from dbgpt.core.interface.operator.message_operator import ( # noqa: F401 BaseConversationOperator, BufferedConversationMapperOperator, ConversationMapperOperator, PreChatHistoryLoadOperator, TokenBufferedConversationMapperOperator, ) -from dbgpt.core.interface.operator.prompt_operator import ( +from dbgpt.core.interface.operator.prompt_operator import ( # noqa: F401 DynamicPromptBuilderOperator, HistoryDynamicPromptBuilderOperator, HistoryPromptBuilderOperator, diff --git a/dbgpt/model/adapter/fschat_adapter.py b/dbgpt/model/adapter/fschat_adapter.py index 956cccf66..fc5ab48f7 100644 --- a/dbgpt/model/adapter/fschat_adapter.py +++ b/dbgpt/model/adapter/fschat_adapter.py @@ -8,6 +8,9 @@ from functools import cache from typing import TYPE_CHECKING, Callable, List, Optional, Tuple +from dbgpt.model.adapter.base import LLMModelAdapter +from dbgpt.model.adapter.template import ConversationAdapter, PromptType + try: from fastchat.conversation import ( Conversation, @@ -20,8 +23,6 @@ "Please install fastchat by command `pip install fschat` " ) from exc -from dbgpt.model.adapter.base import LLMModelAdapter -from dbgpt.model.adapter.template import ConversationAdapter, PromptType if TYPE_CHECKING: from fastchat.model.model_adapter import BaseModelAdapter diff --git a/dbgpt/model/cluster/worker/default_worker.py b/dbgpt/model/cluster/worker/default_worker.py index f07e13d36..9023f3ced 100644 --- a/dbgpt/model/cluster/worker/default_worker.py +++ b/dbgpt/model/cluster/worker/default_worker.py @@ -196,7 +196,14 @@ def count_token(self, prompt: str) -> int: return _try_to_count_token(prompt, self.tokenizer, self.model) async def async_count_token(self, prompt: str) -> int: - # TODO if we deploy the model by vllm, it can't work, we should run transformer _try_to_count_token to async + # TODO if we deploy the model by vllm, it can't work, we should run + # transformer _try_to_count_token to async + from dbgpt.model.proxy.llms.proxy_model import ProxyModel + + if isinstance(self.model, ProxyModel) and self.model.proxy_llm_client: + return await self.model.proxy_llm_client.count_token( + self.model.proxy_llm_client.default_model, prompt + ) raise NotImplementedError def get_model_metadata(self, params: Dict) -> ModelMetadata: diff --git a/dbgpt/model/llm/monkey_patch.py b/dbgpt/model/llm/monkey_patch.py index f9c2b3119..355a9eb51 100644 --- a/dbgpt/model/llm/monkey_patch.py +++ b/dbgpt/model/llm/monkey_patch.py @@ -118,9 +118,6 @@ def replace_llama_attn_with_non_inplace_operations(): transformers.models.llama.modeling_llama.LlamaAttention.forward = forward -import transformers - - def replace_llama_attn_with_non_inplace_operations(): """Avoid bugs in mps backend by not using in-place operations.""" transformers.models.llama.modeling_llama.LlamaAttention.forward = forward diff --git a/dbgpt/model/proxy/base.py b/dbgpt/model/proxy/base.py index 0faec81bd..129dcf11e 100644 --- a/dbgpt/model/proxy/base.py +++ b/dbgpt/model/proxy/base.py @@ -196,6 +196,15 @@ async def models(self) -> List[ModelMetadata]: """ return self._models() + @property + def default_model(self) -> str: + """Get default model name + + Returns: + str: default model name + """ + return self.model_names[0] + @cache def _models(self) -> List[ModelMetadata]: results = [] @@ -237,6 +246,7 @@ async def count_token(self, model: str, prompt: str) -> int: Returns: int: token count, -1 if failed """ - return await blocking_func_to_async( + counts = await blocking_func_to_async( self.executor, self.proxy_tokenizer.count_token, model, [prompt] - )[0] + ) + return counts[0] diff --git a/dbgpt/model/proxy/llms/chatgpt.py b/dbgpt/model/proxy/llms/chatgpt.py index 0d200953d..3db1526a9 100755 --- a/dbgpt/model/proxy/llms/chatgpt.py +++ b/dbgpt/model/proxy/llms/chatgpt.py @@ -86,6 +86,11 @@ def __init__( self._openai_kwargs = openai_kwargs or {} super().__init__(model_names=[model_alias], context_length=context_length) + if self._openai_less_then_v1: + from dbgpt.model.utils.chatgpt_utils import _initialize_openai + + _initialize_openai(self._init_params) + @classmethod def new_client( cls, diff --git a/dbgpt/model/proxy/llms/gemini.py b/dbgpt/model/proxy/llms/gemini.py index 87af2f1a2..7a50b5e93 100644 --- a/dbgpt/model/proxy/llms/gemini.py +++ b/dbgpt/model/proxy/llms/gemini.py @@ -114,7 +114,6 @@ def __init__( self._api_key = api_key if api_key else os.getenv("GEMINI_PROXY_API_KEY") self._api_base = api_base if api_base else os.getenv("GEMINI_PROXY_API_BASE") self._model = model - self.default_model = self._model if not self._api_key: raise RuntimeError("api_key can't be empty") @@ -148,6 +147,10 @@ def new_client( executor=default_executor, ) + @property + def default_model(self) -> str: + return self._model + def sync_generate_stream( self, request: ModelRequest, diff --git a/dbgpt/model/proxy/llms/spark.py b/dbgpt/model/proxy/llms/spark.py index 57bb9f906..2f0847ead 100644 --- a/dbgpt/model/proxy/llms/spark.py +++ b/dbgpt/model/proxy/llms/spark.py @@ -8,9 +8,6 @@ from time import mktime from typing import Iterator, Optional from urllib.parse import urlencode, urlparse -from wsgiref.handlers import format_date_time - -from websockets.sync.client import connect from dbgpt.core import MessageConverter, ModelOutput, ModelRequest, ModelRequestContext from dbgpt.model.parameter import ProxyModelParameters @@ -56,6 +53,8 @@ def spark_generate_stream( def get_response(request_url, data): + from websockets.sync.client import connect + with connect(request_url) as ws: ws.send(json.dumps(data, ensure_ascii=False)) result = "" @@ -87,6 +86,8 @@ def __init__( self.spark_url = spark_url def gen_url(self): + from wsgiref.handlers import format_date_time + # 生成RFC1123格式的时间戳 now = datetime.now() date = format_date_time(mktime(now.timetuple())) @@ -145,7 +146,6 @@ def __init__( if not api_domain: api_domain = domain self._model = model - self.default_model = self._model self._model_version = model_version self._api_base = api_base self._domain = api_domain @@ -183,6 +183,10 @@ def new_client( executor=default_executor, ) + @property + def default_model(self) -> str: + return self._model + def sync_generate_stream( self, request: ModelRequest, diff --git a/dbgpt/model/proxy/llms/tongyi.py b/dbgpt/model/proxy/llms/tongyi.py index a657ab160..40143db76 100644 --- a/dbgpt/model/proxy/llms/tongyi.py +++ b/dbgpt/model/proxy/llms/tongyi.py @@ -51,7 +51,6 @@ def __init__( if api_region: dashscope.api_region = api_region self._model = model - self.default_model = self._model super().__init__( model_names=[model, model_alias], @@ -73,6 +72,10 @@ def new_client( executor=default_executor, ) + @property + def default_model(self) -> str: + return self._model + def sync_generate_stream( self, request: ModelRequest, diff --git a/dbgpt/model/proxy/llms/wenxin.py b/dbgpt/model/proxy/llms/wenxin.py index 0197b8fa6..8d797aeae 100644 --- a/dbgpt/model/proxy/llms/wenxin.py +++ b/dbgpt/model/proxy/llms/wenxin.py @@ -121,7 +121,6 @@ def __init__( self._api_key = api_key self._api_secret = api_secret self._model_version = model_version - self.default_model = self._model super().__init__( model_names=[model, model_alias], @@ -145,6 +144,10 @@ def new_client( executor=default_executor, ) + @property + def default_model(self) -> str: + return self._model + def sync_generate_stream( self, request: ModelRequest, diff --git a/dbgpt/model/proxy/llms/zhipu.py b/dbgpt/model/proxy/llms/zhipu.py index 80974ae16..6e9d36224 100644 --- a/dbgpt/model/proxy/llms/zhipu.py +++ b/dbgpt/model/proxy/llms/zhipu.py @@ -54,7 +54,6 @@ def __init__( if api_key: zhipuai.api_key = api_key self._model = model - self.default_model = self._model super().__init__( model_names=[model, model_alias], @@ -76,6 +75,10 @@ def new_client( executor=default_executor, ) + @property + def default_model(self) -> str: + return self._model + def sync_generate_stream( self, request: ModelRequest, diff --git a/dbgpt/model/utils/chatgpt_utils.py b/dbgpt/model/utils/chatgpt_utils.py index 489a5f3ce..fce5a3259 100644 --- a/dbgpt/model/utils/chatgpt_utils.py +++ b/dbgpt/model/utils/chatgpt_utils.py @@ -88,6 +88,42 @@ def _initialize_openai_v1(init_params: OpenAIParameters): return openai_params, api_type, api_version +def _initialize_openai(params: OpenAIParameters): + try: + import openai + except ImportError as exc: + raise ValueError( + "Could not import python package: openai " + "Please install openai by command `pip install openai` " + ) from exc + + api_type = params.api_type or os.getenv("OPENAI_API_TYPE", "open_ai") + + api_base = params.api_base or os.getenv( + "OPENAI_API_TYPE", + os.getenv("AZURE_OPENAI_ENDPOINT") if api_type == "azure" else None, + ) + api_key = params.api_key or os.getenv( + "OPENAI_API_KEY", + os.getenv("AZURE_OPENAI_KEY") if api_type == "azure" else None, + ) + api_version = params.api_version or os.getenv("OPENAI_API_VERSION") + + if not api_base and params.full_url: + # Adapt previous proxy_server_url configuration + api_base = params.full_url.split("/chat/completions")[0] + if api_type: + openai.api_type = api_type + if api_base: + openai.api_base = api_base + if api_key: + openai.api_key = api_key + if api_version: + openai.api_version = api_version + if params.proxies: + openai.proxy = params.proxies + + def _build_openai_client(init_params: OpenAIParameters) -> Tuple[str, ClientType]: import httpx @@ -112,9 +148,7 @@ def _build_openai_client(init_params: OpenAIParameters) -> Tuple[str, ClientType class OpenAIStreamingOutputOperator(TransformStreamAbsOperator[ModelOutput, str]): """Transform ModelOutput to openai stream format.""" - async def transform_stream( - self, input_value: AsyncIterator[ModelOutput] - ) -> AsyncIterator[str]: + async def transform_stream(self, input_value: AsyncIterator[ModelOutput]): async def model_caller() -> str: """Read model name from share data. In streaming mode, this transform_stream function will be executed diff --git a/dbgpt/rag/operator/datasource.py b/dbgpt/rag/operator/datasource.py index ea138dc1c..09cedeb26 100644 --- a/dbgpt/rag/operator/datasource.py +++ b/dbgpt/rag/operator/datasource.py @@ -1,6 +1,6 @@ from typing import Any -from dbgpt.core.interface.retriever import RetrieverOperator +from dbgpt.core.interface.operator.retriever import RetrieverOperator from dbgpt.datasource.rdbms.base import RDBMSDatabase from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary diff --git a/dbgpt/rag/operator/db_schema.py b/dbgpt/rag/operator/db_schema.py index 988b1674a..79c6b4012 100644 --- a/dbgpt/rag/operator/db_schema.py +++ b/dbgpt/rag/operator/db_schema.py @@ -1,7 +1,7 @@ from typing import Any, Optional from dbgpt.core.awel.task.base import IN -from dbgpt.core.interface.retriever import RetrieverOperator +from dbgpt.core.interface.operator.retriever import RetrieverOperator from dbgpt.datasource.rdbms.base import RDBMSDatabase from dbgpt.rag.retriever.db_schema import DBSchemaRetriever from dbgpt.storage.vector_store.connector import VectorStoreConnector diff --git a/dbgpt/rag/operator/embedding.py b/dbgpt/rag/operator/embedding.py index 99e3ab341..1d749ec1b 100644 --- a/dbgpt/rag/operator/embedding.py +++ b/dbgpt/rag/operator/embedding.py @@ -2,7 +2,7 @@ from typing import Any, Optional from dbgpt.core.awel.task.base import IN -from dbgpt.core.interface.retriever import RetrieverOperator +from dbgpt.core.interface.operator.retriever import RetrieverOperator from dbgpt.rag.retriever.embedding import EmbeddingRetriever from dbgpt.rag.retriever.rerank import Ranker from dbgpt.rag.retriever.rewrite import QueryRewrite diff --git a/dbgpt/serve/agent/agents/controller.py b/dbgpt/serve/agent/agents/controller.py index 2eb99ea31..48a37cee5 100644 --- a/dbgpt/serve/agent/agents/controller.py +++ b/dbgpt/serve/agent/agents/controller.py @@ -1,3 +1,4 @@ +import asyncio import json import logging import uuid @@ -30,7 +31,6 @@ CFG = Config() -import asyncio router = APIRouter() logger = logging.getLogger(__name__) diff --git a/dbgpt/storage/chat_history/chat_hisotry_factory.py b/dbgpt/storage/chat_history/chat_hisotry_factory.py index f202c20ff..9a556fe5c 100644 --- a/dbgpt/storage/chat_history/chat_hisotry_factory.py +++ b/dbgpt/storage/chat_history/chat_hisotry_factory.py @@ -6,13 +6,13 @@ from .base import MemoryStoreType +# Import first for auto create table +from .store_type.meta_db_history import DbHistoryMemory + # TODO remove global variable CFG = Config() logger = logging.getLogger(__name__) -# Import first for auto create table -from .store_type.meta_db_history import DbHistoryMemory - class ChatHistory: def __init__(self): diff --git a/dbgpt/storage/metadata/_base_dao.py b/dbgpt/storage/metadata/_base_dao.py index 96294c00b..b7fe5b003 100644 --- a/dbgpt/storage/metadata/_base_dao.py +++ b/dbgpt/storage/metadata/_base_dao.py @@ -5,6 +5,8 @@ from dbgpt.util.pagination_utils import PaginationResult +from .db_manager import BaseQuery, DatabaseManager, db + # The entity type T = TypeVar("T") # The request schema type @@ -12,7 +14,6 @@ # The response schema type RES = TypeVar("RES") -from .db_manager import BaseQuery, DatabaseManager, db QUERY_SPEC = Union[REQ, Dict[str, Any]] diff --git a/dbgpt/util/annotations.py b/dbgpt/util/annotations.py index f97075339..d20f223db 100644 --- a/dbgpt/util/annotations.py +++ b/dbgpt/util/annotations.py @@ -1,3 +1,6 @@ +from typing import Optional + + def PublicAPI(*args, **kwargs): """Decorator to mark a function or class as a public API. @@ -64,7 +67,7 @@ def decorator(obj): return decorator -def _modify_docstring(obj, message: str = None): +def _modify_docstring(obj, message: Optional[str] = None): if not message: return if not obj.__doc__: @@ -81,6 +84,7 @@ def _modify_docstring(obj, message: str = None): if min_indent == float("inf"): min_indent = 0 + min_indent = int(min_indent) indented_message = message.rstrip() + "\n" + (" " * min_indent) obj.__doc__ = indented_message + original_doc diff --git a/dbgpt/util/config_utils.py b/dbgpt/util/config_utils.py index cf721f2c8..a31fe20af 100644 --- a/dbgpt/util/config_utils.py +++ b/dbgpt/util/config_utils.py @@ -1,6 +1,6 @@ import os from functools import cache -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, cast class AppConfig: @@ -46,7 +46,7 @@ def get_current_lang(self, default: Optional[str] = None) -> str: """ env_lang = ( "zh" - if os.getenv("LANG") and os.getenv("LANG").startswith("zh") + if os.getenv("LANG") and cast(str, os.getenv("LANG")).startswith("zh") else default ) return self.get("dbgpt.app.global.language", env_lang) diff --git a/dbgpt/util/formatting.py b/dbgpt/util/formatting.py index 51f374314..31a294f71 100644 --- a/dbgpt/util/formatting.py +++ b/dbgpt/util/formatting.py @@ -1,7 +1,7 @@ """Utilities for formatting strings.""" import json from string import Formatter -from typing import Any, List, Mapping, Sequence, Union +from typing import Any, List, Mapping, Sequence, Set, Union class StrictFormatter(Formatter): @@ -9,7 +9,7 @@ class StrictFormatter(Formatter): def check_unused_args( self, - used_args: Sequence[Union[int, str]], + used_args: Set[Union[int, str]], args: Sequence, kwargs: Mapping[str, Any], ) -> None: @@ -39,7 +39,7 @@ def validate_input_variables( class NoStrictFormatter(StrictFormatter): def check_unused_args( self, - used_args: Sequence[Union[int, str]], + used_args: Set[Union[int, str]], args: Sequence, kwargs: Mapping[str, Any], ) -> None: diff --git a/dbgpt/util/parameter_utils.py b/dbgpt/util/parameter_utils.py index 202c5f8a6..e4ea48ec6 100644 --- a/dbgpt/util/parameter_utils.py +++ b/dbgpt/util/parameter_utils.py @@ -12,14 +12,14 @@ @dataclass class ParameterDescription: - param_class: str - param_name: str - param_type: str - default_value: Optional[Any] - description: str - required: Optional[bool] - valid_values: Optional[List[Any]] - ext_metadata: Dict + required: bool = False + param_class: Optional[str] = None + param_name: Optional[str] = None + param_type: Optional[str] = None + description: Optional[str] = None + default_value: Optional[Any] = None + valid_values: Optional[List[Any]] = None + ext_metadata: Optional[Dict[str, Any]] = None @dataclass @@ -186,7 +186,9 @@ class Person: return "******" -def _genenv_ignoring_key_case(env_key: str, env_prefix: str = None, default_value=None): +def _genenv_ignoring_key_case( + env_key: str, env_prefix: Optional[str] = None, default_value: Optional[str] = None +): """Get the value from the environment variable, ignoring the case of the key""" if env_prefix: env_key = env_prefix + env_key @@ -196,7 +198,9 @@ def _genenv_ignoring_key_case(env_key: str, env_prefix: str = None, default_valu def _genenv_ignoring_key_case_with_prefixes( - env_key: str, env_prefixes: List[str] = None, default_value=None + env_key: str, + env_prefixes: Optional[List[str]] = None, + default_value: Optional[str] = None, ) -> str: if env_prefixes: for env_prefix in env_prefixes: @@ -208,7 +212,7 @@ def _genenv_ignoring_key_case_with_prefixes( class EnvArgumentParser: @staticmethod - def get_env_prefix(env_key: str) -> str: + def get_env_prefix(env_key: str) -> Optional[str]: if not env_key: return None env_key = env_key.replace("-", "_") @@ -217,14 +221,14 @@ def get_env_prefix(env_key: str) -> str: def parse_args_into_dataclass( self, dataclass_type: Type, - env_prefixes: List[str] = None, - command_args: List[str] = None, + env_prefixes: Optional[List[str]] = None, + command_args: Optional[List[str]] = None, **kwargs, ) -> Any: """Parse parameters from environment variables and command lines and populate them into data class""" parser = argparse.ArgumentParser() for field in fields(dataclass_type): - env_var_value = _genenv_ignoring_key_case_with_prefixes( + env_var_value: Any = _genenv_ignoring_key_case_with_prefixes( field.name, env_prefixes ) if env_var_value: @@ -313,7 +317,8 @@ def _create_click_option_from_field(field_name: str, field: Type, is_func=True): @staticmethod def create_click_option( - *dataclass_types: Type, _dynamic_factory: Callable[[None], List[Type]] = None + *dataclass_types: Type, + _dynamic_factory: Optional[Callable[[], List[Type]]] = None, ): import functools from collections import OrderedDict @@ -322,8 +327,9 @@ def create_click_option( if _dynamic_factory: _types = _dynamic_factory() if _types: - dataclass_types = list(_types) + dataclass_types = list(_types) # type: ignore for dataclass_type in dataclass_types: + # type: ignore for field in fields(dataclass_type): if field.name not in combined_fields: combined_fields[field.name] = field @@ -345,7 +351,8 @@ def wrapper(*args, **kwargs): @staticmethod def _create_raw_click_option( - *dataclass_types: Type, _dynamic_factory: Callable[[None], List[Type]] = None + *dataclass_types: Type, + _dynamic_factory: Optional[Callable[[], List[Type]]] = None, ): combined_fields = _merge_dataclass_types( *dataclass_types, _dynamic_factory=_dynamic_factory @@ -362,7 +369,8 @@ def _create_raw_click_option( @staticmethod def create_argparse_option( - *dataclass_types: Type, _dynamic_factory: Callable[[None], List[Type]] = None + *dataclass_types: Type, + _dynamic_factory: Optional[Callable[[], List[Type]]] = None, ) -> argparse.ArgumentParser: combined_fields = _merge_dataclass_types( *dataclass_types, _dynamic_factory=_dynamic_factory @@ -429,7 +437,7 @@ def _get_argparse_type_str(field_type: Type) -> str: return "str" @staticmethod - def _is_require_type(field_type: Type) -> str: + def _is_require_type(field_type: Type) -> bool: return field_type not in [Optional[int], Optional[float], Optional[bool]] @staticmethod @@ -455,13 +463,13 @@ def _read_env_key_value( def _merge_dataclass_types( - *dataclass_types: Type, _dynamic_factory: Callable[[None], List[Type]] = None + *dataclass_types: Type, _dynamic_factory: Optional[Callable[[], List[Type]]] = None ) -> OrderedDict: combined_fields = OrderedDict() if _dynamic_factory: _types = _dynamic_factory() if _types: - dataclass_types = list(_types) + dataclass_types = list(_types) # type: ignore for dataclass_type in dataclass_types: for field in fields(dataclass_type): if field.name not in combined_fields: @@ -511,11 +519,12 @@ def _build_parameter_class(desc: List[ParameterDescription]) -> Type: if not desc: raise ValueError("Parameter descriptions cant be empty") param_class_str = desc[0].param_class + class_name = None if param_class_str: param_class = import_from_string(param_class_str, ignore_import_error=True) if param_class: return param_class - module_name, _, class_name = param_class_str.rpartition(".") + module_name, _, class_name = param_class_str.rpartition(".") fields_dict = {} # This will store field names and their default values or field() annotations = {} # This will store the type annotations for the fields @@ -526,25 +535,30 @@ def _build_parameter_class(desc: List[ParameterDescription]) -> Type: metadata["valid_values"] = d.valid_values annotations[d.param_name] = _type_str_to_python_type( - d.param_type + d.param_type # type: ignore ) # Set type annotation fields_dict[d.param_name] = field(default=d.default_value, metadata=metadata) # Create the new class. Note the setting of __annotations__ for type hints new_class = type( - class_name, (object,), {**fields_dict, "__annotations__": annotations} + class_name, # type: ignore + (object,), + {**fields_dict, "__annotations__": annotations}, # type: ignore ) - result_class = dataclass(new_class) # Make it a dataclass + # Make it a dataclass + result_class = dataclass(new_class) # type: ignore return result_class def _extract_parameter_details( parser: argparse.ArgumentParser, - param_class: str = None, - skip_names: List[str] = None, - overwrite_default_values: Dict = {}, + param_class: Optional[str] = None, + skip_names: Optional[List[str]] = None, + overwrite_default_values: Optional[Dict[str, Any]] = None, ) -> List[ParameterDescription]: + if overwrite_default_values is None: + overwrite_default_values = {} descriptions = [] for action in parser._actions: @@ -575,7 +589,9 @@ def _extract_parameter_details( if param_name in overwrite_default_values: default_value = overwrite_default_values[param_name] arg_type = ( - action.type if not callable(action.type) else str(action.type.__name__) + action.type + if not callable(action.type) + else str(action.type.__name__) # type: ignore ) description = action.help @@ -583,10 +599,10 @@ def _extract_parameter_details( required = action.required # extract valid values for choices, if provided - valid_values = action.choices if action.choices is not None else None + valid_values = list(action.choices) if action.choices is not None else None # set ext_metadata as an empty dict for now, can be updated later if needed - ext_metadata = {} + ext_metadata: Dict[str, Any] = {} descriptions.append( ParameterDescription( @@ -621,7 +637,7 @@ def _get_dict_from_obj(obj, default_value=None) -> Optional[Dict]: def _get_base_model_descriptions(model_cls: "BaseModel") -> List[ParameterDescription]: from dbgpt._private import pydantic - version = int(pydantic.VERSION.split(".")[0]) + version = int(pydantic.VERSION.split(".")[0]) # type: ignore schema = model_cls.model_json_schema() if version >= 2 else model_cls.schema() required_fields = set(schema.get("required", [])) param_descs = [] @@ -661,7 +677,7 @@ def _get_base_model_descriptions(model_cls: "BaseModel") -> List[ParameterDescri ext_metadata = ( field.field_info.extra if hasattr(field.field_info, "extra") else None ) - param_class = (f"{model_cls.__module__}.{model_cls.__name__}",) + param_class = f"{model_cls.__module__}.{model_cls.__name__}" param_desc = ParameterDescription( param_class=param_class, param_name=field_name, diff --git a/dbgpt/util/utils.py b/dbgpt/util/utils.py index bed637cba..03ae1473d 100644 --- a/dbgpt/util/utils.py +++ b/dbgpt/util/utils.py @@ -5,7 +5,7 @@ import logging import logging.handlers import os -from typing import Any, List +from typing import Any, List, Optional, cast from dbgpt.configs.model_config import LOGDIR @@ -28,19 +28,25 @@ def _get_logging_level() -> str: return os.getenv("DBGPT_LOG_LEVEL", "INFO") -def setup_logging_level(logging_level=None, logger_name: str = None): +def setup_logging_level( + logging_level: Optional[str] = None, logger_name: Optional[str] = None +): if not logging_level: logging_level = _get_logging_level() if type(logging_level) is str: logging_level = logging.getLevelName(logging_level.upper()) if logger_name: logger = logging.getLogger(logger_name) - logger.setLevel(logging_level) + logger.setLevel(cast(str, logging_level)) else: logging.basicConfig(level=logging_level, encoding="utf-8") -def setup_logging(logger_name: str, logging_level=None, logger_filename: str = None): +def setup_logging( + logger_name: str, + logging_level: Optional[str] = None, + logger_filename: Optional[str] = None, +): if not logging_level: logging_level = _get_logging_level() logger = _build_logger(logger_name, logging_level, logger_filename) @@ -74,7 +80,11 @@ def get_gpu_memory(max_gpus=None): return gpu_memory -def _build_logger(logger_name, logging_level=None, logger_filename: str = None): +def _build_logger( + logger_name, + logging_level: Optional[str] = None, + logger_filename: Optional[str] = None, +): global handler formatter = logging.Formatter( @@ -111,14 +121,14 @@ def get_or_create_event_loop() -> asyncio.BaseEventLoop: try: loop = asyncio.get_event_loop() assert loop is not None - return loop + return cast(asyncio.BaseEventLoop, loop) except RuntimeError as e: if not "no running event loop" in str(e) and not "no current event loop" in str( e ): raise e logging.warning("Cant not get running event loop, create new event loop now") - return asyncio.get_event_loop_policy().new_event_loop() + return cast(asyncio.BaseEventLoop, asyncio.get_event_loop_policy().new_event_loop()) def logging_str_to_uvicorn_level(log_level_str): @@ -152,7 +162,7 @@ def filter(self, record: logging.LogRecord) -> bool: return record.getMessage().find(self._path) == -1 -def setup_http_service_logging(exclude_paths: List[str] = None): +def setup_http_service_logging(exclude_paths: Optional[List[str]] = None): """Setup http service logging Now just disable some logs diff --git a/examples/awel/simple_rag_summary_example.py b/examples/awel/simple_rag_summary_example.py index f16cbec73..5447af032 100644 --- a/examples/awel/simple_rag_summary_example.py +++ b/examples/awel/simple_rag_summary_example.py @@ -53,7 +53,7 @@ async def map(self, input_value: TriggerReqBody) -> Dict: return params -with DAG("dbgpt_awel_simple_rag_rewrite_example") as dag: +with DAG("dbgpt_awel_simple_rag_summary_example") as dag: trigger = HttpTrigger( "/examples/rag/summary", methods="POST", request_body=TriggerReqBody ) diff --git a/requirements/dev-requirements.txt b/requirements/dev-requirements.txt index dc49dd0aa..8061d24d2 100644 --- a/requirements/dev-requirements.txt +++ b/requirements/dev-requirements.txt @@ -13,4 +13,4 @@ aioresponses # for git hooks pre-commit # Type checking -mypy==0.991 \ No newline at end of file +mypy==1.7.0 \ No newline at end of file