Skip to content

Commit

Permalink
Fix SA2.0 (query->select) in galaxy.tools
Browse files Browse the repository at this point in the history
  • Loading branch information
jdavcs committed Aug 14, 2023
1 parent d692e1a commit 955cd84
Show file tree
Hide file tree
Showing 7 changed files with 96 additions and 60 deletions.
20 changes: 14 additions & 6 deletions lib/galaxy/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@
from lxml import etree
from mako.template import Template
from packaging.version import Version
from sqlalchemy import (
delete,
func,
select,
)

from galaxy import (
exceptions,
Expand All @@ -39,6 +44,8 @@
from galaxy.job_execution import output_collect
from galaxy.metadata import get_metadata_compute_strategy
from galaxy.model.base import transaction
from galaxy.model.repositories.job import JobRepository
from galaxy.model.repositories.stored_workflow import StoredWorkflowRepository
from galaxy.tool_shed.util.repository_util import get_installed_repository
from galaxy.tool_shed.util.shed_util_common import set_image_paths
from galaxy.tool_util.deps import (
Expand Down Expand Up @@ -346,9 +353,9 @@ def __init__(self, app):

def reset_tags(self):
log.info(
f"removing all tool tag associations ({str(self.sa_session.query(self.app.model.ToolTagAssociation).count())})"
f"removing all tool tag associations ({str(self.sa_session.scalar(select(func.count(self.app.model.ToolTagAssociation))))})"
)
self.sa_session.query(self.app.model.ToolTagAssociation).delete()
self.sa_session.execute(delete(self.app.model.ToolTagAssociation))
with transaction(self.sa_session):
self.sa_session.commit()

Expand All @@ -359,7 +366,8 @@ def handle_tags(self, tool_id, tool_definition_source):
for tag_name in tag_names:
if tag_name == "":
continue
tag = self.sa_session.query(self.app.model.Tag).filter_by(name=tag_name).first()
stmt = select(self.app.model.Tag).filter_by(name=tag_name).limit(1)
tag = self.sa_session.scalars(stmt).first()
if not tag:
tag = self.app.model.Tag(name=tag_name)
self.sa_session.add(tag)
Expand Down Expand Up @@ -618,7 +626,7 @@ def _load_workflow(self, workflow_id):
which is encoded in the tool panel.
"""
id = self.app.security.decode_id(workflow_id)
stored = self.app.model.context.query(self.app.model.StoredWorkflow).get(id)
stored = StoredWorkflowRepository(self.app.model.context).get(id)
return stored.latest_workflow

def __build_tool_version_select_field(self, tools, tool_id, set_selected):
Expand Down Expand Up @@ -3121,7 +3129,7 @@ def exec_after_process(self, app, inp_data, out_data, param_dict, job=None, **kw
self.sa_session.commit()

def job_failed(self, job_wrapper, message, exception=False):
job = job_wrapper.sa_session.query(model.Job).get(job_wrapper.job_id)
job = JobRepository(job_wrapper.sa_session).get(job_wrapper.job_id)
if job:
inp_data = {}
for dataset_assoc in job.input_datasets:
Expand Down Expand Up @@ -3168,7 +3176,7 @@ def exec_after_process(self, app, inp_data, out_data, param_dict, job=None, **kw

def job_failed(self, job_wrapper, message, exception=False):
super().job_failed(job_wrapper, message, exception=exception)
job = job_wrapper.sa_session.query(model.Job).get(job_wrapper.job_id)
job = JobRepository(job_wrapper.sa_session).get(job_wrapper.job_id)
self.__remove_interactivetool_by_job(job)


Expand Down
6 changes: 4 additions & 2 deletions lib/galaxy/tools/actions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from galaxy.model.base import transaction
from galaxy.model.dataset_collections.builder import CollectionBuilder
from galaxy.model.none_like import NoneDataset
from galaxy.model.repositories.hda import HistoryDatasetAssociationRepository as hda_repo
from galaxy.model.repositories.job import JobRepository
from galaxy.objectstore import ObjectStorePopulator
from galaxy.tools.parameters import update_dataset_ids
from galaxy.tools.parameters.basic import (
Expand Down Expand Up @@ -481,7 +483,7 @@ def handle_output(name, output, hidden=None):
if async_tool and name in incoming:
# HACK: output data has already been created as a result of the async controller
dataid = incoming[name]
data = trans.sa_session.query(app.model.HistoryDatasetAssociation).get(dataid)
data = hda_repo(trans.sa_session).get(dataid)
assert data is not None
out_data[name] = data
else:
Expand Down Expand Up @@ -745,7 +747,7 @@ def _remap_job_on_rerun(self, trans, galaxy_session, rerun_remap_job_id, current
input datasets to be those of the job that is being rerun.
"""
try:
old_job = trans.sa_session.query(trans.app.model.Job).get(rerun_remap_job_id)
old_job = JobRepository(trans.sa_session).get(rerun_remap_job_id)
assert old_job is not None, f"({rerun_remap_job_id}/{current_job.id}): Old job id is invalid"
assert (
old_job.tool_id == current_job.tool_id
Expand Down
20 changes: 13 additions & 7 deletions lib/galaxy/tools/actions/upload_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
Optional,
)

from sqlalchemy import select
from sqlalchemy.orm import joinedload
from webob.compat import cgi_FieldStorage

Expand All @@ -30,6 +31,9 @@
Role,
)
from galaxy.model.base import transaction
from galaxy.model.repositories.form_definition import FormDefinitionRepository
from galaxy.model.repositories.library_folder import LibraryFolderRepository
from galaxy.model.repositories.role import RoleRepository
from galaxy.util import is_url
from galaxy.util.path import external_chown

Expand Down Expand Up @@ -94,21 +98,22 @@ def handle_library_params(
# See if we have any template field contents
template_field_contents = {}
template_id = params.get("template_id", None)
folder = trans.sa_session.query(LibraryFolder).get(folder_id)
folder = LibraryFolderRepository(trans.sa_session).get(folder_id)
# We are inheriting the folder's info_association, so we may have received inherited contents or we may have redirected
# here after the user entered template contents ( due to errors ).
template: Optional[FormDefinition] = None
if template_id not in [None, "None"]:
template = trans.sa_session.query(FormDefinition).get(template_id)
template = FormDefinitionRepository(trans.sa_session).get(template_id)
assert template
for field in template.fields:
field_name = field["name"]
if params.get(field_name, False):
field_value = util.restore_text(params.get(field_name, ""))
template_field_contents[field_name] = field_value
roles: List[Role] = []
role_repo = RoleRepository(trans.sa_session)
for role_id in util.listify(params.get("roles", [])):
role = trans.sa_session.query(Role).get(role_id)
role = role_repo.get(role_id)
roles.append(role)
tags = params.get("tags", None)
return LibraryParams(
Expand Down Expand Up @@ -436,10 +441,11 @@ def active_folders(trans, folder):
# Stolen from galaxy.web.controllers.library_common (importing from which causes a circular issues).
# Much faster way of retrieving all active sub-folders within a given folder than the
# performance of the mapper. This query also eagerloads the permissions on each folder.
return (
trans.sa_session.query(LibraryFolder)
stmt = (
select(LibraryFolder)
.filter_by(parent=folder, deleted=False)
.options(joinedload(LibraryFolder.actions))
.order_by(LibraryFolder.table.c.name)
.all()
.unique()
.order_by(LibraryFolder.name)
)
return trans.sa_session.scalars(stmt).all()
6 changes: 4 additions & 2 deletions lib/galaxy/tools/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
util,
web,
)
from galaxy.model.repositories.hda import HistoryDatasetAssociationRepository as hda_repo
from galaxy.security.validate_user_input import validate_email_str
from galaxy.util import unicodify

Expand Down Expand Up @@ -135,12 +136,13 @@ def __init__(self, hda, app):
# Get the dataset
sa_session = app.model.context
if not isinstance(hda, model.HistoryDatasetAssociation):
_hda_repo = hda_repo(sa_session)
hda_id = hda
try:
hda = sa_session.query(model.HistoryDatasetAssociation).get(hda_id)
hda = _hda_repo.get(hda_id)
assert hda is not None, ValueError("No HDA yet")
except Exception:
hda = sa_session.query(model.HistoryDatasetAssociation).get(app.security.decode_id(hda_id))
hda = _hda_repo.get(app.security.decode_id(hda_id))
assert isinstance(hda, model.HistoryDatasetAssociation), ValueError(f"Bad value provided for HDA ({hda}).")
self.hda = hda
# Get the associated job
Expand Down
5 changes: 4 additions & 1 deletion lib/galaxy/tools/imp_exp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import shutil
from typing import Optional

from sqlalchemy import select

from galaxy import model
from galaxy.model import store
from galaxy.model.base import transaction
Expand Down Expand Up @@ -49,7 +51,8 @@ def cleanup_after_job(self):
# Import history.
#

jiha = self.sa_session.query(model.JobImportHistoryArchive).filter_by(job_id=self.job_id).first()
stmt = select(model.JobImportHistoryArchive).filter_by(job_id=self.job_id).limit(1)
jiha = self.sa_session.scalars(stmt).first()
if not jiha:
return None
user = jiha.job.user
Expand Down
Loading

0 comments on commit 955cd84

Please sign in to comment.