From ba4b2174d1a08348eb914492d4f7c8714dae478a Mon Sep 17 00:00:00 2001 From: Giovanni Pizzi Date: Fri, 3 Apr 2020 13:49:46 +0200 Subject: [PATCH] Addressed comments by Sebastiaan. In particular, the most important changes include: - now the auto-group flag is called `--auto-group`, is a flag and is False by default - only kept `--include` and `--exclude` options, checking typestrings and allowing to end with `%`. One benefit is that while reimplementing I replaced some `isinstance` with string comparisons, with potential further benefits. - `--include` and `--exclude` are now mutually exclusive - improved documentation of the main point of the issue --- aiida/cmdline/commands/cmd_run.py | 71 +++--- aiida/orm/autogroup.py | 231 ++++++++++-------- aiida/orm/nodes/node.py | 7 +- aiida/tools/ipython/ipython_magics.py | 4 +- docs/source/verdi/verdi_user_guide.rst | 31 +-- .../aiida_sqlalchemy/test_migrations.py | 1 + tests/cmdline/commands/test_run.py | 198 +++++++++++++-- .../engine/processes/workchains/test_utils.py | 4 +- tests/orm/test_autogroups.py | 129 ++++++++++ 9 files changed, 478 insertions(+), 198 deletions(-) create mode 100644 tests/orm/test_autogroups.py diff --git a/aiida/cmdline/commands/cmd_run.py b/aiida/cmdline/commands/cmd_run.py index 6c09ff6cd7..281e3092a1 100644 --- a/aiida/cmdline/commands/cmd_run.py +++ b/aiida/cmdline/commands/cmd_run.py @@ -40,12 +40,12 @@ def update_environment(argv): sys.path = _path -def validate_entrypoint_string_or_all(ctx, param, value, allow_all=True): # pylint: disable=unused-argument,invalid-name - """Validate that `value` is a valid entrypoint string or the string 'all'.""" +def validate_entrypoint_string(ctx, param, value): # pylint: disable=unused-argument,invalid-name + """Validate that `value` is a valid entrypoint string.""" from aiida.orm import autogroup try: - autogroup.Autogroup.validate(value, allow_all=allow_all) + autogroup.Autogroup.validate(value) except Exception as exc: raise click.BadParameter(str(exc) + ' ({})'.format(value)) @@ -55,58 +55,41 @@ def validate_entrypoint_string_or_all(ctx, param, value, allow_all=True): # pyl @verdi.command('run', context_settings=dict(ignore_unknown_options=True,)) @click.argument('scriptname', type=click.STRING) @click.argument('varargs', nargs=-1, type=click.UNPROCESSED) -@click.option('--group/--no-group', default=True, show_default=True, help='Enables the autogrouping') +@click.option('--auto-group', is_flag=True, help='Enables the autogrouping') @click.option( '-l', - '--group-label-prefix', + '--auto-group-label-prefix', type=click.STRING, required=False, help='Specify the prefix of the label of the auto group (numbers might be automatically ' - 'appended to generate unique names per run)' + 'appended to generate unique names per run).' ) @click.option( '-n', '--group-name', type=click.STRING, required=False, - help='Specify the name of the auto group [DEPRECATED, USE --group-label-prefix instead]' + help='Specify the name of the auto group [DEPRECATED, USE --auto-group-label-prefix instead]. ' + 'This also enables auto-grouping.' ) @click.option( '-e', '--exclude', cls=MultipleValueOption, - default=lambda: [], - help='Exclude these classes from auto grouping (use full entrypoint strings)', - callback=functools.partial(validate_entrypoint_string_or_all, allow_all=False) + default=None, + help='Exclude these classes from auto grouping (use full entrypoint strings).', + callback=functools.partial(validate_entrypoint_string) ) @click.option( '-i', '--include', cls=MultipleValueOption, - default=lambda: ['all'], - help='Include these classes from auto grouping (use full entrypoint strings or "all")', - callback=validate_entrypoint_string_or_all -) -@click.option( - '-E', - '--excludesubclasses', - cls=MultipleValueOption, - default=lambda: [], - help='Exclude these classes and their sub classes from auto grouping (use full entrypoint strings)', - callback=functools.partial(validate_entrypoint_string_or_all, allow_all=False) -) -@click.option( - '-I', - '--includesubclasses', - cls=MultipleValueOption, - default=lambda: [], - help='Include these classes and their sub classes from auto grouping (use full entrypoint strings)', - callback=functools.partial(validate_entrypoint_string_or_all, allow_all=False) + default=None, + help='Include these classes from auto grouping (use full entrypoint strings or "all").', + callback=validate_entrypoint_string ) @decorators.with_dbenv() -def run( - scriptname, varargs, group, group_label_prefix, group_name, exclude, excludesubclasses, include, includesubclasses -): +def run(scriptname, varargs, auto_group, auto_group_label_prefix, group_name, exclude, include): # pylint: disable=too-many-arguments,exec-used """Execute scripts with preloaded AiiDA environment.""" from aiida.cmdline.utils.shell import DEFAULT_MODULES_LIST @@ -126,20 +109,22 @@ def run( globals_dict['{}'.format(alias)] = getattr(__import__(app_mod, {}, {}, model_name), model_name) if group_name: - warnings.warn('--group-name is deprecated, use `--group-label` instead', AiidaDeprecationWarning) # pylint: disable=no-member - if group_label: - raise click.BadParameter('You cannot specify both --group-name and --group-label; use --group-label only') - group_label = group_name - - if group: + warnings.warn('--group-name is deprecated, use `--auto-group-label-prefix` instead', AiidaDeprecationWarning) # pylint: disable=no-member + if auto_group_label_prefix: + raise click.BadParameter( + 'You cannot specify both --group-name and --auto-group-label-prefix; use --group-label only' + ) + auto_group_label_prefix = group_name + # To have the old behavior, with auto-group enabled. + auto_group = True + + if auto_group: aiida_verdilib_autogroup = autogroup.Autogroup() - # if group_label_prefix is None, use autogenerated name - if group_label_prefix is not None: - aiida_verdilib_autogroup.set_group_label_prefix(group_label_prefix) + # Set the ``group_label_prefix`` if defined, otherwise a default prefix will be used + if auto_group_label_prefix is not None: + aiida_verdilib_autogroup.set_group_label_prefix(auto_group_label_prefix) aiida_verdilib_autogroup.set_exclude(exclude) aiida_verdilib_autogroup.set_include(include) - aiida_verdilib_autogroup.set_exclude_with_subclasses(excludesubclasses) - aiida_verdilib_autogroup.set_include_with_subclasses(includesubclasses) # Note: this is also set in the exec environment! This is the intended behavior autogroup.CURRENT_AUTOGROUP = aiida_verdilib_autogroup diff --git a/aiida/orm/autogroup.py b/aiida/orm/autogroup.py index cbba2f1c4e..eee4239b28 100644 --- a/aiida/orm/autogroup.py +++ b/aiida/orm/autogroup.py @@ -14,7 +14,7 @@ from aiida.common.escaping import escape_for_sql_like from aiida.common.warnings import AiidaDeprecationWarning from aiida.orm import GroupTypeString, Group -from aiida.plugins import load_entry_point_from_string +from aiida.plugins.entry_point import get_entry_point_string_from_class CURRENT_AUTOGROUP = None @@ -29,17 +29,21 @@ class Autogroup: If so, it will call Autogroup.is_to_be_grouped, and decide whether to put it in a group. Such autogroups are going to be of the VERDIAUTOGROUP_TYPE. - The exclude/include lists, can have values 'all' if you want to include/exclude all classes. - Otherwise, they are lists of strings like: calculation.quantumespresso.pw, data.array.kpoints, ... - i.e.: a string identifying the base class, than the path to the class as in Calculation/Data -Factories + The exclude/include lists are lists of strings like: + ``aiida.data:int``, ``aiida.calculation:quantumespresso.pw``, + ``aiida.data:array.%``, ... + i.e.: a string identifying the base class, followed a colona and by the path to the class + as accepted by CalculationFactory/DataFactory. + Each string contain the wildcard ``%`` at the end; + in this case this is used in a ``like`` comparison with the QueryBuilder. + + Only one of the two (between exclude and include) can be set. If none of the two is set, everything is included. """ def __init__(self): """Initialize with defaults.""" - self._exclude = [] - self._exclude_with_subclasses = [] - self._include = ['all'] - self._include_with_subclasses = [] + self._exclude = None + self._include = None now = timezone.now() default_label_prefix = 'Verdi autogroup on ' + now.strftime('%Y-%m-%d %H:%M:%S') @@ -47,36 +51,43 @@ def __init__(self): self._group_label = None # Actual group label, set by `get_or_create_group` @staticmethod - def validate(param, allow_all=True): - """ - Used internally to verify the sanity of exclude, include lists - - :param param: should be a list of valid entrypoint strings - """ - for string in param: - if allow_all and string == 'all': - continue - load_entry_point_from_string(string) # This will raise a MissingEntryPointError if invalid + def validate(strings): + """Validate the list of strings passed to set_include and set_exclude.""" + if strings is None: + return + valid_prefixes = set(['aiida.node', 'aiida.calculations', 'aiida.workflows', 'aiida.data']) + for string in strings: + pieces = string.split(':') + if len(pieces) != 2: + raise exceptions.ValidationError( + "'{}' is not a valid include/exclude filter, must contain two parts split by a colon". + format(string) + ) + if pieces[0] not in valid_prefixes: + raise exceptions.ValidationError( + "'{}' has an invalid prefix, must be among: {}".format(string, sorted(valid_prefixes)) + ) + + # If a % is present, it can only be the last character + string_without_percent = pieces[1] + if string_without_percent.endswith('%'): + string_without_percent = string_without_percent[:-1] + if '%' in string_without_percent: + raise exceptions.ValidationError( + "'{}' can only contain a '%' character, if any, at the end of the string".format(string) + ) def get_exclude(self): - """Return the list of classes to exclude from autogrouping.""" - return self._exclude + """Return the list of classes to exclude from autogrouping. - def get_exclude_with_subclasses(self): - """ - Return the list of classes to exclude from autogrouping. - Will also exclude their derived subclasses - """ - return self._exclude_with_subclasses + Returns ``None`` if no exclusion list has been set.""" + return self._exclude def get_include(self): - """Return the list of classes to include in the autogrouping.""" - return self._include - - def get_include_with_subclasses(self): """Return the list of classes to include in the autogrouping. - Will also include their derived subclasses.""" - return self._include_with_subclasses + + Returns ``None`` if no exclusion list has been set.""" + return self._include def get_group_label_prefix(self): """Get the prefix of the label of the group. @@ -87,7 +98,7 @@ def get_group_name(self): """Get the label of the group. If no group label was set, it will set a default one by itself. - .. deprecated:: 1.1.0 + .. deprecated:: 1.2.0 Will be removed in `v2.0.0`, use :py:meth:`.get_group_label_prefix` instead. """ warnings.warn('function is deprecated, use `get_group_label_prefix` instead', AiidaDeprecationWarning) # pylint: disable=no-member @@ -98,45 +109,28 @@ def set_exclude(self, exclude): :param exclude: a list of valid entry point strings (one of which could be the string 'all') """ + if isinstance(exclude, str): + exclude = [exclude] self.validate(exclude) - if 'all' in self.get_include(): - if 'all' in exclude: - raise exceptions.ValidationError('Cannot exclude and include all classes') + if exclude is not None and self.get_include() is not None: + # It's ok to set None, both as a default, or to 'undo' the exclude list + raise exceptions.ValidationError('Cannot both specify exclude and include') self._exclude = exclude - def set_exclude_with_subclasses(self, exclude): - """ - Set the list of classes to exclude from autogrouping. - Will also exclude their derived subclasses - - :param exclude: a list of valid entry point strings (one of which could be the string 'all') - """ - self.validate(exclude) - self._exclude_with_subclasses = exclude - def set_include(self, include): """ Set the list of classes to include in the autogrouping. :param include: a list of valid entry point strings (one of which could be the string 'all') """ + if isinstance(include, str): + include = [include] self.validate(include) - if 'all' in self.get_exclude(): - if 'all' in include: - raise exceptions.ValidationError('Cannot exclude and include all classes') - + if include is not None and self.get_exclude() is not None: + # It's ok to set None, both as a default, or to 'undo' the include list + raise exceptions.ValidationError('Cannot both specify exclude and include') self._include = include - def set_include_with_subclasses(self, include): - """ - Set the list of classes to include in the autogrouping. - Will also include their derived subclasses. - - :param include: a list of valid entry point strings (one of which could be the string 'all') - """ - self.validate(include) - self._include_with_subclasses = include - def set_group_label_prefix(self, label_prefix): """ Set the label of the group to be created @@ -148,54 +142,61 @@ def set_group_label_prefix(self, label_prefix): def set_group_name(self, gname): """Set the name of the group. - .. deprecated:: 1.1.0 + .. deprecated:: 1.2.0 Will be removed in `v2.0.0`, use :py:meth:`.set_group_label_prefix` instead. """ warnings.warn('function is deprecated, use `set_group_label_prefix` instead', AiidaDeprecationWarning) # pylint: disable=no-member return self.set_group_label_prefix(label_prefix=gname) - def is_to_be_grouped(self, the_class): + @staticmethod + def _matches(string, filter_string): + """Check if 'string' matches the 'filter_string' (used for include and exclude filters). + + If 'filter_string' does not end with a % sign, perform an exact match. + Otherwise, strip the '%' sign and match with string.startswith(filter_string[:-1]). + + :param string: the string to match. + :param filter_string: the filter string. """ - Return whether the given class has to be included in the autogroup according to include/exclude list + if filter_string.endswith('%'): + return string.startswith(filter_string[:-1]) + return string == filter_string - :return (bool): True if the_class is to be included in the autogroup + def is_to_be_grouped(self, node): """ - # strings, including possibly 'all' - include_exact = self.get_include() - include_with_subclasses = self.get_include_with_subclasses() + Return whether the given node has to be included in the autogroup according to include/exclude list - # actual classes, with 'all' stripped out - include_exact_classes = tuple( - load_entry_point_from_string(ep_string) for ep_string in include_exact if ep_string != 'all' - ) - include_with_subclasses_classes = tuple( - load_entry_point_from_string(ep_string) for ep_string in include_with_subclasses if ep_string != 'all' - ) + :return (bool): True if ``node`` is to be included in the autogroup + """ + # I import here to avoid circular imports + from aiida.orm.nodes.process import ProcessNode - if ( - 'all' in include_exact or the_class in include_exact_classes or - issubclass(the_class, include_with_subclasses_classes) - ): - # According to the include, this class should be included - # strings, including possibly 'all' - exclude_exact = self.get_exclude() - exclude_with_subclasses = self.get_exclude_with_subclasses() - - # actual classes, with 'all' stripped out - exclude_exact_classes = tuple( - load_entry_point_from_string(ep_string) for ep_string in exclude_exact if ep_string != 'all' - ) - exclude_with_subclasses_classes = tuple( - load_entry_point_from_string(ep_string) for ep_string in exclude_with_subclasses if ep_string != 'all' - ) - - if (the_class not in exclude_exact_classes and not issubclass(the_class, exclude_with_subclasses_classes)): - # If we are here, it's not excluded - return True - # If we're here, it's both in the include and in the exclude - exclude it - return False - # If we're here, the class is not in the include - return False - return False + # strings, including possibly 'all' + include = self.get_include() + exclude = self.get_exclude() + if include is None and exclude is None: + # Include all classes by default if nothing is explicitly specified. + return True + if include is not None and exclude is not None: + # We should never be here, anyway - this should be catched by the `set_include/exclude` methods + raise ValueError("You cannot specify both an 'include' and an 'exclude' list") + + the_class = node.__class__ + if issubclass(the_class, ProcessNode): + try: + the_class = node.process_class + except ValueError: + # It does not have a process class - we just check the node class then, it could be e.g. + # a bare CalculationNode. + pass + class_entry_point_string = get_entry_point_string_from_class(the_class.__module__, the_class.__name__) + if include is not None: + # As soon as a filter string matches, we include the class + return any(self._matches(class_entry_point_string, filter_string) for filter_string in include) + # If we are here, exclude is not None + # include *only* in *none* of the filters match (that is, exclude as + # soon as any of the filters matches) + return not any(self._matches(class_entry_point_string, filter_string) for filter_string in exclude) def clear_group_cache(self): """Clear the cache of the group name. @@ -205,8 +206,26 @@ def clear_group_cache(self): self._group_label = None def get_or_create_group(self): - """Return the current Autogroup, or create one if None has been set yet.""" + """Return the current Autogroup, or create one if None has been set yet. + + This function implements a somewhat complex logic that is however needed + to make sure that, even if `verdi run` is called at the same time multiple + times, e.g. in a for loop in bash, there is never the risk that two ``verdi run`` + Unix processes try to create the same group, with the same label, ending + up in a crash of the code (see PR #3650). + + Here, instead, we make sure that if this concurrency issue happens, + one of the two will get a IntegrityError from the DB, and then recover + trying to create a group with a different label (with a numeric suffix appended), + until it manages to create it. + """ from aiida.orm import QueryBuilder + + # When this function is called, if it is the first time, just generate + # a new group name (later on, after this ``if`` block`). + # In that case, we will later cache in ``self._group_label`` the group label, + # So the group with the same name can be returned quickly in future + # calls of this method. if self._group_label is not None: results = [ res[0] for res in QueryBuilder(). @@ -218,6 +237,7 @@ def get_or_create_group(self): if results: # If it is not empty, it should have only one result due to the # uniqueness constraints + assert len(results) == 1, 'I got more than one autogroup with the same label!' return results[0] # There are no results: probably the group has been deleted. # I continue as if it was not cached @@ -247,13 +267,12 @@ def get_or_create_group(self): if label == '': # This is just the prefix without name - corresponds to counter = 0 existing_group_ints.append(0) - else: - if label.startswith('_'): - try: - existing_group_ints.append(int(label[1:])) - except ValueError: - # It's not an integer, so it will never collide - just ignore it - pass + elif label.startswith('_'): + try: + existing_group_ints.append(int(label[1:])) + except ValueError: + # It's not an integer, so it will never collide - just ignore it + pass if not existing_group_ints: counter = 0 diff --git a/aiida/orm/nodes/node.py b/aiida/orm/nodes/node.py index dedbddcfc5..c5c7f85ce5 100644 --- a/aiida/orm/nodes/node.py +++ b/aiida/orm/nodes/node.py @@ -1025,10 +1025,9 @@ def store(self, with_transaction=True, use_cache=None): # pylint: disable=argum self._store(with_transaction=with_transaction, clean=True) # Set up autogrouping used by verdi run - if autogroup.CURRENT_AUTOGROUP is not None: - if autogroup.CURRENT_AUTOGROUP.is_to_be_grouped(self.__class__): - group = autogroup.CURRENT_AUTOGROUP.get_or_create_group() - group.add_nodes(self) + if autogroup.CURRENT_AUTOGROUP is not None and autogroup.CURRENT_AUTOGROUP.is_to_be_grouped(self): + group = autogroup.CURRENT_AUTOGROUP.get_or_create_group() + group.add_nodes(self) return self diff --git a/aiida/tools/ipython/ipython_magics.py b/aiida/tools/ipython/ipython_magics.py index af3d8cb395..66310c37b9 100644 --- a/aiida/tools/ipython/ipython_magics.py +++ b/aiida/tools/ipython/ipython_magics.py @@ -34,8 +34,8 @@ In [2]: %aiida """ -from IPython import version_info -from IPython.core import magic +from IPython import version_info # pylint: disable=no-name-in-module +from IPython.core import magic # pylint: disable=no-name-in-module,import-error from aiida.common import json diff --git a/docs/source/verdi/verdi_user_guide.rst b/docs/source/verdi/verdi_user_guide.rst index ee43c3999a..3c24ffeb1b 100644 --- a/docs/source/verdi/verdi_user_guide.rst +++ b/docs/source/verdi/verdi_user_guide.rst @@ -702,24 +702,19 @@ Below is a list with all available subcommands. Execute scripts with preloaded AiiDA environment. Options: - --group / --no-group Enables the autogrouping [default: True] - -l, --group-label-prefix TEXT Specify the prefix of the label of the auto - group (numbers might be automatically - appended to generate unique names per run) - -n, --group-name TEXT Specify the name of the auto group - [DEPRECATED, USE --group-label-prefix - instead] - -e, --exclude TEXT Exclude these classes from auto grouping (use - full entrypoint strings) - -i, --include TEXT Include these classes from auto grouping - (use full entrypoint strings or "all") - -E, --excludesubclasses TEXT Exclude these classes and their sub classes - from auto grouping (use full entrypoint - strings) - -I, --includesubclasses TEXT Include these classes and their sub classes - from auto grouping (use full entrypoint - strings) - --help Show this message and exit. + --auto-group Enables the autogrouping + -l, --auto-group-label-prefix TEXT + Specify the prefix of the label of the auto + group (numbers might be automatically + appended to generate unique names per run). + -n, --group-name TEXT Specify the name of the auto group + [DEPRECATED, USE --auto-group-label-prefix + instead]. This also enables auto-grouping. + -e, --exclude TEXT Exclude these classes from auto grouping + (use full entrypoint strings). + -i, --include TEXT Include these classes from auto grouping + (use full entrypoint strings or "all"). + --help Show this message and exit. .. _verdi_setup: diff --git a/tests/backends/aiida_sqlalchemy/test_migrations.py b/tests/backends/aiida_sqlalchemy/test_migrations.py index 2147414f0d..8e2046f293 100644 --- a/tests/backends/aiida_sqlalchemy/test_migrations.py +++ b/tests/backends/aiida_sqlalchemy/test_migrations.py @@ -68,6 +68,7 @@ def setUp(self): except Exception: # Bring back the DB to the correct state if this setup part fails self._reset_database_and_schema() + autogroup.CURRENT_AUTOGROUP = self.current_autogroup raise def _perform_actual_migration(self): diff --git a/tests/cmdline/commands/test_run.py b/tests/cmdline/commands/test_run.py index 8df034f8ef..ae82826e52 100644 --- a/tests/cmdline/commands/test_run.py +++ b/tests/cmdline/commands/test_run.py @@ -9,6 +9,7 @@ ########################################################################### """Tests for `verdi run`.""" import tempfile +import warnings from click.testing import CliRunner @@ -102,7 +103,7 @@ def test_autogroup(self): fhandle.write(script_content) fhandle.flush() - options = [fhandle.name] + options = ['--auto-group', fhandle.name] result = self.cli_runner.invoke(cmd_run.run, options) self.assertClickResultNoException(result) @@ -129,7 +130,7 @@ def test_autogroup_custom_label(self): fhandle.write(script_content) fhandle.flush() - options = [fhandle.name, '--group-label-prefix', autogroup_label] + options = [fhandle.name, '--auto-group', '--auto-group-label-prefix', autogroup_label] result = self.cli_runner.invoke(cmd_run.run, options) self.assertClickResultNoException(result) @@ -157,7 +158,7 @@ def test_no_autogroup(self): fhandle.write(script_content) fhandle.flush() - options = [fhandle.name, '--no-group'] + options = [fhandle.name] # Not storing an autogroup by default result = self.cli_runner.invoke(cmd_run.run, options) self.assertClickResultNoException(result) @@ -173,46 +174,144 @@ def test_autogroup_filter_class(self): # pylint: disable=too-many-locals """Check if the autogroup is properly generated but filtered classes are skipped.""" from aiida.orm import QueryBuilder, Node, Group, load_node - script_content = """from aiida.orm import Data -node1 = Data().store() -node2 = Int(3).store() + script_content = """import sys +from aiida.orm import Computer, Int, ArrayData, KpointsData, CalculationNode, WorkflowNode +from aiida.plugins import CalculationFactory +from aiida.engine import run_get_node +ArithmeticAdd = CalculationFactory('arithmetic.add') + +computer = Computer( + name='localhost-example-{}'.format(sys.argv[1]), + hostname='localhost', + description='my computer', + transport_type='local', + scheduler_type='direct', + workdir='/tmp' +).store() +computer.configure() + +code = Code( + input_plugin_name='arithmetic.add', + remote_computer_exec=[computer, '/bin/true']).store() +inputs = { + 'x': Int(1), + 'y': Int(2), + 'code': code, + 'metadata': { + 'options': { + 'resources': { + 'num_machines': 1, + 'num_mpiprocs_per_machine': 1 + } + } + } +} + +node1 = KpointsData().store() +node2 = ArrayData().store() +node3 = Int(3).store() +node4 = CalculationNode().store() +node5 = WorkflowNode().store() +_, node6 = run_get_node(ArithmeticAdd, **inputs) print(node1.pk) print(node2.pk) +print(node3.pk) +print(node4.pk) +print(node5.pk) +print(node6.pk) """ - - for flags, data_in_autogroup, int_in_autogroup in [ - [['--exclude', 'aiida.node:data'], False, True], - [['--exclude', 'aiida.data:int'], True, False], - [['--excludesubclasses', 'aiida.node:data'], False, False], - [['--excludesubclasses', 'aiida.data:int'], True, False], - [['--excludesubclasses', 'aiida.node:data', 'aiida.data:int'], False, False], - [['--include', 'aiida.node:process'], False, False], - [['--exclude', 'aiida.node:data', 'aiida.data:int'], False, False], - ]: + from aiida.orm import Code + Code() + for idx, ( + flags, + kptdata_in_autogroup, + arraydata_in_autogroup, + int_in_autogroup, + calc_in_autogroup, + wf_in_autogroup, + calcarithmetic_in_autogroup, + ) in enumerate([ + [['--exclude', 'aiida.data:array.kpoints'], False, True, True, True, True, True], + [['--exclude', 'aiida.data:int'], True, True, False, True, True, True], + [['--exclude', 'aiida.data:%'], False, False, False, True, True, True], + [['--exclude', 'aiida.data:array', 'aiida.data:array.%'], False, False, True, True, True, True], + [['--exclude', 'aiida.data:array', 'aiida.data:array.%', 'aiida.data:int'], False, False, False, True, True, + True], + [['--exclude', 'aiida.calculations:arithmetic.add'], True, True, True, True, True, False], + [ + ['--include', 'aiida.node:process.calculation'], # Base type, no specific plugin + False, + False, + False, + True, + False, + False + ], + [ + ['--include', 'aiida.node:process.workflow'], # Base type, no specific plugin + False, + False, + False, + False, + True, + False + ], + [[], True, True, True, True, True, True], + ]): with tempfile.NamedTemporaryFile(mode='w+') as fhandle: fhandle.write(script_content) fhandle.flush() - options = [fhandle.name] + flags + ['--'] + options = ['--auto-group'] + flags + ['--', fhandle.name, str(idx)] result = self.cli_runner.invoke(cmd_run.run, options) self.assertClickResultNoException(result) - pk1_str, pk2_str = result.output.split() + pk1_str, pk2_str, pk3_str, pk4_str, pk5_str, pk6_str = result.output.split() pk1 = int(pk1_str) pk2 = int(pk2_str) + pk3 = int(pk3_str) + pk4 = int(pk4_str) + pk5 = int(pk5_str) + pk6 = int(pk6_str) _ = load_node(pk1) # Check if the node can be loaded _ = load_node(pk2) # Check if the node can be loaded + _ = load_node(pk3) # Check if the node can be loaded + _ = load_node(pk4) # Check if the node can be loaded + _ = load_node(pk5) # Check if the node can be loaded + _ = load_node(pk6) # Check if the node can be loaded queryb = QueryBuilder().append(Node, filters={'id': pk1}, tag='node') queryb.append(Group, with_node='node', filters={'type_string': 'auto.run'}, project='*') - all_auto_groups_data = queryb.all() + all_auto_groups_kptdata = queryb.all() queryb = QueryBuilder().append(Node, filters={'id': pk2}, tag='node') queryb.append(Group, with_node='node', filters={'type_string': 'auto.run'}, project='*') + all_auto_groups_arraydata = queryb.all() + + queryb = QueryBuilder().append(Node, filters={'id': pk3}, tag='node') + queryb.append(Group, with_node='node', filters={'type_string': 'auto.run'}, project='*') all_auto_groups_int = queryb.all() + + queryb = QueryBuilder().append(Node, filters={'id': pk4}, tag='node') + queryb.append(Group, with_node='node', filters={'type_string': 'auto.run'}, project='*') + all_auto_groups_calc = queryb.all() + + queryb = QueryBuilder().append(Node, filters={'id': pk5}, tag='node') + queryb.append(Group, with_node='node', filters={'type_string': 'auto.run'}, project='*') + all_auto_groups_wf = queryb.all() + + queryb = QueryBuilder().append(Node, filters={'id': pk6}, tag='node') + queryb.append(Group, with_node='node', filters={'type_string': 'auto.run'}, project='*') + all_auto_groups_calcarithmetic = queryb.all() + + self.assertEqual( + len(all_auto_groups_kptdata), 1 if kptdata_in_autogroup else 0, + 'Wrong number of nodes in autogroup associated with the KpointsData node ' + "just created with flags '{}'".format(' '.join(flags)) + ) self.assertEqual( - len(all_auto_groups_data), 1 if data_in_autogroup else 0, - 'Wrong number of nodes in autogroup associated with the Data node ' + len(all_auto_groups_arraydata), 1 if arraydata_in_autogroup else 0, + 'Wrong number of nodes in autogroup associated with the ArrayData node ' "just created with flags '{}'".format(' '.join(flags)) ) self.assertEqual( @@ -220,6 +319,21 @@ def test_autogroup_filter_class(self): # pylint: disable=too-many-locals 'Wrong number of nodes in autogroup associated with the Int node ' "just created with flags '{}'".format(' '.join(flags)) ) + self.assertEqual( + len(all_auto_groups_calc), 1 if calc_in_autogroup else 0, + 'Wrong number of nodes in autogroup associated with the CalculationNode ' + "just created with flags '{}'".format(' '.join(flags)) + ) + self.assertEqual( + len(all_auto_groups_wf), 1 if wf_in_autogroup else 0, + 'Wrong number of nodes in autogroup associated with the WorkflowNode ' + "just created with flags '{}'".format(' '.join(flags)) + ) + self.assertEqual( + len(all_auto_groups_calcarithmetic), 1 if calcarithmetic_in_autogroup else 0, + 'Wrong number of nodes in autogroup associated with the ArithmeticAdd CalcJobNode ' + "just created with flags '{}'".format(' '.join(flags)) + ) def test_autogroup_clashing_label(self): """Check if the autogroup label is properly (re)generated when it clashes with an existing group name.""" @@ -235,7 +349,7 @@ def test_autogroup_clashing_label(self): fhandle.flush() # First run - options = [fhandle.name, '--group-label-prefix', autogroup_label] + options = [fhandle.name, '--auto-group', '--auto-group-label-prefix', autogroup_label] result = self.cli_runner.invoke(cmd_run.run, options) self.assertClickResultNoException(result) @@ -251,7 +365,7 @@ def test_autogroup_clashing_label(self): # A few more runs with the same label - it should not crash but append something to the group name for _ in range(10): - options = [fhandle.name, '--group-label-prefix', autogroup_label] + options = [fhandle.name, '--auto-group', '--auto-group-label-prefix', autogroup_label] result = self.cli_runner.invoke(cmd_run.run, options) self.assertClickResultNoException(result) @@ -264,3 +378,41 @@ def test_autogroup_clashing_label(self): len(all_auto_groups), 1, 'There should be only one autogroup associated with the node just created' ) self.assertTrue(all_auto_groups[0][0].label.startswith(autogroup_label)) + + def test_legacy_autogroup_name(self): + """Check if the autogroup is properly generated when using the legacy --group-name flag.""" + from aiida.orm import QueryBuilder, Node, Group, load_node + + script_content = """from aiida.orm import Data +node = Data().store() +print(node.pk) +""" + group_label = 'legacy-group-name' + + with tempfile.NamedTemporaryFile(mode='w+') as fhandle: + fhandle.write(script_content) + fhandle.flush() + + options = ['--group-name', group_label, fhandle.name] + with warnings.catch_warnings(record=True) as warns: # pylint: disable=no-member + result = self.cli_runner.invoke(cmd_run.run, options) + self.assertTrue( + any(['use `--auto-group-label-prefix` instead' in str(warn.message) for warn in warns]), + "No warning for '--group-name' was raised" + ) + + self.assertClickResultNoException(result) + + pk = int(result.output) + _ = load_node(pk) # Check if the node can be loaded + + queryb = QueryBuilder().append(Node, filters={'id': pk}, tag='node') + queryb.append(Group, with_node='node', filters={'type_string': 'auto.run'}, project='*') + all_auto_groups = queryb.all() + self.assertEqual( + len(all_auto_groups), 1, 'There should be only one autogroup associated with the node just created' + ) + self.assertEqual( + all_auto_groups[0][0].label, group_label, + 'The auto group label is "{}" instead of "{}"'.format(all_auto_groups[0][0].label, group_label) + ) diff --git a/tests/engine/processes/workchains/test_utils.py b/tests/engine/processes/workchains/test_utils.py index 5591229d5b..51a787235e 100644 --- a/tests/engine/processes/workchains/test_utils.py +++ b/tests/engine/processes/workchains/test_utils.py @@ -53,7 +53,7 @@ def test_priority(self): attribute_key = 'handlers_called' class ArithmeticAddBaseWorkChain(BaseRestartWorkChain): - """Implementation of a possible BaseRestartWorkChain for the ArithmeticAdd calculation.""" + """Implementation of a possible BaseRestartWorkChain for the ``ArithmeticAddCalculation``.""" _process_class = ArithmeticAddCalculation @@ -164,7 +164,7 @@ def test_exit_codes_filter(self): node_skip.set_exit_status(200) # Some other exit status class ArithmeticAddBaseWorkChain(BaseRestartWorkChain): - """Minimal base restart workchain ofr the ArithmeticAdd calculation""" + """Minimal base restart workchain for the ``ArithmeticAddCalculation``.""" _process_class = ArithmeticAddCalculation diff --git a/tests/orm/test_autogroups.py b/tests/orm/test_autogroups.py new file mode 100644 index 0000000000..e1426ad2e8 --- /dev/null +++ b/tests/orm/test_autogroups.py @@ -0,0 +1,129 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +"""Tests for the Autogroup functionality.""" +from aiida.backends.testbase import AiidaTestCase +from aiida.orm import Group, QueryBuilder +from aiida.orm.autogroup import Autogroup + + +class TestAutogroup(AiidaTestCase): + """Tests the Autogroup logic.""" + + def test_get_or_create(self): + """Test the ``get_or_create_group`` method of ``Autogroup``.""" + label_prefix = 'test_prefix_TestAutogroup' + + # Check that there are no groups to begin with + queryb = QueryBuilder().append(Group, filters={'type_string': 'auto.run', 'label': label_prefix}, project='*') + assert not list(queryb.all()) + queryb = QueryBuilder().append( + Group, filters={ + 'type_string': 'auto.run', + 'label': { + 'like': r'{}\_%'.format(label_prefix) + } + }, project='*' + ) + assert not list(queryb.all()) + + # First group (no existing one) + autogroup = Autogroup() + autogroup.set_group_label_prefix(label_prefix) + group = autogroup.get_or_create_group() + expected_label = label_prefix + self.assertEqual( + group.label, expected_label, + "The auto-group should be labelled '{}', it is instead '{}'".format(expected_label, group.label) + ) + + # Second group (only one with no suffix existing) + autogroup = Autogroup() + autogroup.set_group_label_prefix(label_prefix) + group = autogroup.get_or_create_group() + expected_label = label_prefix + '_1' + self.assertEqual( + group.label, expected_label, + "The auto-group should be labelled '{}', it is instead '{}'".format(expected_label, group.label) + ) + + # Second group (only one suffix _1 existing) + autogroup = Autogroup() + autogroup.set_group_label_prefix(label_prefix) + group = autogroup.get_or_create_group() + expected_label = label_prefix + '_2' + self.assertEqual( + group.label, expected_label, + "The auto-group should be labelled '{}', it is instead '{}'".format(expected_label, group.label) + ) + + # I create a group with a large integer suffix (9) + Group(label='{}_9'.format(label_prefix), type_string='auto.run').store() + # The next autogroup should become number 10 + autogroup = Autogroup() + autogroup.set_group_label_prefix(label_prefix) + group = autogroup.get_or_create_group() + expected_label = label_prefix + '_10' + self.assertEqual( + group.label, expected_label, + "The auto-group should be labelled '{}', it is instead '{}'".format(expected_label, group.label) + ) + + # I create a group with a non-integer suffix (15a), it should be ignored + Group(label='{}_15b'.format(label_prefix), type_string='auto.run').store() + # The next autogroup should become number 11 + autogroup = Autogroup() + autogroup.set_group_label_prefix(label_prefix) + group = autogroup.get_or_create_group() + expected_label = label_prefix + '_11' + self.assertEqual( + group.label, expected_label, + "The auto-group should be labelled '{}', it is instead '{}'".format(expected_label, group.label) + ) + + def test_get_or_create_invalid_prefix(self): + """Test the ``get_or_create_group`` method of ``Autogroup`` when there is already a group + with the same prefix, but followed by other non-underscore characters.""" + label_prefix = 'new_test_prefix_TestAutogroup' + # I create a group with the same prefix, but followed by non-underscore + # characters. These should be ignored in the logic. + Group(label='{}xx'.format(label_prefix), type_string='auto.run').store() + + # Check that there are no groups to begin with + queryb = QueryBuilder().append(Group, filters={'type_string': 'auto.run', 'label': label_prefix}, project='*') + assert not list(queryb.all()) + queryb = QueryBuilder().append( + Group, filters={ + 'type_string': 'auto.run', + 'label': { + 'like': r'{}\_%'.format(label_prefix) + } + }, project='*' + ) + assert not list(queryb.all()) + + # First group (no existing one) + autogroup = Autogroup() + autogroup.set_group_label_prefix(label_prefix) + group = autogroup.get_or_create_group() + expected_label = label_prefix + self.assertEqual( + group.label, expected_label, + "The auto-group should be labelled '{}', it is instead '{}'".format(expected_label, group.label) + ) + + # Second group (only one with no suffix existing) + autogroup = Autogroup() + autogroup.set_group_label_prefix(label_prefix) + group = autogroup.get_or_create_group() + expected_label = label_prefix + '_1' + self.assertEqual( + group.label, expected_label, + "The auto-group should be labelled '{}', it is instead '{}'".format(expected_label, group.label) + )