diff --git a/src/aiida/orm/nodes/data/array/kpoints.py b/src/aiida/orm/nodes/data/array/kpoints.py index ddf0cd6a96..c048d56eb4 100644 --- a/src/aiida/orm/nodes/data/array/kpoints.py +++ b/src/aiida/orm/nodes/data/array/kpoints.py @@ -223,16 +223,24 @@ def set_cell_from_structure(self, structuredata): :param structuredata: an instance of StructureData """ - from aiida.orm import StructureData + from aiida.orm.nodes.data.structure import StructureData as LegacyStructureData + from aiida.orm.nodes.data.structure import has_atomistic - if not isinstance(structuredata, StructureData): - raise ValueError( - 'An instance of StructureData should be passed to ' 'the KpointsData, found instead {}'.format( - structuredata.__class__ - ) + if not has_atomistic(): + structures_classes = (LegacyStructureData,) # type: tuple + else: + from aiida_atomistic import StructureData # type: ignore[import-untyped] + + structures_classes = (LegacyStructureData, StructureData) + + if not isinstance(structuredata, structures_classes): + raise TypeError( + f'An instance of {structures_classes} should be passed to ' + f'the KpointsData, found instead {type(structuredata)}' ) - cell = structuredata.cell - self.set_cell(cell, structuredata.pbc) + else: + cell = structuredata.cell + self.set_cell(cell, structuredata.pbc) def set_cell(self, cell, pbc=None): """Set a cell to be used for symmetry analysis. diff --git a/src/aiida/orm/nodes/data/structure.py b/src/aiida/orm/nodes/data/structure.py index f046898b52..a57abfc078 100644 --- a/src/aiida/orm/nodes/data/structure.py +++ b/src/aiida/orm/nodes/data/structure.py @@ -102,6 +102,15 @@ def has_pymatgen(): return True +def has_atomistic() -> bool: + """:return: True if the StructureData and StructureDataMutable from aiida-atomistic module can be imported, False otherwise.""" + try: + import aiida_atomistic # noqa: F401 + except ImportError: + return False + return True + + def get_pymatgen_version(): """:return: string with pymatgen version, None if can not import.""" if not has_pymatgen(): @@ -1876,6 +1885,32 @@ def _get_object_pymatgen_molecule(self, **kwargs): positions = [list(site.position) for site in self.sites] return Molecule(species, positions) + def to_atomistic(self): + """ + Returns the atomistic StructureData version of the orm.StructureData one. + """ + if not has_atomistic(): + raise ImportError( + 'aiida-atomistic plugin is not installed, \ + please install it to have full support for atomistic structures' + ) + else: + from aiida_atomistic import StructureData, StructureDataMutable + + atomistic = StructureDataMutable() + atomistic.set_pbc(self.pbc) + atomistic.set_cell(self.cell) + + for site in self.sites: + atomistic.add_atom( + symbols=self.get_kind(site.kind_name).symbol, + masses=self.get_kind(site.kind_name).mass, + positions=site.position, + kinds=site.kind_name, + ) + + return StructureData.from_mutable(atomistic) + class Kind: """This class contains the information about the species (kinds) of the system. diff --git a/tests/orm/nodes/data/test_kpoints.py b/tests/orm/nodes/data/test_kpoints.py index 112a46859e..a8461e9a69 100644 --- a/tests/orm/nodes/data/test_kpoints.py +++ b/tests/orm/nodes/data/test_kpoints.py @@ -10,15 +10,27 @@ import numpy as np import pytest -from aiida.orm import KpointsData, StructureData, load_node +from aiida.orm import KpointsData, load_node +from aiida.orm import StructureData as LegacyStructureData +from aiida.orm.nodes.data.structure import has_atomistic +skip_atomistic = pytest.mark.skipif(not has_atomistic(), reason='aiida-atomistic not installed') +if not has_atomistic(): + structures_classes = [LegacyStructureData, pytest.param('StructureData', marks=skip_atomistic)] +else: + from aiida_atomistic import StructureData # type: ignore[import-untyped] + + structures_classes = [LegacyStructureData, StructureData] + + +@pytest.mark.parametrize('structure_class', structures_classes) class TestKpoints: """Test for the `Kpointsdata` class.""" @pytest.fixture(autouse=True) - def init_profile(self): - """Initialize the profile.""" + def generate_structure(self, structure_class): + """Generate the StructureData.""" alat = 5.430 # angstrom cell = [ [ @@ -34,10 +46,13 @@ def init_profile(self): [0.5 * alat, 0.0, 0.5 * alat], ] self.alat = alat - structure = StructureData(cell=cell) + structure = LegacyStructureData(cell=cell) structure.append_atom(position=(0.000 * alat, 0.000 * alat, 0.000 * alat), symbols=['Si']) structure.append_atom(position=(0.250 * alat, 0.250 * alat, 0.250 * alat), symbols=['Si']) - self.structure = structure + if structure_class == LegacyStructureData: + self.structure = structure + else: + self.structure = LegacyStructureData.to_atomistic(structure) # Define the expected reciprocal cell val = 2.0 * np.pi / alat self.expected_reciprocal_cell = np.array([[val, val, -val], [-val, val, val], [val, -val, val]]) diff --git a/tests/test_dataclasses.py b/tests/test_dataclasses.py index 799b6454d7..1ea06f0569 100644 --- a/tests/test_dataclasses.py +++ b/tests/test_dataclasses.py @@ -26,6 +26,7 @@ get_formula, get_pymatgen_version, has_ase, + has_atomistic, has_pymatgen, has_spglib, ) @@ -67,6 +68,7 @@ def simplify(string): skip_spglib = pytest.mark.skipif(not has_spglib(), reason='Unable to import spglib') skip_pycifrw = pytest.mark.skipif(not has_pycifrw(), reason='Unable to import PyCifRW') skip_pymatgen = pytest.mark.skipif(not has_pymatgen(), reason='Unable to import pymatgen') +skip_atomistic = pytest.mark.skipif(not has_atomistic(), reason='Unable to import aiida-atomistic') @skip_pymatgen @@ -1851,6 +1853,26 @@ def test_clone(self): assert round(abs(c.sites[1].position[i] - 1.0), 7) == 0 +@skip_atomistic +def test_to_atomistic(self): + """Test the conversion from orm.StructureData to the atomistic structure.""" + + # Create a structure with a single atom + from aiida_atomistic import StructureData as AtomisticStructureData + + legacy = StructureData(cell=((1.0, 0.0, 0.0), (0.0, 2.0, 0.0), (0.0, 0.0, 3.0))) + legacy.append_atom(position=(0.0, 0.0, 0.0), symbols=['Ba'], name='Ba1') + + # Convert to atomistic structure + structure = legacy.to_atomistic() + + # Check that the structure is as expected + assert isinstance(structure, AtomisticStructureData) + assert structure.properties.sites[0].kinds == legacy.sites[0].kind_name + assert structure.properties.sites[0].positions == list(legacy.sites[0].position) + assert structure.properties.cell == legacy.cell + + class TestStructureDataFromAse: """Tests the creation of Sites from/to a ASE object."""