Skip to content

Commit 2d90519

Browse files
authored
refactor: Refactor for core SDK (#1092)
1 parent ba7248a commit 2d90519

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+236
-133
lines changed

Makefile

+13
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,19 @@ clean: ## Clean up the environment
8181
find . -type d -name '.pytest_cache' -delete
8282
find . -type d -name '.coverage' -delete
8383

84+
.PHONY: clean-dist
85+
clean-dist: ## Clean up the distribution
86+
rm -rf dist/ *.egg-info build/
87+
88+
.PHONY: package
89+
package: clean-dist ## Package the project for distribution
90+
IS_DEV_MODE=false python setup.py sdist bdist_wheel
91+
92+
.PHONY: upload
93+
upload: package ## Upload the package to PyPI
94+
# upload to testpypi: twine upload --repository testpypi dist/*
95+
twine upload dist/*
96+
8497
.PHONY: help
8598
help: ## Display this help screen
8699
@echo "Available commands:"

dbgpt/__init__.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,18 @@
1-
from dbgpt.component import BaseComponent, SystemApp
2-
3-
__ALL__ = ["SystemApp", "BaseComponent"]
1+
"""DB-GPT: Next Generation Data Interaction Solution with LLMs.
2+
"""
3+
from dbgpt import _version # noqa: E402
4+
from dbgpt.component import BaseComponent, SystemApp # noqa: F401
45

56
_CORE_LIBS = ["core", "rag", "model", "agent", "datasource", "vis", "storage", "train"]
67
_SERVE_LIBS = ["serve"]
78
_LIBS = _CORE_LIBS + _SERVE_LIBS
89

910

11+
__version__ = _version.version
12+
13+
__ALL__ = ["__version__", "SystemApp", "BaseComponent"]
14+
15+
1016
def __getattr__(name: str):
1117
# Lazy load
1218
import importlib

dbgpt/_version.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
version = "0.4.7"

dbgpt/agent/agents/expand/retrieve_summary_assistant_agent.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -579,7 +579,7 @@ def _num_token_from_text(self, text: str, model: str = "gpt-3.5-turbo-0613"):
579579
from dbgpt.agent.agents.agent import AgentContext
580580
from dbgpt.agent.agents.user_proxy_agent import UserProxyAgent
581581
from dbgpt.core.interface.llm import ModelMetadata
582-
from dbgpt.model import OpenAILLMClient
582+
from dbgpt.model.proxy import OpenAILLMClient
583583

584584
llm_client = OpenAILLMClient()
585585
context: AgentContext = AgentContext(

dbgpt/app/chat_adapter.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from typing import Dict, List, Tuple
1010

1111
from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
12-
from dbgpt.model.conversation import Conversation, get_conv_template
12+
from dbgpt.model.llm.conversation import Conversation, get_conv_template
1313

1414

1515
class BaseChatAdpter:
@@ -21,7 +21,7 @@ def match(self, model_path: str):
2121

2222
def get_generate_stream_func(self, model_path: str):
2323
"""Return the generate stream handler func"""
24-
from dbgpt.model.inference import generate_stream
24+
from dbgpt.model.llm.inference import generate_stream
2525

2626
return generate_stream
2727

dbgpt/app/scene/base_chat.py

+5-7
Original file line numberDiff line numberDiff line change
@@ -171,13 +171,13 @@ def llm_client(self) -> LLMClient:
171171

172172
async def call_llm_operator(self, request: ModelRequest) -> ModelOutput:
173173
llm_task = build_cached_chat_operator(self.llm_client, False, CFG.SYSTEM_APP)
174-
return await llm_task.call(call_data={"data": request})
174+
return await llm_task.call(call_data=request)
175175

176176
async def call_streaming_operator(
177177
self, request: ModelRequest
178178
) -> AsyncIterator[ModelOutput]:
179179
llm_task = build_cached_chat_operator(self.llm_client, True, CFG.SYSTEM_APP)
180-
async for out in await llm_task.call_stream(call_data={"data": request}):
180+
async for out in await llm_task.call_stream(call_data=request):
181181
yield out
182182

183183
def do_action(self, prompt_response):
@@ -251,11 +251,9 @@ async def _build_model_request(self) -> ModelRequest:
251251
str_history=self.prompt_template.str_history,
252252
request_context=req_ctx,
253253
)
254-
node_input = {
255-
"data": ChatComposerInput(
256-
messages=self.history_messages, prompt_dict=input_values
257-
)
258-
}
254+
node_input = ChatComposerInput(
255+
messages=self.history_messages, prompt_dict=input_values
256+
)
259257
# llm_messages = self.generate_llm_messages()
260258
model_request: ModelRequest = await node.call(call_data=node_input)
261259
model_request.context.cache_enable = self.model_cache_enable

dbgpt/app/scene/operator/app_operator.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ async def map(self, input_value: ChatComposerInput) -> ModelRequest:
8787
end_node: BaseOperator = self._sub_compose_dag.leaf_nodes[0]
8888
# Sub dag, use the same dag context in the parent dag
8989
messages = await end_node.call(
90-
call_data={"data": input_value}, dag_ctx=self.current_dag_context
90+
call_data=input_value, dag_ctx=self.current_dag_context
9191
)
9292
span_id = self._request_context.span_id
9393
model_request = ModelRequest.build_request(

dbgpt/component.py

+4
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
"""Component module for dbgpt.
2+
3+
Manages the lifecycle and registration of components.
4+
"""
15
from __future__ import annotations
26

37
import asyncio

dbgpt/core/awel/__init__.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
JoinOperator,
2323
MapOperator,
2424
ReduceStreamOperator,
25+
TriggerOperator,
2526
)
2627
from .operator.stream_operator import (
2728
StreamifyAbsOperator,
@@ -50,6 +51,7 @@
5051
"BaseOperator",
5152
"JoinOperator",
5253
"ReduceStreamOperator",
54+
"TriggerOperator",
5355
"MapOperator",
5456
"BranchOperator",
5557
"InputOperator",
@@ -150,4 +152,6 @@ def setup_dev_environment(
150152
for trigger in dag.trigger_nodes:
151153
trigger_manager.register_trigger(trigger)
152154
trigger_manager.after_register()
153-
uvicorn.run(app, host=host, port=port)
155+
if trigger_manager.keep_running():
156+
# Should keep running
157+
uvicorn.run(app, host=host, port=port)

dbgpt/core/awel/operator/base.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828

2929
F = TypeVar("F", bound=FunctionType)
3030

31-
CALL_DATA = Union[Dict, Dict[str, Dict]]
31+
CALL_DATA = Union[Dict[str, Any], Any]
3232

3333

3434
class WorkflowRunner(ABC, Generic[T]):
@@ -197,6 +197,8 @@ async def call(
197197
Returns:
198198
OUT: The output of the node after execution.
199199
"""
200+
if call_data:
201+
call_data = {"data": call_data}
200202
out_ctx = await self._runner.execute_workflow(
201203
self, call_data, exist_dag_ctx=dag_ctx
202204
)
@@ -242,6 +244,8 @@ async def call_stream(
242244
Returns:
243245
AsyncIterator[OUT]: An asynchronous iterator over the output stream.
244246
"""
247+
if call_data:
248+
call_data = {"data": call_data}
245249
out_ctx = await self._runner.execute_workflow(
246250
self, call_data, streaming_call=True, exist_dag_ctx=dag_ctx
247251
)

dbgpt/core/awel/task/base.py

+8
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,14 @@ def __bool__(self):
2828
SKIP_DATA = _EMPTY_DATA_TYPE()
2929
PLACEHOLDER_DATA = _EMPTY_DATA_TYPE()
3030

31+
32+
def is_empty_data(data: Any):
33+
"""Check if the data is empty."""
34+
if isinstance(data, _EMPTY_DATA_TYPE):
35+
return data in (EMPTY_DATA, SKIP_DATA)
36+
return False
37+
38+
3139
MapFunc = Union[Callable[[IN], OUT], Callable[[IN], Awaitable[OUT]]]
3240
ReduceFunc = Union[Callable[[IN], OUT], Callable[[IN], Awaitable[OUT]]]
3341
StreamFunc = Callable[[IN], Awaitable[AsyncIterator[OUT]]]

dbgpt/core/awel/task/task_impl.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
EMPTY_DATA,
2525
OUT,
2626
PLACEHOLDER_DATA,
27-
SKIP_DATA,
2827
InputContext,
2928
InputSource,
3029
MapFunc,
@@ -37,6 +36,7 @@
3736
TaskState,
3837
TransformFunc,
3938
UnStreamFunc,
39+
is_empty_data,
4040
)
4141

4242
logger = logging.getLogger(__name__)
@@ -99,7 +99,7 @@ def new_output(self) -> TaskOutput[T]:
9999
@property
100100
def is_empty(self) -> bool:
101101
"""Return True if the output data is empty."""
102-
return self._data == EMPTY_DATA or self._data == SKIP_DATA
102+
return is_empty_data(self._data)
103103

104104
@property
105105
def is_none(self) -> bool:
@@ -171,7 +171,7 @@ def is_stream(self) -> bool:
171171
@property
172172
def is_empty(self) -> bool:
173173
"""Return True if the output data is empty."""
174-
return self._data == EMPTY_DATA or self._data == SKIP_DATA
174+
return is_empty_data(self._data)
175175

176176
@property
177177
def is_none(self) -> bool:
@@ -330,7 +330,7 @@ def _read_data(self, task_ctx: TaskContext) -> Any:
330330
"""
331331
call_data = task_ctx.call_data
332332
data = call_data.get("data", EMPTY_DATA) if call_data else EMPTY_DATA
333-
if data == EMPTY_DATA:
333+
if is_empty_data(data):
334334
raise ValueError("No call data for current SimpleCallDataInputSource")
335335
return data
336336

dbgpt/core/awel/trigger/http_trigger.py

+11-11
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,8 @@
11
"""Http trigger for AWEL."""
2-
from __future__ import annotations
3-
42
import logging
53
from enum import Enum
64
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, Union, cast
75

8-
from starlette.requests import Request
9-
106
from dbgpt._private.pydantic import BaseModel
117

128
from ..dag.base import DAG
@@ -15,9 +11,10 @@
1511

1612
if TYPE_CHECKING:
1713
from fastapi import APIRouter
14+
from starlette.requests import Request
1815

19-
RequestBody = Union[Type[Request], Type[BaseModel], Type[str]]
20-
StreamingPredictFunc = Callable[[Union[Request, BaseModel]], bool]
16+
RequestBody = Union[Type[Request], Type[BaseModel], Type[str]]
17+
StreamingPredictFunc = Callable[[Union[Request, BaseModel]], bool]
2118

2219
logger = logging.getLogger(__name__)
2320

@@ -32,9 +29,9 @@ def __init__(
3229
self,
3330
endpoint: str,
3431
methods: Optional[Union[str, List[str]]] = "GET",
35-
request_body: Optional[RequestBody] = None,
32+
request_body: Optional["RequestBody"] = None,
3633
streaming_response: bool = False,
37-
streaming_predict_func: Optional[StreamingPredictFunc] = None,
34+
streaming_predict_func: Optional["StreamingPredictFunc"] = None,
3835
response_model: Optional[Type] = None,
3936
response_headers: Optional[Dict[str, str]] = None,
4037
response_media_type: Optional[str] = None,
@@ -69,6 +66,7 @@ def mount_to_router(self, router: "APIRouter") -> None:
6966
router (APIRouter): The router to mount the trigger.
7067
"""
7168
from fastapi import Depends
69+
from starlette.requests import Request
7270

7371
methods = [self._methods] if isinstance(self._methods, str) else self._methods
7472

@@ -114,8 +112,10 @@ async def route_function(body=Depends(_request_body_dependency)):
114112

115113

116114
async def _parse_request_body(
117-
request: Request, request_body_cls: Optional[RequestBody]
115+
request: "Request", request_body_cls: Optional["RequestBody"]
118116
):
117+
from starlette.requests import Request
118+
119119
if not request_body_cls:
120120
return None
121121
if request_body_cls == Request:
@@ -152,7 +152,7 @@ async def _trigger_dag(
152152
raise ValueError("HttpTrigger just support one leaf node in dag")
153153
end_node = cast(BaseOperator, leaf_nodes[0])
154154
if not streaming_response:
155-
return await end_node.call(call_data={"data": body})
155+
return await end_node.call(call_data=body)
156156
else:
157157
headers = response_headers
158158
media_type = response_media_type if response_media_type else "text/event-stream"
@@ -163,7 +163,7 @@ async def _trigger_dag(
163163
"Connection": "keep-alive",
164164
"Transfer-Encoding": "chunked",
165165
}
166-
generator = await end_node.call_stream(call_data={"data": body})
166+
generator = await end_node.call_stream(call_data=body)
167167
background_tasks = BackgroundTasks()
168168
background_tasks.add_task(dag._after_dag_end)
169169
return StreamingResponse(

dbgpt/core/awel/trigger/trigger_manager.py

+26
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,14 @@ class TriggerManager(ABC):
2424
def register_trigger(self, trigger: Any) -> None:
2525
"""Register a trigger to current manager."""
2626

27+
def keep_running(self) -> bool:
28+
"""Whether keep running.
29+
30+
Returns:
31+
bool: Whether keep running, True means keep running, False means stop.
32+
"""
33+
return False
34+
2735

2836
class HttpTriggerManager(TriggerManager):
2937
"""Http trigger manager.
@@ -64,6 +72,8 @@ def register_trigger(self, trigger: Any) -> None:
6472
self._trigger_map[trigger_id] = trigger
6573

6674
def _init_app(self, system_app: SystemApp):
75+
if not self.keep_running():
76+
return
6777
logger.info(
6878
f"Include router {self._router} to prefix path {self._router_prefix}"
6979
)
@@ -72,6 +82,14 @@ def _init_app(self, system_app: SystemApp):
7282
raise RuntimeError("System app not initialized")
7383
app.include_router(self._router, prefix=self._router_prefix, tags=["AWEL"])
7484

85+
def keep_running(self) -> bool:
86+
"""Whether keep running.
87+
88+
Returns:
89+
bool: Whether keep running, True means keep running, False means stop.
90+
"""
91+
return len(self._trigger_map) > 0
92+
7593

7694
class DefaultTriggerManager(TriggerManager, BaseComponent):
7795
"""Default trigger manager for AWEL.
@@ -105,3 +123,11 @@ def after_register(self) -> None:
105123
"""After register, init the trigger manager."""
106124
if self.system_app:
107125
self.http_trigger._init_app(self.system_app)
126+
127+
def keep_running(self) -> bool:
128+
"""Whether keep running.
129+
130+
Returns:
131+
bool: Whether keep running, True means keep running, False means stop.
132+
"""
133+
return self.http_trigger.keep_running()

dbgpt/core/interface/operator/composer_operator.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ async def map(self, input_value: ChatComposerInput) -> ModelRequest:
7070
end_node: BaseOperator = cast(BaseOperator, self._sub_compose_dag.leaf_nodes[0])
7171
# Sub dag, use the same dag context in the parent dag
7272
return await end_node.call(
73-
call_data={"data": input_value}, dag_ctx=self.current_dag_context
73+
call_data=input_value, dag_ctx=self.current_dag_context
7474
)
7575

7676
def _build_composer_dag(self) -> DAG:

dbgpt/core/interface/operator/prompt_operator.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ class PromptBuilderOperator(
150150
)
151151
)
152152
153-
single_input = {"data": {"dialect": "mysql"}}
153+
single_input = {"dialect": "mysql"}
154154
single_expected_messages = [
155155
ModelMessage(
156156
content="Please write a mysql SQL count the length of a field",

dbgpt/model/__init__.py

+10-7
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
1-
from dbgpt.model.cluster.client import DefaultLLMClient
1+
try:
2+
from dbgpt.model.cluster.client import DefaultLLMClient
3+
except ImportError as exc:
4+
# logging.warning("Can't import dbgpt.model.DefaultLLMClient")
5+
DefaultLLMClient = None
26

3-
# from dbgpt.model.utils.chatgpt_utils import OpenAILLMClient
4-
from dbgpt.model.proxy.llms.chatgpt import OpenAILLMClient
57

6-
__ALL__ = [
7-
"DefaultLLMClient",
8-
"OpenAILLMClient",
9-
]
8+
_exports = []
9+
if DefaultLLMClient:
10+
_exports.append("DefaultLLMClient")
11+
12+
__ALL__ = _exports

0 commit comments

Comments
 (0)