Skip to content

Commit

Permalink
feat(core): Support complex variables parsing
Browse files Browse the repository at this point in the history
  • Loading branch information
fangyinc committed Aug 11, 2024
1 parent 3b4b34c commit a348e6c
Show file tree
Hide file tree
Showing 8 changed files with 827 additions and 93 deletions.
27 changes: 20 additions & 7 deletions dbgpt/core/awel/flow/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,27 +380,40 @@ def pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]:
return values

@classmethod
def _covert_to_real_type(cls, type_cls: str, v: Any):
def _covert_to_real_type(cls, type_cls: str, v: Any) -> Any:
if type_cls and v is not None:
typed_value: Any = v
try:
# Try to convert the value to the type.
if type_cls == "builtins.str":
return str(v)
typed_value = str(v)
elif type_cls == "builtins.int":
return int(v)
typed_value = int(v)
elif type_cls == "builtins.float":
return float(v)
typed_value = float(v)
elif type_cls == "builtins.bool":
if str(v).lower() in ["false", "0", "", "no", "off"]:
return False
return bool(v)
typed_value = bool(v)
return typed_value
except ValueError:
raise ValidationError(f"Value '{v}' is not valid for type {type_cls}")
return v

def get_typed_value(self) -> Any:
"""Get the typed value."""
return self._covert_to_real_type(self.type_cls, self.value)
"""Get the typed value.
Returns:
Any: The typed value. VariablesPlaceHolder if the value is a variable
string. Otherwise, the real type value.
"""
from ...interface.variables import VariablesPlaceHolder, is_variable_string

is_variables = is_variable_string(self.value) if self.value else False
if is_variables and self.value is not None and isinstance(self.value, str):
return VariablesPlaceHolder(self.name, self.value)
else:
return self._covert_to_real_type(self.type_cls, self.value)

def get_typed_default(self) -> Any:
"""Get the typed default."""
Expand Down
223 changes: 223 additions & 0 deletions dbgpt/core/awel/flow/tests/test_flow_variables.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
import json
from typing import cast

import pytest

from dbgpt.core.awel import BaseOperator, DAGVar, MapOperator
from dbgpt.core.awel.flow import (
IOField,
OperatorCategory,
Parameter,
VariablesDynamicOptions,
ViewMetadata,
ui,
)
from dbgpt.core.awel.flow.flow_factory import FlowData, FlowFactory, FlowPanel

from ...tests.conftest import variables_provider


class MyVariablesOperator(MapOperator[str, str]):
metadata = ViewMetadata(
label="My Test Variables Operator",
name="my_test_variables_operator",
category=OperatorCategory.EXAMPLE,
description="An example flow operator that includes a variables option.",
parameters=[
Parameter.build_from(
"OpenAI API Key",
"openai_api_key",
type=str,
placeholder="Please select the OpenAI API key",
description="The OpenAI API key to use.",
options=VariablesDynamicOptions(),
ui=ui.UIPasswordInput(
key="dbgpt.model.openai.api_key",
),
),
Parameter.build_from(
"Model",
"model",
type=str,
placeholder="Please select the model",
description="The model to use.",
options=VariablesDynamicOptions(),
ui=ui.UIVariablesInput(
key="dbgpt.model.openai.model",
),
),
],
inputs=[
IOField.build_from(
"User Name",
"user_name",
str,
description="The name of the user.",
),
],
outputs=[
IOField.build_from(
"Model info",
"model",
str,
description="The model info.",
),
],
)

def __init__(self, openai_api_key: str, model: str, **kwargs):
super().__init__(**kwargs)
self._openai_api_key = openai_api_key
self._model = model

async def map(self, user_name: str) -> str:
dict_dict = {
"openai_api_key": self._openai_api_key,
"model": self._model,
}
json_data = json.dumps(dict_dict, ensure_ascii=False)
return "Your name is %s, and your model info is %s." % (user_name, json_data)


class EndOperator(MapOperator[str, str]):
metadata = ViewMetadata(
label="End Operator",
name="end_operator",
category=OperatorCategory.EXAMPLE,
description="An example flow operator that ends the flow.",
parameters=[],
inputs=[
IOField.build_from(
"Input",
"input",
str,
description="The input to the end operator.",
),
],
outputs=[
IOField.build_from(
"Output",
"output",
str,
description="The output of the end operator.",
),
],
)

async def map(self, input: str) -> str:
return f"End operator received input: {input}"


@pytest.fixture
def json_flow():
operators = [MyVariablesOperator, EndOperator]
metadata_list = [operator.metadata.to_dict() for operator in operators]
node_names = {}
name_to_parameters_dict = {
"my_test_variables_operator": {
"openai_api_key": "${dbgpt.model.openai.api_key:my_key@global}",
"model": "${dbgpt.model.openai.model:default_model@global}",
}
}
name_to_metadata_dict = {metadata["name"]: metadata for metadata in metadata_list}
ui_nodes = []
for metadata in metadata_list:
type_name = metadata["type_name"]
name = metadata["name"]
id = metadata["id"]
if type_name in node_names:
raise ValueError(f"Duplicate node type name: {type_name}")
# Replace id to flow data id.
metadata["id"] = f"{id}_0"
parameters = metadata["parameters"]
parameters_dict = name_to_parameters_dict.get(name, {})
for parameter in parameters:
parameter_name = parameter["name"]
if parameter_name in parameters_dict:
parameter["value"] = parameters_dict[parameter_name]
ui_nodes.append(
{
"width": 288,
"height": 352,
"id": metadata["id"],
"position": {
"x": -149.98120112708142,
"y": 666.9468497341901,
"zoom": 0.0,
},
"type": "customNode",
"position_absolute": {
"x": -149.98120112708142,
"y": 666.9468497341901,
"zoom": 0.0,
},
"data": metadata,
}
)

