Skip to content

Commit

Permalink
chore: update api (#1972)
Browse files Browse the repository at this point in the history
  • Loading branch information
Dreammy23 authored Sep 4, 2024
2 parents 1794069 + 6ca1381 commit cbfd3d2
Show file tree
Hide file tree
Showing 12 changed files with 236 additions and 4 deletions.
5 changes: 5 additions & 0 deletions .dockerignore
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
.env
.git/
./.mypy_cache/
models/
plugins/
pilot/data
pilot/message
logs/
venv/
web/node_modules/
web/.next/
web/.env
docs/node_modules/
build/
docs/build/
Expand Down
4 changes: 4 additions & 0 deletions dbgpt/_private/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,10 @@ def __init__(self) -> None:
os.getenv("MULTI_INSTANCE", "False").lower() == "true"
)

self.SCHEDULER_ENABLED = (
os.getenv("SCHEDULER_ENABLED", "True").lower() == "true"
)

@property
def local_db_manager(self) -> "ConnectorManager":
from dbgpt.datasource.manages import ConnectorManager
Expand Down
4 changes: 3 additions & 1 deletion dbgpt/app/initialization/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@ def __init__(
system_app: SystemApp,
scheduler_delay_ms: int = 5000,
scheduler_interval_ms: int = 1000,
scheduler_enable: bool = True,
):
super().__init__(system_app)
self.system_app = system_app
self._scheduler_interval_ms = scheduler_interval_ms
self._scheduler_delay_ms = scheduler_delay_ms
self._stop_event = threading.Event()
self._scheduler_enable = scheduler_enable

def init_app(self, system_app: SystemApp):
self.system_app = system_app
Expand All @@ -39,7 +41,7 @@ def before_stop(self):

def _scheduler(self):
time.sleep(self._scheduler_delay_ms / 1000)
while not self._stop_event.is_set():
while self._scheduler_enable and not self._stop_event.is_set():
try:
schedule.run_pending()
except Exception as e:
Expand Down
37 changes: 37 additions & 0 deletions dbgpt/core/awel/dag/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,9 @@ class DAGVar:
_executor: Optional[Executor] = None

_variables_provider: Optional["VariablesProvider"] = None
# Whether check serializable for AWEL, it will be set to True when running AWEL
# operator in remote environment
_check_serializable: Optional[bool] = None

