Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: update api #1972

Merged
merged 2 commits into from
Sep 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .dockerignore
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
.env
.git/
./.mypy_cache/
models/
plugins/
pilot/data
pilot/message
logs/
venv/
web/node_modules/
web/.next/
web/.env
docs/node_modules/
build/
docs/build/
Expand Down
4 changes: 4 additions & 0 deletions dbgpt/_private/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,10 @@ def __init__(self) -> None:
os.getenv("MULTI_INSTANCE", "False").lower() == "true"
)

self.SCHEDULER_ENABLED = (
os.getenv("SCHEDULER_ENABLED", "True").lower() == "true"
)

@property
def local_db_manager(self) -> "ConnectorManager":
from dbgpt.datasource.manages import ConnectorManager
Expand Down
4 changes: 3 additions & 1 deletion dbgpt/app/initialization/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@ def __init__(
system_app: SystemApp,
scheduler_delay_ms: int = 5000,
scheduler_interval_ms: int = 1000,
scheduler_enable: bool = True,
):
super().__init__(system_app)
self.system_app = system_app
self._scheduler_interval_ms = scheduler_interval_ms
self._scheduler_delay_ms = scheduler_delay_ms
self._stop_event = threading.Event()
self._scheduler_enable = scheduler_enable

def init_app(self, system_app: SystemApp):
self.system_app = system_app
Expand All @@ -39,7 +41,7 @@ def before_stop(self):

def _scheduler(self):
time.sleep(self._scheduler_delay_ms / 1000)
while not self._stop_event.is_set():
while self._scheduler_enable and not self._stop_event.is_set():
try:
schedule.run_pending()
except Exception as e:
Expand Down
37 changes: 37 additions & 0 deletions dbgpt/core/awel/dag/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,9 @@ class DAGVar:
_executor: Optional[Executor] = None

_variables_provider: Optional["VariablesProvider"] = None
# Whether check serializable for AWEL, it will be set to True when running AWEL
# operator in remote environment
_check_serializable: Optional[bool] = None

@classmethod
def enter_dag(cls, dag) -> None:
Expand Down Expand Up @@ -257,6 +260,24 @@ def set_variables_provider(cls, variables_provider: "VariablesProvider") -> None
"""
cls._variables_provider = variables_provider

@classmethod
def get_check_serializable(cls) -> Optional[bool]:
"""Get the check serializable flag.

Returns:
Optional[bool]: The check serializable flag
"""
return cls._check_serializable

@classmethod
def set_check_serializable(cls, check_serializable: bool) -> None:
"""Set the check serializable flag.

