From 9c9737685a92ab5435e39b3e45805673ded1c250 Mon Sep 17 00:00:00 2001 From: Sebastiaan Huber Date: Mon, 25 Nov 2019 23:10:02 +0100 Subject: [PATCH] Make `Group` sub classable through entry points We add the `aiida.groups` entry point group where sub classes of the `aiida.orm.groups.Group` class can be registered. A new metaclass is used to automatically set the `type_string` based on the entry point of the `Group` sub class. This will make it possible to reload the correct sub class when reloading from the database. If the `GroupMeta` metaclass cannot retrieve the corresponding entry point of the subclass, a warning is issued that any instances of this class will not be storable and the `__type_string` attribute is set to `None`. This can be checked by the `store` method which will make it fail. We choose to only except in the `store` method such that it is still possible to define and instantiate subclasses of `Group` that have not yet been registered. This is useful for testing and experimenting. Since the group type strings are now based on the entry point names, the existing group type strings in the database have to be migrated: * `user` -> `core` * `data.upf.family` -> `core.upf` * `auto.import` -> `core.import` * `auto.run` -> `core.auto` When loading a `Group` instance from the database, the loader will try to resolve the type string to the corresponding subclass through the entry points. If this fails, a warning is issued and we fallback on the base `Group` class. --- .../db/migrations/0044_dbgroup_type_string.py | 44 ++++ .../backends/djsite/db/migrations/__init__.py | 2 +- .../bf591f31dd12_dbgroup_type_string.py | 45 ++++ aiida/cmdline/commands/cmd_data/cmd_upf.py | 11 +- aiida/cmdline/commands/cmd_group.py | 28 +-- aiida/cmdline/commands/cmd_run.py | 1 + aiida/cmdline/params/types/group.py | 4 +- aiida/orm/autogroup.py | 33 +-- aiida/orm/convert.py | 5 +- aiida/orm/groups.py | 108 ++++++--- aiida/orm/implementation/groups.py | 2 +- aiida/orm/nodes/data/upf.py | 30 +-- aiida/plugins/entry_point.py | 2 + aiida/plugins/factories.py | 23 +- aiida/tools/groups/paths.py | 71 +++--- aiida/tools/importexport/common/config.py | 4 +- .../dbimport/backends/django/__init__.py | 6 +- .../dbimport/backends/sqla/__init__.py | 6 +- docs/source/working_with_aiida/groups.rst | 160 +++++++------ setup.json | 6 + ...est_migrations_0044_dbgroup_type_string.py | 63 +++++ .../aiida_sqlalchemy/test_migrations.py | 66 ++++++ tests/cmdline/commands/test_group.py | 2 +- tests/cmdline/commands/test_group_ls.py | 8 +- tests/cmdline/commands/test_run.py | 223 ++++++++++-------- tests/orm/data/test_upf.py | 16 +- tests/orm/test_groups.py | 53 +++++ tests/tools/graph/test_age.py | 6 +- tests/tools/groups/test_paths.py | 28 ++- .../tools/importexport/test_prov_redesign.py | 4 +- 30 files changed, 696 insertions(+), 364 deletions(-) create mode 100644 aiida/backends/djsite/db/migrations/0044_dbgroup_type_string.py create mode 100644 aiida/backends/sqlalchemy/migrations/versions/bf591f31dd12_dbgroup_type_string.py create mode 100644 tests/backends/aiida_django/migrations/test_migrations_0044_dbgroup_type_string.py diff --git a/aiida/backends/djsite/db/migrations/0044_dbgroup_type_string.py b/aiida/backends/djsite/db/migrations/0044_dbgroup_type_string.py new file mode 100644 index 0000000000..8c577ce397 --- /dev/null +++ b/aiida/backends/djsite/db/migrations/0044_dbgroup_type_string.py @@ -0,0 +1,44 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +# pylint: disable=invalid-name,too-few-public-methods +"""Migration after the `Group` class became pluginnable and so the group `type_string` changed.""" + +# pylint: disable=no-name-in-module,import-error +from django.db import migrations +from aiida.backends.djsite.db.migrations import upgrade_schema_version + +REVISION = '1.0.44' +DOWN_REVISION = '1.0.43' + +forward_sql = [ + """UPDATE db_dbgroup SET type_string = 'core' WHERE type_string = 'user';""", + """UPDATE db_dbgroup SET type_string = 'core.upf' WHERE type_string = 'data.upf';""", + """UPDATE db_dbgroup SET type_string = 'core.import' WHERE type_string = 'auto.import';""", + """UPDATE db_dbgroup SET type_string = 'core.auto' WHERE type_string = 'auto.run';""", +] + +reverse_sql = [ + """UPDATE db_dbgroup SET type_string = 'user' WHERE type_string = 'core';""", + """UPDATE db_dbgroup SET type_string = 'data.upf' WHERE type_string = 'core.upf';""", + """UPDATE db_dbgroup SET type_string = 'auto.import' WHERE type_string = 'core.import';""", + """UPDATE db_dbgroup SET type_string = 'auto.run' WHERE type_string = 'core.auto';""", +] + + +class Migration(migrations.Migration): + """Migration after the update of group `type_string`""" + dependencies = [ + ('db', '0043_default_link_label'), + ] + + operations = [ + migrations.RunSQL(sql='\n'.join(forward_sql), reverse_sql='\n'.join(reverse_sql)), + upgrade_schema_version(REVISION, DOWN_REVISION), + ] diff --git a/aiida/backends/djsite/db/migrations/__init__.py b/aiida/backends/djsite/db/migrations/__init__.py index a832b4e5f7..41ee2b3d2c 100644 --- a/aiida/backends/djsite/db/migrations/__init__.py +++ b/aiida/backends/djsite/db/migrations/__init__.py @@ -21,7 +21,7 @@ class DeserializationException(AiidaException): pass -LATEST_MIGRATION = '0043_default_link_label' +LATEST_MIGRATION = '0044_dbgroup_type_string' def _update_schema_version(version, apps, _): diff --git a/aiida/backends/sqlalchemy/migrations/versions/bf591f31dd12_dbgroup_type_string.py b/aiida/backends/sqlalchemy/migrations/versions/bf591f31dd12_dbgroup_type_string.py new file mode 100644 index 0000000000..8231d8ebb7 --- /dev/null +++ b/aiida/backends/sqlalchemy/migrations/versions/bf591f31dd12_dbgroup_type_string.py @@ -0,0 +1,45 @@ +# -*- coding: utf-8 -*- +"""Migration after the `Group` class became pluginnable and so the group `type_string` changed. + +Revision ID: bf591f31dd12 +Revises: 118349c10896 +Create Date: 2020-03-31 10:00:52.609146 + +""" +# pylint: disable=no-name-in-module,import-error,invalid-name,no-member +from alembic import op +from sqlalchemy.sql import text + +forward_sql = [ + """UPDATE db_dbgroup SET type_string = 'core' WHERE type_string = 'user';""", + """UPDATE db_dbgroup SET type_string = 'core.upf' WHERE type_string = 'data.upf';""", + """UPDATE db_dbgroup SET type_string = 'core.import' WHERE type_string = 'auto.import';""", + """UPDATE db_dbgroup SET type_string = 'core.auto' WHERE type_string = 'auto.run';""", +] + +reverse_sql = [ + """UPDATE db_dbgroup SET type_string = 'user' WHERE type_string = 'core';""", + """UPDATE db_dbgroup SET type_string = 'data.upf' WHERE type_string = 'core.upf';""", + """UPDATE db_dbgroup SET type_string = 'auto.import' WHERE type_string = 'core.import';""", + """UPDATE db_dbgroup SET type_string = 'auto.run' WHERE type_string = 'core.auto';""", +] + +# revision identifiers, used by Alembic. +revision = 'bf591f31dd12' +down_revision = '118349c10896' +branch_labels = None +depends_on = None + + +def upgrade(): + """Migrations for the upgrade.""" + conn = op.get_bind() + statement = text('\n'.join(forward_sql)) + conn.execute(statement) + + +def downgrade(): + """Migrations for the downgrade.""" + conn = op.get_bind() + statement = text('\n'.join(reverse_sql)) + conn.execute(statement) diff --git a/aiida/cmdline/commands/cmd_data/cmd_upf.py b/aiida/cmdline/commands/cmd_data/cmd_upf.py index 78f79b0d9e..745f4af7a2 100644 --- a/aiida/cmdline/commands/cmd_data/cmd_upf.py +++ b/aiida/cmdline/commands/cmd_data/cmd_upf.py @@ -64,22 +64,13 @@ def upf_listfamilies(elements, with_description): """ from aiida import orm from aiida.plugins import DataFactory - from aiida.orm.nodes.data.upf import UPFGROUP_TYPE UpfData = DataFactory('upf') # pylint: disable=invalid-name query = orm.QueryBuilder() query.append(UpfData, tag='upfdata') if elements is not None: query.add_filter(UpfData, {'attributes.element': {'in': elements}}) - query.append( - orm.Group, - with_node='upfdata', - tag='group', - project=['label', 'description'], - filters={'type_string': { - '==': UPFGROUP_TYPE - }} - ) + query.append(orm.UpfFamily, with_node='upfdata', tag='group', project=['label', 'description']) query.distinct() if query.count() > 0: diff --git a/aiida/cmdline/commands/cmd_group.py b/aiida/cmdline/commands/cmd_group.py index 16978379ae..e48c361b33 100644 --- a/aiida/cmdline/commands/cmd_group.py +++ b/aiida/cmdline/commands/cmd_group.py @@ -13,7 +13,7 @@ from aiida.common.exceptions import UniquenessError from aiida.cmdline.commands.cmd_verdi import verdi -from aiida.cmdline.params import options, arguments, types +from aiida.cmdline.params import options, arguments from aiida.cmdline.utils import echo from aiida.cmdline.utils.decorators import with_dbenv @@ -178,18 +178,6 @@ def group_show(group, raw, limit, uuid): echo.echo(tabulate(table, headers=header)) -@with_dbenv() -def valid_group_type_strings(): - from aiida.orm import GroupTypeString - return tuple(i.value for i in GroupTypeString) - - -@with_dbenv() -def user_defined_group(): - from aiida.orm import GroupTypeString - return GroupTypeString.USER.value - - @verdi_group.command('list') @options.ALL_USERS(help='Show groups for all users, rather than only for the current user') @click.option( @@ -204,8 +192,7 @@ def user_defined_group(): '-t', '--type', 'group_type', - type=types.LazyChoice(valid_group_type_strings), - default=user_defined_group, + default='core', help='Show groups of a specific type, instead of user-defined groups. Start with semicolumn if you want to ' 'specify aiida-internal type' ) @@ -330,9 +317,8 @@ def group_list( def group_create(group_label): """Create an empty group with a given name.""" from aiida import orm - from aiida.orm import GroupTypeString - group, created = orm.Group.objects.get_or_create(label=group_label, type_string=GroupTypeString.USER.value) + group, created = orm.Group.objects.get_or_create(label=group_label) if created: echo.echo_success("Group created with PK = {} and name '{}'".format(group.id, group.label)) @@ -351,7 +337,7 @@ def group_copy(source_group, destination_group): Note that the destination group may not exist.""" from aiida import orm - dest_group, created = orm.Group.objects.get_or_create(label=destination_group, type_string=source_group.type_string) + dest_group, created = orm.Group.objects.get_or_create(label=destination_group) # Issue warning if destination group is not empty and get user confirmation to continue if not created and not dest_group.is_empty: @@ -386,8 +372,7 @@ def verdi_group_path(): '-t', '--type', 'group_type', - type=types.LazyChoice(valid_group_type_strings), - default=user_defined_group, + default='core', help='Show groups of a specific type, instead of user-defined groups. Start with semicolumn if you want to ' 'specify aiida-internal type' ) @@ -396,10 +381,11 @@ def verdi_group_path(): def group_path_ls(path, recursive, as_table, no_virtual, group_type, with_description, no_warn): # pylint: disable=too-many-arguments """Show a list of existing group paths.""" + from aiida.plugins import GroupFactory from aiida.tools.groups.paths import GroupPath, InvalidPath try: - path = GroupPath(path or '', type_string=group_type, warn_invalid_child=not no_warn) + path = GroupPath(path or '', cls=GroupFactory(group_type), warn_invalid_child=not no_warn) except InvalidPath as err: echo.echo_critical(str(err)) diff --git a/aiida/cmdline/commands/cmd_run.py b/aiida/cmdline/commands/cmd_run.py index bd3972b841..d46b6f984c 100644 --- a/aiida/cmdline/commands/cmd_run.py +++ b/aiida/cmdline/commands/cmd_run.py @@ -150,5 +150,6 @@ def run(scriptname, varargs, auto_group, auto_group_label_prefix, group_name, ex # Re-raise the exception to have the error code properly returned at the end raise finally: + autogroup.current_autogroup = None if handle: handle.close() diff --git a/aiida/cmdline/params/types/group.py b/aiida/cmdline/params/types/group.py index ef216044e7..6150f6d062 100644 --- a/aiida/cmdline/params/types/group.py +++ b/aiida/cmdline/params/types/group.py @@ -40,12 +40,12 @@ def orm_class_loader(self): @with_dbenv() def convert(self, value, param, ctx): - from aiida.orm import Group, GroupTypeString + from aiida.orm import Group try: group = super().convert(value, param, ctx) except click.BadParameter: if self._create_if_not_exist: - group = Group(label=value, type_string=GroupTypeString.USER.value) + group = Group(label=value) else: raise diff --git a/aiida/orm/autogroup.py b/aiida/orm/autogroup.py index 16bf03f1c1..06e83185e3 100644 --- a/aiida/orm/autogroup.py +++ b/aiida/orm/autogroup.py @@ -14,21 +14,18 @@ from aiida.common import exceptions, timezone from aiida.common.escaping import escape_for_sql_like, get_regex_pattern_from_sql from aiida.common.warnings import AiidaDeprecationWarning -from aiida.orm import GroupTypeString, Group +from aiida.orm import AutoGroup from aiida.plugins.entry_point import get_entry_point_string_from_class CURRENT_AUTOGROUP = None -VERDIAUTOGROUP_TYPE = GroupTypeString.VERDIAUTOGROUP_TYPE.value - class Autogroup: - """ - An object used for the autogrouping of objects. - The autogrouping is checked by the Node.store() method. - In the store(), the Node will check if CURRENT_AUTOGROUP is != None. - If so, it will call Autogroup.is_to_be_grouped, and decide whether to put it in a group. - Such autogroups are going to be of the VERDIAUTOGROUP_TYPE. + """Class to create a new `AutoGroup` instance that will, while active, automatically contain all nodes being stored. + + The autogrouping is checked by the `Node.store()` method which, if `CURRENT_AUTOGROUP is not None` the method + `Autogroup.is_to_be_grouped` is called to decide whether to put the current node being stored in the current + `AutoGroup` instance. The exclude/include lists are lists of strings like: ``aiida.data:int``, ``aiida.calculation:quantumespresso.pw``, @@ -198,7 +195,7 @@ def clear_group_cache(self): self._group_label = None def get_or_create_group(self): - """Return the current Autogroup, or create one if None has been set yet. + """Return the current `AutoGroup`, or create one if None has been set yet. This function implements a somewhat complex logic that is however needed to make sure that, even if `verdi run` is called at the same time multiple @@ -219,16 +216,10 @@ def get_or_create_group(self): # So the group with the same name can be returned quickly in future # calls of this method. if self._group_label is not None: - results = [ - res[0] for res in QueryBuilder(). - append(Group, filters={ - 'label': self._group_label, - 'type_string': VERDIAUTOGROUP_TYPE - }, project='*').iterall() - ] + builder = QueryBuilder().append(AutoGroup, filters={'label': self._group_label}) + results = [res[0] for res in builder.iterall()] if results: - # If it is not empty, it should have only one result due to the - # uniqueness constraints + # If it is not empty, it should have only one result due to the uniqueness constraints assert len(results) == 1, 'I got more than one autogroup with the same label!' return results[0] # There are no results: probably the group has been deleted. @@ -239,7 +230,7 @@ def get_or_create_group(self): # Try to do a preliminary QB query to avoid to do too many try/except # if many of the prefix_NUMBER groups already exist queryb = QueryBuilder().append( - Group, + AutoGroup, filters={ 'or': [{ 'label': { @@ -274,7 +265,7 @@ def get_or_create_group(self): while True: try: label = label_prefix if counter == 0 else '{}_{}'.format(label_prefix, counter) - group = Group(label=label, type_string=VERDIAUTOGROUP_TYPE).store() + group = AutoGroup(label=label).store() self._group_label = group.label except exceptions.IntegrityError: counter += 1 diff --git a/aiida/orm/convert.py b/aiida/orm/convert.py index 197253cffd..d6b577773b 100644 --- a/aiida/orm/convert.py +++ b/aiida/orm/convert.py @@ -61,8 +61,9 @@ def _(backend_entity): @get_orm_entity.register(BackendGroup) def _(backend_entity): - from . import groups - return groups.Group.from_backend_entity(backend_entity) + from .groups import load_group_class + group_class = load_group_class(backend_entity.type_string) + return group_class.from_backend_entity(backend_entity) @get_orm_entity.register(BackendComputer) diff --git a/aiida/orm/groups.py b/aiida/orm/groups.py index cb7b4af801..f2c726c0f2 100644 --- a/aiida/orm/groups.py +++ b/aiida/orm/groups.py @@ -8,7 +8,7 @@ # For further information please visit http://www.aiida.net # ########################################################################### """ AiiDA Group entites""" - +from abc import ABCMeta from enum import Enum import warnings @@ -21,19 +21,63 @@ from . import entities from . import users -__all__ = ('Group', 'GroupTypeString') +__all__ = ('Group', 'GroupTypeString', 'AutoGroup', 'ImportGroup', 'UpfFamily') + + +def load_group_class(type_string): + """Load the sub class of `Group` that corresponds to the given `type_string`. + + .. note:: will fall back on `aiida.orm.groups.Group` if `type_string` cannot be resolved to loadable entry point. + + :param type_string: the entry point name of the `Group` sub class + :return: sub class of `Group` registered through an entry point + """ + from aiida.common.exceptions import EntryPointError + from aiida.plugins.entry_point import load_entry_point + + try: + group_class = load_entry_point('aiida.groups', type_string) + except EntryPointError: + message = 'could not load entry point `{}`, falling back onto `Group` base class.'.format(type_string) + warnings.warn(message) # pylint: disable=no-member + group_class = Group + + return group_class + + +class GroupMeta(ABCMeta): + """Meta class for `aiida.orm.groups.Group` to automatically set the `type_string` attribute.""" + + def __new__(mcs, name, bases, namespace, **kwargs): + from aiida.plugins.entry_point import get_entry_point_from_class + + newcls = ABCMeta.__new__(mcs, name, bases, namespace, **kwargs) # pylint: disable=too-many-function-args + + entry_point_group, entry_point = get_entry_point_from_class(namespace['__module__'], name) + + if entry_point_group is None or entry_point_group != 'aiida.groups': + newcls._type_string = None + message = 'no registered entry point for `{}` so its instances will not be storable.'.format(name) + warnings.warn(message) # pylint: disable=no-member + else: + newcls._type_string = entry_point.name # pylint: disable=protected-access + + return newcls class GroupTypeString(Enum): - """A simple enum of allowed group type strings.""" + """A simple enum of allowed group type strings. + .. deprecated:: 1.2.0 + This enum is deprecated and will be removed in `v2.0.0`. + """ UPFGROUP_TYPE = 'data.upf' IMPORTGROUP_TYPE = 'auto.import' VERDIAUTOGROUP_TYPE = 'auto.run' USER = 'user' -class Group(entities.Entity): +class Group(entities.Entity, metaclass=GroupMeta): """An AiiDA ORM implementation of group of nodes.""" class Collection(entities.Collection): @@ -54,21 +98,10 @@ def get_or_create(self, label=None, **kwargs): if not label: raise ValueError('Group label must be provided') - filters = {'label': label} - - if 'type_string' in kwargs: - if not isinstance(kwargs['type_string'], str): - raise exceptions.ValidationError( - 'type_string must be {}, you provided an object of type ' - '{}'.format(str, type(kwargs['type_string'])) - ) - - filters['type_string'] = kwargs['type_string'] - - res = self.find(filters=filters) + res = self.find(filters={'label': label}) if not res: - return Group(label, backend=self.backend, **kwargs).store(), True + return self.entity_type(label, backend=self.backend, **kwargs).store(), True if len(res) > 1: raise exceptions.MultipleObjectsError('More than one groups found in the database') @@ -83,12 +116,15 @@ def delete(self, id): # pylint: disable=invalid-name, redefined-builtin """ self._backend.groups.delete(id) - def __init__(self, label=None, user=None, description='', type_string=GroupTypeString.USER.value, backend=None): + def __init__(self, label=None, user=None, description='', type_string=None, backend=None): """ Create a new group. Either pass a dbgroup parameter, to reload a group from the DB (and then, no further parameters are allowed), or pass the parameters for the Group creation. + .. deprecated:: 1.2.0 + The parameter `type_string` will be removed in `v2.0.0` and is now determined automatically. + :param label: The group label, required on creation :type label: str @@ -105,12 +141,11 @@ def __init__(self, label=None, user=None, description='', type_string=GroupTypeS if not label: raise ValueError('Group label must be provided') - # Check that chosen type_string is allowed - if not isinstance(type_string, str): - raise exceptions.ValidationError( - 'type_string must be {}, you provided an object of type ' - '{}'.format(str, type(type_string)) - ) + if type_string is not None: + message = '`type_string` is deprecated because it is determined automatically, using default `core`' + warnings.warn(message) # pylint: disable=no-member + + type_string = self._type_string backend = backend or get_manager().get_backend() user = user or users.User.objects(backend).get_default() @@ -130,6 +165,13 @@ def __str__(self): return '"{}" [user-defined], of user {}'.format(self.label, self.user.email) + def store(self): + """Verify that the group is allowed to be stored, which is the case along as `type_string` is set.""" + if self._type_string is None: + raise exceptions.StoringNotAllowed('`type_string` is `None` so the group cannot be stored.') + + return super().store() + @property def label(self): """ @@ -295,11 +337,7 @@ def get(cls, **kwargs): filters = {} if 'type_string' in kwargs: - if not isinstance(kwargs['type_string'], str): - raise exceptions.ValidationError( - 'type_string must be {}, you provided an object of type ' - '{}'.format(str, type(kwargs['type_string'])) - ) + type_check(kwargs['type_string'], str) query = QueryBuilder() for key, val in kwargs.items(): @@ -382,3 +420,15 @@ def get_schema(): 'type': 'unicode' } } + + +class AutoGroup(Group): + """Group to be used to contain selected nodes generated while `aiida.orm.autogroup.CURRENT_AUTOGROUP` is set.""" + + +class ImportGroup(Group): + """Group to be used to contain all nodes from an export archive that has been imported.""" + + +class UpfFamily(Group): + """Group that represents a pseudo potential family containing `UpfData` nodes.""" diff --git a/aiida/orm/implementation/groups.py b/aiida/orm/implementation/groups.py index 74349e25e6..f39314060f 100644 --- a/aiida/orm/implementation/groups.py +++ b/aiida/orm/implementation/groups.py @@ -101,7 +101,7 @@ def get_or_create(cls, *args, **kwargs): :return: (group, created) where group is the group (new or existing, in any case already stored) and created is a boolean saying """ - res = cls.query(name=kwargs.get('name'), type_string=kwargs.get('type_string')) + res = cls.query(name=kwargs.get('name')) if not res: return cls.create(*args, **kwargs), True diff --git a/aiida/orm/nodes/data/upf.py b/aiida/orm/nodes/data/upf.py index d35e1b35ee..33cf9b6421 100644 --- a/aiida/orm/nodes/data/upf.py +++ b/aiida/orm/nodes/data/upf.py @@ -8,20 +8,14 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Module of `Data` sub class to represent a pseudopotential single file in UPF format and related utilities.""" - import json import re from upf_to_json import upf_to_json - -from aiida.common.lang import classproperty -from aiida.orm import GroupTypeString from .singlefile import SinglefileData __all__ = ('UpfData',) -UPFGROUP_TYPE = GroupTypeString.UPFGROUP_TYPE.value - REGEX_UPF_VERSION = re.compile(r""" \s*.*)"> @@ -107,9 +101,7 @@ def upload_upf_family(folder, group_label, group_description, stop_if_existing=T nfiles = len(filenames) automatic_user = orm.User.objects.get_default() - group, group_created = orm.Group.objects.get_or_create( - label=group_label, type_string=UPFGROUP_TYPE, user=automatic_user - ) + group, group_created = orm.UpfFamily.objects.get_or_create(label=group_label, user=automatic_user) if group.user.email != automatic_user.email: raise UniquenessError( @@ -312,12 +304,6 @@ def get_or_create(cls, filepath, use_first=False, store_upf=True): return (pseudos[0], False) - @classproperty - def upffamily_type_string(cls): - """Return the type string used for UPF family groups.""" - # pylint: disable=no-self-argument,no-self-use - return UPFGROUP_TYPE - def store(self, *args, **kwargs): """Store the node, reparsing the file so that the md5 and the element are correctly reset.""" # pylint: disable=arguments-differ @@ -388,11 +374,11 @@ def set_file(self, file, filename=None): def get_upf_family_names(self): """Get the list of all upf family names to which the pseudo belongs.""" - from aiida.orm import Group + from aiida.orm import UpfFamily from aiida.orm import QueryBuilder query = QueryBuilder() - query.append(Group, filters={'type_string': {'==': self.upffamily_type_string}}, tag='group', project='label') + query.append(UpfFamily, tag='group', project='label') query.append(UpfData, filters={'id': {'==': self.id}}, with_group='group') return [label for label, in query.all()] @@ -465,9 +451,9 @@ def get_upf_group(cls, group_label): :param group_label: the family group label :return: the `Group` with the given label, if it exists """ - from aiida.orm import Group + from aiida.orm import UpfFamily - return Group.get(label=group_label, type_string=cls.upffamily_type_string) + return UpfFamily.get(label=group_label) @classmethod def get_upf_groups(cls, filter_elements=None, user=None): @@ -480,12 +466,12 @@ def get_upf_groups(cls, filter_elements=None, user=None): If defined, it should be either a `User` instance or the user email. :return: list of `Group` entities of type UPF. """ - from aiida.orm import Group + from aiida.orm import UpfFamily from aiida.orm import QueryBuilder from aiida.orm import User builder = QueryBuilder() - builder.append(Group, filters={'type_string': {'==': cls.upffamily_type_string}}, tag='group', project='*') + builder.append(UpfFamily, tag='group', project='*') if user: builder.append(User, filters={'email': {'==': user}}, with_group='group') @@ -496,7 +482,7 @@ def get_upf_groups(cls, filter_elements=None, user=None): if filter_elements is not None: builder.append(UpfData, filters={'attributes.element': {'in': filter_elements}}, with_group='group') - builder.order_by({Group: {'id': 'asc'}}) + builder.order_by({UpfFamily: {'id': 'asc'}}) return [group for group, in builder.all()] diff --git a/aiida/plugins/entry_point.py b/aiida/plugins/entry_point.py index 2abe6be077..46e4bf3c7e 100644 --- a/aiida/plugins/entry_point.py +++ b/aiida/plugins/entry_point.py @@ -54,6 +54,7 @@ class EntryPointFormat(enum.Enum): 'aiida.calculations': 'aiida.orm.nodes.process.calculation.calcjob', 'aiida.cmdline.data': 'aiida.cmdline.data', 'aiida.data': 'aiida.orm.nodes.data', + 'aiida.groups': 'aiida.orm.groups', 'aiida.node': 'aiida.orm.nodes', 'aiida.parsers': 'aiida.parsers.plugins', 'aiida.schedulers': 'aiida.schedulers.plugins', @@ -78,6 +79,7 @@ def validate_registered_entry_points(): # pylint: disable=invalid-name factory_mapping = { 'aiida.calculations': factories.CalculationFactory, 'aiida.data': factories.DataFactory, + 'aiida.groups': factories.GroupFactory, 'aiida.parsers': factories.ParserFactory, 'aiida.schedulers': factories.SchedulerFactory, 'aiida.transports': factories.TransportFactory, diff --git a/aiida/plugins/factories.py b/aiida/plugins/factories.py index 6e5a9296e9..1675ac6cb6 100644 --- a/aiida/plugins/factories.py +++ b/aiida/plugins/factories.py @@ -14,8 +14,8 @@ from aiida.common.exceptions import InvalidEntryPointTypeError __all__ = ( - 'BaseFactory', 'CalculationFactory', 'DataFactory', 'DbImporterFactory', 'OrbitalFactory', 'ParserFactory', - 'SchedulerFactory', 'TransportFactory', 'WorkflowFactory' + 'BaseFactory', 'CalculationFactory', 'DataFactory', 'DbImporterFactory', 'GroupFactory', 'OrbitalFactory', + 'ParserFactory', 'SchedulerFactory', 'TransportFactory', 'WorkflowFactory' ) @@ -107,6 +107,25 @@ def DbImporterFactory(entry_point_name): raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes) +def GroupFactory(entry_point_name): + """Return the `Group` sub class registered under the given entry point. + + :param entry_point_name: the entry point name + :return: sub class of :py:class:`~aiida.orm.groups.Group` + :raises aiida.common.InvalidEntryPointTypeError: if the type of the loaded entry point is invalid. + """ + from aiida.orm import Group + + entry_point_group = 'aiida.groups' + entry_point = BaseFactory(entry_point_group, entry_point_name) + valid_classes = (Group,) + + if isclass(entry_point) and issubclass(entry_point, Group): + return entry_point + + raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes) + + def OrbitalFactory(entry_point_name): """Return the `Orbital` sub class registered under the given entry point. diff --git a/aiida/tools/groups/paths.py b/aiida/tools/groups/paths.py index 9d20ea9c55..b025ab250e 100644 --- a/aiida/tools/groups/paths.py +++ b/aiida/tools/groups/paths.py @@ -17,7 +17,7 @@ from aiida import orm from aiida.common.exceptions import NotExistent -__all__ = ('GroupPath', 'InvalidPath') +__all__ = ('GroupPath', 'InvalidPath', 'GroupNotFoundError', 'GroupNotUniqueError', 'NoGroupsInPathError') REGEX_ATTR = re.compile('^[a-zA-Z][\\_a-zA-Z0-9]*$') @@ -60,19 +60,20 @@ class GroupPath: See tests for usage examples. """ - def __init__(self, path='', type_string=orm.GroupTypeString.USER.value, warn_invalid_child=True): + def __init__(self, path='', cls=orm.Group, warn_invalid_child=True): # type: (str, Optional[str], Optional[GroupPath]) """Instantiate the class. :param path: The initial path of the group. - :param type_string: Used to query for and instantiate a ``Group`` with. + :param cls: The subclass of `Group` to operate on. :param warn_invalid_child: Issue a warning, when iterating children, if a child path is invalid. """ + if not issubclass(cls, orm.Group): + raise TypeError('cls must a subclass of Group: {}'.format(cls)) + self._delimiter = '/' - if not isinstance(type_string, str): - raise TypeError('type_string must a str: {}'.format(type_string)) - self._type_string = type_string + self._cls = cls self._path_string = self._validate_path(path) self._path_list = self._path_string.split(self._delimiter) if path else [] self._warn_invalid_child = warn_invalid_child @@ -90,21 +91,21 @@ def _validate_path(self, path): def __repr__(self): # type: () -> str """Represent the instantiated class.""" - return "{}('{}', type='{}')".format(self.__class__.__name__, self.path, self.type_string) + return "{}('{}', cls='{}')".format(self.__class__.__name__, self.path, self.cls) def __eq__(self, other): # type: (Any) -> bool - """Compare equality of path and type string to another ``GroupPath`` object.""" + """Compare equality of path and ``Group`` subclass to another ``GroupPath`` object.""" if not isinstance(other, GroupPath): return NotImplemented - return (self.path, self.type_string) == (other.path, other.type_string) + return (self.path, self.cls) == (other.path, other.cls) def __lt__(self, other): # type: (Any) -> bool - """Compare less-than operator of path and type string to another ``GroupPath`` object.""" + """Compare less-than operator of path and ``Group`` subclass to another ``GroupPath`` object.""" if not isinstance(other, GroupPath): return NotImplemented - return (self.path, self.type_string) < (other.path, other.type_string) + return (self.path, self.cls) < (other.path, other.cls) @property def path(self): @@ -133,10 +134,10 @@ def delimiter(self): return self._delimiter @property - def type_string(self): + def cls(self): # type: () -> str - """Return the type_string used to query for and instantiate a ``Group`` with.""" - return self._type_string + """Return the cls used to query for and instantiate a ``Group`` with.""" + return self._cls @property def parent(self): @@ -144,9 +145,7 @@ def parent(self): """Return the parent path.""" if self.path_list: return GroupPath( - self.delimiter.join(self.path_list[:-1]), - type_string=self.type_string, - warn_invalid_child=self._warn_invalid_child + self.delimiter.join(self.path_list[:-1]), cls=self.cls, warn_invalid_child=self._warn_invalid_child ) return None @@ -158,7 +157,7 @@ def __truediv__(self, path): path = self._validate_path(path) child = GroupPath( path=self.path + self.delimiter + path if self.path else path, - type_string=self.type_string, + cls=self.cls, warn_invalid_child=self._warn_invalid_child ) return child @@ -169,10 +168,10 @@ def __getitem__(self, path): return self.__truediv__(path) def get_group(self): - # type: () -> Optional[orm.Group] + # type: () -> Optional[self.cls] """Return the concrete group associated with this path.""" try: - return orm.Group.objects.get(label=self.path, type_string=self.type_string) + return self.cls.objects.get(label=self.path) except NotExistent: return None @@ -182,16 +181,14 @@ def group_ids(self): """Return all the UUID associated with this GroupPath. :returns: and empty list, if no group associated with this label, - or can be multiple if type_string was None + or can be multiple if cls was None This is an efficient method for checking existence, which does not require the (slow) loading of the ORM entity. """ query = orm.QueryBuilder() filters = {'label': self.path} - if self.type_string is not None: - filters['type_string'] = self.type_string - query.append(orm.Group, filters=filters, project='id') + query.append(self.cls, filters=filters, project='id') return [r[0] for r in query.all()] @property @@ -201,11 +198,9 @@ def is_virtual(self): return len(self.group_ids) == 0 def get_or_create_group(self): - # type: () -> (orm.Group, bool) + # type: () -> (self.cls, bool) """Return the concrete group associated with this path or, create it, if it does not already exist.""" - if self.type_string is not None: - return orm.Group.objects.get_or_create(label=self.path, type_string=self.type_string) - return orm.Group.objects.get_or_create(label=self.path) + return self.cls.objects.get_or_create(label=self.path) def delete_group(self): """Delete the concrete group associated with this path. @@ -217,7 +212,7 @@ def delete_group(self): raise GroupNotFoundError(self) if len(ids) > 1: raise GroupNotUniqueError(self) - orm.Group.objects.delete(ids[0]) + self.cls.objects.delete(ids[0]) @property def children(self): @@ -227,9 +222,7 @@ def children(self): filters = {} if self.path: filters['label'] = {'like': self.path + self.delimiter + '%'} - if self.type_string is not None: - filters['type_string'] = self.type_string - query.append(orm.Group, filters=filters, project='label') + query.append(self.cls, filters=filters, project='label') if query.count() == 0 and self.is_virtual: raise NoGroupsInPathError(self) @@ -242,9 +235,7 @@ def children(self): if (path_string not in yielded and path[:len(self._path_list)] == self._path_list): yielded.append(path_string) try: - yield GroupPath( - path=path_string, type_string=self.type_string, warn_invalid_child=self._warn_invalid_child - ) + yield GroupPath(path=path_string, cls=self.cls, warn_invalid_child=self._warn_invalid_child) except InvalidPath: if self._warn_invalid_child: warnings.warn('invalid path encountered: {}'.format(path_string)) # pylint: disable=no-member @@ -291,9 +282,7 @@ def walk_nodes(self, filters=None, node_class=None, query_batch=None): group_filters = {} if self.path: group_filters['label'] = {'or': [{'==': self.path}, {'like': self.path + self.delimiter + '%'}]} - if self.type_string is not None: - group_filters['type_string'] = self.type_string - query.append(orm.Group, filters=group_filters, project='label', tag='group') + query.append(self.cls, filters=group_filters, project='label', tag='group') query.append( orm.Node if node_class is None else node_class, with_group='group', @@ -301,7 +290,7 @@ def walk_nodes(self, filters=None, node_class=None, query_batch=None): project=['*'], ) for (label, node) in query.iterall(query_batch) if query_batch else query.all(): - yield WalkNodeResult(GroupPath(label, type_string=self.type_string), node) + yield WalkNodeResult(GroupPath(label, cls=self.cls), node) @property def browse(self): @@ -330,9 +319,7 @@ def __init__(self, group_path): def __repr__(self): # type: () -> str """Represent the instantiated class.""" - return "{}('{}', type='{}')".format( - self.__class__.__name__, self._group_path.path, self._group_path.type_string - ) + return "{}('{}', type='{}')".format(self.__class__.__name__, self._group_path.path, self._group_path.cls) def __call__(self): # type: () -> GroupPath diff --git a/aiida/tools/importexport/common/config.py b/aiida/tools/importexport/common/config.py index 0baac376c9..549c22be7d 100644 --- a/aiida/tools/importexport/common/config.py +++ b/aiida/tools/importexport/common/config.py @@ -9,15 +9,13 @@ ########################################################################### # pylint: disable=invalid-name """ Configuration file for AiiDA Import/Export module """ - -from aiida.orm import Computer, Group, GroupTypeString, Node, User, Log, Comment +from aiida.orm import Computer, Group, Node, User, Log, Comment __all__ = ('EXPORT_VERSION',) # Current export version EXPORT_VERSION = '0.8' -IMPORTGROUP_TYPE = GroupTypeString.IMPORTGROUP_TYPE.value DUPL_SUFFIX = ' (Imported #{})' # The name of the subfolder in which the node files are stored diff --git a/aiida/tools/importexport/dbimport/backends/django/__init__.py b/aiida/tools/importexport/dbimport/backends/django/__init__.py index d97ad70d1d..aa463f5ffb 100644 --- a/aiida/tools/importexport/dbimport/backends/django/__init__.py +++ b/aiida/tools/importexport/dbimport/backends/django/__init__.py @@ -21,10 +21,10 @@ from aiida.common.links import LinkType, validate_link_label from aiida.common.utils import grouper, get_object_from_string from aiida.orm.utils.repository import Repository -from aiida.orm import QueryBuilder, Node, Group +from aiida.orm import QueryBuilder, Node, Group, ImportGroup from aiida.tools.importexport.common import exceptions from aiida.tools.importexport.common.archive import extract_tree, extract_tar, extract_zip -from aiida.tools.importexport.common.config import DUPL_SUFFIX, IMPORTGROUP_TYPE, EXPORT_VERSION, NODES_EXPORT_SUBFOLDER +from aiida.tools.importexport.common.config import DUPL_SUFFIX, EXPORT_VERSION, NODES_EXPORT_SUBFOLDER from aiida.tools.importexport.common.config import ( NODE_ENTITY_NAME, GROUP_ENTITY_NAME, COMPUTER_ENTITY_NAME, USER_ENTITY_NAME, LOG_ENTITY_NAME, COMMENT_ENTITY_NAME ) @@ -673,7 +673,7 @@ def import_data_dj( "Overflow of import groups (more than 100 import groups exists with basename '{}')" ''.format(basename) ) - group = Group(label=group_label, type_string=IMPORTGROUP_TYPE).store() + group = ImportGroup(label=group_label).store() # Add all the nodes to the new group # TODO: decide if we want to return the group label diff --git a/aiida/tools/importexport/dbimport/backends/sqla/__init__.py b/aiida/tools/importexport/dbimport/backends/sqla/__init__.py index f08de125ec..2e800b1361 100644 --- a/aiida/tools/importexport/dbimport/backends/sqla/__init__.py +++ b/aiida/tools/importexport/dbimport/backends/sqla/__init__.py @@ -20,13 +20,13 @@ from aiida.common.folders import SandboxFolder, RepositoryFolder from aiida.common.links import LinkType from aiida.common.utils import get_object_from_string -from aiida.orm import QueryBuilder, Node, Group, WorkflowNode, CalculationNode, Data +from aiida.orm import QueryBuilder, Node, Group, ImportGroup from aiida.orm.utils.links import link_triple_exists, validate_link from aiida.orm.utils.repository import Repository from aiida.tools.importexport.common import exceptions from aiida.tools.importexport.common.archive import extract_tree, extract_tar, extract_zip -from aiida.tools.importexport.common.config import DUPL_SUFFIX, IMPORTGROUP_TYPE, EXPORT_VERSION, NODES_EXPORT_SUBFOLDER +from aiida.tools.importexport.common.config import DUPL_SUFFIX, EXPORT_VERSION, NODES_EXPORT_SUBFOLDER from aiida.tools.importexport.common.config import ( NODE_ENTITY_NAME, GROUP_ENTITY_NAME, COMPUTER_ENTITY_NAME, USER_ENTITY_NAME, LOG_ENTITY_NAME, COMMENT_ENTITY_NAME ) @@ -664,7 +664,7 @@ def import_data_sqla( "Overflow of import groups (more than 100 import groups exists with basename '{}')" ''.format(basename) ) - group = Group(label=group_label, type_string=IMPORTGROUP_TYPE) + group = ImportGroup(label=group_label) session.add(group.backend_entity._dbmodel) # Adding nodes to group avoiding the SQLA ORM to increase speed diff --git a/docs/source/working_with_aiida/groups.rst b/docs/source/working_with_aiida/groups.rst index 58eadcc024..55bd82ebd5 100644 --- a/docs/source/working_with_aiida/groups.rst +++ b/docs/source/working_with_aiida/groups.rst @@ -18,144 +18,162 @@ be performed with groups: Create a new Group ------------------ - From the command line interface:: +From the command line interface:: - verdi group create test_group + verdi group create test_group - From the python interface:: +From the python interface:: - In [1]: group = Group(label="test_group") - - In [2]: group.store() - Out[2]: + In [1]: group = Group(label="test_group") + In [2]: group.store() + Out[2]: List available Groups --------------------- - Example:: +Example:: - verdi group list + verdi group list - By default ``verdi group list`` only shows groups of the type *user*. - In case you want to show groups of another type use ``-t/--type`` option. If - you want to show groups of all types, use the ``-a/--all-types`` option. +By default ``verdi group list`` only shows groups of the type *user*. +In case you want to show groups of another type use ``-t/--type`` option. If +you want to show groups of all types, use the ``-a/--all-types`` option. - From the command line interface:: +From the command line interface:: - verdi group list -t user + verdi group list -t user - From the python interface:: +From the python interface:: - In [1]: query = QueryBuilder() + In [1]: query = QueryBuilder() - In [2]: query.append(Group, filters={'type_string':'user'}) - Out[2]: + In [2]: query.append(Group, filters={'type_string':'user'}) + Out[2]: - In [3]: query.all() - Out[3]: - [[], - [], - []] + In [3]: query.all() + Out[3]: + [[], + [], + []] Add nodes to a Group -------------------- - Once the ``test_group`` has been created, we can add nodes to it. To add the node with ``pk=1`` to the group we need to do the following. - - From the command line interface:: - - verdi group add-nodes -G test_group 1 - Do you really want to add 1 nodes to Group? [y/N]: y - - From the python interface:: +Once the ``test_group`` has been created, we can add nodes to it. To add the node with ``pk=1`` to the group we need to do the following. - In [1]: group = Group.get(label='test_group') +From the command line interface:: - In [2]: from aiida.orm import Dict + verdi group add-nodes -G test_group 1 + Do you really want to add 1 nodes to Group? [y/N]: y - In [3]: p = Dict().store() +From the python interface:: - In [4]: p - Out[4]: + In [1]: group = Group.get(label='test_group') + In [2]: from aiida.orm import Dict + In [3]: p = Dict().store() + In [4]: p + Out[4]: + In [5]: group.add_nodes(p) - In [5]: group.add_nodes(p) Show information about a Group ------------------------------ - From the command line interface:: - - verdi group show test_group - ----------------- ---------------- - Group label test_group - Group type_string user - Group description - ----------------- ---------------- - # Nodes: - PK Type Created - ---- ------ --------------- - 1 Code 26D:21h:45m ago +From the command line interface:: + verdi group show test_group + ----------------- ---------------- + Group label test_group + Group type_string user + Group description + ----------------- ---------------- + # Nodes: + PK Type Created + ---- ------ --------------- + 1 Code 26D:21h:45m ago Remove nodes from a Group ------------------------- - From the command line interface:: +From the command line interface:: - verdi group remove-nodes -G test_group 1 - Do you really want to remove 1 nodes from Group? [y/N]: y + verdi group remove-nodes -G test_group 1 + Do you really want to remove 1 nodes from Group? [y/N]: y - From the python interface:: +From the python interface:: - In [1]: group = Group.get(label='test_group') + In [1]: group = Group.get(label='test_group') + In [2]: group.clear() - In [2]: group.clear() Rename Group ------------ - From the command line interface:: +From the command line interface:: verdi group relabel test_group old_group Success: Label changed to old_group - From the python interface:: +From the python interface:: - In [1]: group = Group.get(label='old_group') + In [1]: group = Group.get(label='old_group') + In [2]: group.label = "another_group" - In [2]: group.label = "another_group" Delete Group ------------ - From the command line interface:: +From the command line interface:: verdi group delete another_group Are you sure to delete Group? [y/N]: y Success: Group deleted. - Copy one group into another --------------------------- - This operation will copy the nodes of the source group into the destination - group. Moreover, if the destination group did not exist before, it will - be created automatically. +This operation will copy the nodes of the source group into the destination +group. Moreover, if the destination group did not exist before, it will +be created automatically. + +From the command line interface:: + + verdi group copy source_group dest_group + Success: Nodes copied from group to group + +From the python interface:: + + In [1]: src_group = Group.objects.get(label='source_group') + In [2]: dest_group = Group(label='destination_group').store() + In [3]: dest_group.add_nodes(list(src_group.nodes)) + + +Create a `Group` subclass +------------------------- +It is possible to create a subclass of `Group` to implement custom functionality. +To make the instances of the subclass storable and loadable, it has to be registered through an entry point in the ``aiida.groups`` entry point category. +For example, assuming we have a subclass ``SubClassGroup`` in the module ``aiida_plugin.groups.sub_class:SubClassGroup``, to register it, one has to add the following to the ``setup.py`` of the plugin package:: - From the command line interface:: + "entry_points": { + "aiida.groups": [ + "plugin.sub_class = aiida_plugin.groups.sub_class:SubClassGroup" + ] + } - verdi group copy source_group dest_group - Success: Nodes copied from group to group +Now that the subclass is properly registered, instances can be stored:: - From the python interface:: + group = SubClassGroup(label='sub-class-group') + group.store() - In [1]: src_group = Group.objects.get(label='source_group') +The ``type_string`` of the group instance corresponds to the entry point name and so in this example is ``plugin.sub_class``. +This is what AiiDA uses to load the correct class when reloading the group from the database:: - In [2]: dest_group = Group(label='destination_group').store() + group = load_group(group.pk) + assert isinstance(group, SubClassGroup) - In [3]: dest_group.add_nodes(list(src_group.nodes)) +If the entry point is not currently registered, because the corresponding plugin package is not installed for example, AiiDA will issue a warning and fallback onto the ``Group`` base class. diff --git a/setup.json b/setup.json index ebe8360174..7f4b7783ca 100644 --- a/setup.json +++ b/setup.json @@ -159,6 +159,12 @@ "structure = aiida.orm.nodes.data.structure:StructureData", "upf = aiida.orm.nodes.data.upf:UpfData" ], + "aiida.groups": [ + "core = aiida.orm.groups:Group", + "core.auto = aiida.orm.groups:AutoGroup", + "core.import = aiida.orm.groups:ImportGroup", + "core.upf = aiida.orm.groups:UpfFamily" + ], "aiida.node": [ "data = aiida.orm.nodes.data.data:Data", "process = aiida.orm.nodes.process.process:ProcessNode", diff --git a/tests/backends/aiida_django/migrations/test_migrations_0044_dbgroup_type_string.py b/tests/backends/aiida_django/migrations/test_migrations_0044_dbgroup_type_string.py new file mode 100644 index 0000000000..ab1b31d518 --- /dev/null +++ b/tests/backends/aiida_django/migrations/test_migrations_0044_dbgroup_type_string.py @@ -0,0 +1,63 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +# pylint: disable=import-error,no-name-in-module,invalid-name +"""Test migration of `type_string` after the `Group` class became pluginnable.""" + +from .test_migrations_common import TestMigrations + + +class TestGroupTypeStringMigration(TestMigrations): + """Test migration of `type_string` after the `Group` class became pluginnable.""" + + migrate_from = '0043_default_link_label' + migrate_to = '0044_dbgroup_type_string' + + def setUpBeforeMigration(self): + DbGroup = self.apps.get_model('db', 'DbGroup') + + # test user group type_string: 'user' -> 'core' + group_user = DbGroup(label='01', user_id=self.default_user.id, type_string='user') + group_user.save() + self.group_user_pk = group_user.pk + + # test data.upf group type_string: 'data.upf' -> 'core.upf' + group_data_upf = DbGroup(label='02', user_id=self.default_user.id, type_string='data.upf') + group_data_upf.save() + self.group_data_upf_pk = group_data_upf.pk + + # test auto.import group type_string: 'auto.import' -> 'core.import' + group_autoimport = DbGroup(label='03', user_id=self.default_user.id, type_string='auto.import') + group_autoimport.save() + self.group_autoimport_pk = group_autoimport.pk + + # test auto.run group type_string: 'auto.run' -> 'core.auto' + group_autorun = DbGroup(label='04', user_id=self.default_user.id, type_string='auto.run') + group_autorun.save() + self.group_autorun_pk = group_autorun.pk + + def test_group_string_update(self): + """Test that the type_string were updated correctly.""" + DbGroup = self.apps.get_model('db', 'DbGroup') + + # 'user' -> 'core' + group_user = DbGroup.objects.get(pk=self.group_user_pk) + self.assertEqual(group_user.type_string, 'core') + + # 'data.upf' -> 'core.upf' + group_data_upf = DbGroup.objects.get(pk=self.group_data_upf_pk) + self.assertEqual(group_data_upf.type_string, 'core.upf') + + # 'auto.import' -> 'core.import' + group_autoimport = DbGroup.objects.get(pk=self.group_autoimport_pk) + self.assertEqual(group_autoimport.type_string, 'core.import') + + # 'auto.run' -> 'core.auto' + group_autorun = DbGroup.objects.get(pk=self.group_autorun_pk) + self.assertEqual(group_autorun.type_string, 'core.auto') diff --git a/tests/backends/aiida_sqlalchemy/test_migrations.py b/tests/backends/aiida_sqlalchemy/test_migrations.py index 8e2046f293..2bb52ceecc 100644 --- a/tests/backends/aiida_sqlalchemy/test_migrations.py +++ b/tests/backends/aiida_sqlalchemy/test_migrations.py @@ -1642,3 +1642,69 @@ def test_data_migrated(self): finally: session.close() + + +class TestGroupTypeStringMigration(TestMigrationsSQLA): + """Test the migration that renames the DbGroup type strings.""" + + migrate_from = '118349c10896' # 118349c10896_default_link_label.py + migrate_to = 'bf591f31dd12' # bf591f31dd12_dbgroup_type_string.py + + def setUpBeforeMigration(self): + """Create the DbGroups with the old type strings.""" + DbGroup = self.get_current_table('db_dbgroup') # pylint: disable=invalid-name + DbUser = self.get_current_table('db_dbuser') # pylint: disable=invalid-name + + with self.get_session() as session: + try: + default_user = DbUser(email='{}@aiida.net'.format(self.id())) + session.add(default_user) + session.commit() + + # test user group type_string: 'user' -> 'core' + group_user = DbGroup(label='01', user_id=default_user.id, type_string='user') + session.add(group_user) + # test data.upf group type_string: 'data.upf' -> 'core.upf' + group_data_upf = DbGroup(label='02', user_id=default_user.id, type_string='data.upf') + session.add(group_data_upf) + # test auto.import group type_string: 'auto.import' -> 'core.import' + group_autoimport = DbGroup(label='03', user_id=default_user.id, type_string='auto.import') + session.add(group_autoimport) + # test auto.run group type_string: 'auto.run' -> 'core.auto' + group_autorun = DbGroup(label='04', user_id=default_user.id, type_string='auto.run') + session.add(group_autorun) + + session.commit() + + # Store values for later tests + self.group_user_pk = group_user.id + self.group_data_upf_pk = group_data_upf.id + self.group_autoimport_pk = group_autoimport.id + self.group_autorun_pk = group_autorun.id + + finally: + session.close() + + def test_group_string_update(self): + """Test that the type strings are properly migrated.""" + DbGroup = self.get_current_table('db_dbgroup') # pylint: disable=invalid-name + + with self.get_session() as session: + try: + # test user group type_string: 'user' -> 'core' + group_user = session.query(DbGroup).filter(DbGroup.id == self.group_user_pk).one() + self.assertEqual(group_user.type_string, 'core') + + # test data.upf group type_string: 'data.upf' -> 'core.upf' + group_data_upf = session.query(DbGroup).filter(DbGroup.id == self.group_data_upf_pk).one() + self.assertEqual(group_data_upf.type_string, 'core.upf') + + # test auto.import group type_string: 'auto.import' -> 'core.import' + group_autoimport = session.query(DbGroup).filter(DbGroup.id == self.group_autoimport_pk).one() + self.assertEqual(group_autoimport.type_string, 'core.import') + + # test auto.run group type_string: 'auto.run' -> 'core.auto' + group_autorun = session.query(DbGroup).filter(DbGroup.id == self.group_autorun_pk).one() + self.assertEqual(group_autorun.type_string, 'core.auto') + finally: + session.close() diff --git a/tests/cmdline/commands/test_group.py b/tests/cmdline/commands/test_group.py index 4302420633..ab79f650b1 100644 --- a/tests/cmdline/commands/test_group.py +++ b/tests/cmdline/commands/test_group.py @@ -165,7 +165,7 @@ def test_show(self): self.assertClickResultNoException(result) for grpline in [ - 'Group label', 'dummygroup1', 'Group type_string', 'user', 'Group description', '' + 'Group label', 'dummygroup1', 'Group type_string', 'core', 'Group description', '' ]: self.assertIn(grpline, result.output) diff --git a/tests/cmdline/commands/test_group_ls.py b/tests/cmdline/commands/test_group_ls.py index 7cc01079a4..1a8bb7d2ef 100644 --- a/tests/cmdline/commands/test_group_ls.py +++ b/tests/cmdline/commands/test_group_ls.py @@ -22,12 +22,13 @@ def setup_groups(clear_database_before_test): """Setup some groups for testing.""" for label in ['a', 'a/b', 'a/c/d', 'a/c/e/g', 'a/f']: - group, _ = orm.Group.objects.get_or_create(label, type_string=orm.GroupTypeString.USER.value) + group, _ = orm.Group.objects.get_or_create(label) group.description = 'A description of {}'.format(label) - orm.Group.objects.get_or_create('a/x', type_string=orm.GroupTypeString.UPFGROUP_TYPE.value) + orm.UpfFamily.objects.get_or_create('a/x') yield +@pytest.mark.skip('Reenable when subclassing in the query builder is implemented (#3902)') def test_with_no_opts(setup_groups): """Test ``verdi group path ls``""" @@ -46,6 +47,7 @@ def test_with_no_opts(setup_groups): assert result.output == 'a/c/d\na/c/e\n' +@pytest.mark.skip('Reenable when subclassing in the query builder is implemented (#3902)') def test_recursive(setup_groups): """Test ``verdi group path ls --recursive``""" @@ -61,6 +63,7 @@ def test_recursive(setup_groups): assert result.output == 'a/c/d\na/c/e\na/c/e/g\n' +@pytest.mark.skip('Reenable when subclassing in the query builder is implemented (#3902)') @pytest.mark.parametrize('tag', ['-l', '--long']) def test_long(setup_groups, tag): """Test ``verdi group path ls --long``""" @@ -106,6 +109,7 @@ def test_long(setup_groups, tag): ) +@pytest.mark.skip('Reenable when subclassing in the query builder is implemented (#3902)') @pytest.mark.parametrize('tag', ['--no-virtual']) def test_groups_only(setup_groups, tag): """Test ``verdi group path ls --no-virtual``""" diff --git a/tests/cmdline/commands/test_run.py b/tests/cmdline/commands/test_run.py index 78c858420f..4ed690bb20 100644 --- a/tests/cmdline/commands/test_run.py +++ b/tests/cmdline/commands/test_run.py @@ -9,6 +9,7 @@ ########################################################################### """Tests for `verdi run`.""" import tempfile +import textwrap import warnings from click.testing import CliRunner @@ -31,21 +32,22 @@ def test_run_workfunction(self): that are defined within the script will fail, as the inspect module will not correctly be able to determin the full path of the source file. """ - from aiida.orm import load_node - from aiida.orm import WorkFunctionNode + from aiida.orm import load_node, WorkFunctionNode - script_content = """ -#!/usr/bin/env python -from aiida.engine import workfunction + script_content = textwrap.dedent( + """\ + #!/usr/bin/env python + from aiida.engine import workfunction -@workfunction -def wf(): - pass + @workfunction + def wf(): + pass -if __name__ == '__main__': - result, node = wf.run_get_node() - print(node.pk) - """ + if __name__ == '__main__': + result, node = wf.run_get_node() + print(node.pk) + """ + ) # If `verdi run` is not setup correctly, the script above when run with `verdi run` will fail, because when # the engine will try to create the node for the workfunction and create a copy of its sourcefile, namely the @@ -77,9 +79,8 @@ def setUp(self): super().setUp() self.cli_runner = CliRunner() - # I need to disable the global variable of this test environment, - # because invoke is just calling the function and therefore inheriting - # the global variable + # I need to disable the global variable of this test environment, because invoke is just calling the function + # and therefore inheriting the global variable self._old_autogroup = autogroup.CURRENT_AUTOGROUP autogroup.CURRENT_AUTOGROUP = None @@ -92,12 +93,15 @@ def tearDown(self): def test_autogroup(self): """Check if the autogroup is properly generated.""" - from aiida.orm import QueryBuilder, Node, Group, load_node + from aiida.orm import QueryBuilder, Node, AutoGroup, load_node - script_content = """from aiida.orm import Data -node = Data().store() -print(node.pk) -""" + script_content = textwrap.dedent( + """\ + from aiida.orm import Data + node = Data().store() + print(node.pk) + """ + ) with tempfile.NamedTemporaryFile(mode='w+') as fhandle: fhandle.write(script_content) @@ -111,7 +115,7 @@ def test_autogroup(self): _ = load_node(pk) # Check if the node can be loaded queryb = QueryBuilder().append(Node, filters={'id': pk}, tag='node') - queryb.append(Group, with_node='node', filters={'type_string': 'auto.run'}, project='*') + queryb.append(AutoGroup, with_node='node', project='*') all_auto_groups = queryb.all() self.assertEqual( len(all_auto_groups), 1, 'There should be only one autogroup associated with the node just created' @@ -119,12 +123,16 @@ def test_autogroup(self): def test_autogroup_custom_label(self): """Check if the autogroup is properly generated with the label specified.""" - from aiida.orm import QueryBuilder, Node, Group, load_node + from aiida.orm import QueryBuilder, Node, AutoGroup, load_node + + script_content = textwrap.dedent( + """\ + from aiida.orm import Data + node = Data().store() + print(node.pk) + """ + ) - script_content = """from aiida.orm import Data -node = Data().store() -print(node.pk) -""" autogroup_label = 'SOME_group_LABEL' with tempfile.NamedTemporaryFile(mode='w+') as fhandle: fhandle.write(script_content) @@ -138,7 +146,7 @@ def test_autogroup_custom_label(self): _ = load_node(pk) # Check if the node can be loaded queryb = QueryBuilder().append(Node, filters={'id': pk}, tag='node') - queryb.append(Group, with_node='node', filters={'type_string': 'auto.run'}, project='*') + queryb.append(AutoGroup, with_node='node', project='*') all_auto_groups = queryb.all() self.assertEqual( len(all_auto_groups), 1, 'There should be only one autogroup associated with the node just created' @@ -147,12 +155,15 @@ def test_autogroup_custom_label(self): def test_no_autogroup(self): """Check if the autogroup is not generated if ``verdi run`` is asked not to.""" - from aiida.orm import QueryBuilder, Node, Group, load_node + from aiida.orm import QueryBuilder, Node, AutoGroup, load_node - script_content = """from aiida.orm import Data -node = Data().store() -print(node.pk) -""" + script_content = textwrap.dedent( + """\ + from aiida.orm import Data + node = Data().store() + print(node.pk) + """ + ) with tempfile.NamedTemporaryFile(mode='w+') as fhandle: fhandle.write(script_content) @@ -166,61 +177,64 @@ def test_no_autogroup(self): _ = load_node(pk) # Check if the node can be loaded queryb = QueryBuilder().append(Node, filters={'id': pk}, tag='node') - queryb.append(Group, with_node='node', filters={'type_string': 'auto.run'}, project='*') + queryb.append(AutoGroup, with_node='node', project='*') all_auto_groups = queryb.all() self.assertEqual(len(all_auto_groups), 0, 'There should be no autogroup generated') def test_autogroup_filter_class(self): # pylint: disable=too-many-locals """Check if the autogroup is properly generated but filtered classes are skipped.""" - from aiida.orm import QueryBuilder, Node, Group, load_node - - script_content = """import sys -from aiida.orm import Computer, Int, ArrayData, KpointsData, CalculationNode, WorkflowNode -from aiida.plugins import CalculationFactory -from aiida.engine import run_get_node -ArithmeticAdd = CalculationFactory('arithmetic.add') - -computer = Computer( - name='localhost-example-{}'.format(sys.argv[1]), - hostname='localhost', - description='my computer', - transport_type='local', - scheduler_type='direct', - workdir='/tmp' -).store() -computer.configure() - -code = Code( - input_plugin_name='arithmetic.add', - remote_computer_exec=[computer, '/bin/true']).store() -inputs = { - 'x': Int(1), - 'y': Int(2), - 'code': code, - 'metadata': { - 'options': { - 'resources': { - 'num_machines': 1, - 'num_mpiprocs_per_machine': 1 + from aiida.orm import Code, QueryBuilder, Node, AutoGroup, load_node + + script_content = textwrap.dedent( + """\ + import sys + from aiida.orm import Computer, Int, ArrayData, KpointsData, CalculationNode, WorkflowNode + from aiida.plugins import CalculationFactory + from aiida.engine import run_get_node + ArithmeticAdd = CalculationFactory('arithmetic.add') + + computer = Computer( + name='localhost-example-{}'.format(sys.argv[1]), + hostname='localhost', + description='my computer', + transport_type='local', + scheduler_type='direct', + workdir='/tmp' + ).store() + computer.configure() + + code = Code( + input_plugin_name='arithmetic.add', + remote_computer_exec=[computer, '/bin/true']).store() + inputs = { + 'x': Int(1), + 'y': Int(2), + 'code': code, + 'metadata': { + 'options': { + 'resources': { + 'num_machines': 1, + 'num_mpiprocs_per_machine': 1 + } + } + } } - } - } -} - -node1 = KpointsData().store() -node2 = ArrayData().store() -node3 = Int(3).store() -node4 = CalculationNode().store() -node5 = WorkflowNode().store() -_, node6 = run_get_node(ArithmeticAdd, **inputs) -print(node1.pk) -print(node2.pk) -print(node3.pk) -print(node4.pk) -print(node5.pk) -print(node6.pk) -""" - from aiida.orm import Code + + node1 = KpointsData().store() + node2 = ArrayData().store() + node3 = Int(3).store() + node4 = CalculationNode().store() + node5 = WorkflowNode().store() + _, node6 = run_get_node(ArithmeticAdd, **inputs) + print(node1.pk) + print(node2.pk) + print(node3.pk) + print(node4.pk) + print(node5.pk) + print(node6.pk) + """ + ) + Code() for idx, ( flags, @@ -283,27 +297,27 @@ def test_autogroup_filter_class(self): # pylint: disable=too-many-locals _ = load_node(pk6) # Check if the node can be loaded queryb = QueryBuilder().append(Node, filters={'id': pk1}, tag='node') - queryb.append(Group, with_node='node', filters={'type_string': 'auto.run'}, project='*') + queryb.append(AutoGroup, with_node='node', project='*') all_auto_groups_kptdata = queryb.all() queryb = QueryBuilder().append(Node, filters={'id': pk2}, tag='node') - queryb.append(Group, with_node='node', filters={'type_string': 'auto.run'}, project='*') + queryb.append(AutoGroup, with_node='node', project='*') all_auto_groups_arraydata = queryb.all() queryb = QueryBuilder().append(Node, filters={'id': pk3}, tag='node') - queryb.append(Group, with_node='node', filters={'type_string': 'auto.run'}, project='*') + queryb.append(AutoGroup, with_node='node', project='*') all_auto_groups_int = queryb.all() queryb = QueryBuilder().append(Node, filters={'id': pk4}, tag='node') - queryb.append(Group, with_node='node', filters={'type_string': 'auto.run'}, project='*') + queryb.append(AutoGroup, with_node='node', project='*') all_auto_groups_calc = queryb.all() queryb = QueryBuilder().append(Node, filters={'id': pk5}, tag='node') - queryb.append(Group, with_node='node', filters={'type_string': 'auto.run'}, project='*') + queryb.append(AutoGroup, with_node='node', project='*') all_auto_groups_wf = queryb.all() queryb = QueryBuilder().append(Node, filters={'id': pk6}, tag='node') - queryb.append(Group, with_node='node', filters={'type_string': 'auto.run'}, project='*') + queryb.append(AutoGroup, with_node='node', project='*') all_auto_groups_calcarithmetic = queryb.all() self.assertEqual( @@ -339,12 +353,16 @@ def test_autogroup_filter_class(self): # pylint: disable=too-many-locals def test_autogroup_clashing_label(self): """Check if the autogroup label is properly (re)generated when it clashes with an existing group name.""" - from aiida.orm import QueryBuilder, Node, Group, load_node + from aiida.orm import QueryBuilder, Node, AutoGroup, load_node + + script_content = textwrap.dedent( + """\ + from aiida.orm import Data + node = Data().store() + print(node.pk) + """ + ) - script_content = """from aiida.orm import Data -node = Data().store() -print(node.pk) -""" autogroup_label = 'SOME_repeated_group_LABEL' with tempfile.NamedTemporaryFile(mode='w+') as fhandle: fhandle.write(script_content) @@ -358,7 +376,7 @@ def test_autogroup_clashing_label(self): pk = int(result.output) _ = load_node(pk) # Check if the node can be loaded queryb = QueryBuilder().append(Node, filters={'id': pk}, tag='node') - queryb.append(Group, with_node='node', filters={'type_string': 'auto.run'}, project='*') + queryb.append(AutoGroup, with_node='node', project='*') all_auto_groups = queryb.all() self.assertEqual( len(all_auto_groups), 1, 'There should be only one autogroup associated with the node just created' @@ -374,7 +392,7 @@ def test_autogroup_clashing_label(self): pk = int(result.output) _ = load_node(pk) # Check if the node can be loaded queryb = QueryBuilder().append(Node, filters={'id': pk}, tag='node') - queryb.append(Group, with_node='node', filters={'type_string': 'auto.run'}, project='*') + queryb.append(AutoGroup, with_node='node', project='*') all_auto_groups = queryb.all() self.assertEqual( len(all_auto_groups), 1, 'There should be only one autogroup associated with the node just created' @@ -383,12 +401,15 @@ def test_autogroup_clashing_label(self): def test_legacy_autogroup_name(self): """Check if the autogroup is properly generated when using the legacy --group-name flag.""" - from aiida.orm import QueryBuilder, Node, Group, load_node - - script_content = """from aiida.orm import Data -node = Data().store() -print(node.pk) -""" + from aiida.orm import QueryBuilder, Node, AutoGroup, load_node + + script_content = textwrap.dedent( + """\ + from aiida.orm import Data + node = Data().store() + print(node.pk) + """ + ) group_label = 'legacy-group-name' with tempfile.NamedTemporaryFile(mode='w+') as fhandle: @@ -409,7 +430,7 @@ def test_legacy_autogroup_name(self): _ = load_node(pk) # Check if the node can be loaded queryb = QueryBuilder().append(Node, filters={'id': pk}, tag='node') - queryb.append(Group, with_node='node', filters={'type_string': 'auto.run'}, project='*') + queryb.append(AutoGroup, with_node='node', project='*') all_auto_groups = queryb.all() self.assertEqual( len(all_auto_groups), 1, 'There should be only one autogroup associated with the node just created' diff --git a/tests/orm/data/test_upf.py b/tests/orm/data/test_upf.py index 228f8d9b77..02922bc60f 100644 --- a/tests/orm/data/test_upf.py +++ b/tests/orm/data/test_upf.py @@ -10,7 +10,6 @@ """ This module contains tests for UpfData and UpfData related functions. """ - import errno import tempfile import shutil @@ -95,8 +94,8 @@ def setUp(self): def tearDown(self): """Delete all groups and destroy the temporary directory created.""" - for group in orm.Group.objects.find(filters={'type_string': orm.GroupTypeString.UPFGROUP_TYPE.value}): - orm.Group.objects.delete(group.pk) + for group in orm.UpfFamily.objects.find(): + orm.UpfFamily.objects.delete(group.pk) try: shutil.rmtree(self.temp_dir) @@ -122,32 +121,31 @@ def test_get_upf_family_names(self): """Test the `UpfData.get_upf_family_names` method.""" label = 'family' - family, _ = orm.Group.objects.get_or_create(label=label, type_string=orm.GroupTypeString.UPFGROUP_TYPE.value) + family, _ = orm.UpfFamily.objects.get_or_create(label=label) family.add_nodes([self.pseudo_barium]) family.store() - self.assertEqual({group.label for group in orm.UpfData.get_upf_groups()}, {label}) + self.assertEqual({group.label for group in orm.UpfFamily.objects.all()}, {label}) self.assertEqual(self.pseudo_barium.get_upf_family_names(), [label]) def test_get_upf_groups(self): """Test the `UpfData.get_upf_groups` class method.""" - type_string = orm.GroupTypeString.UPFGROUP_TYPE.value label_01 = 'family_01' label_02 = 'family_02' user = orm.User(email='alternate@localhost').store() - self.assertEqual(orm.UpfData.get_upf_groups(), []) + self.assertEqual(orm.UpfFamily.objects.all(), []) # Create group with default user and add `Ba` pseudo - family_01, _ = orm.Group.objects.get_or_create(label=label_01, type_string=type_string) + family_01, _ = orm.UpfFamily.objects.get_or_create(label=label_01) family_01.add_nodes([self.pseudo_barium]) family_01.store() self.assertEqual({group.label for group in orm.UpfData.get_upf_groups()}, {label_01}) # Create group with different user and add `O` pseudo - family_02, _ = orm.Group.objects.get_or_create(label=label_02, type_string=type_string, user=user) + family_02, _ = orm.UpfFamily.objects.get_or_create(label=label_02, user=user) family_02.add_nodes([self.pseudo_oxygen]) family_02.store() diff --git a/tests/orm/test_groups.py b/tests/orm/test_groups.py index ce2797daad..67b189195c 100644 --- a/tests/orm/test_groups.py +++ b/tests/orm/test_groups.py @@ -8,6 +8,8 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Test for the Group ORM class.""" +import pytest + from aiida import orm from aiida.backends.testbase import AiidaTestCase from aiida.common import exceptions @@ -272,3 +274,54 @@ def test_group_uuid_hashing_for_querybuidler(self): # And that the results are correct self.assertEqual(builder.count(), 1) self.assertEqual(builder.first()[0], group.id) + + +class TestGroupsSubclasses(AiidaTestCase): + """Test rules around creating `Group` subclasses.""" + + @staticmethod + def test_creation_registered(): + """Test rules around creating registered `Group` subclasses.""" + group = orm.AutoGroup('some-label') + assert isinstance(group, orm.AutoGroup) + assert group.type_string == 'core.auto' + + group, _ = orm.AutoGroup.objects.get_or_create('some-auto-group') + assert isinstance(group, orm.AutoGroup) + assert group.type_string == 'core.auto' + + @staticmethod + def test_creation_unregistered(): + """Test rules around creating `Group` subclasses without a registered entry point.""" + + # Defining an unregistered subclas should issue a warning and its type string should be set to `None` + with pytest.warns(UserWarning): + + class SubGroup(orm.Group): + pass + + assert SubGroup._type_string is None # pylint: disable=protected-access + + # Creating an instance is allowed + instance = SubGroup(label='subgroup') + assert instance._type_string is None # pylint: disable=protected-access + + # Storing the instance, however, is forbidden and should raise + with pytest.raises(exceptions.StoringNotAllowed): + instance.store() + + @staticmethod + def test_loading_unregistered(): + """Test rules around loading `Group` subclasses without a registered entry point. + + Storing instances of unregistered subclasses is not allowed so we have to create one sneakily by instantiating + a normal group and manipulating the type string directly on the database model. + """ + group = orm.Group(label='group') + group.backend_entity.dbmodel.type_string = 'unregistered.subclass' + group.store() + + with pytest.warns(UserWarning): + loaded = orm.load_group(group.pk) + + assert isinstance(loaded, orm.Group) diff --git a/tests/tools/graph/test_age.py b/tests/tools/graph/test_age.py index dddf2323c2..538087c7d7 100644 --- a/tests/tools/graph/test_age.py +++ b/tests/tools/graph/test_age.py @@ -494,7 +494,7 @@ def test_groups(self): # Rule that only gets nodes connected by the same group queryb = orm.QueryBuilder() queryb.append(orm.Node, tag='nodes_in_set') - queryb.append(orm.Group, with_node='nodes_in_set', tag='groups_considered', filters={'type_string': 'user'}) + queryb.append(orm.Group, with_node='nodes_in_set', tag='groups_considered') queryb.append(orm.Data, with_group='groups_considered') initial_node = [node2.id] @@ -513,7 +513,7 @@ def test_groups(self): # But two rules chained should get both nodes and groups... queryb = orm.QueryBuilder() queryb.append(orm.Node, tag='nodes_in_set') - queryb.append(orm.Group, with_node='nodes_in_set', filters={'type_string': 'user'}) + queryb.append(orm.Group, with_node='nodes_in_set') rule1 = UpdateRule(queryb) queryb = orm.QueryBuilder() @@ -569,7 +569,7 @@ def test_groups(self): qb1 = orm.QueryBuilder() qb1.append(orm.Node, tag='nodes_in_set') - qb1.append(orm.Group, with_node='nodes_in_set', filters={'type_string': 'user'}) + qb1.append(orm.Group, with_node='nodes_in_set') rule1 = UpdateRule(qb1, track_edges=True) qb2 = orm.QueryBuilder() diff --git a/tests/tools/groups/test_paths.py b/tests/tools/groups/test_paths.py index a6f1cdb757..b6e8940ce5 100644 --- a/tests/tools/groups/test_paths.py +++ b/tests/tools/groups/test_paths.py @@ -7,19 +7,19 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -"""Tests for GroupPath""" # pylint: disable=redefined-outer-name,unused-argument +"""Tests for GroupPath""" import pytest from aiida import orm -from aiida.tools.groups.paths import (GroupAttr, GroupPath, InvalidPath, GroupNotFoundError, NoGroupsInPathError) +from aiida.tools.groups.paths import GroupAttr, GroupPath, InvalidPath, GroupNotFoundError, NoGroupsInPathError @pytest.fixture def setup_groups(clear_database_before_test): """Setup some groups for testing.""" for label in ['a', 'a/b', 'a/c/d', 'a/c/e/g', 'a/f']: - group, _ = orm.Group.objects.get_or_create(label, type_string=orm.GroupTypeString.USER.value) + group, _ = orm.Group.objects.get_or_create(label) group.description = 'A description of {}'.format(label) yield @@ -117,16 +117,17 @@ def test_walk(setup_groups): def test_walk_with_invalid_path(clear_database_before_test): + """Test the ``GroupPath.walk`` method with invalid paths.""" for label in ['a', 'a/b', 'a/c/d', 'a/c/e/g', 'a/f', 'bad//group', 'bad/other']: - orm.Group.objects.get_or_create(label, type_string=orm.GroupTypeString.USER.value) + orm.Group.objects.get_or_create(label) group_path = GroupPath() - assert [c.path for c in sorted(group_path.walk()) - ] == ['a', 'a/b', 'a/c', 'a/c/d', 'a/c/e', 'a/c/e/g', 'a/f', 'bad', 'bad/other'] + expected = ['a', 'a/b', 'a/c', 'a/c/d', 'a/c/e', 'a/c/e/g', 'a/f', 'bad', 'bad/other'] + assert [c.path for c in sorted(group_path.walk())] == expected def test_walk_nodes(clear_database): """Test the ``GroupPath.walk_nodes()`` function.""" - group, _ = orm.Group.objects.get_or_create('a', type_string=orm.GroupTypeString.USER.value) + group, _ = orm.Group.objects.get_or_create('a') node = orm.Data() node.set_attribute_many({'i': 1, 'j': 2}) node.store() @@ -135,17 +136,18 @@ def test_walk_nodes(clear_database): assert [(r.group_path.path, r.node.attributes) for r in group_path.walk_nodes()] == [('a', {'i': 1, 'j': 2})] -def test_type_string(clear_database_before_test): - """Test that only the type_string instantiated in ``GroupPath`` is returned.""" +@pytest.mark.skip('Reenable when subclassing in the query builder is implemented (#3902)') +def test_cls(clear_database_before_test): + """Test that only instances of `cls` or its subclasses are matched by ``GroupPath``.""" for label in ['a', 'a/b', 'a/c/d', 'a/c/e/g']: - orm.Group.objects.get_or_create(label, type_string=orm.GroupTypeString.USER.value) + orm.Group.objects.get_or_create(label) for label in ['a/c/e', 'a/f']: - orm.Group.objects.get_or_create(label, type_string=orm.GroupTypeString.UPFGROUP_TYPE.value) + orm.UpfFamily.objects.get_or_create(label) group_path = GroupPath() assert sorted([c.path for c in group_path.walk()]) == ['a', 'a/b', 'a/c', 'a/c/d', 'a/c/e', 'a/c/e/g'] - group_path = GroupPath(type_string=orm.GroupTypeString.UPFGROUP_TYPE.value) + group_path = GroupPath(cls=orm.UpfFamily) assert sorted([c.path for c in group_path.walk()]) == ['a', 'a/c', 'a/c/e', 'a/f'] - assert GroupPath('a/b/c') != GroupPath('a/b/c', type_string=orm.GroupTypeString.UPFGROUP_TYPE.value) + assert GroupPath('a/b/c') != GroupPath('a/b/c', cls=orm.UpfFamily) def test_attr(clear_database_before_test): diff --git a/tests/tools/importexport/test_prov_redesign.py b/tests/tools/importexport/test_prov_redesign.py index 37f9a485a0..5ef849c51c 100644 --- a/tests/tools/importexport/test_prov_redesign.py +++ b/tests/tools/importexport/test_prov_redesign.py @@ -229,7 +229,7 @@ def test_group_name_and_type_change(self, temp_dir): groups_type_string = [g.type_string for g in [group_user, group_upf]] # Assert correct type strings exists prior to export - self.assertListEqual(groups_type_string, ['user', 'data.upf']) + self.assertListEqual(groups_type_string, ['core', 'core.upf']) # Export node filename = os.path.join(temp_dir, 'export.tar.gz') @@ -268,4 +268,4 @@ def test_group_name_and_type_change(self, temp_dir): # Check type_string content of "import group" import_group = orm.load_group(imported_groups_uuid[0]) - self.assertEqual(import_group.type_string, 'auto.import') + self.assertEqual(import_group.type_string, 'core.import')