Skip to content

Commit

Permalink
global: apply new SQLAlchemy rules
Browse files Browse the repository at this point in the history
* change from model Model.query to db.session.query(Model)

* without an update in pytest-invenio and without this commit it
  produces following error: sqlalchemy.exc.InvalidRequestError: Can't
  operate on closed transaction inside context manager.

* with the updated db fixture in pytest-inveniothis this commit fixes
  following TypeError: 'Session' object is not callable
  • Loading branch information
utnapischtim committed Oct 1, 2024
1 parent 2a234ab commit 3cd34db
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 19 deletions.
31 changes: 17 additions & 14 deletions invenio_banners/records/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,15 @@ def create(cls, data):
def update(cls, data, id):
"""Update an existing banner."""
with db.session.begin_nested():
cls.query.filter_by(id=id).update(data)
db.session.query(cls).filter_by(id=id).update(data)

db.session.commit()

@classmethod
def get(cls, id):
"""Get banner by its id."""
try:
return cls.query.filter_by(id=id).one()
return db.session.query(cls).filter_by(id=id).one()
except NoResultFound:
raise BannerNotExistsError(id)

Expand All @@ -96,7 +96,8 @@ def get_active(cls, url_path):
now = datetime.utcnow()

query = (
cls.query.filter(cls.active.is_(True))
db.session.query(cls)
.filter(cls.active.is_(True))
.filter(cls.start_datetime <= now)
.filter((cls.end_datetime.is_(None)) | (now <= cls.end_datetime))
)
Expand All @@ -114,16 +115,17 @@ def get_active(cls, url_path):
@classmethod
def search(cls, search_params, filters):
"""Filter banners accordingly to query params."""
banners = (
BannerModel.query.filter(or_(False, *filters))
.order_by(
search_params["sort_direction"](text(",".join(search_params["sort"])))
)
.paginate(
page=search_params["page"],
per_page=search_params["size"],
error_out=False,
)
if filters == []:
filtered = db.session.query(BannerModel).filter()
else:
filtered = db.session.query(BannerModel).filter(or_(*filters))

banners = filtered.order_by(
search_params["sort_direction"](text(",".join(search_params["sort"])))
).paginate(
page=search_params["page"],
per_page=search_params["size"],
error_out=False,
)

return banners
Expand All @@ -134,7 +136,8 @@ def disable_expired(cls):
now = datetime.utcnow()

query = (
cls.query.filter(cls.active.is_(True))
db.session.query(cls)
.filter(cls.active.is_(True))
.filter(cls.end_datetime.isnot(None))
.filter(cls.end_datetime < now)
)
Expand Down
3 changes: 2 additions & 1 deletion tests/resources/test_resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from datetime import date, datetime

import pytest
from invenio_db import db
from invenio_records_resources.services.errors import PermissionDeniedError

from invenio_banners.records import BannerModel
Expand Down Expand Up @@ -193,7 +194,7 @@ def test_delete_banner(client, admin, headers):
_delete_banner(client, banner.id, headers, 204)

# check that it's not present in db
assert BannerModel.query.filter_by(id=banner.id).one_or_none() is None
assert db.session.query(BannerModel).filter_by(id=banner.id).one_or_none() is None


def test_delete_is_forbidden(client, user, headers):
Expand Down
10 changes: 6 additions & 4 deletions tests/services/test_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from datetime import datetime, timedelta

import pytest
from invenio_db import db
from invenio_records_resources.services.errors import PermissionDeniedError

from invenio_banners.proxies import current_banners_service as service
Expand Down Expand Up @@ -131,7 +132,7 @@ def test_delete_banner(app, superuser_identity):
service.delete(superuser_identity, banner.id)

# check that it's not present in db
assert BannerModel.query.filter_by(id=banner.id).one_or_none() is None
assert db.session.query(BannerModel).filter_by(id=banner.id).one_or_none() is None


def test_delete_is_forbidden(app, simple_user_identity):
Expand Down Expand Up @@ -211,11 +212,12 @@ def test_disable_expired_banners(app, superuser_identity):
BannerModel.create(banners["expired"])
BannerModel.create(banners["active"])

assert BannerModel.query.filter(BannerModel.active.is_(True)).count() == 2

assert (
db.session.query(BannerModel).filter(BannerModel.active.is_(True)).count() == 2
)
service.disable_expired(superuser_identity)

_banners = BannerModel.query.filter(BannerModel.active.is_(True)).all()
_banners = db.session.query(BannerModel).filter(BannerModel.active.is_(True)).all()

assert len(_banners) == 1
assert _banners[0].message == "active"
Expand Down

0 comments on commit 3cd34db

Please sign in to comment.