Args:
check_serializable (bool): The check serializable flag to set
"""
cls._check_serializable = check_serializable


class DAGLifecycle:
"""The lifecycle of DAG."""
Expand Down Expand Up @@ -286,6 +307,7 @@ def __init__(
node_name: Optional[str] = None,
system_app: Optional[SystemApp] = None,
executor: Optional[Executor] = None,
check_serializable: Optional[bool] = None,
**kwargs,
) -> None:
"""Initialize a DAGNode.
Expand All @@ -311,6 +333,7 @@ def __init__(
node_id = self._dag._new_node_id()
self._node_id: Optional[str] = node_id
self._node_name: Optional[str] = node_name
self._check_serializable = check_serializable
if self._dag:
self._dag._append_node(self)

Expand Down Expand Up @@ -486,6 +509,20 @@ def __str__(self):
"""Return the string of current DAGNode."""
return self.__repr__()

@classmethod
def _do_check_serializable(cls, obj: Any, obj_name: str = "Object"):
"""Check whether the current DAGNode is serializable."""
from dbgpt.util.serialization.check import check_serializable

check_serializable(obj, obj_name)

@property
def check_serializable(self) -> bool:
"""Whether check serializable for current DAGNode."""
if self._check_serializable is not None:
return self._check_serializable or False
return DAGVar.get_check_serializable() or False


def _build_task_key(task_name: str, key: str) -> str:
return f"{task_name}___$$$$$$___{key}"
Expand Down
19 changes: 18 additions & 1 deletion dbgpt/core/awel/operators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,12 +193,29 @@ def __init__(
self.incremental_output = bool(kwargs["incremental_output"])
if "output_format" in kwargs:
self.output_format = kwargs["output_format"]

self._runner: WorkflowRunner = runner
self._dag_ctx: Optional[DAGContext] = None
self._can_skip_in_branch = can_skip_in_branch
self._variables_provider = variables_provider

def __getstate__(self):
"""Customize the pickling process."""
state = self.__dict__.copy()
if "_runner" in state:
del state["_runner"]
if "_executor" in state:
del state["_executor"]
if "_system_app" in state:
del state["_system_app"]
return state

def __setstate__(self, state):
"""Customize the unpickling process."""
self.__dict__.update(state)
self._runner = default_runner
self._system_app = DAGVar.get_current_system_app()
self._executor = DAGVar.get_executor()

@property
def current_dag_context(self) -> DAGContext:
"""Return the current DAG context."""
Expand Down
17 changes: 17 additions & 0 deletions dbgpt/core/awel/operators/common_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@ def __init__(
super().__init__(can_skip_in_branch=can_skip_in_branch, **kwargs)
if not callable(combine_function):
raise ValueError("combine_function must be callable")

if self.check_serializable:
super()._do_check_serializable(
combine_function,
f"JoinOperator: {self}, combine_function: {combine_function}",
)
self.combine_function = combine_function

async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
Expand Down Expand Up @@ -83,6 +89,11 @@ def __init__(self, reduce_function: Optional[ReduceFunc] = None, **kwargs):
super().__init__(**kwargs)
if reduce_function and not callable(reduce_function):
raise ValueError("reduce_function must be callable")
if reduce_function and self.check_serializable:
super()._do_check_serializable(
reduce_function, f"Operator: {self}, reduce_function: {reduce_function}"
)

self.reduce_function = reduce_function

async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
Expand Down Expand Up @@ -133,6 +144,12 @@ def __init__(self, map_function: Optional[MapFunc] = None, **kwargs):
super().__init__(**kwargs)
if map_function and not callable(map_function):
raise ValueError("map_function must be callable")

if map_function and self.check_serializable:
super()._do_check_serializable(
map_function, f"Operator: {self}, map_function: {map_function}"
)

self.map_function = map_function

async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
Expand Down
11 changes: 11 additions & 0 deletions dbgpt/model/proxy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,17 @@ def __init__(
self.executor = executor or ThreadPoolExecutor()
self.proxy_tokenizer = proxy_tokenizer or TiktokenProxyTokenizer()

def __getstate__(self):
"""Customize the serialization of the object"""
state = self.__dict__.copy()
state.pop("executor")
return state

def __setstate__(self, state):
"""Customize the deserialization of the object"""
self.__dict__.update(state)
self.executor = ThreadPoolExecutor()

@classmethod
@abstractmethod
def new_client(
Expand Down
2 changes: 1 addition & 1 deletion dbgpt/serve/flow/api/variables_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ def get_variables(
StorageVariables(
key=key,
name=agent["name"],
label=agent["desc"],
label=agent["name"],
value=agent["name"],
scope=scope,
scope_key=scope_key,
Expand Down
3 changes: 3 additions & 0 deletions dbgpt/storage/metadata/_base_dao.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,9 @@ def _create_query_object(
else model_to_dict(query_request)
)
for key, value in query_dict.items():
if value and isinstance(value, (list, tuple, dict, set)):
# Skip the list, tuple, dict, set
continue
if value is not None and hasattr(model_cls, key):
if isinstance(value, list):
if len(value) > 0:
Expand Down
32 changes: 32 additions & 0 deletions dbgpt/util/net_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import errno
import socket
from typing import Set, Tuple


def _get_ip_address(address: str = "10.254.254.254:1") -> str:
Expand All @@ -22,3 +23,34 @@ def _get_ip_address(address: str = "10.254.254.254:1") -> str:
finally:
s.close()
return curr_address


async def _async_get_free_port(
port_range: Tuple[int, int], timeout: int, used_ports: Set[int]
):
import asyncio

loop = asyncio.get_running_loop()
return await loop.run_in_executor(
None, _get_free_port, port_range, timeout, used_ports
)


def _get_free_port(port_range: Tuple[int, int], timeout: int, used_ports: Set[int]):
import random

available_ports = set(range(port_range[0], port_range[1] + 1)) - used_ports
if not available_ports:
raise RuntimeError("No available ports in the specified range")

while available_ports:
port = random.choice(list(available_ports))
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", port))
used_ports.add(port)
return port
except OSError:
available_ports.remove(port)

raise RuntimeError("No available ports in the specified range")
85 changes: 85 additions & 0 deletions dbgpt/util/serialization/check.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import inspect
from io import StringIO
from typing import Any, Dict, Optional, TextIO

import cloudpickle


def check_serializable(
obj: Any, obj_name: str = "Object", error_msg: str = "Object is not serializable"
):
try:
cloudpickle.dumps(obj)
except Exception as e:
inspect_info = inspect_serializability(obj, obj_name)
msg = f"{error_msg}\n{inspect_info['report']}"
raise TypeError(msg) from e


class SerializabilityInspector:
def __init__(self, stream: Optional[TextIO] = None):
self.stream = stream or StringIO()
self.failures = {}
self.indent_level = 0

def log(self, message: str):
indent = " " * self.indent_level
self.stream.write(f"{indent}{message}\n")

def inspect(self, obj: Any, name: str, depth: int = 3) -> bool:
self.log(f"Inspecting '{name}'")
self.indent_level += 1

try:
cloudpickle.dumps(obj)
self.indent_level -= 1
return True
except Exception as e:
self.failures[name] = str(e)
self.log(f"Failure: {str(e)}")

if depth > 0:
if inspect.isfunction(obj) or inspect.ismethod(obj):
self._inspect_function(obj, depth - 1)
elif hasattr(obj, "__dict__"):
self._inspect_object(obj, depth - 1)

self.indent_level -= 1
return False

def _inspect_function(self, func, depth):
closure = inspect.getclosurevars(func)
for name, value in closure.nonlocals.items():
self.inspect(value, f"{func.__name__}.{name}", depth)
for name, value in closure.globals.items():
self.inspect(value, f"global:{name}", depth)

def _inspect_object(self, obj, depth):
for name, value in inspect.getmembers(obj):
if not name.startswith("__"):
self.inspect(value, f"{type(obj).__name__}.{name}", depth)

def get_report(self) -> str:
summary = "\nSummary of Serialization Failures:\n"
if not self.failures:
summary += "All components are serializable.\n"
else:
for name, error in self.failures.items():
summary += f" - {name}: {error}\n"

return self.stream.getvalue() + summary


def inspect_serializability(
obj: Any,
name: Optional[str] = None,
depth: int = 5,
stream: Optional[TextIO] = None,
) -> Dict[str, Any]:
inspector = SerializabilityInspector(stream)
success = inspector.inspect(obj, name or type(obj).__name__, depth)
return {
"success": success,
"failures": inspector.failures,
"report": inspector.get_report(),
}
Loading
Loading