Skip to content

Commit

Permalink
Tests: Add tests for calcfunction as WorkChain class member methods
Browse files Browse the repository at this point in the history
The test defines a workchain that defines a calcfunction as a class
member in two ways:

 * A proper staticmethod
 * A class attribute

Both versions work, but the former is more correct and is the only one
that can be called from an instance of the workchain. The other has to
be invoked by retrieving the calcfunction as an attribute and then
calling it. The former is the only one that is offcially documented.

The tests verify that both methods of declaring the calcfunction can be
called within the `WorkChain` and that it also works when submitting the
workchain to the daemon. A test is also added that verifies that caching
of the calcfunctions functions properly.

The changes also broke one existing test. The test took a calcfunction,
originally defined on the module level, and modifying it within the test
function scope. The goal of the test was twofold:

 * Check that changing the source code would be recognized by the
   caching mechanism and so an invocation of the changed function would
   not be cached from an invocation of the original
 * Check that it is possible to cache from a function defined inside the
   scope of a function.

The changes to the dynamically built `FunctionProcess`, notably changing
the type from `func.__name__` to `func.__qualname__` stopped the second
point from working. In the original code, the type name would be simply
`tests.engine.test_calcfunctions.add_function`, both for the module
level function as well as the inlined function. However, with the change
this becomes:

 `tests.engine.test_calcfunctions.TestCalcFunction.test_calcfunction_caching_change_code.<locals>.add_calcfunction`

this can no longer be loaded by the `ProcessNode.process_class` property
and so `is_valid_cache` returns `False`, whereas in the original code
it was a valid cache as the process class could be loaded.

Arguably, the new code is more correct, but it is breaking. Before
inlined functions were valid cache sources, but that is no longer the
case. In exchange, class member functions are now valid cache sources
where they weren't before. Arguably, it is preferable to support class
member functions over inline functions.

The broken test is fixed by moving the inlined `add_calcfunction` to a
separate module such that it becomes a valid cache source again.
  • Loading branch information
sphuber committed Feb 1, 2023
1 parent 2af3028 commit e5ab1fb
Show file tree
Hide file tree
Showing 5 changed files with 131 additions and 13 deletions.
6 changes: 3 additions & 3 deletions aiida/orm/nodes/process/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def process_class(self) -> Type['Process']:
raise ValueError(
f'could not load process class for entry point `{self.process_type}` for Node<{self.pk}>: {exception}'
) from exception
except ValueError:
except ValueError as exception:
import importlib

def str_rsplit_iter(string, sep='.'):
Expand All @@ -239,8 +239,8 @@ def str_rsplit_iter(string, sep='.'):
pass
else:
raise ValueError(
f'could not load process class from `{self.process_type}` for Node<{self.pk}>: {exception}'
)
f'could not load process class from `{self.process_type}` for Node<{self.pk}>'
) from exception

return process_class

Expand Down
33 changes: 33 additions & 0 deletions docs/source/topics/processes/functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,39 @@ The question you should ask yourself is whether a potential problem merits throw
Or maybe, as in the example above, the problem is easily foreseeable and classifiable with a well defined exit status, in which case it might make more sense to return the exit code.
At the end one should think which solution makes it easier for a workflow calling the function to respond based on the result and what makes it easier to query for these specific failure modes.

As class member methods
=======================

.. versionadded:: 2.3

Process functions can also be declared as class member methods, for example as part of a :class:`~aiida.engine.processes.workchains.workchain.WorkChain`:

.. code-block:: python
class CalcFunctionWorkChain(WorkChain):
@classmethod
def define(cls, spec):
super().define(spec)
spec.input('x')
spec.input('y')
spec.output('sum')
spec.outline(
cls.run_compute_sum,
)
@staticmethod
@calcfunction
def compute_sum(x, y):
return x + y
def run_compute_sum(self):
self.out('sum', self.compute_sum(self.inputs.x, self.inputs.y))
In this example, the work chain declares a class method called ``compute_sum`` which is decorated with the ``calcfunction`` decorator to turn it into a calculation function.
It is important that the method is also decorated with the ``staticmethod`` (see the `Python documentation <https://docs.python.org/3/library/functions.html#staticmethod>`_) such that the work chain instance is not passed when the method is invoked.
The calcfunction can be called from a work chain step like any other class method, as is shown in the last line.


Provenance
==========
Expand Down
10 changes: 10 additions & 0 deletions tests/engine/calcfunctions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# -*- coding: utf-8 -*-
"""Definition of a calculation function used in ``test_calcfunctions.py``."""
from aiida.engine import calcfunction
from aiida.orm import Int


@calcfunction
def add_calcfunction(data):
"""Calcfunction mirroring a ``test_calcfunctions`` calcfunction but has a slightly different implementation."""
return Int(data.value + 2)
22 changes: 13 additions & 9 deletions tests/engine/test_calcfunctions.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,23 +100,27 @@ def test_calcfunction_caching(self):
assert cached.base.links.get_incoming().one().node.uuid == input_node.uuid

