Skip to content

Commit

Permalink
model: make forward compatible to sqlalchemy >= 2
Browse files Browse the repository at this point in the history
  • Loading branch information
utnapischtim committed Nov 5, 2024
1 parent 037e2ec commit ae541a5
Show file tree
Hide file tree
Showing 8 changed files with 49 additions and 35 deletions.
7 changes: 5 additions & 2 deletions invenio_accounts/datastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# This file is part of Invenio.
# Copyright (C) 2015-2024 CERN.
# Copyright (C) 2024 Graz University of Technology.
#
# Invenio is free software; you can redistribute it and/or modify it
# under the terms of the MIT License; see LICENSE file for more details.
Expand All @@ -12,6 +13,7 @@

from flask import current_app
from flask_security import SQLAlchemyUserDatastore, user_confirmed
from invenio_db import db
from sqlalchemy.orm import joinedload

from .models import Domain, Role, User
Expand Down Expand Up @@ -108,12 +110,13 @@ def create_role(self, **kwargs):

def find_role_by_id(self, role_id):
"""Fetches roles searching by id."""
return self.role_model.query.filter_by(id=role_id).one_or_none()
return db.session.query(self.role_model).filter_by(id=role_id).one_or_none()

def find_domain(self, domain):
"""Find a domain."""
return (
Domain.query.filter_by(domain=domain)
db.session.query(Domain)
.filter_by(domain=domain)
.options(joinedload(Domain.category_name))
.one_or_none()
)
Expand Down
15 changes: 9 additions & 6 deletions invenio_accounts/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# This file is part of Invenio.
# Copyright (C) 2015-2024 CERN.
# Copyright (C) 2022 KTH Royal Institute of Technology
# Copyright (C) 2024 Graz University of Technology.
#
# Invenio is free software; you can redistribute it and/or modify it
# under the terms of the MIT License; see LICENSE file for more details.
Expand Down Expand Up @@ -415,12 +416,12 @@ def query_by_expired(cls):
"""Query to select all expired sessions."""
lifetime = current_app.permanent_session_lifetime
expired_moment = datetime.utcnow() - lifetime
return cls.query.filter(cls.created < expired_moment)
return db.session.query(cls).filter(cls.created < expired_moment)

@classmethod
def query_by_user(cls, user_id):
"""Query to select user sessions."""
return cls.query.filter_by(user_id=user_id)
return db.session.query(cls).filter_by(user_id=user_id)

@classmethod
def is_current(cls, sid_s):
Expand All @@ -446,7 +447,9 @@ class UserIdentity(db.Model, Timestamp):
@classmethod
def get_user(cls, method, external_id):
"""Get the user for a given identity."""
identity = cls.query.filter_by(id=external_id, method=method).one_or_none()
identity = (
db.session.query(cls).filter_by(id=external_id, method=method).one_or_none()
)
if identity is not None:
return identity.user
return None
Expand Down Expand Up @@ -474,13 +477,13 @@ def create(cls, user, method, external_id):
def delete_by_external_id(cls, method, external_id):
"""Unlink a user from an external id."""
with db.session.begin_nested():
cls.query.filter_by(id=external_id, method=method).delete()
db.session.query(cls).filter_by(id=external_id, method=method).delete()

@classmethod
def delete_by_user(cls, method, user):
"""Unlink a user from an external id."""
with db.session.begin_nested():
cls.query.filter_by(id_user=user.id, method=method).delete()
db.session.query(cls).filter_by(id_user=user.id, method=method).delete()


class DomainOrg(db.Model):
Expand Down Expand Up @@ -538,7 +541,7 @@ def create(cls, label):
@classmethod
def get(cls, label):
"""Get a domain category."""
return cls.query.filter_by(label=label).one_or_none()
return db.session.query(cls).filter_by(label=label).one_or_none()


