Skip to content

Commit

Permalink
Fix SA2.0 (query->select) in galaxy.webapps.base (note)
Browse files Browse the repository at this point in the history
Also drop unused parameters and logic from select methods in legacy
controller.
  • Loading branch information
jdavcs committed Aug 10, 2023
1 parent 8ef4020 commit 6c31e3f
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 81 deletions.
102 changes: 30 additions & 72 deletions lib/galaxy/webapps/base/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@
Optional,
)

from sqlalchemy import true
from sqlalchemy import (
select,
true,
)
from webob.exc import (
HTTPBadRequest,
HTTPInternalServerError,
Expand Down Expand Up @@ -447,7 +450,7 @@ def _copy_hdca_to_library_folder(self, trans, hda_manager, from_hdca_id: int, fo
Fetches the collection identified by `from_hcda_id` and dispatches individual collection elements to
_copy_hda_to_library_folder
"""
hdca = trans.sa_session.query(trans.app.model.HistoryDatasetCollectionAssociation).get(from_hdca_id)
hdca = trans.sa_session.get(trans.app.model.HistoryDatasetCollectionAssociation, from_hdca_id)
if hdca.collection.collection_type != "list":
raise exceptions.NotImplemented(
"Cannot add nested collections to library. Please flatten your collection first."
Expand Down Expand Up @@ -614,85 +617,40 @@ def get_visualization(self, trans, id, check_ownership=True, check_accessible=Fa
"""
# Load workflow from database
try:
visualization = trans.sa_session.query(trans.model.Visualization).get(trans.security.decode_id(id))
visualization = trans.sa_session.get(trans.model.Visualization, trans.security.decode_id(id))
except TypeError:
visualization = None
if not visualization:
error("Visualization not found")
else:
return self.security_check(trans, visualization, check_ownership, check_accessible)

def get_visualizations_by_user(self, trans, user, order_by=None, query_only=False):
"""
Return query or query results of visualizations filtered by a user.
Set `order_by` to a column or list of columns to change the order
returned. Defaults to `DEFAULT_ORDER_BY`.
Set `query_only` to return just the query for further filtering or
processing.
"""
# TODO: move into model (as class attr)
DEFAULT_ORDER_BY = [model.Visualization.title]
if not order_by:
order_by = DEFAULT_ORDER_BY
if not isinstance(order_by, list):
order_by = [order_by]
query = trans.sa_session.query(model.Visualization)
query = query.filter(model.Visualization.user == user)
if order_by:
query = query.order_by(*order_by)
if query_only:
return query
return query.all()

def get_visualizations_shared_with_user(self, trans, user, order_by=None, query_only=False):
"""
Return query or query results for visualizations shared with the given user.
def get_visualizations_by_user(self, trans, user):
"""Return query results of visualizations filtered by a user."""
stmt = select(model.Visualization).filter(model.Visualization.user == user).order_by(model.Visualization.title)
return trans.sa_session.scalars(stmt).all()

def get_visualizations_shared_with_user(self, trans, user):
"""Return query results for visualizations shared with the given user."""
# The second `where` clause removes duplicates when a user shares with themselves.
stmt = (
select(model.Visualization)
.join(model.VisualizationUserShareAssociation)
.where(model.VisualizationUserShareAssociation.user_id == user.id)
.where(model.Visualization.user_id != user.id)
.order_by(model.Visualization.title)
)
return trans.sa_session.scalars(stmt).all()

Set `order_by` to a column or list of columns to change the order
returned. Defaults to `DEFAULT_ORDER_BY`.
Set `query_only` to return just the query for further filtering or
processing.
def get_published_visualizations(self, trans, exclude_user=None):
"""
DEFAULT_ORDER_BY = [model.Visualization.title]
if not order_by:
order_by = DEFAULT_ORDER_BY
if not isinstance(order_by, list):
order_by = [order_by]
query = trans.sa_session.query(model.Visualization).join(model.VisualizationUserShareAssociation)
query = query.filter(model.VisualizationUserShareAssociation.user_id == user.id)
# remove duplicates when a user shares with themselves?
query = query.filter(model.Visualization.user_id != user.id)
if order_by:
query = query.order_by(*order_by)
if query_only:
return query
return query.all()

def get_published_visualizations(self, trans, exclude_user=None, order_by=None, query_only=False):
"""
Return query or query results for published visualizations optionally excluding
the user in `exclude_user`.
Set `order_by` to a column or list of columns to change the order
returned. Defaults to `DEFAULT_ORDER_BY`.
Set `query_only` to return just the query for further filtering or
processing.
Return query results for published visualizations optionally excluding the user in `exclude_user`.
"""
DEFAULT_ORDER_BY = [model.Visualization.title]
if not order_by:
order_by = DEFAULT_ORDER_BY
if not isinstance(order_by, list):
order_by = [order_by]
query = trans.sa_session.query(model.Visualization)
query = query.filter(model.Visualization.published == true())
stmt = select(model.Visualization).filter(model.Visualization.published == true())
if exclude_user:
query = query.filter(model.Visualization.user != exclude_user)
if order_by:
query = query.order_by(*order_by)
if query_only:
return query
return query.all()
stmt = stmt.filter(model.Visualization.user != exclude_user)
stmt = stmt.order_by(model.Visualization.title)
return trans.sa_session.scalars(stmt).all()

# TODO: move into model (to_dict)
def get_visualization_summary_dict(self, visualization):
Expand Down Expand Up @@ -837,7 +795,7 @@ def save_visualization(self, trans, config, type, id=None, title=None, dbkey=Non
vis = self._create_visualization(trans, title, type, dbkey, slug, annotation)
else:
decoded_id = trans.security.decode_id(id)
vis = session.query(trans.model.Visualization).get(decoded_id)
vis = session.get(trans.model.Visualization, decoded_id)
# TODO: security check?

# Create new VisualizationRevision that will be attached to the viz
Expand Down Expand Up @@ -1071,7 +1029,7 @@ def get_hda(self, trans, dataset_id, check_ownership=True, check_accessible=Fals
raise HTTPBadRequest(f"Invalid dataset id: {str(dataset_id)}.")

try:
data = trans.sa_session.query(trans.app.model.HistoryDatasetAssociation).get(int(dataset_id))
data = trans.sa_session.get(trans.app.model.HistoryDatasetAssociation, int(dataset_id))
except Exception:
raise HTTPBadRequest(f"Invalid dataset id: {str(dataset_id)}.")

Expand Down
22 changes: 13 additions & 9 deletions lib/galaxy/webapps/base/webapp.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from paste.urlmap import URLMap
from sqlalchemy import (
and_,
select,
true,
)
from sqlalchemy.orm.exc import NoResultFound
Expand Down Expand Up @@ -866,13 +867,14 @@ def handle_user_logout(self, logout_all=False):
self.sa_session.add_all((prev_galaxy_session, self.galaxy_session))
galaxy_user_id = prev_galaxy_session.user_id
if logout_all and galaxy_user_id is not None:
for other_galaxy_session in self.sa_session.query(self.app.model.GalaxySession).filter(
stmt = select(self.app.model.GalaxySession).filter(
and_(
self.app.model.GalaxySession.table.c.user_id == galaxy_user_id,
self.app.model.GalaxySession.table.c.is_valid == true(),
self.app.model.GalaxySession.table.c.id != prev_galaxy_session.id,
self.app.model.GalaxySession.user_id == galaxy_user_id,
self.app.model.GalaxySession.is_valid == true(),
self.app.model.GalaxySession.id != prev_galaxy_session.id,
)
):
)
for other_galaxy_session in self.sa_session.scalars(stmt):
other_galaxy_session.is_valid = False
self.sa_session.add(other_galaxy_session)
with transaction(self.sa_session):
Expand Down Expand Up @@ -933,9 +935,10 @@ def get_or_create_default_history(self):
# Look for default history that (a) has default name + is not deleted and
# (b) has no datasets. If suitable history found, use it; otherwise, create
# new history.
unnamed_histories = self.sa_session.query(self.app.model.History).filter_by(
stmt = select(self.app.model.History).filter_by(
user=self.galaxy_session.user, name=self.app.model.History.default_name, deleted=False
)
unnamed_histories = self.sa_session.scalars(stmt)
default_history = None
for history in unnamed_histories:
if history.empty:
Expand All @@ -962,12 +965,13 @@ def get_most_recent_history(self):
if not user:
return None
try:
recent_history = (
self.sa_session.query(self.app.model.History)
stmt = (
select(self.app.model.History)
.filter_by(user=user, deleted=False)
.order_by(self.app.model.History.update_time.desc())
.first()
.limit(1)
)
recent_history = self.sa_session.scalars(stmt).first()
except NoResultFound:
return None
self.set_history(recent_history)
Expand Down

0 comments on commit 6c31e3f

Please sign in to comment.