Skip to content

Commit

Permalink
Merge pull request #18811 from nsoranzo/mypy_1.11
Browse files Browse the repository at this point in the history
Update Mypy to 1.11.2 and fix new signature override errors
  • Loading branch information
mvdbeek authored Sep 14, 2024
2 parents ac71c06 + 1f8260d commit 948f38b
Show file tree
Hide file tree
Showing 13 changed files with 64 additions and 51 deletions.
2 changes: 1 addition & 1 deletion lib/galaxy/dependencies/pinned-typecheck-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ cryptography==42.0.8 ; python_version >= "3.8" and python_version < "3.13"
lxml-stubs==0.5.1 ; python_version >= "3.8" and python_version < "3.13"
mypy-boto3-s3==1.34.138 ; python_version >= "3.8" and python_version < "3.13"
mypy-extensions==1.0.0 ; python_version >= "3.8" and python_version < "3.13"
mypy==1.10.1 ; python_version >= "3.8" and python_version < "3.13"
mypy==1.11.2 ; python_version >= "3.8" and python_version < "3.13"
pycparser==2.22 ; python_version >= "3.8" and python_version < "3.13" and platform_python_implementation != "PyPy"
pydantic-core==2.20.1 ; python_version >= "3.8" and python_version < "3.13"
pydantic==2.8.2 ; python_version >= "3.8" and python_version < "3.13"
Expand Down
23 changes: 18 additions & 5 deletions lib/galaxy/jobs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
Dict,
Iterable,
List,
Optional,
TYPE_CHECKING,
)

Expand Down Expand Up @@ -99,6 +100,7 @@

if TYPE_CHECKING:
from galaxy.jobs.handler import JobHandlerQueue
from galaxy.tools import Tool

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -984,11 +986,17 @@ class MinimalJobWrapper(HasResourceParameters):

is_task = False

def __init__(self, job: model.Job, app: MinimalManagerApp, use_persisted_destination: bool = False, tool=None):
def __init__(
self,
job: model.Job,
app: MinimalManagerApp,
use_persisted_destination: bool = False,
tool: Optional["Tool"] = None,
):
self.job_id = job.id
self.session_id = job.session_id
self.user_id = job.user_id
self.app: MinimalManagerApp = app
self.app = app
self.tool = tool
self.sa_session = self.app.model.context
self.extra_filenames: List[str] = []
Expand Down Expand Up @@ -2531,10 +2539,15 @@ def set_container(self, container):