def test_calcfunction_caching_change_code(self):
"""Verify that changing the source codde of a calcfunction invalidates any existing cached nodes."""
result_original = self.test_calcfunction(self.default_int)
"""Verify that changing the source code of a calcfunction invalidates any existing cached nodes.
# Intentionally using the same name, to check that caching anyway
# distinguishes between the calcfunctions.
@calcfunction
def add_calcfunction(data): # pylint: disable=redefined-outer-name
"""This calcfunction has a different source code from the one created at the module level."""
return Int(data.value + 2)
The ``add_calcfunction`` of the ``calcfunctions`` module uses the exact same name as the one defined in this
test module, however, it has a slightly different implementation. Note that we have to define the duplicate in
a different module, because we cannot define it in the same module (as the name clashes, on purpose) and we
cannot inline the calcfunction in this test, since inlined process functions are not valid cache sources.
"""
from .calcfunctions import add_calcfunction # pylint: disable=redefined-outer-name

result_original = self.test_calcfunction(self.default_int)

with enable_caching(identifier='*.add_calcfunction'):
result_cached, cached = add_calcfunction.run_get_node(self.default_int)
assert result_original != result_cached
assert not cached.base.caching.is_created_from_cache
assert cached.is_valid_cache

# Test that the locally-created calcfunction can be cached in principle
result2_cached, cached2 = add_calcfunction.run_get_node(self.default_int)
assert result_original != result2_cached
assert result2_cached != result_original
assert result2_cached == result_cached
assert cached2.base.caching.is_created_from_cache

def test_calcfunction_do_not_store_provenance(self):
Expand Down
73 changes: 72 additions & 1 deletion tests/engine/test_work_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from aiida.common.utils import Capturing
from aiida.engine import ExitCode, Process, ToContext, WorkChain, append_, calcfunction, if_, launch, return_, while_
from aiida.engine.persistence import ObjectLoader
from aiida.manage import get_manager
from aiida.manage import enable_caching, get_manager
from aiida.orm import Bool, Float, Int, Str, load_node


Expand Down Expand Up @@ -146,6 +146,36 @@ def _set_finished(self, function_name):
self.finished_steps[function_name] = True


class CalcFunctionWorkChain(WorkChain):

@classmethod
def define(cls, spec):
super().define(spec)
spec.input('a')
spec.input('b')
spec.output('out_member')
spec.output('out_static')
spec.outline(
cls.run_add_member,
cls.run_add_static,
)

@calcfunction
def add_member(a, b): # pylint: disable=no-self-argument
return a + b

@staticmethod
@calcfunction
def add_static(a, b):
return a + b

def run_add_member(self):
self.out('out_member', CalcFunctionWorkChain.add_member(self.inputs.a, self.inputs.b))

def run_add_static(self):
self.out('out_static', self.add_static(self.inputs.a, self.inputs.b))


class PotentialFailureWorkChain(WorkChain):
"""Work chain that can finish with a non-zero exit code."""

Expand Down Expand Up @@ -1031,6 +1061,47 @@ def _run_with_checkpoints(wf_class, inputs=None):
proc = run_and_check_success(wf_class, **inputs)
return proc.finished_steps

def test_member_calcfunction(self):
"""Test defining a calcfunction as a ``WorkChain`` member method."""
results, node = launch.run.get_node(CalcFunctionWorkChain, a=Int(1), b=Int(2))
assert node.is_finished_ok
assert results['out_member'] == 3
assert results['out_static'] == 3

@pytest.mark.usefixtures('aiida_profile_clean')
def test_member_calcfunction_caching(self):
"""Test defining a calcfunction as a ``WorkChain`` member method with caching enabled."""
results, node = launch.run.get_node(CalcFunctionWorkChain, a=Int(1), b=Int(2))
assert node.is_finished_ok
assert results['out_member'] == 3
assert results['out_static'] == 3

with enable_caching():
results, cached_node = launch.run.get_node(CalcFunctionWorkChain, a=Int(1), b=Int(2))
assert cached_node.is_finished_ok
assert results['out_member'] == 3
assert results['out_static'] == 3

# Check that the calcfunctions called by the workchain have been cached
for called in cached_node.called:
assert called.base.caching.is_created_from_cache
assert called.base.caching.get_cache_source() in [n.uuid for n in node.called]

def test_member_calcfunction_daemon(self, entry_points, daemon_client, submit_and_await):
"""Test defining a calcfunction as a ``WorkChain`` member method submitted to the daemon."""
entry_points.add(CalcFunctionWorkChain, 'aiida.workflows:testing.calcfunction.workchain')

daemon_client.start_daemon()

builder = CalcFunctionWorkChain.get_builder()
builder.a = Int(1)
builder.b = Int(2)

node = submit_and_await(builder)
assert node.is_finished_ok
assert node.outputs.out_member == 3
assert node.outputs.out_static == 3


@pytest.mark.requires_rmq
class TestWorkChainAbort:
Expand Down

0 comments on commit e5ab1fb

Please sign in to comment.