Skip to content

Commit

Permalink
Fix #175
Browse files Browse the repository at this point in the history
The previous code has the assumption that all type hints which have an __origin__ attribute also possess an _name attribute. This is not true for certain built-in generic types like list, dict, etc. in Python, particularly when accessed via the __origin__ attribute, which is a legacy of Python's type hinting evolution.
  • Loading branch information
superstar54 committed Jul 17, 2024
1 parent 24c9239 commit e49393c
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 34 deletions.
6 changes: 4 additions & 2 deletions aiida_workgraph/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,8 +535,10 @@ def add_imports(type_hint):
type_hint, "__origin__"
): # This checks for higher-order types like List, Dict
module_name = type_hint.__module__
type_name = type_hint._name
for arg in type_hint.__args__:
type_name = (
getattr(type_hint, "_name", None) or type_hint.__origin__.__name__
)
for arg in getattr(type_hint, "__args__", []):
if arg is type(None): # noqa: E721
continue
add_imports(arg) # Recursively add imports for each argument
Expand Down
53 changes: 21 additions & 32 deletions tests/test_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,41 +66,30 @@ def add(x, **kwargs):
assert wg.tasks["add"].outputs["result"].value.value == 6


def test_PythonJob_typing(fixture_localhost):
def test_PythonJob_typing():
"""Test function with typing."""
from numpy import array

def add(x: array, y: array) -> array:
return x + y

def multiply(x: Any, y: Any) -> Any:
return x * y

wg = WorkGraph("test_PythonJob")
wg.add_task("PythonJob", function=add, name="add")
wg.add_task(
"PythonJob", function=multiply, name="multiply", x=wg.tasks["add"].outputs[0]
)
#
metadata = {
"options": {
"custom_scheduler_commands": "# test",
# "custom_scheduler_commands": 'module load anaconda\nconda activate py3.11\n',
}
from ase import Atoms
from aiida_workgraph.utils import get_required_imports
from typing import List

def generate_structures(
structures: List[Atoms],
strain_lst: list,
data: array,
strain_lst1: list = None,
data1: array = None,
structure1: Atoms = None,
) -> list[Atoms]:
pass

modules = get_required_imports(generate_structures)
assert modules == {
"ase.atoms": {"Atoms"},
"typing": {"List"},
"builtins": {"list"},
"numpy": {"array"},
}
wg.run(
inputs={
"add": {
"x": array([1, 2]),
"y": array([2, 3]),
"computer": "localhost",
"metadata": metadata,
},
"multiply": {"y": 4, "computer": "localhost", "metadata": metadata},
},
# wait=True,
)
assert (wg.tasks["multiply"].outputs["result"].value.value == array([12, 20])).all()


def test_PythonJob_outputs(fixture_localhost):
Expand Down

0 comments on commit e49393c

Please sign in to comment.