Skip to content

Commit

Permalink
Process functions: Support class member functions as process functions
Browse files Browse the repository at this point in the history
The `FunctionProcess` class that is generated dynamically from the
process function signature used the function's `__name__` attribute to
construct the new dynamic type. This prevented process functions from
being defined as class member methods. By using `__qualname__` instead,
this is now enabled and allows for example the following:

    class CalcFunctionWorkChain(WorkChain):

        @classmethod
        def define(cls, spec):
            super().define(spec)
            spec.input('x')
            spec.input('y')
            spec.output('sum')
            spec.outline(
                cls.run_compute_sum,
            )

        @staticmethod
        @calcfunction
        def compute_sum(x, y):
            return x + y

        def run_compute_sum(self):
            self.out('sum', self.compute_sum(self.inputs.x, self.inputs.y))

The changes also allow these class member process functions to be valid
cache sources.
  • Loading branch information
dev-zero authored and sphuber committed Feb 1, 2023
1 parent b874e41 commit 2af3028
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 10 deletions.
3 changes: 2 additions & 1 deletion aiida/engine/processes/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,9 +280,10 @@ def _define(cls, spec): # pylint: disable=unused-argument
spec.outputs.valid_type = (Data, dict)

return type(
func.__name__, (FunctionProcess,), {
func.__qualname__, (FunctionProcess,), {
'__module__': func.__module__,
'__name__': func.__name__,
'__qualname__': func.__qualname__,
'_func': staticmethod(func),
Process.define.__name__: classmethod(_define),
'_func_args': args,
Expand Down
25 changes: 18 additions & 7 deletions aiida/orm/nodes/process/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,14 +219,25 @@ def process_class(self) -> Type['Process']:
except exceptions.EntryPointError as exception:
raise ValueError(
f'could not load process class for entry point `{self.process_type}` for Node<{self.pk}>: {exception}'
)
) from exception
except ValueError:
try:
import importlib
module_name, class_name = self.process_type.rsplit('.', 1)
module = importlib.import_module(module_name)
process_class = getattr(module, class_name)
except (AttributeError, ValueError, ImportError) as exception:
import importlib

def str_rsplit_iter(string, sep='.'):
components = string.split(sep)
for idx in range(1, len(components)):
yield sep.join(components[:-idx]), components[-idx:]

for module_name, class_names in str_rsplit_iter(self.process_type):
try:
module = importlib.import_module(module_name)
process_class = module
for objname in class_names:
process_class = getattr(process_class, objname)
break
except (AttributeError, ValueError, ImportError):
pass
else:
raise ValueError(
f'could not load process class from `{self.process_type}` for Node<{self.pk}>: {exception}'
)
Expand Down
4 changes: 2 additions & 2 deletions aiida/plugins/entry_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,8 @@ def parse_entry_point_string(entry_point_string: str) -> Tuple[str, str]:

try:
group, name = entry_point_string.split(ENTRY_POINT_STRING_SEPARATOR)
except ValueError:
raise ValueError('invalid entry_point_string format')
except ValueError as exc:
raise ValueError(f'invalid entry_point_string format: {entry_point_string}') from exc

return group, name

Expand Down

0 comments on commit 2af3028

Please sign in to comment.