forked from aiidateam/aiida-core
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
SqliteTempBackend
: Allow reading and writing to archives
- Loading branch information
Showing
3 changed files
with
178 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,14 +10,19 @@ | |
"""Definition of the ``SqliteTempBackend`` backend.""" | ||
from __future__ import annotations | ||
|
||
from contextlib import contextmanager | ||
from contextlib import contextmanager, nullcontext | ||
import functools | ||
from functools import cached_property | ||
from typing import Any, Iterator, Sequence | ||
import hashlib | ||
import os | ||
from pathlib import Path | ||
import shutil | ||
from typing import Any, BinaryIO, Iterator, Sequence | ||
|
||
from sqlalchemy.orm import Session | ||
|
||
from aiida.common.exceptions import ClosedStorage | ||
from aiida.manage import Profile, get_config_option | ||
from aiida.common.exceptions import ClosedStorage, IntegrityError | ||
from aiida.manage import Profile | ||
from aiida.orm.entities import EntityTypes | ||
from aiida.orm.implementation import BackendEntity, StorageBackend | ||
from aiida.repository.backend.sandbox import SandboxRepositoryBackend | ||
|
@@ -42,7 +47,8 @@ def create_profile( | |
name: str = 'temp', | ||
default_user_email='[email protected]', | ||
options: dict | None = None, | ||
debug: bool = False | ||
debug: bool = False, | ||
repo_path: str | Path | None = None, | ||
) -> Profile: | ||
"""Create a new profile instance for this backend, from the path to the zip file.""" | ||
return Profile( | ||
|
@@ -52,6 +58,7 @@ def create_profile( | |
'backend': 'core.sqlite_temp', | ||
'config': { | ||
'debug': debug, | ||
'repo_path': repo_path, | ||
} | ||
}, | ||
'process_control': { | ||
|
@@ -81,7 +88,7 @@ def migrate(cls, profile: Profile): | |
def __init__(self, profile: Profile): | ||
super().__init__(profile) | ||
self._session: Session | None = None | ||
self._repo: SandboxRepositoryBackend | None = None | ||
self._repo: SandboxShaRepositoryBackend | None = None | ||
self._globals: dict[str, tuple[Any, str | None]] = {} | ||
self._closed = False | ||
self.get_session() # load the database on initialization | ||
|
@@ -124,12 +131,13 @@ def get_session(self) -> Session: | |
self._session.commit() | ||
return self._session | ||
|
||
def get_repository(self) -> SandboxRepositoryBackend: | ||
def get_repository(self) -> SandboxShaRepositoryBackend: | ||
if self._closed: | ||
raise ClosedStorage(str(self)) | ||
if self._repo is None: | ||
# to-do this does not seem to be removing the folder on garbage collection? | ||
self._repo = SandboxRepositoryBackend(filepath=get_config_option('storage.sandbox') or None) | ||
repo_path = self.profile.storage_config.get('repo_path') | ||
self._repo = SandboxShaRepositoryBackend(filepath=Path(repo_path) if repo_path else None) | ||
return self._repo | ||
|
||
@property | ||
|
@@ -199,11 +207,122 @@ def get_info(self, detailed: bool = False) -> dict: | |
# results['repository'] = self.get_repository().get_info(detailed) | ||
return results | ||
|
||
@staticmethod | ||
@functools.lru_cache(maxsize=18) | ||
def _get_mapper_from_entity(entity_type: EntityTypes, with_pk: bool): | ||
"""Return the Sqlalchemy mapper and fields corresponding to the given entity. | ||
:param with_pk: if True, the fields returned will include the primary key | ||
""" | ||
from sqlalchemy import inspect | ||
|
||
from aiida.storage.sqlite_zip.models import ( | ||
DbAuthInfo, | ||
DbComment, | ||
DbComputer, | ||
DbGroup, | ||
DbGroupNodes, | ||
DbLink, | ||
DbLog, | ||
DbNode, | ||
DbUser, | ||
) | ||
|
||
model = { | ||
EntityTypes.AUTHINFO: DbAuthInfo, | ||
EntityTypes.COMMENT: DbComment, | ||
EntityTypes.COMPUTER: DbComputer, | ||
EntityTypes.GROUP: DbGroup, | ||
EntityTypes.LOG: DbLog, | ||
EntityTypes.NODE: DbNode, | ||
EntityTypes.USER: DbUser, | ||
EntityTypes.LINK: DbLink, | ||
EntityTypes.GROUP_NODE: DbGroupNodes, | ||
}[entity_type] | ||
mapper = inspect(model).mapper | ||
keys = {key for key, col in mapper.c.items() if with_pk or col not in mapper.primary_key} | ||
return mapper, keys | ||
|
||
def bulk_insert(self, entity_type: EntityTypes, rows: list[dict], allow_defaults: bool = False) -> list[int]: | ||
raise NotImplementedError | ||
mapper, keys = self._get_mapper_from_entity(entity_type, False) | ||
if not rows: | ||
return [] | ||
if entity_type in (EntityTypes.COMPUTER, EntityTypes.LOG, EntityTypes.AUTHINFO): | ||
for row in rows: | ||
row['_metadata'] = row.pop('metadata') | ||
if allow_defaults: | ||
for row in rows: | ||
if not keys.issuperset(row): | ||
raise IntegrityError(f'Incorrect fields given for {entity_type}: {set(row)} not subset of {keys}') | ||
else: | ||
for row in rows: | ||
if set(row) != keys: | ||
raise IntegrityError(f'Incorrect fields given for {entity_type}: {set(row)} != {keys}') | ||
session = self.get_session() | ||
with (nullcontext() if self.in_transaction else self.transaction()): | ||
session.bulk_insert_mappings(mapper, rows, render_nulls=True, return_defaults=True) | ||
return [row['id'] for row in rows] | ||
|
||
def bulk_update(self, entity_type: EntityTypes, rows: list[dict]) -> None: | ||
raise NotImplementedError | ||
mapper, keys = self._get_mapper_from_entity(entity_type, True) | ||
if not rows: | ||
return None | ||
for row in rows: | ||
if 'id' not in row: | ||
raise IntegrityError(f"'id' field not given for {entity_type}: {set(row)}") | ||
if not keys.issuperset(row): | ||
raise IntegrityError(f'Incorrect fields given for {entity_type}: {set(row)} not subset of {keys}') | ||
session = self.get_session() | ||
with (nullcontext() if self.in_transaction else self.transaction()): | ||
session.bulk_update_mappings(mapper, rows) | ||
|
||
def delete_nodes_and_connections(self, pks_to_delete: Sequence[int]): | ||
raise NotImplementedError | ||
|
||
|
||
class SandboxShaRepositoryBackend(SandboxRepositoryBackend): | ||
"""A sandbox repository backend that uses the sha256 of the file as the key. | ||
This allows for compatibility with the archive format (i.e. `SqliteZipBackend`). | ||
Which allows for temporary profiles to be exported and imported. | ||
""" | ||
|
||
@property | ||
def key_format(self) -> str | None: | ||
return 'sha256' | ||
|
||
def get_object_hash(self, key: str) -> str: | ||
return key | ||
|
||
def _put_object_from_filelike(self, handle: BinaryIO) -> str: | ||
"""Store the byte contents of a file in the repository. | ||
:param handle: filelike object with the byte content to be stored. | ||
:return: the generated fully qualified identifier for the object within the repository. | ||
:raises TypeError: if the handle is not a byte stream. | ||
""" | ||
# we first compute the hash of the file contents | ||
hsh = hashlib.sha256() | ||
position = handle.tell() | ||
while True: | ||
buf = handle.read(1024 * 1024) | ||
if not buf: | ||
break | ||
hsh.update(buf) | ||
key = hsh.hexdigest() | ||
|
||
filepath = os.path.join(self.sandbox.abspath, key) | ||
if not os.path.exists(filepath): | ||
# if a file with this hash does not already exist | ||
# then we reset the file pointer and copy the contents | ||
handle.seek(position) | ||
with open(filepath, 'wb') as target: | ||
shutil.copyfileobj(handle, target) | ||
|
||
return key | ||
|
||
def get_info(self, detailed: bool = False, **kwargs) -> dict: | ||
return {'objects': {'count': len(list(self.list_objects()))}} | ||
|
||
def maintain(self, dry_run: bool = False, live: bool = True, **kwargs) -> None: | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
# -*- coding: utf-8 -*- | ||
"""Test export and import of AiiDA archives to/from a temporary profile.""" | ||
from pathlib import Path | ||
|
||
from aiida import orm | ||
from aiida.storage.sqlite_temp import SqliteTempBackend | ||
from aiida.tools.archive import create_archive, import_archive | ||
|
||
|
||
def test_basic(tmp_path): | ||
"""Test the creation of an archive and re-import.""" | ||
filename = Path(tmp_path / 'export.aiida') | ||
|
||
# generate a temporary backend | ||
profile1 = SqliteTempBackend.create_profile(repo_path=str(tmp_path / 'repo1')) | ||
backend1 = SqliteTempBackend(profile1) | ||
|
||
# add simple node | ||
dict_data = {'key1': 'value1'} | ||
node = orm.Dict(dict_data, backend=backend1).store() | ||
# add a comment to the node | ||
node.base.comments.add('test comment', backend1.default_user) | ||
# add node with repository data | ||
path = Path(tmp_path / 'test.txt') | ||
text_data = 'test' | ||
path.write_text(text_data, encoding='utf-8') | ||
orm.SinglefileData(str(path), backend=backend1).store() | ||
|
||
# export to archive | ||
create_archive(None, backend=backend1, filename=filename) | ||
|
||
# create a new temporary backend and import | ||
profile2 = SqliteTempBackend.create_profile(repo_path=str(tmp_path / 'repo2')) | ||
backend2 = SqliteTempBackend(profile2) | ||
import_archive(filename, backend=backend2) | ||
|
||
# check that the nodes are there | ||
assert orm.QueryBuilder(backend=backend2).append(orm.Data).count() == 2 | ||
|
||
# check that we can retrieve the attributes and comment data | ||
node = orm.QueryBuilder(backend=backend2).append(orm.Dict).first(flat=True) | ||
assert node.get_dict() == dict_data | ||
assert len(node.base.comments.all()) == 1 | ||
|
||
# check that we can retrieve the repository data | ||
node = orm.QueryBuilder(backend=backend2).append(orm.SinglefileData).first(flat=True) | ||
assert node.get_content() == text_data |