diff --git a/invenio_banners/records/models.py b/invenio_banners/records/models.py index 05baa65..c9238d1 100644 --- a/invenio_banners/records/models.py +++ b/invenio_banners/records/models.py @@ -73,8 +73,6 @@ def update(cls, data, id): # returned and this classmethod would be called db.session.query(cls).filter_by(id=id).update(data) - db.session.commit() - @classmethod def get(cls, id): """Get banner by its id.""" @@ -143,5 +141,3 @@ def disable_expired(cls): for old in query.all(): old.active = False - - db.session.commit() diff --git a/invenio_banners/services/service.py b/invenio_banners/services/service.py index ecfa336..3ed1993 100644 --- a/invenio_banners/services/service.py +++ b/invenio_banners/services/service.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- # # Copyright (C) 2022-2023 CERN. +# Copyright (C) 2024 Graz University of Technology. # # Invenio-Banners is free software; you can redistribute it and/or modify it # under the terms of the MIT License; see LICENSE file for more details. @@ -10,6 +11,7 @@ import distutils.util import arrow +from invenio_db.uow import unit_of_work from invenio_records_resources.services import RecordService from invenio_records_resources.services.base import LinksTemplate from invenio_records_resources.services.base.utils import map_search_params @@ -81,7 +83,8 @@ def search(self, identity, params): links_item_tpl=self.links_item_tpl, ) - def create(self, identity, data, raise_errors=True): + @unit_of_work() + def create(self, identity, data, raise_errors=True, uow=None): """Create a banner.""" self.require_permission(identity, "create") @@ -99,17 +102,18 @@ def create(self, identity, data, raise_errors=True): self, identity, banner, links_tpl=self.links_item_tpl, errors=errors ) - def delete(self, identity, id): + @unit_of_work() + def delete(self, identity, id, uow=None): """Delete a banner from database.""" self.require_permission(identity, "delete") banner = self.record_cls.get(id) - self.record_cls.delete(banner) return self.result_item(self, identity, banner, links_tpl=self.links_item_tpl) - def update(self, identity, id, data): + @unit_of_work() + def update(self, identity, id, data, uow=None): """Update a banner.""" self.require_permission(identity, "update") @@ -131,7 +135,8 @@ def update(self, identity, id, data): links_tpl=self.links_item_tpl, ) - def disable_expired(self, identity): + @unit_of_work() + def disable_expired(self, identity, uow=None): """Disable expired banners.""" self.require_permission(identity, "disable") self.record_cls.disable_expired()