From 43d852ebf0048b0a623061d4b22007ce9dd3a576 Mon Sep 17 00:00:00 2001 From: Giovanni Pizzi Date: Thu, 2 Apr 2020 13:13:00 +0200 Subject: [PATCH] Centralised the creation of the Autogroup This also remove an overzelous isinstance check, and moves additional checks in a cached function that is run only when storing the very first node (that needs to be put in an autogroup), making storing of nodes faster (even if times oscillates so it's hard to estimate exactly by how much). Also, added logic to allow for concurrent creation of multiple groups (and test). This fixes #997 --- .ci/workchains.py | 4 +- .gitignore | 1 + aiida/backends/testbase.py | 11 ++ aiida/cmdline/commands/cmd_run.py | 34 +++-- aiida/engine/processes/calcjobs/calcjob.py | 4 +- aiida/manage/caching.py | 8 +- aiida/orm/autogroup.py | 133 ++++++++++++++---- aiida/orm/nodes/node.py | 17 +-- aiida/plugins/entry_point.py | 14 +- docs/source/verdi/verdi_user_guide.rst | 33 +++-- .../aiida_sqlalchemy/test_migrations.py | 51 +++---- tests/cmdline/commands/test_run.py | 46 +++++- .../engine/processes/workchains/test_utils.py | 6 + tests/tools/importexport/orm/test_codes.py | 2 + tests/tools/visualization/test_graph.py | 2 + utils/dependency_management.py | 5 +- 16 files changed, 252 insertions(+), 119 deletions(-) diff --git a/.ci/workchains.py b/.ci/workchains.py index 110334f0ae..f5ab3872d7 100644 --- a/.ci/workchains.py +++ b/.ci/workchains.py @@ -68,8 +68,8 @@ def a_magic_unicorn_appeared(self, node): @process_handler(priority=400, exit_codes=ArithmeticAddCalculation.exit_codes.ERROR_NEGATIVE_NUMBER) def error_negative_sum(self, node): """What even is a negative number, how can I have minus three melons?!.""" - self.ctx.inputs.x = Int(abs(node.inputs.x.value)) - self.ctx.inputs.y = Int(abs(node.inputs.y.value)) + self.ctx.inputs.x = Int(abs(node.inputs.x.value)) # pylint: disable=invalid-name + self.ctx.inputs.y = Int(abs(node.inputs.y.value)) # pylint: disable=invalid-name return ProcessHandlerReport(True) diff --git a/.gitignore b/.gitignore index 9d225c3ef0..1983db653d 100644 --- a/.gitignore +++ b/.gitignore @@ -20,6 +20,7 @@ .cache .pytest_cache .coverage +coverage.xml # Files created by RPN tests .ci/polish/polish_workchains/polish* diff --git a/aiida/backends/testbase.py b/aiida/backends/testbase.py index de855eec4b..ed18f27566 100644 --- a/aiida/backends/testbase.py +++ b/aiida/backends/testbase.py @@ -99,7 +99,11 @@ def tearDown(self): def reset_database(self): """Reset the database to the default state deleting any content currently stored""" + from aiida.orm import autogroup + self.clean_db() + if autogroup.CURRENT_AUTOGROUP is not None: + autogroup.CURRENT_AUTOGROUP.clear_group_cache() self.insert_data() @classmethod @@ -109,7 +113,10 @@ def insert_data(cls): inserts default data into the database (which is for the moment a default computer). """ + from aiida.orm import User + cls.create_user() + User.objects.reset() cls.create_computer() @classmethod @@ -180,7 +187,11 @@ def user_email(cls): # pylint: disable=no-self-argument def tearDownClass(cls, *args, **kwargs): # pylint: disable=arguments-differ # Double check for double security to avoid to run the tearDown # if this is not a test profile + from aiida.orm import autogroup + check_if_tests_can_run() + if autogroup.CURRENT_AUTOGROUP is not None: + autogroup.CURRENT_AUTOGROUP.clear_group_cache() cls.clean_db() cls.clean_repository() cls.__backend_instance.tearDownClass_method(*args, **kwargs) diff --git a/aiida/cmdline/commands/cmd_run.py b/aiida/cmdline/commands/cmd_run.py index fb591bd0ed..6c09ff6cd7 100644 --- a/aiida/cmdline/commands/cmd_run.py +++ b/aiida/cmdline/commands/cmd_run.py @@ -20,7 +20,6 @@ from aiida.cmdline.params.options.multivalue import MultipleValueOption from aiida.cmdline.utils import decorators, echo from aiida.common.warnings import AiidaDeprecationWarning -from aiida.orm import autogroup @contextlib.contextmanager @@ -43,6 +42,8 @@ def update_environment(argv): def validate_entrypoint_string_or_all(ctx, param, value, allow_all=True): # pylint: disable=unused-argument,invalid-name """Validate that `value` is a valid entrypoint string or the string 'all'.""" + from aiida.orm import autogroup + try: autogroup.Autogroup.validate(value, allow_all=allow_all) except Exception as exc: @@ -55,19 +56,26 @@ def validate_entrypoint_string_or_all(ctx, param, value, allow_all=True): # pyl @click.argument('scriptname', type=click.STRING) @click.argument('varargs', nargs=-1, type=click.UNPROCESSED) @click.option('--group/--no-group', default=True, show_default=True, help='Enables the autogrouping') -@click.option('-l', '--group-label', type=click.STRING, required=False, help='Specify the label of the auto group') +@click.option( + '-l', + '--group-label-prefix', + type=click.STRING, + required=False, + help='Specify the prefix of the label of the auto group (numbers might be automatically ' + 'appended to generate unique names per run)' +) @click.option( '-n', '--group-name', type=click.STRING, required=False, - help='Specify the name of the auto group [DEPRECATED, USE --group-label instead]' + help='Specify the name of the auto group [DEPRECATED, USE --group-label-prefix instead]' ) @click.option( '-e', '--exclude', cls=MultipleValueOption, - default=[], + default=lambda: [], help='Exclude these classes from auto grouping (use full entrypoint strings)', callback=functools.partial(validate_entrypoint_string_or_all, allow_all=False) ) @@ -75,7 +83,7 @@ def validate_entrypoint_string_or_all(ctx, param, value, allow_all=True): # pyl '-i', '--include', cls=MultipleValueOption, - default=['all'], + default=lambda: ['all'], help='Include these classes from auto grouping (use full entrypoint strings or "all")', callback=validate_entrypoint_string_or_all ) @@ -83,7 +91,7 @@ def validate_entrypoint_string_or_all(ctx, param, value, allow_all=True): # pyl '-E', '--excludesubclasses', cls=MultipleValueOption, - default=[], + default=lambda: [], help='Exclude these classes and their sub classes from auto grouping (use full entrypoint strings)', callback=functools.partial(validate_entrypoint_string_or_all, allow_all=False) ) @@ -91,15 +99,18 @@ def validate_entrypoint_string_or_all(ctx, param, value, allow_all=True): # pyl '-I', '--includesubclasses', cls=MultipleValueOption, - default=[], + default=lambda: [], help='Include these classes and their sub classes from auto grouping (use full entrypoint strings)', callback=functools.partial(validate_entrypoint_string_or_all, allow_all=False) ) @decorators.with_dbenv() -def run(scriptname, varargs, group, group_label, group_name, exclude, excludesubclasses, include, includesubclasses): +def run( + scriptname, varargs, group, group_label_prefix, group_name, exclude, excludesubclasses, include, includesubclasses +): # pylint: disable=too-many-arguments,exec-used """Execute scripts with preloaded AiiDA environment.""" from aiida.cmdline.utils.shell import DEFAULT_MODULES_LIST + from aiida.orm import autogroup # Prepare the environment for the script to be run globals_dict = { @@ -121,11 +132,10 @@ def run(scriptname, varargs, group, group_label, group_name, exclude, excludesub group_label = group_name if group: - automatic_group_label = group_label - aiida_verdilib_autogroup = autogroup.Autogroup() - if automatic_group_label is not None: - aiida_verdilib_autogroup.set_group_label(automatic_group_label) + # if group_label_prefix is None, use autogenerated name + if group_label_prefix is not None: + aiida_verdilib_autogroup.set_group_label_prefix(group_label_prefix) aiida_verdilib_autogroup.set_exclude(exclude) aiida_verdilib_autogroup.set_include(include) aiida_verdilib_autogroup.set_exclude_with_subclasses(excludesubclasses) diff --git a/aiida/engine/processes/calcjobs/calcjob.py b/aiida/engine/processes/calcjobs/calcjob.py index 0e6b234ef0..9f3d2d765f 100644 --- a/aiida/engine/processes/calcjobs/calcjob.py +++ b/aiida/engine/processes/calcjobs/calcjob.py @@ -64,7 +64,7 @@ def validate_calc_job(inputs, ctx): ) -def validate_parser(parser_name, ctx): +def validate_parser(parser_name, ctx): # pylint: disable=unused-argument """Validate the parser. :raises InputValidationError: if the parser name does not correspond to a loadable `Parser` class. @@ -78,7 +78,7 @@ def validate_parser(parser_name, ctx): raise exceptions.InputValidationError('invalid parser specified: {}'.format(exception)) -def validate_resources(resources, ctx): +def validate_resources(resources, ctx): # pylint: disable=unused-argument """Validate the resources. :raises InputValidationError: if `num_machines` is not specified or is not an integer. diff --git a/aiida/manage/caching.py b/aiida/manage/caching.py index d8079fd747..9b7f1d427d 100644 --- a/aiida/manage/caching.py +++ b/aiida/manage/caching.py @@ -22,7 +22,7 @@ from aiida.common import exceptions from aiida.common.lang import type_check -from aiida.plugins.entry_point import ENTRY_POINT_STRING_SEPARATOR, entry_point_group_to_module_path_map +from aiida.plugins.entry_point import ENTRY_POINT_STRING_SEPARATOR, ENTRY_POINT_GROUP_TO_MODULE_PATH_MAP __all__ = ('get_use_cache', 'enable_caching', 'disable_caching') @@ -248,7 +248,7 @@ def _validate_identifier_pattern(*, identifier): 1. - where `group_name` is one of the keys in `entry_point_group_to_module_path_map` + where `group_name` is one of the keys in `ENTRY_POINT_GROUP_TO_MODULE_PATH_MAP` and `tail` can be anything _except_ `ENTRY_POINT_STRING_SEPARATOR`. 2. a fully qualified Python name @@ -276,7 +276,7 @@ def _validate_identifier_pattern(*, identifier): group_pattern, _ = identifier.split(ENTRY_POINT_STRING_SEPARATOR) if not any( _match_wildcard(string=group_name, pattern=group_pattern) - for group_name in entry_point_group_to_module_path_map + for group_name in ENTRY_POINT_GROUP_TO_MODULE_PATH_MAP ): raise ValueError( common_error_msg + "Group name pattern '{}' does not match any of the AiiDA entry point group names.". @@ -290,7 +290,7 @@ def _validate_identifier_pattern(*, identifier): # aiida.* or aiida.calculations* if '*' in identifier: group_part, _ = identifier.split('*', 1) - if any(group_name.startswith(group_part) for group_name in entry_point_group_to_module_path_map): + if any(group_name.startswith(group_part) for group_name in ENTRY_POINT_GROUP_TO_MODULE_PATH_MAP): return # Finally, check if it could be a fully qualified Python name for identifier_part in identifier.split('.'): diff --git a/aiida/orm/autogroup.py b/aiida/orm/autogroup.py index f11bf69a36..cbba2f1c4e 100644 --- a/aiida/orm/autogroup.py +++ b/aiida/orm/autogroup.py @@ -11,8 +11,9 @@ import warnings from aiida.common import exceptions, timezone +from aiida.common.escaping import escape_for_sql_like from aiida.common.warnings import AiidaDeprecationWarning -from aiida.orm import GroupTypeString +from aiida.orm import GroupTypeString, Group from aiida.plugins import load_entry_point_from_string CURRENT_AUTOGROUP = None @@ -35,14 +36,15 @@ class Autogroup: def __init__(self): """Initialize with defaults.""" - self.exclude = [] - self.exclude_with_subclasses = [] - self.include = ['all'] - self.include_with_subclasses = [] + self._exclude = [] + self._exclude_with_subclasses = [] + self._include = ['all'] + self._include_with_subclasses = [] now = timezone.now() - gname = 'Verdi autogroup on ' + now.strftime('%Y-%m-%d %H:%M:%S') - self.group_label = gname + default_label_prefix = 'Verdi autogroup on ' + now.strftime('%Y-%m-%d %H:%M:%S') + self._group_label_prefix = default_label_prefix + self._group_label = None # Actual group label, set by `get_or_create_group` @staticmethod def validate(param, allow_all=True): @@ -58,38 +60,38 @@ def validate(param, allow_all=True): def get_exclude(self): """Return the list of classes to exclude from autogrouping.""" - return self.exclude + return self._exclude def get_exclude_with_subclasses(self): """ Return the list of classes to exclude from autogrouping. Will also exclude their derived subclasses """ - return self.exclude_with_subclasses + return self._exclude_with_subclasses def get_include(self): """Return the list of classes to include in the autogrouping.""" - return self.include + return self._include def get_include_with_subclasses(self): """Return the list of classes to include in the autogrouping. Will also include their derived subclasses.""" - return self.include_with_subclasses + return self._include_with_subclasses - def get_group_label(self): - """Get the name of the group. - If no group name was set, it will set a default one by itself.""" - return self.group_label + def get_group_label_prefix(self): + """Get the prefix of the label of the group. + If no group label prefix was set, it will set a default one by itself.""" + return self._group_label_prefix def get_group_name(self): """Get the label of the group. If no group label was set, it will set a default one by itself. .. deprecated:: 1.1.0 - Will be removed in `v2.0.0`, use :py:meth:`.get_group_label` instead. + Will be removed in `v2.0.0`, use :py:meth:`.get_group_label_prefix` instead. """ - warnings.warn('function is deprecated, use `get_group_label` instead', AiidaDeprecationWarning) # pylint: disable=no-member - return self.get_group_label() + warnings.warn('function is deprecated, use `get_group_label_prefix` instead', AiidaDeprecationWarning) # pylint: disable=no-member + return self.get_group_label_prefix() def set_exclude(self, exclude): """Return the list of classes to exclude from autogrouping. @@ -100,7 +102,7 @@ def set_exclude(self, exclude): if 'all' in self.get_include(): if 'all' in exclude: raise exceptions.ValidationError('Cannot exclude and include all classes') - self.exclude = exclude + self._exclude = exclude def set_exclude_with_subclasses(self, exclude): """ @@ -110,7 +112,7 @@ def set_exclude_with_subclasses(self, exclude): :param exclude: a list of valid entry point strings (one of which could be the string 'all') """ self.validate(exclude) - self.exclude_with_subclasses = exclude + self._exclude_with_subclasses = exclude def set_include(self, include): """ @@ -123,7 +125,7 @@ def set_include(self, include): if 'all' in include: raise exceptions.ValidationError('Cannot exclude and include all classes') - self.include = include + self._include = include def set_include_with_subclasses(self, include): """ @@ -133,24 +135,24 @@ def set_include_with_subclasses(self, include): :param include: a list of valid entry point strings (one of which could be the string 'all') """ self.validate(include) - self.include_with_subclasses = include + self._include_with_subclasses = include - def set_group_label(self, label): + def set_group_label_prefix(self, label_prefix): """ Set the label of the group to be created """ - if not isinstance(label, str): + if not isinstance(label_prefix, str): raise exceptions.ValidationError('group label must be a string') - self.group_label = label + self._group_label_prefix = label_prefix def set_group_name(self, gname): """Set the name of the group. .. deprecated:: 1.1.0 - Will be removed in `v2.0.0`, use :py:meth:`.set_group_label` instead. + Will be removed in `v2.0.0`, use :py:meth:`.set_group_label_prefix` instead. """ - warnings.warn('function is deprecated, use `set_group_label` instead', AiidaDeprecationWarning) # pylint: disable=no-member - return self.set_group_label(label=gname) + warnings.warn('function is deprecated, use `set_group_label_prefix` instead', AiidaDeprecationWarning) # pylint: disable=no-member + return self.set_group_label_prefix(label_prefix=gname) def is_to_be_grouped(self, the_class): """ @@ -194,3 +196,78 @@ def is_to_be_grouped(self, the_class): return False # If we're here, the class is not in the include - return False return False + + def clear_group_cache(self): + """Clear the cache of the group name. + + This is mostly used by tests when they reset the database. + """ + self._group_label = None + + def get_or_create_group(self): + """Return the current Autogroup, or create one if None has been set yet.""" + from aiida.orm import QueryBuilder + if self._group_label is not None: + results = [ + res[0] for res in QueryBuilder(). + append(Group, filters={ + 'label': self._group_label, + 'type_string': VERDIAUTOGROUP_TYPE + }, project='*').iterall() + ] + if results: + # If it is not empty, it should have only one result due to the + # uniqueness constraints + return results[0] + # There are no results: probably the group has been deleted. + # I continue as if it was not cached + self._group_label = None + + label_prefix = self.get_group_label_prefix() + # Try to do a preliminary QB query to avoid to do too many try/except + # if many of the prefix_NUMBER groups already exist + queryb = QueryBuilder().append( + Group, + filters={ + 'or': [{ + 'label': { + '==': label_prefix + } + }, { + 'label': { + 'like': escape_for_sql_like(label_prefix + '_') + '%' + } + }] + }, + project='label' + ) + existing_group_labels = [res[0][len(label_prefix):] for res in queryb.all()] + existing_group_ints = [] + for label in existing_group_labels: + if label == '': + # This is just the prefix without name - corresponds to counter = 0 + existing_group_ints.append(0) + else: + if label.startswith('_'): + try: + existing_group_ints.append(int(label[1:])) + except ValueError: + # It's not an integer, so it will never collide - just ignore it + pass + + if not existing_group_ints: + counter = 0 + else: + counter = max(existing_group_ints) + 1 + + while True: + try: + label = label_prefix if counter == 0 else '{}_{}'.format(label_prefix, counter) + group = Group(label=label, type_string=VERDIAUTOGROUP_TYPE).store() + self._group_label = group.label + except exceptions.IntegrityError: + counter += 1 + else: + break + + return group diff --git a/aiida/orm/nodes/node.py b/aiida/orm/nodes/node.py index fecfd555ca..dedbddcfc5 100644 --- a/aiida/orm/nodes/node.py +++ b/aiida/orm/nodes/node.py @@ -23,6 +23,7 @@ from aiida.orm.utils.links import LinkManager, LinkTriple from aiida.orm.utils.repository import Repository from aiida.orm.utils.node import AbstractNodeMeta, validate_attribute_extra_key +from aiida.orm import autogroup from ..comments import Comment from ..computers import Computer @@ -1024,18 +1025,10 @@ def store(self, with_transaction=True, use_cache=None): # pylint: disable=argum self._store(with_transaction=with_transaction, clean=True) # Set up autogrouping used by verdi run - from aiida.orm.autogroup import CURRENT_AUTOGROUP, Autogroup, VERDIAUTOGROUP_TYPE - from aiida.orm import Group - - if CURRENT_AUTOGROUP is not None: - if not isinstance(CURRENT_AUTOGROUP, Autogroup): - raise exceptions.ValidationError('`CURRENT_AUTOGROUP` is not of type `Autogroup`') - - if CURRENT_AUTOGROUP.is_to_be_grouped(self.__class__): - group_label = CURRENT_AUTOGROUP.get_group_label() - if group_label is not None: - group = Group.objects.get_or_create(label=group_label, type_string=VERDIAUTOGROUP_TYPE)[0] - group.add_nodes(self) + if autogroup.CURRENT_AUTOGROUP is not None: + if autogroup.CURRENT_AUTOGROUP.is_to_be_grouped(self.__class__): + group = autogroup.CURRENT_AUTOGROUP.get_or_create_group() + group.add_nodes(self) return self diff --git a/aiida/plugins/entry_point.py b/aiida/plugins/entry_point.py index 9be505f7dd..d2b7132e06 100644 --- a/aiida/plugins/entry_point.py +++ b/aiida/plugins/entry_point.py @@ -240,7 +240,7 @@ def get_entry_points(group): :param group: the entry point group :return: a list of entry points """ - return [ep for ep in ENTRYPOINT_MANAGER.iter_entry_points(group=group)] + return list(ENTRYPOINT_MANAGER.iter_entry_points(group=group)) @functools.lru_cache(maxsize=None) @@ -257,12 +257,16 @@ def get_entry_point(group, name): entry_points = [ep for ep in get_entry_points(group) if ep.name == name] if not entry_points: - raise MissingEntryPointError("Entry point '{}' not found in group '{}'.".format(name, group) + - 'Try running `reentry scan` to update the entry point cache.') + raise MissingEntryPointError( + "Entry point '{}' not found in group '{}'.".format(name, group) + + 'Try running `reentry scan` to update the entry point cache.' + ) if len(entry_points) > 1: - raise MultipleEntryPointError("Multiple entry points '{}' found in group '{}'. ".format(name, group) + - 'Try running `reentry scan` to repopulate the entry point cache.') + raise MultipleEntryPointError( + "Multiple entry points '{}' found in group '{}'. ".format(name, group) + + 'Try running `reentry scan` to repopulate the entry point cache.' + ) return entry_points[0] diff --git a/docs/source/verdi/verdi_user_guide.rst b/docs/source/verdi/verdi_user_guide.rst index 041dabf572..ee43c3999a 100644 --- a/docs/source/verdi/verdi_user_guide.rst +++ b/docs/source/verdi/verdi_user_guide.rst @@ -702,21 +702,24 @@ Below is a list with all available subcommands. Execute scripts with preloaded AiiDA environment. Options: - --group / --no-group Enables the autogrouping [default: True] - -l, --group-label TEXT Specify the label of the auto group - -n, --group-name TEXT Specify the name of the auto group - [DEPRECATED, USE --group-label instead] - -e, --exclude TEXT Exclude these classes from auto grouping (use - full entrypoint strings) - -i, --include TEXT Include these classes from auto grouping (use - full entrypoint strings or "all") - -E, --excludesubclasses TEXT Exclude these classes and their sub classes - from auto grouping (use full entrypoint - strings) - -I, --includesubclasses TEXT Include these classes and their sub classes - from auto grouping (use full entrypoint - strings) - --help Show this message and exit. + --group / --no-group Enables the autogrouping [default: True] + -l, --group-label-prefix TEXT Specify the prefix of the label of the auto + group (numbers might be automatically + appended to generate unique names per run) + -n, --group-name TEXT Specify the name of the auto group + [DEPRECATED, USE --group-label-prefix + instead] + -e, --exclude TEXT Exclude these classes from auto grouping (use + full entrypoint strings) + -i, --include TEXT Include these classes from auto grouping + (use full entrypoint strings or "all") + -E, --excludesubclasses TEXT Exclude these classes and their sub classes + from auto grouping (use full entrypoint + strings) + -I, --includesubclasses TEXT Include these classes and their sub classes + from auto grouping (use full entrypoint + strings) + --help Show this message and exit. .. _verdi_setup: diff --git a/tests/backends/aiida_sqlalchemy/test_migrations.py b/tests/backends/aiida_sqlalchemy/test_migrations.py index fdd27298ca..2147414f0d 100644 --- a/tests/backends/aiida_sqlalchemy/test_migrations.py +++ b/tests/backends/aiida_sqlalchemy/test_migrations.py @@ -22,7 +22,6 @@ from aiida.backends.sqlalchemy.models.base import Base from aiida.backends.sqlalchemy.utils import flag_modified from aiida.backends.testbase import AiidaTestCase -from aiida.common.utils import Capturing from .test_utils import new_database @@ -63,16 +62,21 @@ def setUp(self): "TestCase '{}' must define migrate_from and migrate_to properties".format(type(self).__name__) try: - with Capturing(): - self.migrate_db_down(self.migrate_from) + self.migrate_db_down(self.migrate_from) self.setUpBeforeMigration() - with Capturing(): - self.migrate_db_up(self.migrate_to) + self._perform_actual_migration() except Exception: # Bring back the DB to the correct state if this setup part fails self._reset_database_and_schema() raise + def _perform_actual_migration(self): + """Perform the actual migration (upwards, to migrate_to). + + Must be called after we are properly set to be in migrate_from. + """ + self.migrate_db_up(self.migrate_to) + def migrate_db_up(self, destination): """ Perform a migration upwards (upgrade) with alembic @@ -116,8 +120,7 @@ def _reset_database_and_schema(self): of tests. """ self.reset_database() - with Capturing(): - self.migrate_db_up('head') + self.migrate_db_up('head') @property def current_rev(self): @@ -210,34 +213,12 @@ class TestBackwardMigrationsSQLA(TestMigrationsSQLA): than the migrate_to revision. """ - def setUp(self): - """ - Go to the migrate_from revision, apply setUpBeforeMigration, then - run the migration. - """ - AiidaTestCase.setUp(self) # pylint: disable=bad-super-call - from aiida.orm import autogroup + def _perform_actual_migration(self): + """Perform the actual migration (downwards, to migrate_to). - self.current_autogroup = autogroup.CURRENT_AUTOGROUP - autogroup.CURRENT_AUTOGROUP = None - assert self.migrate_from and self.migrate_to, \ - "TestCase '{}' must define migrate_from and migrate_to properties".format(type(self).__name__) - - try: - with Capturing(): - self.migrate_db_down(self.migrate_from) - self.setUpBeforeMigration() - with Capturing(): - self.migrate_db_down(self.migrate_to) - except Exception: - # Bring back the DB to the correct state if this setup part fails - self._reset_database_and_schema() - raise - - def tearDown(self): - """Put back the correct autogroup.""" - from aiida.orm import autogroup - autogroup.CURRENT_AUTOGROUP = self.current_autogroup + Must be called after we are properly set to be in migrate_from. + """ + self.migrate_db_down(self.migrate_to) class TestMigrationEngine(TestMigrationsSQLA): @@ -1008,7 +989,7 @@ class TestDbLogUUIDAddition(TestMigrationsSQLA): """ Test that the UUID column is correctly added to the DbLog table and that the uniqueness constraint is added without problems (if the migration arrives until 375c2db70663 then the - constraint is added properly. + constraint is added properly). """ migrate_from = '041a79fc615f' # 041a79fc615f_dblog_cleaning diff --git a/tests/cmdline/commands/test_run.py b/tests/cmdline/commands/test_run.py index 6af129d479..8df034f8ef 100644 --- a/tests/cmdline/commands/test_run.py +++ b/tests/cmdline/commands/test_run.py @@ -129,7 +129,7 @@ def test_autogroup_custom_label(self): fhandle.write(script_content) fhandle.flush() - options = [fhandle.name, '--group-label', autogroup_label] + options = [fhandle.name, '--group-label-prefix', autogroup_label] result = self.cli_runner.invoke(cmd_run.run, options) self.assertClickResultNoException(result) @@ -220,3 +220,47 @@ def test_autogroup_filter_class(self): # pylint: disable=too-many-locals 'Wrong number of nodes in autogroup associated with the Int node ' "just created with flags '{}'".format(' '.join(flags)) ) + + def test_autogroup_clashing_label(self): + """Check if the autogroup label is properly (re)generated when it clashes with an existing group name.""" + from aiida.orm import QueryBuilder, Node, Group, load_node + + script_content = """from aiida.orm import Data +node = Data().store() +print(node.pk) +""" + autogroup_label = 'SOME_repeated_group_LABEL' + with tempfile.NamedTemporaryFile(mode='w+') as fhandle: + fhandle.write(script_content) + fhandle.flush() + + # First run + options = [fhandle.name, '--group-label-prefix', autogroup_label] + result = self.cli_runner.invoke(cmd_run.run, options) + self.assertClickResultNoException(result) + + pk = int(result.output) + _ = load_node(pk) # Check if the node can be loaded + queryb = QueryBuilder().append(Node, filters={'id': pk}, tag='node') + queryb.append(Group, with_node='node', filters={'type_string': 'auto.run'}, project='*') + all_auto_groups = queryb.all() + self.assertEqual( + len(all_auto_groups), 1, 'There should be only one autogroup associated with the node just created' + ) + self.assertEqual(all_auto_groups[0][0].label, autogroup_label) + + # A few more runs with the same label - it should not crash but append something to the group name + for _ in range(10): + options = [fhandle.name, '--group-label-prefix', autogroup_label] + result = self.cli_runner.invoke(cmd_run.run, options) + self.assertClickResultNoException(result) + + pk = int(result.output) + _ = load_node(pk) # Check if the node can be loaded + queryb = QueryBuilder().append(Node, filters={'id': pk}, tag='node') + queryb.append(Group, with_node='node', filters={'type_string': 'auto.run'}, project='*') + all_auto_groups = queryb.all() + self.assertEqual( + len(all_auto_groups), 1, 'There should be only one autogroup associated with the node just created' + ) + self.assertTrue(all_auto_groups[0][0].label.startswith(autogroup_label)) diff --git a/tests/engine/processes/workchains/test_utils.py b/tests/engine/processes/workchains/test_utils.py index 00f1e127a3..5591229d5b 100644 --- a/tests/engine/processes/workchains/test_utils.py +++ b/tests/engine/processes/workchains/test_utils.py @@ -53,6 +53,7 @@ def test_priority(self): attribute_key = 'handlers_called' class ArithmeticAddBaseWorkChain(BaseRestartWorkChain): + """Implementation of a possible BaseRestartWorkChain for the ArithmeticAdd calculation.""" _process_class = ArithmeticAddCalculation @@ -61,6 +62,7 @@ class ArithmeticAddBaseWorkChain(BaseRestartWorkChain): # This can then be checked after invoking `inspect_process` to ensure they were called in the right order @process_handler(priority=100) def handler_01(self, node): + """Example handler returing ExitCode 100.""" handlers_called = node.get_attribute(attribute_key, default=[]) handlers_called.append('handler_01') node.set_attribute(attribute_key, handlers_called) @@ -68,6 +70,7 @@ def handler_01(self, node): @process_handler(priority=300) def handler_03(self, node): + """Example handler returing ExitCode 300.""" handlers_called = node.get_attribute(attribute_key, default=[]) handlers_called.append('handler_03') node.set_attribute(attribute_key, handlers_called) @@ -75,6 +78,7 @@ def handler_03(self, node): @process_handler(priority=200) def handler_02(self, node): + """Example handler returing ExitCode 200.""" handlers_called = node.get_attribute(attribute_key, default=[]) handlers_called.append('handler_02') node.set_attribute(attribute_key, handlers_called) @@ -82,6 +86,7 @@ def handler_02(self, node): @process_handler(priority=400) def handler_04(self, node): + """Example handler returing ExitCode 400.""" handlers_called = node.get_attribute(attribute_key, default=[]) handlers_called.append('handler_04') node.set_attribute(attribute_key, handlers_called) @@ -159,6 +164,7 @@ def test_exit_codes_filter(self): node_skip.set_exit_status(200) # Some other exit status class ArithmeticAddBaseWorkChain(BaseRestartWorkChain): + """Minimal base restart workchain ofr the ArithmeticAdd calculation""" _process_class = ArithmeticAddCalculation diff --git a/tests/tools/importexport/orm/test_codes.py b/tests/tools/importexport/orm/test_codes.py index 5a11e07b94..d8f173107b 100644 --- a/tests/tools/importexport/orm/test_codes.py +++ b/tests/tools/importexport/orm/test_codes.py @@ -24,9 +24,11 @@ class TestCode(AiidaTestCase): """Test ex-/import cases related to Codes""" def setUp(self): + super().setUp() self.reset_database() def tearDown(self): + super().tearDown() self.reset_database() @with_temp_dir diff --git a/tests/tools/visualization/test_graph.py b/tests/tools/visualization/test_graph.py index 9f15cab9ca..d48a3e6800 100644 --- a/tests/tools/visualization/test_graph.py +++ b/tests/tools/visualization/test_graph.py @@ -22,9 +22,11 @@ class TestVisGraph(AiidaTestCase): """Tests for verdi graph""" def setUp(self): + super().setUp() self.reset_database() def tearDown(self): + super().tearDown() self.reset_database() def create_provenance(self): diff --git a/utils/dependency_management.py b/utils/dependency_management.py index 17442d66af..af476de3e7 100755 --- a/utils/dependency_management.py +++ b/utils/dependency_management.py @@ -239,7 +239,6 @@ def validate_environment_yml(): # pylint: disable=too-many-branches # Check that all requirements specified in the setup.json file are found in the # conda environment specification. - missing_from_env = set() for req in install_requirements: if any(re.match(ignore, str(req)) for ignore in CONDA_IGNORE): continue # skip explicitly ignored packages @@ -251,7 +250,7 @@ def validate_environment_yml(): # pylint: disable=too-many-branches # The only dependency left should be the one for Python itself, which is not part of # the install_requirements for setuptools. - if len(conda_dependencies) > 0: + if conda_dependencies: raise DependencySpecificationError( "The 'environment.yml' file contains dependencies that are missing " "in 'setup.json':\n- {}".format('\n- '.join(map(str, conda_dependencies))) @@ -304,7 +303,7 @@ def validate_pyproject_toml(): "Missing requirement '{}' in 'pyproject.toml'.".format(reentry_requirement) ) - except FileNotFoundError as error: + except FileNotFoundError: raise DependencySpecificationError("The 'pyproject.toml' file is missing!") click.secho('Pyproject.toml dependency specification is consistent.', fg='green')