From 374543fc3517520d2527719d49e21422687e4745 Mon Sep 17 00:00:00 2001 From: Giovanni Pizzi Date: Thu, 12 Dec 2019 01:08:35 +0100 Subject: [PATCH 1/5] Clean-up of the autogroup logic In particular: - deprecated use of 'group-name' in favour of 'group-label' both in the internal API of AutoGroups and in the CLI parameters of `verdi run` - Improved the Autogroup implementation and API - add various tests of the functionality of autogroups - fixed the CLI parameters to include/exclude groups, that were badly broken (and untested). Now it accepts any entrypoint string (before it was accepting the first part of an node_type string). Also added click validation of the parameters. - fixed the --group flag to activate/deactivate autogrouping (it was wrongly defined and thus always True) - linter fixes --- .pre-commit-config.yaml | 2 - aiida/cmdline/commands/cmd_plugin.py | 6 +- aiida/cmdline/commands/cmd_run.py | 67 ++++-- aiida/orm/autogroup.py | 205 ++++++++++-------- aiida/orm/nodes/node.py | 12 +- aiida/orm/utils/node.py | 4 +- aiida/plugins/entry_point.py | 32 ++- docs/source/verdi/verdi_user_guide.rst | 16 +- .../migrations/test_migrations_common.py | 6 +- .../aiida_sqlalchemy/test_migrations.py | 15 +- tests/cmdline/commands/test_calcjob.py | 1 + tests/cmdline/commands/test_run.py | 158 +++++++++++++- tests/orm/test_groups.py | 1 - 13 files changed, 371 insertions(+), 154 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 62ae29e398..8ceab33c0e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -53,7 +53,6 @@ aiida/common/datastructures.py| aiida/engine/daemon/execmanager.py| aiida/engine/processes/calcjobs/tasks.py| - aiida/orm/autogroup.py| aiida/orm/querybuilder.py| aiida/orm/nodes/data/array/bands.py| aiida/orm/nodes/data/array/projection.py| @@ -66,7 +65,6 @@ aiida/parsers/plugins/arithmetic/add.py| aiida/parsers/plugins/templatereplacer/doubler.py| aiida/parsers/plugins/templatereplacer/__init__.py| - aiida/plugins/entry_point.py| aiida/plugins/entry.py| aiida/plugins/info.py| aiida/plugins/registry.py| diff --git a/aiida/cmdline/commands/cmd_plugin.py b/aiida/cmdline/commands/cmd_plugin.py index f09c064950..3232441379 100644 --- a/aiida/cmdline/commands/cmd_plugin.py +++ b/aiida/cmdline/commands/cmd_plugin.py @@ -13,7 +13,7 @@ from aiida.cmdline.commands.cmd_verdi import verdi from aiida.cmdline.utils import decorators, echo -from aiida.plugins.entry_point import entry_point_group_to_module_path_map +from aiida.plugins.entry_point import ENTRY_POINT_GROUP_TO_MODULE_PATH_MAP @verdi.group('plugin') @@ -22,7 +22,7 @@ def verdi_plugin(): @verdi_plugin.command('list') -@click.argument('entry_point_group', type=click.Choice(entry_point_group_to_module_path_map.keys()), required=False) +@click.argument('entry_point_group', type=click.Choice(ENTRY_POINT_GROUP_TO_MODULE_PATH_MAP.keys()), required=False) @click.argument('entry_point', type=click.STRING, required=False) @decorators.with_dbenv() def plugin_list(entry_point_group, entry_point): @@ -34,7 +34,7 @@ def plugin_list(entry_point_group, entry_point): if entry_point_group is None: echo.echo_info('Available entry point groups:') - for group in sorted(entry_point_group_to_module_path_map.keys()): + for group in sorted(ENTRY_POINT_GROUP_TO_MODULE_PATH_MAP.keys()): echo.echo('* {}'.format(group)) echo.echo('') diff --git a/aiida/cmdline/commands/cmd_run.py b/aiida/cmdline/commands/cmd_run.py index 5a43cad6f5..fb591bd0ed 100644 --- a/aiida/cmdline/commands/cmd_run.py +++ b/aiida/cmdline/commands/cmd_run.py @@ -10,13 +10,17 @@ """`verdi run` command.""" import contextlib import os +import functools import sys +import warnings import click from aiida.cmdline.commands.cmd_verdi import verdi from aiida.cmdline.params.options.multivalue import MultipleValueOption from aiida.cmdline.utils import decorators, echo +from aiida.common.warnings import AiidaDeprecationWarning +from aiida.orm import autogroup @contextlib.contextmanager @@ -37,35 +41,65 @@ def update_environment(argv): sys.path = _path +def validate_entrypoint_string_or_all(ctx, param, value, allow_all=True): # pylint: disable=unused-argument,invalid-name + """Validate that `value` is a valid entrypoint string or the string 'all'.""" + try: + autogroup.Autogroup.validate(value, allow_all=allow_all) + except Exception as exc: + raise click.BadParameter(str(exc) + ' ({})'.format(value)) + + return value + + @verdi.command('run', context_settings=dict(ignore_unknown_options=True,)) @click.argument('scriptname', type=click.STRING) @click.argument('varargs', nargs=-1, type=click.UNPROCESSED) -@click.option('-g', '--group', is_flag=True, default=True, show_default=True, help='Enables the autogrouping') -@click.option('-n', '--group-name', type=click.STRING, required=False, help='Specify the name of the auto group') -@click.option('-e', '--exclude', cls=MultipleValueOption, default=[], help='Exclude these classes from auto grouping') +@click.option('--group/--no-group', default=True, show_default=True, help='Enables the autogrouping') +@click.option('-l', '--group-label', type=click.STRING, required=False, help='Specify the label of the auto group') @click.option( - '-i', '--include', cls=MultipleValueOption, default=['all'], help='Include these classes from auto grouping' + '-n', + '--group-name', + type=click.STRING, + required=False, + help='Specify the name of the auto group [DEPRECATED, USE --group-label instead]' +) +@click.option( + '-e', + '--exclude', + cls=MultipleValueOption, + default=[], + help='Exclude these classes from auto grouping (use full entrypoint strings)', + callback=functools.partial(validate_entrypoint_string_or_all, allow_all=False) +) +@click.option( + '-i', + '--include', + cls=MultipleValueOption, + default=['all'], + help='Include these classes from auto grouping (use full entrypoint strings or "all")', + callback=validate_entrypoint_string_or_all ) @click.option( '-E', '--excludesubclasses', cls=MultipleValueOption, default=[], - help='Exclude these classes and their sub classes from auto grouping' + help='Exclude these classes and their sub classes from auto grouping (use full entrypoint strings)', + callback=functools.partial(validate_entrypoint_string_or_all, allow_all=False) ) @click.option( '-I', '--includesubclasses', cls=MultipleValueOption, default=[], - help='Include these classes and their sub classes from auto grouping' + help='Include these classes and their sub classes from auto grouping (use full entrypoint strings)', + callback=functools.partial(validate_entrypoint_string_or_all, allow_all=False) ) @decorators.with_dbenv() -def run(scriptname, varargs, group, group_name, exclude, excludesubclasses, include, includesubclasses): +def run(scriptname, varargs, group, group_label, group_name, exclude, excludesubclasses, include, includesubclasses): # pylint: disable=too-many-arguments,exec-used """Execute scripts with preloaded AiiDA environment.""" from aiida.cmdline.utils.shell import DEFAULT_MODULES_LIST - from aiida.orm import autogroup # Prepare the environment for the script to be run globals_dict = { @@ -80,22 +114,25 @@ def run(scriptname, varargs, group, group_name, exclude, excludesubclasses, incl for app_mod, model_name, alias in DEFAULT_MODULES_LIST: globals_dict['{}'.format(alias)] = getattr(__import__(app_mod, {}, {}, model_name), model_name) - if group: - automatic_group_name = group_name - if automatic_group_name is None: - from aiida.common import timezone + if group_name: + warnings.warn('--group-name is deprecated, use `--group-label` instead', AiidaDeprecationWarning) # pylint: disable=no-member + if group_label: + raise click.BadParameter('You cannot specify both --group-name and --group-label; use --group-label only') + group_label = group_name - automatic_group_name = 'Verdi autogroup on ' + timezone.now().strftime('%Y-%m-%d %H:%M:%S') + if group: + automatic_group_label = group_label aiida_verdilib_autogroup = autogroup.Autogroup() + if automatic_group_label is not None: + aiida_verdilib_autogroup.set_group_label(automatic_group_label) aiida_verdilib_autogroup.set_exclude(exclude) aiida_verdilib_autogroup.set_include(include) aiida_verdilib_autogroup.set_exclude_with_subclasses(excludesubclasses) aiida_verdilib_autogroup.set_include_with_subclasses(includesubclasses) - aiida_verdilib_autogroup.set_group_name(automatic_group_name) # Note: this is also set in the exec environment! This is the intended behavior - autogroup.current_autogroup = aiida_verdilib_autogroup + autogroup.CURRENT_AUTOGROUP = aiida_verdilib_autogroup # Initialize the variable here, otherwise we get UnboundLocalError in the finally clause if it fails to open handle = None diff --git a/aiida/orm/autogroup.py b/aiida/orm/autogroup.py index ed4551a3ad..f11bf69a36 100644 --- a/aiida/orm/autogroup.py +++ b/aiida/orm/autogroup.py @@ -7,23 +7,24 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### - +"""Module to manage the autogrouping functionality by ``verdi run``.""" +import warnings from aiida.common import exceptions, timezone +from aiida.common.warnings import AiidaDeprecationWarning from aiida.orm import GroupTypeString +from aiida.plugins import load_entry_point_from_string - -current_autogroup = None +CURRENT_AUTOGROUP = None VERDIAUTOGROUP_TYPE = GroupTypeString.VERDIAUTOGROUP_TYPE.value -# TODO: make the Autogroup usable to the user, and not only to the verdi run class Autogroup: """ An object used for the autogrouping of objects. The autogrouping is checked by the Node.store() method. - In the store(), the Node will check if current_autogroup is != None. + In the store(), the Node will check if CURRENT_AUTOGROUP is != None. If so, it will call Autogroup.is_to_be_grouped, and decide whether to put it in a group. Such autogroups are going to be of the VERDIAUTOGROUP_TYPE. @@ -32,127 +33,124 @@ class Autogroup: i.e.: a string identifying the base class, than the path to the class as in Calculation/Data -Factories """ - def _validate(self, param, is_exact=True): + def __init__(self): + """Initialize with defaults.""" + self.exclude = [] + self.exclude_with_subclasses = [] + self.include = ['all'] + self.include_with_subclasses = [] + + now = timezone.now() + gname = 'Verdi autogroup on ' + now.strftime('%Y-%m-%d %H:%M:%S') + self.group_label = gname + + @staticmethod + def validate(param, allow_all=True): """ Used internally to verify the sanity of exclude, include lists + + :param param: should be a list of valid entrypoint strings """ - from aiida.plugins import CalculationFactory, DataFactory - - for i in param: - if not any([i.startswith('calculation'), - i.startswith('code'), - i.startswith('data'), - i == 'all', - ]): - raise exceptions.ValidationError('Module not recognized, allow prefixes ' - ' are: calculation, code or data') - the_param = [i + '.' for i in param] - - factorydict = {'calculation': locals()['CalculationFactory'], - 'data': locals()['DataFactory']} - - for i in the_param: - base, module = i.split('.', 1) - if base == 'code': - if module: - raise exceptions.ValidationError('Cannot have subclasses for codes') - elif base == 'all': + for string in param: + if allow_all and string == 'all': continue - else: - if is_exact: - try: - factorydict[base](module.rstrip('.')) - except exceptions.EntryPointError: - raise exceptions.ValidationError('Cannot find the class to be excluded') - return the_param + load_entry_point_from_string(string) # This will raise a MissingEntryPointError if invalid def get_exclude(self): """Return the list of classes to exclude from autogrouping.""" - try: - return self.exclude - except AttributeError: - return [] + return self.exclude def get_exclude_with_subclasses(self): """ Return the list of classes to exclude from autogrouping. Will also exclude their derived subclasses """ - try: - return self.exclude_with_subclasses - except AttributeError: - return [] + return self.exclude_with_subclasses def get_include(self): """Return the list of classes to include in the autogrouping.""" - try: - return self.include - except AttributeError: - return [] + return self.include def get_include_with_subclasses(self): """Return the list of classes to include in the autogrouping. Will also include their derived subclasses.""" - try: - return self.include_with_subclasses - except AttributeError: - return [] + return self.include_with_subclasses - def get_group_name(self): + def get_group_label(self): """Get the name of the group. If no group name was set, it will set a default one by itself.""" - try: - return self.group_name - except AttributeError: - now = timezone.now() - gname = 'Verdi autogroup on ' + now.strftime('%Y-%m-%d %H:%M:%S') - self.set_group_name(gname) - return self.group_name + return self.group_label + + def get_group_name(self): + """Get the label of the group. + If no group label was set, it will set a default one by itself. + + .. deprecated:: 1.1.0 + Will be removed in `v2.0.0`, use :py:meth:`.get_group_label` instead. + """ + warnings.warn('function is deprecated, use `get_group_label` instead', AiidaDeprecationWarning) # pylint: disable=no-member + return self.get_group_label() def set_exclude(self, exclude): - """Return the list of classes to exclude from autogrouping.""" - the_exclude_classes = self._validate(exclude) - if self.get_include() is not None: - if 'all.' in self.get_include(): - if 'all.' in the_exclude_classes: - raise exceptions.ValidationError('Cannot exclude and include all classes') - self.exclude = the_exclude_classes + """Return the list of classes to exclude from autogrouping. + + :param exclude: a list of valid entry point strings (one of which could be the string 'all') + """ + self.validate(exclude) + if 'all' in self.get_include(): + if 'all' in exclude: + raise exceptions.ValidationError('Cannot exclude and include all classes') + self.exclude = exclude def set_exclude_with_subclasses(self, exclude): """ Set the list of classes to exclude from autogrouping. Will also exclude their derived subclasses + + :param exclude: a list of valid entry point strings (one of which could be the string 'all') """ - the_exclude_classes = self._validate(exclude, is_exact=False) - self.exclude_with_subclasses = the_exclude_classes + self.validate(exclude) + self.exclude_with_subclasses = exclude def set_include(self, include): """ Set the list of classes to include in the autogrouping. + + :param include: a list of valid entry point strings (one of which could be the string 'all') """ - the_include_classes = self._validate(include) - if self.get_exclude() is not None: - if 'all.' in self.get_exclude(): - if 'all.' in the_include_classes: - raise exceptions.ValidationError('Cannot exclude and include all classes') + self.validate(include) + if 'all' in self.get_exclude(): + if 'all' in include: + raise exceptions.ValidationError('Cannot exclude and include all classes') - self.include = the_include_classes + self.include = include def set_include_with_subclasses(self, include): """ Set the list of classes to include in the autogrouping. Will also include their derived subclasses. + + :param include: a list of valid entry point strings (one of which could be the string 'all') """ - the_include_classes = self._validate(include, is_exact=False) - self.include_with_subclasses = the_include_classes + self.validate(include) + self.include_with_subclasses = include - def set_group_name(self, gname): + def set_group_label(self, label): + """ + Set the label of the group to be created """ - Set the name of the group to be created + if not isinstance(label, str): + raise exceptions.ValidationError('group label must be a string') + self.group_label = label + + def set_group_name(self, gname): + """Set the name of the group. + + .. deprecated:: 1.1.0 + Will be removed in `v2.0.0`, use :py:meth:`.set_group_label` instead. """ - if not isinstance(gname, str): - raise exceptions.ValidationError('group name must be a string') - self.group_name = gname + warnings.warn('function is deprecated, use `set_group_label` instead', AiidaDeprecationWarning) # pylint: disable=no-member + return self.set_group_label(label=gname) def is_to_be_grouped(self, the_class): """ @@ -160,20 +158,39 @@ def is_to_be_grouped(self, the_class): :return (bool): True if the_class is to be included in the autogroup """ - include = self.get_include() - include_ws = self.get_include_with_subclasses() - if (('all.' in include) or - (the_class._plugin_type_string in include) or - any([the_class._plugin_type_string.startswith(i) for i in include_ws]) + # strings, including possibly 'all' + include_exact = self.get_include() + include_with_subclasses = self.get_include_with_subclasses() + + # actual classes, with 'all' stripped out + include_exact_classes = tuple( + load_entry_point_from_string(ep_string) for ep_string in include_exact if ep_string != 'all' + ) + include_with_subclasses_classes = tuple( + load_entry_point_from_string(ep_string) for ep_string in include_with_subclasses if ep_string != 'all' + ) + + if ( + 'all' in include_exact or the_class in include_exact_classes or + issubclass(the_class, include_with_subclasses_classes) ): - exclude = self.get_exclude() - exclude_ws = self.get_exclude_with_subclasses() - if ((not 'all.' in exclude) or - (the_class._plugin_type_string in exclude) or - any([the_class._plugin_type_string.startswith(i) for i in exclude_ws]) - ): + # According to the include, this class should be included + # strings, including possibly 'all' + exclude_exact = self.get_exclude() + exclude_with_subclasses = self.get_exclude_with_subclasses() + + # actual classes, with 'all' stripped out + exclude_exact_classes = tuple( + load_entry_point_from_string(ep_string) for ep_string in exclude_exact if ep_string != 'all' + ) + exclude_with_subclasses_classes = tuple( + load_entry_point_from_string(ep_string) for ep_string in exclude_with_subclasses if ep_string != 'all' + ) + + if (the_class not in exclude_exact_classes and not issubclass(the_class, exclude_with_subclasses_classes)): + # If we are here, it's not excluded return True - else: - return False - else: + # If we're here, it's both in the include and in the exclude - exclude it return False + # If we're here, the class is not in the include - return False + return False diff --git a/aiida/orm/nodes/node.py b/aiida/orm/nodes/node.py index 86f6d9ace3..fecfd555ca 100644 --- a/aiida/orm/nodes/node.py +++ b/aiida/orm/nodes/node.py @@ -1024,15 +1024,15 @@ def store(self, with_transaction=True, use_cache=None): # pylint: disable=argum self._store(with_transaction=with_transaction, clean=True) # Set up autogrouping used by verdi run - from aiida.orm.autogroup import current_autogroup, Autogroup, VERDIAUTOGROUP_TYPE + from aiida.orm.autogroup import CURRENT_AUTOGROUP, Autogroup, VERDIAUTOGROUP_TYPE from aiida.orm import Group - if current_autogroup is not None: - if not isinstance(current_autogroup, Autogroup): - raise exceptions.ValidationError('`current_autogroup` is not of type `Autogroup`') + if CURRENT_AUTOGROUP is not None: + if not isinstance(CURRENT_AUTOGROUP, Autogroup): + raise exceptions.ValidationError('`CURRENT_AUTOGROUP` is not of type `Autogroup`') - if current_autogroup.is_to_be_grouped(self): - group_label = current_autogroup.get_group_name() + if CURRENT_AUTOGROUP.is_to_be_grouped(self.__class__): + group_label = CURRENT_AUTOGROUP.get_group_label() if group_label is not None: group = Group.objects.get_or_create(label=group_label, type_string=VERDIAUTOGROUP_TYPE)[0] group.add_nodes(self) diff --git a/aiida/orm/utils/node.py b/aiida/orm/utils/node.py index f48ec2ae16..44e657d0b4 100644 --- a/aiida/orm/utils/node.py +++ b/aiida/orm/utils/node.py @@ -82,13 +82,13 @@ def get_type_string_from_class(class_module, class_name): :param class_module: module of the class :param class_name: name of the class """ - from aiida.plugins.entry_point import get_entry_point_from_class, entry_point_group_to_module_path_map + from aiida.plugins.entry_point import get_entry_point_from_class, ENTRY_POINT_GROUP_TO_MODULE_PATH_MAP group, entry_point = get_entry_point_from_class(class_module, class_name) # If we can reverse engineer an entry point group and name, we're dealing with an external class if group and entry_point: - module_base_path = entry_point_group_to_module_path_map[group] + module_base_path = ENTRY_POINT_GROUP_TO_MODULE_PATH_MAP[group] type_string = '{}.{}.{}.'.format(module_base_path, entry_point.name, class_name) # Otherwise we are dealing with an internal class diff --git a/aiida/plugins/entry_point.py b/aiida/plugins/entry_point.py index 0d47792626..9be505f7dd 100644 --- a/aiida/plugins/entry_point.py +++ b/aiida/plugins/entry_point.py @@ -7,7 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### - +"""Module to manage loading entrypoints.""" import enum import traceback import functools @@ -24,7 +24,6 @@ __all__ = ('load_entry_point', 'load_entry_point_from_string') - ENTRY_POINT_GROUP_PREFIX = 'aiida.' ENTRY_POINT_STRING_SEPARATOR = ':' @@ -51,7 +50,7 @@ class EntryPointFormat(enum.Enum): MINIMAL = 3 -entry_point_group_to_module_path_map = { +ENTRY_POINT_GROUP_TO_MODULE_PATH_MAP = { 'aiida.calculations': 'aiida.orm.nodes.process.calculation.calcjob', 'aiida.cmdline.data': 'aiida.cmdline.data', 'aiida.data': 'aiida.orm.nodes.data', @@ -65,7 +64,7 @@ class EntryPointFormat(enum.Enum): } -def validate_registered_entry_points(): +def validate_registered_entry_points(): # pylint: disable=invalid-name """Validate all registered entry points by loading them with the corresponding factory. :raises EntryPointError: if any of the registered entry points cannot be loaded. This can happen if: @@ -108,12 +107,11 @@ def format_entry_point_string(group, name, fmt=EntryPointFormat.FULL): if fmt == EntryPointFormat.FULL: return '{}{}{}'.format(group, ENTRY_POINT_STRING_SEPARATOR, name) - elif fmt == EntryPointFormat.PARTIAL: + if fmt == EntryPointFormat.PARTIAL: return '{}{}{}'.format(group[len(ENTRY_POINT_GROUP_PREFIX):], ENTRY_POINT_STRING_SEPARATOR, name) - elif fmt == EntryPointFormat.MINIMAL: + if fmt == EntryPointFormat.MINIMAL: return '{}'.format(name) - else: - raise ValueError('invalid EntryPointFormat') + raise ValueError('invalid EntryPointFormat') def parse_entry_point_string(entry_point_string): @@ -146,14 +144,13 @@ def get_entry_point_string_format(entry_point_string): :rtype: EntryPointFormat """ try: - group, name = entry_point_string.split(ENTRY_POINT_STRING_SEPARATOR) + group, _ = entry_point_string.split(ENTRY_POINT_STRING_SEPARATOR) except ValueError: return EntryPointFormat.MINIMAL else: if group.startswith(ENTRY_POINT_GROUP_PREFIX): return EntryPointFormat.FULL - else: - return EntryPointFormat.PARTIAL + return EntryPointFormat.PARTIAL def get_entry_point_from_string(entry_point_string): @@ -186,6 +183,7 @@ def load_entry_point_from_string(entry_point_string): group, name = parse_entry_point_string(entry_point_string) return load_entry_point(group, name) + def load_entry_point(group, name): """ Load the class registered under the entry point for a given name and group @@ -215,7 +213,7 @@ def get_entry_point_groups(): :return: a list of valid entry point groups """ - return entry_point_group_to_module_path_map.keys() + return ENTRY_POINT_GROUP_TO_MODULE_PATH_MAP.keys() def get_entry_point_names(group, sort=True): @@ -244,6 +242,7 @@ def get_entry_points(group): """ return [ep for ep in ENTRYPOINT_MANAGER.iter_entry_points(group=group)] + @functools.lru_cache(maxsize=None) def get_entry_point(group, name): """ @@ -289,7 +288,7 @@ def get_entry_point_from_class(class_module, class_name): return None, None -def get_entry_point_string_from_class(class_module, class_name): +def get_entry_point_string_from_class(class_module, class_name): # pylint: disable=invalid-name """ Given the module and name of a class, attempt to obtain the corresponding entry point if it exists and return the entry point string which will be the entry point group and entry point @@ -311,8 +310,7 @@ def get_entry_point_string_from_class(class_module, class_name): if group and entry_point: return ENTRY_POINT_STRING_SEPARATOR.join([group, entry_point.name]) - else: - return None + return None def is_valid_entry_point_string(entry_point_string): @@ -326,9 +324,9 @@ def is_valid_entry_point_string(entry_point_string): :return: True if the string is considered valid, False otherwise """ try: - group, name = entry_point_string.split(ENTRY_POINT_STRING_SEPARATOR) + group, _ = entry_point_string.split(ENTRY_POINT_STRING_SEPARATOR) except (AttributeError, ValueError): # Either `entry_point_string` is not a string or it does not contain the separator return False - return group in entry_point_group_to_module_path_map + return group in ENTRY_POINT_GROUP_TO_MODULE_PATH_MAP diff --git a/docs/source/verdi/verdi_user_guide.rst b/docs/source/verdi/verdi_user_guide.rst index 1e384344e1..041dabf572 100644 --- a/docs/source/verdi/verdi_user_guide.rst +++ b/docs/source/verdi/verdi_user_guide.rst @@ -702,14 +702,20 @@ Below is a list with all available subcommands. Execute scripts with preloaded AiiDA environment. Options: - -g, --group Enables the autogrouping [default: True] + --group / --no-group Enables the autogrouping [default: True] + -l, --group-label TEXT Specify the label of the auto group -n, --group-name TEXT Specify the name of the auto group - -e, --exclude TEXT Exclude these classes from auto grouping - -i, --include TEXT Include these classes from auto grouping + [DEPRECATED, USE --group-label instead] + -e, --exclude TEXT Exclude these classes from auto grouping (use + full entrypoint strings) + -i, --include TEXT Include these classes from auto grouping (use + full entrypoint strings or "all") -E, --excludesubclasses TEXT Exclude these classes and their sub classes - from auto grouping + from auto grouping (use full entrypoint + strings) -I, --includesubclasses TEXT Include these classes and their sub classes - from auto grouping + from auto grouping (use full entrypoint + strings) --help Show this message and exit. diff --git a/tests/backends/aiida_django/migrations/test_migrations_common.py b/tests/backends/aiida_django/migrations/test_migrations_common.py index f8de61f9a6..43f4f03b3d 100644 --- a/tests/backends/aiida_django/migrations/test_migrations_common.py +++ b/tests/backends/aiida_django/migrations/test_migrations_common.py @@ -38,8 +38,8 @@ def setUp(self): from aiida.backends.djsite import get_scoped_session from aiida.orm import autogroup - self.current_autogroup = autogroup.current_autogroup - autogroup.current_autogroup = None + self.current_autogroup = autogroup.CURRENT_AUTOGROUP + autogroup.CURRENT_AUTOGROUP = None assert self.migrate_from and self.migrate_to, \ "TestCase '{}' must define migrate_from and migrate_to properties".format(type(self).__name__) self.migrate_from = [(self.app, self.migrate_from)] @@ -85,7 +85,7 @@ def tearDown(self): """At the end make sure we go back to the latest schema version.""" from aiida.orm import autogroup self._revert_database_schema() - autogroup.current_autogroup = self.current_autogroup + autogroup.CURRENT_AUTOGROUP = self.current_autogroup def setUpBeforeMigration(self): """Anything to do before running the migrations, which should be implemented in test subclasses.""" diff --git a/tests/backends/aiida_sqlalchemy/test_migrations.py b/tests/backends/aiida_sqlalchemy/test_migrations.py index 8bdda5d145..fdd27298ca 100644 --- a/tests/backends/aiida_sqlalchemy/test_migrations.py +++ b/tests/backends/aiida_sqlalchemy/test_migrations.py @@ -57,8 +57,8 @@ def setUp(self): super().setUp() from aiida.orm import autogroup - self.current_autogroup = autogroup.current_autogroup - autogroup.current_autogroup = None + self.current_autogroup = autogroup.CURRENT_AUTOGROUP + autogroup.CURRENT_AUTOGROUP = None assert self.migrate_from and self.migrate_to, \ "TestCase '{}' must define migrate_from and migrate_to properties".format(type(self).__name__) @@ -99,7 +99,7 @@ def tearDown(self): """ from aiida.orm import autogroup self._reset_database_and_schema() - autogroup.current_autogroup = self.current_autogroup + autogroup.CURRENT_AUTOGROUP = self.current_autogroup super().tearDown() def setUpBeforeMigration(self): # pylint: disable=invalid-name @@ -218,8 +218,8 @@ def setUp(self): AiidaTestCase.setUp(self) # pylint: disable=bad-super-call from aiida.orm import autogroup - self.current_autogroup = autogroup.current_autogroup - autogroup.current_autogroup = None + self.current_autogroup = autogroup.CURRENT_AUTOGROUP + autogroup.CURRENT_AUTOGROUP = None assert self.migrate_from and self.migrate_to, \ "TestCase '{}' must define migrate_from and migrate_to properties".format(type(self).__name__) @@ -234,6 +234,11 @@ def setUp(self): self._reset_database_and_schema() raise + def tearDown(self): + """Put back the correct autogroup.""" + from aiida.orm import autogroup + autogroup.CURRENT_AUTOGROUP = self.current_autogroup + class TestMigrationEngine(TestMigrationsSQLA): """ diff --git a/tests/cmdline/commands/test_calcjob.py b/tests/cmdline/commands/test_calcjob.py index dc7895c5d4..2f1945d45a 100644 --- a/tests/cmdline/commands/test_calcjob.py +++ b/tests/cmdline/commands/test_calcjob.py @@ -98,6 +98,7 @@ def setUpClass(cls, *args, **kwargs): cls.arithmetic_job = calculations[0] def setUp(self): + super().setUp() self.cli_runner = CliRunner() def test_calcjob_res(self): diff --git a/tests/cmdline/commands/test_run.py b/tests/cmdline/commands/test_run.py index 3cae78fb70..6af129d479 100644 --- a/tests/cmdline/commands/test_run.py +++ b/tests/cmdline/commands/test_run.py @@ -8,6 +8,8 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Tests for `verdi run`.""" +import tempfile + from click.testing import CliRunner from aiida.backends.testbase import AiidaTestCase @@ -28,7 +30,6 @@ def test_run_workfunction(self): that are defined within the script will fail, as the inspect module will not correctly be able to determin the full path of the source file. """ - import tempfile from aiida.orm import load_node from aiida.orm import WorkFunctionNode @@ -64,3 +65,158 @@ def wf(): self.assertTrue(isinstance(node, WorkFunctionNode)) self.assertEqual(node.function_name, 'wf') self.assertEqual(node.get_function_source_code(), script_content) + + +class TestAutoGroups(AiidaTestCase): + """Test the autogroup functionality.""" + + def setUp(self): + """Setup the CLI runner to run command line commands.""" + from aiida.orm import autogroup + + super().setUp() + self.cli_runner = CliRunner() + # I need to disable the global variable of this test environment, + # because invoke is just calling the function and therefore inheriting + # the global variable + self._old_autogroup = autogroup.CURRENT_AUTOGROUP + autogroup.CURRENT_AUTOGROUP = None + + def tearDown(self): + """Setup the CLI runner to run command line commands.""" + from aiida.orm import autogroup + + super().tearDown() + autogroup.CURRENT_AUTOGROUP = self._old_autogroup + + def test_autogroup(self): + """Check if the autogroup is properly generated.""" + from aiida.orm import QueryBuilder, Node, Group, load_node + + script_content = """from aiida.orm import Data +node = Data().store() +print(node.pk) +""" + + with tempfile.NamedTemporaryFile(mode='w+') as fhandle: + fhandle.write(script_content) + fhandle.flush() + + options = [fhandle.name] + result = self.cli_runner.invoke(cmd_run.run, options) + self.assertClickResultNoException(result) + + pk = int(result.output) + _ = load_node(pk) # Check if the node can be loaded + + queryb = QueryBuilder().append(Node, filters={'id': pk}, tag='node') + queryb.append(Group, with_node='node', filters={'type_string': 'auto.run'}, project='*') + all_auto_groups = queryb.all() + self.assertEqual( + len(all_auto_groups), 1, 'There should be only one autogroup associated with the node just created' + ) + + def test_autogroup_custom_label(self): + """Check if the autogroup is properly generated with the label specified.""" + from aiida.orm import QueryBuilder, Node, Group, load_node + + script_content = """from aiida.orm import Data +node = Data().store() +print(node.pk) +""" + autogroup_label = 'SOME_group_LABEL' + with tempfile.NamedTemporaryFile(mode='w+') as fhandle: + fhandle.write(script_content) + fhandle.flush() + + options = [fhandle.name, '--group-label', autogroup_label] + result = self.cli_runner.invoke(cmd_run.run, options) + self.assertClickResultNoException(result) + + pk = int(result.output) + _ = load_node(pk) # Check if the node can be loaded + + queryb = QueryBuilder().append(Node, filters={'id': pk}, tag='node') + queryb.append(Group, with_node='node', filters={'type_string': 'auto.run'}, project='*') + all_auto_groups = queryb.all() + self.assertEqual( + len(all_auto_groups), 1, 'There should be only one autogroup associated with the node just created' + ) + self.assertEqual(all_auto_groups[0][0].label, autogroup_label) + + def test_no_autogroup(self): + """Check if the autogroup is not generated if ``verdi run`` is asked not to.""" + from aiida.orm import QueryBuilder, Node, Group, load_node + + script_content = """from aiida.orm import Data +node = Data().store() +print(node.pk) +""" + + with tempfile.NamedTemporaryFile(mode='w+') as fhandle: + fhandle.write(script_content) + fhandle.flush() + + options = [fhandle.name, '--no-group'] + result = self.cli_runner.invoke(cmd_run.run, options) + self.assertClickResultNoException(result) + + pk = int(result.output) + _ = load_node(pk) # Check if the node can be loaded + + queryb = QueryBuilder().append(Node, filters={'id': pk}, tag='node') + queryb.append(Group, with_node='node', filters={'type_string': 'auto.run'}, project='*') + all_auto_groups = queryb.all() + self.assertEqual(len(all_auto_groups), 0, 'There should be no autogroup generated') + + def test_autogroup_filter_class(self): # pylint: disable=too-many-locals + """Check if the autogroup is properly generated but filtered classes are skipped.""" + from aiida.orm import QueryBuilder, Node, Group, load_node + + script_content = """from aiida.orm import Data +node1 = Data().store() +node2 = Int(3).store() +print(node1.pk) +print(node2.pk) +""" + + for flags, data_in_autogroup, int_in_autogroup in [ + [['--exclude', 'aiida.node:data'], False, True], + [['--exclude', 'aiida.data:int'], True, False], + [['--excludesubclasses', 'aiida.node:data'], False, False], + [['--excludesubclasses', 'aiida.data:int'], True, False], + [['--excludesubclasses', 'aiida.node:data', 'aiida.data:int'], False, False], + [['--include', 'aiida.node:process'], False, False], + [['--exclude', 'aiida.node:data', 'aiida.data:int'], False, False], + ]: + with tempfile.NamedTemporaryFile(mode='w+') as fhandle: + fhandle.write(script_content) + fhandle.flush() + + options = [fhandle.name] + flags + ['--'] + result = self.cli_runner.invoke(cmd_run.run, options) + self.assertClickResultNoException(result) + + pk1_str, pk2_str = result.output.split() + pk1 = int(pk1_str) + pk2 = int(pk2_str) + _ = load_node(pk1) # Check if the node can be loaded + _ = load_node(pk2) # Check if the node can be loaded + + queryb = QueryBuilder().append(Node, filters={'id': pk1}, tag='node') + queryb.append(Group, with_node='node', filters={'type_string': 'auto.run'}, project='*') + all_auto_groups_data = queryb.all() + + queryb = QueryBuilder().append(Node, filters={'id': pk2}, tag='node') + queryb.append(Group, with_node='node', filters={'type_string': 'auto.run'}, project='*') + all_auto_groups_int = queryb.all() + self.assertEqual( + len(all_auto_groups_data), 1 if data_in_autogroup else 0, + 'Wrong number of nodes in autogroup associated with the Data node ' + "just created with flags '{}'".format(' '.join(flags)) + ) + self.assertEqual( + len(all_auto_groups_int), 1 if int_in_autogroup else 0, + 'Wrong number of nodes in autogroup associated with the Int node ' + "just created with flags '{}'".format(' '.join(flags)) + ) diff --git a/tests/orm/test_groups.py b/tests/orm/test_groups.py index 9c842aa2c4..ce2797daad 100644 --- a/tests/orm/test_groups.py +++ b/tests/orm/test_groups.py @@ -8,7 +8,6 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Test for the Group ORM class.""" - from aiida import orm from aiida.backends.testbase import AiidaTestCase from aiida.common import exceptions From 43d852ebf0048b0a623061d4b22007ce9dd3a576 Mon Sep 17 00:00:00 2001 From: Giovanni Pizzi Date: Thu, 2 Apr 2020 13:13:00 +0200 Subject: [PATCH 2/5] Centralised the creation of the Autogroup This also remove an overzelous isinstance check, and moves additional checks in a cached function that is run only when storing the very first node (that needs to be put in an autogroup), making storing of nodes faster (even if times oscillates so it's hard to estimate exactly by how much). Also, added logic to allow for concurrent creation of multiple groups (and test). This fixes #997 --- .ci/workchains.py | 4 +- .gitignore | 1 + aiida/backends/testbase.py | 11 ++ aiida/cmdline/commands/cmd_run.py | 34 +++-- aiida/engine/processes/calcjobs/calcjob.py | 4 +- aiida/manage/caching.py | 8 +- aiida/orm/autogroup.py | 133 ++++++++++++++---- aiida/orm/nodes/node.py | 17 +-- aiida/plugins/entry_point.py | 14 +- docs/source/verdi/verdi_user_guide.rst | 33 +++-- .../aiida_sqlalchemy/test_migrations.py | 51 +++---- tests/cmdline/commands/test_run.py | 46 +++++- .../engine/processes/workchains/test_utils.py | 6 + tests/tools/importexport/orm/test_codes.py | 2 + tests/tools/visualization/test_graph.py | 2 + utils/dependency_management.py | 5 +- 16 files changed, 252 insertions(+), 119 deletions(-) diff --git a/.ci/workchains.py b/.ci/workchains.py index 110334f0ae..f5ab3872d7 100644 --- a/.ci/workchains.py +++ b/.ci/workchains.py @@ -68,8 +68,8 @@ def a_magic_unicorn_appeared(self, node): @process_handler(priority=400, exit_codes=ArithmeticAddCalculation.exit_codes.ERROR_NEGATIVE_NUMBER) def error_negative_sum(self, node): """What even is a negative number, how can I have minus three melons?!.""" - self.ctx.inputs.x = Int(abs(node.inputs.x.value)) - self.ctx.inputs.y = Int(abs(node.inputs.y.value)) + self.ctx.inputs.x = Int(abs(node.inputs.x.value)) # pylint: disable=invalid-name + self.ctx.inputs.y = Int(abs(node.inputs.y.value)) # pylint: disable=invalid-name return ProcessHandlerReport(True) diff --git a/.gitignore b/.gitignore index 9d225c3ef0..1983db653d 100644 --- a/.gitignore +++ b/.gitignore @@ -20,6 +20,7 @@ .cache .pytest_cache .coverage +coverage.xml # Files created by RPN tests .ci/polish/polish_workchains/polish* diff --git a/aiida/backends/testbase.py b/aiida/backends/testbase.py index de855eec4b..ed18f27566 100644 --- a/aiida/backends/testbase.py +++ b/aiida/backends/testbase.py @@ -99,7 +99,11 @@ def tearDown(self): def reset_database(self): """Reset the database to the default state deleting any content currently stored""" + from aiida.orm import autogroup + self.clean_db() + if autogroup.CURRENT_AUTOGROUP is not None: + autogroup.CURRENT_AUTOGROUP.clear_group_cache() self.insert_data() @classmethod @@ -109,7 +113,10 @@ def insert_data(cls): inserts default data into the database (which is for the moment a default computer). """ + from aiida.orm import User + cls.create_user() + User.objects.reset() cls.create_computer() @classmethod @@ -180,7 +187,11 @@ def user_email(cls): # pylint: disable=no-self-argument def tearDownClass(cls, *args, **kwargs): # pylint: disable=arguments-differ # Double check for double security to avoid to run the tearDown # if this is not a test profile + from aiida.orm import autogroup + check_if_tests_can_run() + if autogroup.CURRENT_AUTOGROUP is not None: + autogroup.CURRENT_AUTOGROUP.clear_group_cache() cls.clean_db() cls.clean_repository() cls.__backend_instance.tearDownClass_method(*args, **kwargs) diff --git a/aiida/cmdline/commands/cmd_run.py b/aiida/cmdline/commands/cmd_run.py index fb591bd0ed..6c09ff6cd7 100644 --- a/aiida/cmdline/commands/cmd_run.py +++ b/aiida/cmdline/commands/cmd_run.py @@ -20,7 +20,6 @@ from aiida.cmdline.params.options.multivalue import MultipleValueOption from aiida.cmdline.utils import decorators, echo from aiida.common.warnings import AiidaDeprecationWarning -from aiida.orm import autogroup @contextlib.contextmanager @@ -43,6 +42,8 @@ def update_environment(argv): def validate_entrypoint_string_or_all(ctx, param, value, allow_all=True): # pylint: disable=unused-argument,invalid-name """Validate that `value` is a valid entrypoint string or the string 'all'.""" + from aiida.orm import autogroup + try: autogroup.Autogroup.validate(value, allow_all=allow_all) except Exception as exc: @@ -55,19 +56,26 @@ def validate_entrypoint_string_or_all(ctx, param, value, allow_all=True): # pyl @click.argument('scriptname', type=click.STRING) @click.argument('varargs', nargs=-1, type=click.UNPROCESSED) @click.option('--group/--no-group', default=True, show_default=True, help='Enables the autogrouping') -@click.option('-l', '--group-label', type=click.STRING, required=False, help='Specify the label of the auto group') +@click.option( + '-l', + '--group-label-prefix', + type=click.STRING, + required=False, + help='Specify the prefix of the label of the auto group (numbers might be automatically ' + 'appended to generate unique names per run)' +) @click.option( '-n', '--group-name', type=click.STRING, required=False, - help='Specify the name of the auto group [DEPRECATED, USE --group-label instead]' + help='Specify the name of the auto group [DEPRECATED, USE --group-label-prefix instead]' ) @click.option( '-e', '--exclude', cls=MultipleValueOption, - default=[], + default=lambda: [], help='Exclude these classes from auto grouping (use full entrypoint strings)', callback=functools.partial(validate_entrypoint_string_or_all, allow_all=False) ) @@ -75,7 +83,7 @@ def validate_entrypoint_string_or_all(ctx, param, value, allow_all=True): # pyl '-i', '--include', cls=MultipleValueOption, - default=['all'], + default=lambda: ['all'], help='Include these classes from auto grouping (use full entrypoint strings or "all")', callback=validate_entrypoint_string_or_all ) @@ -83,7 +91,7 @@ def validate_entrypoint_string_or_all(ctx, param, value, allow_all=True): # pyl '-E', '--excludesubclasses', cls=MultipleValueOption, - default=[], + default=lambda: [], help='Exclude these classes and their sub classes from auto grouping (use full entrypoint strings)', callback=functools.partial(validate_entrypoint_string_or_all, allow_all=False) ) @@ -91,15 +99,18 @@ def validate_entrypoint_string_or_all(ctx, param, value, allow_all=True): # pyl '-I', '--includesubclasses', cls=MultipleValueOption, - default=[], + default=lambda: [], help='Include these classes and their sub classes from auto grouping (use full entrypoint strings)', callback=functools.partial(validate_entrypoint_string_or_all, allow_all=False) ) @decorators.with_dbenv() -def run(scriptname, varargs, group, group_label, group_name, exclude, excludesubclasses, include, includesubclasses): +def run( + scriptname, varargs, group, group_label_prefix, group_name, exclude, excludesubclasses, include, includesubclasses +): # pylint: disable=too-many-arguments,exec-used """Execute scripts with preloaded AiiDA environment.""" from aiida.cmdline.utils.shell import DEFAULT_MODULES_LIST + from aiida.orm import autogroup # Prepare the environment for the script to be run globals_dict = { @@ -121,11 +132,10 @@ def run(scriptname, varargs, group, group_label, group_name, exclude, excludesub group_label = group_name if group: - automatic_group_label = group_label - aiida_verdilib_autogroup = autogroup.Autogroup() - if automatic_group_label is not None: - aiida_verdilib_autogroup.set_group_label(automatic_group_label) + # if group_label_prefix is None, use autogenerated name + if group_label_prefix is not None: + aiida_verdilib_autogroup.set_group_label_prefix(group_label_prefix) aiida_verdilib_autogroup.set_exclude(exclude) aiida_verdilib_autogroup.set_include(include) aiida_verdilib_autogroup.set_exclude_with_subclasses(excludesubclasses) diff --git a/aiida/engine/processes/calcjobs/calcjob.py b/aiida/engine/processes/calcjobs/calcjob.py index 0e6b234ef0..9f3d2d765f 100644 --- a/aiida/engine/processes/calcjobs/calcjob.py +++ b/aiida/engine/processes/calcjobs/calcjob.py @@ -64,7 +64,7 @@ def validate_calc_job(inputs, ctx): ) -def validate_parser(parser_name, ctx): +def validate_parser(parser_name, ctx): # pylint: disable=unused-argument """Validate the parser. :raises InputValidationError: if the parser name does not correspond to a loadable `Parser` class. @@ -78,7 +78,7 @@ def validate_parser(parser_name, ctx): raise exceptions.InputValidationError('invalid parser specified: {}'.format(exception)) -def validate_resources(resources, ctx): +def validate_resources(resources, ctx): # pylint: disable=unused-argument """Validate the resources. :raises InputValidationError: if `num_machines` is not specified or is not an integer. diff --git a/aiida/manage/caching.py b/aiida/manage/caching.py index d8079fd747..9b7f1d427d 100644 --- a/aiida/manage/caching.py +++ b/aiida/manage/caching.py @@ -22,7 +22,7 @@ from aiida.common import exceptions from aiida.common.lang import type_check -from aiida.plugins.entry_point import ENTRY_POINT_STRING_SEPARATOR, entry_point_group_to_module_path_map +from aiida.plugins.entry_point import ENTRY_POINT_STRING_SEPARATOR, ENTRY_POINT_GROUP_TO_MODULE_PATH_MAP __all__ = ('get_use_cache', 'enable_caching', 'disable_caching') @@ -248,7 +248,7 @@ def _validate_identifier_pattern(*, identifier): 1. - where `group_name` is one of the keys in `entry_point_group_to_module_path_map` + where `group_name` is one of the keys in `ENTRY_POINT_GROUP_TO_MODULE_PATH_MAP` and `tail` can be anything _except_ `ENTRY_POINT_STRING_SEPARATOR`. 2. a fully qualified Python name @@ -276,7 +276,7 @@ def _validate_identifier_pattern(*, identifier): group_pattern, _ = identifier.split(ENTRY_POINT_STRING_SEPARATOR) if not any( _match_wildcard(string=group_name, pattern=group_pattern) - for group_name in entry_point_group_to_module_path_map + for group_name in ENTRY_POINT_GROUP_TO_MODULE_PATH_MAP ): raise ValueError( common_error_msg + "Group name pattern '{}' does not match any of the AiiDA entry point group names.". @@ -290,7 +290,7 @@ def _validate_identifier_pattern(*, identifier): # aiida.* or aiida.calculations* if '*' in identifier: group_part, _ = identifier.split('*', 1) - if any(group_name.startswith(group_part) for group_name in entry_point_group_to_module_path_map): + if any(group_name.startswith(group_part) for group_name in ENTRY_POINT_GROUP_TO_MODULE_PATH_MAP): return # Finally, check if it could be a fully qualified Python name for identifier_part in identifier.split('.'): diff --git a/aiida/orm/autogroup.py b/aiida/orm/autogroup.py index f11bf69a36..cbba2f1c4e 100644 --- a/aiida/orm/autogroup.py +++ b/aiida/orm/autogroup.py @@ -11,8 +11,9 @@ import warnings from aiida.common import exceptions, timezone +from aiida.common.escaping import escape_for_sql_like from aiida.common.warnings import AiidaDeprecationWarning -from aiida.orm import GroupTypeString +from aiida.orm import GroupTypeString, Group from aiida.plugins import load_entry_point_from_string CURRENT_AUTOGROUP = None @@ -35,14 +36,15 @@ class Autogroup: def __init__(self): """Initialize with defaults.""" - self.exclude = [] - self.exclude_with_subclasses = [] - self.include = ['all'] - self.include_with_subclasses = [] + self._exclude = [] + self._exclude_with_subclasses = [] + self._include = ['all'] + self._include_with_subclasses = [] now = timezone.now() - gname = 'Verdi autogroup on ' + now.strftime('%Y-%m-%d %H:%M:%S') - self.group_label = gname + default_label_prefix = 'Verdi autogroup on ' + now.strftime('%Y-%m-%d %H:%M:%S') + self._group_label_prefix = default_label_prefix + self._group_label = None # Actual group label, set by `get_or_create_group` @staticmethod def validate(param, allow_all=True): @@ -58,38 +60,38 @@ def validate(param, allow_all=True): def get_exclude(self): """Return the list of classes to exclude from autogrouping.""" - return self.exclude + return self._exclude def get_exclude_with_subclasses(self): """ Return the list of classes to exclude from autogrouping. Will also exclude their derived subclasses """ - return self.exclude_with_subclasses + return self._exclude_with_subclasses def get_include(self): """Return the list of classes to include in the autogrouping.""" - return self.include + return self._include def get_include_with_subclasses(self): """Return the list of classes to include in the autogrouping. Will also include their derived subclasses.""" - return self.include_with_subclasses + return self._include_with_subclasses - def get_group_label(self): - """Get the name of the group. - If no group name was set, it will set a default one by itself.""" - return self.group_label + def get_group_label_prefix(self): + """Get the prefix of the label of the group. + If no group label prefix was set, it will set a default one by itself.""" + return self._group_label_prefix def get_group_name(self): """Get the label of the group. If no group label was set, it will set a default one by itself. .. deprecated:: 1.1.0 - Will be removed in `v2.0.0`, use :py:meth:`.get_group_label` instead. + Will be removed in `v2.0.0`, use :py:meth:`.get_group_label_prefix` instead. """ - warnings.warn('function is deprecated, use `get_group_label` instead', AiidaDeprecationWarning) # pylint: disable=no-member - return self.get_group_label() + warnings.warn('function is deprecated, use `get_group_label_prefix` instead', AiidaDeprecationWarning) # pylint: disable=no-member + return self.get_group_label_prefix() def set_exclude(self, exclude): """Return the list of classes to exclude from autogrouping. @@ -100,7 +102,7 @@ def set_exclude(self, exclude): if 'all' in self.get_include(): if 'all' in exclude: raise exceptions.ValidationError('Cannot exclude and include all classes') - self.exclude = exclude + self._exclude = exclude def set_exclude_with_subclasses(self, exclude): """ @@ -110,7 +112,7 @@ def set_exclude_with_subclasses(self, exclude): :param exclude: a list of valid entry point strings (one of which could be the string 'all') """ self.validate(exclude) - self.exclude_with_subclasses = exclude + self._exclude_with_subclasses = exclude def set_include(self, include): """ @@ -123,7 +125,7 @@ def set_include(self, include): if 'all' in include: raise exceptions.ValidationError('Cannot exclude and include all classes') - self.include = include + self._include = include def set_include_with_subclasses(self, include): """ @@ -133,24 +135,24 @@ def set_include_with_subclasses(self, include): :param include: a list of valid entry point strings (one of which could be the string 'all') """ self.validate(include) - self.include_with_subclasses = include + self._include_with_subclasses = include - def set_group_label(self, label): + def set_group_label_prefix(self, label_prefix): """ Set the label of the group to be created """ - if not isinstance(label, str): + if not isinstance(label_prefix, str): raise exceptions.ValidationError('group label must be a string') - self.group_label = label + self._group_label_prefix = label_prefix def set_group_name(self, gname): """Set the name of the group. .. deprecated:: 1.1.0 - Will be removed in `v2.0.0`, use :py:meth:`.set_group_label` instead. + Will be removed in `v2.0.0`, use :py:meth:`.set_group_label_prefix` instead. """ - warnings.warn('function is deprecated, use `set_group_label` instead', AiidaDeprecationWarning) # pylint: disable=no-member - return self.set_group_label(label=gname) + warnings.warn('function is deprecated, use `set_group_label_prefix` instead', AiidaDeprecationWarning) # pylint: disable=no-member + return self.set_group_label_prefix(label_prefix=gname) def is_to_be_grouped(self, the_class): """ @@ -194,3 +196,78 @@ def is_to_be_grouped(self, the_class): return False # If we're here, the class is not in the include - return False return False + + def clear_group_cache(self): + """Clear the cache of the group name. + + This is mostly used by tests when they reset the database. + """ + self._group_label = None + + def get_or_create_group(self): + """Return the current Autogroup, or create one if None has been set yet.""" + from aiida.orm import QueryBuilder + if self._group_label is not None: + results = [ + res[0] for res in QueryBuilder(). + append(Group, filters={ + 'label': self._group_label, + 'type_string': VERDIAUTOGROUP_TYPE + }, project='*').iterall() + ] + if results: + # If it is not empty, it should have only one result due to the + # uniqueness constraints + return results[0] + # There are no results: probably the group has been deleted. + # I continue as if it was not cached + self._group_label = None + + label_prefix = self.get_group_label_prefix() + # Try to do a preliminary QB query to avoid to do too many try/except + # if many of the prefix_NUMBER groups already exist + queryb = QueryBuilder().append( + Group, + filters={ + 'or': [{ + 'label': { + '==': label_prefix + } + }, { + 'label': { + 'like': escape_for_sql_like(label_prefix + '_') + '%' + } + }] + }, + project='label' + ) + existing_group_labels = [res[0][len(label_prefix):] for res in queryb.all()] + existing_group_ints = [] + for label in existing_group_labels: + if label == '': + # This is just the prefix without name - corresponds to counter = 0 + existing_group_ints.append(0) + else: + if label.startswith('_'): + try: + existing_group_ints.append(int(label[1:])) + except ValueError: + # It's not an integer, so it will never collide - just ignore it + pass + + if not existing_group_ints: + counter = 0 + else: + counter = max(existing_group_ints) + 1 + + while True: + try: + label = label_prefix if counter == 0 else '{}_{}'.format(label_prefix, counter) + group = Group(label=label, type_string=VERDIAUTOGROUP_TYPE).store() + self._group_label = group.label + except exceptions.IntegrityError: + counter += 1 + else: + break + + return group diff --git a/aiida/orm/nodes/node.py b/aiida/orm/nodes/node.py index fecfd555ca..dedbddcfc5 100644 --- a/aiida/orm/nodes/node.py +++ b/aiida/orm/nodes/node.py @@ -23,6 +23,7 @@ from aiida.orm.utils.links import LinkManager, LinkTriple from aiida.orm.utils.repository import Repository from aiida.orm.utils.node import AbstractNodeMeta, validate_attribute_extra_key +from aiida.orm import autogroup from ..comments import Comment from ..computers import Computer @@ -1024,18 +1025,10 @@ def store(self, with_transaction=True, use_cache=None): # pylint: disable=argum self._store(with_transaction=with_transaction, clean=True) # Set up autogrouping used by verdi run - from aiida.orm.autogroup import CURRENT_AUTOGROUP, Autogroup, VERDIAUTOGROUP_TYPE - from aiida.orm import Group - - if CURRENT_AUTOGROUP is not None: - if not isinstance(CURRENT_AUTOGROUP, Autogroup): - raise exceptions.ValidationError('`CURRENT_AUTOGROUP` is not of type `Autogroup`') - - if CURRENT_AUTOGROUP.is_to_be_grouped(self.__class__): - group_label = CURRENT_AUTOGROUP.get_group_label() - if group_label is not None: - group = Group.objects.get_or_create(label=group_label, type_string=VERDIAUTOGROUP_TYPE)[0] - group.add_nodes(self) + if autogroup.CURRENT_AUTOGROUP is not None: + if autogroup.CURRENT_AUTOGROUP.is_to_be_grouped(self.__class__): + group = autogroup.CURRENT_AUTOGROUP.get_or_create_group() + group.add_nodes(self) return self diff --git a/aiida/plugins/entry_point.py b/aiida/plugins/entry_point.py index 9be505f7dd..d2b7132e06 100644 --- a/aiida/plugins/entry_point.py +++ b/aiida/plugins/entry_point.py @@ -240,7 +240,7 @@ def get_entry_points(group): :param group: the entry point group :return: a list of entry points """ - return [ep for ep in ENTRYPOINT_MANAGER.iter_entry_points(group=group)] + return list(ENTRYPOINT_MANAGER.iter_entry_points(group=group)) @functools.lru_cache(maxsize=None) @@ -257,12 +257,16 @@ def get_entry_point(group, name): entry_points = [ep for ep in get_entry_points(group) if ep.name == name] if not entry_points: - raise MissingEntryPointError("Entry point '{}' not found in group '{}'.".format(name, group) + - 'Try running `reentry scan` to update the entry point cache.') + raise MissingEntryPointError( + "Entry point '{}' not found in group '{}'.".format(name, group) + + 'Try running `reentry scan` to update the entry point cache.' + ) if len(entry_points) > 1: - raise MultipleEntryPointError("Multiple entry points '{}' found in group '{}'. ".format(name, group) + - 'Try running `reentry scan` to repopulate the entry point cache.') + raise MultipleEntryPointError( + "Multiple entry points '{}' found in group '{}'. ".format(name, group) + + 'Try running `reentry scan` to repopulate the entry point cache.' + ) return entry_points[0] diff --git a/docs/source/verdi/verdi_user_guide.rst b/docs/source/verdi/verdi_user_guide.rst index 041dabf572..ee43c3999a 100644 --- a/docs/source/verdi/verdi_user_guide.rst +++ b/docs/source/verdi/verdi_user_guide.rst @@ -702,21 +702,24 @@ Below is a list with all available subcommands. Execute scripts with preloaded AiiDA environment. Options: - --group / --no-group Enables the autogrouping [default: True] - -l, --group-label TEXT Specify the label of the auto group - -n, --group-name TEXT Specify the name of the auto group - [DEPRECATED, USE --group-label instead] - -e, --exclude TEXT Exclude these classes from auto grouping (use - full entrypoint strings) - -i, --include TEXT Include these classes from auto grouping (use - full entrypoint strings or "all") - -E, --excludesubclasses TEXT Exclude these classes and their sub classes - from auto grouping (use full entrypoint - strings) - -I, --includesubclasses TEXT Include these classes and their sub classes - from auto grouping (use full entrypoint - strings) - --help Show this message and exit. + --group / --no-group Enables the autogrouping [default: True] + -l, --group-label-prefix TEXT Specify the prefix of the label of the auto + group (numbers might be automatically + appended to generate unique names per run) + -n, --group-name TEXT Specify the name of the auto group + [DEPRECATED, USE --group-label-prefix + instead] + -e, --exclude TEXT Exclude these classes from auto grouping (use + full entrypoint strings) + -i, --include TEXT Include these classes from auto grouping + (use full entrypoint strings or "all") + -E, --excludesubclasses TEXT Exclude these classes and their sub classes + from auto grouping (use full entrypoint + strings) + -I, --includesubclasses TEXT Include these classes and their sub classes + from auto grouping (use full entrypoint + strings) + --help Show this message and exit. .. _verdi_setup: diff --git a/tests/backends/aiida_sqlalchemy/test_migrations.py b/tests/backends/aiida_sqlalchemy/test_migrations.py index fdd27298ca..2147414f0d 100644 --- a/tests/backends/aiida_sqlalchemy/test_migrations.py +++ b/tests/backends/aiida_sqlalchemy/test_migrations.py @@ -22,7 +22,6 @@ from aiida.backends.sqlalchemy.models.base import Base from aiida.backends.sqlalchemy.utils import flag_modified from aiida.backends.testbase import AiidaTestCase -from aiida.common.utils import Capturing from .test_utils import new_database @@ -63,16 +62,21 @@ def setUp(self): "TestCase '{}' must define migrate_from and migrate_to properties".format(type(self).__name__) try: - with Capturing(): - self.migrate_db_down(self.migrate_from) + self.migrate_db_down(self.migrate_from) self.setUpBeforeMigration() - with Capturing(): - self.migrate_db_up(self.migrate_to) + self._perform_actual_migration() except Exception: # Bring back the DB to the correct state if this setup part fails self._reset_database_and_schema() raise + def _perform_actual_migration(self): + """Perform the actual migration (upwards, to migrate_to). + + Must be called after we are properly set to be in migrate_from. + """ + self.migrate_db_up(self.migrate_to) + def migrate_db_up(self, destination): """ Perform a migration upwards (upgrade) with alembic @@ -116,8 +120,7 @@ def _reset_database_and_schema(self): of tests. """ self.reset_database() - with Capturing(): - self.migrate_db_up('head') + self.migrate_db_up('head') @property def current_rev(self): @@ -210,34 +213,12 @@ class TestBackwardMigrationsSQLA(TestMigrationsSQLA): than the migrate_to revision. """ - def setUp(self): - """ - Go to the migrate_from revision, apply setUpBeforeMigration, then - run the migration. - """ - AiidaTestCase.setUp(self) # pylint: disable=bad-super-call - from aiida.orm import autogroup + def _perform_actual_migration(self): + """Perform the actual migration (downwards, to migrate_to). - self.current_autogroup = autogroup.CURRENT_AUTOGROUP - autogroup.CURRENT_AUTOGROUP = None - assert self.migrate_from and self.migrate_to, \ - "TestCase '{}' must define migrate_from and migrate_to properties".format(type(self).__name__) - - try: - with Capturing(): - self.migrate_db_down(self.migrate_from) - self.setUpBeforeMigration() - with Capturing(): - self.migrate_db_down(self.migrate_to) - except Exception: - # Bring back the DB to the correct state if this setup part fails - self._reset_database_and_schema() - raise - - def tearDown(self): - """Put back the correct autogroup.""" - from aiida.orm import autogroup - autogroup.CURRENT_AUTOGROUP = self.current_autogroup + Must be called after we are properly set to be in migrate_from. + """ + self.migrate_db_down(self.migrate_to) class TestMigrationEngine(TestMigrationsSQLA): @@ -1008,7 +989,7 @@ class TestDbLogUUIDAddition(TestMigrationsSQLA): """ Test that the UUID column is correctly added to the DbLog table and that the uniqueness constraint is added without problems (if the migration arrives until 375c2db70663 then the - constraint is added properly. + constraint is added properly). """ migrate_from = '041a79fc615f' # 041a79fc615f_dblog_cleaning diff --git a/tests/cmdline/commands/test_run.py b/tests/cmdline/commands/test_run.py index 6af129d479..8df034f8ef 100644 --- a/tests/cmdline/commands/test_run.py +++ b/tests/cmdline/commands/test_run.py @@ -129,7 +129,7 @@ def test_autogroup_custom_label(self): fhandle.write(script_content) fhandle.flush() - options = [fhandle.name, '--group-label', autogroup_label] + options = [fhandle.name, '--group-label-prefix', autogroup_label] result = self.cli_runner.invoke(cmd_run.run, options) self.assertClickResultNoException(result) @@ -220,3 +220,47 @@ def test_autogroup_filter_class(self): # pylint: disable=too-many-locals 'Wrong number of nodes in autogroup associated with the Int node ' "just created with flags '{}'".format(' '.join(flags)) ) + + def test_autogroup_clashing_label(self): + """Check if the autogroup label is properly (re)generated when it clashes with an existing group name.""" + from aiida.orm import QueryBuilder, Node, Group, load_node + + script_content = """from aiida.orm import Data +node = Data().store() +print(node.pk) +""" + autogroup_label = 'SOME_repeated_group_LABEL' + with tempfile.NamedTemporaryFile(mode='w+') as fhandle: + fhandle.write(script_content) + fhandle.flush() + + # First run + options = [fhandle.name, '--group-label-prefix', autogroup_label] + result = self.cli_runner.invoke(cmd_run.run, options) + self.assertClickResultNoException(result) + + pk = int(result.output) + _ = load_node(pk) # Check if the node can be loaded + queryb = QueryBuilder().append(Node, filters={'id': pk}, tag='node') + queryb.append(Group, with_node='node', filters={'type_string': 'auto.run'}, project='*') + all_auto_groups = queryb.all() + self.assertEqual( + len(all_auto_groups), 1, 'There should be only one autogroup associated with the node just created' + ) + self.assertEqual(all_auto_groups[0][0].label, autogroup_label) + + # A few more runs with the same label - it should not crash but append something to the group name + for _ in range(10): + options = [fhandle.name, '--group-label-prefix', autogroup_label] + result = self.cli_runner.invoke(cmd_run.run, options) + self.assertClickResultNoException(result) + + pk = int(result.output) + _ = load_node(pk) # Check if the node can be loaded + queryb = QueryBuilder().append(Node, filters={'id': pk}, tag='node') + queryb.append(Group, with_node='node', filters={'type_string': 'auto.run'}, project='*') + all_auto_groups = queryb.all() + self.assertEqual( + len(all_auto_groups), 1, 'There should be only one autogroup associated with the node just created' + ) + self.assertTrue(all_auto_groups[0][0].label.startswith(autogroup_label)) diff --git a/tests/engine/processes/workchains/test_utils.py b/tests/engine/processes/workchains/test_utils.py index 00f1e127a3..5591229d5b 100644 --- a/tests/engine/processes/workchains/test_utils.py +++ b/tests/engine/processes/workchains/test_utils.py @@ -53,6 +53,7 @@ def test_priority(self): attribute_key = 'handlers_called' class ArithmeticAddBaseWorkChain(BaseRestartWorkChain): + """Implementation of a possible BaseRestartWorkChain for the ArithmeticAdd calculation.""" _process_class = ArithmeticAddCalculation @@ -61,6 +62,7 @@ class ArithmeticAddBaseWorkChain(BaseRestartWorkChain): # This can then be checked after invoking `inspect_process` to ensure they were called in the right order @process_handler(priority=100) def handler_01(self, node): + """Example handler returing ExitCode 100.""" handlers_called = node.get_attribute(attribute_key, default=[]) handlers_called.append('handler_01') node.set_attribute(attribute_key, handlers_called) @@ -68,6 +70,7 @@ def handler_01(self, node): @process_handler(priority=300) def handler_03(self, node): + """Example handler returing ExitCode 300.""" handlers_called = node.get_attribute(attribute_key, default=[]) handlers_called.append('handler_03') node.set_attribute(attribute_key, handlers_called) @@ -75,6 +78,7 @@ def handler_03(self, node): @process_handler(priority=200) def handler_02(self, node): + """Example handler returing ExitCode 200.""" handlers_called = node.get_attribute(attribute_key, default=[]) handlers_called.append('handler_02') node.set_attribute(attribute_key, handlers_called) @@ -82,6 +86,7 @@ def handler_02(self, node): @process_handler(priority=400) def handler_04(self, node): + """Example handler returing ExitCode 400.""" handlers_called = node.get_attribute(attribute_key, default=[]) handlers_called.append('handler_04') node.set_attribute(attribute_key, handlers_called) @@ -159,6 +164,7 @@ def test_exit_codes_filter(self): node_skip.set_exit_status(200) # Some other exit status class ArithmeticAddBaseWorkChain(BaseRestartWorkChain): + """Minimal base restart workchain ofr the ArithmeticAdd calculation""" _process_class = ArithmeticAddCalculation diff --git a/tests/tools/importexport/orm/test_codes.py b/tests/tools/importexport/orm/test_codes.py index 5a11e07b94..d8f173107b 100644 --- a/tests/tools/importexport/orm/test_codes.py +++ b/tests/tools/importexport/orm/test_codes.py @@ -24,9 +24,11 @@ class TestCode(AiidaTestCase): """Test ex-/import cases related to Codes""" def setUp(self): + super().setUp() self.reset_database() def tearDown(self): + super().tearDown() self.reset_database() @with_temp_dir diff --git a/tests/tools/visualization/test_graph.py b/tests/tools/visualization/test_graph.py index 9f15cab9ca..d48a3e6800 100644 --- a/tests/tools/visualization/test_graph.py +++ b/tests/tools/visualization/test_graph.py @@ -22,9 +22,11 @@ class TestVisGraph(AiidaTestCase): """Tests for verdi graph""" def setUp(self): + super().setUp() self.reset_database() def tearDown(self): + super().tearDown() self.reset_database() def create_provenance(self): diff --git a/utils/dependency_management.py b/utils/dependency_management.py index 17442d66af..af476de3e7 100755 --- a/utils/dependency_management.py +++ b/utils/dependency_management.py @@ -239,7 +239,6 @@ def validate_environment_yml(): # pylint: disable=too-many-branches # Check that all requirements specified in the setup.json file are found in the # conda environment specification. - missing_from_env = set() for req in install_requirements: if any(re.match(ignore, str(req)) for ignore in CONDA_IGNORE): continue # skip explicitly ignored packages @@ -251,7 +250,7 @@ def validate_environment_yml(): # pylint: disable=too-many-branches # The only dependency left should be the one for Python itself, which is not part of # the install_requirements for setuptools. - if len(conda_dependencies) > 0: + if conda_dependencies: raise DependencySpecificationError( "The 'environment.yml' file contains dependencies that are missing " "in 'setup.json':\n- {}".format('\n- '.join(map(str, conda_dependencies))) @@ -304,7 +303,7 @@ def validate_pyproject_toml(): "Missing requirement '{}' in 'pyproject.toml'.".format(reentry_requirement) ) - except FileNotFoundError as error: + except FileNotFoundError: raise DependencySpecificationError("The 'pyproject.toml' file is missing!") click.secho('Pyproject.toml dependency specification is consistent.', fg='green') From ba4b2174d1a08348eb914492d4f7c8714dae478a Mon Sep 17 00:00:00 2001 From: Giovanni Pizzi Date: Fri, 3 Apr 2020 13:49:46 +0200 Subject: [PATCH 3/5] Addressed comments by Sebastiaan. In particular, the most important changes include: - now the auto-group flag is called `--auto-group`, is a flag and is False by default - only kept `--include` and `--exclude` options, checking typestrings and allowing to end with `%`. One benefit is that while reimplementing I replaced some `isinstance` with string comparisons, with potential further benefits. - `--include` and `--exclude` are now mutually exclusive - improved documentation of the main point of the issue --- aiida/cmdline/commands/cmd_run.py | 71 +++--- aiida/orm/autogroup.py | 231 ++++++++++-------- aiida/orm/nodes/node.py | 7 +- aiida/tools/ipython/ipython_magics.py | 4 +- docs/source/verdi/verdi_user_guide.rst | 31 +-- .../aiida_sqlalchemy/test_migrations.py | 1 + tests/cmdline/commands/test_run.py | 198 +++++++++++++-- .../engine/processes/workchains/test_utils.py | 4 +- tests/orm/test_autogroups.py | 129 ++++++++++ 9 files changed, 478 insertions(+), 198 deletions(-) create mode 100644 tests/orm/test_autogroups.py diff --git a/aiida/cmdline/commands/cmd_run.py b/aiida/cmdline/commands/cmd_run.py index 6c09ff6cd7..281e3092a1 100644 --- a/aiida/cmdline/commands/cmd_run.py +++ b/aiida/cmdline/commands/cmd_run.py @@ -40,12 +40,12 @@ def update_environment(argv): sys.path = _path -def validate_entrypoint_string_or_all(ctx, param, value, allow_all=True): # pylint: disable=unused-argument,invalid-name - """Validate that `value` is a valid entrypoint string or the string 'all'.""" +def validate_entrypoint_string(ctx, param, value): # pylint: disable=unused-argument,invalid-name + """Validate that `value` is a valid entrypoint string.""" from aiida.orm import autogroup try: - autogroup.Autogroup.validate(value, allow_all=allow_all) + autogroup.Autogroup.validate(value) except Exception as exc: raise click.BadParameter(str(exc) + ' ({})'.format(value)) @@ -55,58 +55,41 @@ def validate_entrypoint_string_or_all(ctx, param, value, allow_all=True): # pyl @verdi.command('run', context_settings=dict(ignore_unknown_options=True,)) @click.argument('scriptname', type=click.STRING) @click.argument('varargs', nargs=-1, type=click.UNPROCESSED) -@click.option('--group/--no-group', default=True, show_default=True, help='Enables the autogrouping') +@click.option('--auto-group', is_flag=True, help='Enables the autogrouping') @click.option( '-l', - '--group-label-prefix', + '--auto-group-label-prefix', type=click.STRING, required=False, help='Specify the prefix of the label of the auto group (numbers might be automatically ' - 'appended to generate unique names per run)' + 'appended to generate unique names per run).' ) @click.option( '-n', '--group-name', type=click.STRING, required=False, - help='Specify the name of the auto group [DEPRECATED, USE --group-label-prefix instead]' + help='Specify the name of the auto group [DEPRECATED, USE --auto-group-label-prefix instead]. ' + 'This also enables auto-grouping.' ) @click.option( '-e', '--exclude', cls=MultipleValueOption, - default=lambda: [], - help='Exclude these classes from auto grouping (use full entrypoint strings)', - callback=functools.partial(validate_entrypoint_string_or_all, allow_all=False) + default=None, + help='Exclude these classes from auto grouping (use full entrypoint strings).', + callback=functools.partial(validate_entrypoint_string) ) @click.option( '-i', '--include', cls=MultipleValueOption, - default=lambda: ['all'], - help='Include these classes from auto grouping (use full entrypoint strings or "all")', - callback=validate_entrypoint_string_or_all -) -@click.option( - '-E', - '--excludesubclasses', - cls=MultipleValueOption, - default=lambda: [], - help='Exclude these classes and their sub classes from auto grouping (use full entrypoint strings)', - callback=functools.partial(validate_entrypoint_string_or_all, allow_all=False) -) -@click.option( - '-I', - '--includesubclasses', - cls=MultipleValueOption, - default=lambda: [], - help='Include these classes and their sub classes from auto grouping (use full entrypoint strings)', - callback=functools.partial(validate_entrypoint_string_or_all, allow_all=False) + default=None, + help='Include these classes from auto grouping (use full entrypoint strings or "all").', + callback=validate_entrypoint_string ) @decorators.with_dbenv() -def run( - scriptname, varargs, group, group_label_prefix, group_name, exclude, excludesubclasses, include, includesubclasses -): +def run(scriptname, varargs, auto_group, auto_group_label_prefix, group_name, exclude, include): # pylint: disable=too-many-arguments,exec-used """Execute scripts with preloaded AiiDA environment.""" from aiida.cmdline.utils.shell import DEFAULT_MODULES_LIST @@ -126,20 +109,22 @@ def run( globals_dict['{}'.format(alias)] = getattr(__import__(app_mod, {}, {}, model_name), model_name) if group_name: - warnings.warn('--group-name is deprecated, use `--group-label` instead', AiidaDeprecationWarning) # pylint: disable=no-member - if group_label: - raise click.BadParameter('You cannot specify both --group-name and --group-label; use --group-label only') - group_label = group_name - - if group: + warnings.warn('--group-name is deprecated, use `--auto-group-label-prefix` instead', AiidaDeprecationWarning) # pylint: disable=no-member + if auto_group_label_prefix: + raise click.BadParameter( + 'You cannot specify both --group-name and --auto-group-label-prefix; use --group-label only' + ) + auto_group_label_prefix = group_name + # To have the old behavior, with auto-group enabled. + auto_group = True + + if auto_group: aiida_verdilib_autogroup = autogroup.Autogroup() - # if group_label_prefix is None, use autogenerated name - if group_label_prefix is not None: - aiida_verdilib_autogroup.set_group_label_prefix(group_label_prefix) + # Set the ``group_label_prefix`` if defined, otherwise a default prefix will be used + if auto_group_label_prefix is not None: + aiida_verdilib_autogroup.set_group_label_prefix(auto_group_label_prefix) aiida_verdilib_autogroup.set_exclude(exclude) aiida_verdilib_autogroup.set_include(include) - aiida_verdilib_autogroup.set_exclude_with_subclasses(excludesubclasses) - aiida_verdilib_autogroup.set_include_with_subclasses(includesubclasses) # Note: this is also set in the exec environment! This is the intended behavior autogroup.CURRENT_AUTOGROUP = aiida_verdilib_autogroup diff --git a/aiida/orm/autogroup.py b/aiida/orm/autogroup.py index cbba2f1c4e..eee4239b28 100644 --- a/aiida/orm/autogroup.py +++ b/aiida/orm/autogroup.py @@ -14,7 +14,7 @@ from aiida.common.escaping import escape_for_sql_like from aiida.common.warnings import AiidaDeprecationWarning from aiida.orm import GroupTypeString, Group -from aiida.plugins import load_entry_point_from_string +from aiida.plugins.entry_point import get_entry_point_string_from_class CURRENT_AUTOGROUP = None @@ -29,17 +29,21 @@ class Autogroup: If so, it will call Autogroup.is_to_be_grouped, and decide whether to put it in a group. Such autogroups are going to be of the VERDIAUTOGROUP_TYPE. - The exclude/include lists, can have values 'all' if you want to include/exclude all classes. - Otherwise, they are lists of strings like: calculation.quantumespresso.pw, data.array.kpoints, ... - i.e.: a string identifying the base class, than the path to the class as in Calculation/Data -Factories + The exclude/include lists are lists of strings like: + ``aiida.data:int``, ``aiida.calculation:quantumespresso.pw``, + ``aiida.data:array.%``, ... + i.e.: a string identifying the base class, followed a colona and by the path to the class + as accepted by CalculationFactory/DataFactory. + Each string contain the wildcard ``%`` at the end; + in this case this is used in a ``like`` comparison with the QueryBuilder. + + Only one of the two (between exclude and include) can be set. If none of the two is set, everything is included. """ def __init__(self): """Initialize with defaults.""" - self._exclude = [] - self._exclude_with_subclasses = [] - self._include = ['all'] - self._include_with_subclasses = [] + self._exclude = None + self._include = None now = timezone.now() default_label_prefix = 'Verdi autogroup on ' + now.strftime('%Y-%m-%d %H:%M:%S') @@ -47,36 +51,43 @@ def __init__(self): self._group_label = None # Actual group label, set by `get_or_create_group` @staticmethod - def validate(param, allow_all=True): - """ - Used internally to verify the sanity of exclude, include lists - - :param param: should be a list of valid entrypoint strings - """ - for string in param: - if allow_all and string == 'all': - continue - load_entry_point_from_string(string) # This will raise a MissingEntryPointError if invalid + def validate(strings): + """Validate the list of strings passed to set_include and set_exclude.""" + if strings is None: + return + valid_prefixes = set(['aiida.node', 'aiida.calculations', 'aiida.workflows', 'aiida.data']) + for string in strings: + pieces = string.split(':') + if len(pieces) != 2: + raise exceptions.ValidationError( + "'{}' is not a valid include/exclude filter, must contain two parts split by a colon". + format(string) + ) + if pieces[0] not in valid_prefixes: + raise exceptions.ValidationError( + "'{}' has an invalid prefix, must be among: {}".format(string, sorted(valid_prefixes)) + ) + + # If a % is present, it can only be the last character + string_without_percent = pieces[1] + if string_without_percent.endswith('%'): + string_without_percent = string_without_percent[:-1] + if '%' in string_without_percent: + raise exceptions.ValidationError( + "'{}' can only contain a '%' character, if any, at the end of the string".format(string) + ) def get_exclude(self): - """Return the list of classes to exclude from autogrouping.""" - return self._exclude + """Return the list of classes to exclude from autogrouping. - def get_exclude_with_subclasses(self): - """ - Return the list of classes to exclude from autogrouping. - Will also exclude their derived subclasses - """ - return self._exclude_with_subclasses + Returns ``None`` if no exclusion list has been set.""" + return self._exclude def get_include(self): - """Return the list of classes to include in the autogrouping.""" - return self._include - - def get_include_with_subclasses(self): """Return the list of classes to include in the autogrouping. - Will also include their derived subclasses.""" - return self._include_with_subclasses + + Returns ``None`` if no exclusion list has been set.""" + return self._include def get_group_label_prefix(self): """Get the prefix of the label of the group. @@ -87,7 +98,7 @@ def get_group_name(self): """Get the label of the group. If no group label was set, it will set a default one by itself. - .. deprecated:: 1.1.0 + .. deprecated:: 1.2.0 Will be removed in `v2.0.0`, use :py:meth:`.get_group_label_prefix` instead. """ warnings.warn('function is deprecated, use `get_group_label_prefix` instead', AiidaDeprecationWarning) # pylint: disable=no-member @@ -98,45 +109,28 @@ def set_exclude(self, exclude): :param exclude: a list of valid entry point strings (one of which could be the string 'all') """ + if isinstance(exclude, str): + exclude = [exclude] self.validate(exclude) - if 'all' in self.get_include(): - if 'all' in exclude: - raise exceptions.ValidationError('Cannot exclude and include all classes') + if exclude is not None and self.get_include() is not None: + # It's ok to set None, both as a default, or to 'undo' the exclude list + raise exceptions.ValidationError('Cannot both specify exclude and include') self._exclude = exclude - def set_exclude_with_subclasses(self, exclude): - """ - Set the list of classes to exclude from autogrouping. - Will also exclude their derived subclasses - - :param exclude: a list of valid entry point strings (one of which could be the string 'all') - """ - self.validate(exclude) - self._exclude_with_subclasses = exclude - def set_include(self, include): """ Set the list of classes to include in the autogrouping. :param include: a list of valid entry point strings (one of which could be the string 'all') """ + if isinstance(include, str): + include = [include] self.validate(include) - if 'all' in self.get_exclude(): - if 'all' in include: - raise exceptions.ValidationError('Cannot exclude and include all classes') - + if include is not None and self.get_exclude() is not None: + # It's ok to set None, both as a default, or to 'undo' the include list + raise exceptions.ValidationError('Cannot both specify exclude and include') self._include = include - def set_include_with_subclasses(self, include): - """ - Set the list of classes to include in the autogrouping. - Will also include their derived subclasses. - - :param include: a list of valid entry point strings (one of which could be the string 'all') - """ - self.validate(include) - self._include_with_subclasses = include - def set_group_label_prefix(self, label_prefix): """ Set the label of the group to be created @@ -148,54 +142,61 @@ def set_group_label_prefix(self, label_prefix): def set_group_name(self, gname): """Set the name of the group. - .. deprecated:: 1.1.0 + .. deprecated:: 1.2.0 Will be removed in `v2.0.0`, use :py:meth:`.set_group_label_prefix` instead. """ warnings.warn('function is deprecated, use `set_group_label_prefix` instead', AiidaDeprecationWarning) # pylint: disable=no-member return self.set_group_label_prefix(label_prefix=gname) - def is_to_be_grouped(self, the_class): + @staticmethod + def _matches(string, filter_string): + """Check if 'string' matches the 'filter_string' (used for include and exclude filters). + + If 'filter_string' does not end with a % sign, perform an exact match. + Otherwise, strip the '%' sign and match with string.startswith(filter_string[:-1]). + + :param string: the string to match. + :param filter_string: the filter string. """ - Return whether the given class has to be included in the autogroup according to include/exclude list + if filter_string.endswith('%'): + return string.startswith(filter_string[:-1]) + return string == filter_string - :return (bool): True if the_class is to be included in the autogroup + def is_to_be_grouped(self, node): """ - # strings, including possibly 'all' - include_exact = self.get_include() - include_with_subclasses = self.get_include_with_subclasses() + Return whether the given node has to be included in the autogroup according to include/exclude list - # actual classes, with 'all' stripped out - include_exact_classes = tuple( - load_entry_point_from_string(ep_string) for ep_string in include_exact if ep_string != 'all' - ) - include_with_subclasses_classes = tuple( - load_entry_point_from_string(ep_string) for ep_string in include_with_subclasses if ep_string != 'all' - ) + :return (bool): True if ``node`` is to be included in the autogroup + """ + # I import here to avoid circular imports + from aiida.orm.nodes.process import ProcessNode - if ( - 'all' in include_exact or the_class in include_exact_classes or - issubclass(the_class, include_with_subclasses_classes) - ): - # According to the include, this class should be included - # strings, including possibly 'all' - exclude_exact = self.get_exclude() - exclude_with_subclasses = self.get_exclude_with_subclasses() - - # actual classes, with 'all' stripped out - exclude_exact_classes = tuple( - load_entry_point_from_string(ep_string) for ep_string in exclude_exact if ep_string != 'all' - ) - exclude_with_subclasses_classes = tuple( - load_entry_point_from_string(ep_string) for ep_string in exclude_with_subclasses if ep_string != 'all' - ) - - if (the_class not in exclude_exact_classes and not issubclass(the_class, exclude_with_subclasses_classes)): - # If we are here, it's not excluded - return True - # If we're here, it's both in the include and in the exclude - exclude it - return False - # If we're here, the class is not in the include - return False - return False + # strings, including possibly 'all' + include = self.get_include() + exclude = self.get_exclude() + if include is None and exclude is None: + # Include all classes by default if nothing is explicitly specified. + return True + if include is not None and exclude is not None: + # We should never be here, anyway - this should be catched by the `set_include/exclude` methods + raise ValueError("You cannot specify both an 'include' and an 'exclude' list") + + the_class = node.__class__ + if issubclass(the_class, ProcessNode): + try: + the_class = node.process_class + except ValueError: + # It does not have a process class - we just check the node class then, it could be e.g. + # a bare CalculationNode. + pass + class_entry_point_string = get_entry_point_string_from_class(the_class.__module__, the_class.__name__) + if include is not None: + # As soon as a filter string matches, we include the class + return any(self._matches(class_entry_point_string, filter_string) for filter_string in include) + # If we are here, exclude is not None + # include *only* in *none* of the filters match (that is, exclude as + # soon as any of the filters matches) + return not any(self._matches(class_entry_point_string, filter_string) for filter_string in exclude) def clear_group_cache(self): """Clear the cache of the group name. @@ -205,8 +206,26 @@ def clear_group_cache(self): self._group_label = None def get_or_create_group(self): - """Return the current Autogroup, or create one if None has been set yet.""" + """Return the current Autogroup, or create one if None has been set yet. + + This function implements a somewhat complex logic that is however needed + to make sure that, even if `verdi run` is called at the same time multiple + times, e.g. in a for loop in bash, there is never the risk that two ``verdi run`` + Unix processes try to create the same group, with the same label, ending + up in a crash of the code (see PR #3650). + + Here, instead, we make sure that if this concurrency issue happens, + one of the two will get a IntegrityError from the DB, and then recover + trying to create a group with a different label (with a numeric suffix appended), + until it manages to create it. + """ from aiida.orm import QueryBuilder + + # When this function is called, if it is the first time, just generate + # a new group name (later on, after this ``if`` block`). + # In that case, we will later cache in ``self._group_label`` the group label, + # So the group with the same name can be returned quickly in future + # calls of this method. if self._group_label is not None: results = [ res[0] for res in QueryBuilder(). @@ -218,6 +237,7 @@ def get_or_create_group(self): if results: # If it is not empty, it should have only one result due to the # uniqueness constraints + assert len(results) == 1, 'I got more than one autogroup with the same label!' return results[0] # There are no results: probably the group has been deleted. # I continue as if it was not cached @@ -247,13 +267,12 @@ def get_or_create_group(self): if label == '': # This is just the prefix without name - corresponds to counter = 0 existing_group_ints.append(0) - else: - if label.startswith('_'): - try: - existing_group_ints.append(int(label[1:])) - except ValueError: - # It's not an integer, so it will never collide - just ignore it - pass + elif label.startswith('_'): + try: + existing_group_ints.append(int(label[1:])) + except ValueError: + # It's not an integer, so it will never collide - just ignore it + pass if not existing_group_ints: counter = 0 diff --git a/aiida/orm/nodes/node.py b/aiida/orm/nodes/node.py index dedbddcfc5..c5c7f85ce5 100644 --- a/aiida/orm/nodes/node.py +++ b/aiida/orm/nodes/node.py @@ -1025,10 +1025,9 @@ def store(self, with_transaction=True, use_cache=None): # pylint: disable=argum self._store(with_transaction=with_transaction, clean=True) # Set up autogrouping used by verdi run - if autogroup.CURRENT_AUTOGROUP is not None: - if autogroup.CURRENT_AUTOGROUP.is_to_be_grouped(self.__class__): - group = autogroup.CURRENT_AUTOGROUP.get_or_create_group() - group.add_nodes(self) + if autogroup.CURRENT_AUTOGROUP is not None and autogroup.CURRENT_AUTOGROUP.is_to_be_grouped(self): + group = autogroup.CURRENT_AUTOGROUP.get_or_create_group() + group.add_nodes(self) return self diff --git a/aiida/tools/ipython/ipython_magics.py b/aiida/tools/ipython/ipython_magics.py index af3d8cb395..66310c37b9 100644 --- a/aiida/tools/ipython/ipython_magics.py +++ b/aiida/tools/ipython/ipython_magics.py @@ -34,8 +34,8 @@ In [2]: %aiida """ -from IPython import version_info -from IPython.core import magic +from IPython import version_info # pylint: disable=no-name-in-module +from IPython.core import magic # pylint: disable=no-name-in-module,import-error from aiida.common import json diff --git a/docs/source/verdi/verdi_user_guide.rst b/docs/source/verdi/verdi_user_guide.rst index ee43c3999a..3c24ffeb1b 100644 --- a/docs/source/verdi/verdi_user_guide.rst +++ b/docs/source/verdi/verdi_user_guide.rst @@ -702,24 +702,19 @@ Below is a list with all available subcommands. Execute scripts with preloaded AiiDA environment. Options: - --group / --no-group Enables the autogrouping [default: True] - -l, --group-label-prefix TEXT Specify the prefix of the label of the auto - group (numbers might be automatically - appended to generate unique names per run) - -n, --group-name TEXT Specify the name of the auto group - [DEPRECATED, USE --group-label-prefix - instead] - -e, --exclude TEXT Exclude these classes from auto grouping (use - full entrypoint strings) - -i, --include TEXT Include these classes from auto grouping - (use full entrypoint strings or "all") - -E, --excludesubclasses TEXT Exclude these classes and their sub classes - from auto grouping (use full entrypoint - strings) - -I, --includesubclasses TEXT Include these classes and their sub classes - from auto grouping (use full entrypoint - strings) - --help Show this message and exit. + --auto-group Enables the autogrouping + -l, --auto-group-label-prefix TEXT + Specify the prefix of the label of the auto + group (numbers might be automatically + appended to generate unique names per run). + -n, --group-name TEXT Specify the name of the auto group + [DEPRECATED, USE --auto-group-label-prefix + instead]. This also enables auto-grouping. + -e, --exclude TEXT Exclude these classes from auto grouping + (use full entrypoint strings). + -i, --include TEXT Include these classes from auto grouping + (use full entrypoint strings or "all"). + --help Show this message and exit. .. _verdi_setup: diff --git a/tests/backends/aiida_sqlalchemy/test_migrations.py b/tests/backends/aiida_sqlalchemy/test_migrations.py index 2147414f0d..8e2046f293 100644 --- a/tests/backends/aiida_sqlalchemy/test_migrations.py +++ b/tests/backends/aiida_sqlalchemy/test_migrations.py @@ -68,6 +68,7 @@ def setUp(self): except Exception: # Bring back the DB to the correct state if this setup part fails self._reset_database_and_schema() + autogroup.CURRENT_AUTOGROUP = self.current_autogroup raise def _perform_actual_migration(self): diff --git a/tests/cmdline/commands/test_run.py b/tests/cmdline/commands/test_run.py index 8df034f8ef..ae82826e52 100644 --- a/tests/cmdline/commands/test_run.py +++ b/tests/cmdline/commands/test_run.py @@ -9,6 +9,7 @@ ########################################################################### """Tests for `verdi run`.""" import tempfile +import warnings from click.testing import CliRunner @@ -102,7 +103,7 @@ def test_autogroup(self): fhandle.write(script_content) fhandle.flush() - options = [fhandle.name] + options = ['--auto-group', fhandle.name] result = self.cli_runner.invoke(cmd_run.run, options) self.assertClickResultNoException(result) @@ -129,7 +130,7 @@ def test_autogroup_custom_label(self): fhandle.write(script_content) fhandle.flush() - options = [fhandle.name, '--group-label-prefix', autogroup_label] + options = [fhandle.name, '--auto-group', '--auto-group-label-prefix', autogroup_label] result = self.cli_runner.invoke(cmd_run.run, options) self.assertClickResultNoException(result) @@ -157,7 +158,7 @@ def test_no_autogroup(self): fhandle.write(script_content) fhandle.flush() - options = [fhandle.name, '--no-group'] + options = [fhandle.name] # Not storing an autogroup by default result = self.cli_runner.invoke(cmd_run.run, options) self.assertClickResultNoException(result) @@ -173,46 +174,144 @@ def test_autogroup_filter_class(self): # pylint: disable=too-many-locals """Check if the autogroup is properly generated but filtered classes are skipped.""" from aiida.orm import QueryBuilder, Node, Group, load_node - script_content = """from aiida.orm import Data -node1 = Data().store() -node2 = Int(3).store() + script_content = """import sys +from aiida.orm import Computer, Int, ArrayData, KpointsData, CalculationNode, WorkflowNode +from aiida.plugins import CalculationFactory +from aiida.engine import run_get_node +ArithmeticAdd = CalculationFactory('arithmetic.add') + +computer = Computer( + name='localhost-example-{}'.format(sys.argv[1]), + hostname='localhost', + description='my computer', + transport_type='local', + scheduler_type='direct', + workdir='/tmp' +).store() +computer.configure() + +code = Code( + input_plugin_name='arithmetic.add', + remote_computer_exec=[computer, '/bin/true']).store() +inputs = { + 'x': Int(1), + 'y': Int(2), + 'code': code, + 'metadata': { + 'options': { + 'resources': { + 'num_machines': 1, + 'num_mpiprocs_per_machine': 1 + } + } + } +} + +node1 = KpointsData().store() +node2 = ArrayData().store() +node3 = Int(3).store() +node4 = CalculationNode().store() +node5 = WorkflowNode().store() +_, node6 = run_get_node(ArithmeticAdd, **inputs) print(node1.pk) print(node2.pk) +print(node3.pk) +print(node4.pk) +print(node5.pk) +print(node6.pk) """ - - for flags, data_in_autogroup, int_in_autogroup in [ - [['--exclude', 'aiida.node:data'], False, True], - [['--exclude', 'aiida.data:int'], True, False], - [['--excludesubclasses', 'aiida.node:data'], False, False], - [['--excludesubclasses', 'aiida.data:int'], True, False], - [['--excludesubclasses', 'aiida.node:data', 'aiida.data:int'], False, False], - [['--include', 'aiida.node:process'], False, False], - [['--exclude', 'aiida.node:data', 'aiida.data:int'], False, False], - ]: + from aiida.orm import Code + Code() + for idx, ( + flags, + kptdata_in_autogroup, + arraydata_in_autogroup, + int_in_autogroup, + calc_in_autogroup, + wf_in_autogroup, + calcarithmetic_in_autogroup, + ) in enumerate([ + [['--exclude', 'aiida.data:array.kpoints'], False, True, True, True, True, True], + [['--exclude', 'aiida.data:int'], True, True, False, True, True, True], + [['--exclude', 'aiida.data:%'], False, False, False, True, True, True], + [['--exclude', 'aiida.data:array', 'aiida.data:array.%'], False, False, True, True, True, True], + [['--exclude', 'aiida.data:array', 'aiida.data:array.%', 'aiida.data:int'], False, False, False, True, True, + True], + [['--exclude', 'aiida.calculations:arithmetic.add'], True, True, True, True, True, False], + [ + ['--include', 'aiida.node:process.calculation'], # Base type, no specific plugin + False, + False, + False, + True, + False, + False + ], + [ + ['--include', 'aiida.node:process.workflow'], # Base type, no specific plugin + False, + False, + False, + False, + True, + False + ], + [[], True, True, True, True, True, True], + ]): with tempfile.NamedTemporaryFile(mode='w+') as fhandle: fhandle.write(script_content) fhandle.flush() - options = [fhandle.name] + flags + ['--'] + options = ['--auto-group'] + flags + ['--', fhandle.name, str(idx)] result = self.cli_runner.invoke(cmd_run.run, options) self.assertClickResultNoException(result) - pk1_str, pk2_str = result.output.split() + pk1_str, pk2_str, pk3_str, pk4_str, pk5_str, pk6_str = result.output.split() pk1 = int(pk1_str) pk2 = int(pk2_str) + pk3 = int(pk3_str) + pk4 = int(pk4_str) + pk5 = int(pk5_str) + pk6 = int(pk6_str) _ = load_node(pk1) # Check if the node can be loaded _ = load_node(pk2) # Check if the node can be loaded + _ = load_node(pk3) # Check if the node can be loaded + _ = load_node(pk4) # Check if the node can be loaded + _ = load_node(pk5) # Check if the node can be loaded + _ = load_node(pk6) # Check if the node can be loaded queryb = QueryBuilder().append(Node, filters={'id': pk1}, tag='node') queryb.append(Group, with_node='node', filters={'type_string': 'auto.run'}, project='*') - all_auto_groups_data = queryb.all() + all_auto_groups_kptdata = queryb.all() queryb = QueryBuilder().append(Node, filters={'id': pk2}, tag='node') queryb.append(Group, with_node='node', filters={'type_string': 'auto.run'}, project='*') + all_auto_groups_arraydata = queryb.all() + + queryb = QueryBuilder().append(Node, filters={'id': pk3}, tag='node') + queryb.append(Group, with_node='node', filters={'type_string': 'auto.run'}, project='*') all_auto_groups_int = queryb.all() + + queryb = QueryBuilder().append(Node, filters={'id': pk4}, tag='node') + queryb.append(Group, with_node='node', filters={'type_string': 'auto.run'}, project='*') + all_auto_groups_calc = queryb.all() + + queryb = QueryBuilder().append(Node, filters={'id': pk5}, tag='node') + queryb.append(Group, with_node='node', filters={'type_string': 'auto.run'}, project='*') + all_auto_groups_wf = queryb.all() + + queryb = QueryBuilder().append(Node, filters={'id': pk6}, tag='node') + queryb.append(Group, with_node='node', filters={'type_string': 'auto.run'}, project='*') + all_auto_groups_calcarithmetic = queryb.all() + + self.assertEqual( + len(all_auto_groups_kptdata), 1 if kptdata_in_autogroup else 0, + 'Wrong number of nodes in autogroup associated with the KpointsData node ' + "just created with flags '{}'".format(' '.join(flags)) + ) self.assertEqual( - len(all_auto_groups_data), 1 if data_in_autogroup else 0, - 'Wrong number of nodes in autogroup associated with the Data node ' + len(all_auto_groups_arraydata), 1 if arraydata_in_autogroup else 0, + 'Wrong number of nodes in autogroup associated with the ArrayData node ' "just created with flags '{}'".format(' '.join(flags)) ) self.assertEqual( @@ -220,6 +319,21 @@ def test_autogroup_filter_class(self): # pylint: disable=too-many-locals 'Wrong number of nodes in autogroup associated with the Int node ' "just created with flags '{}'".format(' '.join(flags)) ) + self.assertEqual( + len(all_auto_groups_calc), 1 if calc_in_autogroup else 0, + 'Wrong number of nodes in autogroup associated with the CalculationNode ' + "just created with flags '{}'".format(' '.join(flags)) + ) + self.assertEqual( + len(all_auto_groups_wf), 1 if wf_in_autogroup else 0, + 'Wrong number of nodes in autogroup associated with the WorkflowNode ' + "just created with flags '{}'".format(' '.join(flags)) + ) + self.assertEqual( + len(all_auto_groups_calcarithmetic), 1 if calcarithmetic_in_autogroup else 0, + 'Wrong number of nodes in autogroup associated with the ArithmeticAdd CalcJobNode ' + "just created with flags '{}'".format(' '.join(flags)) + ) def test_autogroup_clashing_label(self): """Check if the autogroup label is properly (re)generated when it clashes with an existing group name.""" @@ -235,7 +349,7 @@ def test_autogroup_clashing_label(self): fhandle.flush() # First run - options = [fhandle.name, '--group-label-prefix', autogroup_label] + options = [fhandle.name, '--auto-group', '--auto-group-label-prefix', autogroup_label] result = self.cli_runner.invoke(cmd_run.run, options) self.assertClickResultNoException(result) @@ -251,7 +365,7 @@ def test_autogroup_clashing_label(self): # A few more runs with the same label - it should not crash but append something to the group name for _ in range(10): - options = [fhandle.name, '--group-label-prefix', autogroup_label] + options = [fhandle.name, '--auto-group', '--auto-group-label-prefix', autogroup_label] result = self.cli_runner.invoke(cmd_run.run, options) self.assertClickResultNoException(result) @@ -264,3 +378,41 @@ def test_autogroup_clashing_label(self): len(all_auto_groups), 1, 'There should be only one autogroup associated with the node just created' ) self.assertTrue(all_auto_groups[0][0].label.startswith(autogroup_label)) + + def test_legacy_autogroup_name(self): + """Check if the autogroup is properly generated when using the legacy --group-name flag.""" + from aiida.orm import QueryBuilder, Node, Group, load_node + + script_content = """from aiida.orm import Data +node = Data().store() +print(node.pk) +""" + group_label = 'legacy-group-name' + + with tempfile.NamedTemporaryFile(mode='w+') as fhandle: + fhandle.write(script_content) + fhandle.flush() + + options = ['--group-name', group_label, fhandle.name] + with warnings.catch_warnings(record=True) as warns: # pylint: disable=no-member + result = self.cli_runner.invoke(cmd_run.run, options) + self.assertTrue( + any(['use `--auto-group-label-prefix` instead' in str(warn.message) for warn in warns]), + "No warning for '--group-name' was raised" + ) + + self.assertClickResultNoException(result) + + pk = int(result.output) + _ = load_node(pk) # Check if the node can be loaded + + queryb = QueryBuilder().append(Node, filters={'id': pk}, tag='node') + queryb.append(Group, with_node='node', filters={'type_string': 'auto.run'}, project='*') + all_auto_groups = queryb.all() + self.assertEqual( + len(all_auto_groups), 1, 'There should be only one autogroup associated with the node just created' + ) + self.assertEqual( + all_auto_groups[0][0].label, group_label, + 'The auto group label is "{}" instead of "{}"'.format(all_auto_groups[0][0].label, group_label) + ) diff --git a/tests/engine/processes/workchains/test_utils.py b/tests/engine/processes/workchains/test_utils.py index 5591229d5b..51a787235e 100644 --- a/tests/engine/processes/workchains/test_utils.py +++ b/tests/engine/processes/workchains/test_utils.py @@ -53,7 +53,7 @@ def test_priority(self): attribute_key = 'handlers_called' class ArithmeticAddBaseWorkChain(BaseRestartWorkChain): - """Implementation of a possible BaseRestartWorkChain for the ArithmeticAdd calculation.""" + """Implementation of a possible BaseRestartWorkChain for the ``ArithmeticAddCalculation``.""" _process_class = ArithmeticAddCalculation @@ -164,7 +164,7 @@ def test_exit_codes_filter(self): node_skip.set_exit_status(200) # Some other exit status class ArithmeticAddBaseWorkChain(BaseRestartWorkChain): - """Minimal base restart workchain ofr the ArithmeticAdd calculation""" + """Minimal base restart workchain for the ``ArithmeticAddCalculation``.""" _process_class = ArithmeticAddCalculation diff --git a/tests/orm/test_autogroups.py b/tests/orm/test_autogroups.py new file mode 100644 index 0000000000..e1426ad2e8 --- /dev/null +++ b/tests/orm/test_autogroups.py @@ -0,0 +1,129 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +"""Tests for the Autogroup functionality.""" +from aiida.backends.testbase import AiidaTestCase +from aiida.orm import Group, QueryBuilder +from aiida.orm.autogroup import Autogroup + + +class TestAutogroup(AiidaTestCase): + """Tests the Autogroup logic.""" + + def test_get_or_create(self): + """Test the ``get_or_create_group`` method of ``Autogroup``.""" + label_prefix = 'test_prefix_TestAutogroup' + + # Check that there are no groups to begin with + queryb = QueryBuilder().append(Group, filters={'type_string': 'auto.run', 'label': label_prefix}, project='*') + assert not list(queryb.all()) + queryb = QueryBuilder().append( + Group, filters={ + 'type_string': 'auto.run', + 'label': { + 'like': r'{}\_%'.format(label_prefix) + } + }, project='*' + ) + assert not list(queryb.all()) + + # First group (no existing one) + autogroup = Autogroup() + autogroup.set_group_label_prefix(label_prefix) + group = autogroup.get_or_create_group() + expected_label = label_prefix + self.assertEqual( + group.label, expected_label, + "The auto-group should be labelled '{}', it is instead '{}'".format(expected_label, group.label) + ) + + # Second group (only one with no suffix existing) + autogroup = Autogroup() + autogroup.set_group_label_prefix(label_prefix) + group = autogroup.get_or_create_group() + expected_label = label_prefix + '_1' + self.assertEqual( + group.label, expected_label, + "The auto-group should be labelled '{}', it is instead '{}'".format(expected_label, group.label) + ) + + # Second group (only one suffix _1 existing) + autogroup = Autogroup() + autogroup.set_group_label_prefix(label_prefix) + group = autogroup.get_or_create_group() + expected_label = label_prefix + '_2' + self.assertEqual( + group.label, expected_label, + "The auto-group should be labelled '{}', it is instead '{}'".format(expected_label, group.label) + ) + + # I create a group with a large integer suffix (9) + Group(label='{}_9'.format(label_prefix), type_string='auto.run').store() + # The next autogroup should become number 10 + autogroup = Autogroup() + autogroup.set_group_label_prefix(label_prefix) + group = autogroup.get_or_create_group() + expected_label = label_prefix + '_10' + self.assertEqual( + group.label, expected_label, + "The auto-group should be labelled '{}', it is instead '{}'".format(expected_label, group.label) + ) + + # I create a group with a non-integer suffix (15a), it should be ignored + Group(label='{}_15b'.format(label_prefix), type_string='auto.run').store() + # The next autogroup should become number 11 + autogroup = Autogroup() + autogroup.set_group_label_prefix(label_prefix) + group = autogroup.get_or_create_group() + expected_label = label_prefix + '_11' + self.assertEqual( + group.label, expected_label, + "The auto-group should be labelled '{}', it is instead '{}'".format(expected_label, group.label) + ) + + def test_get_or_create_invalid_prefix(self): + """Test the ``get_or_create_group`` method of ``Autogroup`` when there is already a group + with the same prefix, but followed by other non-underscore characters.""" + label_prefix = 'new_test_prefix_TestAutogroup' + # I create a group with the same prefix, but followed by non-underscore + # characters. These should be ignored in the logic. + Group(label='{}xx'.format(label_prefix), type_string='auto.run').store() + + # Check that there are no groups to begin with + queryb = QueryBuilder().append(Group, filters={'type_string': 'auto.run', 'label': label_prefix}, project='*') + assert not list(queryb.all()) + queryb = QueryBuilder().append( + Group, filters={ + 'type_string': 'auto.run', + 'label': { + 'like': r'{}\_%'.format(label_prefix) + } + }, project='*' + ) + assert not list(queryb.all()) + + # First group (no existing one) + autogroup = Autogroup() + autogroup.set_group_label_prefix(label_prefix) + group = autogroup.get_or_create_group() + expected_label = label_prefix + self.assertEqual( + group.label, expected_label, + "The auto-group should be labelled '{}', it is instead '{}'".format(expected_label, group.label) + ) + + # Second group (only one with no suffix existing) + autogroup = Autogroup() + autogroup.set_group_label_prefix(label_prefix) + group = autogroup.get_or_create_group() + expected_label = label_prefix + '_1' + self.assertEqual( + group.label, expected_label, + "The auto-group should be labelled '{}', it is instead '{}'".format(expected_label, group.label) + ) From 31d0470543c2213bd0ca0dbd99cf92b49e09c100 Mon Sep 17 00:00:00 2001 From: Giovanni Pizzi Date: Fri, 3 Apr 2020 17:09:45 +0200 Subject: [PATCH 4/5] A few further comments addressed. Most notable: now the % sign can be anywhere in the string of included or excluded classes, and not only at the end. --- aiida/cmdline/commands/cmd_run.py | 3 +- aiida/orm/autogroup.py | 70 +++++++++++++----------------- tests/cmdline/commands/test_run.py | 2 + 3 files changed, 35 insertions(+), 40 deletions(-) diff --git a/aiida/cmdline/commands/cmd_run.py b/aiida/cmdline/commands/cmd_run.py index 281e3092a1..bd3972b841 100644 --- a/aiida/cmdline/commands/cmd_run.py +++ b/aiida/cmdline/commands/cmd_run.py @@ -112,7 +112,8 @@ def run(scriptname, varargs, auto_group, auto_group_label_prefix, group_name, ex warnings.warn('--group-name is deprecated, use `--auto-group-label-prefix` instead', AiidaDeprecationWarning) # pylint: disable=no-member if auto_group_label_prefix: raise click.BadParameter( - 'You cannot specify both --group-name and --auto-group-label-prefix; use --group-label only' + 'You cannot specify both --group-name and --auto-group-label-prefix; ' + 'use --auto-group-label-prefix only' ) auto_group_label_prefix = group_name # To have the old behavior, with auto-group enabled. diff --git a/aiida/orm/autogroup.py b/aiida/orm/autogroup.py index eee4239b28..16bf03f1c1 100644 --- a/aiida/orm/autogroup.py +++ b/aiida/orm/autogroup.py @@ -8,10 +8,11 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Module to manage the autogrouping functionality by ``verdi run``.""" +import re import warnings from aiida.common import exceptions, timezone -from aiida.common.escaping import escape_for_sql_like +from aiida.common.escaping import escape_for_sql_like, get_regex_pattern_from_sql from aiida.common.warnings import AiidaDeprecationWarning from aiida.orm import GroupTypeString, Group from aiida.plugins.entry_point import get_entry_point_string_from_class @@ -34,10 +35,13 @@ class Autogroup: ``aiida.data:array.%``, ... i.e.: a string identifying the base class, followed a colona and by the path to the class as accepted by CalculationFactory/DataFactory. - Each string contain the wildcard ``%`` at the end; + Each string can contain one or more wildcard characters ``%``; in this case this is used in a ``like`` comparison with the QueryBuilder. + Note that in this case you have to remember that ``_`` means "any character" + in the QueryBuilder, and you need to escape it if you mean a literal underscore. - Only one of the two (between exclude and include) can be set. If none of the two is set, everything is included. + Only one of the two (between exclude and include) can be set. + If none of the two is set, everything is included. """ def __init__(self): @@ -68,15 +72,6 @@ def validate(strings): "'{}' has an invalid prefix, must be among: {}".format(string, sorted(valid_prefixes)) ) - # If a % is present, it can only be the last character - string_without_percent = pieces[1] - if string_without_percent.endswith('%'): - string_without_percent = string_without_percent[:-1] - if '%' in string_without_percent: - raise exceptions.ValidationError( - "'{}' can only contain a '%' character, if any, at the end of the string".format(string) - ) - def get_exclude(self): """Return the list of classes to exclude from autogrouping. @@ -86,7 +81,7 @@ def get_exclude(self): def get_include(self): """Return the list of classes to include in the autogrouping. - Returns ``None`` if no exclusion list has been set.""" + Returns ``None`` if no inclusion list has been set.""" return self._include def get_group_label_prefix(self): @@ -105,9 +100,11 @@ def get_group_name(self): return self.get_group_label_prefix() def set_exclude(self, exclude): - """Return the list of classes to exclude from autogrouping. + """Set the list of classes to exclude in the autogrouping. - :param exclude: a list of valid entry point strings (one of which could be the string 'all') + :param exclude: a list of valid entry point strings (might contain '%' to be used as + string to be matched using SQL's ``LIKE`` pattern-making logic), or ``None`` + to specify no include list. """ if isinstance(exclude, str): exclude = [exclude] @@ -118,10 +115,11 @@ def set_exclude(self, exclude): self._exclude = exclude def set_include(self, include): - """ - Set the list of classes to include in the autogrouping. + """Set the list of classes to include in the autogrouping. - :param include: a list of valid entry point strings (one of which could be the string 'all') + :param include: a list of valid entry point strings (might contain '%' to be used as + string to be matched using SQL's ``LIKE`` pattern-making logic), or ``None`` + to specify no include list. """ if isinstance(include, str): include = [include] @@ -152,14 +150,16 @@ def set_group_name(self, gname): def _matches(string, filter_string): """Check if 'string' matches the 'filter_string' (used for include and exclude filters). - If 'filter_string' does not end with a % sign, perform an exact match. - Otherwise, strip the '%' sign and match with string.startswith(filter_string[:-1]). + If 'filter_string' does not contain any % sign, perform an exact match. + Otherwise, match with a SQL-like query, where % means any character sequence, + and _ means a single character (these caracters can be escaped with a backslash). :param string: the string to match. :param filter_string: the filter string. """ - if filter_string.endswith('%'): - return string.startswith(filter_string[:-1]) + if '%' in filter_string: + regex_filter = get_regex_pattern_from_sql(filter_string) + return re.match(regex_filter, string) is not None return string == filter_string def is_to_be_grouped(self, node): @@ -168,35 +168,27 @@ def is_to_be_grouped(self, node): :return (bool): True if ``node`` is to be included in the autogroup """ - # I import here to avoid circular imports - from aiida.orm.nodes.process import ProcessNode - # strings, including possibly 'all' include = self.get_include() exclude = self.get_exclude() if include is None and exclude is None: # Include all classes by default if nothing is explicitly specified. return True - if include is not None and exclude is not None: - # We should never be here, anyway - this should be catched by the `set_include/exclude` methods - raise ValueError("You cannot specify both an 'include' and an 'exclude' list") - the_class = node.__class__ - if issubclass(the_class, ProcessNode): - try: - the_class = node.process_class - except ValueError: - # It does not have a process class - we just check the node class then, it could be e.g. - # a bare CalculationNode. - pass - class_entry_point_string = get_entry_point_string_from_class(the_class.__module__, the_class.__name__) + # We should never be here, anyway - this should be catched by the `set_include/exclude` methods + assert include is None or exclude is None, "You cannot specify both an 'include' and an 'exclude' list" + + entry_point_string = node.process_type + # If there is no `process_type` we are dealing with a `Data` node so we get the entry point from the class + if not entry_point_string: + entry_point_string = get_entry_point_string_from_class(node.__class__.__module__, node.__class__.__name__) if include is not None: # As soon as a filter string matches, we include the class - return any(self._matches(class_entry_point_string, filter_string) for filter_string in include) + return any(self._matches(entry_point_string, filter_string) for filter_string in include) # If we are here, exclude is not None # include *only* in *none* of the filters match (that is, exclude as # soon as any of the filters matches) - return not any(self._matches(class_entry_point_string, filter_string) for filter_string in exclude) + return not any(self._matches(entry_point_string, filter_string) for filter_string in exclude) def clear_group_cache(self): """Clear the cache of the group name. diff --git a/tests/cmdline/commands/test_run.py b/tests/cmdline/commands/test_run.py index ae82826e52..78c858420f 100644 --- a/tests/cmdline/commands/test_run.py +++ b/tests/cmdline/commands/test_run.py @@ -232,6 +232,8 @@ def test_autogroup_filter_class(self): # pylint: disable=too-many-locals calcarithmetic_in_autogroup, ) in enumerate([ [['--exclude', 'aiida.data:array.kpoints'], False, True, True, True, True, True], + # Check if % works anywhere - both 'int' and 'array.kpoints' contain an 'i' + [['--exclude', 'aiida.data:%i%'], False, True, False, True, True, True], [['--exclude', 'aiida.data:int'], True, True, False, True, True, True], [['--exclude', 'aiida.data:%'], False, False, False, True, True, True], [['--exclude', 'aiida.data:array', 'aiida.data:array.%'], False, False, True, True, True, True], From 9a681009d2d3095fc54c0f25aa0d129e7ab97d7f Mon Sep 17 00:00:00 2001 From: Giovanni Pizzi Date: Fri, 3 Apr 2020 23:26:07 +0200 Subject: [PATCH 5/5] Minor pre-commit fixes after merge. --- aiida/orm/utils/node.py | 2 +- aiida/plugins/entry_point.py | 18 ++++++++++-------- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/aiida/orm/utils/node.py b/aiida/orm/utils/node.py index eafbee0b79..0432964467 100644 --- a/aiida/orm/utils/node.py +++ b/aiida/orm/utils/node.py @@ -258,7 +258,7 @@ class AbstractNodeMeta(ABCMeta): # pylint: disable=too-few-public-methods """Some python black magic to set correctly the logger also in subclasses.""" def __new__(mcs, name, bases, namespace): # pylint: disable=arguments-differ,protected-access,too-many-function-args - newcls = ABCMeta.__new__(mcs, name, bases, namespace) + newcls = ABCMeta.__new__(mcs, name, bases, namespace) # pylint: disable=too-many-function-args newcls._logger = logging.getLogger('{}.{}'.format(namespace['__module__'], name)) # Set the plugin type string and query type string based on the plugin type string diff --git a/aiida/plugins/entry_point.py b/aiida/plugins/entry_point.py index a1b32e3a55..2abe6be077 100644 --- a/aiida/plugins/entry_point.py +++ b/aiida/plugins/entry_point.py @@ -243,7 +243,6 @@ def get_entry_points(group): return list(ENTRYPOINT_MANAGER.iter_entry_points(group=group)) - @functools.lru_cache(maxsize=None) def get_entry_point(group, name): """ @@ -258,12 +257,16 @@ def get_entry_point(group, name): entry_points = [ep for ep in get_entry_points(group) if ep.name == name] if not entry_points: - raise MissingEntryPointError("Entry point '{}' not found in group '{}'. Try running `reentry scan` to update " - 'the entry point cache.'.format(name, group)) + raise MissingEntryPointError( + "Entry point '{}' not found in group '{}'. Try running `reentry scan` to update " + 'the entry point cache.'.format(name, group) + ) if len(entry_points) > 1: - raise MultipleEntryPointError("Multiple entry points '{}' found in group '{}'.Try running `reentry scan` to " - 'repopulate the entry point cache.'.format(name, group)) + raise MultipleEntryPointError( + "Multiple entry points '{}' found in group '{}'.Try running `reentry scan` to " + 'repopulate the entry point cache.'.format(name, group) + ) return entry_points[0] @@ -346,11 +349,10 @@ def is_registered_entry_point(class_module, class_name, groups=None): :return: boolean, True if the class is a registered entry point, False otherwise. """ if groups is None: - groups = list(entry_point_group_to_module_path_map.keys()) + groups = list(ENTRY_POINT_GROUP_TO_MODULE_PATH_MAP.keys()) for group in groups: for entry_point in ENTRYPOINT_MANAGER.iter_entry_points(group): if class_module == entry_point.module_name and [class_name] == entry_point.attrs: return True - else: - return False + return False