Skip to content

Commit

Permalink
Fix SA2.0 (query->select) in galaxy.tools
Browse files Browse the repository at this point in the history
  • Loading branch information
jdavcs committed Aug 11, 2023
1 parent 49fee79 commit 37e9708
Show file tree
Hide file tree
Showing 7 changed files with 59 additions and 51 deletions.
18 changes: 12 additions & 6 deletions lib/galaxy/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@
from lxml import etree
from mako.template import Template
from packaging.version import Version
from sqlalchemy import (
delete,
func,
select,
)

from galaxy import (
exceptions,
Expand Down Expand Up @@ -346,9 +351,9 @@ def __init__(self, app):

def reset_tags(self):
log.info(
f"removing all tool tag associations ({str(self.sa_session.query(self.app.model.ToolTagAssociation).count())})"
f"removing all tool tag associations ({str(self.sa_session.scalar(select(func.count(self.app.model.ToolTagAssociation))))})"
)
self.sa_session.query(self.app.model.ToolTagAssociation).delete()
self.sa_session.execute(delete(self.app.model.ToolTagAssociation))
with transaction(self.sa_session):
self.sa_session.commit()

Expand All @@ -359,7 +364,8 @@ def handle_tags(self, tool_id, tool_definition_source):
for tag_name in tag_names:
if tag_name == "":
continue
tag = self.sa_session.query(self.app.model.Tag).filter_by(name=tag_name).first()
stmt = select(self.app.model.Tag).filter_by(name=tag_name).limit(1)
tag = self.sa_session.scalars(stmt).first()
if not tag:
tag = self.app.model.Tag(name=tag_name)
self.sa_session.add(tag)
Expand Down Expand Up @@ -618,7 +624,7 @@ def _load_workflow(self, workflow_id):
which is encoded in the tool panel.
"""
id = self.app.security.decode_id(workflow_id)
stored = self.app.model.context.query(self.app.model.StoredWorkflow).get(id)
stored = self.app.model.context.get(self.app.model.StoredWorkflow, id)
return stored.latest_workflow

def __build_tool_version_select_field(self, tools, tool_id, set_selected):
Expand Down Expand Up @@ -3121,7 +3127,7 @@ def exec_after_process(self, app, inp_data, out_data, param_dict, job=None, **kw
self.sa_session.commit()

def job_failed(self, job_wrapper, message, exception=False):
job = job_wrapper.sa_session.query(model.Job).get(job_wrapper.job_id)
job = job_wrapper.sa_session.get(model.Job, job_wrapper.job_id)
if job:
inp_data = {}
for dataset_assoc in job.input_datasets:
Expand Down Expand Up @@ -3168,7 +3174,7 @@ def exec_after_process(self, app, inp_data, out_data, param_dict, job=None, **kw

def job_failed(self, job_wrapper, message, exception=False):
super().job_failed(job_wrapper, message, exception=exception)
job = job_wrapper.sa_session.query(model.Job).get(job_wrapper.job_id)
job = job_wrapper.sa_session.get(model.Job, job_wrapper.job_id)
self.__remove_interactivetool_by_job(job)


Expand Down
4 changes: 2 additions & 2 deletions lib/galaxy/tools/actions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,7 @@ def handle_output(name, output, hidden=None):
if async_tool and name in incoming:
# HACK: output data has already been created as a result of the async controller
dataid = incoming[name]
data = trans.sa_session.query(app.model.HistoryDatasetAssociation).get(dataid)
data = trans.sa_session.get(app.model.HistoryDatasetAssociation, dataid)
assert data is not None
out_data[name] = data
else:
Expand Down Expand Up @@ -745,7 +745,7 @@ def _remap_job_on_rerun(self, trans, galaxy_session, rerun_remap_job_id, current
input datasets to be those of the job that is being rerun.
"""
try:
old_job = trans.sa_session.query(trans.app.model.Job).get(rerun_remap_job_id)
old_job = trans.sa_session.get(trans.app.model.Job, rerun_remap_job_id)
assert old_job is not None, f"({rerun_remap_job_id}/{current_job.id}): Old job id is invalid"
assert (
old_job.tool_id == current_job.tool_id
Expand Down
16 changes: 9 additions & 7 deletions lib/galaxy/tools/actions/upload_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
Optional,
)

from sqlalchemy import select
from sqlalchemy.orm import joinedload
from webob.compat import cgi_FieldStorage

Expand Down Expand Up @@ -94,12 +95,12 @@ def handle_library_params(
# See if we have any template field contents
template_field_contents = {}
template_id = params.get("template_id", None)
folder = trans.sa_session.query(LibraryFolder).get(folder_id)
folder = trans.sa_session.get(LibraryFolder, folder_id)
# We are inheriting the folder's info_association, so we may have received inherited contents or we may have redirected
# here after the user entered template contents ( due to errors ).
template: Optional[FormDefinition] = None
if template_id not in [None, "None"]:
template = trans.sa_session.query(FormDefinition).get(template_id)
template = trans.sa_session.get(FormDefinition, template_id)
assert template
for field in template.fields:
field_name = field["name"]
Expand All @@ -108,7 +109,7 @@ def handle_library_params(
template_field_contents[field_name] = field_value
roles: List[Role] = []
for role_id in util.listify(params.get("roles", [])):
role = trans.sa_session.query(Role).get(role_id)
role = trans.sa_session.get(Role, role_id)
roles.append(role)
tags = params.get("tags", None)
return LibraryParams(
Expand Down Expand Up @@ -436,10 +437,11 @@ def active_folders(trans, folder):
# Stolen from galaxy.web.controllers.library_common (importing from which causes a circular issues).
# Much faster way of retrieving all active sub-folders within a given folder than the
# performance of the mapper. This query also eagerloads the permissions on each folder.
return (
trans.sa_session.query(LibraryFolder)
stmt = (
select(LibraryFolder)
.filter_by(parent=folder, deleted=False)
.options(joinedload(LibraryFolder.actions))
.order_by(LibraryFolder.table.c.name)
.all()
.unique()
.order_by(LibraryFolder.name)
)
return trans.sa_session.scalars(stmt).all()
4 changes: 2 additions & 2 deletions lib/galaxy/tools/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,10 @@ def __init__(self, hda, app):
if not isinstance(hda, model.HistoryDatasetAssociation):
hda_id = hda
try:
hda = sa_session.query(model.HistoryDatasetAssociation).get(hda_id)
hda = sa_session.get(model.HistoryDatasetAssociation, hda_id)
assert hda is not None, ValueError("No HDA yet")
except Exception:
hda = sa_session.query(model.HistoryDatasetAssociation).get(app.security.decode_id(hda_id))
hda = sa_session.get(model.HistoryDatasetAssociation, app.security.decode_id(hda_id))
assert isinstance(hda, model.HistoryDatasetAssociation), ValueError(f"Bad value provided for HDA ({hda}).")
self.hda = hda
# Get the associated job
Expand Down
5 changes: 4 additions & 1 deletion lib/galaxy/tools/imp_exp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import shutil
from typing import Optional

from sqlalchemy import select

from galaxy import model
from galaxy.model import store
from galaxy.model.base import transaction
Expand Down Expand Up @@ -49,7 +51,8 @@ def cleanup_after_job(self):
# Import history.
#

jiha = self.sa_session.query(model.JobImportHistoryArchive).filter_by(job_id=self.job_id).first()
stmt = select(model.JobImportHistoryArchive).filter_by(job_id=self.job_id).limit(1)
jiha = self.sa_session.scalars(stmt).first()
if not jiha:
return None
user = jiha.job.user
Expand Down
61 changes: 29 additions & 32 deletions lib/galaxy/tools/parameters/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1943,13 +1943,13 @@ def single_to_python(value):
if isinstance(value, dict) and "src" in value:
id = value["id"] if isinstance(value["id"], int) else app.security.decode_id(value["id"])
if value["src"] == "dce":
return app.model.context.query(DatasetCollectionElement).get(id)
return app.model.context.get(DatasetCollectionElement, id)
elif value["src"] == "hdca":
return app.model.context.query(HistoryDatasetCollectionAssociation).get(id)
return app.model.context.get(HistoryDatasetCollectionAssociation, id)
elif value["src"] == "ldda":
return app.model.context.query(LibraryDatasetDatasetAssociation).get(id)
return app.model.context.get(LibraryDatasetDatasetAssociation, id)
else:
return app.model.context.query(HistoryDatasetAssociation).get(id)
return app.model.context.get(HistoryDatasetAssociation, id)

if isinstance(value, dict) and "values" in value:
if hasattr(self, "multiple") and self.multiple is True:
Expand All @@ -1963,21 +1963,21 @@ def single_to_python(value):
return None
if isinstance(value, str) and value.find(",") > -1:
return [
app.model.context.query(HistoryDatasetAssociation).get(int(v))
app.model.context.get(HistoryDatasetAssociation, int(v))
for v in value.split(",")
if v not in none_values
]
elif str(value).startswith("__collection_reduce__|"):
decoded_id = str(value)[len("__collection_reduce__|") :]
if not decoded_id.isdigit():
decoded_id = app.security.decode_id(decoded_id)
return app.model.context.query(HistoryDatasetCollectionAssociation).get(int(decoded_id))
return app.model.context.get(HistoryDatasetCollectionAssociation, int(decoded_id))
elif str(value).startswith("dce:"):
return app.model.context.query(DatasetCollectionElement).get(int(value[len("dce:") :]))
return app.model.context.get(DatasetCollectionElement, int(value[len("dce:") :]))
elif str(value).startswith("hdca:"):
return app.model.context.query(HistoryDatasetCollectionAssociation).get(int(value[len("hdca:") :]))
return app.model.context.get(HistoryDatasetCollectionAssociation, int(value[len("hdca:") :]))
else:
return app.model.context.query(HistoryDatasetAssociation).get(int(value))
return app.model.context.get(HistoryDatasetAssociation, int(value))

def validate(self, value, trans=None):
def do_validate(v):
Expand Down Expand Up @@ -2097,17 +2097,17 @@ def from_json(self, value, trans, other_values=None):
if isinstance(single_value, dict) and "src" in single_value and "id" in single_value:
if single_value["src"] == "hda":
decoded_id = trans.security.decode_id(single_value["id"])
rval.append(trans.sa_session.query(HistoryDatasetAssociation).get(decoded_id))
rval.append(trans.sa_session.get(HistoryDatasetAssociation, decoded_id))
elif single_value["src"] == "hdca":
found_hdca = True
decoded_id = trans.security.decode_id(single_value["id"])
rval.append(trans.sa_session.query(HistoryDatasetCollectionAssociation).get(decoded_id))
rval.append(trans.sa_session.get(HistoryDatasetCollectionAssociation, decoded_id))
elif single_value["src"] == "ldda":
decoded_id = trans.security.decode_id(single_value["id"])
rval.append(trans.sa_session.query(LibraryDatasetDatasetAssociation).get(decoded_id))
rval.append(trans.sa_session.get(LibraryDatasetDatasetAssociation, decoded_id))
elif single_value["src"] == "dce":
decoded_id = trans.security.decode_id(single_value["id"])
rval.append(trans.sa_session.query(DatasetCollectionElement).get(decoded_id))
rval.append(trans.sa_session.get(DatasetCollectionElement, decoded_id))
else:
raise ValueError(f"Unknown input source {single_value['src']} passed to job submission API.")
elif isinstance(
Expand All @@ -2126,7 +2126,7 @@ def from_json(self, value, trans, other_values=None):
# support that for integer column types.
log.warning("Encoded ID where unencoded ID expected.")
single_value = trans.security.decode_id(single_value)
rval.append(trans.sa_session.query(HistoryDatasetAssociation).get(single_value))
rval.append(trans.sa_session.get(HistoryDatasetAssociation, single_value))
if found_hdca:
for val in rval:
if not isinstance(val, HistoryDatasetCollectionAssociation):
Expand All @@ -2139,26 +2139,26 @@ def from_json(self, value, trans, other_values=None):
elif isinstance(value, dict) and "src" in value and "id" in value:
if value["src"] == "hda":
decoded_id = trans.security.decode_id(value["id"])
rval.append(trans.sa_session.query(HistoryDatasetAssociation).get(decoded_id))
rval.append(trans.sa_session.get(HistoryDatasetAssociation, decoded_id))
elif value["src"] == "hdca":
decoded_id = trans.security.decode_id(value["id"])
rval.append(trans.sa_session.query(HistoryDatasetCollectionAssociation).get(decoded_id))
rval.append(trans.sa_session.get(HistoryDatasetCollectionAssociation, decoded_id))
elif value["src"] == "dce":
decoded_id = trans.security.decode_id(value["id"])
rval.append(trans.sa_session.query(DatasetCollectionElement).get(decoded_id))
rval.append(trans.sa_session.get(DatasetCollectionElement, decoded_id))
else:
raise ValueError(f"Unknown input source {value['src']} passed to job submission API.")
elif str(value).startswith("__collection_reduce__|"):
encoded_ids = [v[len("__collection_reduce__|") :] for v in str(value).split(",")]
decoded_ids = map(trans.security.decode_id, encoded_ids)
rval = []
for decoded_id in decoded_ids:
hdca = trans.sa_session.query(HistoryDatasetCollectionAssociation).get(decoded_id)
hdca = trans.sa_session.get(HistoryDatasetCollectionAssociation, decoded_id)
rval.append(hdca)
elif isinstance(value, HistoryDatasetCollectionAssociation) or isinstance(value, DatasetCollectionElement):
rval.append(value)
else:
rval.append(trans.sa_session.query(HistoryDatasetAssociation).get(value))
rval.append(trans.sa_session.get(HistoryDatasetAssociation, value))
dataset_matcher_factory = get_dataset_matcher_factory(trans)
dataset_matcher = dataset_matcher_factory.dataset_matcher(self, other_values)
for v in rval:
Expand Down Expand Up @@ -2443,28 +2443,24 @@ def from_json(self, value, trans, other_values=None):
rval = value
elif isinstance(value, dict) and "src" in value and "id" in value:
if value["src"] == "hdca":
rval = trans.sa_session.query(HistoryDatasetCollectionAssociation).get(
trans.security.decode_id(value["id"])
)
rval = trans.sa_session.get(HistoryDatasetCollectionAssociation, trans.security.decode_id(value["id"]))
elif isinstance(value, list):
if len(value) > 0:
value = value[0]
if isinstance(value, dict) and "src" in value and "id" in value:
if value["src"] == "hdca":
rval = trans.sa_session.query(HistoryDatasetCollectionAssociation).get(
trans.security.decode_id(value["id"])
rval = trans.sa_session.get(
HistoryDatasetCollectionAssociation, trans.security.decode_id(value["id"])
)
elif value["src"] == "dce":
rval = trans.sa_session.query(DatasetCollectionElement).get(
trans.security.decode_id(value["id"])
)
rval = trans.sa_session.get(DatasetCollectionElement, trans.security.decode_id(value["id"]))
elif isinstance(value, str):
if value.startswith("dce:"):
rval = trans.sa_session.query(DatasetCollectionElement).get(value[len("dce:") :])
rval = trans.sa_session.get(DatasetCollectionElement, value[len("dce:") :])
elif value.startswith("hdca:"):
rval = trans.sa_session.query(HistoryDatasetCollectionAssociation).get(value[len("hdca:") :])
rval = trans.sa_session.get(HistoryDatasetCollectionAssociation, value[len("hdca:") :])
else:
rval = trans.sa_session.query(HistoryDatasetCollectionAssociation).get(value)
rval = trans.sa_session.get(HistoryDatasetCollectionAssociation, value)
if rval and isinstance(rval, HistoryDatasetCollectionAssociation):
if rval.deleted:
raise ParameterValueError("the previously selected dataset collection has been deleted", self.name)
Expand Down Expand Up @@ -2634,8 +2630,9 @@ def to_python(self, value, app, other_values=None, validate=False):
else:
lst = []
break
lda = app.model.context.query(LibraryDatasetDatasetAssociation).get(
lda_id if isinstance(lda_id, int) else app.security.decode_id(lda_id)
lda = app.model.context.get(
LibraryDatasetDatasetAssociation,
lda_id if isinstance(lda_id, int) else app.security.decode_id(lda_id),
)
if lda is not None:
lst.append(lda)
Expand Down
2 changes: 1 addition & 1 deletion lib/galaxy/tools/parameters/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def __expand_collection_parameter(trans, input_key, incoming_val, collections_to
encoded_hdc_id = incoming_val
subcollection_type = None
hdc_id = trans.app.security.decode_id(encoded_hdc_id)
hdc = trans.sa_session.query(model.HistoryDatasetCollectionAssociation).get(hdc_id)
hdc = trans.sa_session.get(model.HistoryDatasetCollectionAssociation, hdc_id)
collections_to_match.add(input_key, hdc, subcollection_type=subcollection_type, linked=linked)
if subcollection_type is not None:
subcollection_elements = subcollections.split_dataset_collection_instance(hdc, subcollection_type)
Expand Down

0 comments on commit 37e9708

Please sign in to comment.