class JobWrapper(MinimalJobWrapper):
def __init__(self, job, queue: "JobHandlerQueue", use_persisted_destination=False, app=None):
super().__init__(job, app=queue.app, use_persisted_destination=use_persisted_destination)
def __init__(self, job, queue: "JobHandlerQueue", use_persisted_destination=False):
app = queue.app
super().__init__(
job,
app=app,
use_persisted_destination=use_persisted_destination,
tool=app.toolbox.get_tool(job.tool_id, job.tool_version, exact=True),
)
self.queue = queue
self.tool = self.app.toolbox.get_tool(job.tool_id, job.tool_version, exact=True)
self.job_runner_mapper = JobRunnerMapper(self, queue.dispatcher.url_to_destination, self.app.job_config)
if use_persisted_destination:
self.job_runner_mapper.cached_job_destination = JobDestination(from_job=job)
Expand Down
1 change: 1 addition & 0 deletions lib/galaxy/jobs/command_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ def build_command(

if job_wrapper.is_cwl_job:
# Minimal metadata needed by the relocate script
assert job_wrapper.tool
cwl_metadata_params = {
"job_metadata": join("working", job_wrapper.tool.provided_metadata_file),
"job_id_tag": job_wrapper.get_id_tag(),
Expand Down
6 changes: 4 additions & 2 deletions lib/galaxy/jobs/runners/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,7 @@ def get_job_file(self, job_wrapper: "MinimalJobWrapper", **kwds) -> str:
env_setup_commands.append(env_to_statement(env))
command_line = job_wrapper.runner_command_line
tmp_dir_creation_statement = job_wrapper.tmp_dir_creation_statement
assert job_wrapper.tool
options = dict(
tmp_dir_creation_statement=tmp_dir_creation_statement,
job_instrumenter=job_instrumenter,
Expand Down Expand Up @@ -538,13 +539,14 @@ def _find_container(
if not compute_job_directory:
compute_job_directory = job_wrapper.working_directory

tool = job_wrapper.tool
assert tool
if not compute_tool_directory:
compute_tool_directory = job_wrapper.tool.tool_dir
compute_tool_directory = tool.tool_dir

if not compute_tmp_directory:
compute_tmp_directory = job_wrapper.tmp_directory()

tool = job_wrapper.tool
guest_ports = job_wrapper.guest_ports
tool_info = ToolInfo(
tool.containers,
Expand Down
20 changes: 11 additions & 9 deletions lib/galaxy/managers/genomes.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,17 @@
text,
)

from galaxy import model as m
from galaxy.exceptions import (
ReferenceDataError,
RequestParameterInvalidException,
)
from galaxy.managers.context import ProvidesUserContext
from galaxy.structured_app import StructuredApp
from galaxy.model import User
from galaxy.model.database_utils import is_postgres
from galaxy.structured_app import (
MinimalManagerApp,
StructuredApp,
)
from .base import raise_filter_err

if TYPE_CHECKING:
Expand All @@ -28,10 +32,10 @@ def __init__(self, app: StructuredApp):
self._app = app
self.genomes = app.genomes

def get_dbkeys(self, user: Optional[m.User], chrom_info: bool) -> List[List[str]]:
def get_dbkeys(self, user: Optional[User], chrom_info: bool) -> List[List[str]]:
return self.genomes.get_dbkeys(user, chrom_info)

def is_registered_dbkey(self, dbkey: str, user: Optional[m.User]) -> bool:
def is_registered_dbkey(self, dbkey: str, user: Optional[User]) -> bool:
dbkeys = self.get_dbkeys(user, chrom_info=False)
for _, key in dbkeys:
if dbkey == key:
Expand Down Expand Up @@ -78,8 +82,8 @@ def _get_index_filename(self, id, tbl_entries, ext, index_type):


class GenomeFilterMixin:
app: MinimalManagerApp
orm_filter_parsers: "OrmFilterParsersType"
database_connection: str
valid_ops = ("eq", "contains", "has")

def create_genome_filter(self, attr, op, val):
Expand All @@ -91,8 +95,7 @@ def _create_genome_filter(model_class=None):
# Doesn't filter genome_build for collections
if model_class.__name__ == "HistoryDatasetCollectionAssociation":
return False
# TODO: should use is_postgres(self.database_connection) in 23.2
if self.database_connection.startswith("postgres"):
if is_postgres(self.app.config.database_connection):
column = text("convert_from(metadata, 'UTF8')::json ->> 'dbkey'")
else:
column = func.json_extract(model_class.table.c._metadata, "$.dbkey") # type:ignore[assignment]
Expand All @@ -106,6 +109,5 @@ def _create_genome_filter(model_class=None):

return _create_genome_filter

def _add_parsers(self, database_connection: str):
self.database_connection = database_connection
def _add_parsers(self):
self.orm_filter_parsers.update({"genome_build": self.create_genome_filter})
6 changes: 3 additions & 3 deletions lib/galaxy/managers/hdas.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,14 +613,14 @@ def add_serializers(self):
}
self.serializers.update(serializers)

def serialize(self, hda, keys, user=None, **context):
def serialize(self, item, keys, user=None, **context):
"""
Override to hide information to users not able to access.
"""
# TODO: to DatasetAssociationSerializer
if not self.manager.is_accessible(hda, user, **context):
if not self.manager.is_accessible(item, user, **context):
keys = self._view_to_keys("inaccessible")
return super().serialize(hda, keys, user=user, **context)
return super().serialize(item, keys, user=user, **context)

def serialize_display_apps(self, item, key, trans=None, **context):
"""
Expand Down
3 changes: 1 addition & 2 deletions lib/galaxy/managers/history_contents.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,9 +637,8 @@ def parse_type_id_list(self, type_id_list_string, sep=","):

def _add_parsers(self):
super()._add_parsers()
database_connection: str = self.app.config.database_connection
annotatable.AnnotatableFilterMixin._add_parsers(self)
genomes.GenomeFilterMixin._add_parsers(self, database_connection)
genomes.GenomeFilterMixin._add_parsers(self)
deletable.PurgableFiltersMixin._add_parsers(self)
taggable.TaggableFilterMixin._add_parsers(self)
tools.ToolFilterMixin._add_parsers(self)
Expand Down
7 changes: 2 additions & 5 deletions lib/galaxy/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5980,11 +5980,8 @@ def to_dict(self, view="collection"):


class LibraryDatasetDatasetAssociation(DatasetInstance, HasName, Serializable):
message: Mapped[Optional[str]] = mapped_column(TrimmedString(255))
tags: Mapped[List["LibraryDatasetDatasetAssociationTagAssociation"]] = relationship(
order_by=lambda: LibraryDatasetDatasetAssociationTagAssociation.id,
back_populates="library_dataset_dataset_association",
)
message: Mapped[Optional[str]]
tags: Mapped[List["LibraryDatasetDatasetAssociationTagAssociation"]]

def __init__(
self,
Expand Down
32 changes: 16 additions & 16 deletions lib/galaxy/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@

if TYPE_CHECKING:
from galaxy.app import UniverseApplication
from galaxy.managers.context import ProvidesUserContext
from galaxy.managers.jobs import JobSearch
from galaxy.tools.actions.metadata import SetMetadataToolAction

Expand Down Expand Up @@ -2345,7 +2346,7 @@ def call_hook(self, hook_name, *args, **kwargs):
def exec_before_job(self, app, inp_data, out_data, param_dict=None):
pass

def exec_after_process(self, app, inp_data, out_data, param_dict, job=None, **kwds):
def exec_after_process(self, app, inp_data, out_data, param_dict, job, final_job_state: Optional[str] = None):
pass

def job_failed(self, job_wrapper, message, exception=False):
Expand Down Expand Up @@ -2976,7 +2977,7 @@ def exec_before_job(self, app, inp_data, out_data, param_dict=None):
with open(expression_inputs_path, "w") as f:
json.dump(expression_inputs, f)

def exec_after_process(self, app, inp_data, out_data, param_dict, job=None, **kwds):
def exec_after_process(self, app, inp_data, out_data, param_dict, job, final_job_state=None):
for key, val in self.outputs.items():
if key not in out_data:
# Skip filtered outputs
Expand Down Expand Up @@ -3156,7 +3157,7 @@ def regenerate_imported_metadata_if_needed(
)
self.app.job_manager.enqueue(job=job, tool=self)

def exec_after_process(self, app, inp_data, out_data, param_dict, job=None, **kwds):
def exec_after_process(self, app, inp_data, out_data, param_dict, job, final_job_state=None):
working_directory = app.object_store.get_filename(job, base_dir="job_work", dir_only=True, obj_dir=True)
for name, dataset in inp_data.items():
external_metadata = get_metadata_compute_strategy(app.config, job.id, tool_id=self.id)
Expand Down Expand Up @@ -3214,8 +3215,8 @@ class ExportHistoryTool(Tool):
class ImportHistoryTool(Tool):
tool_type = "import_history"

def exec_after_process(self, app, inp_data, out_data, param_dict, job, final_job_state=None, **kwds):
super().exec_after_process(app, inp_data, out_data, param_dict, job=job, **kwds)
def exec_after_process(self, app, inp_data, out_data, param_dict, job, final_job_state=None):
super().exec_after_process(app, inp_data, out_data, param_dict, job=job, final_job_state=final_job_state)
if final_job_state != DETECTED_JOB_STATE.OK:
return
JobImportHistoryArchiveWrapper(self.app, job.id).cleanup_after_job()
Expand All @@ -3239,9 +3240,8 @@ def __remove_interactivetool_by_job(self, job):
else:
log.warning("Could not determine job to stop InteractiveTool: %s", job)

def exec_after_process(self, app, inp_data, out_data, param_dict, job=None, **kwds):
# run original exec_after_process
super().exec_after_process(app, inp_data, out_data, param_dict, job=job, **kwds)
def exec_after_process(self, app, inp_data, out_data, param_dict, job, final_job_state=None):
super().exec_after_process(app, inp_data, out_data, param_dict, job=job, final_job_state=final_job_state)
self.__remove_interactivetool_by_job(job)

def job_failed(self, job_wrapper, message, exception=False):
Expand All @@ -3260,12 +3260,11 @@ def __init__(self, config_file, root, app, guid=None, data_manager_id=None, **kw
if self.data_manager_id is None:
self.data_manager_id = self.id

def exec_after_process(self, app, inp_data, out_data, param_dict, job=None, final_job_state=None, **kwds):
def exec_after_process(self, app, inp_data, out_data, param_dict, job, final_job_state=None):
assert self.allow_user_access(job.user), "You must be an admin to access this tool."
if final_job_state != DETECTED_JOB_STATE.OK:
return
# run original exec_after_process
super().exec_after_process(app, inp_data, out_data, param_dict, job=job, **kwds)
super().exec_after_process(app, inp_data, out_data, param_dict, job=job, final_job_state=final_job_state)
# process results of tool
data_manager_id = job.data_manager_association.data_manager_id
data_manager = self.app.data_managers.get_manager(data_manager_id)
Expand Down Expand Up @@ -3402,7 +3401,7 @@ def _add_datasets_to_history(self, history, elements, datasets_visible=False):
element_object.visible = datasets_visible
history.stage_addition(element_object)

def produce_outputs(self, trans, out_data, output_collections, incoming, history, **kwds):
def produce_outputs(self, trans: "ProvidesUserContext", out_data, output_collections, incoming, history, **kwds):
return self._outputs_dict()

def _outputs_dict(self):
Expand Down Expand Up @@ -3579,7 +3578,7 @@ class ExtractDatasetCollectionTool(DatabaseOperationTool):
require_terminal_states = False
require_dataset_ok = False

def produce_outputs(self, trans, out_data, output_collections, incoming, history, tags=None, **kwds):
def produce_outputs(self, trans, out_data, output_collections, incoming, history, **kwds):
has_collection = incoming["input"]
if hasattr(has_collection, "element_type"):
# It is a DCE
Expand Down Expand Up @@ -3992,15 +3991,15 @@ def add_copied_value_to_new_elements(new_label, dce_object):
class ApplyRulesTool(DatabaseOperationTool):
tool_type = "apply_rules"

def produce_outputs(self, trans, out_data, output_collections, incoming, history, tag_handler, **kwds):
def produce_outputs(self, trans, out_data, output_collections, incoming, history, **kwds):
hdca = incoming["input"]
rule_set = RuleSet(incoming["rules"])
copied_datasets = []

def copy_dataset(dataset, tags):
copied_dataset = dataset.copy(copy_tags=dataset.tags, flush=False)
if tags is not None:
tag_handler.set_tags_from_list(
trans.tag_handler.set_tags_from_list(
trans.get_user(),
copied_dataset,
tags,
Expand Down Expand Up @@ -4029,14 +4028,15 @@ class TagFromFileTool(DatabaseOperationTool):
# require_terminal_states = True
# require_dataset_ok = False

def produce_outputs(self, trans, out_data, output_collections, incoming, history, tag_handler, **kwds):
def produce_outputs(self, trans, out_data, output_collections, incoming, history, **kwds):
hdca = incoming["input"]
how = incoming["how"]
new_tags_dataset_assoc = incoming["tags"]
new_elements = {}
new_datasets = []

def add_copied_value_to_new_elements(new_tags_dict, dce):
tag_handler = trans.tag_handler
if getattr(dce.element_object, "history_content_type", None) == "dataset":
copied_value = dce.element_object.copy(copy_tags=dce.element_object.tags, flush=False)
# copy should never be visible, since part of a collection
Expand Down
2 changes: 0 additions & 2 deletions lib/galaxy/tools/actions/model_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,6 @@ def execute(
def _produce_outputs(
self, trans: "ProvidesUserContext", tool, out_data, output_collections, incoming, history, tags, hdca_tags, skip
):
tag_handler = trans.tag_handler
tool.produce_outputs(
trans,
out_data,
Expand All @@ -148,7 +147,6 @@ def _produce_outputs(
history=history,
tags=tags,
hdca_tags=hdca_tags,
tag_handler=tag_handler,
)
if mapped_over_elements := output_collections.dataset_collection_elements:
for name, value in out_data.items():
Expand Down
2 changes: 1 addition & 1 deletion lib/galaxy/util/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1133,7 +1133,7 @@ def commaify(amount):


@overload
def unicodify( # type: ignore[overload-overlap] # ignore can be removed in mypy >=1.11.0
def unicodify(
value: Literal[None],
encoding: str = DEFAULT_ENCODING,
error: str = "replace",
Expand Down
3 changes: 2 additions & 1 deletion lib/galaxy/util/topsort.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
a good deal more code than topsort itself!
"""

from random import choice


class CycleError(Exception):
def __init__(self, sofar, numpreds, succs):
Expand Down Expand Up @@ -112,7 +114,6 @@ def pick_a_cycle(self):
# crawl backward over the preds until we hit a duplicate, then
# reverse the path.
preds = self.get_preds()
from random import choice

x = choice(remaining_elts)
answer = []
Expand Down
Loading

0 comments on commit 948f38b

Please sign in to comment.