class Domain(db.Model, Timestamp):
Expand Down
5 changes: 3 additions & 2 deletions invenio_accounts/sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# This file is part of Invenio.
# Copyright (C) 2015-2024 CERN.
# Copyright (C) 2024 Graz University of Technology.
#
# Invenio is free software; you can redistribute it and/or modify it
# under the terms of the MIT License; see LICENSE file for more details.
Expand Down Expand Up @@ -134,7 +135,7 @@ def delete_session(sid_s):
# Find and remove the corresponding SessionActivity entry
if request and "_impersonator_id" not in session:
with db.session.begin_nested():
SessionActivity.query.filter_by(sid_s=sid_s).delete()
db.session.query(SessionActivity).filter_by(sid_s=sid_s).delete()
return 1


Expand All @@ -148,7 +149,7 @@ def delete_user_sessions(user):
for s in user.active_sessions:
_sessionstore.delete(s.sid_s)

SessionActivity.query.filter_by(user=user).delete()
db.session.query(SessionActivity).filter_by(user=user).delete()

return True

Expand Down
5 changes: 3 additions & 2 deletions invenio_accounts/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# This file is part of Invenio.
# Copyright (C) 2015-2018 CERN.
# Copyright (C) 2024 Graz University of Technology.
#
# Invenio is free software; you can redistribute it and/or modify it
# under the terms of the MIT License; see LICENSE file for more details.
Expand Down Expand Up @@ -64,12 +65,12 @@ def delete_ips():
datetime.utcnow() - current_app.config["ACCOUNTS_RETENTION_PERIOD"]
)

LoginInformation.query.filter(
db.session.query(LoginInformation).filter(
LoginInformation.last_login_ip.isnot(None),
LoginInformation.last_login_at < expiration_date,
).update({LoginInformation.last_login_ip: None})

LoginInformation.query.filter(
db.session.query(LoginInformation).filter(
LoginInformation.current_login_ip.isnot(None),
LoginInformation.current_login_at < expiration_date,
).update({LoginInformation.current_login_ip: None})
Expand Down
7 changes: 4 additions & 3 deletions invenio_accounts/views/security.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# This file is part of Invenio.
# Copyright (C) 2017-2018 CERN.
# Copyright (C) 2024 Graz University of Technology.
#
# Invenio is free software; you can redistribute it and/or modify it
# under the terms of the MIT License; see LICENSE file for more details.
Expand Down Expand Up @@ -48,9 +49,9 @@ def revoke_session():

sid_s = form.data["sid_s"]
if (
SessionActivity.query.filter_by(
user_id=current_user.get_id(), sid_s=sid_s
).count()
db.session.query(SessionActivity)
.filter_by(user_id=current_user.get_id(), sid_s=sid_s)
.count()
== 1
):
delete_session(sid_s=sid_s)
Expand Down
12 changes: 7 additions & 5 deletions tests/test_invenio_accounts.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,9 @@ def test_datastore_usercreate(app):
ds.commit()
u2 = ds.find_user(email="[email protected]")
assert u1 == u2
assert 1 == User.query.filter_by(email="[email protected]").count()
assert (
1 == db.session.query(User).filter_by(email="[email protected]").count()
)


def test_datastore_rolecreate(app):
Expand All @@ -162,7 +164,7 @@ def test_datastore_rolecreate(app):
ds.commit()
r2 = ds.find_role("superuser")
assert r1 == r2
assert 1 == Role.query.filter_by(name="superuser").count()
assert 1 == db.session.query(Role).filter_by(name="superuser").count()


def test_datastore_update_role(app):
Expand All @@ -173,7 +175,7 @@ def test_datastore_update_role(app):
ds.commit()
r2 = ds.find_role("superuser")
assert r1 == r2
assert 1 == Role.query.filter_by(name="superuser").count()
assert 1 == db.session.query(Role).filter_by(name="superuser").count()
assert r2.is_managed is True

r1 = ds.update_role(
Expand All @@ -186,8 +188,8 @@ def test_datastore_update_role(app):
assert r1 == r2
assert r2.description == "updated description"
assert r2.is_managed is False
assert 1 == Role.query.filter_by(name="megauser").count()
assert 0 == Role.query.filter_by(name="superuser").count()
assert 1 == db.session.query(Role).filter_by(name="megauser").count()
assert 0 == db.session.query(Role).filter_by(name="superuser").count()


def test_datastore_assignrole(app):
Expand Down
16 changes: 9 additions & 7 deletions tests/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# This file is part of Invenio.
# Copyright (C) 2015-2018 CERN.
# Copyright (C) 2024 Graz University of Technology.
#
# Invenio is free software; you can redistribute it and/or modify it
# under the terms of the MIT License; see LICENSE file for more details.
Expand All @@ -16,6 +17,7 @@
from flask_login import login_required
from flask_mail import Message
from flask_security import url_for_security
from invenio_db import db

from invenio_accounts.models import SessionActivity, User
from invenio_accounts.tasks import clean_session_table, delete_ips, send_security_email
Expand Down Expand Up @@ -79,7 +81,7 @@ def test():
password=user1.password_plaintext,
),
)
assert len(SessionActivity.query.all()) == 1
assert len(db.session.query(SessionActivity).all()) == 1
sleep(15)

with task_app.test_client() as client:
Expand All @@ -90,11 +92,11 @@ def test():
password=user2.password_plaintext,
),
)
assert len(SessionActivity.query.all()) == 2
assert len(db.session.query(SessionActivity).all()) == 2
sleep(10)

