diff --git a/aiida/storage/sqlite_temp/backend.py b/aiida/storage/sqlite_temp/backend.py index 685c192930..dd85d82f89 100644 --- a/aiida/storage/sqlite_temp/backend.py +++ b/aiida/storage/sqlite_temp/backend.py @@ -10,14 +10,19 @@ """Definition of the ``SqliteTempBackend`` backend.""" from __future__ import annotations -from contextlib import contextmanager +from contextlib import contextmanager, nullcontext +import functools from functools import cached_property -from typing import Any, Iterator, Sequence +import hashlib +import os +from pathlib import Path +import shutil +from typing import Any, BinaryIO, Iterator, Sequence from sqlalchemy.orm import Session -from aiida.common.exceptions import ClosedStorage -from aiida.manage import Profile, get_config_option +from aiida.common.exceptions import ClosedStorage, IntegrityError +from aiida.manage import Profile from aiida.orm.entities import EntityTypes from aiida.orm.implementation import BackendEntity, StorageBackend from aiida.repository.backend.sandbox import SandboxRepositoryBackend @@ -42,7 +47,8 @@ def create_profile( name: str = 'temp', default_user_email='user@email.com', options: dict | None = None, - debug: bool = False + debug: bool = False, + repo_path: str | Path | None = None, ) -> Profile: """Create a new profile instance for this backend, from the path to the zip file.""" return Profile( @@ -52,6 +58,7 @@ def create_profile( 'backend': 'core.sqlite_temp', 'config': { 'debug': debug, + 'repo_path': repo_path, } }, 'process_control': { @@ -81,7 +88,7 @@ def migrate(cls, profile: Profile): def __init__(self, profile: Profile): super().__init__(profile) self._session: Session | None = None - self._repo: SandboxRepositoryBackend | None = None + self._repo: SandboxShaRepositoryBackend | None = None self._globals: dict[str, tuple[Any, str | None]] = {} self._closed = False self.get_session() # load the database on initialization @@ -124,12 +131,13 @@ def get_session(self) -> Session: self._session.commit() return self._session - def get_repository(self) -> SandboxRepositoryBackend: + def get_repository(self) -> SandboxShaRepositoryBackend: if self._closed: raise ClosedStorage(str(self)) if self._repo is None: # to-do this does not seem to be removing the folder on garbage collection? - self._repo = SandboxRepositoryBackend(filepath=get_config_option('storage.sandbox') or None) + repo_path = self.profile.storage_config.get('repo_path') + self._repo = SandboxShaRepositoryBackend(filepath=Path(repo_path) if repo_path else None) return self._repo @property @@ -199,11 +207,122 @@ def get_info(self, detailed: bool = False) -> dict: # results['repository'] = self.get_repository().get_info(detailed) return results + @staticmethod + @functools.lru_cache(maxsize=18) + def _get_mapper_from_entity(entity_type: EntityTypes, with_pk: bool): + """Return the Sqlalchemy mapper and fields corresponding to the given entity. + + :param with_pk: if True, the fields returned will include the primary key + """ + from sqlalchemy import inspect + + from aiida.storage.sqlite_zip.models import ( + DbAuthInfo, + DbComment, + DbComputer, + DbGroup, + DbGroupNodes, + DbLink, + DbLog, + DbNode, + DbUser, + ) + + model = { + EntityTypes.AUTHINFO: DbAuthInfo, + EntityTypes.COMMENT: DbComment, + EntityTypes.COMPUTER: DbComputer, + EntityTypes.GROUP: DbGroup, + EntityTypes.LOG: DbLog, + EntityTypes.NODE: DbNode, + EntityTypes.USER: DbUser, + EntityTypes.LINK: DbLink, + EntityTypes.GROUP_NODE: DbGroupNodes, + }[entity_type] + mapper = inspect(model).mapper + keys = {key for key, col in mapper.c.items() if with_pk or col not in mapper.primary_key} + return mapper, keys + def bulk_insert(self, entity_type: EntityTypes, rows: list[dict], allow_defaults: bool = False) -> list[int]: - raise NotImplementedError + mapper, keys = self._get_mapper_from_entity(entity_type, False) + if not rows: + return [] + if entity_type in (EntityTypes.COMPUTER, EntityTypes.LOG, EntityTypes.AUTHINFO): + for row in rows: + row['_metadata'] = row.pop('metadata') + if allow_defaults: + for row in rows: + if not keys.issuperset(row): + raise IntegrityError(f'Incorrect fields given for {entity_type}: {set(row)} not subset of {keys}') + else: + for row in rows: + if set(row) != keys: + raise IntegrityError(f'Incorrect fields given for {entity_type}: {set(row)} != {keys}') + session = self.get_session() + with (nullcontext() if self.in_transaction else self.transaction()): + session.bulk_insert_mappings(mapper, rows, render_nulls=True, return_defaults=True) + return [row['id'] for row in rows] def bulk_update(self, entity_type: EntityTypes, rows: list[dict]) -> None: - raise NotImplementedError + mapper, keys = self._get_mapper_from_entity(entity_type, True) + if not rows: + return None + for row in rows: + if 'id' not in row: + raise IntegrityError(f"'id' field not given for {entity_type}: {set(row)}") + if not keys.issuperset(row): + raise IntegrityError(f'Incorrect fields given for {entity_type}: {set(row)} not subset of {keys}') + session = self.get_session() + with (nullcontext() if self.in_transaction else self.transaction()): + session.bulk_update_mappings(mapper, rows) def delete_nodes_and_connections(self, pks_to_delete: Sequence[int]): raise NotImplementedError + + +class SandboxShaRepositoryBackend(SandboxRepositoryBackend): + """A sandbox repository backend that uses the sha256 of the file as the key. + + This allows for compatibility with the archive format (i.e. `SqliteZipBackend`). + Which allows for temporary profiles to be exported and imported. + """ + + @property + def key_format(self) -> str | None: + return 'sha256' + + def get_object_hash(self, key: str) -> str: + return key + + def _put_object_from_filelike(self, handle: BinaryIO) -> str: + """Store the byte contents of a file in the repository. + + :param handle: filelike object with the byte content to be stored. + :return: the generated fully qualified identifier for the object within the repository. + :raises TypeError: if the handle is not a byte stream. + """ + # we first compute the hash of the file contents + hsh = hashlib.sha256() + position = handle.tell() + while True: + buf = handle.read(1024 * 1024) + if not buf: + break + hsh.update(buf) + key = hsh.hexdigest() + + filepath = os.path.join(self.sandbox.abspath, key) + if not os.path.exists(filepath): + # if a file with this hash does not already exist + # then we reset the file pointer and copy the contents + handle.seek(position) + with open(filepath, 'wb') as target: + shutil.copyfileobj(handle, target) + + return key + + def get_info(self, detailed: bool = False, **kwargs) -> dict: + return {'objects': {'count': len(list(self.list_objects()))}} + + def maintain(self, dry_run: bool = False, live: bool = True, **kwargs) -> None: + pass diff --git a/aiida/tools/archive/imports.py b/aiida/tools/archive/imports.py index 61dc349e8c..055214a3ec 100644 --- a/aiida/tools/archive/imports.py +++ b/aiida/tools/archive/imports.py @@ -778,7 +778,7 @@ def _transform(row): def _transform(row): # to-do this is probably not the most efficient way to do this uuid, new_mtime, new_comment = row - cmt = orm.Comment.collection.get(uuid=uuid) + cmt = orm.comments.CommentCollection(orm.Comment, backend).get(uuid=uuid) if cmt.mtime < new_mtime: cmt.set_mtime(new_mtime) cmt.set_content(new_comment) @@ -1086,7 +1086,7 @@ def _make_import_group( break else: raise ImportUniquenessError(f'New import Group has existing label {label!r} and re-labelling failed') - dummy_orm = orm.ImportGroup(label) + dummy_orm = orm.ImportGroup(label, backend=backend_to) row = { 'label': label, 'description': 'Group generated by archive import', diff --git a/tests/storage/sqlite/test_archive.py b/tests/storage/sqlite/test_archive.py new file mode 100644 index 0000000000..b370b67966 --- /dev/null +++ b/tests/storage/sqlite/test_archive.py @@ -0,0 +1,47 @@ +# -*- coding: utf-8 -*- +"""Test export and import of AiiDA archives to/from a temporary profile.""" +from pathlib import Path + +from aiida import orm +from aiida.storage.sqlite_temp import SqliteTempBackend +from aiida.tools.archive import create_archive, import_archive + + +def test_basic(tmp_path): + """Test the creation of an archive and re-import.""" + filename = Path(tmp_path / 'export.aiida') + + # generate a temporary backend + profile1 = SqliteTempBackend.create_profile(repo_path=str(tmp_path / 'repo1')) + backend1 = SqliteTempBackend(profile1) + + # add simple node + dict_data = {'key1': 'value1'} + node = orm.Dict(dict_data, backend=backend1).store() + # add a comment to the node + node.base.comments.add('test comment', backend1.default_user) + # add node with repository data + path = Path(tmp_path / 'test.txt') + text_data = 'test' + path.write_text(text_data, encoding='utf-8') + orm.SinglefileData(str(path), backend=backend1).store() + + # export to archive + create_archive(None, backend=backend1, filename=filename) + + # create a new temporary backend and import + profile2 = SqliteTempBackend.create_profile(repo_path=str(tmp_path / 'repo2')) + backend2 = SqliteTempBackend(profile2) + import_archive(filename, backend=backend2) + + # check that the nodes are there + assert orm.QueryBuilder(backend=backend2).append(orm.Data).count() == 2 + + # check that we can retrieve the attributes and comment data + node = orm.QueryBuilder(backend=backend2).append(orm.Dict).first(flat=True) + assert node.get_dict() == dict_data + assert len(node.base.comments.all()) == 1 + + # check that we can retrieve the repository data + node = orm.QueryBuilder(backend=backend2).append(orm.SinglefileData).first(flat=True) + assert node.get_content() == text_data