From f8ce7d4580fa695e001582d167575d53e48754ac Mon Sep 17 00:00:00 2001 From: Fangyin Cheng Date: Wed, 4 Sep 2024 22:08:55 +0800 Subject: [PATCH] feat: Check serialization for AWEL operator function --- .dockerignore | 5 ++ dbgpt/_private/config.py | 4 + dbgpt/app/initialization/scheduler.py | 4 +- dbgpt/core/awel/dag/base.py | 37 +++++++++ dbgpt/core/awel/operators/base.py | 19 ++++- dbgpt/core/awel/operators/common_operator.py | 17 ++++ dbgpt/model/proxy/base.py | 11 +++ dbgpt/serve/flow/api/variables_provider.py | 2 +- dbgpt/storage/metadata/_base_dao.py | 3 + dbgpt/util/net_utils.py | 32 ++++++++ dbgpt/util/serialization/check.py | 85 ++++++++++++++++++++ docker/base/build_image.sh | 21 ++++- 12 files changed, 236 insertions(+), 4 deletions(-) create mode 100644 dbgpt/util/serialization/check.py diff --git a/.dockerignore b/.dockerignore index 823dcbd59..7e4266596 100644 --- a/.dockerignore +++ b/.dockerignore @@ -1,3 +1,6 @@ +.env +.git/ +./.mypy_cache/ models/ plugins/ pilot/data @@ -5,6 +8,8 @@ pilot/message logs/ venv/ web/node_modules/ +web/.next/ +web/.env docs/node_modules/ build/ docs/build/ diff --git a/dbgpt/_private/config.py b/dbgpt/_private/config.py index cd7936a60..eaf1e953c 100644 --- a/dbgpt/_private/config.py +++ b/dbgpt/_private/config.py @@ -332,6 +332,10 @@ def __init__(self) -> None: os.getenv("MULTI_INSTANCE", "False").lower() == "true" ) + self.SCHEDULER_ENABLED = ( + os.getenv("SCHEDULER_ENABLED", "True").lower() == "true" + ) + @property def local_db_manager(self) -> "ConnectorManager": from dbgpt.datasource.manages import ConnectorManager diff --git a/dbgpt/app/initialization/scheduler.py b/dbgpt/app/initialization/scheduler.py index 70b7bb71a..36a3107db 100644 --- a/dbgpt/app/initialization/scheduler.py +++ b/dbgpt/app/initialization/scheduler.py @@ -19,12 +19,14 @@ def __init__( system_app: SystemApp, scheduler_delay_ms: int = 5000, scheduler_interval_ms: int = 1000, + scheduler_enable: bool = True, ): super().__init__(system_app) self.system_app = system_app self._scheduler_interval_ms = scheduler_interval_ms self._scheduler_delay_ms = scheduler_delay_ms self._stop_event = threading.Event() + self._scheduler_enable = scheduler_enable def init_app(self, system_app: SystemApp): self.system_app = system_app @@ -39,7 +41,7 @@ def before_stop(self): def _scheduler(self): time.sleep(self._scheduler_delay_ms / 1000) - while not self._stop_event.is_set(): + while self._scheduler_enable and not self._stop_event.is_set(): try: schedule.run_pending() except Exception as e: diff --git a/dbgpt/core/awel/dag/base.py b/dbgpt/core/awel/dag/base.py index 2f3521d24..2c12b3bad 100644 --- a/dbgpt/core/awel/dag/base.py +++ b/dbgpt/core/awel/dag/base.py @@ -145,6 +145,9 @@ class DAGVar: _executor: Optional[Executor] = None _variables_provider: Optional["VariablesProvider"] = None + # Whether check serializable for AWEL, it will be set to True when running AWEL + # operator in remote environment + _check_serializable: Optional[bool] = None @classmethod def enter_dag(cls, dag) -> None: @@ -257,6 +260,24 @@ def set_variables_provider(cls, variables_provider: "VariablesProvider") -> None """ cls._variables_provider = variables_provider + @classmethod + def get_check_serializable(cls) -> Optional[bool]: + """Get the check serializable flag. + + Returns: + Optional[bool]: The check serializable flag + """ + return cls._check_serializable + + @classmethod + def set_check_serializable(cls, check_serializable: bool) -> None: + """Set the check serializable flag. + + Args: + check_serializable (bool): The check serializable flag to set + """ + cls._check_serializable = check_serializable + class DAGLifecycle: """The lifecycle of DAG.""" @@ -286,6 +307,7 @@ def __init__( node_name: Optional[str] = None, system_app: Optional[SystemApp] = None, executor: Optional[Executor] = None, + check_serializable: Optional[bool] = None, **kwargs, ) -> None: """Initialize a DAGNode. @@ -311,6 +333,7 @@ def __init__( node_id = self._dag._new_node_id() self._node_id: Optional[str] = node_id self._node_name: Optional[str] = node_name + self._check_serializable = check_serializable if self._dag: self._dag._append_node(self) @@ -486,6 +509,20 @@ def __str__(self): """Return the string of current DAGNode.""" return self.__repr__() + @classmethod + def _do_check_serializable(cls, obj: Any, obj_name: str = "Object"): + """Check whether the current DAGNode is serializable.""" + from dbgpt.util.serialization.check import check_serializable + + check_serializable(obj, obj_name) + + @property + def check_serializable(self) -> bool: + """Whether check serializable for current DAGNode.""" + if self._check_serializable is not None: + return self._check_serializable or False + return DAGVar.get_check_serializable() or False + def _build_task_key(task_name: str, key: str) -> str: return f"{task_name}___$$$$$$___{key}" diff --git a/dbgpt/core/awel/operators/base.py b/dbgpt/core/awel/operators/base.py index da82d2856..7c66c0adc 100644 --- a/dbgpt/core/awel/operators/base.py +++ b/dbgpt/core/awel/operators/base.py @@ -193,12 +193,29 @@ def __init__( self.incremental_output = bool(kwargs["incremental_output"]) if "output_format" in kwargs: self.output_format = kwargs["output_format"] - self._runner: WorkflowRunner = runner self._dag_ctx: Optional[DAGContext] = None self._can_skip_in_branch = can_skip_in_branch self._variables_provider = variables_provider + def __getstate__(self): + """Customize the pickling process.""" + state = self.__dict__.copy() + if "_runner" in state: + del state["_runner"] + if "_executor" in state: + del state["_executor"] + if "_system_app" in state: + del state["_system_app"] + return state + + def __setstate__(self, state): + """Customize the unpickling process.""" + self.__dict__.update(state) + self._runner = default_runner + self._system_app = DAGVar.get_current_system_app() + self._executor = DAGVar.get_executor() + @property def current_dag_context(self) -> DAGContext: """Return the current DAG context.""" diff --git a/dbgpt/core/awel/operators/common_operator.py b/dbgpt/core/awel/operators/common_operator.py index f8bc25370..763992323 100644 --- a/dbgpt/core/awel/operators/common_operator.py +++ b/dbgpt/core/awel/operators/common_operator.py @@ -41,6 +41,12 @@ def __init__( super().__init__(can_skip_in_branch=can_skip_in_branch, **kwargs) if not callable(combine_function): raise ValueError("combine_function must be callable") + + if self.check_serializable: + super()._do_check_serializable( + combine_function, + f"JoinOperator: {self}, combine_function: {combine_function}", + ) self.combine_function = combine_function async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]: @@ -83,6 +89,11 @@ def __init__(self, reduce_function: Optional[ReduceFunc] = None, **kwargs): super().__init__(**kwargs) if reduce_function and not callable(reduce_function): raise ValueError("reduce_function must be callable") + if reduce_function and self.check_serializable: + super()._do_check_serializable( + reduce_function, f"Operator: {self}, reduce_function: {reduce_function}" + ) + self.reduce_function = reduce_function async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]: @@ -133,6 +144,12 @@ def __init__(self, map_function: Optional[MapFunc] = None, **kwargs): super().__init__(**kwargs) if map_function and not callable(map_function): raise ValueError("map_function must be callable") + + if map_function and self.check_serializable: + super()._do_check_serializable( + map_function, f"Operator: {self}, map_function: {map_function}" + ) + self.map_function = map_function async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]: diff --git a/dbgpt/model/proxy/base.py b/dbgpt/model/proxy/base.py index 129dcf11e..2a1a3b6b8 100644 --- a/dbgpt/model/proxy/base.py +++ b/dbgpt/model/proxy/base.py @@ -94,6 +94,17 @@ def __init__( self.executor = executor or ThreadPoolExecutor() self.proxy_tokenizer = proxy_tokenizer or TiktokenProxyTokenizer() + def __getstate__(self): + """Customize the serialization of the object""" + state = self.__dict__.copy() + state.pop("executor") + return state + + def __setstate__(self, state): + """Customize the deserialization of the object""" + self.__dict__.update(state) + self.executor = ThreadPoolExecutor() + @classmethod @abstractmethod def new_client( diff --git a/dbgpt/serve/flow/api/variables_provider.py b/dbgpt/serve/flow/api/variables_provider.py index 27ed63bf5..10bb8a656 100644 --- a/dbgpt/serve/flow/api/variables_provider.py +++ b/dbgpt/serve/flow/api/variables_provider.py @@ -341,7 +341,7 @@ def get_variables( StorageVariables( key=key, name=agent["name"], - label=agent["desc"], + label=agent["name"], value=agent["name"], scope=scope, scope_key=scope_key, diff --git a/dbgpt/storage/metadata/_base_dao.py b/dbgpt/storage/metadata/_base_dao.py index 04b74e81c..199b3eabc 100644 --- a/dbgpt/storage/metadata/_base_dao.py +++ b/dbgpt/storage/metadata/_base_dao.py @@ -285,6 +285,9 @@ def _create_query_object( else model_to_dict(query_request) ) for key, value in query_dict.items(): + if value and isinstance(value, (list, tuple, dict, set)): + # Skip the list, tuple, dict, set + continue if value is not None and hasattr(model_cls, key): if isinstance(value, list): if len(value) > 0: diff --git a/dbgpt/util/net_utils.py b/dbgpt/util/net_utils.py index fc9fb3f86..ce41ba781 100644 --- a/dbgpt/util/net_utils.py +++ b/dbgpt/util/net_utils.py @@ -1,5 +1,6 @@ import errno import socket +from typing import Set, Tuple def _get_ip_address(address: str = "10.254.254.254:1") -> str: @@ -22,3 +23,34 @@ def _get_ip_address(address: str = "10.254.254.254:1") -> str: finally: s.close() return curr_address + + +async def _async_get_free_port( + port_range: Tuple[int, int], timeout: int, used_ports: Set[int] +): + import asyncio + + loop = asyncio.get_running_loop() + return await loop.run_in_executor( + None, _get_free_port, port_range, timeout, used_ports + ) + + +def _get_free_port(port_range: Tuple[int, int], timeout: int, used_ports: Set[int]): + import random + + available_ports = set(range(port_range[0], port_range[1] + 1)) - used_ports + if not available_ports: + raise RuntimeError("No available ports in the specified range") + + while available_ports: + port = random.choice(list(available_ports)) + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", port)) + used_ports.add(port) + return port + except OSError: + available_ports.remove(port) + + raise RuntimeError("No available ports in the specified range") diff --git a/dbgpt/util/serialization/check.py b/dbgpt/util/serialization/check.py new file mode 100644 index 000000000..10a86edb2 --- /dev/null +++ b/dbgpt/util/serialization/check.py @@ -0,0 +1,85 @@ +import inspect +from io import StringIO +from typing import Any, Dict, Optional, TextIO + +import cloudpickle + + +def check_serializable( + obj: Any, obj_name: str = "Object", error_msg: str = "Object is not serializable" +): + try: + cloudpickle.dumps(obj) + except Exception as e: + inspect_info = inspect_serializability(obj, obj_name) + msg = f"{error_msg}\n{inspect_info['report']}" + raise TypeError(msg) from e + + +class SerializabilityInspector: + def __init__(self, stream: Optional[TextIO] = None): + self.stream = stream or StringIO() + self.failures = {} + self.indent_level = 0 + + def log(self, message: str): + indent = " " * self.indent_level + self.stream.write(f"{indent}{message}\n") + + def inspect(self, obj: Any, name: str, depth: int = 3) -> bool: + self.log(f"Inspecting '{name}'") + self.indent_level += 1 + + try: + cloudpickle.dumps(obj) + self.indent_level -= 1 + return True + except Exception as e: + self.failures[name] = str(e) + self.log(f"Failure: {str(e)}") + + if depth > 0: + if inspect.isfunction(obj) or inspect.ismethod(obj): + self._inspect_function(obj, depth - 1) + elif hasattr(obj, "__dict__"): + self._inspect_object(obj, depth - 1) + + self.indent_level -= 1 + return False + + def _inspect_function(self, func, depth): + closure = inspect.getclosurevars(func) + for name, value in closure.nonlocals.items(): + self.inspect(value, f"{func.__name__}.{name}", depth) + for name, value in closure.globals.items(): + self.inspect(value, f"global:{name}", depth) + + def _inspect_object(self, obj, depth): + for name, value in inspect.getmembers(obj): + if not name.startswith("__"): + self.inspect(value, f"{type(obj).__name__}.{name}", depth) + + def get_report(self) -> str: + summary = "\nSummary of Serialization Failures:\n" + if not self.failures: + summary += "All components are serializable.\n" + else: + for name, error in self.failures.items(): + summary += f" - {name}: {error}\n" + + return self.stream.getvalue() + summary + + +def inspect_serializability( + obj: Any, + name: Optional[str] = None, + depth: int = 5, + stream: Optional[TextIO] = None, +) -> Dict[str, Any]: + inspector = SerializabilityInspector(stream) + success = inspector.inspect(obj, name or type(obj).__name__, depth) + return { + "success": success, + "failures": inspector.failures, + "report": inspector.get_report(), + } diff --git a/docker/base/build_image.sh b/docker/base/build_image.sh index 028dcc809..08cd0b549 100755 --- a/docker/base/build_image.sh +++ b/docker/base/build_image.sh @@ -20,16 +20,21 @@ LOAD_EXAMPLES="true" BUILD_NETWORK="" DB_GPT_INSTALL_MODEL="default" +DOCKERFILE="Dockerfile" +IMAGE_NAME_SUFFIX="" + usage () { echo "USAGE: $0 [--base-image nvidia/cuda:12.1.0-runtime-ubuntu22.04] [--image-name db-gpt]" echo " [-b|--base-image base image name] Base image name" echo " [-n|--image-name image name] Current image name, default: db-gpt" + echo " [--image-name-suffix image name suffix] Image name suffix" echo " [-i|--pip-index-url pip index url] Pip index url, default: https://pypi.org/simple" echo " [--language en or zh] You language, default: en" echo " [--build-local-code true or false] Whether to use the local project code to package the image, default: true" echo " [--load-examples true or false] Whether to load examples to default database default: true" echo " [--network network name] The network of docker build" echo " [--install-mode mode name] Installation mode name, default: default, If you completely use openai's service, you can set the mode name to 'openai'" + echo " [-f|--dockerfile dockerfile] Dockerfile name, default: Dockerfile" echo " [-h|--help] Usage message" } @@ -46,6 +51,11 @@ while [[ $# -gt 0 ]]; do shift # past argument shift # past value ;; + --image-name-suffix) + IMAGE_NAME_SUFFIX="$2" + shift # past argument + shift # past value + ;; -i|--pip-index-url) PIP_INDEX_URL="$2" shift @@ -80,6 +90,11 @@ while [[ $# -gt 0 ]]; do shift # past argument shift # past value ;; + -f|--dockerfile) + DOCKERFILE="$2" + shift # past argument + shift # past value + ;; -h|--help) help="true" shift @@ -111,6 +126,10 @@ else BASE_IMAGE=$IMAGE_NAME_ARGS fi +if [ -n "$IMAGE_NAME_SUFFIX" ]; then + IMAGE_NAME="$IMAGE_NAME-$IMAGE_NAME_SUFFIX" +fi + echo "Begin build docker image, base image: ${BASE_IMAGE}, target image name: ${IMAGE_NAME}" docker build $BUILD_NETWORK \ @@ -120,5 +139,5 @@ docker build $BUILD_NETWORK \ --build-arg BUILD_LOCAL_CODE=$BUILD_LOCAL_CODE \ --build-arg LOAD_EXAMPLES=$LOAD_EXAMPLES \ --build-arg DB_GPT_INSTALL_MODEL=$DB_GPT_INSTALL_MODEL \ - -f Dockerfile \ + -f $DOCKERFILE \ -t $IMAGE_NAME $WORK_DIR/../../