diff --git a/aiida_siesta/data/psf.py b/aiida_siesta/data/psf.py index 7ff6be91..3b939b86 100644 --- a/aiida_siesta/data/psf.py +++ b/aiida_siesta/data/psf.py @@ -328,8 +328,13 @@ def get_psf_family_names(self): Get the list of all psf family names to which the pseudo belongs """ from aiida.orm import Group + from aiida.orm import QueryBuilder - return [_.name for _ in Group.query(nodes=self, type_string=self.psffamily_type_string)] + query = QueryBuilder() + query.append(Group, filters={'type_string': {'==': self.psffamily_type_string}}, tag='group', project='label') + query.append(PsfData, filters={'id': {'==': self.id}}, with_group='group') + + return [_[0] for _ in query.all()] @property def element(self): @@ -401,24 +406,23 @@ def get_psf_groups(cls, filter_elements=None, user=None): for the username (that is, the user email). """ from aiida.orm import Group + from aiida.orm import QueryBuilder + from aiida.orm import User + + query = QueryBuilder() + filters = {'type_string': {'==': cls.psffamily_type_string}} - group_query_params = {"type_string": cls.psffamily_type_string} + query.append(Group, filters=filters, tag='group', project='*') - if user is not None: - group_query_params['user'] = user + if user: + query.append(User, filters={'email': {'==': user}}, with_group='group') if isinstance(filter_elements, str): filter_elements = [filter_elements] if filter_elements is not None: - actual_filter_elements = {_.capitalize() for _ in filter_elements} - - group_query_params['node_attributes'] = {'element': actual_filter_elements} + query.append(PsfData, filters={'attributes.element': {'in': filter_elements}}, with_group='group') - all_psf_groups = Group.query(**group_query_params) + query.order_by({Group: {'id': 'asc'}}) - groups = [(g.name, g) for g in all_psf_groups] - # Sort by name - groups.sort() - # Return the groups, without name - return [_[1] for _ in groups] + return [_[0] for _ in query.all()] diff --git a/aiida_siesta/data/psml.py b/aiida_siesta/data/psml.py index 8c4f6582..c61c29a5 100644 --- a/aiida_siesta/data/psml.py +++ b/aiida_siesta/data/psml.py @@ -329,8 +329,13 @@ def get_psml_family_names(self): Get the list of all psml family names to which the pseudo belongs """ from aiida.orm import Group + from aiida.orm import QueryBuilder - return [_.name for _ in Group.query(nodes=self, type_string=self.psmlfamily_type_string)] + query = QueryBuilder() + query.append(Group, filters={'type_string': {'==': self.psmlfamily_type_string}}, tag='group', project='label') + query.append(PsmlData, filters={'id': {'==': self.id}}, with_group='group') + + return [_[0] for _ in query.all()] @property def element(self): @@ -403,24 +408,23 @@ def get_psml_groups(cls, filter_elements=None, user=None): for the username (that is, the user email). """ from aiida.orm import Group + from aiida.orm import QueryBuilder + from aiida.orm import User + + query = QueryBuilder() + filters = {'type_string': {'==': cls.psmlfamily_type_string}} - group_query_params = {"type_string": cls.psmlfamily_type_string} + query.append(Group, filters=filters, tag='group', project='*') - if user is not None: - group_query_params['user'] = user + if user: + query.append(User, filters={'email': {'==': user}}, with_group='group') if isinstance(filter_elements, str): filter_elements = [filter_elements] if filter_elements is not None: - actual_filter_elements = {_.capitalize() for _ in filter_elements} - - group_query_params['node_attributes'] = {'element': actual_filter_elements} + query.append(PsmlData, filters={'attributes.element': {'in': filter_elements}}, with_group='group') - all_psml_groups = Group.query(**group_query_params) + query.order_by({Group: {'id': 'asc'}}) - groups = [(g.name, g) for g in all_psml_groups] - # Sort by name - groups.sort() - # Return the groups, without name - return [_[1] for _ in groups] + return [_[0] for _ in query.all()] diff --git a/tests/data/test_pseudos.py b/tests/data/test_pseudos.py new file mode 100644 index 00000000..2b837a11 --- /dev/null +++ b/tests/data/test_pseudos.py @@ -0,0 +1,24 @@ +def test_pseudos_classmethods(generate_psf_data, generate_psml_data): + + from aiida_siesta.data.psml import PsmlData + from aiida_siesta.data.psf import PsfData + + assert PsfData.get_psf_groups() == [] + assert PsmlData.get_psml_groups() == [] + + #get_psf_group(cls, group_label) + + #from_md5 + + #get_or_create + +def test_pseudo(generate_psf_data, generate_psml_data): + + psf = generate_psf_data('Si') + psml = generate_psml_data('Si') + + assert psf.get_psf_family_names() == [] + assert psml.get_psml_family_names() == [] + + #set_file + #assert 'md5' in psf.attributes