ui_edges = []
source_id = name_to_metadata_dict["my_test_variables_operator"]["id"]
target_id = name_to_metadata_dict["end_operator"]["id"]
ui_edges.append(
{
"source": source_id,
"target": target_id,
"source_order": 0,
"target_order": 0,
"id": f"{source_id}|{target_id}",
"source_handle": f"{source_id}|outputs|0",
"target_handle": f"{target_id}|inputs|0",
"type": "buttonedge",
}
)
return {
"nodes": ui_nodes,
"edges": ui_edges,
"viewport": {
"x": 509.2191773722104,
"y": -66.11286175905718,
"zoom": 1.252741002590748,
},
}


@pytest.mark.asyncio
@pytest.mark.parametrize(
"variables_provider",
[
(
{
"vars": {
"openai_api_key": {
"key": "${dbgpt.model.openai.api_key:my_key@global}",
"value": "my_openai_api_key",
"value_type": "str",
"category": "secret",
},
"model": {
"key": "${dbgpt.model.openai.model:default_model@global}",
"value": "GPT-4o",
"value_type": "str",
},
}
}
),
],
indirect=["variables_provider"],
)
async def test_build_flow(json_flow, variables_provider):
DAGVar.set_variables_provider(variables_provider)
flow_data = FlowData(**json_flow)
flow_panel = FlowPanel(
label="My Test Flow", name="my_test_flow", flow_data=flow_data, state="deployed"
)
factory = FlowFactory()
dag = factory.build(flow_panel)

leaf_node: BaseOperator = cast(BaseOperator, dag.leaf_nodes[0])
result = await leaf_node.call("Alice")
assert (
result
== "End operator received input: Your name is Alice, and your model info is "
'{"openai_api_key": "my_openai_api_key", "model": "GPT-4o"}.'
)
2 changes: 1 addition & 1 deletion dbgpt/core/awel/flow/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
"time_picker",
"tree_select",
"upload",
"variable",
"variables",
"password",
"code_editor",
]
Expand Down
1 change: 1 addition & 0 deletions dbgpt/core/awel/operators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,7 @@ async def _resolve_variables(self, _: DAGContext):

if not self._variables_provider:
return
# TODO: Resolve variables parallel
for attr, value in self.__dict__.items():
if isinstance(value, VariablesPlaceHolder):
resolved_value = await self.blocking_func_to_async(
Expand Down
43 changes: 35 additions & 8 deletions dbgpt/core/awel/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
from contextlib import asynccontextmanager, contextmanager
from contextlib import asynccontextmanager
from typing import AsyncIterator, List

import pytest
import pytest_asyncio

from .. import (
DAGContext,
DefaultWorkflowRunner,
InputOperator,
SimpleInputSource,
TaskState,
WorkflowRunner,
from ...interface.variables import (
StorageVariables,
StorageVariablesProvider,
VariablesIdentifier,
)
from .. import DefaultWorkflowRunner, InputOperator, SimpleInputSource
from ..task.task_impl import _is_async_iterator


Expand Down Expand Up @@ -102,3 +100,32 @@ async def stream_input_nodes(request):
param["is_stream"] = True
async with _create_input_node(**param) as input_nodes:
yield input_nodes


@asynccontextmanager
async def _create_variables(**kwargs):
vp = StorageVariablesProvider()
vars = kwargs.get("vars")
if vars and isinstance(vars, dict):
for param_key, param_var in vars.items():
key = param_var.get("key")
value = param_var.get("value")
value_type = param_var.get("value_type")
category = param_var.get("category", "common")
id = VariablesIdentifier.from_str_identifier(key)
vp.save(
StorageVariables.from_identifier(
id, value, value_type, label="", category=category
)
)
else:
raise ValueError("vars is required.")

yield vp


@pytest_asyncio.fixture
async def variables_provider(request):
param = getattr(request, "param", {})
async with _create_variables(**param) as vp:
yield vp
8 changes: 4 additions & 4 deletions dbgpt/core/awel/tests/test_dag_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ async def _create_variables(**kwargs):
id, value, value_type, label="", category=category
)
)
variables[param_key] = VariablesPlaceHolder(param_key, key, value_type)
variables[param_key] = VariablesPlaceHolder(param_key, key)
else:
raise ValueError("vars is required.")

Expand Down Expand Up @@ -85,17 +85,17 @@ async def test_default_dag(default_dag: DAG):
{
"vars": {
"int_var": {
"key": "int_key@my_int_var@global",
"key": "${int_key:my_int_var@global}",
"value": 0,
"value_type": "int",
},
"str_var": {
"key": "str_key@my_str_var@global",
"key": "${str_key:my_str_var@global}",
"value": "1",
"value_type": "str",
},
"secret": {
"key": "secret_key@my_secret_var@global",
"key": "${secret_key:my_secret_var@global}",
"value": "2131sdsdf",
"value_type": "str",
"category": "secret",
Expand Down
Loading

0 comments on commit a348e6c

Please sign in to comment.