diff --git a/aiida_siesta/commands/data_psf.py b/aiida_siesta/commands/data_psf.py index e0065ac2..807cec69 100644 --- a/aiida_siesta/commands/data_psf.py +++ b/aiida_siesta/commands/data_psf.py @@ -55,22 +55,14 @@ def psf_listfamilies(elements, with_description): """ from aiida import orm from aiida.plugins import DataFactory - from aiida_siesta.data.psf import PSFGROUP_TYPE + from aiida_siesta.groups.pseudos import PsfFamily PsfData = DataFactory('siesta.psf') # pylint: disable=invalid-name query = orm.QueryBuilder() query.append(PsfData, tag='psfdata') if elements is not None: query.add_filter(PsfData, {'attributes.element': {'in': elements}}) - query.append( - orm.Group, - with_node='psfdata', - tag='group', - project=['label', 'description'], - filters={'type_string': { - '==': PSFGROUP_TYPE - }} - ) + query.append(PsfFamily, with_node='psfdata', tag='group', project=['label', 'description']) query.distinct() if query.count() > 0: diff --git a/aiida_siesta/commands/data_psml.py b/aiida_siesta/commands/data_psml.py index 381b654e..866958b6 100644 --- a/aiida_siesta/commands/data_psml.py +++ b/aiida_siesta/commands/data_psml.py @@ -55,22 +55,14 @@ def psml_listfamilies(elements, with_description): """ from aiida import orm from aiida.plugins import DataFactory - from aiida_siesta.data.psml import PSMLGROUP_TYPE + from aiida_siesta.groups.pseudos import PsmlFamily PsmlData = DataFactory('siesta.psml') # pylint: disable=invalid-name query = orm.QueryBuilder() query.append(PsmlData, tag='psmldata') if elements is not None: query.add_filter(PsmlData, {'attributes.element': {'in': elements}}) - query.append( - orm.Group, - with_node='psmldata', - tag='group', - project=['label', 'description'], - filters={'type_string': { - '==': PSMLGROUP_TYPE - }} - ) + query.append(PsmlFamily, with_node='psmldata', tag='group', project=['label', 'description']) query.distinct() if query.count() > 0: diff --git a/aiida_siesta/data/psf.py b/aiida_siesta/data/psf.py index 7ff6be91..af32362e 100644 --- a/aiida_siesta/data/psf.py +++ b/aiida_siesta/data/psf.py @@ -3,12 +3,9 @@ """ import io -from aiida.common.utils import classproperty from aiida.common.files import md5_file from aiida.orm.nodes import SinglefileData -PSFGROUP_TYPE = 'data.psf.family' - def get_pseudos_from_structure(structure, family_name): """ @@ -64,6 +61,7 @@ def upload_psf_family(folder, group_label, group_description, stop_if_existing=T from aiida.common import AIIDA_LOGGER as aiidalogger from aiida.common.exceptions import UniquenessError from aiida.orm.querybuilder import QueryBuilder + from aiida_siesta.groups.pseudos import PsfFamily if not os.path.isdir(folder): raise ValueError("folder must be a directory") @@ -79,9 +77,7 @@ def upload_psf_family(folder, group_label, group_description, stop_if_existing=T nfiles = len(files) automatic_user = orm.User.objects.get_default() - group, group_created = orm.Group.objects.get_or_create( - label=group_label, type_string=PSFGROUP_TYPE, user=automatic_user - ) + group, group_created = PsfFamily.objects.get_or_create(label=group_label, user=automatic_user) if group.user.email != automatic_user.email: raise UniquenessError( @@ -257,10 +253,6 @@ def get_or_create(cls, filename, use_first=False, store_psf=True): return (pseudos[0], False) - @classproperty - def psffamily_type_string(cls): # pylint: disable=no-self-argument,no-self-use - return PSFGROUP_TYPE - def store(self, *args, **kwargs): # pylint: disable=arguments-differ """ Store the node, reparsing the file so that the md5 and the element @@ -327,9 +319,14 @@ def get_psf_family_names(self): """ Get the list of all psf family names to which the pseudo belongs """ - from aiida.orm import Group + from aiida.orm import QueryBuilder + from aiida_siesta.groups.pseudos import PsfFamily - return [_.name for _ in Group.query(nodes=self, type_string=self.psffamily_type_string)] + query = QueryBuilder() + query.append(PsfFamily, tag='group', project='label') + query.append(PsfData, filters={'id': {'==': self.id}}, with_group='group') + + return [_[0] for _ in query.all()] @property def element(self): @@ -383,9 +380,9 @@ def get_psf_group(cls, group_label): """ Return the PsfFamily group with the given name. """ - from aiida.orm import Group + from aiida_siesta.groups.pseudos import PsfFamily - return Group.get(label=group_label, type_string=cls.psffamily_type_string) + return PsfFamily.get(label=group_label) @classmethod def get_psf_groups(cls, filter_elements=None, user=None): @@ -400,25 +397,22 @@ def get_psf_groups(cls, filter_elements=None, user=None): If defined, it should be either a DbUser instance, or a string for the username (that is, the user email). """ - from aiida.orm import Group + from aiida.orm import QueryBuilder + from aiida.orm import User + from aiida_siesta.groups.pseudos import PsfFamily - group_query_params = {"type_string": cls.psffamily_type_string} + query = QueryBuilder() + query.append(PsfFamily, tag='group', project='*') - if user is not None: - group_query_params['user'] = user + if user: + query.append(User, filters={'email': {'==': user}}, with_group='group') if isinstance(filter_elements, str): filter_elements = [filter_elements] if filter_elements is not None: - actual_filter_elements = {_.capitalize() for _ in filter_elements} - - group_query_params['node_attributes'] = {'element': actual_filter_elements} + query.append(PsfData, filters={'attributes.element': {'in': filter_elements}}, with_group='group') - all_psf_groups = Group.query(**group_query_params) + query.order_by({PsfFamily: {'id': 'asc'}}) - groups = [(g.name, g) for g in all_psf_groups] - # Sort by name - groups.sort() - # Return the groups, without name - return [_[1] for _ in groups] + return [_[0] for _ in query.all()] diff --git a/aiida_siesta/data/psml.py b/aiida_siesta/data/psml.py index 8c4f6582..d3d9d525 100644 --- a/aiida_siesta/data/psml.py +++ b/aiida_siesta/data/psml.py @@ -2,13 +2,12 @@ This module manages the PSML pseudopotentials in the local repository. """ -from aiida.common.utils import classproperty from aiida.common.files import md5_file from aiida.orm.nodes import SinglefileData # See LICENSE.txt and AUTHORS.txt -PSMLGROUP_TYPE = 'data.psml.family' +#PSMLGROUP_TYPE = 'data.psml.family' def get_pseudos_from_structure(structure, family_name): @@ -65,6 +64,7 @@ def upload_psml_family(folder, group_label, group_description, stop_if_existing= from aiida.common import AIIDA_LOGGER as aiidalogger from aiida.common.exceptions import UniquenessError from aiida.orm.querybuilder import QueryBuilder + from aiida_siesta.groups.pseudos import PsmlFamily if not os.path.isdir(folder): raise ValueError("folder must be a directory") @@ -80,9 +80,10 @@ def upload_psml_family(folder, group_label, group_description, stop_if_existing= nfiles = len(files) automatic_user = orm.User.objects.get_default() - group, group_created = orm.Group.objects.get_or_create( - label=group_label, type_string=PSMLGROUP_TYPE, user=automatic_user - ) + #group, group_created = orm.Group.objects.get_or_create( + # label=group_label, type_string=PSMLGROUP_TYPE, user=automatic_user + #) + group, group_created = PsmlFamily.objects.get_or_create(label=group_label, user=automatic_user) if group.user.email != automatic_user.email: raise UniquenessError( @@ -258,9 +259,9 @@ def get_or_create(cls, filename, use_first=False, store_psml=True): return (pseudos[0], False) - @classproperty - def psmlfamily_type_string(cls): # pylint: disable=no-self-argument,no-self-use - return PSMLGROUP_TYPE + #@classproperty + #def psmlfamily_type_string(cls): # pylint: disable=no-self-argument,no-self-use + # return PSMLGROUP_TYPE def store(self, *args, **kwargs): # pylint: disable=arguments-differ """ @@ -328,9 +329,15 @@ def get_psml_family_names(self): """ Get the list of all psml family names to which the pseudo belongs """ - from aiida.orm import Group + #from aiida.orm import Group + from aiida.orm import QueryBuilder + from aiida_siesta.groups.pseudos import PsmlFamily - return [_.name for _ in Group.query(nodes=self, type_string=self.psmlfamily_type_string)] + query = QueryBuilder() + query.append(PsmlFamily, tag='group', project='label') + query.append(PsmlData, filters={'id': {'==': self.id}}, with_group='group') + + return [_[0] for _ in query.all()] @property def element(self): @@ -385,9 +392,10 @@ def get_psml_group(cls, group_label): """ Return the PsmlFamily group with the given name. """ - from aiida.orm import Group + #from aiida.orm import Group + from aiida_siesta.groups.pseudos import PsmlFamily - return Group.get(label=group_label, type_string=cls.psmlfamily_type_string) + return PsmlFamily.get(label=group_label) @classmethod def get_psml_groups(cls, filter_elements=None, user=None): @@ -402,25 +410,25 @@ def get_psml_groups(cls, filter_elements=None, user=None): If defined, it should be either a DbUser instance, or a string for the username (that is, the user email). """ - from aiida.orm import Group + #from aiida.orm import Group + from aiida.orm import QueryBuilder + from aiida.orm import User + from aiida_siesta.groups.pseudos import PsmlFamily + + query = QueryBuilder() + #filters = {'type_string': {'==': cls.psmlfamily_type_string}} - group_query_params = {"type_string": cls.psmlfamily_type_string} + query.append(PsmlFamily, tag='group', project='*') - if user is not None: - group_query_params['user'] = user + if user: + query.append(User, filters={'email': {'==': user}}, with_group='group') if isinstance(filter_elements, str): filter_elements = [filter_elements] if filter_elements is not None: - actual_filter_elements = {_.capitalize() for _ in filter_elements} - - group_query_params['node_attributes'] = {'element': actual_filter_elements} + query.append(PsmlData, filters={'attributes.element': {'in': filter_elements}}, with_group='group') - all_psml_groups = Group.query(**group_query_params) + query.order_by({PsmlFamily: {'id': 'asc'}}) - groups = [(g.name, g) for g in all_psml_groups] - # Sort by name - groups.sort() - # Return the groups, without name - return [_[1] for _ in groups] + return [_[0] for _ in query.all()] diff --git a/aiida_siesta/groups/pseudos.py b/aiida_siesta/groups/pseudos.py new file mode 100644 index 00000000..003c79b4 --- /dev/null +++ b/aiida_siesta/groups/pseudos.py @@ -0,0 +1,9 @@ +from aiida.orm.groups import Group + + +class PsfFamily(Group): + """Group that represents a pseudo potential family containing `PsfData` nodes.""" + + +class PsmlFamily(Group): + """Group that represents a pseudo potential family containing `PsmlData` nodes.""" diff --git a/setup.json b/setup.json index 3deebea6..6a4e6e75 100644 --- a/setup.json +++ b/setup.json @@ -55,6 +55,10 @@ "aiida.cmdline.data": [ "psf = aiida_siesta.commands.data_psf:psfdata", "psml = aiida_siesta.commands.data_psml:psmldata" - ] + ], + "aiida.groups": [ + "data.psf.family = aiida_siesta.groups.pseudos:PsfFamily", + "data.psml.family = aiida_siesta.groups.pseudos:PsmlFamily" + ] } } diff --git a/tests/data/test_pseudos.py b/tests/data/test_pseudos.py new file mode 100644 index 00000000..2b837a11 --- /dev/null +++ b/tests/data/test_pseudos.py @@ -0,0 +1,24 @@ +def test_pseudos_classmethods(generate_psf_data, generate_psml_data): + + from aiida_siesta.data.psml import PsmlData + from aiida_siesta.data.psf import PsfData + + assert PsfData.get_psf_groups() == [] + assert PsmlData.get_psml_groups() == [] + + #get_psf_group(cls, group_label) + + #from_md5 + + #get_or_create + +def test_pseudo(generate_psf_data, generate_psml_data): + + psf = generate_psf_data('Si') + psml = generate_psml_data('Si') + + assert psf.get_psf_family_names() == [] + assert psml.get_psml_family_names() == [] + + #set_file + #assert 'md5' in psf.attributes