Skip to content

Commit

Permalink
SqliteTempBackend: Allow reading and writing to archives
Browse files Browse the repository at this point in the history
  • Loading branch information
sphuber committed Sep 17, 2023
1 parent 7a3f108 commit 6cc61e2
Show file tree
Hide file tree
Showing 3 changed files with 178 additions and 12 deletions.
139 changes: 129 additions & 10 deletions aiida/storage/sqlite_temp/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,19 @@
"""Definition of the ``SqliteTempBackend`` backend."""
from __future__ import annotations

from contextlib import contextmanager
from contextlib import contextmanager, nullcontext
import functools
from functools import cached_property
from typing import Any, Iterator, Sequence
import hashlib
import os
from pathlib import Path
import shutil
from typing import Any, BinaryIO, Iterator, Sequence

from sqlalchemy.orm import Session

from aiida.common.exceptions import ClosedStorage
from aiida.manage import Profile, get_config_option
from aiida.common.exceptions import ClosedStorage, IntegrityError
from aiida.manage import Profile
from aiida.orm.entities import EntityTypes
from aiida.orm.implementation import BackendEntity, StorageBackend
from aiida.repository.backend.sandbox import SandboxRepositoryBackend
Expand All @@ -42,7 +47,8 @@ def create_profile(
name: str = 'temp',
default_user_email='[email protected]',
options: dict | None = None,
debug: bool = False
debug: bool = False,
repo_path: str | Path | None = None,
) -> Profile:
"""Create a new profile instance for this backend, from the path to the zip file."""
return Profile(
Expand All @@ -52,6 +58,7 @@ def create_profile(
'backend': 'core.sqlite_temp',
'config': {
'debug': debug,
'repo_path': repo_path,
}
},
'process_control': {
Expand Down Expand Up @@ -81,7 +88,7 @@ def migrate(cls, profile: Profile):
def __init__(self, profile: Profile):
super().__init__(profile)
self._session: Session | None = None
self._repo: SandboxRepositoryBackend | None = None
self._repo: SandboxShaRepositoryBackend | None = None
self._globals: dict[str, tuple[Any, str | None]] = {}
self._closed = False
self.get_session() # load the database on initialization
Expand Down Expand Up @@ -124,12 +131,13 @@ def get_session(self) -> Session:
self._session.commit()
return self._session

def get_repository(self) -> SandboxRepositoryBackend:
def get_repository(self) -> SandboxShaRepositoryBackend:
if self._closed:
raise ClosedStorage(str(self))
if self._repo is None:
# to-do this does not seem to be removing the folder on garbage collection?
self._repo = SandboxRepositoryBackend(filepath=get_config_option('storage.sandbox') or None)
repo_path = self.profile.storage_config.get('repo_path')
self._repo = SandboxShaRepositoryBackend(filepath=Path(repo_path) if repo_path else None)
return self._repo

@property
Expand Down Expand Up @@ -199,11 +207,122 @@ def get_info(self, detailed: bool = False) -> dict:
# results['repository'] = self.get_repository().get_info(detailed)
return results

@staticmethod
@functools.lru_cache(maxsize=18)
def _get_mapper_from_entity(entity_type: EntityTypes, with_pk: bool):
"""Return the Sqlalchemy mapper and fields corresponding to the given entity.
:param with_pk: if True, the fields returned will include the primary key
"""
from sqlalchemy import inspect

from aiida.storage.sqlite_zip.models import (
DbAuthInfo,
DbComment,
DbComputer,
DbGroup,
DbGroupNodes,
DbLink,
DbLog,
DbNode,
DbUser,
)

model = {
EntityTypes.AUTHINFO: DbAuthInfo,
EntityTypes.COMMENT: DbComment,
EntityTypes.COMPUTER: DbComputer,
EntityTypes.GROUP: DbGroup,
EntityTypes.LOG: DbLog,
EntityTypes.NODE: DbNode,
EntityTypes.USER: DbUser,
EntityTypes.LINK: DbLink,
EntityTypes.GROUP_NODE: DbGroupNodes,
}[entity_type]
mapper = inspect(model).mapper
keys = {key for key, col in mapper.c.items() if with_pk or col not in mapper.primary_key}
return mapper, keys

def bulk_insert(self, entity_type: EntityTypes, rows: list[dict], allow_defaults: bool = False) -> list[int]:
raise NotImplementedError
mapper, keys = self._get_mapper_from_entity(entity_type, False)
if not rows:
return []
if entity_type in (EntityTypes.COMPUTER, EntityTypes.LOG, EntityTypes.AUTHINFO):
for row in rows:
row['_metadata'] = row.pop('metadata')
if allow_defaults:
for row in rows:
if not keys.issuperset(row):
raise IntegrityError(f'Incorrect fields given for {entity_type}: {set(row)} not subset of {keys}')
else:
for row in rows:
if set(row) != keys:
raise IntegrityError(f'Incorrect fields given for {entity_type}: {set(row)} != {keys}')
session = self.get_session()
with (nullcontext() if self.in_transaction else self.transaction()):
session.bulk_insert_mappings(mapper, rows, render_nulls=True, return_defaults=True)
return [row['id'] for row in rows]

def bulk_update(self, entity_type: EntityTypes, rows: list[dict]) -> None:
raise NotImplementedError
mapper, keys = self._get_mapper_from_entity(entity_type, True)
if not rows:
return None
for row in rows:
if 'id' not in row:
raise IntegrityError(f"'id' field not given for {entity_type}: {set(row)}")
if not keys.issuperset(row):
raise IntegrityError(f'Incorrect fields given for {entity_type}: {set(row)} not subset of {keys}')
session = self.get_session()
with (nullcontext() if self.in_transaction else self.transaction()):
session.bulk_update_mappings(mapper, rows)

def delete_nodes_and_connections(self, pks_to_delete: Sequence[int]):
raise NotImplementedError


class SandboxShaRepositoryBackend(SandboxRepositoryBackend):
"""A sandbox repository backend that uses the sha256 of the file as the key.
This allows for compatibility with the archive format (i.e. `SqliteZipBackend`).
Which allows for temporary profiles to be exported and imported.
"""

@property
def key_format(self) -> str | None:
return 'sha256'

def get_object_hash(self, key: str) -> str:
return key

def _put_object_from_filelike(self, handle: BinaryIO) -> str:
"""Store the byte contents of a file in the repository.
:param handle: filelike object with the byte content to be stored.
:return: the generated fully qualified identifier for the object within the repository.
:raises TypeError: if the handle is not a byte stream.
"""
# we first compute the hash of the file contents
hsh = hashlib.sha256()
position = handle.tell()
while True:
buf = handle.read(1024 * 1024)
if not buf:
break
hsh.update(buf)
key = hsh.hexdigest()

filepath = os.path.join(self.sandbox.abspath, key)
if not os.path.exists(filepath):
# if a file with this hash does not already exist
# then we reset the file pointer and copy the contents
handle.seek(position)
with open(filepath, 'wb') as target:
shutil.copyfileobj(handle, target)

return key

def get_info(self, detailed: bool = False, **kwargs) -> dict:
return {'objects': {'count': len(list(self.list_objects()))}}

def maintain(self, dry_run: bool = False, live: bool = True, **kwargs) -> None:
pass
4 changes: 2 additions & 2 deletions aiida/tools/archive/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,7 +778,7 @@ def _transform(row):
def _transform(row):
# to-do this is probably not the most efficient way to do this
uuid, new_mtime, new_comment = row
cmt = orm.Comment.collection.get(uuid=uuid)
cmt = orm.comments.CommentCollection(orm.Comment, backend).get(uuid=uuid)
if cmt.mtime < new_mtime:
cmt.set_mtime(new_mtime)
cmt.set_content(new_comment)
Expand Down Expand Up @@ -1086,7 +1086,7 @@ def _make_import_group(
break
else:
raise ImportUniquenessError(f'New import Group has existing label {label!r} and re-labelling failed')
dummy_orm = orm.ImportGroup(label)
dummy_orm = orm.ImportGroup(label, backend=backend_to)
row = {
'label': label,
'description': 'Group generated by archive import',
Expand Down
47 changes: 47 additions & 0 deletions tests/storage/sqlite/test_archive.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# -*- coding: utf-8 -*-
"""Test export and import of AiiDA archives to/from a temporary profile."""
from pathlib import Path

from aiida import orm
from aiida.storage.sqlite_temp import SqliteTempBackend
from aiida.tools.archive import create_archive, import_archive


def test_basic(tmp_path):
"""Test the creation of an archive and re-import."""
filename = Path(tmp_path / 'export.aiida')

# generate a temporary backend
profile1 = SqliteTempBackend.create_profile(repo_path=str(tmp_path / 'repo1'))
backend1 = SqliteTempBackend(profile1)

# add simple node
dict_data = {'key1': 'value1'}
node = orm.Dict(dict_data, backend=backend1).store()
# add a comment to the node
node.base.comments.add('test comment', backend1.default_user)
# add node with repository data
path = Path(tmp_path / 'test.txt')
text_data = 'test'
path.write_text(text_data, encoding='utf-8')
orm.SinglefileData(str(path), backend=backend1).store()

# export to archive
create_archive(None, backend=backend1, filename=filename)

# create a new temporary backend and import
profile2 = SqliteTempBackend.create_profile(repo_path=str(tmp_path / 'repo2'))
backend2 = SqliteTempBackend(profile2)
import_archive(filename, backend=backend2)

# check that the nodes are there
assert orm.QueryBuilder(backend=backend2).append(orm.Data).count() == 2

# check that we can retrieve the attributes and comment data
node = orm.QueryBuilder(backend=backend2).append(orm.Dict).first(flat=True)
assert node.get_dict() == dict_data
assert len(node.base.comments.all()) == 1

# check that we can retrieve the repository data
node = orm.QueryBuilder(backend=backend2).append(orm.SinglefileData).first(flat=True)
assert node.get_content() == text_data

0 comments on commit 6cc61e2

Please sign in to comment.