Skip to content

Commit

Permalink
Merge branch 'release/2.2.1' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
sphuber committed Dec 22, 2022
2 parents d5ebf0e + 9aca3e1 commit 4be23fe
Show file tree
Hide file tree
Showing 17 changed files with 121 additions and 119 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# Changelog

## v2.2.1 - 2022-12-22

### Fixes
- Critical bug fix: Revert the changes of PR [#5804](https://github.com/aiidateam/aiida-core/pull/5804) released with v2.2.0, which addressed a bug when mutating nodes during `QueryBuilder.iterall`. Unfortunately, the change caused changes performed by `verdi` commands (as well as changes made in `verdi shell`) to not be persisted to the database. [[#5851]](https://github.com/aiidateam/aiida-core/pull/5851)


## v2.2.0 - 2022-12-13

This feature release comes with a significant feature and a number of improvements and fixes.
Expand Down
2 changes: 1 addition & 1 deletion aiida/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
'For further information please visit http://www.aiida.net/. All rights reserved.'
)
__license__ = 'MIT license, see LICENSE.txt file.'
__version__ = '2.2.0'
__version__ = '2.2.1'
__authors__ = 'The AiiDA team.'
__paper__ = (
'S. P. Huber et al., "AiiDA 1.0, a scalable computational infrastructure for automated reproducible workflows and '
Expand Down
43 changes: 21 additions & 22 deletions aiida/storage/psql_dos/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from typing import TYPE_CHECKING, Iterator, List, Optional, Sequence, Set, Union

from disk_objectstore import Container
from sqlalchemy.exc import IntegrityError as SqlaIntegrityError
from sqlalchemy.orm import Session, scoped_session, sessionmaker

from aiida.common.exceptions import ClosedStorage, ConfigurationError, IntegrityError
Expand Down Expand Up @@ -233,23 +232,20 @@ def users(self):

@contextmanager
def transaction(self) -> Iterator[Session]:
"""Open a transaction and yield the current session.
"""Open a transaction to be used as a context manager.
If there is an exception within the context then the changes will be rolled back and the state will be as before
entering, otherwise the changes will be commited and the transaction closed. Transactions can be nested.
entering. Transactions can be nested.
"""
session = self.get_session()

try:
if session.in_transaction():
if session.in_transaction():
with session.begin_nested():
yield session
session.commit()
else:
with session.begin():
with session.begin_nested():
yield session
else:
with session.begin():
with session.begin_nested():
yield session
except SqlaIntegrityError as exception:
raise IntegrityError(str(exception)) from exception

@property
def in_transaction(self) -> bool:
Expand Down Expand Up @@ -327,16 +323,19 @@ def delete_nodes_and_connections(self, pks_to_delete: Sequence[int]) -> None: #
from aiida.storage.psql_dos.models.group import DbGroupNode
from aiida.storage.psql_dos.models.node import DbLink, DbNode

with self.transaction() as session:
# Delete the membership of these nodes to groups.
session.query(DbGroupNode).filter(DbGroupNode.dbnode_id.in_(list(pks_to_delete))
).delete(synchronize_session='fetch')
# Delete the links coming out of the nodes marked for deletion.
session.query(DbLink).filter(DbLink.input_id.in_(list(pks_to_delete))).delete(synchronize_session='fetch')
# Delete the links pointing to the nodes marked for deletion.
session.query(DbLink).filter(DbLink.output_id.in_(list(pks_to_delete))).delete(synchronize_session='fetch')
# Delete the actual nodes
session.query(DbNode).filter(DbNode.id.in_(list(pks_to_delete))).delete(synchronize_session='fetch')
if not self.in_transaction:
raise AssertionError('Cannot delete nodes and links outside a transaction')

session = self.get_session()
# Delete the membership of these nodes to groups.
session.query(DbGroupNode).filter(DbGroupNode.dbnode_id.in_(list(pks_to_delete))
).delete(synchronize_session='fetch')
# Delete the links coming out of the nodes marked for deletion.
session.query(DbLink).filter(DbLink.input_id.in_(list(pks_to_delete))).delete(synchronize_session='fetch')
# Delete the links pointing to the nodes marked for deletion.
session.query(DbLink).filter(DbLink.output_id.in_(list(pks_to_delete))).delete(synchronize_session='fetch')
# Delete the actual nodes
session.query(DbNode).filter(DbNode.id.in_(list(pks_to_delete))).delete(synchronize_session='fetch')

def get_backend_entity(self, model: base.Base) -> BackendEntity:
"""
Expand Down
34 changes: 20 additions & 14 deletions aiida/storage/psql_dos/orm/groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,9 @@
from aiida.orm.implementation.groups import BackendGroup, BackendGroupCollection
from aiida.storage.psql_dos.models.group import DbGroup, DbGroupNode

from . import entities, users
from . import entities, users, utils
from .extras_mixin import ExtrasMixin
from .nodes import SqlaNode
from .utils import ModelWrapper

_LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -47,7 +46,7 @@ def __init__(self, backend, label, user, description='', type_string=''):
super().__init__(backend)

dbgroup = self.MODEL_CLASS(label=label, description=description, user=user.bare_model, type_string=type_string)
self._model = ModelWrapper(dbgroup, backend)
self._model = utils.ModelWrapper(dbgroup, backend)

@property
def label(self):
Expand Down Expand Up @@ -116,9 +115,8 @@ def is_stored(self):
return self.pk is not None

def store(self):
with self.backend.transaction():
self.model.save()
return self
self.model.save()
return self

def count(self):
"""Return the number of entities in this group.
Expand All @@ -130,9 +128,10 @@ def count(self):

def clear(self):
"""Remove all the nodes from this group."""
with self.backend.transaction():
# Note we have to call `bare_model` to circumvent flushing data to the database
self.bare_model.dbnodes = []
session = self.backend.get_session()
# Note we have to call `bare_model` to circumvent flushing data to the database
self.bare_model.dbnodes = []
session.commit()

@property
def nodes(self):
Expand Down Expand Up @@ -193,7 +192,7 @@ def check_node(given_node):
if not given_node.is_stored:
raise ValueError('At least one of the provided nodes is unstored, stopping...')

with self.backend.transaction() as session:
with utils.disable_expire_on_commit(self.backend.get_session()) as session:
if not skip_orm:
# Get dbnodes here ONCE, otherwise each call to dbnodes will re-read the current value in the database
dbnodes = self.model.dbnodes
Expand All @@ -220,6 +219,9 @@ def check_node(given_node):
ins = insert(table).values(ins_dict)
session.execute(ins.on_conflict_do_nothing(index_elements=['dbnode_id', 'dbgroup_id']))

# Commit everything as up till now we've just flushed
session.commit()

def remove_nodes(self, nodes, **kwargs):
"""Remove a node or a set of nodes from the group.
Expand Down Expand Up @@ -247,7 +249,7 @@ def check_node(node):

list_nodes = []

with self.backend.transaction() as session:
with utils.disable_expire_on_commit(self.backend.get_session()) as session:
if not skip_orm:
for node in nodes:
check_node(node)
Expand All @@ -266,13 +268,17 @@ def check_node(node):
statement = table.delete().where(clause)
session.execute(statement)

session.commit()


class SqlaGroupCollection(BackendGroupCollection):
"""The SLQA collection of groups"""

ENTITY_CLASS = SqlaGroup

def delete(self, id): # pylint: disable=redefined-builtin
with self.backend.transaction() as session:
row = session.get(self.ENTITY_CLASS.MODEL_CLASS, id)
session.delete(row)
session = self.backend.get_session()

row = session.get(self.ENTITY_CLASS.MODEL_CLASS, id)
session.delete(row)
session.commit()
53 changes: 34 additions & 19 deletions aiida/storage/psql_dos/orm/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,8 @@ def user(self, user):
self.model.user = user.bare_model

def add_incoming(self, source, link_type, link_label):
session = self.backend.get_session()

type_check(source, self.__class__)

if not self.is_stored:
Expand All @@ -192,33 +194,43 @@ def add_incoming(self, source, link_type, link_label):
raise exceptions.ModificationNotAllowed('source node has to be stored when adding a link from it')

self._add_link(source, link_type, link_label)
session.commit()

def _add_link(self, source, link_type, link_label):
"""Add a single link"""
with self.backend.transaction() as session:
try:
session = self.backend.get_session()

try:
with session.begin_nested():
link = self.LINK_CLASS(input_id=source.pk, output_id=self.pk, label=link_label, type=link_type.value)
session.add(link)
except SQLAlchemyError as exception:
raise exceptions.UniquenessError(f'failed to create the link: {exception}') from exception
except SQLAlchemyError as exception:
raise exceptions.UniquenessError(f'failed to create the link: {exception}') from exception

def clean_values(self):
self.model.attributes = clean_value(self.model.attributes)
self.model.extras = clean_value(self.model.extras)

def store(self, links=None, with_transaction=True, clean=True): # pylint: disable=arguments-differ,unused-argument
with self.backend.transaction():
def store(self, links=None, with_transaction=True, clean=True): # pylint: disable=arguments-differ
session = self.backend.get_session()

if clean:
self.clean_values()

if clean:
self.clean_values()
session.add(self.model)

self.model.save()
if links:
for link_triple in links:
self._add_link(*link_triple)

if links:
for link_triple in links:
self._add_link(*link_triple)
if with_transaction:
try:
session.commit()
except SQLAlchemyError:
session.rollback()
raise

return self
return self

@property
def attributes(self):
Expand Down Expand Up @@ -301,6 +313,7 @@ class SqlaNodeCollection(BackendNodeCollection):

def get(self, pk):
session = self.backend.get_session()

try:
return self.ENTITY_CLASS.from_dbmodel(
session.query(self.ENTITY_CLASS.MODEL_CLASS).filter_by(id=pk).one(), self.backend
Expand All @@ -309,9 +322,11 @@ def get(self, pk):
raise exceptions.NotExistent(f"Node with pk '{pk}' not found") from NoResultFound

def delete(self, pk):
with self.backend.transaction() as session:
try:
row = session.query(self.ENTITY_CLASS.MODEL_CLASS).filter_by(id=pk).one()
session.delete(row)
except NoResultFound:
raise exceptions.NotExistent(f"Node with pk '{pk}' not found") from NoResultFound
session = self.backend.get_session()

try:
row = session.query(self.ENTITY_CLASS.MODEL_CLASS).filter_by(id=pk).one()
session.delete(row)
session.commit()
except NoResultFound:
raise exceptions.NotExistent(f"Node with pk '{pk}' not found") from NoResultFound
16 changes: 16 additions & 0 deletions aiida/storage/psql_dos/orm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
# For further information please visit http://www.aiida.net #
###########################################################################
"""Utilities for the implementation of the SqlAlchemy backend."""
import contextlib
from typing import TYPE_CHECKING

# pylint: disable=import-error,no-name-in-module
Expand Down Expand Up @@ -166,3 +167,18 @@ def _in_transaction(self):
:return: boolean, True if currently in open transaction, False otherwise.
"""
return self.session.in_nested_transaction()


@contextlib.contextmanager
def disable_expire_on_commit(session):
"""Context manager that disables expire_on_commit and restores the original value on exit
:param session: The SQLA session
:type session: :class:`sqlalchemy.orm.session.Session`
"""
current_value = session.expire_on_commit
session.expire_on_commit = False
try:
yield session
finally:
session.expire_on_commit = current_value
18 changes: 7 additions & 11 deletions aiida/storage/sqlite_temp/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,9 @@
from functools import cached_property
from typing import Any, Iterator, Sequence

from sqlalchemy.exc import IntegrityError as SqlaIntegrityError
from sqlalchemy.orm import Session

from aiida.common.exceptions import ClosedStorage, IntegrityError
from aiida.common.exceptions import ClosedStorage
from aiida.manage import Profile, get_config_option
from aiida.orm.entities import EntityTypes
from aiida.orm.implementation import BackendEntity, StorageBackend
Expand Down Expand Up @@ -145,17 +144,14 @@ def transaction(self) -> Iterator[Session]:
entering. Transactions can be nested.
"""
session = self.get_session()

try:
if session.in_transaction():
if session.in_transaction():
with session.begin_nested():
yield session
session.commit()
else:
with session.begin():
with session.begin_nested():
yield session
else:
with session.begin():
with session.begin_nested():
yield session
except SqlaIntegrityError as exception:
raise IntegrityError(str(exception)) from exception

def _clear(self) -> None:
raise NotImplementedError
Expand Down
3 changes: 2 additions & 1 deletion aiida/tools/graph/deletions.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,8 @@ def _missing_callback(_pks: Iterable[int]):
return (pks_set_to_delete, True)

DELETE_LOGGER.report('Starting node deletion...')
backend.delete_nodes_and_connections(pks_set_to_delete)
with backend.transaction():
backend.delete_nodes_and_connections(pks_set_to_delete)
DELETE_LOGGER.report('Deletion of nodes completed.')

return (pks_set_to_delete, True)
Expand Down
7 changes: 6 additions & 1 deletion tests/orm/implementation/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,12 @@ def test_delete_nodes_and_connections(self):
assert len(calc_node.base.links.get_outgoing().all()) == 1
assert len(group.nodes) == 1

self.backend.delete_nodes_and_connections([node_pk])
# cannot call outside a transaction
with pytest.raises(AssertionError):
self.backend.delete_nodes_and_connections([node_pk])

with self.backend.transaction():
self.backend.delete_nodes_and_connections([node_pk])

# checks after deletion
with pytest.raises(exceptions.NotExistent):
Expand Down
3 changes: 2 additions & 1 deletion tests/orm/nodes/test_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -860,7 +860,8 @@ def test_delete_through_backend(self):
assert len(Log.collection.get_logs_for(data_two)) == 1
assert Log.collection.get_logs_for(data_two)[0].pk == log_two.pk

backend.delete_nodes_and_connections([data_two.pk])
with backend.transaction():
backend.delete_nodes_and_connections([data_two.pk])

assert len(Log.collection.get_logs_for(data_one)) == 1
assert Log.collection.get_logs_for(data_one)[0].pk == log_one.pk
Expand Down
1 change: 1 addition & 0 deletions tests/orm/test_querybuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1495,6 +1495,7 @@ def test_iterall_with_mutation(self):
assert orm.load_node(pk).get_extra('key') == 'value'

@pytest.mark.usefixtures('aiida_profile_clean')
@pytest.mark.skip('enable when https://github.com/aiidateam/aiida-core/issues/5802 is fixed')
def test_iterall_with_store(self):
"""Test that nodes can be stored while being iterated using ``QueryBuilder.iterall``.
Expand Down
2 changes: 1 addition & 1 deletion tests/storage/psql_dos/test_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def test_multiple_node_creation(self):
from aiida.storage.psql_dos.models.node import DbNode

# Get the automatic user
dbuser = self.backend.users.create(get_new_uuid()).store().bare_model
dbuser = self.backend.users.create('[email protected]').store().bare_model
# Create a new node but don't add it to the session
node_uuid = get_new_uuid()
DbNode(user=dbuser, uuid=node_uuid, node_type=None)
Expand Down
Loading

0 comments on commit 4be23fe

Please sign in to comment.