diff --git a/aiida/backends/djsite/db/migrations/0044_dbgroup_type_string.py b/aiida/backends/djsite/db/migrations/0044_dbgroup_type_string.py new file mode 100644 index 0000000000..8c577ce397 --- /dev/null +++ b/aiida/backends/djsite/db/migrations/0044_dbgroup_type_string.py @@ -0,0 +1,44 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +# pylint: disable=invalid-name,too-few-public-methods +"""Migration after the `Group` class became pluginnable and so the group `type_string` changed.""" + +# pylint: disable=no-name-in-module,import-error +from django.db import migrations +from aiida.backends.djsite.db.migrations import upgrade_schema_version + +REVISION = '1.0.44' +DOWN_REVISION = '1.0.43' + +forward_sql = [ + """UPDATE db_dbgroup SET type_string = 'core' WHERE type_string = 'user';""", + """UPDATE db_dbgroup SET type_string = 'core.upf' WHERE type_string = 'data.upf';""", + """UPDATE db_dbgroup SET type_string = 'core.import' WHERE type_string = 'auto.import';""", + """UPDATE db_dbgroup SET type_string = 'core.auto' WHERE type_string = 'auto.run';""", +] + +reverse_sql = [ + """UPDATE db_dbgroup SET type_string = 'user' WHERE type_string = 'core';""", + """UPDATE db_dbgroup SET type_string = 'data.upf' WHERE type_string = 'core.upf';""", + """UPDATE db_dbgroup SET type_string = 'auto.import' WHERE type_string = 'core.import';""", + """UPDATE db_dbgroup SET type_string = 'auto.run' WHERE type_string = 'core.auto';""", +] + + +class Migration(migrations.Migration): + """Migration after the update of group `type_string`""" + dependencies = [ + ('db', '0043_default_link_label'), + ] + + operations = [ + migrations.RunSQL(sql='\n'.join(forward_sql), reverse_sql='\n'.join(reverse_sql)), + upgrade_schema_version(REVISION, DOWN_REVISION), + ] diff --git a/aiida/backends/djsite/db/migrations/__init__.py b/aiida/backends/djsite/db/migrations/__init__.py index a832b4e5f7..41ee2b3d2c 100644 --- a/aiida/backends/djsite/db/migrations/__init__.py +++ b/aiida/backends/djsite/db/migrations/__init__.py @@ -21,7 +21,7 @@ class DeserializationException(AiidaException): pass -LATEST_MIGRATION = '0043_default_link_label' +LATEST_MIGRATION = '0044_dbgroup_type_string' def _update_schema_version(version, apps, _): diff --git a/aiida/backends/sqlalchemy/migrations/versions/bf591f31dd12_dbgroup_type_string.py b/aiida/backends/sqlalchemy/migrations/versions/bf591f31dd12_dbgroup_type_string.py new file mode 100644 index 0000000000..8231d8ebb7 --- /dev/null +++ b/aiida/backends/sqlalchemy/migrations/versions/bf591f31dd12_dbgroup_type_string.py @@ -0,0 +1,45 @@ +# -*- coding: utf-8 -*- +"""Migration after the `Group` class became pluginnable and so the group `type_string` changed. + +Revision ID: bf591f31dd12 +Revises: 118349c10896 +Create Date: 2020-03-31 10:00:52.609146 + +""" +# pylint: disable=no-name-in-module,import-error,invalid-name,no-member +from alembic import op +from sqlalchemy.sql import text + +forward_sql = [ + """UPDATE db_dbgroup SET type_string = 'core' WHERE type_string = 'user';""", + """UPDATE db_dbgroup SET type_string = 'core.upf' WHERE type_string = 'data.upf';""", + """UPDATE db_dbgroup SET type_string = 'core.import' WHERE type_string = 'auto.import';""", + """UPDATE db_dbgroup SET type_string = 'core.auto' WHERE type_string = 'auto.run';""", +] + +reverse_sql = [ + """UPDATE db_dbgroup SET type_string = 'user' WHERE type_string = 'core';""", + """UPDATE db_dbgroup SET type_string = 'data.upf' WHERE type_string = 'core.upf';""", + """UPDATE db_dbgroup SET type_string = 'auto.import' WHERE type_string = 'core.import';""", + """UPDATE db_dbgroup SET type_string = 'auto.run' WHERE type_string = 'core.auto';""", +] + +# revision identifiers, used by Alembic. +revision = 'bf591f31dd12' +down_revision = '118349c10896' +branch_labels = None +depends_on = None + + +def upgrade(): + """Migrations for the upgrade.""" + conn = op.get_bind() + statement = text('\n'.join(forward_sql)) + conn.execute(statement) + + +def downgrade(): + """Migrations for the downgrade.""" + conn = op.get_bind() + statement = text('\n'.join(reverse_sql)) + conn.execute(statement) diff --git a/aiida/cmdline/commands/cmd_data/cmd_upf.py b/aiida/cmdline/commands/cmd_data/cmd_upf.py index 78f79b0d9e..745f4af7a2 100644 --- a/aiida/cmdline/commands/cmd_data/cmd_upf.py +++ b/aiida/cmdline/commands/cmd_data/cmd_upf.py @@ -64,22 +64,13 @@ def upf_listfamilies(elements, with_description): """ from aiida import orm from aiida.plugins import DataFactory - from aiida.orm.nodes.data.upf import UPFGROUP_TYPE UpfData = DataFactory('upf') # pylint: disable=invalid-name query = orm.QueryBuilder() query.append(UpfData, tag='upfdata') if elements is not None: query.add_filter(UpfData, {'attributes.element': {'in': elements}}) - query.append( - orm.Group, - with_node='upfdata', - tag='group', - project=['label', 'description'], - filters={'type_string': { - '==': UPFGROUP_TYPE - }} - ) + query.append(orm.UpfFamily, with_node='upfdata', tag='group', project=['label', 'description']) query.distinct() if query.count() > 0: diff --git a/aiida/cmdline/commands/cmd_group.py b/aiida/cmdline/commands/cmd_group.py index 16978379ae..e48c361b33 100644 --- a/aiida/cmdline/commands/cmd_group.py +++ b/aiida/cmdline/commands/cmd_group.py @@ -13,7 +13,7 @@ from aiida.common.exceptions import UniquenessError from aiida.cmdline.commands.cmd_verdi import verdi -from aiida.cmdline.params import options, arguments, types +from aiida.cmdline.params import options, arguments from aiida.cmdline.utils import echo from aiida.cmdline.utils.decorators import with_dbenv @@ -178,18 +178,6 @@ def group_show(group, raw, limit, uuid): echo.echo(tabulate(table, headers=header)) -@with_dbenv() -def valid_group_type_strings(): - from aiida.orm import GroupTypeString - return tuple(i.value for i in GroupTypeString) - - -@with_dbenv() -def user_defined_group(): - from aiida.orm import GroupTypeString - return GroupTypeString.USER.value - - @verdi_group.command('list') @options.ALL_USERS(help='Show groups for all users, rather than only for the current user') @click.option( @@ -204,8 +192,7 @@ def user_defined_group(): '-t', '--type', 'group_type', - type=types.LazyChoice(valid_group_type_strings), - default=user_defined_group, + default='core', help='Show groups of a specific type, instead of user-defined groups. Start with semicolumn if you want to ' 'specify aiida-internal type' ) @@ -330,9 +317,8 @@ def group_list( def group_create(group_label): """Create an empty group with a given name.""" from aiida import orm - from aiida.orm import GroupTypeString - group, created = orm.Group.objects.get_or_create(label=group_label, type_string=GroupTypeString.USER.value) + group, created = orm.Group.objects.get_or_create(label=group_label) if created: echo.echo_success("Group created with PK = {} and name '{}'".format(group.id, group.label)) @@ -351,7 +337,7 @@ def group_copy(source_group, destination_group): Note that the destination group may not exist.""" from aiida import orm - dest_group, created = orm.Group.objects.get_or_create(label=destination_group, type_string=source_group.type_string) + dest_group, created = orm.Group.objects.get_or_create(label=destination_group) # Issue warning if destination group is not empty and get user confirmation to continue if not created and not dest_group.is_empty: @@ -386,8 +372,7 @@ def verdi_group_path(): '-t', '--type', 'group_type', - type=types.LazyChoice(valid_group_type_strings), - default=user_defined_group, + default='core', help='Show groups of a specific type, instead of user-defined groups. Start with semicolumn if you want to ' 'specify aiida-internal type' ) @@ -396,10 +381,11 @@ def verdi_group_path(): def group_path_ls(path, recursive, as_table, no_virtual, group_type, with_description, no_warn): # pylint: disable=too-many-arguments """Show a list of existing group paths.""" + from aiida.plugins import GroupFactory from aiida.tools.groups.paths import GroupPath, InvalidPath try: - path = GroupPath(path or '', type_string=group_type, warn_invalid_child=not no_warn) + path = GroupPath(path or '', cls=GroupFactory(group_type), warn_invalid_child=not no_warn) except InvalidPath as err: echo.echo_critical(str(err)) diff --git a/aiida/cmdline/commands/cmd_run.py b/aiida/cmdline/commands/cmd_run.py index bd3972b841..d46b6f984c 100644 --- a/aiida/cmdline/commands/cmd_run.py +++ b/aiida/cmdline/commands/cmd_run.py @@ -150,5 +150,6 @@ def run(scriptname, varargs, auto_group, auto_group_label_prefix, group_name, ex # Re-raise the exception to have the error code properly returned at the end raise finally: + autogroup.current_autogroup = None if handle: handle.close() diff --git a/aiida/cmdline/params/types/group.py b/aiida/cmdline/params/types/group.py index ef216044e7..6150f6d062 100644 --- a/aiida/cmdline/params/types/group.py +++ b/aiida/cmdline/params/types/group.py @@ -40,12 +40,12 @@ def orm_class_loader(self): @with_dbenv() def convert(self, value, param, ctx): - from aiida.orm import Group, GroupTypeString + from aiida.orm import Group try: group = super().convert(value, param, ctx) except click.BadParameter: if self._create_if_not_exist: - group = Group(label=value, type_string=GroupTypeString.USER.value) + group = Group(label=value) else: raise diff --git a/aiida/orm/autogroup.py b/aiida/orm/autogroup.py index 16bf03f1c1..06e83185e3 100644 --- a/aiida/orm/autogroup.py +++ b/aiida/orm/autogroup.py @@ -14,21 +14,18 @@ from aiida.common import exceptions, timezone from aiida.common.escaping import escape_for_sql_like, get_regex_pattern_from_sql from aiida.common.warnings import AiidaDeprecationWarning -from aiida.orm import GroupTypeString, Group +from aiida.orm import AutoGroup from aiida.plugins.entry_point import get_entry_point_string_from_class CURRENT_AUTOGROUP = None -VERDIAUTOGROUP_TYPE = GroupTypeString.VERDIAUTOGROUP_TYPE.value - class Autogroup: - """ - An object used for the autogrouping of objects. - The autogrouping is checked by the Node.store() method. - In the store(), the Node will check if CURRENT_AUTOGROUP is != None. - If so, it will call Autogroup.is_to_be_grouped, and decide whether to put it in a group. - Such autogroups are going to be of the VERDIAUTOGROUP_TYPE. + """Class to create a new `AutoGroup` instance that will, while active, automatically contain all nodes being stored. + + The autogrouping is checked by the `Node.store()` method which, if `CURRENT_AUTOGROUP is not None` the method + `Autogroup.is_to_be_grouped` is called to decide whether to put the current node being stored in the current + `AutoGroup` instance. The exclude/include lists are lists of strings like: ``aiida.data:int``, ``aiida.calculation:quantumespresso.pw``, @@ -198,7 +195,7 @@ def clear_group_cache(self): self._group_label = None def get_or_create_group(self): - """Return the current Autogroup, or create one if None has been set yet. + """Return the current `AutoGroup`, or create one if None has been set yet. This function implements a somewhat complex logic that is however needed to make sure that, even if `verdi run` is called at the same time multiple @@ -219,16 +216,10 @@ def get_or_create_group(self): # So the group with the same name can be returned quickly in future # calls of this method. if self._group_label is not None: - results = [ - res[0] for res in QueryBuilder(). - append(Group, filters={ - 'label': self._group_label, - 'type_string': VERDIAUTOGROUP_TYPE - }, project='*').iterall() - ] + builder = QueryBuilder().append(AutoGroup, filters={'label': self._group_label}) + results = [res[0] for res in builder.iterall()] if results: - # If it is not empty, it should have only one result due to the - # uniqueness constraints + # If it is not empty, it should have only one result due to the uniqueness constraints assert len(results) == 1, 'I got more than one autogroup with the same label!' return results[0] # There are no results: probably the group has been deleted. @@ -239,7 +230,7 @@ def get_or_create_group(self): # Try to do a preliminary QB query to avoid to do too many try/except # if many of the prefix_NUMBER groups already exist queryb = QueryBuilder().append( - Group, + AutoGroup, filters={ 'or': [{ 'label': { @@ -274,7 +265,7 @@ def get_or_create_group(self): while True: try: label = label_prefix if counter == 0 else '{}_{}'.format(label_prefix, counter) - group = Group(label=label, type_string=VERDIAUTOGROUP_TYPE).store() + group = AutoGroup(label=label).store() self._group_label = group.label except exceptions.IntegrityError: counter += 1 diff --git a/aiida/orm/convert.py b/aiida/orm/convert.py index 197253cffd..d6b577773b 100644 --- a/aiida/orm/convert.py +++ b/aiida/orm/convert.py @@ -61,8 +61,9 @@ def _(backend_entity): @get_orm_entity.register(BackendGroup) def _(backend_entity): - from . import groups - return groups.Group.from_backend_entity(backend_entity) + from .groups import load_group_class + group_class = load_group_class(backend_entity.type_string) + return group_class.from_backend_entity(backend_entity) @get_orm_entity.register(BackendComputer) diff --git a/aiida/orm/groups.py b/aiida/orm/groups.py index cb7b4af801..f2c726c0f2 100644 --- a/aiida/orm/groups.py +++ b/aiida/orm/groups.py @@ -8,7 +8,7 @@ # For further information please visit http://www.aiida.net # ########################################################################### """ AiiDA Group entites""" - +from abc import ABCMeta from enum import Enum import warnings @@ -21,19 +21,63 @@ from . import entities from . import users -__all__ = ('Group', 'GroupTypeString') +__all__ = ('Group', 'GroupTypeString', 'AutoGroup', 'ImportGroup', 'UpfFamily') + + +def load_group_class(type_string): + """Load the sub class of `Group` that corresponds to the given `type_string`. + + .. note:: will fall back on `aiida.orm.groups.Group` if `type_string` cannot be resolved to loadable entry point. + + :param type_string: the entry point name of the `Group` sub class + :return: sub class of `Group` registered through an entry point + """ + from aiida.common.exceptions import EntryPointError + from aiida.plugins.entry_point import load_entry_point + + try: + group_class = load_entry_point('aiida.groups', type_string) + except EntryPointError: + message = 'could not load entry point `{}`, falling back onto `Group` base class.'.format(type_string) + warnings.warn(message) # pylint: disable=no-member + group_class = Group + + return group_class + + +class GroupMeta(ABCMeta): + """Meta class for `aiida.orm.groups.Group` to automatically set the `type_string` attribute.""" + + def __new__(mcs, name, bases, namespace, **kwargs): + from aiida.plugins.entry_point import get_entry_point_from_class + + newcls = ABCMeta.__new__(mcs, name, bases, namespace, **kwargs) # pylint: disable=too-many-function-args + + entry_point_group, entry_point = get_entry_point_from_class(namespace['__module__'], name) + + if entry_point_group is None or entry_point_group != 'aiida.groups': + newcls._type_string = None + message = 'no registered entry point for `{}` so its instances will not be storable.'.format(name) + warnings.warn(message) # pylint: disable=no-member + else: + newcls._type_string = entry_point.name # pylint: disable=protected-access + + return newcls class GroupTypeString(Enum): - """A simple enum of allowed group type strings.""" + """A simple enum of allowed group type strings. + .. deprecated:: 1.2.0 + This enum is deprecated and will be removed in `v2.0.0`. + """ UPFGROUP_TYPE = 'data.upf' IMPORTGROUP_TYPE = 'auto.import' VERDIAUTOGROUP_TYPE = 'auto.run' USER = 'user' -class Group(entities.Entity): +class Group(entities.Entity, metaclass=GroupMeta): """An AiiDA ORM implementation of group of nodes.""" class Collection(entities.Collection): @@ -54,21 +98,10 @@ def get_or_create(self, label=None, **kwargs): if not label: raise ValueError('Group label must be provided') - filters = {'label': label} - - if 'type_string' in kwargs: - if not isinstance(kwargs['type_string'], str): - raise exceptions.ValidationError( - 'type_string must be {}, you provided an object of type ' - '{}'.format(str, type(kwargs['type_string'])) - ) - - filters['type_string'] = kwargs['type_string'] - - res = self.find(filters=filters) + res = self.find(filters={'label': label}) if not res: - return Group(label, backend=self.backend, **kwargs).store(), True + return self.entity_type(label, backend=self.backend, **kwargs).store(), True if len(res) > 1: raise exceptions.MultipleObjectsError('More than one groups found in the database') @@ -83,12 +116,15 @@ def delete(self, id): # pylint: disable=invalid-name, redefined-builtin """ self._backend.groups.delete(id) - def __init__(self, label=None, user=None, description='', type_string=GroupTypeString.USER.value, backend=None): + def __init__(self, label=None, user=None, description='', type_string=None, backend=None): """ Create a new group. Either pass a dbgroup parameter, to reload a group from the DB (and then, no further parameters are allowed), or pass the parameters for the Group creation. + .. deprecated:: 1.2.0 + The parameter `type_string` will be removed in `v2.0.0` and is now determined automatically. + :param label: The group label, required on creation :type label: str @@ -105,12 +141,11 @@ def __init__(self, label=None, user=None, description='', type_string=GroupTypeS if not label: raise ValueError('Group label must be provided') - # Check that chosen type_string is allowed - if not isinstance(type_string, str): - raise exceptions.ValidationError( - 'type_string must be {}, you provided an object of type ' - '{}'.format(str, type(type_string)) - ) + if type_string is not None: + message = '`type_string` is deprecated because it is determined automatically, using default `core`' + warnings.warn(message) # pylint: disable=no-member + + type_string = self._type_string backend = backend or get_manager().get_backend() user = user or users.User.objects(backend).get_default() @@ -130,6 +165,13 @@ def __str__(self): return '"{}" [user-defined], of user {}'.format(self.label, self.user.email) + def store(self): + """Verify that the group is allowed to be stored, which is the case along as `type_string` is set.""" + if self._type_string is None: + raise exceptions.StoringNotAllowed('`type_string` is `None` so the group cannot be stored.') + + return super().store() + @property def label(self): """ @@ -295,11 +337,7 @@ def get(cls, **kwargs): filters = {} if 'type_string' in kwargs: - if not isinstance(kwargs['type_string'], str): - raise exceptions.ValidationError( - 'type_string must be {}, you provided an object of type ' - '{}'.format(str, type(kwargs['type_string'])) - ) + type_check(kwargs['type_string'], str) query = QueryBuilder() for key, val in kwargs.items(): @@ -382,3 +420,15 @@ def get_schema(): 'type': 'unicode' } } + + +class AutoGroup(Group): + """Group to be used to contain selected nodes generated while `aiida.orm.autogroup.CURRENT_AUTOGROUP` is set.""" + + +class ImportGroup(Group): + """Group to be used to contain all nodes from an export archive that has been imported.""" + + +class UpfFamily(Group): + """Group that represents a pseudo potential family containing `UpfData` nodes.""" diff --git a/aiida/orm/implementation/groups.py b/aiida/orm/implementation/groups.py index 74349e25e6..f39314060f 100644 --- a/aiida/orm/implementation/groups.py +++ b/aiida/orm/implementation/groups.py @@ -101,7 +101,7 @@ def get_or_create(cls, *args, **kwargs): :return: (group, created) where group is the group (new or existing, in any case already stored) and created is a boolean saying """ - res = cls.query(name=kwargs.get('name'), type_string=kwargs.get('type_string')) + res = cls.query(name=kwargs.get('name')) if not res: return cls.create(*args, **kwargs), True diff --git a/aiida/orm/nodes/data/upf.py b/aiida/orm/nodes/data/upf.py index d35e1b35ee..33cf9b6421 100644 --- a/aiida/orm/nodes/data/upf.py +++ b/aiida/orm/nodes/data/upf.py @@ -8,20 +8,14 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Module of `Data` sub class to represent a pseudopotential single file in UPF format and related utilities.""" - import json import re from upf_to_json import upf_to_json - -from aiida.common.lang import classproperty -from aiida.orm import GroupTypeString from .singlefile import SinglefileData __all__ = ('UpfData',) -UPFGROUP_TYPE = GroupTypeString.UPFGROUP_TYPE.value - REGEX_UPF_VERSION = re.compile(r""" \s*.*)"> @@ -107,9 +101,7 @@ def upload_upf_family(folder, group_label, group_description, stop_if_existing=T nfiles = len(filenames) automatic_user = orm.User.objects.get_default() - group, group_created = orm.Group.objects.get_or_create( - label=group_label, type_string=UPFGROUP_TYPE, user=automatic_user - ) + group, group_created = orm.UpfFamily.objects.get_or_create(label=group_label, user=automatic_user) if group.user.email != automatic_user.email: raise UniquenessError( @@ -312,12 +304,6 @@ def get_or_create(cls, filepath, use_first=False, store_upf=True): return (pseudos[0], False) - @classproperty - def upffamily_type_string(cls): - """Return the type string used for UPF family groups.""" - # pylint: disable=no-self-argument,no-self-use - return UPFGROUP_TYPE - def store(self, *args, **kwargs): """Store the node, reparsing the file so that the md5 and the element are correctly reset.""" # pylint: disable=arguments-differ @@ -388,11 +374,11 @@ def set_file(self, file, filename=None): def get_upf_family_names(self): """Get the list of all upf family names to which the pseudo belongs.""" - from aiida.orm import Group + from aiida.orm import UpfFamily from aiida.orm import QueryBuilder query = QueryBuilder() - query.append(Group, filters={'type_string': {'==': self.upffamily_type_string}}, tag='group', project='label') + query.append(UpfFamily, tag='group', project='label') query.append(UpfData, filters={'id': {'==': self.id}}, with_group='group') return [label for label, in query.all()] @@ -465,9 +451,9 @@ def get_upf_group(cls, group_label): :param group_label: the family group label :return: the `Group` with the given label, if it exists """ - from aiida.orm import Group + from aiida.orm import UpfFamily - return Group.get(label=group_label, type_string=cls.upffamily_type_string) + return UpfFamily.get(label=group_label) @classmethod def get_upf_groups(cls, filter_elements=None, user=None): @@ -480,12 +466,12 @@ def get_upf_groups(cls, filter_elements=None, user=None): If defined, it should be either a `User` instance or the user email. :return: list of `Group` entities of type UPF. """ - from aiida.orm import Group + from aiida.orm import UpfFamily from aiida.orm import QueryBuilder from aiida.orm import User builder = QueryBuilder() - builder.append(Group, filters={'type_string': {'==': cls.upffamily_type_string}}, tag='group', project='*') + builder.append(UpfFamily, tag='group', project='*') if user: builder.append(User, filters={'email': {'==': user}}, with_group='group') @@ -496,7 +482,7 @@ def get_upf_groups(cls, filter_elements=None, user=None): if filter_elements is not None: builder.append(UpfData, filters={'attributes.element': {'in': filter_elements}}, with_group='group') - builder.order_by({Group: {'id': 'asc'}}) + builder.order_by({UpfFamily: {'id': 'asc'}}) return [group for group, in builder.all()] diff --git a/aiida/plugins/entry_point.py b/aiida/plugins/entry_point.py index 2abe6be077..46e4bf3c7e 100644 --- a/aiida/plugins/entry_point.py +++ b/aiida/plugins/entry_point.py @@ -54,6 +54,7 @@ class EntryPointFormat(enum.Enum): 'aiida.calculations': 'aiida.orm.nodes.process.calculation.calcjob', 'aiida.cmdline.data': 'aiida.cmdline.data', 'aiida.data': 'aiida.orm.nodes.data', + 'aiida.groups': 'aiida.orm.groups', 'aiida.node': 'aiida.orm.nodes', 'aiida.parsers': 'aiida.parsers.plugins', 'aiida.schedulers': 'aiida.schedulers.plugins', @@ -78,6 +79,7 @@ def validate_registered_entry_points(): # pylint: disable=invalid-name factory_mapping = { 'aiida.calculations': factories.CalculationFactory, 'aiida.data': factories.DataFactory, + 'aiida.groups': factories.GroupFactory, 'aiida.parsers': factories.ParserFactory, 'aiida.schedulers': factories.SchedulerFactory, 'aiida.transports': factories.TransportFactory, diff --git a/aiida/plugins/factories.py b/aiida/plugins/factories.py index 6e5a9296e9..1675ac6cb6 100644 --- a/aiida/plugins/factories.py +++ b/aiida/plugins/factories.py @@ -14,8 +14,8 @@ from aiida.common.exceptions import InvalidEntryPointTypeError __all__ = ( - 'BaseFactory', 'CalculationFactory', 'DataFactory', 'DbImporterFactory', 'OrbitalFactory', 'ParserFactory', - 'SchedulerFactory', 'TransportFactory', 'WorkflowFactory' + 'BaseFactory', 'CalculationFactory', 'DataFactory', 'DbImporterFactory', 'GroupFactory', 'OrbitalFactory', + 'ParserFactory', 'SchedulerFactory', 'TransportFactory', 'WorkflowFactory' ) @@ -107,6 +107,25 @@ def DbImporterFactory(entry_point_name): raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes) +def GroupFactory(entry_point_name): + """Return the `Group` sub class registered under the given entry point. + + :param entry_point_name: the entry point name + :return: sub class of :py:class:`~aiida.orm.groups.Group` + :raises aiida.common.InvalidEntryPointTypeError: if the type of the loaded entry point is invalid. + """ + from aiida.orm import Group + + entry_point_group = 'aiida.groups' + entry_point = BaseFactory(entry_point_group, entry_point_name) + valid_classes = (Group,) + + if isclass(entry_point) and issubclass(entry_point, Group): + return entry_point + + raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes) + + def OrbitalFactory(entry_point_name): """Return the `Orbital` sub class registered under the given entry point. diff --git a/aiida/tools/groups/paths.py b/aiida/tools/groups/paths.py index 9d20ea9c55..b025ab250e 100644 --- a/aiida/tools/groups/paths.py +++ b/aiida/tools/groups/paths.py @@ -17,7 +17,7 @@ from aiida import orm from aiida.common.exceptions import NotExistent -__all__ = ('GroupPath', 'InvalidPath') +__all__ = ('GroupPath', 'InvalidPath', 'GroupNotFoundError', 'GroupNotUniqueError', 'NoGroupsInPathError') REGEX_ATTR = re.compile('^[a-zA-Z][\\_a-zA-Z0-9]*$') @@ -60,19 +60,20 @@ class GroupPath: See tests for usage examples. """ - def __init__(self, path='', type_string=orm.GroupTypeString.USER.value, warn_invalid_child=True): + def __init__(self, path='', cls=orm.Group, warn_invalid_child=True): # type: (str, Optional[str], Optional[GroupPath]) """Instantiate the class. :param path: The initial path of the group. - :param type_string: Used to query for and instantiate a ``Group`` with. + :param cls: The subclass of `Group` to operate on. :param warn_invalid_child: Issue a warning, when iterating children, if a child path is invalid. """ + if not issubclass(cls, orm.Group): + raise TypeError('cls must a subclass of Group: {}'.format(cls)) + self._delimiter = '/' - if not isinstance(type_string, str): - raise TypeError('type_string must a str: {}'.format(type_string)) - self._type_string = type_string + self._cls = cls self._path_string = self._validate_path(path) self._path_list = self._path_string.split(self._delimiter) if path else [] self._warn_invalid_child = warn_invalid_child @@ -90,21 +91,21 @@ def _validate_path(self, path): def __repr__(self): # type: () -> str """Represent the instantiated class.""" - return "{}('{}', type='{}')".format(self.__class__.__name__, self.path, self.type_string) + return "{}('{}', cls='{}')".format(self.__class__.__name__, self.path, self.cls) def __eq__(self, other): # type: (Any) -> bool - """Compare equality of path and type string to another ``GroupPath`` object.""" + """Compare equality of path and ``Group`` subclass to another ``GroupPath`` object.""" if not isinstance(other, GroupPath): return NotImplemented - return (self.path, self.type_string) == (other.path, other.type_string) + return (self.path, self.cls) == (other.path, other.cls) def __lt__(self, other): # type: (Any) -> bool - """Compare less-than operator of path and type string to another ``GroupPath`` object.""" + """Compare less-than operator of path and ``Group`` subclass to another ``GroupPath`` object.""" if not isinstance(other, GroupPath): return NotImplemented - return (self.path, self.type_string) < (other.path, other.type_string) + return (self.path, self.cls) < (other.path, other.cls) @property def path(self): @@ -133,10 +134,10 @@ def delimiter(self): return self._delimiter @property - def type_string(self): + def cls(self): # type: () -> str - """Return the type_string used to query for and instantiate a ``Group`` with.""" - return self._type_string + """Return the cls used to query for and instantiate a ``Group`` with.""" + return self._cls @property def parent(self): @@ -144,9 +145,7 @@ def parent(self): """Return the parent path.""" if self.path_list: return GroupPath( - self.delimiter.join(self.path_list[:-1]), - type_string=self.type_string, - warn_invalid_child=self._warn_invalid_child + self.delimiter.join(self.path_list[:-1]), cls=self.cls, warn_invalid_child=self._warn_invalid_child ) return None @@ -158,7 +157,7 @@ def __truediv__(self, path): path = self._validate_path(path) child = GroupPath( path=self.path + self.delimiter + path if self.path else path, - type_string=self.type_string, + cls=self.cls, warn_invalid_child=self._warn_invalid_child ) return child @@ -169,10 +168,10 @@ def __getitem__(self, path): return self.__truediv__(path) def get_group(self): - # type: () -> Optional[orm.Group] + # type: () -> Optional[self.cls] """Return the concrete group associated with this path.""" try: - return orm.Group.objects.get(label=self.path, type_string=self.type_string) + return self.cls.objects.get(label=self.path) except NotExistent: return None @@ -182,16 +181,14 @@ def group_ids(self): """Return all the UUID associated with this GroupPath. :returns: and empty list, if no group associated with this label, - or can be multiple if type_string was None + or can be multiple if cls was None This is an efficient method for checking existence, which does not require the (slow) loading of the ORM entity. """ query = orm.QueryBuilder() filters = {'label': self.path} - if self.type_string is not None: - filters['type_string'] = self.type_string - query.append(orm.Group, filters=filters, project='id') + query.append(self.cls, filters=filters, project='id') return [r[0] for r in query.all()] @property @@ -201,11 +198,9 @@ def is_virtual(self): return len(self.group_ids) == 0 def get_or_create_group(self): - # type: () -> (orm.Group, bool) + # type: () -> (self.cls, bool) """Return the concrete group associated with this path or, create it, if it does not already exist.""" - if self.type_string is not None: - return orm.Group.objects.get_or_create(label=self.path, type_string=self.type_string) - return orm.Group.objects.get_or_create(label=self.path) + return self.cls.objects.get_or_create(label=self.path) def delete_group(self): """Delete the concrete group associated with this path. @@ -217,7 +212,7 @@ def delete_group(self): raise GroupNotFoundError(self) if len(ids) > 1: raise GroupNotUniqueError(self) - orm.Group.objects.delete(ids[0]) + self.cls.objects.delete(ids[0]) @property def children(self): @@ -227,9 +222,7 @@ def children(self): filters = {} if self.path: filters['label'] = {'like': self.path + self.delimiter + '%'} - if self.type_string is not None: - filters['type_string'] = self.type_string - query.append(orm.Group, filters=filters, project='label') + query.append(self.cls, filters=filters, project='label') if query.count() == 0 and self.is_virtual: raise NoGroupsInPathError(self) @@ -242,9 +235,7 @@ def children(self): if (path_string not in yielded and path[:len(self._path_list)] == self._path_list): yielded.append(path_string) try: - yield GroupPath( - path=path_string, type_string=self.type_string, warn_invalid_child=self._warn_invalid_child - ) + yield GroupPath(path=path_string, cls=self.cls, warn_invalid_child=self._warn_invalid_child) except InvalidPath: if self._warn_invalid_child: warnings.warn('invalid path encountered: {}'.format(path_string)) # pylint: disable=no-member @@ -291,9 +282,7 @@ def walk_nodes(self, filters=None, node_class=None, query_batch=None): group_filters = {} if self.path: group_filters['label'] = {'or': [{'==': self.path}, {'like': self.path + self.delimiter + '%'}]} - if self.type_string is not None: - group_filters['type_string'] = self.type_string - query.append(orm.Group, filters=group_filters, project='label', tag='group') + query.append(self.cls, filters=group_filters, project='label', tag='group') query.append( orm.Node if node_class is None else node_class, with_group='group', @@ -301,7 +290,7 @@ def walk_nodes(self, filters=None, node_class=None, query_batch=None): project=['*'], ) for (label, node) in query.iterall(query_batch) if query_batch else query.all(): - yield WalkNodeResult(GroupPath(label, type_string=self.type_string), node) + yield WalkNodeResult(GroupPath(label, cls=self.cls), node) @property def browse(self): @@ -330,9 +319,7 @@ def __init__(self, group_path): def __repr__(self): # type: () -> str """Represent the instantiated class.""" - return "{}('{}', type='{}')".format( - self.__class__.__name__, self._group_path.path, self._group_path.type_string - ) + return "{}('{}', type='{}')".format(self.__class__.__name__, self._group_path.path, self._group_path.cls) def __call__(self): # type: () -> GroupPath diff --git a/aiida/tools/importexport/common/config.py b/aiida/tools/importexport/common/config.py index 0baac376c9..549c22be7d 100644 --- a/aiida/tools/importexport/common/config.py +++ b/aiida/tools/importexport/common/config.py @@ -9,15 +9,13 @@ ########################################################################### # pylint: disable=invalid-name """ Configuration file for AiiDA Import/Export module """ - -from aiida.orm import Computer, Group, GroupTypeString, Node, User, Log, Comment +from aiida.orm import Computer, Group, Node, User, Log, Comment __all__ = ('EXPORT_VERSION',) # Current export version EXPORT_VERSION = '0.8' -IMPORTGROUP_TYPE = GroupTypeString.IMPORTGROUP_TYPE.value DUPL_SUFFIX = ' (Imported #{})' # The name of the subfolder in which the node files are stored diff --git a/aiida/tools/importexport/dbimport/backends/django/__init__.py b/aiida/tools/importexport/dbimport/backends/django/__init__.py index d97ad70d1d..aa463f5ffb 100644 --- a/aiida/tools/importexport/dbimport/backends/django/__init__.py +++ b/aiida/tools/importexport/dbimport/backends/django/__init__.py @@ -21,10 +21,10 @@ from aiida.common.links import LinkType, validate_link_label from aiida.common.utils import grouper, get_object_from_string from aiida.orm.utils.repository import Repository -from aiida.orm import QueryBuilder, Node, Group +from aiida.orm import QueryBuilder, Node, Group, ImportGroup from aiida.tools.importexport.common import exceptions from aiida.tools.importexport.common.archive import extract_tree, extract_tar, extract_zip -from aiida.tools.importexport.common.config import DUPL_SUFFIX, IMPORTGROUP_TYPE, EXPORT_VERSION, NODES_EXPORT_SUBFOLDER +from aiida.tools.importexport.common.config import DUPL_SUFFIX, EXPORT_VERSION, NODES_EXPORT_SUBFOLDER from aiida.tools.importexport.common.config import ( NODE_ENTITY_NAME, GROUP_ENTITY_NAME, COMPUTER_ENTITY_NAME, USER_ENTITY_NAME, LOG_ENTITY_NAME, COMMENT_ENTITY_NAME ) @@ -673,7 +673,7 @@ def import_data_dj( "Overflow of import groups (more than 100 import groups exists with basename '{}')" ''.format(basename) ) - group = Group(label=group_label, type_string=IMPORTGROUP_TYPE).store() + group = ImportGroup(label=group_label).store() # Add all the nodes to the new group # TODO: decide if we want to return the group label diff --git a/aiida/tools/importexport/dbimport/backends/sqla/__init__.py b/aiida/tools/importexport/dbimport/backends/sqla/__init__.py index f08de125ec..2e800b1361 100644 --- a/aiida/tools/importexport/dbimport/backends/sqla/__init__.py +++ b/aiida/tools/importexport/dbimport/backends/sqla/__init__.py @@ -20,13 +20,13 @@ from aiida.common.folders import SandboxFolder, RepositoryFolder from aiida.common.links import LinkType from aiida.common.utils import get_object_from_string -from aiida.orm import QueryBuilder, Node, Group, WorkflowNode, CalculationNode, Data +from aiida.orm import QueryBuilder, Node, Group, ImportGroup from aiida.orm.utils.links import link_triple_exists, validate_link from aiida.orm.utils.repository import Repository from aiida.tools.importexport.common import exceptions from aiida.tools.importexport.common.archive import extract_tree, extract_tar, extract_zip -from aiida.tools.importexport.common.config import DUPL_SUFFIX, IMPORTGROUP_TYPE, EXPORT_VERSION, NODES_EXPORT_SUBFOLDER +from aiida.tools.importexport.common.config import DUPL_SUFFIX, EXPORT_VERSION, NODES_EXPORT_SUBFOLDER from aiida.tools.importexport.common.config import ( NODE_ENTITY_NAME, GROUP_ENTITY_NAME, COMPUTER_ENTITY_NAME, USER_ENTITY_NAME, LOG_ENTITY_NAME, COMMENT_ENTITY_NAME ) @@ -664,7 +664,7 @@ def import_data_sqla( "Overflow of import groups (more than 100 import groups exists with basename '{}')" ''.format(basename) ) - group = Group(label=group_label, type_string=IMPORTGROUP_TYPE) + group = ImportGroup(label=group_label) session.add(group.backend_entity._dbmodel) # Adding nodes to group avoiding the SQLA ORM to increase speed diff --git a/docs/source/working_with_aiida/groups.rst b/docs/source/working_with_aiida/groups.rst index 58eadcc024..55bd82ebd5 100644 --- a/docs/source/working_with_aiida/groups.rst +++ b/docs/source/working_with_aiida/groups.rst @@ -18,144 +18,162 @@ be performed with groups: Create a new Group ------------------ - From the command line interface:: +From the command line interface:: - verdi group create test_group + verdi group create test_group - From the python interface:: +From the python interface:: - In [1]: group = Group(label="test_group") - - In [2]: group.store() - Out[2]: + In [1]: group = Group(label="test_group") + In [2]: group.store() + Out[2]: List available Groups --------------------- - Example:: +Example:: - verdi group list + verdi group list - By default ``verdi group list`` only shows groups of the type *user*. - In case you want to show groups of another type use ``-t/--type`` option. If - you want to show groups of all types, use the ``-a/--all-types`` option. +By default ``verdi group list`` only shows groups of the type *user*. +In case you want to show groups of another type use ``-t/--type`` option. If +you want to show groups of all types, use the ``-a/--all-types`` option. - From the command line interface:: +From the command line interface:: - verdi group list -t user + verdi group list -t user - From the python interface:: +From the python interface:: - In [1]: query = QueryBuilder() + In [1]: query = QueryBuilder() - In [2]: query.append(Group, filters={'type_string':'user'}) - Out[2]: + In [2]: query.append(Group, filters={'type_string':'user'}) + Out[2]: - In [3]: query.all() - Out[3]: - [[], - [], - []] + In [3]: query.all() + Out[3]: + [[], + [], + []] Add nodes to a Group -------------------- - Once the ``test_group`` has been created, we can add nodes to it. To add the node with ``pk=1`` to the group we need to do the following. - - From the command line interface:: - - verdi group add-nodes -G test_group 1 - Do you really want to add 1 nodes to Group? [y/N]: y - - From the python interface:: +Once the ``test_group`` has been created, we can add nodes to it. To add the node with ``pk=1`` to the group we need to do the following. - In [1]: group = Group.get(label='test_group') +From the command line interface:: - In [2]: from aiida.orm import Dict + verdi group add-nodes -G test_group 1 + Do you really want to add 1 nodes to Group? [y/N]: y - In [3]: p = Dict().store() +From the python interface:: - In [4]: p - Out[4]: + In [1]: group = Group.get(label='test_group') + In [2]: from aiida.orm import Dict + In [3]: p = Dict().store() + In [4]: p + Out[4]: + In [5]: group.add_nodes(p) - In [5]: group.add_nodes(p) Show information about a Group ------------------------------ - From the command line interface:: - - verdi group show test_group - ----------------- ---------------- - Group label test_group - Group type_string user - Group description - ----------------- ---------------- - # Nodes: - PK Type Created - ---- ------ --------------- - 1 Code 26D:21h:45m ago +From the command line interface:: + verdi group show test_group + ----------------- ---------------- + Group label test_group + Group type_string user + Group description + ----------------- ---------------- + # Nodes: + PK Type Created + ---- ------ --------------- + 1 Code 26D:21h:45m ago Remove nodes from a Group ------------------------- - From the command line interface:: +From the command line interface:: - verdi group remove-nodes -G test_group 1 - Do you really want to remove 1 nodes from Group? [y/N]: y + verdi group remove-nodes -G test_group 1 + Do you really want to remove 1 nodes from Group? [y/N]: y - From the python interface:: +From the python interface:: - In [1]: group = Group.get(label='test_group') + In [1]: group = Group.get(label='test_group') + In [2]: group.clear() - In [2]: group.clear() Rename Group ------------ - From the command line interface:: +From the command line interface:: verdi group relabel test_group old_group Success: Label changed to old_group - From the python interface:: +From the python interface:: - In [1]: group = Group.get(label='old_group') + In [1]: group = Group.get(label='old_group') + In [2]: group.label = "another_group" - In [2]: group.label = "another_group" Delete Group ------------ - From the command line interface:: +From the command line interface:: verdi group delete another_group Are you sure to delete Group? [y/N]: y Success: Group deleted. - Copy one group into another --------------------------- - This operation will copy the nodes of the source group into the destination - group. Moreover, if the destination group did not exist before, it will - be created automatically. +This operation will copy the nodes of the source group into the destination +group. Moreover, if the destination group did not exist before, it will +be created automatically. + +From the command line interface:: + + verdi group copy source_group dest_group + Success: Nodes copied from group to group + +From the python interface:: + + In [1]: src_group = Group.objects.get(label='source_group') + In [2]: dest_group = Group(label='destination_group').store() + In [3]: dest_group.add_nodes(list(src_group.nodes)) + + +Create a `Group` subclass +------------------------- +It is possible to create a subclass of `Group` to implement custom functionality. +To make the instances of the subclass storable and loadable, it has to be registered through an entry point in the ``aiida.groups`` entry point category. +For example, assuming we have a subclass ``SubClassGroup`` in the module ``aiida_plugin.groups.sub_class:SubClassGroup``, to register it, one has to add the following to the ``setup.py`` of the plugin package:: - From the command line interface:: + "entry_points": { + "aiida.groups": [ + "plugin.sub_class = aiida_plugin.groups.sub_class:SubClassGroup" + ] + } - verdi group copy source_group dest_group - Success: Nodes copied from group to group +Now that the subclass is properly registered, instances can be stored:: - From the python interface:: + group = SubClassGroup(label='sub-class-group') + group.store() - In [1]: src_group = Group.objects.get(label='source_group') +The ``type_string`` of the group instance corresponds to the entry point name and so in this example is ``plugin.sub_class``. +This is what AiiDA uses to load the correct class when reloading the group from the database:: - In [2]: dest_group = Group(label='destination_group').store() + group = load_group(group.pk) + assert isinstance(group, SubClassGroup) - In [3]: dest_group.add_nodes(list(src_group.nodes)) +If the entry point is not currently registered, because the corresponding plugin package is not installed for example, AiiDA will issue a warning and fallback onto the ``Group`` base class. diff --git a/setup.json b/setup.json index ebe8360174..7f4b7783ca 100644 --- a/setup.json +++ b/setup.json @@ -159,6 +159,12 @@ "structure = aiida.orm.nodes.data.structure:StructureData", "upf = aiida.orm.nodes.data.upf:UpfData" ], + "aiida.groups": [ + "core = aiida.orm.groups:Group", + "core.auto = aiida.orm.groups:AutoGroup", + "core.import = aiida.orm.groups:ImportGroup", + "core.upf = aiida.orm.groups:UpfFamily" + ], "aiida.node": [ "data = aiida.orm.nodes.data.data:Data", "process = aiida.orm.nodes.process.process:ProcessNode", diff --git a/tests/backends/aiida_django/migrations/test_migrations_0044_dbgroup_type_string.py b/tests/backends/aiida_django/migrations/test_migrations_0044_dbgroup_type_string.py new file mode 100644 index 0000000000..ab1b31d518 --- /dev/null +++ b/tests/backends/aiida_django/migrations/test_migrations_0044_dbgroup_type_string.py @@ -0,0 +1,63 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +# pylint: disable=import-error,no-name-in-module,invalid-name +"""Test migration of `type_string` after the `Group` class became pluginnable.""" + +from .test_migrations_common import TestMigrations + + +class TestGroupTypeStringMigration(TestMigrations): + """Test migration of `type_string` after the `Group` class became pluginnable.""" + + migrate_from = '0043_default_link_label' + migrate_to = '0044_dbgroup_type_string' + + def setUpBeforeMigration(self): + DbGroup = self.apps.get_model('db', 'DbGroup') + + # test user group type_string: 'user' -> 'core' + group_user = DbGroup(label='01', user_id=self.default_user.id, type_string='user') + group_user.save() + self.group_user_pk = group_user.pk + + # test data.upf group type_string: 'data.upf' -> 'core.upf' + group_data_upf = DbGroup(label='02', user_id=self.default_user.id, type_string='data.upf') + group_data_upf.save() + self.group_data_upf_pk = group_data_upf.pk + + # test auto.import group type_string: 'auto.import' -> 'core.import' + group_autoimport = DbGroup(label='03', user_id=self.default_user.id, type_string='auto.import') + group_autoimport.save() + self.group_autoimport_pk = group_autoimport.pk + + # test auto.run group type_string: 'auto.run' -> 'core.auto' + group_autorun = DbGroup(label='04', user_id=self.default_user.id, type_string='auto.run') + group_autorun.save() + self.group_autorun_pk = group_autorun.pk + + def test_group_string_update(self): + """Test that the type_string were updated correctly.""" + DbGroup = self.apps.get_model('db', 'DbGroup') + + # 'user' -> 'core' + group_user = DbGroup.objects.get(pk=self.group_user_pk) + self.assertEqual(group_user.type_string, 'core') + + # 'data.upf' -> 'core.upf' + group_data_upf = DbGroup.objects.get(pk=self.group_data_upf_pk) + self.assertEqual(group_data_upf.type_string, 'core.upf') + + # 'auto.import' -> 'core.import' + group_autoimport = DbGroup.objects.get(pk=self.group_autoimport_pk) + self.assertEqual(group_autoimport.type_string, 'core.import') + + # 'auto.run' -> 'core.auto' + group_autorun = DbGroup.objects.get(pk=self.group_autorun_pk) + self.assertEqual(group_autorun.type_string, 'core.auto') diff --git a/tests/backends/aiida_sqlalchemy/test_migrations.py b/tests/backends/aiida_sqlalchemy/test_migrations.py index 8e2046f293..2bb52ceecc 100644 --- a/tests/backends/aiida_sqlalchemy/test_migrations.py +++ b/tests/backends/aiida_sqlalchemy/test_migrations.py @@ -1642,3 +1642,69 @@ def test_data_migrated(self): finally: session.close() + + +class TestGroupTypeStringMigration(TestMigrationsSQLA): + """Test the migration that renames the DbGroup type strings.""" + + migrate_from = '118349c10896' # 118349c10896_default_link_label.py + migrate_to = 'bf591f31dd12' # bf591f31dd12_dbgroup_type_string.py + + def setUpBeforeMigration(self): + """Create the DbGroups with the old type strings.""" + DbGroup = self.get_current_table('db_dbgroup') # pylint: disable=invalid-name + DbUser = self.get_current_table('db_dbuser') # pylint: disable=invalid-name + + with self.get_session() as session: + try: + default_user = DbUser(email='{}@aiida.net'.format(self.id())) + session.add(default_user) + session.commit() + + # test user group type_string: 'user' -> 'core' + group_user = DbGroup(label='01', user_id=default_user.id, type_string='user') + session.add(group_user) + # test data.upf group type_string: 'data.upf' -> 'core.upf' + group_data_upf = DbGroup(label='02', user_id=default_user.id, type_string='data.upf') + session.add(group_data_upf) + # test auto.import group type_string: 'auto.import' -> 'core.import' + group_autoimport = DbGroup(label='03', user_id=default_user.id, type_string='auto.import') + session.add(group_autoimport) + # test auto.run group type_string: 'auto.run' -> 'core.auto' + group_autorun = DbGroup(label='04', user_id=default_user.id, type_string='auto.run') + session.add(group_autorun) + + session.commit() + + # Store values for later tests + self.group_user_pk = group_user.id + self.group_data_upf_pk = group_data_upf.id + self.group_autoimport_pk = group_autoimport.id + self.group_autorun_pk = group_autorun.id + + finally: + session.close() + + def test_group_string_update(self): + """Test that the type strings are properly migrated.""" + DbGroup = self.get_current_table('db_dbgroup') # pylint: disable=invalid-name + + with self.get_session() as session: + try: + # test user group type_string: 'user' -> 'core' + group_user = session.query(DbGroup).filter(DbGroup.id == self.group_user_pk).one() + self.assertEqual(group_user.type_string, 'core') + + # test data.upf group type_string: 'data.upf' -> 'core.upf' + group_data_upf = session.query(DbGroup).filter(DbGroup.id == self.group_data_upf_pk).one() + self.assertEqual(group_data_upf.type_string, 'core.upf') + + # test auto.import group type_string: 'auto.import' -> 'core.import' + group_autoimport = session.query(DbGroup).filter(DbGroup.id == self.group_autoimport_pk).one() + self.assertEqual(group_autoimport.type_string, 'core.import') + + # test auto.run group type_string: 'auto.run' -> 'core.auto' + group_autorun = session.query(DbGroup).filter(DbGroup.id == self.group_autorun_pk).one() + self.assertEqual(group_autorun.type_string, 'core.auto') + finally: + session.close() diff --git a/tests/cmdline/commands/test_group.py b/tests/cmdline/commands/test_group.py index 4302420633..ab79f650b1 100644 --- a/tests/cmdline/commands/test_group.py +++ b/tests/cmdline/commands/test_group.py @@ -165,7 +165,7 @@ def test_show(self): self.assertClickResultNoException(result) for grpline in [ - 'Group label', 'dummygroup1', 'Group type_string', 'user', 'Group description', '' + 'Group label', 'dummygroup1', 'Group type_string', 'core', 'Group description', '' ]: self.assertIn(grpline, result.output) diff --git a/tests/cmdline/commands/test_group_ls.py b/tests/cmdline/commands/test_group_ls.py index 7cc01079a4..1a8bb7d2ef 100644 --- a/tests/cmdline/commands/test_group_ls.py +++ b/tests/cmdline/commands/test_group_ls.py @@ -22,12 +22,13 @@ def setup_groups(clear_database_before_test): """Setup some groups for testing.""" for label in ['a', 'a/b', 'a/c/d', 'a/c/e/g', 'a/f']: - group, _ = orm.Group.objects.get_or_create(label, type_string=orm.GroupTypeString.USER.value) + group, _ = orm.Group.objects.get_or_create(label) group.description = 'A description of {}'.format(label) - orm.Group.objects.get_or_create('a/x', type_string=orm.GroupTypeString.UPFGROUP_TYPE.value) + orm.UpfFamily.objects.get_or_create('a/x') yield +@pytest.mark.skip('Reenable when subclassing in the query builder is implemented (#3902)') def test_with_no_opts(setup_groups): """Test ``verdi group path ls``""" @@ -46,6 +47,7 @@ def test_with_no_opts(setup_groups): assert result.output == 'a/c/d\na/c/e\n' +@pytest.mark.skip('Reenable when subclassing in the query builder is implemented (#3902)') def test_recursive(setup_groups): """Test ``verdi group path ls --recursive``""" @@ -61,6 +63,7 @@ def test_recursive(setup_groups): assert result.output == 'a/c/d\na/c/e\na/c/e/g\n' +@pytest.mark.skip('Reenable when subclassing in the query builder is implemented (#3902)') @pytest.mark.parametrize('tag', ['-l', '--long']) def test_long(setup_groups, tag): """Test ``verdi group path ls --long``""" @@ -106,6 +109,7 @@ def test_long(setup_groups, tag): ) +@pytest.mark.skip('Reenable when subclassing in the query builder is implemented (#3902)') @pytest.mark.parametrize('tag', ['--no-virtual']) def test_groups_only(setup_groups, tag): """Test ``verdi group path ls --no-virtual``""" diff --git a/tests/cmdline/commands/test_run.py b/tests/cmdline/commands/test_run.py index 78c858420f..4ed690bb20 100644 --- a/tests/cmdline/commands/test_run.py +++ b/tests/cmdline/commands/test_run.py @@ -9,6 +9,7 @@ ########################################################################### """Tests for `verdi run`.""" import tempfile +import textwrap import warnings from click.testing import CliRunner @@ -31,21 +32,22 @@ def test_run_workfunction(self): that are defined within the script will fail, as the inspect module will not correctly be able to determin the full path of the source file. """ - from aiida.orm import load_node - from aiida.orm import WorkFunctionNode + from aiida.orm import load_node, WorkFunctionNode - script_content = """ -#!/usr/bin/env python -from aiida.engine import workfunction + script_content = textwrap.dedent( + """\ + #!/usr/bin/env python + from aiida.engine import workfunction -@workfunction -def wf(): - pass + @workfunction + def wf(): + pass -if __name__ == '__main__': - result, node = wf.run_get_node() - print(node.pk) - """ + if __name__ == '__main__': + result, node = wf.run_get_node() + print(node.pk) + """ + ) # If `verdi run` is not setup correctly, the script above when run with `verdi run` will fail, because when # the engine will try to create the node for the workfunction and create a copy of its sourcefile, namely the @@ -77,9 +79,8 @@ def setUp(self): super().setUp() self.cli_runner = CliRunner() - # I need to disable the global variable of this test environment, - # because invoke is just calling the function and therefore inheriting - # the global variable + # I need to disable the global variable of this test environment, because invoke is just calling the function + # and therefore inheriting the global variable self._old_autogroup = autogroup.CURRENT_AUTOGROUP autogroup.CURRENT_AUTOGROUP = None @@ -92,12 +93,15 @@ def tearDown(self): def test_autogroup(self): """Check if the autogroup is properly generated.""" - from aiida.orm import QueryBuilder, Node, Group, load_node + from aiida.orm import QueryBuilder, Node, AutoGroup, load_node - script_content = """from aiida.orm import Data -node = Data().store() -print(node.pk) -""" + script_content = textwrap.dedent( + """\ + from aiida.orm import Data + node = Data().store() + print(node.pk) + """ + ) with tempfile.NamedTemporaryFile(mode='w+') as fhandle: fhandle.write(script_content) @@ -111,7 +115,7 @@ def test_autogroup(self): _ = load_node(pk) # Check if the node can be loaded queryb = QueryBuilder().append(Node, filters={'id': pk}, tag='node') - queryb.append(Group, with_node='node', filters={'type_string': 'auto.run'}, project='*') + queryb.append(AutoGroup, with_node='node', project='*') all_auto_groups = queryb.all() self.assertEqual( len(all_auto_groups), 1, 'There should be only one autogroup associated with the node just created' @@ -119,12 +123,16 @@ def test_autogroup(self): def test_autogroup_custom_label(self): """Check if the autogroup is properly generated with the label specified.""" - from aiida.orm import QueryBuilder, Node, Group, load_node + from aiida.orm import QueryBuilder, Node, AutoGroup, load_node + + script_content = textwrap.dedent( + """\ + from aiida.orm import Data + node = Data().store() + print(node.pk) + """ + ) - script_content = """from aiida.orm import Data -node = Data().store() -print(node.pk) -""" autogroup_label = 'SOME_group_LABEL' with tempfile.NamedTemporaryFile(mode='w+') as fhandle: fhandle.write(script_content) @@ -138,7 +146,7 @@ def test_autogroup_custom_label(self): _ = load_node(pk) # Check if the node can be loaded queryb = QueryBuilder().append(Node, filters={'id': pk}, tag='node') - queryb.append(Group, with_node='node', filters={'type_string': 'auto.run'}, project='*') + queryb.append(AutoGroup, with_node='node', project='*') all_auto_groups = queryb.all() self.assertEqual( len(all_auto_groups), 1, 'There should be only one autogroup associated with the node just created' @@ -147,12 +155,15 @@ def test_autogroup_custom_label(self): def test_no_autogroup(self): """Check if the autogroup is not generated if ``verdi run`` is asked not to.""" - from aiida.orm import QueryBuilder, Node, Group, load_node + from aiida.orm import QueryBuilder, Node, AutoGroup, load_node - script_content = """from aiida.orm import Data -node = Data().store() -print(node.pk) -""" + script_content = textwrap.dedent( + """\ + from aiida.orm import Data + node = Data().store() + print(node.pk) + """ + ) with tempfile.NamedTemporaryFile(mode='w+') as fhandle: fhandle.write(script_content) @@ -166,61 +177,64 @@ def test_no_autogroup(self): _ = load_node(pk) # Check if the node can be loaded queryb = QueryBuilder().append(Node, filters={'id': pk}, tag='node') - queryb.append(Group, with_node='node', filters={'type_string': 'auto.run'}, project='*') + queryb.append(AutoGroup, with_node='node', project='*') all_auto_groups = queryb.all() self.assertEqual(len(all_auto_groups), 0, 'There should be no autogroup generated') def test_autogroup_filter_class(self): # pylint: disable=too-many-locals """Check if the autogroup is properly generated but filtered classes are skipped.""" - from aiida.orm import QueryBuilder, Node, Group, load_node - - script_content = """import sys -from aiida.orm import Computer, Int, ArrayData, KpointsData, CalculationNode, WorkflowNode -from aiida.plugins import CalculationFactory -from aiida.engine import run_get_node -ArithmeticAdd = CalculationFactory('arithmetic.add') - -computer = Computer( - name='localhost-example-{}'.format(sys.argv[1]), - hostname='localhost', - description='my computer', - transport_type='local', - scheduler_type='direct', - workdir='/tmp' -).store() -computer.configure() - -code = Code( - input_plugin_name='arithmetic.add', - remote_computer_exec=[computer, '/bin/true']).store() -inputs = { - 'x': Int(1), - 'y': Int(2), - 'code': code, - 'metadata': { - 'options': { - 'resources': { - 'num_machines': 1, - 'num_mpiprocs_per_machine': 1 + from aiida.orm import Code, QueryBuilder, Node, AutoGroup, load_node + + script_content = textwrap.dedent( + """\ + import sys + from aiida.orm import Computer, Int, ArrayData, KpointsData, CalculationNode, WorkflowNode + from aiida.plugins import CalculationFactory + from aiida.engine import run_get_node + ArithmeticAdd = CalculationFactory('arithmetic.add') + + computer = Computer( + name='localhost-example-{}'.format(sys.argv[1]), + hostname='localhost', + description='my computer', + transport_type='local', + scheduler_type='direct', + workdir='/tmp' + ).store() + computer.configure() + + code = Code( + input_plugin_name='arithmetic.add', + remote_computer_exec=[computer, '/bin/true']).store() + inputs = { + 'x': Int(1), + 'y': Int(2), + 'code': code, + 'metadata': { + 'options': { + 'resources': { + 'num_machines': 1, + 'num_mpiprocs_per_machine': 1 + } + } + } } - } - } -} - -node1 = KpointsData().store() -node2 = ArrayData().store() -node3 = Int(3).store() -node4 = CalculationNode().store() -node5 = WorkflowNode().store() -_, node6 = run_get_node(ArithmeticAdd, **inputs) -print(node1.pk) -print(node2.pk) -print(node3.pk) -print(node4.pk) -print(node5.pk) -print(node6.pk) -""" - from aiida.orm import Code + + node1 = KpointsData().store() + node2 = ArrayData().store() + node3 = Int(3).store() + node4 = CalculationNode().store() + node5 = WorkflowNode().store() + _, node6 = run_get_node(ArithmeticAdd, **inputs) + print(node1.pk) + print(node2.pk) + print(node3.pk) + print(node4.pk) + print(node5.pk) + print(node6.pk) + """ + ) + Code() for idx, ( flags, @@ -283,27 +297,27 @@ def test_autogroup_filter_class(self): # pylint: disable=too-many-locals _ = load_node(pk6) # Check if the node can be loaded queryb = QueryBuilder().append(Node, filters={'id': pk1}, tag='node') - queryb.append(Group, with_node='node', filters={'type_string': 'auto.run'}, project='*') + queryb.append(AutoGroup, with_node='node', project='*') all_auto_groups_kptdata = queryb.all() queryb = QueryBuilder().append(Node, filters={'id': pk2}, tag='node') - queryb.append(Group, with_node='node', filters={'type_string': 'auto.run'}, project='*') + queryb.append(AutoGroup, with_node='node', project='*') all_auto_groups_arraydata = queryb.all() queryb = QueryBuilder().append(Node, filters={'id': pk3}, tag='node') - queryb.append(Group, with_node='node', filters={'type_string': 'auto.run'}, project='*') + queryb.append(AutoGroup, with_node='node', project='*') all_auto_groups_int = queryb.all() queryb = QueryBuilder().append(Node, filters={'id': pk4}, tag='node') - queryb.append(Group, with_node='node', filters={'type_string': 'auto.run'}, project='*') + queryb.append(AutoGroup, with_node='node', project='*') all_auto_groups_calc = queryb.all() queryb = QueryBuilder().append(Node, filters={'id': pk5}, tag='node') - queryb.append(Group, with_node='node', filters={'type_string': 'auto.run'}, project='*') + queryb.append(AutoGroup, with_node='node', project='*') all_auto_groups_wf = queryb.all() queryb = QueryBuilder().append(Node, filters={'id': pk6}, tag='node') - queryb.append(Group, with_node='node', filters={'type_string': 'auto.run'}, project='*') + queryb.append(AutoGroup, with_node='node', project='*') all_auto_groups_calcarithmetic = queryb.all() self.assertEqual( @@ -339,12 +353,16 @@ def test_autogroup_filter_class(self): # pylint: disable=too-many-locals def test_autogroup_clashing_label(self): """Check if the autogroup label is properly (re)generated when it clashes with an existing group name.""" - from aiida.orm import QueryBuilder, Node, Group, load_node + from aiida.orm import QueryBuilder, Node, AutoGroup, load_node + + script_content = textwrap.dedent( + """\ + from aiida.orm import Data + node = Data().store() + print(node.pk) + """ + ) - script_content = """from aiida.orm import Data -node = Data().store() -print(node.pk) -""" autogroup_label = 'SOME_repeated_group_LABEL' with tempfile.NamedTemporaryFile(mode='w+') as fhandle: fhandle.write(script_content) @@ -358,7 +376,7 @@ def test_autogroup_clashing_label(self): pk = int(result.output) _ = load_node(pk) # Check if the node can be loaded queryb = QueryBuilder().append(Node, filters={'id': pk}, tag='node') - queryb.append(Group, with_node='node', filters={'type_string': 'auto.run'}, project='*') + queryb.append(AutoGroup, with_node='node', project='*') all_auto_groups = queryb.all() self.assertEqual( len(all_auto_groups), 1, 'There should be only one autogroup associated with the node just created' @@ -374,7 +392,7 @@ def test_autogroup_clashing_label(self): pk = int(result.output) _ = load_node(pk) # Check if the node can be loaded queryb = QueryBuilder().append(Node, filters={'id': pk}, tag='node') - queryb.append(Group, with_node='node', filters={'type_string': 'auto.run'}, project='*') + queryb.append(AutoGroup, with_node='node', project='*') all_auto_groups = queryb.all() self.assertEqual( len(all_auto_groups), 1, 'There should be only one autogroup associated with the node just created' @@ -383,12 +401,15 @@ def test_autogroup_clashing_label(self): def test_legacy_autogroup_name(self): """Check if the autogroup is properly generated when using the legacy --group-name flag.""" - from aiida.orm import QueryBuilder, Node, Group, load_node - - script_content = """from aiida.orm import Data -node = Data().store() -print(node.pk) -""" + from aiida.orm import QueryBuilder, Node, AutoGroup, load_node + + script_content = textwrap.dedent( + """\ + from aiida.orm import Data + node = Data().store() + print(node.pk) + """ + ) group_label = 'legacy-group-name' with tempfile.NamedTemporaryFile(mode='w+') as fhandle: @@ -409,7 +430,7 @@ def test_legacy_autogroup_name(self): _ = load_node(pk) # Check if the node can be loaded queryb = QueryBuilder().append(Node, filters={'id': pk}, tag='node') - queryb.append(Group, with_node='node', filters={'type_string': 'auto.run'}, project='*') + queryb.append(AutoGroup, with_node='node', project='*') all_auto_groups = queryb.all() self.assertEqual( len(all_auto_groups), 1, 'There should be only one autogroup associated with the node just created' diff --git a/tests/orm/data/test_upf.py b/tests/orm/data/test_upf.py index 228f8d9b77..02922bc60f 100644 --- a/tests/orm/data/test_upf.py +++ b/tests/orm/data/test_upf.py @@ -10,7 +10,6 @@ """ This module contains tests for UpfData and UpfData related functions. """ - import errno import tempfile import shutil @@ -95,8 +94,8 @@ def setUp(self): def tearDown(self): """Delete all groups and destroy the temporary directory created.""" - for group in orm.Group.objects.find(filters={'type_string': orm.GroupTypeString.UPFGROUP_TYPE.value}): - orm.Group.objects.delete(group.pk) + for group in orm.UpfFamily.objects.find(): + orm.UpfFamily.objects.delete(group.pk) try: shutil.rmtree(self.temp_dir) @@ -122,32 +121,31 @@ def test_get_upf_family_names(self): """Test the `UpfData.get_upf_family_names` method.""" label = 'family' - family, _ = orm.Group.objects.get_or_create(label=label, type_string=orm.GroupTypeString.UPFGROUP_TYPE.value) + family, _ = orm.UpfFamily.objects.get_or_create(label=label) family.add_nodes([self.pseudo_barium]) family.store() - self.assertEqual({group.label for group in orm.UpfData.get_upf_groups()}, {label}) + self.assertEqual({group.label for group in orm.UpfFamily.objects.all()}, {label}) self.assertEqual(self.pseudo_barium.get_upf_family_names(), [label]) def test_get_upf_groups(self): """Test the `UpfData.get_upf_groups` class method.""" - type_string = orm.GroupTypeString.UPFGROUP_TYPE.value label_01 = 'family_01' label_02 = 'family_02' user = orm.User(email='alternate@localhost').store() - self.assertEqual(orm.UpfData.get_upf_groups(), []) + self.assertEqual(orm.UpfFamily.objects.all(), []) # Create group with default user and add `Ba` pseudo - family_01, _ = orm.Group.objects.get_or_create(label=label_01, type_string=type_string) + family_01, _ = orm.UpfFamily.objects.get_or_create(label=label_01) family_01.add_nodes([self.pseudo_barium]) family_01.store() self.assertEqual({group.label for group in orm.UpfData.get_upf_groups()}, {label_01}) # Create group with different user and add `O` pseudo - family_02, _ = orm.Group.objects.get_or_create(label=label_02, type_string=type_string, user=user) + family_02, _ = orm.UpfFamily.objects.get_or_create(label=label_02, user=user) family_02.add_nodes([self.pseudo_oxygen]) family_02.store() diff --git a/tests/orm/test_groups.py b/tests/orm/test_groups.py index ce2797daad..67b189195c 100644 --- a/tests/orm/test_groups.py +++ b/tests/orm/test_groups.py @@ -8,6 +8,8 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Test for the Group ORM class.""" +import pytest + from aiida import orm from aiida.backends.testbase import AiidaTestCase from aiida.common import exceptions @@ -272,3 +274,54 @@ def test_group_uuid_hashing_for_querybuidler(self): # And that the results are correct self.assertEqual(builder.count(), 1) self.assertEqual(builder.first()[0], group.id) + + +class TestGroupsSubclasses(AiidaTestCase): + """Test rules around creating `Group` subclasses.""" + + @staticmethod + def test_creation_registered(): + """Test rules around creating registered `Group` subclasses.""" + group = orm.AutoGroup('some-label') + assert isinstance(group, orm.AutoGroup) + assert group.type_string == 'core.auto' + + group, _ = orm.AutoGroup.objects.get_or_create('some-auto-group') + assert isinstance(group, orm.AutoGroup) + assert group.type_string == 'core.auto' + + @staticmethod + def test_creation_unregistered(): + """Test rules around creating `Group` subclasses without a registered entry point.""" + + # Defining an unregistered subclas should issue a warning and its type string should be set to `None` + with pytest.warns(UserWarning): + + class SubGroup(orm.Group): + pass + + assert SubGroup._type_string is None # pylint: disable=protected-access + + # Creating an instance is allowed + instance = SubGroup(label='subgroup') + assert instance._type_string is None # pylint: disable=protected-access + + # Storing the instance, however, is forbidden and should raise + with pytest.raises(exceptions.StoringNotAllowed): + instance.store() + + @staticmethod + def test_loading_unregistered(): + """Test rules around loading `Group` subclasses without a registered entry point. + + Storing instances of unregistered subclasses is not allowed so we have to create one sneakily by instantiating + a normal group and manipulating the type string directly on the database model. + """ + group = orm.Group(label='group') + group.backend_entity.dbmodel.type_string = 'unregistered.subclass' + group.store() + + with pytest.warns(UserWarning): + loaded = orm.load_group(group.pk) + + assert isinstance(loaded, orm.Group) diff --git a/tests/tools/graph/test_age.py b/tests/tools/graph/test_age.py index dddf2323c2..538087c7d7 100644 --- a/tests/tools/graph/test_age.py +++ b/tests/tools/graph/test_age.py @@ -494,7 +494,7 @@ def test_groups(self): # Rule that only gets nodes connected by the same group queryb = orm.QueryBuilder() queryb.append(orm.Node, tag='nodes_in_set') - queryb.append(orm.Group, with_node='nodes_in_set', tag='groups_considered', filters={'type_string': 'user'}) + queryb.append(orm.Group, with_node='nodes_in_set', tag='groups_considered') queryb.append(orm.Data, with_group='groups_considered') initial_node = [node2.id] @@ -513,7 +513,7 @@ def test_groups(self): # But two rules chained should get both nodes and groups... queryb = orm.QueryBuilder() queryb.append(orm.Node, tag='nodes_in_set') - queryb.append(orm.Group, with_node='nodes_in_set', filters={'type_string': 'user'}) + queryb.append(orm.Group, with_node='nodes_in_set') rule1 = UpdateRule(queryb) queryb = orm.QueryBuilder() @@ -569,7 +569,7 @@ def test_groups(self): qb1 = orm.QueryBuilder() qb1.append(orm.Node, tag='nodes_in_set') - qb1.append(orm.Group, with_node='nodes_in_set', filters={'type_string': 'user'}) + qb1.append(orm.Group, with_node='nodes_in_set') rule1 = UpdateRule(qb1, track_edges=True) qb2 = orm.QueryBuilder() diff --git a/tests/tools/groups/test_paths.py b/tests/tools/groups/test_paths.py index a6f1cdb757..b6e8940ce5 100644 --- a/tests/tools/groups/test_paths.py +++ b/tests/tools/groups/test_paths.py @@ -7,19 +7,19 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -"""Tests for GroupPath""" # pylint: disable=redefined-outer-name,unused-argument +"""Tests for GroupPath""" import pytest from aiida import orm -from aiida.tools.groups.paths import (GroupAttr, GroupPath, InvalidPath, GroupNotFoundError, NoGroupsInPathError) +from aiida.tools.groups.paths import GroupAttr, GroupPath, InvalidPath, GroupNotFoundError, NoGroupsInPathError @pytest.fixture def setup_groups(clear_database_before_test): """Setup some groups for testing.""" for label in ['a', 'a/b', 'a/c/d', 'a/c/e/g', 'a/f']: - group, _ = orm.Group.objects.get_or_create(label, type_string=orm.GroupTypeString.USER.value) + group, _ = orm.Group.objects.get_or_create(label) group.description = 'A description of {}'.format(label) yield @@ -117,16 +117,17 @@ def test_walk(setup_groups): def test_walk_with_invalid_path(clear_database_before_test): + """Test the ``GroupPath.walk`` method with invalid paths.""" for label in ['a', 'a/b', 'a/c/d', 'a/c/e/g', 'a/f', 'bad//group', 'bad/other']: - orm.Group.objects.get_or_create(label, type_string=orm.GroupTypeString.USER.value) + orm.Group.objects.get_or_create(label) group_path = GroupPath() - assert [c.path for c in sorted(group_path.walk()) - ] == ['a', 'a/b', 'a/c', 'a/c/d', 'a/c/e', 'a/c/e/g', 'a/f', 'bad', 'bad/other'] + expected = ['a', 'a/b', 'a/c', 'a/c/d', 'a/c/e', 'a/c/e/g', 'a/f', 'bad', 'bad/other'] + assert [c.path for c in sorted(group_path.walk())] == expected def test_walk_nodes(clear_database): """Test the ``GroupPath.walk_nodes()`` function.""" - group, _ = orm.Group.objects.get_or_create('a', type_string=orm.GroupTypeString.USER.value) + group, _ = orm.Group.objects.get_or_create('a') node = orm.Data() node.set_attribute_many({'i': 1, 'j': 2}) node.store() @@ -135,17 +136,18 @@ def test_walk_nodes(clear_database): assert [(r.group_path.path, r.node.attributes) for r in group_path.walk_nodes()] == [('a', {'i': 1, 'j': 2})] -def test_type_string(clear_database_before_test): - """Test that only the type_string instantiated in ``GroupPath`` is returned.""" +@pytest.mark.skip('Reenable when subclassing in the query builder is implemented (#3902)') +def test_cls(clear_database_before_test): + """Test that only instances of `cls` or its subclasses are matched by ``GroupPath``.""" for label in ['a', 'a/b', 'a/c/d', 'a/c/e/g']: - orm.Group.objects.get_or_create(label, type_string=orm.GroupTypeString.USER.value) + orm.Group.objects.get_or_create(label) for label in ['a/c/e', 'a/f']: - orm.Group.objects.get_or_create(label, type_string=orm.GroupTypeString.UPFGROUP_TYPE.value) + orm.UpfFamily.objects.get_or_create(label) group_path = GroupPath() assert sorted([c.path for c in group_path.walk()]) == ['a', 'a/b', 'a/c', 'a/c/d', 'a/c/e', 'a/c/e/g'] - group_path = GroupPath(type_string=orm.GroupTypeString.UPFGROUP_TYPE.value) + group_path = GroupPath(cls=orm.UpfFamily) assert sorted([c.path for c in group_path.walk()]) == ['a', 'a/c', 'a/c/e', 'a/f'] - assert GroupPath('a/b/c') != GroupPath('a/b/c', type_string=orm.GroupTypeString.UPFGROUP_TYPE.value) + assert GroupPath('a/b/c') != GroupPath('a/b/c', cls=orm.UpfFamily) def test_attr(clear_database_before_test): diff --git a/tests/tools/importexport/test_prov_redesign.py b/tests/tools/importexport/test_prov_redesign.py index 37f9a485a0..5ef849c51c 100644 --- a/tests/tools/importexport/test_prov_redesign.py +++ b/tests/tools/importexport/test_prov_redesign.py @@ -229,7 +229,7 @@ def test_group_name_and_type_change(self, temp_dir): groups_type_string = [g.type_string for g in [group_user, group_upf]] # Assert correct type strings exists prior to export - self.assertListEqual(groups_type_string, ['user', 'data.upf']) + self.assertListEqual(groups_type_string, ['core', 'core.upf']) # Export node filename = os.path.join(temp_dir, 'export.tar.gz') @@ -268,4 +268,4 @@ def test_group_name_and_type_change(self, temp_dir): # Check type_string content of "import group" import_group = orm.load_group(imported_groups_uuid[0]) - self.assertEqual(import_group.type_string, 'auto.import') + self.assertEqual(import_group.type_string, 'core.import')