@classmethod
def enter_dag(cls, dag) -> None:
Expand Down Expand Up @@ -257,6 +260,24 @@ def set_variables_provider(cls, variables_provider: "VariablesProvider") -> None
"""
cls._variables_provider = variables_provider

@classmethod
def get_check_serializable(cls) -> Optional[bool]:
"""Get the check serializable flag.
Returns:
Optional[bool]: The check serializable flag
"""
return cls._check_serializable

@classmethod
def set_check_serializable(cls, check_serializable: bool) -> None:
"""Set the check serializable flag.
Args:
check_serializable (bool): The check serializable flag to set
"""
cls._check_serializable = check_serializable


class DAGLifecycle:
"""The lifecycle of DAG."""
Expand Down Expand Up @@ -286,6 +307,7 @@ def __init__(
node_name: Optional[str] = None,
system_app: Optional[SystemApp] = None,
executor: Optional[Executor] = None,
check_serializable: Optional[bool] = None,
**kwargs,
) -> None:
"""Initialize a DAGNode.
Expand All @@ -311,6 +333,7 @@ def __init__(
node_id = self._dag._new_node_id()
self._node_id: Optional[str] = node_id
self._node_name: Optional[str] = node_name
self._check_serializable = check_serializable
if self._dag:
self._dag._append_node(self)

Expand Down Expand Up @@ -486,6 +509,20 @@ def __str__(self):
"""Return the string of current DAGNode."""
return self.__repr__()

@classmethod
def _do_check_serializable(cls, obj: Any, obj_name: str = "Object"):
"""Check whether the current DAGNode is serializable."""
from dbgpt.util.serialization.check import check_serializable

check_serializable(obj, obj_name)

@property
def check_serializable(self) -> bool:
"""Whether check serializable for current DAGNode."""
if self._check_serializable is not None:
return self._check_serializable or False
return DAGVar.get_check_serializable() or False


def _build_task_key(task_name: str, key: str) -> str:
return f"{task_name}___$$$$$$___{key}"
Expand Down
19 changes: 18 additions & 1 deletion dbgpt/core/awel/operators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,12 +193,29 @@ def __init__(
self.incremental_output = bool(kwargs["incremental_output"])
if "output_format" in kwargs:
self.output_format = kwargs["output_format"]

self._runner: WorkflowRunner = runner
self._dag_ctx: Optional[DAGContext] = None
self._can_skip_in_branch = can_skip_in_branch
self._variables_provider = variables_provider

def __getstate__(self):
"""Customize the pickling process."""
state = self.__dict__.copy()
if "_runner" in state:
del state["_runner"]
if "_executor" in state:
del state["_executor"]
if "_system_app" in state:
del state["_system_app"]
return state

def __setstate__(self, state):
"""Customize the unpickling process."""
self.__dict__.update(state)
self._runner = default_runner
self._system_app = DAGVar.get_current_system_app()
self._executor = DAGVar.get_executor()

@property
def current_dag_context(self) -> DAGContext:
"""Return the current DAG context."""
Expand Down
17 changes: 17 additions & 0 deletions dbgpt/core/awel/operators/common_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@ def __init__(
super().__init__(can_skip_in_branch=can_skip_in_branch, **kwargs)
if not callable(combine_function):
raise ValueError("combine_function must be callable")

if self.check_serializable:
super()._do_check_serializable(
combine_function,
f"JoinOperator: {self}, combine_function: {combine_function}",
)
self.combine_function = combine_function

async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
Expand Down Expand Up @@ -83,6 +89,11 @@ def __init__(self, reduce_function: Optional[ReduceFunc] = None, **kwargs):
super().__init__(**kwargs)
if reduce_function and not callable(reduce_function):
raise ValueError("reduce_function must be callable")
if reduce_function and self.check_serializable:
super()._do_check_serializable(
reduce_function, f"Operator: {self}, reduce_function: {reduce_function}"
)

self.reduce_function = reduce_function

async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
Expand Down Expand Up @@ -133,6 +144,12 @@ def __init__(self, map_function: Optional[MapFunc] = None, **kwargs):
super().__init__(**kwargs)
if map_function and not callable(map_function):
raise ValueError("map_function must be callable")

if map_function and self.check_serializable:
super()._do_check_serializable(
map_function, f"Operator: {self}, map_function: {map_function}"
)

self.map_function = map_function

async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
Expand Down
11 changes: 11 additions & 0 deletions dbgpt/model/proxy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,17 @@ def __init__(
self.executor = executor or ThreadPoolExecutor()
self.proxy_tokenizer = proxy_tokenizer or TiktokenProxyTokenizer()

def __getstate__(self):
"""Customize the serialization of the object"""
state = self.__dict__.copy()
state.pop("executor")
return state

def __setstate__(self, state):
"""Customize the deserialization of the object"""
self.__dict__.update(state)
self.executor = ThreadPoolExecutor()

@classmethod
@abstractmethod
def new_client(
Expand Down
2 changes: 1 addition & 1 deletion dbgpt/serve/flow/api/variables_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ def get_variables(
StorageVariables(
key=key,
name=agent["name"],
label=agent["desc"],
label=agent["name"],
value=agent["name"],
scope=scope,
scope_key=scope_key,
Expand Down
3 changes: 3 additions & 0 deletions dbgpt/storage/metadata/_base_dao.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,9 @@ def _create_query_object(
else model_to_dict(query_request)
)
for key, value in query_dict.items():
if value and isinstance(value, (list, tuple, dict, set)):
# Skip the list, tuple, dict, set
continue
if value is not None and hasattr(model_cls, key):
if isinstance(value, list):
if len(value) > 0:
Expand Down
32 changes: 32 additions & 0 deletions dbgpt/util/net_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import errno
import socket
from typing import Set, Tuple


def _get_ip_address(address: str = "10.254.254.254:1") -> str:
Expand All @@ -22,3 +23,34 @@ def _get_ip_address(address: str = "10.254.254.254:1") -> str:
finally:
s.close()
return curr_address


async def _async_get_free_port(
port_range: Tuple[int, int], timeout: int, used_ports: Set[int]
):
import asyncio

loop = asyncio.get_running_loop()
return await loop.run_in_executor(
None, _get_free_port, port_range, timeout, used_ports
)


def _get_free_port(port_range: Tuple[int, int], timeout: int, used_ports: Set[int]):
import random

available_ports = set(range(port_range[0], port_range[1] + 1)) - used_ports
if not available_ports:
raise RuntimeError("No available ports in the specified range")

while available_ports:
port = random.choice(list(available_ports))
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", port))
used_ports.add(port)
return port
except OSError:
available_ports.remove(port)

raise RuntimeError("No available ports in the specified range")
85 changes: 85 additions & 0 deletions dbgpt/util/serialization/check.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import inspect
from io import StringIO
from typing import Any, Dict, Optional, TextIO

import cloudpickle


def check_serializable(
obj: Any, obj_name: str = "Object", error_msg: str = "Object is not serializable"
):
try:
cloudpickle.dumps(obj)
except Exception as e:
inspect_info = inspect_serializability(obj, obj_name)
msg = f"{error_msg}\n{inspect_info['report']}"
raise TypeError(msg) from e


class SerializabilityInspector:
def __init__(self, stream: Optional[TextIO] = None):
self.stream = stream or StringIO()
self.failures = {}
self.indent_level = 0

def log(self, message: str):
indent = " " * self.indent_level
self.stream.write(f"{indent}{message}\n")

def inspect(self, obj: Any, name: str, depth: int = 3) -> bool:
self.log(f"Inspecting '{name}'")
self.indent_level += 1

try:
cloudpickle.dumps(obj)
self.indent_level -= 1
return True
except Exception as e:
self.failures[name] = str(e)
self.log(f"Failure: {str(e)}")

if depth > 0:
if inspect.isfunction(obj) or inspect.ismethod(obj):
self._inspect_function(obj, depth - 1)
elif hasattr(obj, "__dict__"):
self._inspect_object(obj, depth - 1)

self.indent_level -= 1
return False

def _inspect_function(self, func, depth):
closure = inspect.getclosurevars(func)
for name, value in closure.nonlocals.items():
self.inspect(value, f"{func.__name__}.{name}", depth)
for name, value in closure.globals.items():
self.inspect(value, f"global:{name}", depth)

def _inspect_object(self, obj, depth):
for name, value in inspect.getmembers(obj):
if not name.startswith("__"):
self.inspect(value, f"{type(obj).__name__}.{name}", depth)

def get_report(self) -> str:
summary = "\nSummary of Serialization Failures:\n"
if not self.failures:
summary += "All components are serializable.\n"
else:
for name, error in self.failures.items():
summary += f" - {name}: {error}\n"

return self.stream.getvalue() + summary


def inspect_serializability(
obj: Any,
name: Optional[str] = None,
depth: int = 5,
stream: Optional[TextIO] = None,
) -> Dict[str, Any]:
inspector = SerializabilityInspector(stream)
success = inspector.inspect(obj, name or type(obj).__name__, depth)
return {
"success": success,
"failures": inspector.failures,
"report": inspector.get_report(),
}
Loading

0 comments on commit cbfd3d2

Please sign in to comment.