diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 62ae29e398..8ceab33c0e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -53,7 +53,6 @@ aiida/common/datastructures.py| aiida/engine/daemon/execmanager.py| aiida/engine/processes/calcjobs/tasks.py| - aiida/orm/autogroup.py| aiida/orm/querybuilder.py| aiida/orm/nodes/data/array/bands.py| aiida/orm/nodes/data/array/projection.py| @@ -66,7 +65,6 @@ aiida/parsers/plugins/arithmetic/add.py| aiida/parsers/plugins/templatereplacer/doubler.py| aiida/parsers/plugins/templatereplacer/__init__.py| - aiida/plugins/entry_point.py| aiida/plugins/entry.py| aiida/plugins/info.py| aiida/plugins/registry.py| diff --git a/aiida/cmdline/commands/cmd_plugin.py b/aiida/cmdline/commands/cmd_plugin.py index f09c064950..3232441379 100644 --- a/aiida/cmdline/commands/cmd_plugin.py +++ b/aiida/cmdline/commands/cmd_plugin.py @@ -13,7 +13,7 @@ from aiida.cmdline.commands.cmd_verdi import verdi from aiida.cmdline.utils import decorators, echo -from aiida.plugins.entry_point import entry_point_group_to_module_path_map +from aiida.plugins.entry_point import ENTRY_POINT_GROUP_TO_MODULE_PATH_MAP @verdi.group('plugin') @@ -22,7 +22,7 @@ def verdi_plugin(): @verdi_plugin.command('list') -@click.argument('entry_point_group', type=click.Choice(entry_point_group_to_module_path_map.keys()), required=False) +@click.argument('entry_point_group', type=click.Choice(ENTRY_POINT_GROUP_TO_MODULE_PATH_MAP.keys()), required=False) @click.argument('entry_point', type=click.STRING, required=False) @decorators.with_dbenv() def plugin_list(entry_point_group, entry_point): @@ -34,7 +34,7 @@ def plugin_list(entry_point_group, entry_point): if entry_point_group is None: echo.echo_info('Available entry point groups:') - for group in sorted(entry_point_group_to_module_path_map.keys()): + for group in sorted(ENTRY_POINT_GROUP_TO_MODULE_PATH_MAP.keys()): echo.echo('* {}'.format(group)) echo.echo('') diff --git a/aiida/cmdline/commands/cmd_run.py b/aiida/cmdline/commands/cmd_run.py index 5a43cad6f5..fb591bd0ed 100644 --- a/aiida/cmdline/commands/cmd_run.py +++ b/aiida/cmdline/commands/cmd_run.py @@ -10,13 +10,17 @@ """`verdi run` command.""" import contextlib import os +import functools import sys +import warnings import click from aiida.cmdline.commands.cmd_verdi import verdi from aiida.cmdline.params.options.multivalue import MultipleValueOption from aiida.cmdline.utils import decorators, echo +from aiida.common.warnings import AiidaDeprecationWarning +from aiida.orm import autogroup @contextlib.contextmanager @@ -37,35 +41,65 @@ def update_environment(argv): sys.path = _path +def validate_entrypoint_string_or_all(ctx, param, value, allow_all=True): # pylint: disable=unused-argument,invalid-name + """Validate that `value` is a valid entrypoint string or the string 'all'.""" + try: + autogroup.Autogroup.validate(value, allow_all=allow_all) + except Exception as exc: + raise click.BadParameter(str(exc) + ' ({})'.format(value)) + + return value + + @verdi.command('run', context_settings=dict(ignore_unknown_options=True,)) @click.argument('scriptname', type=click.STRING) @click.argument('varargs', nargs=-1, type=click.UNPROCESSED) -@click.option('-g', '--group', is_flag=True, default=True, show_default=True, help='Enables the autogrouping') -@click.option('-n', '--group-name', type=click.STRING, required=False, help='Specify the name of the auto group') -@click.option('-e', '--exclude', cls=MultipleValueOption, default=[], help='Exclude these classes from auto grouping') +@click.option('--group/--no-group', default=True, show_default=True, help='Enables the autogrouping') +@click.option('-l', '--group-label', type=click.STRING, required=False, help='Specify the label of the auto group') @click.option( - '-i', '--include', cls=MultipleValueOption, default=['all'], help='Include these classes from auto grouping' + '-n', + '--group-name', + type=click.STRING, + required=False, + help='Specify the name of the auto group [DEPRECATED, USE --group-label instead]' +) +@click.option( + '-e', + '--exclude', + cls=MultipleValueOption, + default=[], + help='Exclude these classes from auto grouping (use full entrypoint strings)', + callback=functools.partial(validate_entrypoint_string_or_all, allow_all=False) +) +@click.option( + '-i', + '--include', + cls=MultipleValueOption, + default=['all'], + help='Include these classes from auto grouping (use full entrypoint strings or "all")', + callback=validate_entrypoint_string_or_all ) @click.option( '-E', '--excludesubclasses', cls=MultipleValueOption, default=[], - help='Exclude these classes and their sub classes from auto grouping' + help='Exclude these classes and their sub classes from auto grouping (use full entrypoint strings)', + callback=functools.partial(validate_entrypoint_string_or_all, allow_all=False) ) @click.option( '-I', '--includesubclasses', cls=MultipleValueOption, default=[], - help='Include these classes and their sub classes from auto grouping' + help='Include these classes and their sub classes from auto grouping (use full entrypoint strings)', + callback=functools.partial(validate_entrypoint_string_or_all, allow_all=False) ) @decorators.with_dbenv() -def run(scriptname, varargs, group, group_name, exclude, excludesubclasses, include, includesubclasses): +def run(scriptname, varargs, group, group_label, group_name, exclude, excludesubclasses, include, includesubclasses): # pylint: disable=too-many-arguments,exec-used """Execute scripts with preloaded AiiDA environment.""" from aiida.cmdline.utils.shell import DEFAULT_MODULES_LIST - from aiida.orm import autogroup # Prepare the environment for the script to be run globals_dict = { @@ -80,22 +114,25 @@ def run(scriptname, varargs, group, group_name, exclude, excludesubclasses, incl for app_mod, model_name, alias in DEFAULT_MODULES_LIST: globals_dict['{}'.format(alias)] = getattr(__import__(app_mod, {}, {}, model_name), model_name) - if group: - automatic_group_name = group_name - if automatic_group_name is None: - from aiida.common import timezone + if group_name: + warnings.warn('--group-name is deprecated, use `--group-label` instead', AiidaDeprecationWarning) # pylint: disable=no-member + if group_label: + raise click.BadParameter('You cannot specify both --group-name and --group-label; use --group-label only') + group_label = group_name - automatic_group_name = 'Verdi autogroup on ' + timezone.now().strftime('%Y-%m-%d %H:%M:%S') + if group: + automatic_group_label = group_label aiida_verdilib_autogroup = autogroup.Autogroup() + if automatic_group_label is not None: + aiida_verdilib_autogroup.set_group_label(automatic_group_label) aiida_verdilib_autogroup.set_exclude(exclude) aiida_verdilib_autogroup.set_include(include) aiida_verdilib_autogroup.set_exclude_with_subclasses(excludesubclasses) aiida_verdilib_autogroup.set_include_with_subclasses(includesubclasses) - aiida_verdilib_autogroup.set_group_name(automatic_group_name) # Note: this is also set in the exec environment! This is the intended behavior - autogroup.current_autogroup = aiida_verdilib_autogroup + autogroup.CURRENT_AUTOGROUP = aiida_verdilib_autogroup # Initialize the variable here, otherwise we get UnboundLocalError in the finally clause if it fails to open handle = None diff --git a/aiida/orm/autogroup.py b/aiida/orm/autogroup.py index ed4551a3ad..f11bf69a36 100644 --- a/aiida/orm/autogroup.py +++ b/aiida/orm/autogroup.py @@ -7,23 +7,24 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### - +"""Module to manage the autogrouping functionality by ``verdi run``.""" +import warnings from aiida.common import exceptions, timezone +from aiida.common.warnings import AiidaDeprecationWarning from aiida.orm import GroupTypeString +from aiida.plugins import load_entry_point_from_string - -current_autogroup = None +CURRENT_AUTOGROUP = None VERDIAUTOGROUP_TYPE = GroupTypeString.VERDIAUTOGROUP_TYPE.value -# TODO: make the Autogroup usable to the user, and not only to the verdi run class Autogroup: """ An object used for the autogrouping of objects. The autogrouping is checked by the Node.store() method. - In the store(), the Node will check if current_autogroup is != None. + In the store(), the Node will check if CURRENT_AUTOGROUP is != None. If so, it will call Autogroup.is_to_be_grouped, and decide whether to put it in a group. Such autogroups are going to be of the VERDIAUTOGROUP_TYPE. @@ -32,127 +33,124 @@ class Autogroup: i.e.: a string identifying the base class, than the path to the class as in Calculation/Data -Factories """ - def _validate(self, param, is_exact=True): + def __init__(self): + """Initialize with defaults.""" + self.exclude = [] + self.exclude_with_subclasses = [] + self.include = ['all'] + self.include_with_subclasses = [] + + now = timezone.now() + gname = 'Verdi autogroup on ' + now.strftime('%Y-%m-%d %H:%M:%S') + self.group_label = gname + + @staticmethod + def validate(param, allow_all=True): """ Used internally to verify the sanity of exclude, include lists + + :param param: should be a list of valid entrypoint strings """ - from aiida.plugins import CalculationFactory, DataFactory - - for i in param: - if not any([i.startswith('calculation'), - i.startswith('code'), - i.startswith('data'), - i == 'all', - ]): - raise exceptions.ValidationError('Module not recognized, allow prefixes ' - ' are: calculation, code or data') - the_param = [i + '.' for i in param] - - factorydict = {'calculation': locals()['CalculationFactory'], - 'data': locals()['DataFactory']} - - for i in the_param: - base, module = i.split('.', 1) - if base == 'code': - if module: - raise exceptions.ValidationError('Cannot have subclasses for codes') - elif base == 'all': + for string in param: + if allow_all and string == 'all': continue - else: - if is_exact: - try: - factorydict[base](module.rstrip('.')) - except exceptions.EntryPointError: - raise exceptions.ValidationError('Cannot find the class to be excluded') - return the_param + load_entry_point_from_string(string) # This will raise a MissingEntryPointError if invalid def get_exclude(self): """Return the list of classes to exclude from autogrouping.""" - try: - return self.exclude - except AttributeError: - return [] + return self.exclude def get_exclude_with_subclasses(self): """ Return the list of classes to exclude from autogrouping. Will also exclude their derived subclasses """ - try: - return self.exclude_with_subclasses - except AttributeError: - return [] + return self.exclude_with_subclasses def get_include(self): """Return the list of classes to include in the autogrouping.""" - try: - return self.include - except AttributeError: - return [] + return self.include def get_include_with_subclasses(self): """Return the list of classes to include in the autogrouping. Will also include their derived subclasses.""" - try: - return self.include_with_subclasses - except AttributeError: - return [] + return self.include_with_subclasses - def get_group_name(self): + def get_group_label(self): """Get the name of the group. If no group name was set, it will set a default one by itself.""" - try: - return self.group_name - except AttributeError: - now = timezone.now() - gname = 'Verdi autogroup on ' + now.strftime('%Y-%m-%d %H:%M:%S') - self.set_group_name(gname) - return self.group_name + return self.group_label + + def get_group_name(self): + """Get the label of the group. + If no group label was set, it will set a default one by itself. + + .. deprecated:: 1.1.0 + Will be removed in `v2.0.0`, use :py:meth:`.get_group_label` instead. + """ + warnings.warn('function is deprecated, use `get_group_label` instead', AiidaDeprecationWarning) # pylint: disable=no-member + return self.get_group_label() def set_exclude(self, exclude): - """Return the list of classes to exclude from autogrouping.""" - the_exclude_classes = self._validate(exclude) - if self.get_include() is not None: - if 'all.' in self.get_include(): - if 'all.' in the_exclude_classes: - raise exceptions.ValidationError('Cannot exclude and include all classes') - self.exclude = the_exclude_classes + """Return the list of classes to exclude from autogrouping. + + :param exclude: a list of valid entry point strings (one of which could be the string 'all') + """ + self.validate(exclude) + if 'all' in self.get_include(): + if 'all' in exclude: + raise exceptions.ValidationError('Cannot exclude and include all classes') + self.exclude = exclude def set_exclude_with_subclasses(self, exclude): """ Set the list of classes to exclude from autogrouping. Will also exclude their derived subclasses + + :param exclude: a list of valid entry point strings (one of which could be the string 'all') """ - the_exclude_classes = self._validate(exclude, is_exact=False) - self.exclude_with_subclasses = the_exclude_classes + self.validate(exclude) + self.exclude_with_subclasses = exclude def set_include(self, include): """ Set the list of classes to include in the autogrouping. + + :param include: a list of valid entry point strings (one of which could be the string 'all') """ - the_include_classes = self._validate(include) - if self.get_exclude() is not None: - if 'all.' in self.get_exclude(): - if 'all.' in the_include_classes: - raise exceptions.ValidationError('Cannot exclude and include all classes') + self.validate(include) + if 'all' in self.get_exclude(): + if 'all' in include: + raise exceptions.ValidationError('Cannot exclude and include all classes') - self.include = the_include_classes + self.include = include def set_include_with_subclasses(self, include): """ Set the list of classes to include in the autogrouping. Will also include their derived subclasses. + + :param include: a list of valid entry point strings (one of which could be the string 'all') """ - the_include_classes = self._validate(include, is_exact=False) - self.include_with_subclasses = the_include_classes + self.validate(include) + self.include_with_subclasses = include - def set_group_name(self, gname): + def set_group_label(self, label): + """ + Set the label of the group to be created """ - Set the name of the group to be created + if not isinstance(label, str): + raise exceptions.ValidationError('group label must be a string') + self.group_label = label + + def set_group_name(self, gname): + """Set the name of the group. + + .. deprecated:: 1.1.0 + Will be removed in `v2.0.0`, use :py:meth:`.set_group_label` instead. """ - if not isinstance(gname, str): - raise exceptions.ValidationError('group name must be a string') - self.group_name = gname + warnings.warn('function is deprecated, use `set_group_label` instead', AiidaDeprecationWarning) # pylint: disable=no-member + return self.set_group_label(label=gname) def is_to_be_grouped(self, the_class): """ @@ -160,20 +158,39 @@ def is_to_be_grouped(self, the_class): :return (bool): True if the_class is to be included in the autogroup """ - include = self.get_include() - include_ws = self.get_include_with_subclasses() - if (('all.' in include) or - (the_class._plugin_type_string in include) or - any([the_class._plugin_type_string.startswith(i) for i in include_ws]) + # strings, including possibly 'all' + include_exact = self.get_include() + include_with_subclasses = self.get_include_with_subclasses() + + # actual classes, with 'all' stripped out + include_exact_classes = tuple( + load_entry_point_from_string(ep_string) for ep_string in include_exact if ep_string != 'all' + ) + include_with_subclasses_classes = tuple( + load_entry_point_from_string(ep_string) for ep_string in include_with_subclasses if ep_string != 'all' + ) + + if ( + 'all' in include_exact or the_class in include_exact_classes or + issubclass(the_class, include_with_subclasses_classes) ): - exclude = self.get_exclude() - exclude_ws = self.get_exclude_with_subclasses() - if ((not 'all.' in exclude) or - (the_class._plugin_type_string in exclude) or - any([the_class._plugin_type_string.startswith(i) for i in exclude_ws]) - ): + # According to the include, this class should be included + # strings, including possibly 'all' + exclude_exact = self.get_exclude() + exclude_with_subclasses = self.get_exclude_with_subclasses() + + # actual classes, with 'all' stripped out + exclude_exact_classes = tuple( + load_entry_point_from_string(ep_string) for ep_string in exclude_exact if ep_string != 'all' + ) + exclude_with_subclasses_classes = tuple( + load_entry_point_from_string(ep_string) for ep_string in exclude_with_subclasses if ep_string != 'all' + ) + + if (the_class not in exclude_exact_classes and not issubclass(the_class, exclude_with_subclasses_classes)): + # If we are here, it's not excluded return True - else: - return False - else: + # If we're here, it's both in the include and in the exclude - exclude it return False + # If we're here, the class is not in the include - return False + return False diff --git a/aiida/orm/nodes/node.py b/aiida/orm/nodes/node.py index 86f6d9ace3..fecfd555ca 100644 --- a/aiida/orm/nodes/node.py +++ b/aiida/orm/nodes/node.py @@ -1024,15 +1024,15 @@ def store(self, with_transaction=True, use_cache=None): # pylint: disable=argum self._store(with_transaction=with_transaction, clean=True) # Set up autogrouping used by verdi run - from aiida.orm.autogroup import current_autogroup, Autogroup, VERDIAUTOGROUP_TYPE + from aiida.orm.autogroup import CURRENT_AUTOGROUP, Autogroup, VERDIAUTOGROUP_TYPE from aiida.orm import Group - if current_autogroup is not None: - if not isinstance(current_autogroup, Autogroup): - raise exceptions.ValidationError('`current_autogroup` is not of type `Autogroup`') + if CURRENT_AUTOGROUP is not None: + if not isinstance(CURRENT_AUTOGROUP, Autogroup): + raise exceptions.ValidationError('`CURRENT_AUTOGROUP` is not of type `Autogroup`') - if current_autogroup.is_to_be_grouped(self): - group_label = current_autogroup.get_group_name() + if CURRENT_AUTOGROUP.is_to_be_grouped(self.__class__): + group_label = CURRENT_AUTOGROUP.get_group_label() if group_label is not None: group = Group.objects.get_or_create(label=group_label, type_string=VERDIAUTOGROUP_TYPE)[0] group.add_nodes(self) diff --git a/aiida/orm/utils/node.py b/aiida/orm/utils/node.py index f48ec2ae16..44e657d0b4 100644 --- a/aiida/orm/utils/node.py +++ b/aiida/orm/utils/node.py @@ -82,13 +82,13 @@ def get_type_string_from_class(class_module, class_name): :param class_module: module of the class :param class_name: name of the class """ - from aiida.plugins.entry_point import get_entry_point_from_class, entry_point_group_to_module_path_map + from aiida.plugins.entry_point import get_entry_point_from_class, ENTRY_POINT_GROUP_TO_MODULE_PATH_MAP group, entry_point = get_entry_point_from_class(class_module, class_name) # If we can reverse engineer an entry point group and name, we're dealing with an external class if group and entry_point: - module_base_path = entry_point_group_to_module_path_map[group] + module_base_path = ENTRY_POINT_GROUP_TO_MODULE_PATH_MAP[group] type_string = '{}.{}.{}.'.format(module_base_path, entry_point.name, class_name) # Otherwise we are dealing with an internal class diff --git a/aiida/plugins/entry_point.py b/aiida/plugins/entry_point.py index 0d47792626..9be505f7dd 100644 --- a/aiida/plugins/entry_point.py +++ b/aiida/plugins/entry_point.py @@ -7,7 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### - +"""Module to manage loading entrypoints.""" import enum import traceback import functools @@ -24,7 +24,6 @@ __all__ = ('load_entry_point', 'load_entry_point_from_string') - ENTRY_POINT_GROUP_PREFIX = 'aiida.' ENTRY_POINT_STRING_SEPARATOR = ':' @@ -51,7 +50,7 @@ class EntryPointFormat(enum.Enum): MINIMAL = 3 -entry_point_group_to_module_path_map = { +ENTRY_POINT_GROUP_TO_MODULE_PATH_MAP = { 'aiida.calculations': 'aiida.orm.nodes.process.calculation.calcjob', 'aiida.cmdline.data': 'aiida.cmdline.data', 'aiida.data': 'aiida.orm.nodes.data', @@ -65,7 +64,7 @@ class EntryPointFormat(enum.Enum): } -def validate_registered_entry_points(): +def validate_registered_entry_points(): # pylint: disable=invalid-name """Validate all registered entry points by loading them with the corresponding factory. :raises EntryPointError: if any of the registered entry points cannot be loaded. This can happen if: @@ -108,12 +107,11 @@ def format_entry_point_string(group, name, fmt=EntryPointFormat.FULL): if fmt == EntryPointFormat.FULL: return '{}{}{}'.format(group, ENTRY_POINT_STRING_SEPARATOR, name) - elif fmt == EntryPointFormat.PARTIAL: + if fmt == EntryPointFormat.PARTIAL: return '{}{}{}'.format(group[len(ENTRY_POINT_GROUP_PREFIX):], ENTRY_POINT_STRING_SEPARATOR, name) - elif fmt == EntryPointFormat.MINIMAL: + if fmt == EntryPointFormat.MINIMAL: return '{}'.format(name) - else: - raise ValueError('invalid EntryPointFormat') + raise ValueError('invalid EntryPointFormat') def parse_entry_point_string(entry_point_string): @@ -146,14 +144,13 @@ def get_entry_point_string_format(entry_point_string): :rtype: EntryPointFormat """ try: - group, name = entry_point_string.split(ENTRY_POINT_STRING_SEPARATOR) + group, _ = entry_point_string.split(ENTRY_POINT_STRING_SEPARATOR) except ValueError: return EntryPointFormat.MINIMAL else: if group.startswith(ENTRY_POINT_GROUP_PREFIX): return EntryPointFormat.FULL - else: - return EntryPointFormat.PARTIAL + return EntryPointFormat.PARTIAL def get_entry_point_from_string(entry_point_string): @@ -186,6 +183,7 @@ def load_entry_point_from_string(entry_point_string): group, name = parse_entry_point_string(entry_point_string) return load_entry_point(group, name) + def load_entry_point(group, name): """ Load the class registered under the entry point for a given name and group @@ -215,7 +213,7 @@ def get_entry_point_groups(): :return: a list of valid entry point groups """ - return entry_point_group_to_module_path_map.keys() + return ENTRY_POINT_GROUP_TO_MODULE_PATH_MAP.keys() def get_entry_point_names(group, sort=True): @@ -244,6 +242,7 @@ def get_entry_points(group): """ return [ep for ep in ENTRYPOINT_MANAGER.iter_entry_points(group=group)] + @functools.lru_cache(maxsize=None) def get_entry_point(group, name): """ @@ -289,7 +288,7 @@ def get_entry_point_from_class(class_module, class_name): return None, None -def get_entry_point_string_from_class(class_module, class_name): +def get_entry_point_string_from_class(class_module, class_name): # pylint: disable=invalid-name """ Given the module and name of a class, attempt to obtain the corresponding entry point if it exists and return the entry point string which will be the entry point group and entry point @@ -311,8 +310,7 @@ def get_entry_point_string_from_class(class_module, class_name): if group and entry_point: return ENTRY_POINT_STRING_SEPARATOR.join([group, entry_point.name]) - else: - return None + return None def is_valid_entry_point_string(entry_point_string): @@ -326,9 +324,9 @@ def is_valid_entry_point_string(entry_point_string): :return: True if the string is considered valid, False otherwise """ try: - group, name = entry_point_string.split(ENTRY_POINT_STRING_SEPARATOR) + group, _ = entry_point_string.split(ENTRY_POINT_STRING_SEPARATOR) except (AttributeError, ValueError): # Either `entry_point_string` is not a string or it does not contain the separator return False - return group in entry_point_group_to_module_path_map + return group in ENTRY_POINT_GROUP_TO_MODULE_PATH_MAP diff --git a/docs/source/verdi/verdi_user_guide.rst b/docs/source/verdi/verdi_user_guide.rst index 1e384344e1..041dabf572 100644 --- a/docs/source/verdi/verdi_user_guide.rst +++ b/docs/source/verdi/verdi_user_guide.rst @@ -702,14 +702,20 @@ Below is a list with all available subcommands. Execute scripts with preloaded AiiDA environment. Options: - -g, --group Enables the autogrouping [default: True] + --group / --no-group Enables the autogrouping [default: True] + -l, --group-label TEXT Specify the label of the auto group -n, --group-name TEXT Specify the name of the auto group - -e, --exclude TEXT Exclude these classes from auto grouping - -i, --include TEXT Include these classes from auto grouping + [DEPRECATED, USE --group-label instead] + -e, --exclude TEXT Exclude these classes from auto grouping (use + full entrypoint strings) + -i, --include TEXT Include these classes from auto grouping (use + full entrypoint strings or "all") -E, --excludesubclasses TEXT Exclude these classes and their sub classes - from auto grouping + from auto grouping (use full entrypoint + strings) -I, --includesubclasses TEXT Include these classes and their sub classes - from auto grouping + from auto grouping (use full entrypoint + strings) --help Show this message and exit. diff --git a/tests/backends/aiida_django/migrations/test_migrations_common.py b/tests/backends/aiida_django/migrations/test_migrations_common.py index f8de61f9a6..43f4f03b3d 100644 --- a/tests/backends/aiida_django/migrations/test_migrations_common.py +++ b/tests/backends/aiida_django/migrations/test_migrations_common.py @@ -38,8 +38,8 @@ def setUp(self): from aiida.backends.djsite import get_scoped_session from aiida.orm import autogroup - self.current_autogroup = autogroup.current_autogroup - autogroup.current_autogroup = None + self.current_autogroup = autogroup.CURRENT_AUTOGROUP + autogroup.CURRENT_AUTOGROUP = None assert self.migrate_from and self.migrate_to, \ "TestCase '{}' must define migrate_from and migrate_to properties".format(type(self).__name__) self.migrate_from = [(self.app, self.migrate_from)] @@ -85,7 +85,7 @@ def tearDown(self): """At the end make sure we go back to the latest schema version.""" from aiida.orm import autogroup self._revert_database_schema() - autogroup.current_autogroup = self.current_autogroup + autogroup.CURRENT_AUTOGROUP = self.current_autogroup def setUpBeforeMigration(self): """Anything to do before running the migrations, which should be implemented in test subclasses.""" diff --git a/tests/backends/aiida_sqlalchemy/test_migrations.py b/tests/backends/aiida_sqlalchemy/test_migrations.py index 8bdda5d145..fdd27298ca 100644 --- a/tests/backends/aiida_sqlalchemy/test_migrations.py +++ b/tests/backends/aiida_sqlalchemy/test_migrations.py @@ -57,8 +57,8 @@ def setUp(self): super().setUp() from aiida.orm import autogroup - self.current_autogroup = autogroup.current_autogroup - autogroup.current_autogroup = None + self.current_autogroup = autogroup.CURRENT_AUTOGROUP + autogroup.CURRENT_AUTOGROUP = None assert self.migrate_from and self.migrate_to, \ "TestCase '{}' must define migrate_from and migrate_to properties".format(type(self).__name__) @@ -99,7 +99,7 @@ def tearDown(self): """ from aiida.orm import autogroup self._reset_database_and_schema() - autogroup.current_autogroup = self.current_autogroup + autogroup.CURRENT_AUTOGROUP = self.current_autogroup super().tearDown() def setUpBeforeMigration(self): # pylint: disable=invalid-name @@ -218,8 +218,8 @@ def setUp(self): AiidaTestCase.setUp(self) # pylint: disable=bad-super-call from aiida.orm import autogroup - self.current_autogroup = autogroup.current_autogroup - autogroup.current_autogroup = None + self.current_autogroup = autogroup.CURRENT_AUTOGROUP + autogroup.CURRENT_AUTOGROUP = None assert self.migrate_from and self.migrate_to, \ "TestCase '{}' must define migrate_from and migrate_to properties".format(type(self).__name__) @@ -234,6 +234,11 @@ def setUp(self): self._reset_database_and_schema() raise + def tearDown(self): + """Put back the correct autogroup.""" + from aiida.orm import autogroup + autogroup.CURRENT_AUTOGROUP = self.current_autogroup + class TestMigrationEngine(TestMigrationsSQLA): """ diff --git a/tests/cmdline/commands/test_calcjob.py b/tests/cmdline/commands/test_calcjob.py index dc7895c5d4..2f1945d45a 100644 --- a/tests/cmdline/commands/test_calcjob.py +++ b/tests/cmdline/commands/test_calcjob.py @@ -98,6 +98,7 @@ def setUpClass(cls, *args, **kwargs): cls.arithmetic_job = calculations[0] def setUp(self): + super().setUp() self.cli_runner = CliRunner() def test_calcjob_res(self): diff --git a/tests/cmdline/commands/test_run.py b/tests/cmdline/commands/test_run.py index 3cae78fb70..6af129d479 100644 --- a/tests/cmdline/commands/test_run.py +++ b/tests/cmdline/commands/test_run.py @@ -8,6 +8,8 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Tests for `verdi run`.""" +import tempfile + from click.testing import CliRunner from aiida.backends.testbase import AiidaTestCase @@ -28,7 +30,6 @@ def test_run_workfunction(self): that are defined within the script will fail, as the inspect module will not correctly be able to determin the full path of the source file. """ - import tempfile from aiida.orm import load_node from aiida.orm import WorkFunctionNode @@ -64,3 +65,158 @@ def wf(): self.assertTrue(isinstance(node, WorkFunctionNode)) self.assertEqual(node.function_name, 'wf') self.assertEqual(node.get_function_source_code(), script_content) + + +class TestAutoGroups(AiidaTestCase): + """Test the autogroup functionality.""" + + def setUp(self): + """Setup the CLI runner to run command line commands.""" + from aiida.orm import autogroup + + super().setUp() + self.cli_runner = CliRunner() + # I need to disable the global variable of this test environment, + # because invoke is just calling the function and therefore inheriting + # the global variable + self._old_autogroup = autogroup.CURRENT_AUTOGROUP + autogroup.CURRENT_AUTOGROUP = None + + def tearDown(self): + """Setup the CLI runner to run command line commands.""" + from aiida.orm import autogroup + + super().tearDown() + autogroup.CURRENT_AUTOGROUP = self._old_autogroup + + def test_autogroup(self): + """Check if the autogroup is properly generated.""" + from aiida.orm import QueryBuilder, Node, Group, load_node + + script_content = """from aiida.orm import Data +node = Data().store() +print(node.pk) +""" + + with tempfile.NamedTemporaryFile(mode='w+') as fhandle: + fhandle.write(script_content) + fhandle.flush() + + options = [fhandle.name] + result = self.cli_runner.invoke(cmd_run.run, options) + self.assertClickResultNoException(result) + + pk = int(result.output) + _ = load_node(pk) # Check if the node can be loaded + + queryb = QueryBuilder().append(Node, filters={'id': pk}, tag='node') + queryb.append(Group, with_node='node', filters={'type_string': 'auto.run'}, project='*') + all_auto_groups = queryb.all() + self.assertEqual( + len(all_auto_groups), 1, 'There should be only one autogroup associated with the node just created' + ) + + def test_autogroup_custom_label(self): + """Check if the autogroup is properly generated with the label specified.""" + from aiida.orm import QueryBuilder, Node, Group, load_node + + script_content = """from aiida.orm import Data +node = Data().store() +print(node.pk) +""" + autogroup_label = 'SOME_group_LABEL' + with tempfile.NamedTemporaryFile(mode='w+') as fhandle: + fhandle.write(script_content) + fhandle.flush() + + options = [fhandle.name, '--group-label', autogroup_label] + result = self.cli_runner.invoke(cmd_run.run, options) + self.assertClickResultNoException(result) + + pk = int(result.output) + _ = load_node(pk) # Check if the node can be loaded + + queryb = QueryBuilder().append(Node, filters={'id': pk}, tag='node') + queryb.append(Group, with_node='node', filters={'type_string': 'auto.run'}, project='*') + all_auto_groups = queryb.all() + self.assertEqual( + len(all_auto_groups), 1, 'There should be only one autogroup associated with the node just created' + ) + self.assertEqual(all_auto_groups[0][0].label, autogroup_label) + + def test_no_autogroup(self): + """Check if the autogroup is not generated if ``verdi run`` is asked not to.""" + from aiida.orm import QueryBuilder, Node, Group, load_node + + script_content = """from aiida.orm import Data +node = Data().store() +print(node.pk) +""" + + with tempfile.NamedTemporaryFile(mode='w+') as fhandle: + fhandle.write(script_content) + fhandle.flush() + + options = [fhandle.name, '--no-group'] + result = self.cli_runner.invoke(cmd_run.run, options) + self.assertClickResultNoException(result) + + pk = int(result.output) + _ = load_node(pk) # Check if the node can be loaded + + queryb = QueryBuilder().append(Node, filters={'id': pk}, tag='node') + queryb.append(Group, with_node='node', filters={'type_string': 'auto.run'}, project='*') + all_auto_groups = queryb.all() + self.assertEqual(len(all_auto_groups), 0, 'There should be no autogroup generated') + + def test_autogroup_filter_class(self): # pylint: disable=too-many-locals + """Check if the autogroup is properly generated but filtered classes are skipped.""" + from aiida.orm import QueryBuilder, Node, Group, load_node + + script_content = """from aiida.orm import Data +node1 = Data().store() +node2 = Int(3).store() +print(node1.pk) +print(node2.pk) +""" + + for flags, data_in_autogroup, int_in_autogroup in [ + [['--exclude', 'aiida.node:data'], False, True], + [['--exclude', 'aiida.data:int'], True, False], + [['--excludesubclasses', 'aiida.node:data'], False, False], + [['--excludesubclasses', 'aiida.data:int'], True, False], + [['--excludesubclasses', 'aiida.node:data', 'aiida.data:int'], False, False], + [['--include', 'aiida.node:process'], False, False], + [['--exclude', 'aiida.node:data', 'aiida.data:int'], False, False], + ]: + with tempfile.NamedTemporaryFile(mode='w+') as fhandle: + fhandle.write(script_content) + fhandle.flush() + + options = [fhandle.name] + flags + ['--'] + result = self.cli_runner.invoke(cmd_run.run, options) + self.assertClickResultNoException(result) + + pk1_str, pk2_str = result.output.split() + pk1 = int(pk1_str) + pk2 = int(pk2_str) + _ = load_node(pk1) # Check if the node can be loaded + _ = load_node(pk2) # Check if the node can be loaded + + queryb = QueryBuilder().append(Node, filters={'id': pk1}, tag='node') + queryb.append(Group, with_node='node', filters={'type_string': 'auto.run'}, project='*') + all_auto_groups_data = queryb.all() + + queryb = QueryBuilder().append(Node, filters={'id': pk2}, tag='node') + queryb.append(Group, with_node='node', filters={'type_string': 'auto.run'}, project='*') + all_auto_groups_int = queryb.all() + self.assertEqual( + len(all_auto_groups_data), 1 if data_in_autogroup else 0, + 'Wrong number of nodes in autogroup associated with the Data node ' + "just created with flags '{}'".format(' '.join(flags)) + ) + self.assertEqual( + len(all_auto_groups_int), 1 if int_in_autogroup else 0, + 'Wrong number of nodes in autogroup associated with the Int node ' + "just created with flags '{}'".format(' '.join(flags)) + ) diff --git a/tests/orm/test_groups.py b/tests/orm/test_groups.py index 9c842aa2c4..ce2797daad 100644 --- a/tests/orm/test_groups.py +++ b/tests/orm/test_groups.py @@ -8,7 +8,6 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Test for the Group ORM class.""" - from aiida import orm from aiida.backends.testbase import AiidaTestCase from aiida.common import exceptions