Skip to content

Commit

Permalink
Set the default value to an empty dictionary for namespace socket (#317)
Browse files Browse the repository at this point in the history
* update docstring of `update_nested_dict`
* When creating a task from a workgraph, the top level input of a task should be a namespace
* Fix aiida-shell version to 0.7.3
* Set the default value to an empty dictionary for namespace socket
  • Loading branch information
superstar54 authored and agoscinski committed Sep 19, 2024
1 parent 61a5b98 commit f689ef1
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 11 deletions.
14 changes: 12 additions & 2 deletions aiida_workgraph/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,15 +407,25 @@ def build_task_from_workgraph(wg: any) -> Task:
# add all the inputs/outputs from the tasks in the workgraph
for task in wg.tasks:
# inputs
inputs.append({"identifier": "workgraph.any", "name": f"{task.name}"})
inputs.append(
{
"identifier": "workgraph.namespace",
"name": f"{task.name}",
}
)
for socket in task.inputs:
if socket.name == "_wait":
continue
inputs.append(
{"identifier": socket.identifier, "name": f"{task.name}.{socket.name}"}
)
# outputs
outputs.append({"identifier": "workgraph.any", "name": f"{task.name}"})
outputs.append(
{
"identifier": "workgraph.namespace",
"name": f"{task.name}",
}
)
for socket in task.outputs:
if socket.name in ["_wait", "_outputs"]:
continue
Expand Down
2 changes: 2 additions & 0 deletions aiida_workgraph/sockets/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ def __init__(
**kwargs: Any
) -> None:
super().__init__(name, node, type, index, uuid=uuid)
# Set the default value to an empty dictionary
kwargs.setdefault("default", {})
self.add_property("workgraph.any", name, **kwargs)


Expand Down
40 changes: 32 additions & 8 deletions aiida_workgraph/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,20 +149,44 @@ def get_nested_dict(d: Dict, name: str, **kwargs) -> Any:
return current


def update_nested_dict(d: Dict[str, Any], key: str, value: Any) -> None:
def update_nested_dict(d: Optional[Dict[str, Any]], key: str, value: Any) -> None:
"""
d = {}
key = "base.pw.parameters"
value = 2
will give:
d = {"base": {"pw": {"parameters": 2}}
Update or create a nested dictionary structure based on a dotted key path.
This function allows updating a nested dictionary or creating one if `d` is `None`.
Given a dictionary and a key path (e.g., "base.pw.parameters"), it will traverse
or create the necessary nested structure to set the provided value at the specified
key location. If intermediate dictionaries do not exist, they will be created.
If the resulting dictionary is empty, it is set to `None`.
Args:
d (Dict[str, Any] | None): The dictionary to update, which can be `None`.
If `None`, an empty dictionary will be created.
key (str): A dotted key path string representing the nested structure.
value (Any): The value to set at the specified key.
Example:
d = None
key = "base.pw.parameters"
value = 2
After running:
update_nested_dict(d, key, value)
The result will be:
d = {"base": {"pw": {"parameters": 2}}}
Edge Case:
If the resulting dictionary is empty after the update, it will be set to `None`.
"""
keys = key.split(".")
current = d
current = {} if current is None else current
current = d if d is not None else {}
for k in keys[:-1]:
current = current.setdefault(k, {})
current[keys[-1]] = value
# if current is empty, set it to None
if not current:
current = None
return current


def is_empty(value: Any) -> bool:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ dependencies = [
"node-graph>=0.1.0",
"aiida-core>=2.3",
"cloudpickle",
"aiida-shell",
"aiida-shell==0.7.3",
"fastapi",
"uvicorn",
"pydantic_settings",
Expand Down
1 change: 1 addition & 0 deletions tests/test_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def test(a, b=1, **kwargs):
test1 = test.node()
assert test1.inputs["kwargs"].link_limit == 1e6
assert test1.inputs["kwargs"].identifier == "workgraph.namespace"
assert test1.inputs["kwargs"].property.value == {}


@pytest.mark.parametrize(
Expand Down
1 change: 1 addition & 0 deletions tests/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def test_build_task_from_workgraph(
wg = WorkGraph("build_task_from_workgraph")
add1_task = wg.add_task(decorated_add, name="add1", x=1, y=3)
wg_task = wg.add_task(wg_calcfunction, name="wg_calcfunction")
assert wg_task.inputs["sumdiff1"].value == {}
wg.add_task(decorated_add, name="add2", y=3)
wg.add_link(add1_task.outputs["result"], wg_task.inputs["sumdiff1.x"])
wg.add_link(wg_task.outputs["sumdiff2.sum"], wg.tasks["add2"].inputs["x"])
Expand Down

0 comments on commit f689ef1

Please sign in to comment.