clean_session_table.s().apply()
assert len(SessionActivity.query.all()) == 1
assert len(db.session.query(SessionActivity).all()) == 1

protected_url = url_for("test")

Expand All @@ -103,7 +105,7 @@ def test():

sleep(15)
clean_session_table.s().apply()
assert len(SessionActivity.query.all()) == 0
assert len(db.session.query(SessionActivity).all()) == 0

res = client.get(protected_url)
# check if the user is really logout
Expand Down Expand Up @@ -146,14 +148,14 @@ def test_delete_ips(task_app):

delete_ips()

user = User.query.filter(User.id == user1.id).one()
user = db.session.query(User).filter(User.id == user1.id).one()
assert user.last_login_ip is None
assert user.current_login_ip is None

user = User.query.filter(User.id == user2.id).one()
user = db.session.query(User).filter(User.id == user2.id).one()
assert user.last_login_ip is not None
assert user.current_login_ip is not None

user = User.query.filter(User.id == user3.id).one()
user = db.session.query(User).filter(User.id == user3.id).one()
assert user.last_login_ip is None
assert user.current_login_ip is not None
17 changes: 9 additions & 8 deletions tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# This file is part of Invenio.
# Copyright (C) 2015-2024 CERN.
# Copyright (C) 2024 Graz University of Technology.
#
# Invenio is free software; you can redistribute it and/or modify it
# under the terms of the MIT License; see LICENSE file for more details.
Expand Down Expand Up @@ -122,12 +123,12 @@ def test_view_list_sessions(app):
assert res.status_code == 200

# check session for user 1 is not in the list
sessions_1 = SessionActivity.query.filter_by(user_id=user1.id).all()
sessions_1 = db.session.query(SessionActivity).filter_by(user_id=user1.id).all()
assert len(sessions_1) == 1
assert sessions_1[0].sid_s not in res.data.decode("utf-8")

# check session for user 2 is in the list
sessions_2 = SessionActivity.query.filter_by(user_id=user2.id).all()
sessions_2 = db.session.query(SessionActivity).filter_by(user_id=user2.id).all()
assert len(sessions_2) == 1
assert sessions_2[0].sid_s in res.data.decode("utf-8")

Expand All @@ -136,9 +137,9 @@ def test_view_list_sessions(app):
res = client.post(url, data={"sid_s": sessions_1[0].sid_s})
assert res.status_code == 302
assert (
SessionActivity.query.filter_by(
user_id=user1.id, sid_s=sessions_1[0].sid_s
).count()
db.session.query(SessionActivity)
.filter_by(user_id=user1.id, sid_s=sessions_1[0].sid_s)
.count()
== 1
)

Expand All @@ -147,9 +148,9 @@ def test_view_list_sessions(app):
res = client.post(url, data={"sid_s": sessions_2[0].sid_s})
assert res.status_code == 302
assert (
SessionActivity.query.filter_by(
user_id=user1.id, sid_s=sessions_2[0].sid_s
).count()
db.session.query(SessionActivity)
.filter_by(user_id=user1.id, sid_s=sessions_2[0].sid_s)
.count()
== 0
)

Expand Down

0 comments on commit ae541a5

Please sign in to comment.