diff --git a/.ci/workchains.py b/.ci/workchains.py index 110334f0ae..f5ab3872d7 100644 --- a/.ci/workchains.py +++ b/.ci/workchains.py @@ -68,8 +68,8 @@ def a_magic_unicorn_appeared(self, node): @process_handler(priority=400, exit_codes=ArithmeticAddCalculation.exit_codes.ERROR_NEGATIVE_NUMBER) def error_negative_sum(self, node): """What even is a negative number, how can I have minus three melons?!.""" - self.ctx.inputs.x = Int(abs(node.inputs.x.value)) - self.ctx.inputs.y = Int(abs(node.inputs.y.value)) + self.ctx.inputs.x = Int(abs(node.inputs.x.value)) # pylint: disable=invalid-name + self.ctx.inputs.y = Int(abs(node.inputs.y.value)) # pylint: disable=invalid-name return ProcessHandlerReport(True) diff --git a/.gitignore b/.gitignore index 9d225c3ef0..1983db653d 100644 --- a/.gitignore +++ b/.gitignore @@ -20,6 +20,7 @@ .cache .pytest_cache .coverage +coverage.xml # Files created by RPN tests .ci/polish/polish_workchains/polish* diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 62ae29e398..8ceab33c0e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -53,7 +53,6 @@ aiida/common/datastructures.py| aiida/engine/daemon/execmanager.py| aiida/engine/processes/calcjobs/tasks.py| - aiida/orm/autogroup.py| aiida/orm/querybuilder.py| aiida/orm/nodes/data/array/bands.py| aiida/orm/nodes/data/array/projection.py| @@ -66,7 +65,6 @@ aiida/parsers/plugins/arithmetic/add.py| aiida/parsers/plugins/templatereplacer/doubler.py| aiida/parsers/plugins/templatereplacer/__init__.py| - aiida/plugins/entry_point.py| aiida/plugins/entry.py| aiida/plugins/info.py| aiida/plugins/registry.py| diff --git a/aiida/backends/testbase.py b/aiida/backends/testbase.py index de855eec4b..ed18f27566 100644 --- a/aiida/backends/testbase.py +++ b/aiida/backends/testbase.py @@ -99,7 +99,11 @@ def tearDown(self): def reset_database(self): """Reset the database to the default state deleting any content currently stored""" + from aiida.orm import autogroup + self.clean_db() + if autogroup.CURRENT_AUTOGROUP is not None: + autogroup.CURRENT_AUTOGROUP.clear_group_cache() self.insert_data() @classmethod @@ -109,7 +113,10 @@ def insert_data(cls): inserts default data into the database (which is for the moment a default computer). """ + from aiida.orm import User + cls.create_user() + User.objects.reset() cls.create_computer() @classmethod @@ -180,7 +187,11 @@ def user_email(cls): # pylint: disable=no-self-argument def tearDownClass(cls, *args, **kwargs): # pylint: disable=arguments-differ # Double check for double security to avoid to run the tearDown # if this is not a test profile + from aiida.orm import autogroup + check_if_tests_can_run() + if autogroup.CURRENT_AUTOGROUP is not None: + autogroup.CURRENT_AUTOGROUP.clear_group_cache() cls.clean_db() cls.clean_repository() cls.__backend_instance.tearDownClass_method(*args, **kwargs) diff --git a/aiida/cmdline/commands/cmd_plugin.py b/aiida/cmdline/commands/cmd_plugin.py index f09c064950..3232441379 100644 --- a/aiida/cmdline/commands/cmd_plugin.py +++ b/aiida/cmdline/commands/cmd_plugin.py @@ -13,7 +13,7 @@ from aiida.cmdline.commands.cmd_verdi import verdi from aiida.cmdline.utils import decorators, echo -from aiida.plugins.entry_point import entry_point_group_to_module_path_map +from aiida.plugins.entry_point import ENTRY_POINT_GROUP_TO_MODULE_PATH_MAP @verdi.group('plugin') @@ -22,7 +22,7 @@ def verdi_plugin(): @verdi_plugin.command('list') -@click.argument('entry_point_group', type=click.Choice(entry_point_group_to_module_path_map.keys()), required=False) +@click.argument('entry_point_group', type=click.Choice(ENTRY_POINT_GROUP_TO_MODULE_PATH_MAP.keys()), required=False) @click.argument('entry_point', type=click.STRING, required=False) @decorators.with_dbenv() def plugin_list(entry_point_group, entry_point): @@ -34,7 +34,7 @@ def plugin_list(entry_point_group, entry_point): if entry_point_group is None: echo.echo_info('Available entry point groups:') - for group in sorted(entry_point_group_to_module_path_map.keys()): + for group in sorted(ENTRY_POINT_GROUP_TO_MODULE_PATH_MAP.keys()): echo.echo('* {}'.format(group)) echo.echo('') diff --git a/aiida/cmdline/commands/cmd_run.py b/aiida/cmdline/commands/cmd_run.py index 5a43cad6f5..bd3972b841 100644 --- a/aiida/cmdline/commands/cmd_run.py +++ b/aiida/cmdline/commands/cmd_run.py @@ -10,13 +10,16 @@ """`verdi run` command.""" import contextlib import os +import functools import sys +import warnings import click from aiida.cmdline.commands.cmd_verdi import verdi from aiida.cmdline.params.options.multivalue import MultipleValueOption from aiida.cmdline.utils import decorators, echo +from aiida.common.warnings import AiidaDeprecationWarning @contextlib.contextmanager @@ -37,31 +40,56 @@ def update_environment(argv): sys.path = _path +def validate_entrypoint_string(ctx, param, value): # pylint: disable=unused-argument,invalid-name + """Validate that `value` is a valid entrypoint string.""" + from aiida.orm import autogroup + + try: + autogroup.Autogroup.validate(value) + except Exception as exc: + raise click.BadParameter(str(exc) + ' ({})'.format(value)) + + return value + + @verdi.command('run', context_settings=dict(ignore_unknown_options=True,)) @click.argument('scriptname', type=click.STRING) @click.argument('varargs', nargs=-1, type=click.UNPROCESSED) -@click.option('-g', '--group', is_flag=True, default=True, show_default=True, help='Enables the autogrouping') -@click.option('-n', '--group-name', type=click.STRING, required=False, help='Specify the name of the auto group') -@click.option('-e', '--exclude', cls=MultipleValueOption, default=[], help='Exclude these classes from auto grouping') +@click.option('--auto-group', is_flag=True, help='Enables the autogrouping') +@click.option( + '-l', + '--auto-group-label-prefix', + type=click.STRING, + required=False, + help='Specify the prefix of the label of the auto group (numbers might be automatically ' + 'appended to generate unique names per run).' +) @click.option( - '-i', '--include', cls=MultipleValueOption, default=['all'], help='Include these classes from auto grouping' + '-n', + '--group-name', + type=click.STRING, + required=False, + help='Specify the name of the auto group [DEPRECATED, USE --auto-group-label-prefix instead]. ' + 'This also enables auto-grouping.' ) @click.option( - '-E', - '--excludesubclasses', + '-e', + '--exclude', cls=MultipleValueOption, - default=[], - help='Exclude these classes and their sub classes from auto grouping' + default=None, + help='Exclude these classes from auto grouping (use full entrypoint strings).', + callback=functools.partial(validate_entrypoint_string) ) @click.option( - '-I', - '--includesubclasses', + '-i', + '--include', cls=MultipleValueOption, - default=[], - help='Include these classes and their sub classes from auto grouping' + default=None, + help='Include these classes from auto grouping (use full entrypoint strings or "all").', + callback=validate_entrypoint_string ) @decorators.with_dbenv() -def run(scriptname, varargs, group, group_name, exclude, excludesubclasses, include, includesubclasses): +def run(scriptname, varargs, auto_group, auto_group_label_prefix, group_name, exclude, include): # pylint: disable=too-many-arguments,exec-used """Execute scripts with preloaded AiiDA environment.""" from aiida.cmdline.utils.shell import DEFAULT_MODULES_LIST @@ -80,22 +108,27 @@ def run(scriptname, varargs, group, group_name, exclude, excludesubclasses, incl for app_mod, model_name, alias in DEFAULT_MODULES_LIST: globals_dict['{}'.format(alias)] = getattr(__import__(app_mod, {}, {}, model_name), model_name) - if group: - automatic_group_name = group_name - if automatic_group_name is None: - from aiida.common import timezone - - automatic_group_name = 'Verdi autogroup on ' + timezone.now().strftime('%Y-%m-%d %H:%M:%S') + if group_name: + warnings.warn('--group-name is deprecated, use `--auto-group-label-prefix` instead', AiidaDeprecationWarning) # pylint: disable=no-member + if auto_group_label_prefix: + raise click.BadParameter( + 'You cannot specify both --group-name and --auto-group-label-prefix; ' + 'use --auto-group-label-prefix only' + ) + auto_group_label_prefix = group_name + # To have the old behavior, with auto-group enabled. + auto_group = True + if auto_group: aiida_verdilib_autogroup = autogroup.Autogroup() + # Set the ``group_label_prefix`` if defined, otherwise a default prefix will be used + if auto_group_label_prefix is not None: + aiida_verdilib_autogroup.set_group_label_prefix(auto_group_label_prefix) aiida_verdilib_autogroup.set_exclude(exclude) aiida_verdilib_autogroup.set_include(include) - aiida_verdilib_autogroup.set_exclude_with_subclasses(excludesubclasses) - aiida_verdilib_autogroup.set_include_with_subclasses(includesubclasses) - aiida_verdilib_autogroup.set_group_name(automatic_group_name) # Note: this is also set in the exec environment! This is the intended behavior - autogroup.current_autogroup = aiida_verdilib_autogroup + autogroup.CURRENT_AUTOGROUP = aiida_verdilib_autogroup # Initialize the variable here, otherwise we get UnboundLocalError in the finally clause if it fails to open handle = None diff --git a/aiida/engine/processes/calcjobs/calcjob.py b/aiida/engine/processes/calcjobs/calcjob.py index 0e6b234ef0..9f3d2d765f 100644 --- a/aiida/engine/processes/calcjobs/calcjob.py +++ b/aiida/engine/processes/calcjobs/calcjob.py @@ -64,7 +64,7 @@ def validate_calc_job(inputs, ctx): ) -def validate_parser(parser_name, ctx): +def validate_parser(parser_name, ctx): # pylint: disable=unused-argument """Validate the parser. :raises InputValidationError: if the parser name does not correspond to a loadable `Parser` class. @@ -78,7 +78,7 @@ def validate_parser(parser_name, ctx): raise exceptions.InputValidationError('invalid parser specified: {}'.format(exception)) -def validate_resources(resources, ctx): +def validate_resources(resources, ctx): # pylint: disable=unused-argument """Validate the resources. :raises InputValidationError: if `num_machines` is not specified or is not an integer. diff --git a/aiida/manage/caching.py b/aiida/manage/caching.py index d8079fd747..9b7f1d427d 100644 --- a/aiida/manage/caching.py +++ b/aiida/manage/caching.py @@ -22,7 +22,7 @@ from aiida.common import exceptions from aiida.common.lang import type_check -from aiida.plugins.entry_point import ENTRY_POINT_STRING_SEPARATOR, entry_point_group_to_module_path_map +from aiida.plugins.entry_point import ENTRY_POINT_STRING_SEPARATOR, ENTRY_POINT_GROUP_TO_MODULE_PATH_MAP __all__ = ('get_use_cache', 'enable_caching', 'disable_caching') @@ -248,7 +248,7 @@ def _validate_identifier_pattern(*, identifier): 1. - where `group_name` is one of the keys in `entry_point_group_to_module_path_map` + where `group_name` is one of the keys in `ENTRY_POINT_GROUP_TO_MODULE_PATH_MAP` and `tail` can be anything _except_ `ENTRY_POINT_STRING_SEPARATOR`. 2. a fully qualified Python name @@ -276,7 +276,7 @@ def _validate_identifier_pattern(*, identifier): group_pattern, _ = identifier.split(ENTRY_POINT_STRING_SEPARATOR) if not any( _match_wildcard(string=group_name, pattern=group_pattern) - for group_name in entry_point_group_to_module_path_map + for group_name in ENTRY_POINT_GROUP_TO_MODULE_PATH_MAP ): raise ValueError( common_error_msg + "Group name pattern '{}' does not match any of the AiiDA entry point group names.". @@ -290,7 +290,7 @@ def _validate_identifier_pattern(*, identifier): # aiida.* or aiida.calculations* if '*' in identifier: group_part, _ = identifier.split('*', 1) - if any(group_name.startswith(group_part) for group_name in entry_point_group_to_module_path_map): + if any(group_name.startswith(group_part) for group_name in ENTRY_POINT_GROUP_TO_MODULE_PATH_MAP): return # Finally, check if it could be a fully qualified Python name for identifier_part in identifier.split('.'): diff --git a/aiida/orm/autogroup.py b/aiida/orm/autogroup.py index ed4551a3ad..16bf03f1c1 100644 --- a/aiida/orm/autogroup.py +++ b/aiida/orm/autogroup.py @@ -7,173 +7,278 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### - +"""Module to manage the autogrouping functionality by ``verdi run``.""" +import re +import warnings from aiida.common import exceptions, timezone -from aiida.orm import GroupTypeString - +from aiida.common.escaping import escape_for_sql_like, get_regex_pattern_from_sql +from aiida.common.warnings import AiidaDeprecationWarning +from aiida.orm import GroupTypeString, Group +from aiida.plugins.entry_point import get_entry_point_string_from_class -current_autogroup = None +CURRENT_AUTOGROUP = None VERDIAUTOGROUP_TYPE = GroupTypeString.VERDIAUTOGROUP_TYPE.value -# TODO: make the Autogroup usable to the user, and not only to the verdi run class Autogroup: """ An object used for the autogrouping of objects. The autogrouping is checked by the Node.store() method. - In the store(), the Node will check if current_autogroup is != None. + In the store(), the Node will check if CURRENT_AUTOGROUP is != None. If so, it will call Autogroup.is_to_be_grouped, and decide whether to put it in a group. Such autogroups are going to be of the VERDIAUTOGROUP_TYPE. - The exclude/include lists, can have values 'all' if you want to include/exclude all classes. - Otherwise, they are lists of strings like: calculation.quantumespresso.pw, data.array.kpoints, ... - i.e.: a string identifying the base class, than the path to the class as in Calculation/Data -Factories + The exclude/include lists are lists of strings like: + ``aiida.data:int``, ``aiida.calculation:quantumespresso.pw``, + ``aiida.data:array.%``, ... + i.e.: a string identifying the base class, followed a colona and by the path to the class + as accepted by CalculationFactory/DataFactory. + Each string can contain one or more wildcard characters ``%``; + in this case this is used in a ``like`` comparison with the QueryBuilder. + Note that in this case you have to remember that ``_`` means "any character" + in the QueryBuilder, and you need to escape it if you mean a literal underscore. + + Only one of the two (between exclude and include) can be set. + If none of the two is set, everything is included. """ - def _validate(self, param, is_exact=True): - """ - Used internally to verify the sanity of exclude, include lists - """ - from aiida.plugins import CalculationFactory, DataFactory - - for i in param: - if not any([i.startswith('calculation'), - i.startswith('code'), - i.startswith('data'), - i == 'all', - ]): - raise exceptions.ValidationError('Module not recognized, allow prefixes ' - ' are: calculation, code or data') - the_param = [i + '.' for i in param] - - factorydict = {'calculation': locals()['CalculationFactory'], - 'data': locals()['DataFactory']} - - for i in the_param: - base, module = i.split('.', 1) - if base == 'code': - if module: - raise exceptions.ValidationError('Cannot have subclasses for codes') - elif base == 'all': - continue - else: - if is_exact: - try: - factorydict[base](module.rstrip('.')) - except exceptions.EntryPointError: - raise exceptions.ValidationError('Cannot find the class to be excluded') - return the_param + def __init__(self): + """Initialize with defaults.""" + self._exclude = None + self._include = None + + now = timezone.now() + default_label_prefix = 'Verdi autogroup on ' + now.strftime('%Y-%m-%d %H:%M:%S') + self._group_label_prefix = default_label_prefix + self._group_label = None # Actual group label, set by `get_or_create_group` + + @staticmethod + def validate(strings): + """Validate the list of strings passed to set_include and set_exclude.""" + if strings is None: + return + valid_prefixes = set(['aiida.node', 'aiida.calculations', 'aiida.workflows', 'aiida.data']) + for string in strings: + pieces = string.split(':') + if len(pieces) != 2: + raise exceptions.ValidationError( + "'{}' is not a valid include/exclude filter, must contain two parts split by a colon". + format(string) + ) + if pieces[0] not in valid_prefixes: + raise exceptions.ValidationError( + "'{}' has an invalid prefix, must be among: {}".format(string, sorted(valid_prefixes)) + ) def get_exclude(self): - """Return the list of classes to exclude from autogrouping.""" - try: - return self.exclude - except AttributeError: - return [] + """Return the list of classes to exclude from autogrouping. - def get_exclude_with_subclasses(self): - """ - Return the list of classes to exclude from autogrouping. - Will also exclude their derived subclasses - """ - try: - return self.exclude_with_subclasses - except AttributeError: - return [] + Returns ``None`` if no exclusion list has been set.""" + return self._exclude def get_include(self): - """Return the list of classes to include in the autogrouping.""" - try: - return self.include - except AttributeError: - return [] - - def get_include_with_subclasses(self): """Return the list of classes to include in the autogrouping. - Will also include their derived subclasses.""" - try: - return self.include_with_subclasses - except AttributeError: - return [] + + Returns ``None`` if no inclusion list has been set.""" + return self._include + + def get_group_label_prefix(self): + """Get the prefix of the label of the group. + If no group label prefix was set, it will set a default one by itself.""" + return self._group_label_prefix def get_group_name(self): - """Get the name of the group. - If no group name was set, it will set a default one by itself.""" - try: - return self.group_name - except AttributeError: - now = timezone.now() - gname = 'Verdi autogroup on ' + now.strftime('%Y-%m-%d %H:%M:%S') - self.set_group_name(gname) - return self.group_name + """Get the label of the group. + If no group label was set, it will set a default one by itself. - def set_exclude(self, exclude): - """Return the list of classes to exclude from autogrouping.""" - the_exclude_classes = self._validate(exclude) - if self.get_include() is not None: - if 'all.' in self.get_include(): - if 'all.' in the_exclude_classes: - raise exceptions.ValidationError('Cannot exclude and include all classes') - self.exclude = the_exclude_classes - - def set_exclude_with_subclasses(self, exclude): + .. deprecated:: 1.2.0 + Will be removed in `v2.0.0`, use :py:meth:`.get_group_label_prefix` instead. """ - Set the list of classes to exclude from autogrouping. - Will also exclude their derived subclasses + warnings.warn('function is deprecated, use `get_group_label_prefix` instead', AiidaDeprecationWarning) # pylint: disable=no-member + return self.get_group_label_prefix() + + def set_exclude(self, exclude): + """Set the list of classes to exclude in the autogrouping. + + :param exclude: a list of valid entry point strings (might contain '%' to be used as + string to be matched using SQL's ``LIKE`` pattern-making logic), or ``None`` + to specify no include list. """ - the_exclude_classes = self._validate(exclude, is_exact=False) - self.exclude_with_subclasses = the_exclude_classes + if isinstance(exclude, str): + exclude = [exclude] + self.validate(exclude) + if exclude is not None and self.get_include() is not None: + # It's ok to set None, both as a default, or to 'undo' the exclude list + raise exceptions.ValidationError('Cannot both specify exclude and include') + self._exclude = exclude def set_include(self, include): - """ - Set the list of classes to include in the autogrouping. - """ - the_include_classes = self._validate(include) - if self.get_exclude() is not None: - if 'all.' in self.get_exclude(): - if 'all.' in the_include_classes: - raise exceptions.ValidationError('Cannot exclude and include all classes') + """Set the list of classes to include in the autogrouping. - self.include = the_include_classes + :param include: a list of valid entry point strings (might contain '%' to be used as + string to be matched using SQL's ``LIKE`` pattern-making logic), or ``None`` + to specify no include list. + """ + if isinstance(include, str): + include = [include] + self.validate(include) + if include is not None and self.get_exclude() is not None: + # It's ok to set None, both as a default, or to 'undo' the include list + raise exceptions.ValidationError('Cannot both specify exclude and include') + self._include = include - def set_include_with_subclasses(self, include): + def set_group_label_prefix(self, label_prefix): """ - Set the list of classes to include in the autogrouping. - Will also include their derived subclasses. + Set the label of the group to be created """ - the_include_classes = self._validate(include, is_exact=False) - self.include_with_subclasses = the_include_classes + if not isinstance(label_prefix, str): + raise exceptions.ValidationError('group label must be a string') + self._group_label_prefix = label_prefix def set_group_name(self, gname): + """Set the name of the group. + + .. deprecated:: 1.2.0 + Will be removed in `v2.0.0`, use :py:meth:`.set_group_label_prefix` instead. """ - Set the name of the group to be created + warnings.warn('function is deprecated, use `set_group_label_prefix` instead', AiidaDeprecationWarning) # pylint: disable=no-member + return self.set_group_label_prefix(label_prefix=gname) + + @staticmethod + def _matches(string, filter_string): + """Check if 'string' matches the 'filter_string' (used for include and exclude filters). + + If 'filter_string' does not contain any % sign, perform an exact match. + Otherwise, match with a SQL-like query, where % means any character sequence, + and _ means a single character (these caracters can be escaped with a backslash). + + :param string: the string to match. + :param filter_string: the filter string. """ - if not isinstance(gname, str): - raise exceptions.ValidationError('group name must be a string') - self.group_name = gname + if '%' in filter_string: + regex_filter = get_regex_pattern_from_sql(filter_string) + return re.match(regex_filter, string) is not None + return string == filter_string - def is_to_be_grouped(self, the_class): + def is_to_be_grouped(self, node): """ - Return whether the given class has to be included in the autogroup according to include/exclude list + Return whether the given node has to be included in the autogroup according to include/exclude list - :return (bool): True if the_class is to be included in the autogroup + :return (bool): True if ``node`` is to be included in the autogroup """ + # strings, including possibly 'all' include = self.get_include() - include_ws = self.get_include_with_subclasses() - if (('all.' in include) or - (the_class._plugin_type_string in include) or - any([the_class._plugin_type_string.startswith(i) for i in include_ws]) - ): - exclude = self.get_exclude() - exclude_ws = self.get_exclude_with_subclasses() - if ((not 'all.' in exclude) or - (the_class._plugin_type_string in exclude) or - any([the_class._plugin_type_string.startswith(i) for i in exclude_ws]) - ): - return True - else: - return False + exclude = self.get_exclude() + if include is None and exclude is None: + # Include all classes by default if nothing is explicitly specified. + return True + + # We should never be here, anyway - this should be catched by the `set_include/exclude` methods + assert include is None or exclude is None, "You cannot specify both an 'include' and an 'exclude' list" + + entry_point_string = node.process_type + # If there is no `process_type` we are dealing with a `Data` node so we get the entry point from the class + if not entry_point_string: + entry_point_string = get_entry_point_string_from_class(node.__class__.__module__, node.__class__.__name__) + if include is not None: + # As soon as a filter string matches, we include the class + return any(self._matches(entry_point_string, filter_string) for filter_string in include) + # If we are here, exclude is not None + # include *only* in *none* of the filters match (that is, exclude as + # soon as any of the filters matches) + return not any(self._matches(entry_point_string, filter_string) for filter_string in exclude) + + def clear_group_cache(self): + """Clear the cache of the group name. + + This is mostly used by tests when they reset the database. + """ + self._group_label = None + + def get_or_create_group(self): + """Return the current Autogroup, or create one if None has been set yet. + + This function implements a somewhat complex logic that is however needed + to make sure that, even if `verdi run` is called at the same time multiple + times, e.g. in a for loop in bash, there is never the risk that two ``verdi run`` + Unix processes try to create the same group, with the same label, ending + up in a crash of the code (see PR #3650). + + Here, instead, we make sure that if this concurrency issue happens, + one of the two will get a IntegrityError from the DB, and then recover + trying to create a group with a different label (with a numeric suffix appended), + until it manages to create it. + """ + from aiida.orm import QueryBuilder + + # When this function is called, if it is the first time, just generate + # a new group name (later on, after this ``if`` block`). + # In that case, we will later cache in ``self._group_label`` the group label, + # So the group with the same name can be returned quickly in future + # calls of this method. + if self._group_label is not None: + results = [ + res[0] for res in QueryBuilder(). + append(Group, filters={ + 'label': self._group_label, + 'type_string': VERDIAUTOGROUP_TYPE + }, project='*').iterall() + ] + if results: + # If it is not empty, it should have only one result due to the + # uniqueness constraints + assert len(results) == 1, 'I got more than one autogroup with the same label!' + return results[0] + # There are no results: probably the group has been deleted. + # I continue as if it was not cached + self._group_label = None + + label_prefix = self.get_group_label_prefix() + # Try to do a preliminary QB query to avoid to do too many try/except + # if many of the prefix_NUMBER groups already exist + queryb = QueryBuilder().append( + Group, + filters={ + 'or': [{ + 'label': { + '==': label_prefix + } + }, { + 'label': { + 'like': escape_for_sql_like(label_prefix + '_') + '%' + } + }] + }, + project='label' + ) + existing_group_labels = [res[0][len(label_prefix):] for res in queryb.all()] + existing_group_ints = [] + for label in existing_group_labels: + if label == '': + # This is just the prefix without name - corresponds to counter = 0 + existing_group_ints.append(0) + elif label.startswith('_'): + try: + existing_group_ints.append(int(label[1:])) + except ValueError: + # It's not an integer, so it will never collide - just ignore it + pass + + if not existing_group_ints: + counter = 0 else: - return False + counter = max(existing_group_ints) + 1 + + while True: + try: + label = label_prefix if counter == 0 else '{}_{}'.format(label_prefix, counter) + group = Group(label=label, type_string=VERDIAUTOGROUP_TYPE).store() + self._group_label = group.label + except exceptions.IntegrityError: + counter += 1 + else: + break + + return group diff --git a/aiida/orm/nodes/node.py b/aiida/orm/nodes/node.py index c8608bc36c..f776809952 100644 --- a/aiida/orm/nodes/node.py +++ b/aiida/orm/nodes/node.py @@ -23,6 +23,7 @@ from aiida.orm.utils.links import LinkManager, LinkTriple from aiida.orm.utils.repository import Repository from aiida.orm.utils.node import AbstractNodeMeta, validate_attribute_extra_key +from aiida.orm import autogroup from ..comments import Comment from ..computers import Computer @@ -1037,18 +1038,9 @@ def store(self, with_transaction=True, use_cache=None): # pylint: disable=argum self._store(with_transaction=with_transaction, clean=True) # Set up autogrouping used by verdi run - from aiida.orm.autogroup import current_autogroup, Autogroup, VERDIAUTOGROUP_TYPE - from aiida.orm import Group - - if current_autogroup is not None: - if not isinstance(current_autogroup, Autogroup): - raise exceptions.ValidationError('`current_autogroup` is not of type `Autogroup`') - - if current_autogroup.is_to_be_grouped(self): - group_label = current_autogroup.get_group_name() - if group_label is not None: - group = Group.objects.get_or_create(label=group_label, type_string=VERDIAUTOGROUP_TYPE)[0] - group.add_nodes(self) + if autogroup.CURRENT_AUTOGROUP is not None and autogroup.CURRENT_AUTOGROUP.is_to_be_grouped(self): + group = autogroup.CURRENT_AUTOGROUP.get_or_create_group() + group.add_nodes(self) return self diff --git a/aiida/orm/utils/node.py b/aiida/orm/utils/node.py index 6827225272..0432964467 100644 --- a/aiida/orm/utils/node.py +++ b/aiida/orm/utils/node.py @@ -90,13 +90,13 @@ def get_type_string_from_class(class_module, class_name): :param class_module: module of the class :param class_name: name of the class """ - from aiida.plugins.entry_point import get_entry_point_from_class, entry_point_group_to_module_path_map + from aiida.plugins.entry_point import get_entry_point_from_class, ENTRY_POINT_GROUP_TO_MODULE_PATH_MAP group, entry_point = get_entry_point_from_class(class_module, class_name) # If we can reverse engineer an entry point group and name, we're dealing with an external class if group and entry_point: - module_base_path = entry_point_group_to_module_path_map[group] + module_base_path = ENTRY_POINT_GROUP_TO_MODULE_PATH_MAP[group] type_string = '{}.{}.{}.'.format(module_base_path, entry_point.name, class_name) # Otherwise we are dealing with an internal class @@ -258,7 +258,7 @@ class AbstractNodeMeta(ABCMeta): # pylint: disable=too-few-public-methods """Some python black magic to set correctly the logger also in subclasses.""" def __new__(mcs, name, bases, namespace): # pylint: disable=arguments-differ,protected-access,too-many-function-args - newcls = ABCMeta.__new__(mcs, name, bases, namespace) + newcls = ABCMeta.__new__(mcs, name, bases, namespace) # pylint: disable=too-many-function-args newcls._logger = logging.getLogger('{}.{}'.format(namespace['__module__'], name)) # Set the plugin type string and query type string based on the plugin type string diff --git a/aiida/plugins/entry_point.py b/aiida/plugins/entry_point.py index 432e740852..2abe6be077 100644 --- a/aiida/plugins/entry_point.py +++ b/aiida/plugins/entry_point.py @@ -7,7 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### - +"""Module to manage loading entrypoints.""" import enum import traceback import functools @@ -24,7 +24,6 @@ __all__ = ('load_entry_point', 'load_entry_point_from_string') - ENTRY_POINT_GROUP_PREFIX = 'aiida.' ENTRY_POINT_STRING_SEPARATOR = ':' @@ -51,7 +50,7 @@ class EntryPointFormat(enum.Enum): MINIMAL = 3 -entry_point_group_to_module_path_map = { +ENTRY_POINT_GROUP_TO_MODULE_PATH_MAP = { 'aiida.calculations': 'aiida.orm.nodes.process.calculation.calcjob', 'aiida.cmdline.data': 'aiida.cmdline.data', 'aiida.data': 'aiida.orm.nodes.data', @@ -65,7 +64,7 @@ class EntryPointFormat(enum.Enum): } -def validate_registered_entry_points(): +def validate_registered_entry_points(): # pylint: disable=invalid-name """Validate all registered entry points by loading them with the corresponding factory. :raises EntryPointError: if any of the registered entry points cannot be loaded. This can happen if: @@ -108,12 +107,11 @@ def format_entry_point_string(group, name, fmt=EntryPointFormat.FULL): if fmt == EntryPointFormat.FULL: return '{}{}{}'.format(group, ENTRY_POINT_STRING_SEPARATOR, name) - elif fmt == EntryPointFormat.PARTIAL: + if fmt == EntryPointFormat.PARTIAL: return '{}{}{}'.format(group[len(ENTRY_POINT_GROUP_PREFIX):], ENTRY_POINT_STRING_SEPARATOR, name) - elif fmt == EntryPointFormat.MINIMAL: + if fmt == EntryPointFormat.MINIMAL: return '{}'.format(name) - else: - raise ValueError('invalid EntryPointFormat') + raise ValueError('invalid EntryPointFormat') def parse_entry_point_string(entry_point_string): @@ -146,14 +144,13 @@ def get_entry_point_string_format(entry_point_string): :rtype: EntryPointFormat """ try: - group, name = entry_point_string.split(ENTRY_POINT_STRING_SEPARATOR) + group, _ = entry_point_string.split(ENTRY_POINT_STRING_SEPARATOR) except ValueError: return EntryPointFormat.MINIMAL else: if group.startswith(ENTRY_POINT_GROUP_PREFIX): return EntryPointFormat.FULL - else: - return EntryPointFormat.PARTIAL + return EntryPointFormat.PARTIAL def get_entry_point_from_string(entry_point_string): @@ -216,7 +213,7 @@ def get_entry_point_groups(): :return: a list of valid entry point groups """ - return entry_point_group_to_module_path_map.keys() + return ENTRY_POINT_GROUP_TO_MODULE_PATH_MAP.keys() def get_entry_point_names(group, sort=True): @@ -243,7 +240,7 @@ def get_entry_points(group): :param group: the entry point group :return: a list of entry points """ - return [ep for ep in ENTRYPOINT_MANAGER.iter_entry_points(group=group)] + return list(ENTRYPOINT_MANAGER.iter_entry_points(group=group)) @functools.lru_cache(maxsize=None) @@ -260,12 +257,16 @@ def get_entry_point(group, name): entry_points = [ep for ep in get_entry_points(group) if ep.name == name] if not entry_points: - raise MissingEntryPointError("Entry point '{}' not found in group '{}'. Try running `reentry scan` to update " - 'the entry point cache.'.format(name, group)) + raise MissingEntryPointError( + "Entry point '{}' not found in group '{}'. Try running `reentry scan` to update " + 'the entry point cache.'.format(name, group) + ) if len(entry_points) > 1: - raise MultipleEntryPointError("Multiple entry points '{}' found in group '{}'.Try running `reentry scan` to " - 'repopulate the entry point cache.'.format(name, group)) + raise MultipleEntryPointError( + "Multiple entry points '{}' found in group '{}'.Try running `reentry scan` to " + 'repopulate the entry point cache.'.format(name, group) + ) return entry_points[0] @@ -291,7 +292,7 @@ def get_entry_point_from_class(class_module, class_name): return None, None -def get_entry_point_string_from_class(class_module, class_name): +def get_entry_point_string_from_class(class_module, class_name): # pylint: disable=invalid-name """ Given the module and name of a class, attempt to obtain the corresponding entry point if it exists and return the entry point string which will be the entry point group and entry point @@ -313,8 +314,7 @@ def get_entry_point_string_from_class(class_module, class_name): if group and entry_point: return ENTRY_POINT_STRING_SEPARATOR.join([group, entry_point.name]) - else: - return None + return None def is_valid_entry_point_string(entry_point_string): @@ -328,12 +328,12 @@ def is_valid_entry_point_string(entry_point_string): :return: True if the string is considered valid, False otherwise """ try: - group, name = entry_point_string.split(ENTRY_POINT_STRING_SEPARATOR) + group, _ = entry_point_string.split(ENTRY_POINT_STRING_SEPARATOR) except (AttributeError, ValueError): # Either `entry_point_string` is not a string or it does not contain the separator return False - return group in entry_point_group_to_module_path_map + return group in ENTRY_POINT_GROUP_TO_MODULE_PATH_MAP @functools.lru_cache(maxsize=None) @@ -349,11 +349,10 @@ def is_registered_entry_point(class_module, class_name, groups=None): :return: boolean, True if the class is a registered entry point, False otherwise. """ if groups is None: - groups = list(entry_point_group_to_module_path_map.keys()) + groups = list(ENTRY_POINT_GROUP_TO_MODULE_PATH_MAP.keys()) for group in groups: for entry_point in ENTRYPOINT_MANAGER.iter_entry_points(group): if class_module == entry_point.module_name and [class_name] == entry_point.attrs: return True - else: - return False + return False diff --git a/aiida/tools/ipython/ipython_magics.py b/aiida/tools/ipython/ipython_magics.py index af3d8cb395..66310c37b9 100644 --- a/aiida/tools/ipython/ipython_magics.py +++ b/aiida/tools/ipython/ipython_magics.py @@ -34,8 +34,8 @@ In [2]: %aiida """ -from IPython import version_info -from IPython.core import magic +from IPython import version_info # pylint: disable=no-name-in-module +from IPython.core import magic # pylint: disable=no-name-in-module,import-error from aiida.common import json diff --git a/docs/source/verdi/verdi_user_guide.rst b/docs/source/verdi/verdi_user_guide.rst index 1e384344e1..3c24ffeb1b 100644 --- a/docs/source/verdi/verdi_user_guide.rst +++ b/docs/source/verdi/verdi_user_guide.rst @@ -702,15 +702,19 @@ Below is a list with all available subcommands. Execute scripts with preloaded AiiDA environment. Options: - -g, --group Enables the autogrouping [default: True] - -n, --group-name TEXT Specify the name of the auto group - -e, --exclude TEXT Exclude these classes from auto grouping - -i, --include TEXT Include these classes from auto grouping - -E, --excludesubclasses TEXT Exclude these classes and their sub classes - from auto grouping - -I, --includesubclasses TEXT Include these classes and their sub classes - from auto grouping - --help Show this message and exit. + --auto-group Enables the autogrouping + -l, --auto-group-label-prefix TEXT + Specify the prefix of the label of the auto + group (numbers might be automatically + appended to generate unique names per run). + -n, --group-name TEXT Specify the name of the auto group + [DEPRECATED, USE --auto-group-label-prefix + instead]. This also enables auto-grouping. + -e, --exclude TEXT Exclude these classes from auto grouping + (use full entrypoint strings). + -i, --include TEXT Include these classes from auto grouping + (use full entrypoint strings or "all"). + --help Show this message and exit. .. _verdi_setup: diff --git a/tests/backends/aiida_django/migrations/test_migrations_common.py b/tests/backends/aiida_django/migrations/test_migrations_common.py index f8de61f9a6..43f4f03b3d 100644 --- a/tests/backends/aiida_django/migrations/test_migrations_common.py +++ b/tests/backends/aiida_django/migrations/test_migrations_common.py @@ -38,8 +38,8 @@ def setUp(self): from aiida.backends.djsite import get_scoped_session from aiida.orm import autogroup - self.current_autogroup = autogroup.current_autogroup - autogroup.current_autogroup = None + self.current_autogroup = autogroup.CURRENT_AUTOGROUP + autogroup.CURRENT_AUTOGROUP = None assert self.migrate_from and self.migrate_to, \ "TestCase '{}' must define migrate_from and migrate_to properties".format(type(self).__name__) self.migrate_from = [(self.app, self.migrate_from)] @@ -85,7 +85,7 @@ def tearDown(self): """At the end make sure we go back to the latest schema version.""" from aiida.orm import autogroup self._revert_database_schema() - autogroup.current_autogroup = self.current_autogroup + autogroup.CURRENT_AUTOGROUP = self.current_autogroup def setUpBeforeMigration(self): """Anything to do before running the migrations, which should be implemented in test subclasses.""" diff --git a/tests/backends/aiida_sqlalchemy/test_migrations.py b/tests/backends/aiida_sqlalchemy/test_migrations.py index 8bdda5d145..8e2046f293 100644 --- a/tests/backends/aiida_sqlalchemy/test_migrations.py +++ b/tests/backends/aiida_sqlalchemy/test_migrations.py @@ -22,7 +22,6 @@ from aiida.backends.sqlalchemy.models.base import Base from aiida.backends.sqlalchemy.utils import flag_modified from aiida.backends.testbase import AiidaTestCase -from aiida.common.utils import Capturing from .test_utils import new_database @@ -57,22 +56,28 @@ def setUp(self): super().setUp() from aiida.orm import autogroup - self.current_autogroup = autogroup.current_autogroup - autogroup.current_autogroup = None + self.current_autogroup = autogroup.CURRENT_AUTOGROUP + autogroup.CURRENT_AUTOGROUP = None assert self.migrate_from and self.migrate_to, \ "TestCase '{}' must define migrate_from and migrate_to properties".format(type(self).__name__) try: - with Capturing(): - self.migrate_db_down(self.migrate_from) + self.migrate_db_down(self.migrate_from) self.setUpBeforeMigration() - with Capturing(): - self.migrate_db_up(self.migrate_to) + self._perform_actual_migration() except Exception: # Bring back the DB to the correct state if this setup part fails self._reset_database_and_schema() + autogroup.CURRENT_AUTOGROUP = self.current_autogroup raise + def _perform_actual_migration(self): + """Perform the actual migration (upwards, to migrate_to). + + Must be called after we are properly set to be in migrate_from. + """ + self.migrate_db_up(self.migrate_to) + def migrate_db_up(self, destination): """ Perform a migration upwards (upgrade) with alembic @@ -99,7 +104,7 @@ def tearDown(self): """ from aiida.orm import autogroup self._reset_database_and_schema() - autogroup.current_autogroup = self.current_autogroup + autogroup.CURRENT_AUTOGROUP = self.current_autogroup super().tearDown() def setUpBeforeMigration(self): # pylint: disable=invalid-name @@ -116,8 +121,7 @@ def _reset_database_and_schema(self): of tests. """ self.reset_database() - with Capturing(): - self.migrate_db_up('head') + self.migrate_db_up('head') @property def current_rev(self): @@ -210,29 +214,12 @@ class TestBackwardMigrationsSQLA(TestMigrationsSQLA): than the migrate_to revision. """ - def setUp(self): - """ - Go to the migrate_from revision, apply setUpBeforeMigration, then - run the migration. - """ - AiidaTestCase.setUp(self) # pylint: disable=bad-super-call - from aiida.orm import autogroup + def _perform_actual_migration(self): + """Perform the actual migration (downwards, to migrate_to). - self.current_autogroup = autogroup.current_autogroup - autogroup.current_autogroup = None - assert self.migrate_from and self.migrate_to, \ - "TestCase '{}' must define migrate_from and migrate_to properties".format(type(self).__name__) - - try: - with Capturing(): - self.migrate_db_down(self.migrate_from) - self.setUpBeforeMigration() - with Capturing(): - self.migrate_db_down(self.migrate_to) - except Exception: - # Bring back the DB to the correct state if this setup part fails - self._reset_database_and_schema() - raise + Must be called after we are properly set to be in migrate_from. + """ + self.migrate_db_down(self.migrate_to) class TestMigrationEngine(TestMigrationsSQLA): @@ -1003,7 +990,7 @@ class TestDbLogUUIDAddition(TestMigrationsSQLA): """ Test that the UUID column is correctly added to the DbLog table and that the uniqueness constraint is added without problems (if the migration arrives until 375c2db70663 then the - constraint is added properly. + constraint is added properly). """ migrate_from = '041a79fc615f' # 041a79fc615f_dblog_cleaning diff --git a/tests/cmdline/commands/test_calcjob.py b/tests/cmdline/commands/test_calcjob.py index dc7895c5d4..2f1945d45a 100644 --- a/tests/cmdline/commands/test_calcjob.py +++ b/tests/cmdline/commands/test_calcjob.py @@ -98,6 +98,7 @@ def setUpClass(cls, *args, **kwargs): cls.arithmetic_job = calculations[0] def setUp(self): + super().setUp() self.cli_runner = CliRunner() def test_calcjob_res(self): diff --git a/tests/cmdline/commands/test_run.py b/tests/cmdline/commands/test_run.py index 3cae78fb70..78c858420f 100644 --- a/tests/cmdline/commands/test_run.py +++ b/tests/cmdline/commands/test_run.py @@ -8,6 +8,9 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Tests for `verdi run`.""" +import tempfile +import warnings + from click.testing import CliRunner from aiida.backends.testbase import AiidaTestCase @@ -28,7 +31,6 @@ def test_run_workfunction(self): that are defined within the script will fail, as the inspect module will not correctly be able to determin the full path of the source file. """ - import tempfile from aiida.orm import load_node from aiida.orm import WorkFunctionNode @@ -64,3 +66,355 @@ def wf(): self.assertTrue(isinstance(node, WorkFunctionNode)) self.assertEqual(node.function_name, 'wf') self.assertEqual(node.get_function_source_code(), script_content) + + +class TestAutoGroups(AiidaTestCase): + """Test the autogroup functionality.""" + + def setUp(self): + """Setup the CLI runner to run command line commands.""" + from aiida.orm import autogroup + + super().setUp() + self.cli_runner = CliRunner() + # I need to disable the global variable of this test environment, + # because invoke is just calling the function and therefore inheriting + # the global variable + self._old_autogroup = autogroup.CURRENT_AUTOGROUP + autogroup.CURRENT_AUTOGROUP = None + + def tearDown(self): + """Setup the CLI runner to run command line commands.""" + from aiida.orm import autogroup + + super().tearDown() + autogroup.CURRENT_AUTOGROUP = self._old_autogroup + + def test_autogroup(self): + """Check if the autogroup is properly generated.""" + from aiida.orm import QueryBuilder, Node, Group, load_node + + script_content = """from aiida.orm import Data +node = Data().store() +print(node.pk) +""" + + with tempfile.NamedTemporaryFile(mode='w+') as fhandle: + fhandle.write(script_content) + fhandle.flush() + + options = ['--auto-group', fhandle.name] + result = self.cli_runner.invoke(cmd_run.run, options) + self.assertClickResultNoException(result) + + pk = int(result.output) + _ = load_node(pk) # Check if the node can be loaded + + queryb = QueryBuilder().append(Node, filters={'id': pk}, tag='node') + queryb.append(Group, with_node='node', filters={'type_string': 'auto.run'}, project='*') + all_auto_groups = queryb.all() + self.assertEqual( + len(all_auto_groups), 1, 'There should be only one autogroup associated with the node just created' + ) + + def test_autogroup_custom_label(self): + """Check if the autogroup is properly generated with the label specified.""" + from aiida.orm import QueryBuilder, Node, Group, load_node + + script_content = """from aiida.orm import Data +node = Data().store() +print(node.pk) +""" + autogroup_label = 'SOME_group_LABEL' + with tempfile.NamedTemporaryFile(mode='w+') as fhandle: + fhandle.write(script_content) + fhandle.flush() + + options = [fhandle.name, '--auto-group', '--auto-group-label-prefix', autogroup_label] + result = self.cli_runner.invoke(cmd_run.run, options) + self.assertClickResultNoException(result) + + pk = int(result.output) + _ = load_node(pk) # Check if the node can be loaded + + queryb = QueryBuilder().append(Node, filters={'id': pk}, tag='node') + queryb.append(Group, with_node='node', filters={'type_string': 'auto.run'}, project='*') + all_auto_groups = queryb.all() + self.assertEqual( + len(all_auto_groups), 1, 'There should be only one autogroup associated with the node just created' + ) + self.assertEqual(all_auto_groups[0][0].label, autogroup_label) + + def test_no_autogroup(self): + """Check if the autogroup is not generated if ``verdi run`` is asked not to.""" + from aiida.orm import QueryBuilder, Node, Group, load_node + + script_content = """from aiida.orm import Data +node = Data().store() +print(node.pk) +""" + + with tempfile.NamedTemporaryFile(mode='w+') as fhandle: + fhandle.write(script_content) + fhandle.flush() + + options = [fhandle.name] # Not storing an autogroup by default + result = self.cli_runner.invoke(cmd_run.run, options) + self.assertClickResultNoException(result) + + pk = int(result.output) + _ = load_node(pk) # Check if the node can be loaded + + queryb = QueryBuilder().append(Node, filters={'id': pk}, tag='node') + queryb.append(Group, with_node='node', filters={'type_string': 'auto.run'}, project='*') + all_auto_groups = queryb.all() + self.assertEqual(len(all_auto_groups), 0, 'There should be no autogroup generated') + + def test_autogroup_filter_class(self): # pylint: disable=too-many-locals + """Check if the autogroup is properly generated but filtered classes are skipped.""" + from aiida.orm import QueryBuilder, Node, Group, load_node + + script_content = """import sys +from aiida.orm import Computer, Int, ArrayData, KpointsData, CalculationNode, WorkflowNode +from aiida.plugins import CalculationFactory +from aiida.engine import run_get_node +ArithmeticAdd = CalculationFactory('arithmetic.add') + +computer = Computer( + name='localhost-example-{}'.format(sys.argv[1]), + hostname='localhost', + description='my computer', + transport_type='local', + scheduler_type='direct', + workdir='/tmp' +).store() +computer.configure() + +code = Code( + input_plugin_name='arithmetic.add', + remote_computer_exec=[computer, '/bin/true']).store() +inputs = { + 'x': Int(1), + 'y': Int(2), + 'code': code, + 'metadata': { + 'options': { + 'resources': { + 'num_machines': 1, + 'num_mpiprocs_per_machine': 1 + } + } + } +} + +node1 = KpointsData().store() +node2 = ArrayData().store() +node3 = Int(3).store() +node4 = CalculationNode().store() +node5 = WorkflowNode().store() +_, node6 = run_get_node(ArithmeticAdd, **inputs) +print(node1.pk) +print(node2.pk) +print(node3.pk) +print(node4.pk) +print(node5.pk) +print(node6.pk) +""" + from aiida.orm import Code + Code() + for idx, ( + flags, + kptdata_in_autogroup, + arraydata_in_autogroup, + int_in_autogroup, + calc_in_autogroup, + wf_in_autogroup, + calcarithmetic_in_autogroup, + ) in enumerate([ + [['--exclude', 'aiida.data:array.kpoints'], False, True, True, True, True, True], + # Check if % works anywhere - both 'int' and 'array.kpoints' contain an 'i' + [['--exclude', 'aiida.data:%i%'], False, True, False, True, True, True], + [['--exclude', 'aiida.data:int'], True, True, False, True, True, True], + [['--exclude', 'aiida.data:%'], False, False, False, True, True, True], + [['--exclude', 'aiida.data:array', 'aiida.data:array.%'], False, False, True, True, True, True], + [['--exclude', 'aiida.data:array', 'aiida.data:array.%', 'aiida.data:int'], False, False, False, True, True, + True], + [['--exclude', 'aiida.calculations:arithmetic.add'], True, True, True, True, True, False], + [ + ['--include', 'aiida.node:process.calculation'], # Base type, no specific plugin + False, + False, + False, + True, + False, + False + ], + [ + ['--include', 'aiida.node:process.workflow'], # Base type, no specific plugin + False, + False, + False, + False, + True, + False + ], + [[], True, True, True, True, True, True], + ]): + with tempfile.NamedTemporaryFile(mode='w+') as fhandle: + fhandle.write(script_content) + fhandle.flush() + + options = ['--auto-group'] + flags + ['--', fhandle.name, str(idx)] + result = self.cli_runner.invoke(cmd_run.run, options) + self.assertClickResultNoException(result) + + pk1_str, pk2_str, pk3_str, pk4_str, pk5_str, pk6_str = result.output.split() + pk1 = int(pk1_str) + pk2 = int(pk2_str) + pk3 = int(pk3_str) + pk4 = int(pk4_str) + pk5 = int(pk5_str) + pk6 = int(pk6_str) + _ = load_node(pk1) # Check if the node can be loaded + _ = load_node(pk2) # Check if the node can be loaded + _ = load_node(pk3) # Check if the node can be loaded + _ = load_node(pk4) # Check if the node can be loaded + _ = load_node(pk5) # Check if the node can be loaded + _ = load_node(pk6) # Check if the node can be loaded + + queryb = QueryBuilder().append(Node, filters={'id': pk1}, tag='node') + queryb.append(Group, with_node='node', filters={'type_string': 'auto.run'}, project='*') + all_auto_groups_kptdata = queryb.all() + + queryb = QueryBuilder().append(Node, filters={'id': pk2}, tag='node') + queryb.append(Group, with_node='node', filters={'type_string': 'auto.run'}, project='*') + all_auto_groups_arraydata = queryb.all() + + queryb = QueryBuilder().append(Node, filters={'id': pk3}, tag='node') + queryb.append(Group, with_node='node', filters={'type_string': 'auto.run'}, project='*') + all_auto_groups_int = queryb.all() + + queryb = QueryBuilder().append(Node, filters={'id': pk4}, tag='node') + queryb.append(Group, with_node='node', filters={'type_string': 'auto.run'}, project='*') + all_auto_groups_calc = queryb.all() + + queryb = QueryBuilder().append(Node, filters={'id': pk5}, tag='node') + queryb.append(Group, with_node='node', filters={'type_string': 'auto.run'}, project='*') + all_auto_groups_wf = queryb.all() + + queryb = QueryBuilder().append(Node, filters={'id': pk6}, tag='node') + queryb.append(Group, with_node='node', filters={'type_string': 'auto.run'}, project='*') + all_auto_groups_calcarithmetic = queryb.all() + + self.assertEqual( + len(all_auto_groups_kptdata), 1 if kptdata_in_autogroup else 0, + 'Wrong number of nodes in autogroup associated with the KpointsData node ' + "just created with flags '{}'".format(' '.join(flags)) + ) + self.assertEqual( + len(all_auto_groups_arraydata), 1 if arraydata_in_autogroup else 0, + 'Wrong number of nodes in autogroup associated with the ArrayData node ' + "just created with flags '{}'".format(' '.join(flags)) + ) + self.assertEqual( + len(all_auto_groups_int), 1 if int_in_autogroup else 0, + 'Wrong number of nodes in autogroup associated with the Int node ' + "just created with flags '{}'".format(' '.join(flags)) + ) + self.assertEqual( + len(all_auto_groups_calc), 1 if calc_in_autogroup else 0, + 'Wrong number of nodes in autogroup associated with the CalculationNode ' + "just created with flags '{}'".format(' '.join(flags)) + ) + self.assertEqual( + len(all_auto_groups_wf), 1 if wf_in_autogroup else 0, + 'Wrong number of nodes in autogroup associated with the WorkflowNode ' + "just created with flags '{}'".format(' '.join(flags)) + ) + self.assertEqual( + len(all_auto_groups_calcarithmetic), 1 if calcarithmetic_in_autogroup else 0, + 'Wrong number of nodes in autogroup associated with the ArithmeticAdd CalcJobNode ' + "just created with flags '{}'".format(' '.join(flags)) + ) + + def test_autogroup_clashing_label(self): + """Check if the autogroup label is properly (re)generated when it clashes with an existing group name.""" + from aiida.orm import QueryBuilder, Node, Group, load_node + + script_content = """from aiida.orm import Data +node = Data().store() +print(node.pk) +""" + autogroup_label = 'SOME_repeated_group_LABEL' + with tempfile.NamedTemporaryFile(mode='w+') as fhandle: + fhandle.write(script_content) + fhandle.flush() + + # First run + options = [fhandle.name, '--auto-group', '--auto-group-label-prefix', autogroup_label] + result = self.cli_runner.invoke(cmd_run.run, options) + self.assertClickResultNoException(result) + + pk = int(result.output) + _ = load_node(pk) # Check if the node can be loaded + queryb = QueryBuilder().append(Node, filters={'id': pk}, tag='node') + queryb.append(Group, with_node='node', filters={'type_string': 'auto.run'}, project='*') + all_auto_groups = queryb.all() + self.assertEqual( + len(all_auto_groups), 1, 'There should be only one autogroup associated with the node just created' + ) + self.assertEqual(all_auto_groups[0][0].label, autogroup_label) + + # A few more runs with the same label - it should not crash but append something to the group name + for _ in range(10): + options = [fhandle.name, '--auto-group', '--auto-group-label-prefix', autogroup_label] + result = self.cli_runner.invoke(cmd_run.run, options) + self.assertClickResultNoException(result) + + pk = int(result.output) + _ = load_node(pk) # Check if the node can be loaded + queryb = QueryBuilder().append(Node, filters={'id': pk}, tag='node') + queryb.append(Group, with_node='node', filters={'type_string': 'auto.run'}, project='*') + all_auto_groups = queryb.all() + self.assertEqual( + len(all_auto_groups), 1, 'There should be only one autogroup associated with the node just created' + ) + self.assertTrue(all_auto_groups[0][0].label.startswith(autogroup_label)) + + def test_legacy_autogroup_name(self): + """Check if the autogroup is properly generated when using the legacy --group-name flag.""" + from aiida.orm import QueryBuilder, Node, Group, load_node + + script_content = """from aiida.orm import Data +node = Data().store() +print(node.pk) +""" + group_label = 'legacy-group-name' + + with tempfile.NamedTemporaryFile(mode='w+') as fhandle: + fhandle.write(script_content) + fhandle.flush() + + options = ['--group-name', group_label, fhandle.name] + with warnings.catch_warnings(record=True) as warns: # pylint: disable=no-member + result = self.cli_runner.invoke(cmd_run.run, options) + self.assertTrue( + any(['use `--auto-group-label-prefix` instead' in str(warn.message) for warn in warns]), + "No warning for '--group-name' was raised" + ) + + self.assertClickResultNoException(result) + + pk = int(result.output) + _ = load_node(pk) # Check if the node can be loaded + + queryb = QueryBuilder().append(Node, filters={'id': pk}, tag='node') + queryb.append(Group, with_node='node', filters={'type_string': 'auto.run'}, project='*') + all_auto_groups = queryb.all() + self.assertEqual( + len(all_auto_groups), 1, 'There should be only one autogroup associated with the node just created' + ) + self.assertEqual( + all_auto_groups[0][0].label, group_label, + 'The auto group label is "{}" instead of "{}"'.format(all_auto_groups[0][0].label, group_label) + ) diff --git a/tests/engine/processes/workchains/test_utils.py b/tests/engine/processes/workchains/test_utils.py index 00f1e127a3..51a787235e 100644 --- a/tests/engine/processes/workchains/test_utils.py +++ b/tests/engine/processes/workchains/test_utils.py @@ -53,6 +53,7 @@ def test_priority(self): attribute_key = 'handlers_called' class ArithmeticAddBaseWorkChain(BaseRestartWorkChain): + """Implementation of a possible BaseRestartWorkChain for the ``ArithmeticAddCalculation``.""" _process_class = ArithmeticAddCalculation @@ -61,6 +62,7 @@ class ArithmeticAddBaseWorkChain(BaseRestartWorkChain): # This can then be checked after invoking `inspect_process` to ensure they were called in the right order @process_handler(priority=100) def handler_01(self, node): + """Example handler returing ExitCode 100.""" handlers_called = node.get_attribute(attribute_key, default=[]) handlers_called.append('handler_01') node.set_attribute(attribute_key, handlers_called) @@ -68,6 +70,7 @@ def handler_01(self, node): @process_handler(priority=300) def handler_03(self, node): + """Example handler returing ExitCode 300.""" handlers_called = node.get_attribute(attribute_key, default=[]) handlers_called.append('handler_03') node.set_attribute(attribute_key, handlers_called) @@ -75,6 +78,7 @@ def handler_03(self, node): @process_handler(priority=200) def handler_02(self, node): + """Example handler returing ExitCode 200.""" handlers_called = node.get_attribute(attribute_key, default=[]) handlers_called.append('handler_02') node.set_attribute(attribute_key, handlers_called) @@ -82,6 +86,7 @@ def handler_02(self, node): @process_handler(priority=400) def handler_04(self, node): + """Example handler returing ExitCode 400.""" handlers_called = node.get_attribute(attribute_key, default=[]) handlers_called.append('handler_04') node.set_attribute(attribute_key, handlers_called) @@ -159,6 +164,7 @@ def test_exit_codes_filter(self): node_skip.set_exit_status(200) # Some other exit status class ArithmeticAddBaseWorkChain(BaseRestartWorkChain): + """Minimal base restart workchain for the ``ArithmeticAddCalculation``.""" _process_class = ArithmeticAddCalculation diff --git a/tests/orm/test_autogroups.py b/tests/orm/test_autogroups.py new file mode 100644 index 0000000000..e1426ad2e8 --- /dev/null +++ b/tests/orm/test_autogroups.py @@ -0,0 +1,129 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +"""Tests for the Autogroup functionality.""" +from aiida.backends.testbase import AiidaTestCase +from aiida.orm import Group, QueryBuilder +from aiida.orm.autogroup import Autogroup + + +class TestAutogroup(AiidaTestCase): + """Tests the Autogroup logic.""" + + def test_get_or_create(self): + """Test the ``get_or_create_group`` method of ``Autogroup``.""" + label_prefix = 'test_prefix_TestAutogroup' + + # Check that there are no groups to begin with + queryb = QueryBuilder().append(Group, filters={'type_string': 'auto.run', 'label': label_prefix}, project='*') + assert not list(queryb.all()) + queryb = QueryBuilder().append( + Group, filters={ + 'type_string': 'auto.run', + 'label': { + 'like': r'{}\_%'.format(label_prefix) + } + }, project='*' + ) + assert not list(queryb.all()) + + # First group (no existing one) + autogroup = Autogroup() + autogroup.set_group_label_prefix(label_prefix) + group = autogroup.get_or_create_group() + expected_label = label_prefix + self.assertEqual( + group.label, expected_label, + "The auto-group should be labelled '{}', it is instead '{}'".format(expected_label, group.label) + ) + + # Second group (only one with no suffix existing) + autogroup = Autogroup() + autogroup.set_group_label_prefix(label_prefix) + group = autogroup.get_or_create_group() + expected_label = label_prefix + '_1' + self.assertEqual( + group.label, expected_label, + "The auto-group should be labelled '{}', it is instead '{}'".format(expected_label, group.label) + ) + + # Second group (only one suffix _1 existing) + autogroup = Autogroup() + autogroup.set_group_label_prefix(label_prefix) + group = autogroup.get_or_create_group() + expected_label = label_prefix + '_2' + self.assertEqual( + group.label, expected_label, + "The auto-group should be labelled '{}', it is instead '{}'".format(expected_label, group.label) + ) + + # I create a group with a large integer suffix (9) + Group(label='{}_9'.format(label_prefix), type_string='auto.run').store() + # The next autogroup should become number 10 + autogroup = Autogroup() + autogroup.set_group_label_prefix(label_prefix) + group = autogroup.get_or_create_group() + expected_label = label_prefix + '_10' + self.assertEqual( + group.label, expected_label, + "The auto-group should be labelled '{}', it is instead '{}'".format(expected_label, group.label) + ) + + # I create a group with a non-integer suffix (15a), it should be ignored + Group(label='{}_15b'.format(label_prefix), type_string='auto.run').store() + # The next autogroup should become number 11 + autogroup = Autogroup() + autogroup.set_group_label_prefix(label_prefix) + group = autogroup.get_or_create_group() + expected_label = label_prefix + '_11' + self.assertEqual( + group.label, expected_label, + "The auto-group should be labelled '{}', it is instead '{}'".format(expected_label, group.label) + ) + + def test_get_or_create_invalid_prefix(self): + """Test the ``get_or_create_group`` method of ``Autogroup`` when there is already a group + with the same prefix, but followed by other non-underscore characters.""" + label_prefix = 'new_test_prefix_TestAutogroup' + # I create a group with the same prefix, but followed by non-underscore + # characters. These should be ignored in the logic. + Group(label='{}xx'.format(label_prefix), type_string='auto.run').store() + + # Check that there are no groups to begin with + queryb = QueryBuilder().append(Group, filters={'type_string': 'auto.run', 'label': label_prefix}, project='*') + assert not list(queryb.all()) + queryb = QueryBuilder().append( + Group, filters={ + 'type_string': 'auto.run', + 'label': { + 'like': r'{}\_%'.format(label_prefix) + } + }, project='*' + ) + assert not list(queryb.all()) + + # First group (no existing one) + autogroup = Autogroup() + autogroup.set_group_label_prefix(label_prefix) + group = autogroup.get_or_create_group() + expected_label = label_prefix + self.assertEqual( + group.label, expected_label, + "The auto-group should be labelled '{}', it is instead '{}'".format(expected_label, group.label) + ) + + # Second group (only one with no suffix existing) + autogroup = Autogroup() + autogroup.set_group_label_prefix(label_prefix) + group = autogroup.get_or_create_group() + expected_label = label_prefix + '_1' + self.assertEqual( + group.label, expected_label, + "The auto-group should be labelled '{}', it is instead '{}'".format(expected_label, group.label) + ) diff --git a/tests/orm/test_groups.py b/tests/orm/test_groups.py index 9c842aa2c4..ce2797daad 100644 --- a/tests/orm/test_groups.py +++ b/tests/orm/test_groups.py @@ -8,7 +8,6 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Test for the Group ORM class.""" - from aiida import orm from aiida.backends.testbase import AiidaTestCase from aiida.common import exceptions diff --git a/tests/tools/importexport/orm/test_codes.py b/tests/tools/importexport/orm/test_codes.py index 5a11e07b94..d8f173107b 100644 --- a/tests/tools/importexport/orm/test_codes.py +++ b/tests/tools/importexport/orm/test_codes.py @@ -24,9 +24,11 @@ class TestCode(AiidaTestCase): """Test ex-/import cases related to Codes""" def setUp(self): + super().setUp() self.reset_database() def tearDown(self): + super().tearDown() self.reset_database() @with_temp_dir diff --git a/tests/tools/visualization/test_graph.py b/tests/tools/visualization/test_graph.py index 9f15cab9ca..d48a3e6800 100644 --- a/tests/tools/visualization/test_graph.py +++ b/tests/tools/visualization/test_graph.py @@ -22,9 +22,11 @@ class TestVisGraph(AiidaTestCase): """Tests for verdi graph""" def setUp(self): + super().setUp() self.reset_database() def tearDown(self): + super().tearDown() self.reset_database() def create_provenance(self): diff --git a/utils/dependency_management.py b/utils/dependency_management.py index 17442d66af..af476de3e7 100755 --- a/utils/dependency_management.py +++ b/utils/dependency_management.py @@ -239,7 +239,6 @@ def validate_environment_yml(): # pylint: disable=too-many-branches # Check that all requirements specified in the setup.json file are found in the # conda environment specification. - missing_from_env = set() for req in install_requirements: if any(re.match(ignore, str(req)) for ignore in CONDA_IGNORE): continue # skip explicitly ignored packages @@ -251,7 +250,7 @@ def validate_environment_yml(): # pylint: disable=too-many-branches # The only dependency left should be the one for Python itself, which is not part of # the install_requirements for setuptools. - if len(conda_dependencies) > 0: + if conda_dependencies: raise DependencySpecificationError( "The 'environment.yml' file contains dependencies that are missing " "in 'setup.json':\n- {}".format('\n- '.join(map(str, conda_dependencies))) @@ -304,7 +303,7 @@ def validate_pyproject_toml(): "Missing requirement '{}' in 'pyproject.toml'.".format(reentry_requirement) ) - except FileNotFoundError as error: + except FileNotFoundError: raise DependencySpecificationError("The 'pyproject.toml' file is missing!") click.secho('Pyproject.toml dependency specification is consistent.', fg='green')