Skip to content

Commit

Permalink
Fix SA2.0 ORM usage in galaxy.controllers.page [partial]
Browse files Browse the repository at this point in the history
TODO: `build_initial_query`
  • Loading branch information
jdavcs committed Sep 25, 2023
1 parent a633b9a commit b73b989
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 22 deletions.
33 changes: 32 additions & 1 deletion lib/galaxy/managers/pages.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
)

from sqlalchemy import (
desc,
false,
or_,
select,
Expand All @@ -42,7 +43,12 @@
ready_galaxy_markdown_for_export,
ready_galaxy_markdown_for_import,
)
from galaxy.model import PageRevision
from galaxy.model import (
Page,
PageRevision,
PageUserShareAssociation,
User,
)
from galaxy.model.base import transaction
from galaxy.model.index_filter_util import (
append_user_filter,
Expand Down Expand Up @@ -631,3 +637,28 @@ def placeholderRenderForSave(trans: ProvidesHistoryContext, item_class, item_id,
def get_page_revision(session: Session, page_id: int):
stmt = select(PageRevision).filter_by(page_id=page_id)
return session.scalars(stmt)


def get_shared_pages(session: Session, user: User):
stmt = (
select(PageUserShareAssociation)
.where(PageUserShareAssociation.user == user)
.join(Page)
.where(Page.deleted == false())
.order_by(desc(Page.update_time))
)
return session.scalars(stmt)


def get_page(session: Session, user: User, slug: str):
stmt = _build_page_query(select(Page), user, slug)
return session.scalar(stmt).first()


def page_exists(session: Session, user: User, slug: str) -> bool:
stmt = _build_page_query(select(Page.id), user, slug)
return session.scalar(stmt).first() is not None


def _build_page_query(select_clause, user: User, slug: str):
return select_clause.where(Page.user == user).where(Page.slug == slug).where(Page.deleted == false()).limit(1)
36 changes: 15 additions & 21 deletions lib/galaxy/webapps/galaxy/controllers/page.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from markupsafe import escape
from sqlalchemy import (
desc,
false,
true,
)
Expand All @@ -19,8 +18,14 @@
HistoryManager,
HistorySerializer,
)
from galaxy.managers.pages import PageManager
from galaxy.managers.pages import (
get_page as get_page_,
get_shared_pages,
page_exists,
PageManager,
)
from galaxy.managers.sharable import SlugBuilder
from galaxy.managers.users import get_user_by_username
from galaxy.managers.workflows import WorkflowsManager
from galaxy.model.base import transaction
from galaxy.model.item_attrs import UsesItemRatings
Expand Down Expand Up @@ -377,7 +382,7 @@ def list(self, trans, *args, **kwargs):
ids = util.listify(kwargs["id"])
for id in ids:
if operation == "delete":
item = session.query(model.Page).get(self.decode_id(id))
item = session.get(model.Page, self.decode_id(id))
self.security_check(trans, item, check_ownership=True)
item.deleted = True
with transaction(session):
Expand All @@ -397,14 +402,7 @@ def list_published(self, trans, *args, **kwargs):

def _get_shared(self, trans):
"""Identify shared pages"""
shared_by_others = (
trans.sa_session.query(model.PageUserShareAssociation)
.filter_by(user=trans.get_user())
.join(model.Page.table)
.filter(model.Page.deleted == false())
.order_by(desc(model.Page.update_time))
.all()
)
shared_by_others = get_shared_pages(trans.sa_session, trans.get_user())
return [
{"username": p.page.user.username, "slug": p.page.slug, "title": p.page.title} for p in shared_by_others
]
Expand Down Expand Up @@ -480,7 +478,7 @@ def edit(self, trans, payload=None, **kwd):
return self.message_exception(trans, "No page id received for editing.")
decoded_id = self.decode_id(id)
user = trans.get_user()
p = trans.sa_session.query(model.Page).get(decoded_id)
p = trans.sa_session.get(model.Page, decoded_id)
p = self.security_check(trans, p, check_ownership=True)
if trans.request.method == "GET":
if p.slug is None:
Expand Down Expand Up @@ -515,10 +513,7 @@ def edit(self, trans, payload=None, **kwd):
return self.message_exception(
trans, "Page identifier can only contain lowercase letters, numbers, and dashes (-)."
)
elif (
p_slug != p.slug
and trans.sa_session.query(model.Page).filter_by(user=p.user, slug=p_slug, deleted=False).first()
):
elif p_slug != p.slug and page_exists(trans.sa_session, p.user, p_slug):
return self.message_exception(trans, "Page id must be unique.")
else:
p.title = p_title
Expand All @@ -535,7 +530,7 @@ def edit(self, trans, payload=None, **kwd):
@web.require_login()
def display(self, trans, id, **kwargs):
id = self.decode_id(id)
page = trans.sa_session.query(model.Page).get(id)
page = trans.sa_session.get(model.Page, id)
if not page:
raise web.httpexceptions.HTTPNotFound()
return self.display_by_username_and_slug(trans, page.user.username, page.slug)
Expand All @@ -545,9 +540,8 @@ def display_by_username_and_slug(self, trans, username, slug, **kwargs):
"""Display page based on a username and slug."""

# Get page.
session = trans.sa_session
user = session.query(model.User).filter_by(username=username).first()
page = trans.sa_session.query(model.Page).filter_by(user=user, slug=slug, deleted=False).first()
user = get_user_by_username(trans.sa_session, username)
page = get_page_(trans.sa_session, user, slug)
if page is None:
raise web.httpexceptions.HTTPNotFound()

Expand Down Expand Up @@ -605,7 +599,7 @@ def get_page(self, trans, id, check_ownership=True, check_accessible=False):
"""Get a page from the database by id."""
# Load history from database
id = self.decode_id(id)
page = trans.sa_session.query(model.Page).get(id)
page = trans.sa_session.get(model.Page, id)
if not page:
error("Page not found")
else:
Expand Down

0 comments on commit b73b989

Please sign in to comment.