From 246e07ab09231e7e1502f9f91f7cf17b4f98e0ee Mon Sep 17 00:00:00 2001 From: nuwang <2070605+nuwang@users.noreply.github.com> Date: Fri, 26 Jul 2024 17:37:00 +0530 Subject: [PATCH 01/30] Refactor linter --- tpv/commands/linter.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tpv/commands/linter.py b/tpv/commands/linter.py index 08e232e..f8bce57 100644 --- a/tpv/commands/linter.py +++ b/tpv/commands/linter.py @@ -23,6 +23,11 @@ def lint(self): except Exception as e: log.error(f"Linting failed due to syntax errors in yaml file: {e}") raise TPVLintError("Linting failed due to syntax errors in yaml file: ") from e + self.lint_tools(loader) + self.lint_destinations(loader) + self.print_errors_and_warnings() + + def lint_tools(self, loader): default_inherits = loader.global_settings.get('default_inherits') for tool_regex, tool in loader.tools.items(): try: @@ -34,6 +39,9 @@ def lint(self): f"The tool named: {default_inherits} is marked globally as the tool to inherit from " "by default. You may want to mark it as abstract if it is not an actual tool and it " "will be excluded from scheduling decisions.") + + def lint_destinations(self, loader): + default_inherits = loader.global_settings.get('default_inherits') for destination in loader.destinations.values(): if not destination.runner and not destination.abstract: self.errors.append(f"Destination '{destination.id}' does not define the runner parameter. " @@ -51,6 +59,8 @@ def lint(self): f"The destination named: {default_inherits} is marked globally as the destination to inherit from " "by default. You may want to mark it as abstract if it is not meant to be dispatched to, and it " "will be excluded from scheduling decisions.") + + def print_errors_and_warnings(self): if self.warnings: for w in self.warnings: log.warning(w) From e906d829a571c2d4ffc2f7c92bd92fc2ba32cd9f Mon Sep 17 00:00:00 2001 From: nuwang <2070605+nuwang@users.noreply.github.com> Date: Thu, 8 Aug 2024 13:21:09 +0530 Subject: [PATCH 02/30] Initial use of pydantic for entity validation --- tests/fixtures/mapping-destinations.yml | 4 +- .../fixtures/mapping-rule-argument-based.yml | 1 + tests/test_entity.py | 31 +- tpv/commands/linter.py | 2 +- tpv/core/entities.py | 1139 ++++++++--------- tpv/core/evaluator.py | 12 + tpv/core/loader.py | 124 +- tpv/core/mapper.py | 32 +- 8 files changed, 611 insertions(+), 734 deletions(-) create mode 100644 tpv/core/evaluator.py diff --git a/tests/fixtures/mapping-destinations.yml b/tests/fixtures/mapping-destinations.yml index 818a95c..18e5ca2 100644 --- a/tests/fixtures/mapping-destinations.yml +++ b/tests/fixtures/mapping-destinations.yml @@ -174,7 +174,7 @@ destinations: SPECIAL_FLAG: "first" params: memory_requests: "{mem}" - k8s_walltime_limit: 10 + k8s_walltime_limit: "10" rules: - if: input_size > 10 execute: | @@ -193,7 +193,7 @@ destinations: TEST_ENTITY_PRIORITY: "{cores*2}" params: memory_requests: "{mem*2}" - k8s_walltime_limit: 20 + k8s_walltime_limit: "20" rules: - if: input_size > 20 execute: | diff --git a/tests/fixtures/mapping-rule-argument-based.yml b/tests/fixtures/mapping-rule-argument-based.yml index 24eefb2..85d2800 100644 --- a/tests/fixtures/mapping-rule-argument-based.yml +++ b/tests/fixtures/mapping-rule-argument-based.yml @@ -16,6 +16,7 @@ tools: rules: - if: | helpers.job_args_match(job, app, {'input_opts': {'db_selector': 'db'}}) + id: limbo_rule scheduling: prefer: - pulsar diff --git a/tests/test_entity.py b/tests/test_entity.py index 03621d2..68dddff 100644 --- a/tests/test_entity.py +++ b/tests/test_entity.py @@ -2,8 +2,6 @@ import unittest from tpv.rules import gateway from tpv.core.entities import Destination -from tpv.core.entities import Tag -from tpv.core.entities import TagType from tpv.core.entities import Tool from tpv.core.loader import TPVConfigLoader from tpv.commands.test import mock_galaxy @@ -42,16 +40,15 @@ def test_all_entities_refer_to_same_loader(self): assert rule.loader == original_loader def test_destination_to_dict(self): - tpv_config = os.path.join(os.path.dirname(__file__), 'fixtures/mapping-rule-argument-based.yml') loader = TPVConfigLoader.from_url_or_path(tpv_config) # create a destination - destination = loader.destinations["k8s_environment"] + destination = loader.config.destinations["k8s_environment"] # serialize the destination - serialized_destination = destination.to_dict() + serialized_destination = destination.dict() # deserialize the same destination - deserialized_destination = Destination.from_dict(loader, serialized_destination) + deserialized_destination = Destination(loader=loader, **serialized_destination) # make sure the deserialized destination is the same as the original self.assertEqual(deserialized_destination, destination) @@ -60,24 +57,10 @@ def test_tool_to_dict(self): loader = TPVConfigLoader.from_url_or_path(tpv_config) # create a tool - tool = loader.tools["limbo"] + tool = loader.config.tools["limbo"] # serialize the tool - serialized_destination = tool.to_dict() + serialized_tool = tool.dict() # deserialize the same tool - deserialized_destination = Tool.from_dict(loader, serialized_destination) + deserialized_tool = Tool(loader=loader, **serialized_tool) # make sure the deserialized tool is the same as the original - self.assertEqual(deserialized_destination, tool) - - def test_tag_equivalence(self): - tag1 = Tag("tag_name", "tag_value", TagType.REQUIRE) - tag2 = Tag("tag_name2", "tag_value", TagType.REQUIRE) - tag3 = Tag("tag_name", "tag_value1", TagType.REQUIRE) - tag4 = Tag("tag_name", "tag_value1", TagType.PREFER) - same_as_tag1 = Tag("tag_name", "tag_value", TagType.REQUIRE) - - self.assertEqual(tag1, tag1) - self.assertEqual(tag1, same_as_tag1) - self.assertNotEqual(tag1, tag2) - self.assertNotEqual(tag1, tag3) - self.assertNotEqual(tag1, tag4) - self.assertNotEqual(tag1, "hello") + self.assertEqual(deserialized_tool, tool) diff --git a/tpv/commands/linter.py b/tpv/commands/linter.py index f8bce57..b96f9c0 100644 --- a/tpv/commands/linter.py +++ b/tpv/commands/linter.py @@ -42,7 +42,7 @@ def lint_tools(self, loader): def lint_destinations(self, loader): default_inherits = loader.global_settings.get('default_inherits') - for destination in loader.destinations.values(): + for destination in loader.config.destinations.values(): if not destination.runner and not destination.abstract: self.errors.append(f"Destination '{destination.id}' does not define the runner parameter. " "The runner parameter is mandatory.") diff --git a/tpv/core/entities.py b/tpv/core/entities.py index 1e214a5..19adf7b 100644 --- a/tpv/core/entities.py +++ b/tpv/core/entities.py @@ -1,362 +1,349 @@ -from __future__ import annotations - -from enum import Enum -import logging import copy +import itertools +import logging +from collections import defaultdict +from enum import IntEnum +from typing import Any, ClassVar, Dict, Generator, List, Optional, Union from galaxy import util as galaxy_util +from pydantic import ( + BaseModel, + ConfigDict, + Field, + field_validator, + model_validator, + ValidationInfo, +) +from pydantic.json_schema import SkipJsonSchema + +from .evaluator import TPVCodeBlockInterface log = logging.getLogger(__name__) -class TagType(Enum): - REQUIRE = 2 - PREFER = 1 - ACCEPT = 0 - REJECT = -1 +class TryNextDestinationOrFail(Exception): + # Try next destination, fail job if destination options exhausted + pass - def __int__(self): - return self.value +class TryNextDestinationOrWait(Exception): + # Try next destination, raise JobNotReadyException if destination options exhausted + pass -class Tag: - def __init__(self, name, value, tag_type: Enum): - self.name = name - self.value = value - self.tag_type = tag_type +def default_field_copier(entity1, entity2, property_name): + return ( + getattr( + entity1, + property_name, + ) + if getattr(entity1, property_name, None) is None + else getattr(entity2, property_name, None) + ) - def __eq__(self, other): - if not isinstance(other, Tag): - # don't attempt to compare against unrelated types - return NotImplemented - return self.name == other.name and self.value == other.value and self.tag_type == other.tag_type +def default_dict_copier(entity1, entity2, property_name): + new_dict = copy.deepcopy(getattr(entity2, property_name)) or {} + new_dict.update(getattr(entity2, property_name) or {}) + return new_dict - def __repr__(self): - return f"" + +class TagType(IntEnum): + REQUIRE = 2 + PREFER = 1 + ACCEPT = 0 + REJECT = -1 class IncompatibleTagsException(Exception): + def __init__(self, first_set: "SchedulingTags", second_set: "SchedulingTags"): + super().__init__( + "Cannot combine tag sets because require and reject tags mismatch. First" + f" tag set requires: {first_set.require} and rejects: {first_set.reject}." + f" Second tag set requires: {second_set.require} and rejects:" + f" {second_set.reject}." + ) - def __init__(self, first_set, second_set): - super().__init__( - f"Cannot combine tag sets because require and reject tags mismatch. First tag set requires:" - f" {[tag.value for tag in first_set.filter(TagType.REQUIRE)]} and rejects:" - f" {[tag.value for tag in first_set.filter(TagType.REJECT)]}. Second tag set requires:" - f" {[tag.value for tag in second_set.filter(TagType.REQUIRE)]} and rejects:" - f" {[tag.value for tag in second_set.filter(TagType.REJECT)]}.") +class SchedulingTags(BaseModel): + require: Optional[List[str]] = Field(default_factory=list) + prefer: Optional[List[str]] = Field(default_factory=list) + accept: Optional[List[str]] = Field(default_factory=list) + reject: Optional[List[str]] = Field(default_factory=list) + @model_validator(mode="after") + def check_duplicates(self): + tag_occurrences = defaultdict(list) -class TryNextDestinationOrFail(Exception): - # Try next destination, fail job if destination options exhausted - pass + # Track tag occurrences within each category and across categories + for tag_type in TagType: + field = tag_type.name.lower() + tags = getattr(self, field, []) or [] + for tag in tags: + tag_occurrences[tag].append(field) + # Identify duplicates + duplicates = { + tag: fields for tag, fields in tag_occurrences.items() if len(fields) > 1 + } -class TryNextDestinationOrWait(Exception): - # Try next destination, raise JobNotReadyException if destination options exhausted - pass + # Build the detailed error message + if duplicates: + details = "; ".join( + [f"'{tag}' in {fields}" for tag, fields in duplicates.items()] + ) + raise ValueError(f"Duplicate tags found: {details}") + return self -class TagSetManager(object): - - def __init__(self, tags=[]): - self.tags = tags or [] - - def add_tag_override(self, tag: Tag): - # pop the tag if it exists, as a tag can only belong to one type - self.tags = list(filter(lambda t: t.value != tag.value, self.tags)) - self.tags.append(tag) - - def filter(self, tag_type: TagType | list[TagType] = None, - tag_name: str = None, tag_value: str = None) -> list[Tag]: - filtered = self.tags - if tag_type: - if isinstance(tag_type, TagType): - filtered = (tag for tag in filtered if tag.tag_type == tag_type) - else: - filtered = (tag for tag in filtered if tag.tag_type in tag_type) - if tag_name: - filtered = (tag for tag in filtered if tag.name == tag_name) - if tag_value: - filtered = (tag for tag in filtered if tag.value == tag_value) - return filtered - - def add_tag_overrides(self, tags: list[Tag]): - for tag in tags: - self.add_tag_override(tag) - - def can_combine(self, other: TagSetManager) -> bool: - self_required = ((t.name, t.value) for t in self.filter(TagType.REQUIRE)) - other_required = ((t.name, t.value) for t in other.filter(TagType.REQUIRE)) - self_rejected = ((t.name, t.value) for t in self.filter(TagType.REJECT)) - other_rejected = ((t.name, t.value) for t in other.filter(TagType.REJECT)) - if set(self_required).intersection(set(other_rejected)): - return False - elif set(self_rejected).intersection(set(other_required)): + def all_tags(self) -> Generator[str, None, None]: + return itertools.chain(self.require, self.prefer, self.accept, self.reject) + + def add_tag_override(self, tag_type: TagType, tag_value: str): + # Remove tag from all categories + for field in TagType: + field_name = field.name.lower() + if tag_value in getattr(self, field_name): + getattr(self, field_name).remove(tag_value) + + # Add tag to the specified category + tag_field = tag_type.name.lower() + setattr(self, tag_field, getattr(self, tag_field, []).append(tag_value)) + + def inherit(self, other: "SchedulingTags") -> "SchedulingTags": + # Create new lists of tags that combine self and other + new_tags = copy.deepcopy(other) + for tag_type in [ + TagType.ACCEPT, + TagType.PREFER, + TagType.REQUIRE, + TagType.REJECT, + ]: + for tag in getattr(self, tag_type.name.lower(), []): + new_tags.add_tag_override(tag_type, tag) + return new_tags + + def can_combine(self, other: "SchedulingTags") -> bool: + self_required = set(self.require) + other_required = set(other.require) + self_rejected = set(self.reject) + other_rejected = set(other.reject) + + if self_required.intersection(other_rejected) or self_rejected.intersection( + other_required + ): return False - else: - return True + return True - def inherit(self, other) -> TagSetManager: - assert type(self) is type(other) - new_tag_set = TagSetManager() - new_tag_set.add_tag_overrides(other.filter(TagType.ACCEPT)) - new_tag_set.add_tag_overrides(other.filter(TagType.PREFER)) - new_tag_set.add_tag_overrides(other.filter(TagType.REQUIRE)) - new_tag_set.add_tag_overrides(other.filter(TagType.REJECT)) - new_tag_set.add_tag_overrides(self.filter(TagType.ACCEPT)) - new_tag_set.add_tag_overrides(self.filter(TagType.PREFER)) - new_tag_set.add_tag_overrides(self.filter(TagType.REQUIRE)) - new_tag_set.add_tag_overrides(self.filter(TagType.REJECT)) - return new_tag_set - - def combine(self, other: TagSetManager) -> TagSetManager: + def combine(self, other: "SchedulingTags") -> "SchedulingTags": if not self.can_combine(other): raise IncompatibleTagsException(self, other) - new_tag_set = TagSetManager() - # Add accept tags first, as they should be overridden by prefer, require and reject tags - new_tag_set.add_tag_overrides(other.filter(TagType.ACCEPT)) - new_tag_set.add_tag_overrides(self.filter(TagType.ACCEPT)) - # Next add preferred, as they should be overridden by require and reject tags - new_tag_set.add_tag_overrides(other.filter(TagType.PREFER)) - new_tag_set.add_tag_overrides(self.filter(TagType.PREFER)) - # Require and reject tags can be added in either order, as there's no overlap - new_tag_set.add_tag_overrides(other.filter(TagType.REQUIRE)) - new_tag_set.add_tag_overrides(self.filter(TagType.REQUIRE)) - new_tag_set.add_tag_overrides(other.filter(TagType.REJECT)) - new_tag_set.add_tag_overrides(self.filter(TagType.REJECT)) - return new_tag_set - - def match(self, other: TagSetManager) -> bool: - return (all(other.contains_tag(required) for required in self.filter(TagType.REQUIRE)) and - all(self.contains_tag(required) for required in other.filter(TagType.REQUIRE)) and - not any(other.contains_tag(rejected) for rejected in self.filter(TagType.REJECT)) and - not any(self.contains_tag(rejected) for rejected in other.filter(TagType.REJECT))) - - def contains_tag(self, tag) -> bool: - """ - Returns true if the name and value of the tag match. Ignores tag_type. - :param tag: - :return: - """ - return any(self.filter(tag_name=tag.name, tag_value=tag.value)) - def score(self, other: TagSetManager) -> bool: - """ - Computes a compatibility score between tag sets. - :param other: - :return: - """ - return (sum(int(tag.tag_type) * int(o.tag_type) for tag in self.tags for o in other.tags - if tag.name == o.name and tag.value == o.value) - # penalize tags that don't exist in the other - - sum(int(tag.tag_type) for tag in self.tags if not other.contains_tag(tag))) + new_tags = SchedulingTags() - def __eq__(self, other): - if not isinstance(other, TagSetManager): - # don't attempt to compare against unrelated types - return NotImplemented + # Add tags in the specific precedence order + for tag_type in [ + TagType.ACCEPT, + TagType.PREFER, + TagType.REQUIRE, + TagType.REJECT, + ]: + for tag in getattr(other, tag_type.name.lower()): + new_tags.add_tag_override(tag_type, tag) + for tag in getattr(self, tag_type.name.lower()): + new_tags.add_tag_override(tag_type, tag) - return self.tags == other.tags + return new_tags - def __repr__(self): - return f"{self.__class__} tags={[tag for tag in self.tags]}" + def match(self, other: "SchedulingTags") -> bool: + self_required = set(self.require) + other_required = set(other.require) + self_rejected = set(self.reject) + other_rejected = set(other.reject) - @staticmethod - def from_dict(tags: list[dict]) -> TagSetManager: - tag_list = [] - for tag_val in tags.get('require') or []: - tag_list.append(Tag(name="scheduling", value=tag_val, tag_type=TagType.REQUIRE)) - for tag_val in tags.get('prefer') or []: - tag_list.append(Tag(name="scheduling", value=tag_val, tag_type=TagType.PREFER)) - for tag_val in tags.get('accept') or []: - tag_list.append(Tag(name="scheduling", value=tag_val, tag_type=TagType.ACCEPT)) - for tag_val in tags.get('reject') or []: - tag_list.append(Tag(name="scheduling", value=tag_val, tag_type=TagType.REJECT)) - return TagSetManager(tags=tag_list) - - def to_dict(self) -> dict: - result_dict = { - 'require': [tag.value for tag in self.tags if tag.tag_type == TagType.REQUIRE], - 'prefer': [tag.value for tag in self.tags if tag.tag_type == TagType.PREFER], - 'accept': [tag.value for tag in self.tags if tag.tag_type == TagType.ACCEPT], - 'reject': [tag.value for tag in self.tags if tag.tag_type == TagType.REJECT] - } - return result_dict + return ( + self_required.issubset(other.all_tags()) + and other_required.issubset(self.all_tags()) + and not self_rejected.intersection(other.all_tags()) + and not other_rejected.intersection(self.all_tags()) + ) + + def score(self, other: "SchedulingTags") -> int: + score = 0 + for tag_type in TagType: + self_tags = set(getattr(self, tag_type.name.lower())) + other_tags = set(getattr(other, tag_type.name.lower())) + common_tags = self_tags & other_tags + score += len(common_tags) * int(tag_type) * int(tag_type) -class Entity(object): + unique_self_tags = self_tags - other_tags + score -= len(unique_self_tags) * int(tag_type) + + return score - merge_order = 0 - def __init__(self, loader, id=None, abstract=False, cores=None, mem=None, gpus=None, min_cores=None, min_mem=None, - min_gpus=None, max_cores=None, max_mem=None, max_gpus=None, env=None, params=None, resubmit=None, - tpv_tags=None, rank=None, inherits=None, context=None): +class Entity(BaseModel): + class Config: + arbitrary_types_allowed = True + + merge_order: ClassVar[int] = 0 + loader: SkipJsonSchema[Optional[TPVCodeBlockInterface]] = Field( + exclude=True, default=None + ) + id: Optional[str] = None + abstract: Optional[bool] = False + cores: Optional[Union[float, str]] = None + mem: Optional[Union[float, str]] = None + gpus: Optional[Union[float, str]] = None + min_cores: Optional[Union[float, str]] = None + min_mem: Optional[Union[float, str]] = None + min_gpus: Optional[Union[float, str]] = None + max_cores: Optional[Union[float, str]] = None + max_mem: Optional[Union[float, str]] = None + max_gpus: Optional[Union[float, str]] = None + env: Optional[List[Dict[str, str]]] = None + params: Optional[Dict[str, str]] = None + resubmit: Optional[Dict[str, str]] = Field(default_factory=dict) + tpv_tags: Optional[SchedulingTags] = Field( + alias="scheduling", default_factory=SchedulingTags + ) + rank: Optional[str] = None + inherits: Optional[str] = None + context: Optional[Dict[str, Any]] = Field(default_factory=dict) + + def __init__(self, **data: Any): + super().__init__(**data) + self.propagate_loader(self.loader) + + def propagate_loader(self, loader): self.loader = loader - self.id = id - self.abstract = galaxy_util.asbool(abstract) - self.cores = cores - self.mem = mem - self.gpus = gpus - self.min_cores = min_cores - self.min_mem = min_mem - self.min_gpus = min_gpus - self.max_cores = max_cores - self.max_mem = max_mem - self.max_gpus = max_gpus - self.env = self.convert_env(env) - self.params = params - self.resubmit = resubmit - self.tpv_tags = TagSetManager.from_dict(tpv_tags or {}) - self.rank = rank - self.inherits = inherits - self.context = context - self.validate() - - def __deepcopy__(self, memodict={}): + + def __deepcopy__(self, memo: dict): # make sure we don't deepcopy the loader: https://github.com/galaxyproject/total-perspective-vortex/issues/53 - # xref: https://stackoverflow.com/a/15774013 - cls = self.__class__ - result = cls.__new__(cls) - memodict[id(self)] = result - for k, v in self.__dict__.items(): - if k == "loader": - setattr(result, k, v) - else: - setattr(result, k, copy.deepcopy(v, memodict)) - return result - - def process_complex_property(self, prop, context, func, stringify=False): + # xref: https://stackoverflow.com/a/68746763/10971151 + memo[id(self.loader)] = self.loader + return super().__deepcopy__(memo) + + @staticmethod + def convert_env(env): + if isinstance(env, dict): + env = [dict(name=k, value=v) for (k, v) in env.items()] + return env + + @model_validator(mode="before") + @classmethod + def preprocess(cls, values): + if values: + values["abstract"] = galaxy_util.asbool(values.get("abstract", False)) + values["env"] = Entity.convert_env(values.get("env")) + # loader = values.get("loader") + # compile properties and check for errors + # if loader: + # for f in cls.model_fields: + # field = cls.model_fields[f] + # if f in values and field.metadata and field.metadata[0]: + # metadata = field.metadata[0] + # if metadata.complex_property: + # self.compile_complex_property(loader, values[f]) + # else: + # self.compile_code_block(loader, values[f]) + return values + + def process_complex_property(self, prop, context: Dict[str, Any], func): if isinstance(prop, str): return func(prop, context) elif isinstance(prop, dict): - evaluated_props = {key: self.process_complex_property(childprop, context, func, stringify=stringify) - for key, childprop in prop.items()} + evaluated_props = { + key: self.process_complex_property(childprop, context, func) + for key, childprop in prop.items() + } return evaluated_props elif isinstance(prop, list): - evaluated_props = [self.process_complex_property(childprop, context, func, stringify=stringify) - for childprop in prop] + evaluated_props = [ + self.process_complex_property(childprop, context, func) + for childprop in prop + ] return evaluated_props else: - return str(prop) if stringify else prop # To handle special case of env vars provided as ints + return prop - def compile_complex_property(self, prop): - return self.process_complex_property( - prop, None, lambda p, c: self.loader.compile_code_block(p, as_f_string=True)) + @classmethod + def compile_complex_property(cls, loader, prop): + return cls.process_complex_property( + prop, None, lambda p, c: loader.compile_code_block(p, as_f_string=True) + ) - def evaluate_complex_property(self, prop, context, stringify=False): + def evaluate_complex_property(self, prop, context: Dict[str, Any]): return self.process_complex_property( - prop, context, lambda p, c: self.loader.eval_code_block(p, c, as_f_string=True), stringify=stringify) - - def convert_env(self, env): - if isinstance(env, dict): - env = [dict(name=k, value=v) for (k, v) in env.items()] - return env - - def validate(self): - """ - Validates each code block and makes sure the code can be compiled. - This process also results in the compiled code being cached by the loader, - so that future evaluations are faster. - """ - if self.cores: - self.loader.compile_code_block(self.cores) - if self.mem: - self.loader.compile_code_block(self.mem) - if self.gpus: - self.loader.compile_code_block(self.gpus) - if self.min_cores: - self.loader.compile_code_block(self.min_cores) - if self.min_mem: - self.loader.compile_code_block(self.min_mem) - if self.min_gpus: - self.loader.compile_code_block(self.min_gpus) - if self.max_cores: - self.loader.compile_code_block(self.max_cores) - if self.max_mem: - self.loader.compile_code_block(self.max_mem) - if self.max_gpus: - self.loader.compile_code_block(self.max_gpus) - if self.env: - self.compile_complex_property(self.env) - if self.params: - self.compile_complex_property(self.params) - if self.resubmit: - self.compile_complex_property(self.resubmit) - if self.rank: - self.loader.compile_code_block(self.rank) - - def __repr__(self): - return f"{self.__class__} id={self.id}, abstract={self.abstract}, cores={self.cores}, mem={self.mem}, " \ - f"gpus={self.gpus}, min_cores = {self.min_cores}, min_mem = {self.min_mem}, " \ - f"min_gpus = {self.min_gpus}, max_cores = {self.max_cores}, max_mem = {self.max_mem}, " \ - f"max_gpus = {self.max_gpus}, env={self.env}, params={self.params}, resubmit={self.resubmit}, " \ - f"tags={self.tpv_tags}, rank={self.rank[:10] if self.rank else ''}, inherits={self.inherits}, "\ - f"context={self.context}" - - def __eq__(self, other): - if not isinstance(other, self.__class__): - # don't attempt to compare against unrelated types - return NotImplemented - - return ( - self.id == other.id and - self.abstract == other.abstract and - self.cores == other.cores and - self.mem == other.mem and - self.gpus == other.gpus and - self.min_cores == other.min_cores and - self.min_mem == other.min_mem and - self.min_gpus == other.min_gpus and - self.max_cores == other.max_cores and - self.max_mem == other.max_mem and - self.max_gpus == other.max_gpus and - self.env == other.env and - self.params == other.params and - self.resubmit == other.resubmit and - self.tpv_tags == other.tpv_tags and - self.inherits == other.inherits and - self.context == other.context + prop, + context, + lambda p, c: self.loader.eval_code_block(p, c, as_f_string=True), ) - def merge_env_list(self, original, replace): + @staticmethod + def merge_env_list(original, replace): for i, original_elem in enumerate(original): for j, replace_elem in enumerate(replace): - if (("name" in replace_elem and original_elem.get("name") == replace_elem["name"]) - or original_elem == replace_elem): + if ( + "name" in replace_elem + and original_elem.get("name") == replace_elem["name"] + ) or original_elem == replace_elem: original[i] = replace.pop(j) break original.extend(replace) return original - def override(self, entity): + @staticmethod + def override_single_property( + entity, entity1, entity2, property_name, field_copier=default_field_copier + ): + if ( + property_name in entity1.model_fields_set + or property_name in entity2.model_fields_set + ): + setattr( + entity, property_name, field_copier(entity1, entity2, property_name) + ) + + def override(self, entity: "Entity") -> "Entity": if entity.merge_order <= self.merge_order: # Use the broader class as a base when copying. Useful in particular for Rules - new_entity = copy.copy(self) + new_entity = self.copy() else: - new_entity = copy.copy(entity) - new_entity.id = self.id or entity.id - new_entity.abstract = self.abstract and entity.abstract - new_entity.cores = self.cores if self.cores is not None else entity.cores - new_entity.mem = self.mem if self.mem is not None else entity.mem - new_entity.gpus = self.gpus if self.gpus is not None else entity.gpus - new_entity.min_cores = self.min_cores if self.min_cores is not None else entity.min_cores - new_entity.min_mem = self.min_mem if self.min_mem is not None else entity.min_mem - new_entity.min_gpus = self.min_gpus if self.min_gpus is not None else entity.min_gpus - new_entity.max_cores = self.max_cores if self.max_cores is not None else entity.max_cores - new_entity.max_mem = self.max_mem if self.max_mem is not None else entity.max_mem - new_entity.max_gpus = self.max_gpus if self.max_gpus is not None else entity.max_gpus - new_entity.env = self.merge_env_list(copy.deepcopy(entity.env) or [], copy.deepcopy(self.env) or []) - new_entity.params = copy.copy(entity.params) or {} - new_entity.params.update(self.params or {}) - new_entity.resubmit = copy.copy(entity.resubmit) or {} - new_entity.resubmit.update(self.resubmit or {}) - new_entity.rank = self.rank if self.rank is not None else entity.rank - new_entity.inherits = self.inherits if self.inherits is not None else entity.inherits - new_entity.context = copy.copy(entity.context) or {} - new_entity.context.update(self.context or {}) + new_entity = entity.copy() + self.override_single_property(new_entity, self, entity, "id") + self.override_single_property(new_entity, self, entity, "abstract") + self.override_single_property(new_entity, self, entity, "cores") + self.override_single_property(new_entity, self, entity, "mem") + self.override_single_property(new_entity, self, entity, "gpus") + self.override_single_property(new_entity, self, entity, "min_cores") + self.override_single_property(new_entity, self, entity, "min_mem") + self.override_single_property(new_entity, self, entity, "min_gpus") + self.override_single_property(new_entity, self, entity, "max_cores") + self.override_single_property(new_entity, self, entity, "max_mem") + self.override_single_property(new_entity, self, entity, "max_gpus") + self.override_single_property(new_entity, self, entity, "max_gpus") + self.override_single_property( + new_entity, + self, + entity, + "env", + field_copier=lambda e1, e2, p: self.merge_env_list( + copy.deepcopy(entity.env) or [], copy.deepcopy(self.env) or [] + ), + ) + self.override_single_property( + new_entity, self, entity, "params", field_copier=default_dict_copier + ) + self.override_single_property( + new_entity, self, entity, "resubmit", field_copier=default_dict_copier + ) + self.override_single_property(new_entity, self, entity, "rank") + self.override_single_property(new_entity, self, entity, "inherits") + self.override_single_property( + new_entity, self, entity, "context", field_copier=default_dict_copier + ) return new_entity def inherit(self, entity): @@ -388,52 +375,66 @@ def combine(self, entity): :return: """ new_entity = self.override(entity) - new_entity.id = f"{type(self).__name__}: {self.id}, {type(entity).__name__}: {entity.id}" + new_entity.id = ( + f"{type(self).__name__}: {self.id}, {type(entity).__name__}: {entity.id}" + ) new_entity.tpv_tags = entity.tpv_tags.combine(self.tpv_tags) return new_entity - def evaluate_resources(self, context): + def evaluate_resources(self, context: Dict[str, Any]): new_entity = copy.deepcopy(self) context.update(self.context or {}) if self.min_gpus is not None: new_entity.min_gpus = self.loader.eval_code_block(self.min_gpus, context) - context['min_gpus'] = new_entity.min_gpus + context["min_gpus"] = new_entity.min_gpus if self.min_cores is not None: new_entity.min_cores = self.loader.eval_code_block(self.min_cores, context) - context['min_cores'] = new_entity.min_cores + context["min_cores"] = new_entity.min_cores if self.min_mem is not None: new_entity.min_mem = self.loader.eval_code_block(self.min_mem, context) - context['min_mem'] = new_entity.min_mem + context["min_mem"] = new_entity.min_mem if self.max_gpus is not None: new_entity.max_gpus = self.loader.eval_code_block(self.max_gpus, context) - context['max_gpus'] = new_entity.max_gpus + context["max_gpus"] = new_entity.max_gpus if self.max_cores is not None: new_entity.max_cores = self.loader.eval_code_block(self.max_cores, context) - context['max_cores'] = new_entity.max_cores + context["max_cores"] = new_entity.max_cores if self.max_mem is not None: new_entity.max_mem = self.loader.eval_code_block(self.max_mem, context) - context['max_mem'] = new_entity.max_mem + context["max_mem"] = new_entity.max_mem if self.gpus is not None: new_entity.gpus = self.loader.eval_code_block(self.gpus, context) # clamp gpus new_entity.gpus = max(new_entity.min_gpus or 0, new_entity.gpus or 0) - new_entity.gpus = min(new_entity.max_gpus, new_entity.gpus) if new_entity.max_gpus else new_entity.gpus - context['gpus'] = new_entity.gpus + new_entity.gpus = ( + min(new_entity.max_gpus, new_entity.gpus) + if new_entity.max_gpus + else new_entity.gpus + ) + context["gpus"] = new_entity.gpus if self.cores is not None: new_entity.cores = self.loader.eval_code_block(self.cores, context) # clamp cores new_entity.cores = max(new_entity.min_cores or 0, new_entity.cores or 0) - new_entity.cores = min(new_entity.max_cores, new_entity.cores) if new_entity.max_cores else new_entity.cores - context['cores'] = new_entity.cores + new_entity.cores = ( + min(new_entity.max_cores, new_entity.cores) + if new_entity.max_cores + else new_entity.cores + ) + context["cores"] = new_entity.cores if self.mem is not None: new_entity.mem = self.loader.eval_code_block(self.mem, context) # clamp mem new_entity.mem = max(new_entity.min_mem or 0, new_entity.mem or 0) - new_entity.mem = min(new_entity.max_mem, new_entity.mem or 0) if new_entity.max_mem else new_entity.mem - context['mem'] = new_entity.mem + new_entity.mem = ( + min(new_entity.max_mem, new_entity.mem or 0) + if new_entity.max_mem + else new_entity.mem + ) + context["mem"] = new_entity.mem return new_entity - def evaluate(self, context): + def evaluate(self, context: Dict[str, Any]): """ Evaluate expressions in entity properties that must be evaluated as late as possible, which is to say, after combining entity requirements. This includes env, params and resubmit, that rely on @@ -443,104 +444,108 @@ def evaluate(self, context): """ new_entity = self.evaluate_resources(context) if self.env: - new_entity.env = self.evaluate_complex_property(self.env, context, stringify=True) - context['env'] = new_entity.env + new_entity.env = self.evaluate_complex_property(self.env, context) + context["env"] = new_entity.env if self.params: new_entity.params = self.evaluate_complex_property(self.params, context) - context['params'] = new_entity.params + context["params"] = new_entity.params if self.resubmit: new_entity.resubmit = self.evaluate_complex_property(self.resubmit, context) - context['resubmit'] = new_entity.resubmit + context["resubmit"] = new_entity.resubmit return new_entity - def rank_destinations(self, destinations, context): + def rank_destinations( + self, destinations: List["Destination"], context: Dict[str, Any] + ): if self.rank: - log.debug(f"Ranking destinations: {destinations} for entity: {self} using custom function") - context['candidate_destinations'] = destinations + log.debug( + f"Ranking destinations: {destinations} for entity: {self} using custom" + " function" + ) + context["candidate_destinations"] = destinations return self.loader.eval_code_block(self.rank, context) else: # Sort destinations by priority - log.debug(f"Ranking destinations: {destinations} for entity: {self} using default ranker") + log.debug( + f"Ranking destinations: {destinations} for entity: {self} using default" + " ranker" + ) return sorted(destinations, key=lambda d: d.score(self), reverse=True) - def to_dict(self): - dict_obj = { - 'id': self.id, - 'abstract': self.abstract, - 'cores': self.cores, - 'mem': self.mem, - 'gpus': self.gpus, - 'min_cores': self.min_cores, - 'min_mem': self.min_mem, - 'min_gpus': self.min_gpus, - 'max_cores': self.max_cores, - 'max_mem': self.max_mem, - 'max_gpus': self.max_gpus, - 'env': self.env, - 'params': self.params, - 'resubmit': self.resubmit, - 'scheduling': self.tpv_tags.to_dict(), - 'inherits': self.inherits, - 'context': self.context - } - return dict_obj + def model_dump(self, **kwargs): + # Ensure by_alias is set to True to use the field aliases during serialization + kwargs.setdefault("by_alias", True) + return super().model_dump(**kwargs) + def dict(self, **kwargs): + # by_alias is set to True to use the field aliases during serialization + kwargs.setdefault("by_alias", True) + return super().dict(**kwargs) -class EntityWithRules(Entity): - merge_order = 1 - - def __init__(self, loader, id=None, abstract=False, cores=None, mem=None, gpus=None, min_cores=None, min_mem=None, - min_gpus=None, max_cores=None, max_mem=None, max_gpus=None, env=None, - params=None, resubmit=None, tpv_tags=None, rank=None, inherits=None, context=None, rules=None): - super().__init__(loader, id=id, abstract=abstract, cores=cores, mem=mem, gpus=gpus, min_cores=min_cores, - min_mem=min_mem, min_gpus=min_gpus, max_cores=max_cores, max_mem=max_mem, max_gpus=max_gpus, - env=env, params=params, resubmit=resubmit, tpv_tags=tpv_tags, rank=rank, inherits=inherits, - context=context) - self.rules = self.validate_rules(rules) - - def validate_rules(self, rules: list) -> list: - validated = {} - for rule in rules or []: - try: - validated_rule = Rule.from_dict(self.loader, rule) - validated[validated_rule.id] = validated_rule - except Exception: - log.exception(f"Could not load rule for entity: {self.__class__} with id: {self.id} and data: {rule}") - raise - return validated +class Rule(Entity): + rule_counter: ClassVar[int] = 0 + id: Optional[str] = Field(default_factory=lambda: Rule.set_default_id()) + if_condition: str = Field(alias="if") + execute: Optional[str] = None + fail: Optional[str] = None @classmethod - def from_dict(cls: type, loader, entity_dict): - return cls( - loader=loader, - id=entity_dict.get('id'), - abstract=entity_dict.get('abstract'), - cores=entity_dict.get('cores'), - mem=entity_dict.get('mem'), - gpus=entity_dict.get('gpus'), - min_cores=entity_dict.get('min_cores'), - min_mem=entity_dict.get('min_mem'), - min_gpus=entity_dict.get('min_gpus'), - max_cores=entity_dict.get('max_cores'), - max_mem=entity_dict.get('max_mem'), - max_gpus=entity_dict.get('max_gpus'), - env=entity_dict.get('env'), - params=entity_dict.get('params'), - resubmit=entity_dict.get('resubmit'), - tpv_tags=entity_dict.get('scheduling'), - rank=entity_dict.get('rank'), - inherits=entity_dict.get('inherits'), - context=entity_dict.get('context'), - rules=entity_dict.get('rules') - ) - - def to_dict(self): - dict_obj = super().to_dict() - dict_obj['rules'] = [rule.to_dict() for rule in self.rules.values()] - return dict_obj + def set_default_id(cls): + cls.rule_counter += 1 + return f"tpv_rule_{cls.rule_counter}" def override(self, entity): + new_entity = super().override(entity) + self.override_single_property(new_entity, self, entity, "match") + self.override_single_property(new_entity, self, entity, "execute") + self.override_single_property(new_entity, self, entity, "fail") + return new_entity + + def is_matching(self, context): + if self.loader.eval_code_block(self.match, context): + return True + else: + return False + + def evaluate(self, context): + if self.fail: + from galaxy.jobs.mapper import JobMappingException + + raise JobMappingException( + self.loader.eval_code_block(self.fail, context, as_f_string=True) + ) + if self.execute: + self.loader.eval_code_block(self.execute, context, exec_only=True) + # return any changes made to the entity + return context["entity"] + return self + + +class EntityWithRules(Entity): + merge_order: ClassVar[int] = 1 + rules: Optional[Dict[str, Rule]] = Field(default_factory=dict) + + def propagate_loader(self, loader): + super().propagate_loader(loader) + for rule in self.rules.values(): + rule.loader = loader + + @field_validator("rules", mode="after") + def inject_loader(cls, v: Dict[str, Entity], info: ValidationInfo): + for element in v.values(): + element.loader = info.data["loader"] + return v + + @model_validator(mode="before") + @classmethod + def deserialize_rules(cls, values: Dict[str, Any]) -> Dict[str, Any]: + if "rules" in values and isinstance(values["rules"], list): + rules = (Rule(**r) for r in values["rules"]) + values["rules"] = {rule.id: rule for rule in rules} + return values + + def override(self, entity: Entity): new_entity = super().override(entity) new_entity.rules = copy.deepcopy(entity.rules) new_entity.rules.update(self.rules or {}) @@ -549,7 +554,7 @@ def override(self, entity): new_entity.rules[rule.id] = rule.inherit(entity.rules[rule.id]) return new_entity - def evaluate_rules(self, context): + def evaluate_rules(self, context: Dict[str, str]): new_entity = copy.deepcopy(self) context.update(new_entity.context or {}) for rule in self.rules.values(): @@ -560,184 +565,86 @@ def evaluate_rules(self, context): new_entity.cores = rule.cores or new_entity.cores new_entity.mem = rule.mem or new_entity.mem new_entity.id = f"{new_entity.id}, Rule: {rule.id}" - context.update({ - 'entity': new_entity - }) + context.update({"entity": new_entity}) return new_entity - def evaluate(self, context): + def evaluate(self, context: Dict[str, str]): new_entity = self.evaluate_rules(context) return super(EntityWithRules, new_entity).evaluate(context) - def __repr__(self): - return super().__repr__() + f", rules={self.rules}" - - def __eq__(self, other): - return super().__eq__(other) and ( - self.rules == other.rules - ) - class Tool(EntityWithRules): - - merge_order = 2 + merge_order: ClassVar[int] = 2 + pass class Role(EntityWithRules): - - merge_order = 3 + merge_order: ClassVar[int] = 3 class User(EntityWithRules): - - merge_order = 4 + merge_order: ClassVar[int] = 4 class Destination(EntityWithRules): + merge_order: ClassVar[int] = 5 + runner: Optional[str] = None + max_accepted_cores: Optional[float] = None + max_accepted_mem: Optional[float] = None + max_accepted_gpus: Optional[float] = None + min_accepted_cores: Optional[float] = None + min_accepted_mem: Optional[float] = None + min_accepted_gpus: Optional[float] = None + dest_name: Optional[str] = Field(alias="destination_name_override", default=None) + # tpv_tags track what tags the entity being scheduled requested, while tpv_dest_tags track what the destination + # supports. When serializing a Destination, we don't need tpv_tags, only tpv_dest_tags. + tpv_tags: SkipJsonSchema[Optional[SchedulingTags]] = Field( + exclude=True, default_factory=SchedulingTags + ) + tpv_dest_tags: Optional[SchedulingTags] = Field( + alias="scheduling", default_factory=SchedulingTags + ) + handler_tags: Optional[List[str]] = Field(alias="tags", default_factory=list) + + @model_validator(mode="after") + def assign_defaults(self): + self.dest_name = self.dest_name or self.id + return self - merge_order = 5 - - def __init__(self, loader, id=None, abstract=False, runner=None, dest_name=None, cores=None, mem=None, gpus=None, - min_cores=None, min_mem=None, min_gpus=None, max_cores=None, max_mem=None, max_gpus=None, - min_accepted_cores=None, min_accepted_mem=None, min_accepted_gpus=None, - max_accepted_cores=None, max_accepted_mem=None, max_accepted_gpus=None, env=None, params=None, - resubmit=None, tpv_dest_tags=None, inherits=None, context=None, rules=None, handler_tags=None): - self.runner = runner - self.dest_name = dest_name or id - self.min_accepted_cores = min_accepted_cores - self.min_accepted_mem = min_accepted_mem - self.min_accepted_gpus = min_accepted_gpus - self.max_accepted_cores = max_accepted_cores - self.max_accepted_mem = max_accepted_mem - self.max_accepted_gpus = max_accepted_gpus - self.tpv_dest_tags = TagSetManager.from_dict(tpv_dest_tags or {}) - # Handler tags refer to Galaxy's job handler level tags - self.handler_tags = handler_tags - super().__init__(loader, id=id, abstract=abstract, cores=cores, mem=mem, gpus=gpus, min_cores=min_cores, - min_mem=min_mem, min_gpus=min_gpus, max_cores=max_cores, max_mem=max_mem, max_gpus=max_gpus, - env=env, params=params, resubmit=resubmit, tpv_tags=None, inherits=inherits, context=context, - rules=rules) - - @staticmethod - def from_dict(loader, entity_dict): - return Destination( - loader=loader, - id=entity_dict.get('id'), - abstract=entity_dict.get('abstract'), - runner=entity_dict.get('runner'), - dest_name=entity_dict.get('destination_name_override'), - cores=entity_dict.get('cores'), - mem=entity_dict.get('mem'), - gpus=entity_dict.get('gpus'), - min_cores=entity_dict.get('min_cores'), - min_mem=entity_dict.get('min_mem'), - min_gpus=entity_dict.get('min_gpus'), - max_cores=entity_dict.get('max_cores'), - max_mem=entity_dict.get('max_mem'), - max_gpus=entity_dict.get('max_gpus'), - min_accepted_cores=entity_dict.get('min_accepted_cores'), - min_accepted_mem=entity_dict.get('min_accepted_mem'), - min_accepted_gpus=entity_dict.get('min_accepted_gpus'), - max_accepted_cores=entity_dict.get('max_accepted_cores'), - max_accepted_mem=entity_dict.get('max_accepted_mem'), - max_accepted_gpus=entity_dict.get('max_accepted_gpus'), - env=entity_dict.get('env'), - params=entity_dict.get('params'), - resubmit=entity_dict.get('resubmit'), - tpv_dest_tags=entity_dict.get('scheduling'), - inherits=entity_dict.get('inherits'), - context=entity_dict.get('context'), - rules=entity_dict.get('rules'), - handler_tags=entity_dict.get('tags') - ) - - def to_dict(self): - dict_obj = super().to_dict() - dict_obj['runner'] = self.runner - dict_obj['destination_name_override'] = self.dest_name - dict_obj['min_accepted_cores'] = self.min_accepted_cores - dict_obj['min_accepted_mem'] = self.min_accepted_mem - dict_obj['min_accepted_gpus'] = self.min_accepted_gpus - dict_obj['max_accepted_cores'] = self.max_accepted_cores - dict_obj['max_accepted_mem'] = self.max_accepted_mem - dict_obj['max_accepted_gpus'] = self.max_accepted_gpus - dict_obj['scheduling'] = self.tpv_dest_tags.to_dict() - dict_obj['tags'] = self.handler_tags - return dict_obj - - def __eq__(self, other): - if not isinstance(other, Destination): - # don't attempt to compare against unrelated types - return NotImplemented - - return super().__eq__(other) and ( - self.runner == other.runner and - self.dest_name == other.dest_name and - self.min_accepted_cores == other.min_accepted_cores and - self.min_accepted_mem == other.min_accepted_mem and - self.min_accepted_gpus == other.min_accepted_gpus and - self.max_accepted_cores == other.max_accepted_cores and - self.max_accepted_mem == other.max_accepted_mem and - self.max_accepted_gpus == other.max_accepted_gpus and - self.tpv_dest_tags == other.tpv_dest_tags and - self.handler_tags == other.handler_tags - ) - - def __repr__(self): - return f"runner={self.runner}, dest_name={self.dest_name}, min_accepted_cores={self.min_accepted_cores}, "\ - f"min_accepted_mem={self.min_accepted_mem}, min_accepted_gpus={self.min_accepted_gpus}, "\ - f"max_accepted_cores={self.max_accepted_cores}, max_accepted_mem={self.max_accepted_mem}, "\ - f"max_accepted_gpus={self.max_accepted_gpus}, tpv_dest_tags={self.tpv_dest_tags}, "\ - f"handler_tags={self.handler_tags}" + super().__repr__() - - def override(self, entity): + def override(self, entity: Entity): new_entity = super().override(entity) - new_entity.runner = self.runner if self.runner is not None else getattr(entity, 'runner', None) - new_entity.dest_name = self.dest_name if self.dest_name is not None else getattr(entity, 'dest_name', None) - new_entity.min_accepted_cores = (self.min_accepted_cores if self.min_accepted_cores is not None - else getattr(entity, 'min_accepted_cores', None)) - new_entity.min_accepted_mem = (self.min_accepted_mem if self.min_accepted_mem is not None - else getattr(entity, 'min_accepted_mem', None)) - new_entity.min_accepted_gpus = (self.min_accepted_gpus if self.min_accepted_gpus is not None - else getattr(entity, 'min_accepted_gpus', None)) - new_entity.max_accepted_cores = (self.max_accepted_cores if self.max_accepted_cores is not None - else getattr(entity, 'max_accepted_cores', None)) - new_entity.max_accepted_mem = (self.max_accepted_mem if self.max_accepted_mem is not None - else getattr(entity, 'max_accepted_mem', None)) - new_entity.max_accepted_gpus = (self.max_accepted_gpus if self.max_accepted_gpus is not None - else getattr(entity, 'max_accepted_gpus', None)) - new_entity.handler_tags = self.handler_tags or getattr(entity, 'handler_tags', None) + self.override_single_property(new_entity, self, entity, "runner") + self.override_single_property(new_entity, self, entity, "dest_name") + self.override_single_property(new_entity, self, entity, "min_accepted_cores") + self.override_single_property(new_entity, self, entity, "min_accepted_mem") + self.override_single_property(new_entity, self, entity, "min_accepted_gpus") + self.override_single_property(new_entity, self, entity, "max_accepted_cores") + self.override_single_property(new_entity, self, entity, "max_accepted_mem") + self.override_single_property(new_entity, self, entity, "max_accepted_gpus") + self.override_single_property(new_entity, self, entity, "handler_tags") return new_entity - def validate(self): - """ - Validates each code block and makes sure the code can be compiled. - This process also results in the compiled code being cached by the loader, - so that future evaluations are faster. - """ - super().validate() - if self.dest_name: - self.loader.compile_code_block(self.dest_name, as_f_string=True) - if self.handler_tags: - self.compile_complex_property(self.handler_tags) - - def evaluate(self, context): + def evaluate(self, context: Dict[str, Any]): new_entity = super(Destination, self).evaluate(context) if self.dest_name is not None: - new_entity.dest_name = self.loader.eval_code_block(self.dest_name, context, as_f_string=True) - context['dest_name'] = new_entity.dest_name + new_entity.dest_name = self.loader.eval_code_block( + self.dest_name, context, as_f_string=True + ) + context["dest_name"] = new_entity.dest_name if self.handler_tags is not None: - new_entity.handler_tags = self.evaluate_complex_property(self.handler_tags, context) - context['handler_tags'] = new_entity.handler_tags + new_entity.handler_tags = self.evaluate_complex_property( + self.handler_tags, context + ) + context["handler_tags"] = new_entity.handler_tags return new_entity - def inherit(self, entity): + def inherit(self, entity: Entity): new_entity = super().inherit(entity) if entity: new_entity.tpv_dest_tags = self.tpv_dest_tags.inherit(entity.tpv_dest_tags) return new_entity - def matches(self, entity, context): + def matches(self, entity: Entity, context: Dict[str, Any]): """ The match operation checks whether @@ -754,17 +661,41 @@ def matches(self, entity, context): """ if self.abstract: return False - if self.max_accepted_cores is not None and entity.cores is not None and self.max_accepted_cores < entity.cores: + if ( + self.max_accepted_cores is not None + and entity.cores is not None + and self.max_accepted_cores < entity.cores + ): return False - if self.max_accepted_mem is not None and entity.mem is not None and self.max_accepted_mem < entity.mem: + if ( + self.max_accepted_mem is not None + and entity.mem is not None + and self.max_accepted_mem < entity.mem + ): return False - if self.max_accepted_gpus is not None and entity.gpus is not None and self.max_accepted_gpus < entity.gpus: + if ( + self.max_accepted_gpus is not None + and entity.gpus is not None + and self.max_accepted_gpus < entity.gpus + ): return False - if self.min_accepted_cores is not None and entity.cores is not None and self.min_accepted_cores > entity.cores: + if ( + self.min_accepted_cores is not None + and entity.cores is not None + and self.min_accepted_cores > entity.cores + ): return False - if self.min_accepted_mem is not None and entity.mem is not None and self.min_accepted_mem > entity.mem: + if ( + self.min_accepted_mem is not None + and entity.mem is not None + and self.min_accepted_mem > entity.mem + ): return False - if self.min_accepted_gpus is not None and entity.gpus is not None and self.min_accepted_gpus > entity.gpus: + if ( + self.min_accepted_gpus is not None + and entity.gpus is not None + and self.min_accepted_gpus > entity.gpus + ): return False return entity.tpv_tags.match(self.tpv_dest_tags or {}) @@ -780,89 +711,37 @@ def score(self, entity): return score -class Rule(Entity): - - rule_counter = 0 - merge_order = 0 - - def __init__(self, loader, id=None, cores=None, mem=None, gpus=None, min_cores=None, min_mem=None, min_gpus=None, - max_cores=None, max_mem=None, max_gpus=None, env=None, params=None, resubmit=None, - tpv_tags=None, inherits=None, context=None, match=None, execute=None, fail=None): - if not id: - Rule.rule_counter += 1 - id = f"tpv_rule_{Rule.rule_counter}" - super().__init__(loader, id=id, abstract=False, cores=cores, mem=mem, gpus=gpus, min_cores=min_cores, - min_mem=min_mem, min_gpus=min_gpus, max_cores=max_cores, max_mem=max_mem, max_gpus=max_gpus, - env=env, params=params, resubmit=resubmit, tpv_tags=tpv_tags, context=context, - inherits=inherits) - self.match = match - self.execute = execute - self.fail = fail - if self.match: - self.loader.compile_code_block(self.match) - if self.execute: - self.loader.compile_code_block(self.execute, exec_only=True) - if self.fail: - self.loader.compile_code_block(self.fail, as_f_string=True) - - @staticmethod - def from_dict(loader, entity_dict): - return Rule( - loader=loader, - id=entity_dict.get('id'), - cores=entity_dict.get('cores'), - mem=entity_dict.get('mem'), - gpus=entity_dict.get('gpus'), - min_cores=entity_dict.get('min_cores'), - min_mem=entity_dict.get('min_mem'), - min_gpus=entity_dict.get('min_gpus'), - max_cores=entity_dict.get('max_cores'), - max_mem=entity_dict.get('max_mem'), - max_gpus=entity_dict.get('max_gpus'), - env=entity_dict.get('env'), - params=entity_dict.get('params'), - resubmit=entity_dict.get('resubmit'), - tpv_tags=entity_dict.get('scheduling'), - inherits=entity_dict.get('inherits'), - context=entity_dict.get('context'), - # TODO: Remove deprecated match clause in future - match=entity_dict.get('if') or entity_dict.get('match'), - execute=entity_dict.get('execute'), - fail=entity_dict.get('fail') - ) - - def to_dict(self): - dict_obj = super().to_dict() - dict_obj['if'] = self.match - dict_obj['execute'] = self.execute - dict_obj['fail'] = self.fail - return dict_obj - - def override(self, entity): - new_entity = super().override(entity) - new_entity.match = self.match if self.match is not None else getattr(entity, 'match', None) - new_entity.execute = self.execute if self.execute is not None else getattr(entity, 'execute', None) - new_entity.fail = self.fail if self.fail is not None else getattr(entity, 'fail', None) - return new_entity - - def __repr__(self): - return super().__repr__() + f", if={self.match[:10] if self.match else ''}, " \ - f"execute={self.execute[:10] if self.execute else ''}, " \ - f"fail={self.fail[:10] if self.fail else ''}" +class GlobalConfig(BaseModel): + default_inherits: Optional[str] = None + context: Optional[Dict[str, Any]] = Field(default_factory=dict) + + +class TPVConfig(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + global_config: Optional[GlobalConfig] = Field(alias="global", default_factory=dict) + loader: SkipJsonSchema[Optional[TPVCodeBlockInterface]] = Field( + exclude=True, default=None + ) + tools: Optional[Dict[str, Tool]] = Field(default_factory=dict) + users: Optional[Dict[str, User]] = Field(default_factory=dict) + roles: Optional[Dict[str, Role]] = Field(default_factory=dict) + destinations: Optional[Dict[str, Destination]] = Field(default_factory=dict) + + @model_validator(mode="after") + def propagate_loader(self): + if self.loader: + for tool in self.tools.values(): + tool.propagate_loader(self.loader) + for user in self.users.values(): + user.propagate_loader(self.loader) + for role in self.roles.values(): + role.propagate_loader(self.loader) + for destination in self.destinations.values(): + destination.propagate_loader(self.loader) + return self - def is_matching(self, context): - if self.loader.eval_code_block(self.match, context): - return True - else: - return False - def evaluate(self, context): - if self.fail: - from galaxy.jobs.mapper import JobMappingException - raise JobMappingException( - self.loader.eval_code_block(self.fail, context, as_f_string=True)) - if self.execute: - self.loader.eval_code_block(self.execute, context, exec_only=True) - # return any changes made to the entity - return context['entity'] - return self +# from tpv.core import schema +# import yaml +# data = yaml.safe_load(open("/Users/nuwan/work/total-perspective-vortex/tests/fixtures/scenario.yml")) +# config = schema.TPVConfig(**data) diff --git a/tpv/core/evaluator.py b/tpv/core/evaluator.py new file mode 100644 index 0000000..b9a9f40 --- /dev/null +++ b/tpv/core/evaluator.py @@ -0,0 +1,12 @@ +import abc + + +class TPVCodeBlockInterface(abc.ABC): + + @abc.abstractmethod + def compile_code_block(self, code, as_f_string=False, exec_only=False): + pass + + @abc.abstractmethod + def eval_code_block(self, code, context, as_f_string=False, exec_only=False): + pass diff --git a/tpv/core/loader.py b/tpv/core/loader.py index 99f469c..03e576f 100644 --- a/tpv/core/loader.py +++ b/tpv/core/loader.py @@ -1,11 +1,13 @@ from __future__ import annotations + import ast import functools import logging +from typing import Dict -from . import helpers -from . import util -from .entities import Tool, User, Role, Destination, Entity +from . import helpers, util +from .entities import Entity, GlobalConfig, TPVConfig +from .evaluator import TPVCodeBlockInterface log = logging.getLogger(__name__) @@ -14,40 +16,51 @@ class InvalidParentException(Exception): pass -class TPVConfigLoader(object): +class TPVConfigLoader(TPVCodeBlockInterface): + + def __init__(self, tpv_config: TPVConfig): + self.compile_code_block = functools.lru_cache(maxsize=None)( + self.__compile_code_block + ) + self.config = TPVConfig(loader=self, **tpv_config) - def __init__(self, tpv_config: dict): - self.compile_code_block = functools.lru_cache(maxsize=None)(self.__compile_code_block) - self.global_settings = tpv_config.get('global', {}) - entities = self.load_entities(tpv_config) - self.tools = entities.get('tools') - self.users = entities.get('users') - self.roles = entities.get('roles') - self.destinations = entities.get('destinations') + def compile_code_block(self, code, as_f_string=False, exec_only=False): + # interface method, replaced with instance based lru cache in constructor + pass def __compile_code_block(self, code, as_f_string=False, exec_only=False): if as_f_string: code_str = "f'''" + str(code) + "'''" else: code_str = str(code) - block = ast.parse(code_str, mode='exec') + block = ast.parse(code_str, mode="exec") if exec_only: - return compile(block, '', mode='exec'), None + return compile(block, "", mode="exec"), None else: # assumes last node is an expression last = ast.Expression(block.body.pop().value) - return compile(block, '', mode='exec'), compile(last, '', mode='eval') + return compile(block, "", mode="exec"), compile( + last, "", mode="eval" + ) # https://stackoverflow.com/a/39381428 def eval_code_block(self, code, context, as_f_string=False, exec_only=False): - exec_block, eval_block = self.compile_code_block(code, as_f_string=as_f_string, exec_only=exec_only) + exec_block, eval_block = self.compile_code_block( + code, as_f_string=as_f_string, exec_only=exec_only + ) locals = dict(globals()) locals.update(context) - locals.update({ - 'helpers': helpers, - # Don't unnecessarily compute input_size unless it's referred to - 'input_size': helpers.input_size(context['job']) if 'input_size' in str(code) else 0 - }) + locals.update( + { + "helpers": helpers, + # Don't unnecessarily compute input_size unless it's referred to + "input_size": ( + helpers.input_size(context["job"]) + if "input_size" in str(code) + else 0 + ), + } + ) exec(exec_block, locals) if eval_block: return eval(eval_block, locals) @@ -58,8 +71,10 @@ def process_inheritance(self, entity_list: dict[str, Entity], entity: Entity): if entity.inherits: parent_entity = entity_list.get(entity.inherits) if not parent_entity: - raise InvalidParentException(f"The specified parent: {entity.inherits} for" - f" entity: {entity} does not exist") + raise InvalidParentException( + f"The specified parent: {entity.inherits} for" + f" entity: {entity} does not exist" + ) return entity.inherit(self.process_inheritance(entity_list, parent_entity)) # do not process default inheritance here, only at runtime, as multiple can cause default inheritance # to override later matches. @@ -69,38 +84,26 @@ def recompute_inheritance(self, entities: dict[str, Entity]): for key, entity in entities.items(): entities[key] = self.process_inheritance(entities, entity) - def validate_entities(self, entity_class: type, entity_list: dict) -> dict: - # This code relies on dict ordering guarantees provided since python 3.6 - validated = {} - for entity_id, entity_dict in entity_list.items(): - try: - if not entity_dict: - entity_dict = {} - entity_dict['id'] = entity_id - validated[entity_id] = entity_class.from_dict(self, entity_dict) - except Exception: - log.exception(f"Could not load entity of type: {entity_class} with data: {entity_dict}") - raise - self.recompute_inheritance(validated) - return validated - - def load_entities(self, tpv_config: dict) -> dict: - validated = { - 'tools': self.validate_entities(Tool, tpv_config.get('tools', {})), - 'users': self.validate_entities(User, tpv_config.get('users', {})), - 'roles': self.validate_entities(Role, tpv_config.get('roles', {})), - 'destinations': self.validate_entities(Destination, tpv_config.get('destinations', {})) - } - return validated - - def inherit_globals(self, globals_other): - if globals_other: - self.global_settings.update({'default_inherits': globals_other.get('default_inherits')} - if globals_other.get('default_inherits') else {}) - self.global_settings['context'] = self.global_settings.get('context') or {} - self.global_settings['context'].update(globals_other.get('context') or {}) + def validate_entities(self, entities: Dict[str, Entity]) -> dict: + self.recompute_inheritance(entities) - def inherit_existing_entities(self, entities_current, entities_new): + def load_entities(self, tpv_config: TPVConfig) -> dict: + self.validate_entities(tpv_config.tools), + self.validate_entities(tpv_config.users), + self.validate_entities(tpv_config.roles), + self.validate_entities(tpv_config.destinations) + + def inherit_globals(self, globals_other: GlobalConfig): + if globals_other: + self.config.global_config.default_inherits = ( + globals_other.default_inherits + or self.config.global_config.default_inherits + ) + self.config.global_config.context.update(globals_other.context) + + def inherit_existing_entities( + self, entities_current: dict[str, Entity], entities_new: dict[str, Entity] + ): for entity in entities_new.values(): if entities_current.get(entity.id): current_entity = entities_current.get(entity.id) @@ -111,12 +114,15 @@ def inherit_existing_entities(self, entities_current, entities_new): entities_current[entity.id] = entity self.recompute_inheritance(entities_current) + def merge_config(self, config: TPVConfig): + self.inherit_globals(config.global_config) + self.inherit_existing_entities(self.config.tools, config.tools) + self.inherit_existing_entities(self.config.users, config.users) + self.inherit_existing_entities(self.config.roles, config.roles) + self.inherit_existing_entities(self.config.destinations, config.destinations) + def merge_loader(self, loader: TPVConfigLoader): - self.inherit_globals(loader.global_settings) - self.inherit_existing_entities(self.tools, loader.tools) - self.inherit_existing_entities(self.users, loader.users) - self.inherit_existing_entities(self.roles, loader.roles) - self.inherit_existing_entities(self.destinations, loader.destinations) + self.merge_config(loader.config) @staticmethod def from_url_or_path(url_or_path: str): diff --git a/tpv/core/mapper.py b/tpv/core/mapper.py index 6d5db57..a981e89 100644 --- a/tpv/core/mapper.py +++ b/tpv/core/mapper.py @@ -2,7 +2,7 @@ import logging import re -from .entities import Tool, TryNextDestinationOrFail, TryNextDestinationOrWait +from .entities import Entity, Tool, TryNextDestinationOrFail, TryNextDestinationOrWait from .loader import TPVConfigLoader from galaxy.jobs import JobDestination @@ -15,16 +15,13 @@ class EntityToDestinationMapper(object): def __init__(self, loader: TPVConfigLoader): self.loader = loader - self.entities = { - "tools": loader.tools, - "users": loader.users, - "roles": loader.roles - } - self.destinations = loader.destinations - self.default_inherits = loader.global_settings.get('default_inherits') - self.global_context = loader.global_settings.get('context') + self.config = loader.config + self.destinations = self.config.destinations + self.default_inherits = self.config.global_config.default_inherits + self.global_context = self.config.global_config.context self.lookup_tool_regex = functools.lru_cache(maxsize=None)(self.__compile_tool_regex) - self.inherit_matching_entities = functools.lru_cache(maxsize=None)(self.__inherit_matching_entities) + # self.inherit_matching_entities = functools.lru_cache(maxsize=None)(self.__inherit_matching_entities) + self.inherit_matching_entities = self.__inherit_matching_entities def __compile_tool_regex(self, key): try: @@ -33,7 +30,7 @@ def __compile_tool_regex(self, key): log.error(f"Failed to compile regex: {key}") raise - def _find_entities_matching_id(self, entity_list, entity_name): + def _find_entities_matching_id(self, entity_list: dict[str, Entity], entity_name: str): default_inherits = self.__get_default_inherits(entity_list) if default_inherits: matches = [default_inherits] @@ -49,12 +46,11 @@ def _find_entities_matching_id(self, entity_list, entity_name): matches.append(match) return matches - def __inherit_matching_entities(self, entity_type, entity_name): - entity_list = self.entities.get(entity_type) + def __inherit_matching_entities(self, entity_list: dict[str, Entity], entity_name: str): matches = self._find_entities_matching_id(entity_list, entity_name) return self.inherit_entities(matches) - def __get_default_inherits(self, entity_list): + def __get_default_inherits(self, entity_list: dict[str, Entity]): if self.default_inherits: default_match = entity_list.get(self.default_inherits) if default_match: @@ -101,14 +97,14 @@ def to_galaxy_destination(self, destination): ) def _find_matching_entities(self, tool, user): - tool_entity = self.inherit_matching_entities("tools", tool.id) + tool_entity = self.inherit_matching_entities(self.config.tools, tool.id) if not tool_entity: - tool_entity = Tool.from_dict(self.loader, {'id': tool.id}) + tool_entity = Tool(loader=self.loader, id=tool.id) entity_list = [tool_entity] if user: - role_entities = (self.inherit_matching_entities("roles", role.name) + role_entities = (self.inherit_matching_entities(self.config.roles, role.name) for role in user.all_roles() if not role.deleted) # trim empty user_role_entities = (role for role in role_entities if role) @@ -116,7 +112,7 @@ def _find_matching_entities(self, tool, user): if user_role_entity: entity_list += [user_role_entity] - user_entity = self.inherit_matching_entities("users", user.email) + user_entity = self.inherit_matching_entities(self.config.users, user.email) if user_entity: entity_list += [user_entity] From b1b2325ce1dd191325c0986fd7094bd09258f477 Mon Sep 17 00:00:00 2001 From: nuwang <2070605+nuwang@users.noreply.github.com> Date: Mon, 12 Aug 2024 16:27:01 +0530 Subject: [PATCH 03/30] Add black configuration --- setup.cfg | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/setup.cfg b/setup.cfg index 99d372d..7d5b991 100644 --- a/setup.cfg +++ b/setup.cfg @@ -19,3 +19,10 @@ match=^[Tt]est [bdist_wheel] universal = 1 + +[black] +line-length = 88 + +[isort] +line_length = 88 +profile = black From d9922b6efa10af75b685653ce17235a329d8846c Mon Sep 17 00:00:00 2001 From: nuwang <2070605+nuwang@users.noreply.github.com> Date: Mon, 12 Aug 2024 16:28:20 +0530 Subject: [PATCH 04/30] Fix invalid tag tests --- tests/fixtures/mapping-basic.yml | 6 ---- tests/fixtures/mapping-invalid-regex.yml | 6 ---- tests/fixtures/mapping-invalid-tags.yml | 40 ++++++++++++++++++++++++ tests/test_mapper_basic.py | 5 +-- 4 files changed, 43 insertions(+), 14 deletions(-) create mode 100644 tests/fixtures/mapping-invalid-tags.yml diff --git a/tests/fixtures/mapping-basic.yml b/tests/fixtures/mapping-basic.yml index f06869d..724f0ab 100644 --- a/tests/fixtures/mapping-basic.yml +++ b/tests/fixtures/mapping-basic.yml @@ -28,12 +28,6 @@ tools: scheduling: require: - non_existent - invalidly_tagged_tool: - scheduling: - require: - - general - reject: - - general regex_tool.*: scheduling: require: diff --git a/tests/fixtures/mapping-invalid-regex.yml b/tests/fixtures/mapping-invalid-regex.yml index ce3de75..ffc43d0 100644 --- a/tests/fixtures/mapping-invalid-regex.yml +++ b/tests/fixtures/mapping-invalid-regex.yml @@ -28,12 +28,6 @@ tools: scheduling: require: - non_existent - invalidly_tagged_tool: - scheduling: - require: - - general - reject: - - general regex_tool.*: scheduling: require: diff --git a/tests/fixtures/mapping-invalid-tags.yml b/tests/fixtures/mapping-invalid-tags.yml new file mode 100644 index 0000000..1adca66 --- /dev/null +++ b/tests/fixtures/mapping-invalid-tags.yml @@ -0,0 +1,40 @@ +global: + default_inherits: default + +tools: + default: + abstract: true + cores: 2 + mem: 8 + env: {} + scheduling: + require: [] + prefer: + - general + accept: + reject: + - pulsar + rules: [] + invalidly_tagged_tool: + scheduling: + require: + - general + reject: + - general + +destinations: + local: + runner: local + max_accepted_cores: 4 + max_accepted_mem: 16 + scheduling: + prefer: + - general + k8s_environment: + runner: k8s + max_accepted_cores: 16 + max_accepted_mem: 64 + max_accepted_gpus: 2 + scheduling: + prefer: + - pulsar diff --git a/tests/test_mapper_basic.py b/tests/test_mapper_basic.py index abe59ea..a70e079 100644 --- a/tests/test_mapper_basic.py +++ b/tests/test_mapper_basic.py @@ -35,8 +35,9 @@ def test_map_unschedulable_tool(self): def test_map_invalidly_tagged_tool(self): tool = mock_galaxy.Tool('invalidly_tagged_tool') - with self.assertRaisesRegex(JobMappingException, "No destinations are available to fulfill request"): - self._map_to_destination(tool) + config = os.path.join(os.path.dirname(__file__), 'fixtures/mapping-invalid-tags.yml') + with self.assertRaisesRegex(Exception, r"Duplicate tags found: 'general' in \['require', 'reject'\]"): + self._map_to_destination(tool, tpv_config_path=config) def test_map_tool_by_regex(self): tool = mock_galaxy.Tool('regex_tool_test') From c46641bea02866d005bacd6122d06f799d2580a9 Mon Sep 17 00:00:00 2001 From: nuwang <2070605+nuwang@users.noreply.github.com> Date: Mon, 12 Aug 2024 16:29:28 +0530 Subject: [PATCH 05/30] Fix loader propagation and other minor fixes --- tpv/commands/linter.py | 4 +- tpv/core/entities.py | 86 +++++++++++++++++++++--------------------- 2 files changed, 45 insertions(+), 45 deletions(-) diff --git a/tpv/commands/linter.py b/tpv/commands/linter.py index b96f9c0..08f283d 100644 --- a/tpv/commands/linter.py +++ b/tpv/commands/linter.py @@ -28,7 +28,7 @@ def lint(self): self.print_errors_and_warnings() def lint_tools(self, loader): - default_inherits = loader.global_settings.get('default_inherits') + default_inherits = loader.config.global_config.get('default_inherits') for tool_regex, tool in loader.tools.items(): try: re.compile(tool_regex) @@ -41,7 +41,7 @@ def lint_tools(self, loader): "will be excluded from scheduling decisions.") def lint_destinations(self, loader): - default_inherits = loader.global_settings.get('default_inherits') + default_inherits = loader.config.global_config.get('default_inherits') for destination in loader.config.destinations.values(): if not destination.runner and not destination.abstract: self.errors.append(f"Destination '{destination.id}' does not define the runner parameter. " diff --git a/tpv/core/entities.py b/tpv/core/entities.py index 19adf7b..2c1060f 100644 --- a/tpv/core/entities.py +++ b/tpv/core/entities.py @@ -32,19 +32,20 @@ class TryNextDestinationOrWait(Exception): def default_field_copier(entity1, entity2, property_name): + # if property_name in entity1.model_fields_set return ( getattr( entity1, property_name, ) - if getattr(entity1, property_name, None) is None + if getattr(entity1, property_name, None) is not None else getattr(entity2, property_name, None) ) def default_dict_copier(entity1, entity2, property_name): new_dict = copy.deepcopy(getattr(entity2, property_name)) or {} - new_dict.update(getattr(entity2, property_name) or {}) + new_dict.update(copy.deepcopy(getattr(entity1, property_name)) or {}) return new_dict @@ -97,18 +98,18 @@ def check_duplicates(self): return self def all_tags(self) -> Generator[str, None, None]: - return itertools.chain(self.require, self.prefer, self.accept, self.reject) + return itertools.chain(self.require or [], self.prefer or [], self.accept or [], self.reject or []) def add_tag_override(self, tag_type: TagType, tag_value: str): # Remove tag from all categories for field in TagType: field_name = field.name.lower() - if tag_value in getattr(self, field_name): + if tag_value in (getattr(self, field_name) or []): getattr(self, field_name).remove(tag_value) # Add tag to the specified category tag_field = tag_type.name.lower() - setattr(self, tag_field, getattr(self, tag_field, []).append(tag_value)) + setattr(self, tag_field, getattr(self, tag_field, []) + [tag_value]) def inherit(self, other: "SchedulingTags") -> "SchedulingTags": # Create new lists of tags that combine self and other @@ -119,15 +120,15 @@ def inherit(self, other: "SchedulingTags") -> "SchedulingTags": TagType.REQUIRE, TagType.REJECT, ]: - for tag in getattr(self, tag_type.name.lower(), []): + for tag in (getattr(self, tag_type.name.lower()) or []): new_tags.add_tag_override(tag_type, tag) return new_tags def can_combine(self, other: "SchedulingTags") -> bool: - self_required = set(self.require) - other_required = set(other.require) - self_rejected = set(self.reject) - other_rejected = set(other.reject) + self_required = set(self.require or []) + other_required = set(other.require or []) + self_rejected = set(self.reject or []) + other_rejected = set(other.reject or []) if self_required.intersection(other_rejected) or self_rejected.intersection( other_required @@ -148,18 +149,18 @@ def combine(self, other: "SchedulingTags") -> "SchedulingTags": TagType.REQUIRE, TagType.REJECT, ]: - for tag in getattr(other, tag_type.name.lower()): + for tag in getattr(other, tag_type.name.lower()) or []: new_tags.add_tag_override(tag_type, tag) - for tag in getattr(self, tag_type.name.lower()): + for tag in getattr(self, tag_type.name.lower()) or []: new_tags.add_tag_override(tag_type, tag) return new_tags def match(self, other: "SchedulingTags") -> bool: - self_required = set(self.require) - other_required = set(other.require) - self_rejected = set(self.reject) - other_rejected = set(other.reject) + self_required = set(self.require or []) + other_required = set(other.require or []) + self_rejected = set(self.reject or []) + other_rejected = set(other.reject or []) return ( self_required.issubset(other.all_tags()) @@ -171,8 +172,9 @@ def match(self, other: "SchedulingTags") -> bool: def score(self, other: "SchedulingTags") -> int: score = 0 for tag_type in TagType: - self_tags = set(getattr(self, tag_type.name.lower())) - other_tags = set(getattr(other, tag_type.name.lower())) + tag_type_name = tag_type.name.lower() + self_tags = set(getattr(self, tag_type_name) or []) + other_tags = set(getattr(other, tag_type_name) or []) common_tags = self_tags & other_tags score += len(common_tags) * int(tag_type) * int(tag_type) @@ -214,9 +216,10 @@ class Config: def __init__(self, **data: Any): super().__init__(**data) - self.propagate_loader(self.loader) + self.propagate_parent_properties(id=self.id, loader=self.loader) - def propagate_loader(self, loader): + def propagate_parent_properties(self, id=None, loader=None): + self.id = id self.loader = loader def __deepcopy__(self, memo: dict): @@ -298,13 +301,11 @@ def merge_env_list(original, replace): def override_single_property( entity, entity1, entity2, property_name, field_copier=default_field_copier ): - if ( - property_name in entity1.model_fields_set - or property_name in entity2.model_fields_set - ): - setattr( - entity, property_name, field_copier(entity1, entity2, property_name) - ) + # if ( + # property_name in entity1.model_fields_set + # or property_name in entity2.model_fields_set + # ): + setattr(entity, property_name, field_copier(entity1, entity2, property_name)) def override(self, entity: "Entity") -> "Entity": if entity.merge_order <= self.merge_order: @@ -497,13 +498,13 @@ def set_default_id(cls): def override(self, entity): new_entity = super().override(entity) - self.override_single_property(new_entity, self, entity, "match") + self.override_single_property(new_entity, self, entity, "if_condition") self.override_single_property(new_entity, self, entity, "execute") self.override_single_property(new_entity, self, entity, "fail") return new_entity def is_matching(self, context): - if self.loader.eval_code_block(self.match, context): + if self.loader.eval_code_block(self.if_condition, context): return True else: return False @@ -526,8 +527,8 @@ class EntityWithRules(Entity): merge_order: ClassVar[int] = 1 rules: Optional[Dict[str, Rule]] = Field(default_factory=dict) - def propagate_loader(self, loader): - super().propagate_loader(loader) + def propagate_parent_properties(self, id=None, loader=None): + super().propagate_parent_properties(id=id, loader=loader) for rule in self.rules.values(): rule.loader = loader @@ -606,10 +607,9 @@ class Destination(EntityWithRules): ) handler_tags: Optional[List[str]] = Field(alias="tags", default_factory=list) - @model_validator(mode="after") - def assign_defaults(self): + def propagate_parent_properties(self, id=None, loader=None): + super().propagate_parent_properties(id=id, loader=loader) self.dest_name = self.dest_name or self.id - return self def override(self, entity: Entity): new_entity = super().override(entity) @@ -728,16 +728,16 @@ class TPVConfig(BaseModel): destinations: Optional[Dict[str, Destination]] = Field(default_factory=dict) @model_validator(mode="after") - def propagate_loader(self): + def propagate_parent_properties(self): if self.loader: - for tool in self.tools.values(): - tool.propagate_loader(self.loader) - for user in self.users.values(): - user.propagate_loader(self.loader) - for role in self.roles.values(): - role.propagate_loader(self.loader) - for destination in self.destinations.values(): - destination.propagate_loader(self.loader) + for id, tool in self.tools.items(): + tool.propagate_parent_properties(id=id, loader=self.loader) + for id, user in self.users.items(): + user.propagate_parent_properties(id=id, loader=self.loader) + for id, role in self.roles.items(): + role.propagate_parent_properties(id=id, loader=self.loader) + for id, destination in self.destinations.items(): + destination.propagate_parent_properties(id=id, loader=self.loader) return self From 43a7915960ff6adb7f805f5f228fcbfa011b7ec8 Mon Sep 17 00:00:00 2001 From: nuwang <2070605+nuwang@users.noreply.github.com> Date: Mon, 12 Aug 2024 16:51:50 +0530 Subject: [PATCH 06/30] Gpus can only be whole numbers --- tpv/core/entities.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tpv/core/entities.py b/tpv/core/entities.py index 2c1060f..28c1341 100644 --- a/tpv/core/entities.py +++ b/tpv/core/entities.py @@ -197,13 +197,13 @@ class Config: abstract: Optional[bool] = False cores: Optional[Union[float, str]] = None mem: Optional[Union[float, str]] = None - gpus: Optional[Union[float, str]] = None + gpus: Optional[Union[int, str]] = None min_cores: Optional[Union[float, str]] = None min_mem: Optional[Union[float, str]] = None - min_gpus: Optional[Union[float, str]] = None + min_gpus: Optional[Union[int, str]] = None max_cores: Optional[Union[float, str]] = None max_mem: Optional[Union[float, str]] = None - max_gpus: Optional[Union[float, str]] = None + max_gpus: Optional[Union[int, str]] = None env: Optional[List[Dict[str, str]]] = None params: Optional[Dict[str, str]] = None resubmit: Optional[Dict[str, str]] = Field(default_factory=dict) @@ -592,10 +592,10 @@ class Destination(EntityWithRules): runner: Optional[str] = None max_accepted_cores: Optional[float] = None max_accepted_mem: Optional[float] = None - max_accepted_gpus: Optional[float] = None + max_accepted_gpus: Optional[int] = None min_accepted_cores: Optional[float] = None min_accepted_mem: Optional[float] = None - min_accepted_gpus: Optional[float] = None + min_accepted_gpus: Optional[int] = None dest_name: Optional[str] = Field(alias="destination_name_override", default=None) # tpv_tags track what tags the entity being scheduled requested, while tpv_dest_tags track what the destination # supports. When serializing a Destination, we don't need tpv_tags, only tpv_dest_tags. From 63ecdea23ab799fe222083765ad946450161df34 Mon Sep 17 00:00:00 2001 From: nuwang <2070605+nuwang@users.noreply.github.com> Date: Mon, 12 Aug 2024 22:56:09 +0530 Subject: [PATCH 07/30] Process inheritance of loaded entities --- tpv/core/loader.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tpv/core/loader.py b/tpv/core/loader.py index 03e576f..707d53d 100644 --- a/tpv/core/loader.py +++ b/tpv/core/loader.py @@ -23,6 +23,7 @@ def __init__(self, tpv_config: TPVConfig): self.__compile_code_block ) self.config = TPVConfig(loader=self, **tpv_config) + self.process_entities(self.config) def compile_code_block(self, code, as_f_string=False, exec_only=False): # interface method, replaced with instance based lru cache in constructor @@ -87,7 +88,7 @@ def recompute_inheritance(self, entities: dict[str, Entity]): def validate_entities(self, entities: Dict[str, Entity]) -> dict: self.recompute_inheritance(entities) - def load_entities(self, tpv_config: TPVConfig) -> dict: + def process_entities(self, tpv_config: TPVConfig) -> dict: self.validate_entities(tpv_config.tools), self.validate_entities(tpv_config.users), self.validate_entities(tpv_config.roles), From 3f06eda4f29f500c6fb46fa9f7425ab6bc977e40 Mon Sep 17 00:00:00 2001 From: nuwang <2070605+nuwang@users.noreply.github.com> Date: Mon, 12 Aug 2024 22:56:31 +0530 Subject: [PATCH 08/30] Only apply Rule properties if instance of Rule --- tpv/core/entities.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tpv/core/entities.py b/tpv/core/entities.py index 28c1341..9d364fd 100644 --- a/tpv/core/entities.py +++ b/tpv/core/entities.py @@ -109,7 +109,7 @@ def add_tag_override(self, tag_type: TagType, tag_value: str): # Add tag to the specified category tag_field = tag_type.name.lower() - setattr(self, tag_field, getattr(self, tag_field, []) + [tag_value]) + setattr(self, tag_field, getattr(self, tag_field, []) or [] + [tag_value]) def inherit(self, other: "SchedulingTags") -> "SchedulingTags": # Create new lists of tags that combine self and other @@ -498,9 +498,10 @@ def set_default_id(cls): def override(self, entity): new_entity = super().override(entity) - self.override_single_property(new_entity, self, entity, "if_condition") - self.override_single_property(new_entity, self, entity, "execute") - self.override_single_property(new_entity, self, entity, "fail") + if isinstance(entity, Rule): + self.override_single_property(new_entity, self, entity, "if_condition") + self.override_single_property(new_entity, self, entity, "execute") + self.override_single_property(new_entity, self, entity, "fail") return new_entity def is_matching(self, context): From 562202127823d282d433aa65864a4edc0cfd3ab3 Mon Sep 17 00:00:00 2001 From: nuwang <2070605+nuwang@users.noreply.github.com> Date: Mon, 12 Aug 2024 22:59:04 +0530 Subject: [PATCH 09/30] Change factory for GlobalConfig --- tpv/core/entities.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tpv/core/entities.py b/tpv/core/entities.py index 9d364fd..ef181e6 100644 --- a/tpv/core/entities.py +++ b/tpv/core/entities.py @@ -719,7 +719,7 @@ class GlobalConfig(BaseModel): class TPVConfig(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) - global_config: Optional[GlobalConfig] = Field(alias="global", default_factory=dict) + global_config: Optional[GlobalConfig] = Field(alias="global", default_factory=GlobalConfig) loader: SkipJsonSchema[Optional[TPVCodeBlockInterface]] = Field( exclude=True, default=None ) From ab6ffac61ec376fb736be56e40d3cc88d0992405 Mon Sep 17 00:00:00 2001 From: nuwang <2070605+nuwang@users.noreply.github.com> Date: Tue, 13 Aug 2024 00:57:30 +0530 Subject: [PATCH 10/30] Allow cores and mem to be int as well --- tpv/core/entities.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tpv/core/entities.py b/tpv/core/entities.py index ef181e6..cf10fdb 100644 --- a/tpv/core/entities.py +++ b/tpv/core/entities.py @@ -195,8 +195,8 @@ class Config: ) id: Optional[str] = None abstract: Optional[bool] = False - cores: Optional[Union[float, str]] = None - mem: Optional[Union[float, str]] = None + cores: Optional[Union[int, float, str]] = None + mem: Optional[Union[int, float, str]] = None gpus: Optional[Union[int, str]] = None min_cores: Optional[Union[float, str]] = None min_mem: Optional[Union[float, str]] = None From 359386a394f24a5a8740239bd5d444d0e105ba4d Mon Sep 17 00:00:00 2001 From: nuwang <2070605+nuwang@users.noreply.github.com> Date: Tue, 13 Aug 2024 00:57:50 +0530 Subject: [PATCH 11/30] Fix bug in setting tags --- tpv/core/entities.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tpv/core/entities.py b/tpv/core/entities.py index cf10fdb..67caff3 100644 --- a/tpv/core/entities.py +++ b/tpv/core/entities.py @@ -109,7 +109,8 @@ def add_tag_override(self, tag_type: TagType, tag_value: str): # Add tag to the specified category tag_field = tag_type.name.lower() - setattr(self, tag_field, getattr(self, tag_field, []) or [] + [tag_value]) + current_tags = getattr(self, tag_field, []) or [] + setattr(self, tag_field, current_tags + [tag_value]) def inherit(self, other: "SchedulingTags") -> "SchedulingTags": # Create new lists of tags that combine self and other From c800655fd71f3458eb132112c05344d689f8138d Mon Sep 17 00:00:00 2001 From: nuwang <2070605+nuwang@users.noreply.github.com> Date: Tue, 13 Aug 2024 01:04:19 +0530 Subject: [PATCH 12/30] Fix incorrect variable reference --- tpv/commands/linter.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tpv/commands/linter.py b/tpv/commands/linter.py index 08f283d..dfc4315 100644 --- a/tpv/commands/linter.py +++ b/tpv/commands/linter.py @@ -28,8 +28,8 @@ def lint(self): self.print_errors_and_warnings() def lint_tools(self, loader): - default_inherits = loader.config.global_config.get('default_inherits') - for tool_regex, tool in loader.tools.items(): + default_inherits = loader.config.global_config.default_inherits + for tool_regex, tool in loader.config.tools.items(): try: re.compile(tool_regex) except re.error: @@ -41,7 +41,7 @@ def lint_tools(self, loader): "will be excluded from scheduling decisions.") def lint_destinations(self, loader): - default_inherits = loader.config.global_config.get('default_inherits') + default_inherits = loader.config.global_config.default_inherits for destination in loader.config.destinations.values(): if not destination.runner and not destination.abstract: self.errors.append(f"Destination '{destination.id}' does not define the runner parameter. " From 9fba3598fb6d23583fd01c435bf1911e9287f5fe Mon Sep 17 00:00:00 2001 From: nuwang <2070605+nuwang@users.noreply.github.com> Date: Tue, 13 Aug 2024 01:05:52 +0530 Subject: [PATCH 13/30] Change field ordering --- tpv/core/entities.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tpv/core/entities.py b/tpv/core/entities.py index 67caff3..760754c 100644 --- a/tpv/core/entities.py +++ b/tpv/core/entities.py @@ -191,11 +191,9 @@ class Config: arbitrary_types_allowed = True merge_order: ClassVar[int] = 0 - loader: SkipJsonSchema[Optional[TPVCodeBlockInterface]] = Field( - exclude=True, default=None - ) id: Optional[str] = None abstract: Optional[bool] = False + inherits: Optional[str] = None cores: Optional[Union[int, float, str]] = None mem: Optional[Union[int, float, str]] = None gpus: Optional[Union[int, str]] = None @@ -208,12 +206,14 @@ class Config: env: Optional[List[Dict[str, str]]] = None params: Optional[Dict[str, str]] = None resubmit: Optional[Dict[str, str]] = Field(default_factory=dict) + rank: Optional[str] = None + context: Optional[Dict[str, Any]] = Field(default_factory=dict) + loader: SkipJsonSchema[Optional[TPVCodeBlockInterface]] = Field( + exclude=True, default=None + ) tpv_tags: Optional[SchedulingTags] = Field( alias="scheduling", default_factory=SchedulingTags ) - rank: Optional[str] = None - inherits: Optional[str] = None - context: Optional[Dict[str, Any]] = Field(default_factory=dict) def __init__(self, **data: Any): super().__init__(**data) From 8793622e25bf69b593a641de01bbd3a77b355765 Mon Sep 17 00:00:00 2001 From: nuwang <2070605+nuwang@users.noreply.github.com> Date: Tue, 13 Aug 2024 01:18:43 +0530 Subject: [PATCH 14/30] Fix pattern for tag filtering --- tests/fixtures/scenario-too-many-highmem-jobs.yml | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/fixtures/scenario-too-many-highmem-jobs.yml b/tests/fixtures/scenario-too-many-highmem-jobs.yml index ea3e297..8c005c6 100644 --- a/tests/fixtures/scenario-too-many-highmem-jobs.yml +++ b/tests/fixtures/scenario-too-many-highmem-jobs.yml @@ -59,12 +59,11 @@ users: from galaxy.jobs.rule_helper import RuleHelper from tpv.core.entities import TagType - if entity.tpv_tags.filter(tag_value='highmem'): + if 'highmem' in entity.tpv_tags.all_tags(): rule_helper = RuleHelper(app) # Find all destinations that support highmem destinations = [d.dest_name for d in mapper.destinations.values() - if any(d.tpv_dest_tags.filter(tag_value='highmem', - tag_type=[TagType.REQUIRE, TagType.PREFER, TagType.ACCEPT]))] + if 'highmem' in (d.tpv_dest_tags.require + d.tpv_dest_tags.prefer + d.tpv_dest_tags.accept)] count = rule_helper.job_count(for_user_email=user.email, for_destinations=destinations) if count > 4: retval = True From 10991318b4c9a7f59f8350469af6542c6985bf27 Mon Sep 17 00:00:00 2001 From: nuwang <2070605+nuwang@users.noreply.github.com> Date: Tue, 13 Aug 2024 10:27:44 +0530 Subject: [PATCH 15/30] Add error message when failing to load a file --- tpv/core/loader.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tpv/core/loader.py b/tpv/core/loader.py index 707d53d..aa84c14 100644 --- a/tpv/core/loader.py +++ b/tpv/core/loader.py @@ -128,4 +128,8 @@ def merge_loader(self, loader: TPVConfigLoader): @staticmethod def from_url_or_path(url_or_path: str): tpv_config = util.load_yaml_from_url_or_path(url_or_path) - return TPVConfigLoader(tpv_config) + try: + return TPVConfigLoader(tpv_config) + except Exception as e: + log.exception(f"Error loading TPV config: {url_or_path}") + raise e From fd81f04b8545db575266a2142bee0bd0cbfef2fa Mon Sep 17 00:00:00 2001 From: nuwang <2070605+nuwang@users.noreply.github.com> Date: Tue, 13 Aug 2024 11:22:30 +0530 Subject: [PATCH 16/30] Allow any datatype for params and convert env to str --- tpv/core/entities.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tpv/core/entities.py b/tpv/core/entities.py index 760754c..f48c151 100644 --- a/tpv/core/entities.py +++ b/tpv/core/entities.py @@ -204,7 +204,7 @@ class Config: max_mem: Optional[Union[float, str]] = None max_gpus: Optional[Union[int, str]] = None env: Optional[List[Dict[str, str]]] = None - params: Optional[Dict[str, str]] = None + params: Optional[Dict[str, Any]] = None resubmit: Optional[Dict[str, str]] = Field(default_factory=dict) rank: Optional[str] = None context: Optional[Dict[str, Any]] = Field(default_factory=dict) @@ -232,7 +232,7 @@ def __deepcopy__(self, memo: dict): @staticmethod def convert_env(env): if isinstance(env, dict): - env = [dict(name=k, value=v) for (k, v) in env.items()] + env = [dict(name=k, value=str(v)) for (k, v) in env.items()] return env @model_validator(mode="before") From e4d18323ad24d407e93db08426454fceb73db2e6 Mon Sep 17 00:00:00 2001 From: nuwang <2070605+nuwang@users.noreply.github.com> Date: Tue, 13 Aug 2024 18:18:33 +0530 Subject: [PATCH 17/30] Fix tag_values_match function in helpers --- tpv/core/helpers.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tpv/core/helpers.py b/tpv/core/helpers.py index 343b79c..135b581 100644 --- a/tpv/core/helpers.py +++ b/tpv/core/helpers.py @@ -5,6 +5,7 @@ # If Galaxy is < 23.1 you need to have `packaging` in <= 21.3 from packaging.version import parse as parse_version +from .entities import SchedulingTags import random from functools import reduce from galaxy import model @@ -90,10 +91,8 @@ def concurrent_job_count_for_tool(app, tool, user=None): # requires galaxy vers def tag_values_match(entity, match_tag_values=[], exclude_tag_values=[]): # Return true if an entity has require/prefer/accept tags in the match_tags_values list # and no require/prefer/accept tags in the exclude_tag_values list - return ( - all([any(entity.tpv_tags.filter(tag_value=tag_value)) for tag_value in match_tag_values]) - and not any([any(entity.tpv_tags.filter(tag_value=tag_value)) for tag_value in exclude_tag_values]) - ) + tags = SchedulingTags(require=match_tag_values, reject=exclude_tag_values) + return entity.tpv_tags.match(tags) def tool_version_eq(tool, version): From 503ade69c721d1bed35f8d58303b9e93b4d31d7e Mon Sep 17 00:00:00 2001 From: nuwang <2070605+nuwang@users.noreply.github.com> Date: Thu, 15 Aug 2024 00:20:10 +0530 Subject: [PATCH 18/30] Add back entity.filter() and fix score function --- .../scenario-too-many-highmem-jobs.yml | 2 +- tpv/core/entities.py | 80 +++++++++++++------ 2 files changed, 58 insertions(+), 24 deletions(-) diff --git a/tests/fixtures/scenario-too-many-highmem-jobs.yml b/tests/fixtures/scenario-too-many-highmem-jobs.yml index 8c005c6..f512c85 100644 --- a/tests/fixtures/scenario-too-many-highmem-jobs.yml +++ b/tests/fixtures/scenario-too-many-highmem-jobs.yml @@ -59,7 +59,7 @@ users: from galaxy.jobs.rule_helper import RuleHelper from tpv.core.entities import TagType - if 'highmem' in entity.tpv_tags.all_tags(): + if entity.tpv_tags.filter(tag_value='highmem'): rule_helper = RuleHelper(app) # Find all destinations that support highmem destinations = [d.dest_name for d in mapper.destinations.values() diff --git a/tpv/core/entities.py b/tpv/core/entities.py index f48c151..7ad5759 100644 --- a/tpv/core/entities.py +++ b/tpv/core/entities.py @@ -2,17 +2,18 @@ import itertools import logging from collections import defaultdict +from dataclasses import dataclass from enum import IntEnum -from typing import Any, ClassVar, Dict, Generator, List, Optional, Union +from typing import Any, ClassVar, Dict, Iterable, List, Optional, Union from galaxy import util as galaxy_util from pydantic import ( BaseModel, ConfigDict, Field, + ValidationInfo, field_validator, model_validator, - ValidationInfo, ) from pydantic.json_schema import SkipJsonSchema @@ -56,6 +57,12 @@ class TagType(IntEnum): REJECT = -1 +@dataclass +class Tag: + value: str + tag_type: TagType + + class IncompatibleTagsException(Exception): def __init__(self, first_set: "SchedulingTags", second_set: "SchedulingTags"): super().__init__( @@ -97,8 +104,32 @@ def check_duplicates(self): return self - def all_tags(self) -> Generator[str, None, None]: - return itertools.chain(self.require or [], self.prefer or [], self.accept or [], self.reject or []) + @property + def tags(self) -> Iterable[Tag]: + return itertools.chain( + (Tag(value=tag, tag_type=TagType.REQUIRE) for tag in self.require or []), + (Tag(value=tag, tag_type=TagType.PREFER) for tag in self.prefer or []), + (Tag(value=tag, tag_type=TagType.ACCEPT) for tag in self.accept or []), + (Tag(value=tag, tag_type=TagType.REJECT) for tag in self.reject or []), + ) + + def all_tag_values(self) -> Iterable[str]: + return itertools.chain( + self.require or [], self.prefer or [], self.accept or [], self.reject or [] + ) + + def filter( + self, tag_type: TagType | list[TagType] = None, tag_value: str = None + ) -> list[Tag]: + filtered = self.tags + if tag_type: + if isinstance(tag_type, TagType): + filtered = (tag for tag in filtered if tag.tag_type == tag_type) + else: + filtered = (tag for tag in filtered if tag.tag_type in tag_type) + if tag_value: + filtered = (tag for tag in filtered if tag.value == tag_value) + return filtered def add_tag_override(self, tag_type: TagType, tag_value: str): # Remove tag from all categories @@ -121,7 +152,7 @@ def inherit(self, other: "SchedulingTags") -> "SchedulingTags": TagType.REQUIRE, TagType.REJECT, ]: - for tag in (getattr(self, tag_type.name.lower()) or []): + for tag in getattr(self, tag_type.name.lower()) or []: new_tags.add_tag_override(tag_type, tag) return new_tags @@ -164,26 +195,27 @@ def match(self, other: "SchedulingTags") -> bool: other_rejected = set(other.reject or []) return ( - self_required.issubset(other.all_tags()) - and other_required.issubset(self.all_tags()) - and not self_rejected.intersection(other.all_tags()) - and not other_rejected.intersection(self.all_tags()) + self_required.issubset(other.all_tag_values()) + and other_required.issubset(self.all_tag_values()) + and not self_rejected.intersection(other.all_tag_values()) + and not other_rejected.intersection(self.all_tag_values()) ) def score(self, other: "SchedulingTags") -> int: - score = 0 - for tag_type in TagType: - tag_type_name = tag_type.name.lower() - self_tags = set(getattr(self, tag_type_name) or []) - other_tags = set(getattr(other, tag_type_name) or []) - - common_tags = self_tags & other_tags - score += len(common_tags) * int(tag_type) * int(tag_type) - - unique_self_tags = self_tags - other_tags - score -= len(unique_self_tags) * int(tag_type) - - return score + return ( + sum( + int(tag.tag_type) * int(o.tag_type) + for tag in self.filter() + for o in other.filter() + if tag.value == o.value + ) + # penalize tags that don't exist in the other + - sum( + int(tag.tag_type) + for tag in self.tags + if tag not in other.tags + ) + ) class Entity(BaseModel): @@ -720,7 +752,9 @@ class GlobalConfig(BaseModel): class TPVConfig(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) - global_config: Optional[GlobalConfig] = Field(alias="global", default_factory=GlobalConfig) + global_config: Optional[GlobalConfig] = Field( + alias="global", default_factory=GlobalConfig + ) loader: SkipJsonSchema[Optional[TPVCodeBlockInterface]] = Field( exclude=True, default=None ) From a938add510a42351cf41ab973011cf783c4147c8 Mon Sep 17 00:00:00 2001 From: nuwang <2070605+nuwang@users.noreply.github.com> Date: Thu, 15 Aug 2024 11:19:43 +0530 Subject: [PATCH 19/30] Make sure min_cores/mem etc. types match specified --- tpv/core/entities.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tpv/core/entities.py b/tpv/core/entities.py index 7ad5759..6276ad3 100644 --- a/tpv/core/entities.py +++ b/tpv/core/entities.py @@ -229,11 +229,11 @@ class Config: cores: Optional[Union[int, float, str]] = None mem: Optional[Union[int, float, str]] = None gpus: Optional[Union[int, str]] = None - min_cores: Optional[Union[float, str]] = None - min_mem: Optional[Union[float, str]] = None + min_cores: Optional[Union[int, float, str]] = None + min_mem: Optional[Union[int, float, str]] = None min_gpus: Optional[Union[int, str]] = None - max_cores: Optional[Union[float, str]] = None - max_mem: Optional[Union[float, str]] = None + max_cores: Optional[Union[int, float, str]] = None + max_mem: Optional[Union[int, float, str]] = None max_gpus: Optional[Union[int, str]] = None env: Optional[List[Dict[str, str]]] = None params: Optional[Dict[str, Any]] = None @@ -624,11 +624,11 @@ class User(EntityWithRules): class Destination(EntityWithRules): merge_order: ClassVar[int] = 5 runner: Optional[str] = None - max_accepted_cores: Optional[float] = None - max_accepted_mem: Optional[float] = None + max_accepted_cores: Optional[int | float] = None + max_accepted_mem: Optional[int | float] = None max_accepted_gpus: Optional[int] = None - min_accepted_cores: Optional[float] = None - min_accepted_mem: Optional[float] = None + min_accepted_cores: Optional[int | float] = None + min_accepted_mem: Optional[int | float] = None min_accepted_gpus: Optional[int] = None dest_name: Optional[str] = Field(alias="destination_name_override", default=None) # tpv_tags track what tags the entity being scheduled requested, while tpv_dest_tags track what the destination From 41c38e327be73b89a5b5b53f7bf35ac4ad488b21 Mon Sep 17 00:00:00 2001 From: nuwang <2070605+nuwang@users.noreply.github.com> Date: Thu, 15 Aug 2024 11:20:09 +0530 Subject: [PATCH 20/30] Restore original tag_values_match helper function --- tpv/core/helpers.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tpv/core/helpers.py b/tpv/core/helpers.py index 135b581..343b79c 100644 --- a/tpv/core/helpers.py +++ b/tpv/core/helpers.py @@ -5,7 +5,6 @@ # If Galaxy is < 23.1 you need to have `packaging` in <= 21.3 from packaging.version import parse as parse_version -from .entities import SchedulingTags import random from functools import reduce from galaxy import model @@ -91,8 +90,10 @@ def concurrent_job_count_for_tool(app, tool, user=None): # requires galaxy vers def tag_values_match(entity, match_tag_values=[], exclude_tag_values=[]): # Return true if an entity has require/prefer/accept tags in the match_tags_values list # and no require/prefer/accept tags in the exclude_tag_values list - tags = SchedulingTags(require=match_tag_values, reject=exclude_tag_values) - return entity.tpv_tags.match(tags) + return ( + all([any(entity.tpv_tags.filter(tag_value=tag_value)) for tag_value in match_tag_values]) + and not any([any(entity.tpv_tags.filter(tag_value=tag_value)) for tag_value in exclude_tag_values]) + ) def tool_version_eq(tool, version): From 0607018d5602f3118c4c33fcb0786f112d71b495 Mon Sep 17 00:00:00 2001 From: nuwang <2070605+nuwang@users.noreply.github.com> Date: Thu, 15 Aug 2024 13:38:14 +0530 Subject: [PATCH 21/30] Allow bool or str for if conditions --- tpv/core/entities.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tpv/core/entities.py b/tpv/core/entities.py index 6276ad3..75a5756 100644 --- a/tpv/core/entities.py +++ b/tpv/core/entities.py @@ -520,7 +520,7 @@ def dict(self, **kwargs): class Rule(Entity): rule_counter: ClassVar[int] = 0 id: Optional[str] = Field(default_factory=lambda: Rule.set_default_id()) - if_condition: str = Field(alias="if") + if_condition: str | bool = Field(alias="if") execute: Optional[str] = None fail: Optional[str] = None From 151040dd50af22e341169be5f8748e2fb5b876a5 Mon Sep 17 00:00:00 2001 From: nuwang <2070605+nuwang@users.noreply.github.com> Date: Thu, 15 Aug 2024 13:38:37 +0530 Subject: [PATCH 22/30] Fix score function --- tpv/core/entities.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tpv/core/entities.py b/tpv/core/entities.py index 75a5756..9902cdb 100644 --- a/tpv/core/entities.py +++ b/tpv/core/entities.py @@ -213,7 +213,7 @@ def score(self, other: "SchedulingTags") -> int: - sum( int(tag.tag_type) for tag in self.tags - if tag not in other.tags + if tag.value not in other.all_tag_values() ) ) From 696c8751a52ce98364c00033221ff77803e0f1fd Mon Sep 17 00:00:00 2001 From: nuwang <2070605+nuwang@users.noreply.github.com> Date: Thu, 15 Aug 2024 21:58:13 +0530 Subject: [PATCH 23/30] Fix tests for invalidly tagged users --- tests/fixtures/mapping-rank.yml | 6 --- tests/fixtures/mapping-role.yml | 6 --- tests/fixtures/mapping-rules-changed.yml | 6 --- tests/fixtures/mapping-rules.yml | 6 --- tests/fixtures/mapping-syntax-error.yml | 6 --- tests/fixtures/mapping-user-invalid-tags.yml | 50 ++++++++++++++++++++ tests/fixtures/mapping-user.yml | 6 --- tests/test_mapper_user.py | 9 ++-- 8 files changed, 55 insertions(+), 40 deletions(-) create mode 100644 tests/fixtures/mapping-user-invalid-tags.yml diff --git a/tests/fixtures/mapping-rank.yml b/tests/fixtures/mapping-rank.yml index 192facd..0c45741 100644 --- a/tests/fixtures/mapping-rank.yml +++ b/tests/fixtures/mapping-rank.yml @@ -46,12 +46,6 @@ users: scheduling: require: - earth - improbable@vortex.org: - scheduling: - require: - - pulsar - reject: - - pulsar .*@vortex.org: scheduling: require: diff --git a/tests/fixtures/mapping-role.yml b/tests/fixtures/mapping-role.yml index bca2058..2640d40 100644 --- a/tests/fixtures/mapping-role.yml +++ b/tests/fixtures/mapping-role.yml @@ -68,12 +68,6 @@ users: scheduling: require: - earth - improbable@vortex.org: - scheduling: - require: - - pulsar - reject: - - pulsar .*@vortex.org: scheduling: require: diff --git a/tests/fixtures/mapping-rules-changed.yml b/tests/fixtures/mapping-rules-changed.yml index 6860f81..8861cd6 100644 --- a/tests/fixtures/mapping-rules-changed.yml +++ b/tests/fixtures/mapping-rules-changed.yml @@ -52,12 +52,6 @@ users: max_mem: cores * 6 - if: input_size >= 5 fail: Just because - improbable@vortex.org: - scheduling: - require: - - pulsar - reject: - - pulsar .*@vortex.org: scheduling: require: diff --git a/tests/fixtures/mapping-rules.yml b/tests/fixtures/mapping-rules.yml index be5e53e..118f0ef 100644 --- a/tests/fixtures/mapping-rules.yml +++ b/tests/fixtures/mapping-rules.yml @@ -60,12 +60,6 @@ users: max_mem: cores * 6 - if: input_size >= 5 fail: Just because - improbable@vortex.org: - scheduling: - require: - - pulsar - reject: - - pulsar .*@vortex.org: scheduling: require: diff --git a/tests/fixtures/mapping-syntax-error.yml b/tests/fixtures/mapping-syntax-error.yml index bbd3944..dfe972f 100644 --- a/tests/fixtures/mapping-syntax-error.yml +++ b/tests/fixtures/mapping-syntax-error.yml @@ -42,12 +42,6 @@ users: scheduling: require: - earth - improbable@vortex.org: - scheduling: - require: - - pulsar - reject: - - pulsar .*@vortex.org: scheduling: require: diff --git a/tests/fixtures/mapping-user-invalid-tags.yml b/tests/fixtures/mapping-user-invalid-tags.yml new file mode 100644 index 0000000..aa880c2 --- /dev/null +++ b/tests/fixtures/mapping-user-invalid-tags.yml @@ -0,0 +1,50 @@ +global: + default_inherits: default + +tools: + default: + cores: 2 + mem: 8 + gpus: 1 + env: {} + scheduling: + require: [] + prefer: + - general + accept: + reject: + - pulsar + params: + native_spec: "--mem {mem} --cores {cores}" + rules: [] + +users: + default: + max_cores: 3 + max_mem: 4 + env: {} + scheduling: + require: [] + prefer: + - general + accept: + reject: + - pulsar + rules: [] + improbable@vortex.org: + scheduling: + require: + - pulsar + reject: + - pulsar + +destinations: + local: + runner: local + max_accepted_cores: 4 + max_accepted_mem: 16 + scheduling: + prefer: + - general + accept: + - pulsar diff --git a/tests/fixtures/mapping-user.yml b/tests/fixtures/mapping-user.yml index 3a49bc4..daaa206 100644 --- a/tests/fixtures/mapping-user.yml +++ b/tests/fixtures/mapping-user.yml @@ -60,12 +60,6 @@ users: - earth reject: - pulsar - improbable@vortex.org: - scheduling: - require: - - pulsar - reject: - - pulsar prefect@vortex.org: max_cores: 4 max_mem: 32 diff --git a/tests/test_mapper_user.py b/tests/test_mapper_user.py index 647a0d1..729f47f 100644 --- a/tests/test_mapper_user.py +++ b/tests/test_mapper_user.py @@ -8,10 +8,10 @@ class TestMapperUser(unittest.TestCase): @staticmethod - def _map_to_destination(tool, user): + def _map_to_destination(tool, user, tpv_config_path=None): galaxy_app = mock_galaxy.App(job_conf=os.path.join(os.path.dirname(__file__), 'fixtures/job_conf.yml')) job = mock_galaxy.Job() - tpv_config = os.path.join(os.path.dirname(__file__), 'fixtures/mapping-user.yml') + tpv_config = tpv_config_path or os.path.join(os.path.dirname(__file__), 'fixtures/mapping-user.yml') gateway.ACTIVE_DESTINATION_MAPPER = None return gateway.map_tool_to_destination(galaxy_app, job, tool, user, tpv_config_files=[tpv_config]) @@ -40,8 +40,9 @@ def test_map_invalidly_tagged_user(self): tool = mock_galaxy.Tool('bwa') user = mock_galaxy.User('infinitely', 'improbable@vortex.org') - with self.assertRaises(IncompatibleTagsException): - self._map_to_destination(tool, user) + config = os.path.join(os.path.dirname(__file__), 'fixtures/mapping-user-invalid-tags.yml') + with self.assertRaisesRegex(Exception, r"Duplicate tags found: 'pulsar' in \['require', 'reject'\]"): + self._map_to_destination(tool, user, tpv_config_path=config) def test_map_user_by_regex(self): tool = mock_galaxy.Tool('bwa') From 9e0e04712eb7554d1535a9926ef5d5a44913c0b9 Mon Sep 17 00:00:00 2001 From: nuwang <2070605+nuwang@users.noreply.github.com> Date: Thu, 15 Aug 2024 22:15:05 +0530 Subject: [PATCH 24/30] Switch tests to python 3.11 --- tests/fixtures/linter/linter-invalid-regex.yml | 2 +- tox.ini | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/fixtures/linter/linter-invalid-regex.yml b/tests/fixtures/linter/linter-invalid-regex.yml index 78cf912..1f205e6 100644 --- a/tests/fixtures/linter/linter-invalid-regex.yml +++ b/tests/fixtures/linter/linter-invalid-regex.yml @@ -6,7 +6,7 @@ tools: cores: 2 params: native_spec: "--mem {mem} --cores {cores} --gpus {gpus}" - bwa[0-9]++: + bwa[0-9]++\: gpus: 2 destinations: diff --git a/tox.ini b/tox.ini index 0883ee8..04c5cad 100644 --- a/tox.ini +++ b/tox.ini @@ -6,7 +6,7 @@ # running the tests. [tox] -envlist = py3.10,lint +envlist = py3.11,lint [testenv] commands = # see setup.cfg for options sent to nosetests and coverage From 791f52d5fc7e484eacdc0c83fd767156898b29df Mon Sep 17 00:00:00 2001 From: nuwang <2070605+nuwang@users.noreply.github.com> Date: Fri, 16 Aug 2024 01:15:00 +0530 Subject: [PATCH 25/30] Refactor code evaluation interface --- tests/test_entity.py | 8 +- tpv/core/entities.py | 194 +++++++++++++++++++----------------------- tpv/core/evaluator.py | 39 ++++++++- tpv/core/loader.py | 6 +- 4 files changed, 130 insertions(+), 117 deletions(-) diff --git a/tests/test_entity.py b/tests/test_entity.py index 68dddff..6d1e4cf 100644 --- a/tests/test_entity.py +++ b/tests/test_entity.py @@ -35,9 +35,9 @@ def test_all_entities_refer_to_same_loader(self): } # make sure we are still referring to the same loader after evaluation evaluated_entity = gateway.ACTIVE_DESTINATION_MAPPER.match_combine_evaluate_entities(context, tool, user) - assert evaluated_entity.loader == original_loader + assert evaluated_entity.evaluator == original_loader for rule in evaluated_entity.rules: - assert rule.loader == original_loader + assert rule.evaluator == original_loader def test_destination_to_dict(self): tpv_config = os.path.join(os.path.dirname(__file__), 'fixtures/mapping-rule-argument-based.yml') @@ -48,7 +48,7 @@ def test_destination_to_dict(self): # serialize the destination serialized_destination = destination.dict() # deserialize the same destination - deserialized_destination = Destination(loader=loader, **serialized_destination) + deserialized_destination = Destination(evaluator=loader, **serialized_destination) # make sure the deserialized destination is the same as the original self.assertEqual(deserialized_destination, destination) @@ -61,6 +61,6 @@ def test_tool_to_dict(self): # serialize the tool serialized_tool = tool.dict() # deserialize the same tool - deserialized_tool = Tool(loader=loader, **serialized_tool) + deserialized_tool = Tool(evaluator=loader, **serialized_tool) # make sure the deserialized tool is the same as the original self.assertEqual(deserialized_tool, tool) diff --git a/tpv/core/entities.py b/tpv/core/entities.py index 9902cdb..f06e342 100644 --- a/tpv/core/entities.py +++ b/tpv/core/entities.py @@ -4,20 +4,18 @@ from collections import defaultdict from dataclasses import dataclass from enum import IntEnum -from typing import Any, ClassVar, Dict, Iterable, List, Optional, Union +from typing import Annotated, Any, ClassVar, Dict, Iterable, List, Optional from galaxy import util as galaxy_util from pydantic import ( BaseModel, ConfigDict, Field, - ValidationInfo, - field_validator, model_validator, ) from pydantic.json_schema import SkipJsonSchema -from .evaluator import TPVCodeBlockInterface +from .evaluator import TPVCodeEvaluator log = logging.getLogger(__name__) @@ -32,6 +30,11 @@ class TryNextDestinationOrWait(Exception): pass +@dataclass +class TPVFieldMetadata: + complex_property: bool = False + + def default_field_copier(entity1, entity2, property_name): # if property_name in entity1.model_fields_set return ( @@ -226,21 +229,27 @@ class Config: id: Optional[str] = None abstract: Optional[bool] = False inherits: Optional[str] = None - cores: Optional[Union[int, float, str]] = None - mem: Optional[Union[int, float, str]] = None - gpus: Optional[Union[int, str]] = None - min_cores: Optional[Union[int, float, str]] = None - min_mem: Optional[Union[int, float, str]] = None - min_gpus: Optional[Union[int, str]] = None - max_cores: Optional[Union[int, float, str]] = None - max_mem: Optional[Union[int, float, str]] = None - max_gpus: Optional[Union[int, str]] = None - env: Optional[List[Dict[str, str]]] = None - params: Optional[Dict[str, Any]] = None - resubmit: Optional[Dict[str, str]] = Field(default_factory=dict) - rank: Optional[str] = None + cores: Annotated[Optional[int | float | str], TPVFieldMetadata()] = None + mem: Annotated[Optional[int | float | str], TPVFieldMetadata()] = None + gpus: Annotated[Optional[int | str], TPVFieldMetadata()] = None + min_cores: Annotated[Optional[int | float | str], TPVFieldMetadata()] = None + min_mem: Annotated[Optional[int | float | str], TPVFieldMetadata()] = None + min_gpus: Annotated[Optional[int | str], TPVFieldMetadata()] = None + max_cores: Annotated[Optional[int | float | str], TPVFieldMetadata()] = None + max_mem: Annotated[Optional[int | float | str], TPVFieldMetadata()] = None + max_gpus: Annotated[Optional[int | str], TPVFieldMetadata()] = None + env: Annotated[ + Optional[List[Dict[str, str]]], TPVFieldMetadata(complex_property=True) + ] = None + params: Annotated[ + Optional[Dict[str, Any]], TPVFieldMetadata(complex_property=True) + ] = None + resubmit: Annotated[ + Optional[Dict[str, str]], TPVFieldMetadata(complex_property=True) + ] = Field(default_factory=dict) + rank: Annotated[Optional[str], TPVFieldMetadata()] = None context: Optional[Dict[str, Any]] = Field(default_factory=dict) - loader: SkipJsonSchema[Optional[TPVCodeBlockInterface]] = Field( + evaluator: SkipJsonSchema[Optional[TPVCodeEvaluator]] = Field( exclude=True, default=None ) tpv_tags: Optional[SchedulingTags] = Field( @@ -249,16 +258,31 @@ class Config: def __init__(self, **data: Any): super().__init__(**data) - self.propagate_parent_properties(id=self.id, loader=self.loader) + self.propagate_parent_properties(id=self.id, evaluator=self.evaluator) - def propagate_parent_properties(self, id=None, loader=None): + def propagate_parent_properties(self, id=None, evaluator=None): self.id = id - self.loader = loader + self.evaluator = evaluator + if evaluator: + self.precompile_properties(evaluator) + + def precompile_properties(self, evaluator: TPVCodeEvaluator): + # compile properties and check for errors + if evaluator: + for name, value in self: + field = self.model_fields[name] + if field.metadata and field.metadata[0]: + prop = field.metadata[0] + if isinstance(prop, TPVFieldMetadata): + if prop.complex_property: + evaluator.compile_complex_property(value) + else: + evaluator.compile_code_block(value) def __deepcopy__(self, memo: dict): - # make sure we don't deepcopy the loader: https://github.com/galaxyproject/total-perspective-vortex/issues/53 + # make sure we don't deepcopy the evaluator: https://github.com/galaxyproject/total-perspective-vortex/issues/53 # xref: https://stackoverflow.com/a/68746763/10971151 - memo[id(self.loader)] = self.loader + memo[id(self.evaluator)] = self.evaluator return super().__deepcopy__(memo) @staticmethod @@ -273,50 +297,8 @@ def preprocess(cls, values): if values: values["abstract"] = galaxy_util.asbool(values.get("abstract", False)) values["env"] = Entity.convert_env(values.get("env")) - # loader = values.get("loader") - # compile properties and check for errors - # if loader: - # for f in cls.model_fields: - # field = cls.model_fields[f] - # if f in values and field.metadata and field.metadata[0]: - # metadata = field.metadata[0] - # if metadata.complex_property: - # self.compile_complex_property(loader, values[f]) - # else: - # self.compile_code_block(loader, values[f]) return values - def process_complex_property(self, prop, context: Dict[str, Any], func): - if isinstance(prop, str): - return func(prop, context) - elif isinstance(prop, dict): - evaluated_props = { - key: self.process_complex_property(childprop, context, func) - for key, childprop in prop.items() - } - return evaluated_props - elif isinstance(prop, list): - evaluated_props = [ - self.process_complex_property(childprop, context, func) - for childprop in prop - ] - return evaluated_props - else: - return prop - - @classmethod - def compile_complex_property(cls, loader, prop): - return cls.process_complex_property( - prop, None, lambda p, c: loader.compile_code_block(p, as_f_string=True) - ) - - def evaluate_complex_property(self, prop, context: Dict[str, Any]): - return self.process_complex_property( - prop, - context, - lambda p, c: self.loader.eval_code_block(p, c, as_f_string=True), - ) - @staticmethod def merge_env_list(original, replace): for i, original_elem in enumerate(original): @@ -419,25 +401,29 @@ def evaluate_resources(self, context: Dict[str, Any]): new_entity = copy.deepcopy(self) context.update(self.context or {}) if self.min_gpus is not None: - new_entity.min_gpus = self.loader.eval_code_block(self.min_gpus, context) + new_entity.min_gpus = self.evaluator.eval_code_block(self.min_gpus, context) context["min_gpus"] = new_entity.min_gpus if self.min_cores is not None: - new_entity.min_cores = self.loader.eval_code_block(self.min_cores, context) + new_entity.min_cores = self.evaluator.eval_code_block( + self.min_cores, context + ) context["min_cores"] = new_entity.min_cores if self.min_mem is not None: - new_entity.min_mem = self.loader.eval_code_block(self.min_mem, context) + new_entity.min_mem = self.evaluator.eval_code_block(self.min_mem, context) context["min_mem"] = new_entity.min_mem if self.max_gpus is not None: - new_entity.max_gpus = self.loader.eval_code_block(self.max_gpus, context) + new_entity.max_gpus = self.evaluator.eval_code_block(self.max_gpus, context) context["max_gpus"] = new_entity.max_gpus if self.max_cores is not None: - new_entity.max_cores = self.loader.eval_code_block(self.max_cores, context) + new_entity.max_cores = self.evaluator.eval_code_block( + self.max_cores, context + ) context["max_cores"] = new_entity.max_cores if self.max_mem is not None: - new_entity.max_mem = self.loader.eval_code_block(self.max_mem, context) + new_entity.max_mem = self.evaluator.eval_code_block(self.max_mem, context) context["max_mem"] = new_entity.max_mem if self.gpus is not None: - new_entity.gpus = self.loader.eval_code_block(self.gpus, context) + new_entity.gpus = self.evaluator.eval_code_block(self.gpus, context) # clamp gpus new_entity.gpus = max(new_entity.min_gpus or 0, new_entity.gpus or 0) new_entity.gpus = ( @@ -447,7 +433,7 @@ def evaluate_resources(self, context: Dict[str, Any]): ) context["gpus"] = new_entity.gpus if self.cores is not None: - new_entity.cores = self.loader.eval_code_block(self.cores, context) + new_entity.cores = self.evaluator.eval_code_block(self.cores, context) # clamp cores new_entity.cores = max(new_entity.min_cores or 0, new_entity.cores or 0) new_entity.cores = ( @@ -457,7 +443,7 @@ def evaluate_resources(self, context: Dict[str, Any]): ) context["cores"] = new_entity.cores if self.mem is not None: - new_entity.mem = self.loader.eval_code_block(self.mem, context) + new_entity.mem = self.evaluator.eval_code_block(self.mem, context) # clamp mem new_entity.mem = max(new_entity.min_mem or 0, new_entity.mem or 0) new_entity.mem = ( @@ -478,13 +464,17 @@ def evaluate(self, context: Dict[str, Any]): """ new_entity = self.evaluate_resources(context) if self.env: - new_entity.env = self.evaluate_complex_property(self.env, context) + new_entity.env = self.evaluator.evaluate_complex_property(self.env, context) context["env"] = new_entity.env if self.params: - new_entity.params = self.evaluate_complex_property(self.params, context) + new_entity.params = self.evaluator.evaluate_complex_property( + self.params, context + ) context["params"] = new_entity.params if self.resubmit: - new_entity.resubmit = self.evaluate_complex_property(self.resubmit, context) + new_entity.resubmit = self.evaluator.evaluate_complex_property( + self.resubmit, context + ) context["resubmit"] = new_entity.resubmit return new_entity @@ -497,7 +487,7 @@ def rank_destinations( " function" ) context["candidate_destinations"] = destinations - return self.loader.eval_code_block(self.rank, context) + return self.evaluator.eval_code_block(self.rank, context) else: # Sort destinations by priority log.debug( @@ -538,7 +528,7 @@ def override(self, entity): return new_entity def is_matching(self, context): - if self.loader.eval_code_block(self.if_condition, context): + if self.evaluator.eval_code_block(self.if_condition, context): return True else: return False @@ -548,10 +538,10 @@ def evaluate(self, context): from galaxy.jobs.mapper import JobMappingException raise JobMappingException( - self.loader.eval_code_block(self.fail, context, as_f_string=True) + self.evaluator.eval_code_block(self.fail, context, as_f_string=True) ) if self.execute: - self.loader.eval_code_block(self.execute, context, exec_only=True) + self.evaluator.eval_code_block(self.execute, context, exec_only=True) # return any changes made to the entity return context["entity"] return self @@ -561,16 +551,10 @@ class EntityWithRules(Entity): merge_order: ClassVar[int] = 1 rules: Optional[Dict[str, Rule]] = Field(default_factory=dict) - def propagate_parent_properties(self, id=None, loader=None): - super().propagate_parent_properties(id=id, loader=loader) + def propagate_parent_properties(self, id=None, evaluator=None): + super().propagate_parent_properties(id=id, evaluator=evaluator) for rule in self.rules.values(): - rule.loader = loader - - @field_validator("rules", mode="after") - def inject_loader(cls, v: Dict[str, Entity], info: ValidationInfo): - for element in v.values(): - element.loader = info.data["loader"] - return v + rule.evaluator = evaluator @model_validator(mode="before") @classmethod @@ -639,10 +623,12 @@ class Destination(EntityWithRules): tpv_dest_tags: Optional[SchedulingTags] = Field( alias="scheduling", default_factory=SchedulingTags ) - handler_tags: Optional[List[str]] = Field(alias="tags", default_factory=list) + handler_tags: Annotated[ + Optional[List[str]], TPVFieldMetadata(complex_property=True) + ] = Field(alias="tags", default_factory=list) - def propagate_parent_properties(self, id=None, loader=None): - super().propagate_parent_properties(id=id, loader=loader) + def propagate_parent_properties(self, id=None, evaluator=None): + super().propagate_parent_properties(id=id, evaluator=evaluator) self.dest_name = self.dest_name or self.id def override(self, entity: Entity): @@ -661,12 +647,12 @@ def override(self, entity: Entity): def evaluate(self, context: Dict[str, Any]): new_entity = super(Destination, self).evaluate(context) if self.dest_name is not None: - new_entity.dest_name = self.loader.eval_code_block( + new_entity.dest_name = self.evaluator.eval_code_block( self.dest_name, context, as_f_string=True ) context["dest_name"] = new_entity.dest_name if self.handler_tags is not None: - new_entity.handler_tags = self.evaluate_complex_property( + new_entity.handler_tags = self.evaluator.evaluate_complex_property( self.handler_tags, context ) context["handler_tags"] = new_entity.handler_tags @@ -755,7 +741,7 @@ class TPVConfig(BaseModel): global_config: Optional[GlobalConfig] = Field( alias="global", default_factory=GlobalConfig ) - loader: SkipJsonSchema[Optional[TPVCodeBlockInterface]] = Field( + evaluator: SkipJsonSchema[Optional[TPVCodeEvaluator]] = Field( exclude=True, default=None ) tools: Optional[Dict[str, Tool]] = Field(default_factory=dict) @@ -765,19 +751,13 @@ class TPVConfig(BaseModel): @model_validator(mode="after") def propagate_parent_properties(self): - if self.loader: + if self.evaluator: for id, tool in self.tools.items(): - tool.propagate_parent_properties(id=id, loader=self.loader) + tool.propagate_parent_properties(id=id, evaluator=self.evaluator) for id, user in self.users.items(): - user.propagate_parent_properties(id=id, loader=self.loader) + user.propagate_parent_properties(id=id, evaluator=self.evaluator) for id, role in self.roles.items(): - role.propagate_parent_properties(id=id, loader=self.loader) + role.propagate_parent_properties(id=id, evaluator=self.evaluator) for id, destination in self.destinations.items(): - destination.propagate_parent_properties(id=id, loader=self.loader) + destination.propagate_parent_properties(id=id, evaluator=self.evaluator) return self - - -# from tpv.core import schema -# import yaml -# data = yaml.safe_load(open("/Users/nuwan/work/total-perspective-vortex/tests/fixtures/scenario.yml")) -# config = schema.TPVConfig(**data) diff --git a/tpv/core/evaluator.py b/tpv/core/evaluator.py index b9a9f40..4731ac9 100644 --- a/tpv/core/evaluator.py +++ b/tpv/core/evaluator.py @@ -1,12 +1,45 @@ import abc +from typing import Any, Dict -class TPVCodeBlockInterface(abc.ABC): +class TPVCodeEvaluator(abc.ABC): @abc.abstractmethod - def compile_code_block(self, code, as_f_string=False, exec_only=False): + def compile_code_block(self, code: str, as_f_string=False, exec_only=False): pass @abc.abstractmethod - def eval_code_block(self, code, context, as_f_string=False, exec_only=False): + def eval_code_block( + self, code: str, context: Dict[str, Any], as_f_string=False, exec_only=False + ): pass + + def process_complex_property(self, prop: Any, context: Dict[str, Any], func): + if isinstance(prop, str): + return func(prop, context) + elif isinstance(prop, dict): + evaluated_props = { + key: self.process_complex_property(childprop, context, func) + for key, childprop in prop.items() + } + return evaluated_props + elif isinstance(prop, list): + evaluated_props = [ + self.process_complex_property(childprop, context, func) + for childprop in prop + ] + return evaluated_props + else: + return prop + + def compile_complex_property(self, prop): + return self.process_complex_property( + prop, None, lambda p, c: self.compile_code_block(p, as_f_string=True) + ) + + def evaluate_complex_property(self, prop, context: Dict[str, Any]): + return self.process_complex_property( + prop, + context, + lambda p, c: self.eval_code_block(p, c, as_f_string=True), + ) diff --git a/tpv/core/loader.py b/tpv/core/loader.py index aa84c14..3e189e9 100644 --- a/tpv/core/loader.py +++ b/tpv/core/loader.py @@ -7,7 +7,7 @@ from . import helpers, util from .entities import Entity, GlobalConfig, TPVConfig -from .evaluator import TPVCodeBlockInterface +from .evaluator import TPVCodeEvaluator log = logging.getLogger(__name__) @@ -16,13 +16,13 @@ class InvalidParentException(Exception): pass -class TPVConfigLoader(TPVCodeBlockInterface): +class TPVConfigLoader(TPVCodeEvaluator): def __init__(self, tpv_config: TPVConfig): self.compile_code_block = functools.lru_cache(maxsize=None)( self.__compile_code_block ) - self.config = TPVConfig(loader=self, **tpv_config) + self.config = TPVConfig(evaluator=self, **tpv_config) self.process_entities(self.config) def compile_code_block(self, code, as_f_string=False, exec_only=False): From a137cf058adfe8a5390f0e0686ad20912609efa6 Mon Sep 17 00:00:00 2001 From: nuwang <2070605+nuwang@users.noreply.github.com> Date: Sun, 25 Aug 2024 18:57:46 +0530 Subject: [PATCH 26/30] Improve linting and formatting commands with black and isort --- tox.ini | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/tox.ini b/tox.ini index 04c5cad..14d82af 100644 --- a/tox.ini +++ b/tox.ini @@ -21,5 +21,16 @@ deps = coverage [testenv:lint] -commands = flake8 tpv -deps = flake8 +commands = + flake8 tpv + isort -c --df tpv + black --check --diff tpv +deps = + flake8 + isort + black + +[testenv:format] +commands = + isort tpv + black tpv From 5ea80d0ef2cf7dbf6b8c6eb8aaf58dd1bb2253ae Mon Sep 17 00:00:00 2001 From: nuwang <2070605+nuwang@users.noreply.github.com> Date: Sun, 25 Aug 2024 18:57:59 +0530 Subject: [PATCH 27/30] Run isort against codebase --- tpv/commands/dryrunner.py | 3 ++- tpv/commands/formatter.py | 1 + tpv/commands/shell.py | 3 +-- tpv/commands/test/mock_galaxy.py | 2 +- tpv/core/entities.py | 7 +------ tpv/core/helpers.py | 1 + tpv/core/mapper.py | 6 +++--- tpv/core/util.py | 2 +- tpv/rules/gateway.py | 1 + 9 files changed, 12 insertions(+), 14 deletions(-) diff --git a/tpv/commands/dryrunner.py b/tpv/commands/dryrunner.py index 111bf04..aba8989 100644 --- a/tpv/commands/dryrunner.py +++ b/tpv/commands/dryrunner.py @@ -1,6 +1,7 @@ -from .test import mock_galaxy from tpv.rules import gateway +from .test import mock_galaxy + class TPVDryRunner(): diff --git a/tpv/commands/formatter.py b/tpv/commands/formatter.py index cadcbf4..0e16f11 100644 --- a/tpv/commands/formatter.py +++ b/tpv/commands/formatter.py @@ -1,4 +1,5 @@ from __future__ import annotations + import logging from tpv.core import util diff --git a/tpv/commands/shell.py b/tpv/commands/shell.py index 9ea0c5b..473b736 100644 --- a/tpv/commands/shell.py +++ b/tpv/commands/shell.py @@ -6,8 +6,7 @@ from .dryrunner import TPVDryRunner from .formatter import TPVConfigFormatter -from .linter import TPVConfigLinter -from .linter import TPVLintError +from .linter import TPVConfigLinter, TPVLintError log = logging.getLogger(__name__) diff --git a/tpv/commands/test/mock_galaxy.py b/tpv/commands/test/mock_galaxy.py index be3a82f..22aa0bd 100644 --- a/tpv/commands/test/mock_galaxy.py +++ b/tpv/commands/test/mock_galaxy.py @@ -1,8 +1,8 @@ import hashlib -from galaxy.model import mapping from galaxy.job_metrics import JobMetrics from galaxy.jobs import JobConfiguration +from galaxy.model import mapping from galaxy.util import bunch from galaxy.web_stack import ApplicationStack diff --git a/tpv/core/entities.py b/tpv/core/entities.py index f06e342..5ff2d8d 100644 --- a/tpv/core/entities.py +++ b/tpv/core/entities.py @@ -7,12 +7,7 @@ from typing import Annotated, Any, ClassVar, Dict, Iterable, List, Optional from galaxy import util as galaxy_util -from pydantic import ( - BaseModel, - ConfigDict, - Field, - model_validator, -) +from pydantic import BaseModel, ConfigDict, Field, model_validator from pydantic.json_schema import SkipJsonSchema from .evaluator import TPVCodeEvaluator diff --git a/tpv/core/helpers.py b/tpv/core/helpers.py index 343b79c..0dfee51 100644 --- a/tpv/core/helpers.py +++ b/tpv/core/helpers.py @@ -7,6 +7,7 @@ import random from functools import reduce + from galaxy import model GIGABYTES = 1024.0**3 diff --git a/tpv/core/mapper.py b/tpv/core/mapper.py index a981e89..e551b9d 100644 --- a/tpv/core/mapper.py +++ b/tpv/core/mapper.py @@ -2,12 +2,12 @@ import logging import re -from .entities import Entity, Tool, TryNextDestinationOrFail, TryNextDestinationOrWait -from .loader import TPVConfigLoader - from galaxy.jobs import JobDestination from galaxy.jobs.mapper import JobNotReadyException +from .entities import Entity, Tool, TryNextDestinationOrFail, TryNextDestinationOrWait +from .loader import TPVConfigLoader + log = logging.getLogger(__name__) diff --git a/tpv/core/util.py b/tpv/core/util.py index 0bcfcda..5dea468 100644 --- a/tpv/core/util.py +++ b/tpv/core/util.py @@ -1,7 +1,7 @@ import os -import ruamel.yaml import requests +import ruamel.yaml def load_yaml_from_url_or_path(url_or_path: str): diff --git a/tpv/rules/gateway.py b/tpv/rules/gateway.py index f06a074..8f739a6 100644 --- a/tpv/rules/gateway.py +++ b/tpv/rules/gateway.py @@ -2,6 +2,7 @@ import os from galaxy.util.watcher import get_watcher + from tpv.core.loader import TPVConfigLoader from tpv.core.mapper import EntityToDestinationMapper From 97595071788b8cbb13bf37841cfb531597b11ae0 Mon Sep 17 00:00:00 2001 From: nuwang <2070605+nuwang@users.noreply.github.com> Date: Sun, 25 Aug 2024 18:59:19 +0530 Subject: [PATCH 28/30] reformat codebase using black --- tpv/commands/dryrunner.py | 27 +++++---- tpv/commands/formatter.py | 77 +++++++++++++------------ tpv/commands/linter.py | 31 +++++++---- tpv/commands/shell.py | 96 +++++++++++++++++++------------- tpv/commands/test/mock_galaxy.py | 18 +++--- tpv/core/helpers.py | 40 +++++++++---- tpv/core/mapper.py | 95 ++++++++++++++++++++----------- tpv/core/util.py | 4 +- tpv/rules/gateway.py | 33 ++++++++--- 9 files changed, 261 insertions(+), 160 deletions(-) diff --git a/tpv/commands/dryrunner.py b/tpv/commands/dryrunner.py index aba8989..b21906c 100644 --- a/tpv/commands/dryrunner.py +++ b/tpv/commands/dryrunner.py @@ -3,7 +3,7 @@ from .test import mock_galaxy -class TPVDryRunner(): +class TPVDryRunner: def __init__(self, job_conf, tpv_confs=None, user=None, tool=None, job=None): self.galaxy_app = mock_galaxy.App(job_conf=job_conf, create_model=True) @@ -14,25 +14,30 @@ def __init__(self, job_conf, tpv_confs=None, user=None, tool=None, job=None): self.tpv_config_files = tpv_confs else: self.tpv_config_files = self.galaxy_app.job_config.get_destination( - 'tpv_dispatcher').params['tpv_config_files'] + "tpv_dispatcher" + ).params["tpv_config_files"] def run(self): gateway.ACTIVE_DESTINATION_MAPPER = None - return gateway.map_tool_to_destination(self.galaxy_app, self.job, self.tool, self.user, - tpv_config_files=self.tpv_config_files) + return gateway.map_tool_to_destination( + self.galaxy_app, + self.job, + self.tool, + self.user, + tpv_config_files=self.tpv_config_files, + ) @staticmethod def from_params(job_conf, user=None, tool=None, tpv_confs=None, input_size=None): if user is not None: email = user - user = mock_galaxy.User('gargravarr', email) + user = mock_galaxy.User("gargravarr", email) else: user = None if tool: tool = mock_galaxy.Tool( - tool, - version=tool.split('/')[-1] if '/' in tool else None + tool, version=tool.split("/")[-1] if "/" in tool else None ) else: tool = None @@ -40,8 +45,10 @@ def from_params(job_conf, user=None, tool=None, tpv_confs=None, input_size=None) job = mock_galaxy.Job() if input_size: dataset = mock_galaxy.DatasetAssociation( - "test", - mock_galaxy.Dataset("test.txt", file_size=input_size*1024**3)) + "test", mock_galaxy.Dataset("test.txt", file_size=input_size * 1024**3) + ) job.add_input_dataset(dataset) - return TPVDryRunner(job_conf=job_conf, tpv_confs=tpv_confs, user=user, tool=tool, job=job) + return TPVDryRunner( + job_conf=job_conf, tpv_confs=tpv_confs, user=user, tool=tool, job=job + ) diff --git a/tpv/commands/formatter.py b/tpv/commands/formatter.py index 0e16f11..4a2e06c 100644 --- a/tpv/commands/formatter.py +++ b/tpv/commands/formatter.py @@ -21,6 +21,7 @@ def sort_criteria(key): index = len(keys_to_place_first) # sort by keys to place first, then potential toolshed tools, and finally alphabetically return (index, "/" not in key, key) + return sort_criteria @staticmethod @@ -66,57 +67,61 @@ def multi_level_dict_sorter(dict_to_sort, sort_order): if not sort_order: return dict_to_sort if isinstance(dict_to_sort, dict): - sorted_keys = sorted(dict_to_sort or [], key=TPVConfigFormatter.generic_key_sorter(sort_order.keys())) - return {key: TPVConfigFormatter.multi_level_dict_sorter(dict_to_sort.get(key), - sort_order.get(key, {}) or sort_order.get('*', {})) - for key in sorted_keys} + sorted_keys = sorted( + dict_to_sort or [], + key=TPVConfigFormatter.generic_key_sorter(sort_order.keys()), + ) + return { + key: TPVConfigFormatter.multi_level_dict_sorter( + dict_to_sort.get(key), + sort_order.get(key, {}) or sort_order.get("*", {}), + ) + for key in sorted_keys + } elif isinstance(dict_to_sort, list): - return [TPVConfigFormatter.multi_level_dict_sorter(item, sort_order.get('*', [])) - for item in dict_to_sort] + return [ + TPVConfigFormatter.multi_level_dict_sorter( + item, sort_order.get("*", []) + ) + for item in dict_to_sort + ] else: return dict_to_sort def format(self): - default_inherits = self.yaml_dict.get('global', {}).get('default_inherits') or 'default' + default_inherits = ( + self.yaml_dict.get("global", {}).get("default_inherits") or "default" + ) basic_entity_sort_order = { - 'id': {}, - 'inherits': {}, - 'if': {}, - 'context': {}, - 'gpus': {}, - 'cores': {}, - 'mem': {}, - 'env': { - '*': {} - }, - 'params': { - '*': {} + "id": {}, + "inherits": {}, + "if": {}, + "context": {}, + "gpus": {}, + "cores": {}, + "mem": {}, + "env": {"*": {}}, + "params": {"*": {}}, + "scheduling": { + "require": {}, + "prefer": {}, + "accept": {}, + "reject": {}, }, - 'scheduling': { - 'require': {}, - 'prefer': {}, - 'accept': {}, - 'reject': {}, - } } entity_with_rules_sort_order = { default_inherits: {}, - '*': { - **basic_entity_sort_order, - 'rules': { - '*': basic_entity_sort_order - } - } + "*": {**basic_entity_sort_order, "rules": {"*": basic_entity_sort_order}}, } global_field_sort_order = { - 'global': {}, - 'tools': entity_with_rules_sort_order, - 'roles': entity_with_rules_sort_order, - 'users': entity_with_rules_sort_order, - 'destinations': entity_with_rules_sort_order, + "global": {}, + "tools": entity_with_rules_sort_order, + "roles": entity_with_rules_sort_order, + "users": entity_with_rules_sort_order, + "destinations": entity_with_rules_sort_order, } return self.multi_level_dict_sorter(self.yaml_dict, global_field_sort_order) diff --git a/tpv/commands/linter.py b/tpv/commands/linter.py index dfc4315..7346c3a 100644 --- a/tpv/commands/linter.py +++ b/tpv/commands/linter.py @@ -22,7 +22,9 @@ def lint(self): loader = TPVConfigLoader.from_url_or_path(self.url_or_path) except Exception as e: log.error(f"Linting failed due to syntax errors in yaml file: {e}") - raise TPVLintError("Linting failed due to syntax errors in yaml file: ") from e + raise TPVLintError( + "Linting failed due to syntax errors in yaml file: " + ) from e self.lint_tools(loader) self.lint_destinations(loader) self.print_errors_and_warnings() @@ -38,27 +40,34 @@ def lint_tools(self, loader): self.warnings.append( f"The tool named: {default_inherits} is marked globally as the tool to inherit from " "by default. You may want to mark it as abstract if it is not an actual tool and it " - "will be excluded from scheduling decisions.") + "will be excluded from scheduling decisions." + ) def lint_destinations(self, loader): default_inherits = loader.config.global_config.default_inherits for destination in loader.config.destinations.values(): if not destination.runner and not destination.abstract: - self.errors.append(f"Destination '{destination.id}' does not define the runner parameter. " - "The runner parameter is mandatory.") - if ((destination.cores and not destination.max_accepted_cores) or - (destination.mem and not destination.max_accepted_mem) or - (destination.gpus and not destination.max_accepted_gpus)): + self.errors.append( + f"Destination '{destination.id}' does not define the runner parameter. " + "The runner parameter is mandatory." + ) + if ( + (destination.cores and not destination.max_accepted_cores) + or (destination.mem and not destination.max_accepted_mem) + or (destination.gpus and not destination.max_accepted_gpus) + ): self.errors.append( f"The destination named: {destination.id} defines the cores/mem/gpus property instead of " f"max_accepted_cores/mem/gpus. This is probably an error. If you're migrating from an older " f"version of TPV, the destination properties for cores/mem/gpus have been superseded by the " - f"max_accepted_cores/mem/gpus property. Simply renaming them will give you the same functionality.") + f"max_accepted_cores/mem/gpus property. Simply renaming them will give you the same functionality." + ) if default_inherits == destination.id: self.warnings.append( f"The destination named: {default_inherits} is marked globally as the destination to inherit from " "by default. You may want to mark it as abstract if it is not meant to be dispatched to, and it " - "will be excluded from scheduling decisions.") + "will be excluded from scheduling decisions." + ) def print_errors_and_warnings(self): if self.warnings: @@ -67,7 +76,9 @@ def print_errors_and_warnings(self): if self.errors: for e in self.errors: log.error(e) - raise TPVLintError(f"The following errors occurred during linting: {self.errors}") + raise TPVLintError( + f"The following errors occurred during linting: {self.errors}" + ) @staticmethod def from_url_or_path(url_or_path: str): diff --git a/tpv/commands/shell.py b/tpv/commands/shell.py index 473b736..f5b78c7 100644 --- a/tpv/commands/shell.py +++ b/tpv/commands/shell.py @@ -13,14 +13,14 @@ # https://stackoverflow.com/a/64933809 def repr_str(dumper: RoundTripRepresenter, data: str): - if '\n' in data: - return dumper.represent_scalar('tag:yaml.org,2002:str', data, style='|') - return dumper.represent_scalar('tag:yaml.org,2002:str', data) + if "\n" in data: + return dumper.represent_scalar("tag:yaml.org,2002:str", data, style="|") + return dumper.represent_scalar("tag:yaml.org,2002:str", data) # https://stackoverflow.com/a/37445121 def repr_none(dumper: RoundTripRepresenter, data): - return dumper.represent_scalar(u'tag:yaml.org,2002:null', '') + return dumper.represent_scalar("tag:yaml.org,2002:null", "") def tpv_lint_config_file(args): @@ -37,7 +37,7 @@ def tpv_lint_config_file(args): def tpv_format_config_file(args): try: formatter = TPVConfigFormatter.from_url_or_path(args.config) - yaml = YAML(typ='unsafe', pure=True) + yaml = YAML(typ="unsafe", pure=True) yaml.Representer = RoundTripRepresenter yaml.Representer.add_representer(str, repr_str) yaml.Representer.add_representer(type(None), repr_none) @@ -51,10 +51,15 @@ def tpv_format_config_file(args): def tpv_dry_run_config_files(args): - dry_runner = TPVDryRunner.from_params(user=args.user, tool=args.tool, job_conf=args.job_conf, tpv_confs=args.config, - input_size=args.input_size) + dry_runner = TPVDryRunner.from_params( + user=args.user, + tool=args.tool, + job_conf=args.job_conf, + tpv_confs=args.config, + input_size=args.input_size, + ) destination = dry_runner.run() - yaml = YAML(typ='unsafe', pure=True) + yaml = YAML(typ="unsafe", pure=True) yaml.dump(destination, sys.stdout) @@ -63,52 +68,64 @@ def create_parser(): parser.set_defaults(func=lambda args: parser.print_help()) # debugging and logging settings - parser.add_argument("-v", "--verbose", action="count", - dest="verbosity_count", default=0, - help="increases log verbosity for each occurrence.") - subparsers = parser.add_subparsers(metavar='') + parser.add_argument( + "-v", + "--verbose", + action="count", + dest="verbosity_count", + default=0, + help="increases log verbosity for each occurrence.", + ) + subparsers = parser.add_subparsers(metavar="") # File copy commands lint_parser = subparsers.add_parser( - 'lint', - help='loads a TPV configuration file and checks it for syntax errors', - description="The linter will check yaml syntax and compile python code blocks") + "lint", + help="loads a TPV configuration file and checks it for syntax errors", + description="The linter will check yaml syntax and compile python code blocks", + ) lint_parser.add_argument( - 'config', type=str, - help="Path to the TPV config file to lint. Can be a local path or http url.") + "config", + type=str, + help="Path to the TPV config file to lint. Can be a local path or http url.", + ) lint_parser.set_defaults(func=tpv_lint_config_file) format_parser = subparsers.add_parser( - 'format', - help='Reformats a TPV configuration file and prints it to stdout.', - description="The formatter will reorder tools, users etc by name, moving defaults first") + "format", + help="Reformats a TPV configuration file and prints it to stdout.", + description="The formatter will reorder tools, users etc by name, moving defaults first", + ) format_parser.add_argument( - 'config', type=str, - help="Path to the TPV config file to format. Can be a local path or http url.") + "config", + type=str, + help="Path to the TPV config file to format. Can be a local path or http url.", + ) format_parser.set_defaults(func=tpv_format_config_file) dry_run_parser = subparsers.add_parser( - 'dry-run', - help="Perform a dry run test of a TPV configuration.", - description="") + "dry-run", help="Perform a dry run test of a TPV configuration.", description="" + ) dry_run_parser.add_argument( - '--job-conf', type=str, - required=True, - help="Galaxy job configuration file") + "--job-conf", type=str, required=True, help="Galaxy job configuration file" + ) dry_run_parser.add_argument( - '--input-size', type=int, - help="Input dataset size (in GB)") + "--input-size", type=int, help="Input dataset size (in GB)" + ) dry_run_parser.add_argument( - '--tool', type=str, - default='_default_', - help="Test mapping for Galaxy tool with given ID") + "--tool", + type=str, + default="_default_", + help="Test mapping for Galaxy tool with given ID", + ) dry_run_parser.add_argument( - '--user', type=str, - help="Test mapping for Galaxy user with username or email") + "--user", type=str, help="Test mapping for Galaxy user with username or email" + ) dry_run_parser.add_argument( - 'config', - nargs='*', - help="TPV configuration files, overrides tpv_config_files in Galaxy job configuration if provided") + "config", + nargs="*", + help="TPV configuration files, overrides tpv_config_files in Galaxy job configuration if provided", + ) dry_run_parser.set_defaults(func=tpv_dry_run_config_files) return parser @@ -123,7 +140,8 @@ def configure_logging(verbosity_count): logging.basicConfig( stream=sys.stdout, level=logging.DEBUG if verbosity_count > 3 else logging.ERROR, - format='%(levelname)-5s: %(name)s: %(message)s') + format="%(levelname)-5s: %(name)s: %(message)s", + ) # Set client log level if verbosity_count: log.setLevel(max(4 - verbosity_count, 1) * 10) diff --git a/tpv/commands/test/mock_galaxy.py b/tpv/commands/test/mock_galaxy.py index 22aa0bd..92cdcff 100644 --- a/tpv/commands/test/mock_galaxy.py +++ b/tpv/commands/test/mock_galaxy.py @@ -16,7 +16,9 @@ def __init__(self): self.parameters = [] def add_input_dataset(self, dataset_association): - self.input_datasets.append(JobToInputDatasetAssociation(dataset_association.name, dataset_association)) + self.input_datasets.append( + JobToInputDatasetAssociation(dataset_association.name, dataset_association) + ) def get_param_values(self, app): return self.param_values @@ -72,11 +74,7 @@ def __init__(self, job_conf=None, create_model=False): ) self.job_metrics = JobMetrics() if create_model: - self.model = mapping.init( - "/tmp", - "sqlite:///:memory:", - create_tables=True - ) + self.model = mapping.init("/tmp", "sqlite:///:memory:", create_tables=True) self.application_stack = ApplicationStack(app=self) self.job_config = JobConfiguration(self) @@ -86,9 +84,11 @@ def __init__(self, username, email, roles=[], id=None): self.username = username self.email = email self.roles = [Role(name) for name in roles] - self.id = id or int( - hashlib.sha256(f"{self.username}".encode("utf-8")).hexdigest(), 16 - ) % 1000000 + self.id = ( + id + or int(hashlib.sha256(f"{self.username}".encode("utf-8")).hexdigest(), 16) + % 1000000 + ) def all_roles(self): """ diff --git a/tpv/core/helpers.py b/tpv/core/helpers.py index 0dfee51..1ac9d71 100644 --- a/tpv/core/helpers.py +++ b/tpv/core/helpers.py @@ -25,20 +25,24 @@ def sum_total(prev, current): def calculate_dataset_total(datasets): if datasets: - unique_datasets = {inp_ds.dataset.dataset.id: inp_ds.dataset.dataset for inp_ds in datasets if inp_ds.dataset} + unique_datasets = { + inp_ds.dataset.dataset.id: inp_ds.dataset.dataset + for inp_ds in datasets + if inp_ds.dataset + } return reduce(sum_total, map(get_dataset_size, unique_datasets.values()), 0.0) else: return 0.0 def input_size(job): - return calculate_dataset_total(job.input_datasets)/GIGABYTES + return calculate_dataset_total(job.input_datasets) / GIGABYTES def weighted_random_sampling(destinations): if not destinations: return [] - rankings = [(d.params.get('weight', 1) if d.params else 1) for d in destinations] + rankings = [(d.params.get("weight", 1) if d.params else 1) for d in destinations] return random.choices(destinations, weights=rankings, k=len(destinations)) @@ -69,21 +73,25 @@ def job_args_match(job, app, args): try: options_value = reduce(dict.__getitem__, arg_keys_list, options) arg_value = reduce(dict.__getitem__, arg_keys_list, arg_dict) - if (arg_value != options_value): + if arg_value != options_value: matched = False except KeyError: matched = False return matched -def concurrent_job_count_for_tool(app, tool, user=None): # requires galaxy version >= 21.09 +def concurrent_job_count_for_tool( + app, tool, user=None +): # requires galaxy version >= 21.09 # Match all tools, regardless of version. For example, a tool id such as "fastqc/0.1.0+galaxy1" is # turned into "fastqc/.*" - tool_id_regex = '/'.join(tool.id.split('/')[:-1]) + '/.*' if '/' in tool.id else tool.id + tool_id_regex = ( + "/".join(tool.id.split("/")[:-1]) + "/.*" if "/" in tool.id else tool.id + ) query = app.model.context.query(model.Job) if user: query = query.filter(model.Job.table.c.user_id == user.id) - query = query.filter(model.Job.table.c.state.in_(['queued', 'running'])) + query = query.filter(model.Job.table.c.state.in_(["queued", "running"])) query = query.filter(model.Job.table.c.tool_id.regexp_match(tool_id_regex)) return query.count() @@ -91,9 +99,16 @@ def concurrent_job_count_for_tool(app, tool, user=None): # requires galaxy vers def tag_values_match(entity, match_tag_values=[], exclude_tag_values=[]): # Return true if an entity has require/prefer/accept tags in the match_tags_values list # and no require/prefer/accept tags in the exclude_tag_values list - return ( - all([any(entity.tpv_tags.filter(tag_value=tag_value)) for tag_value in match_tag_values]) - and not any([any(entity.tpv_tags.filter(tag_value=tag_value)) for tag_value in exclude_tag_values]) + return all( + [ + any(entity.tpv_tags.filter(tag_value=tag_value)) + for tag_value in match_tag_values + ] + ) and not any( + [ + any(entity.tpv_tags.filter(tag_value=tag_value)) + for tag_value in exclude_tag_values + ] ) @@ -122,7 +137,8 @@ def get_dataset_attributes(datasets): # and file sizes in bytes for all input datasets in a job return { i.dataset.dataset.id: { - 'object_store_id': i.dataset.dataset.object_store_id, - 'size': get_dataset_size(i.dataset.dataset)} + "object_store_id": i.dataset.dataset.object_store_id, + "size": get_dataset_size(i.dataset.dataset), + } for i in datasets or {} } diff --git a/tpv/core/mapper.py b/tpv/core/mapper.py index e551b9d..84dc1d5 100644 --- a/tpv/core/mapper.py +++ b/tpv/core/mapper.py @@ -19,7 +19,9 @@ def __init__(self, loader: TPVConfigLoader): self.destinations = self.config.destinations self.default_inherits = self.config.global_config.default_inherits self.global_context = self.config.global_config.context - self.lookup_tool_regex = functools.lru_cache(maxsize=None)(self.__compile_tool_regex) + self.lookup_tool_regex = functools.lru_cache(maxsize=None)( + self.__compile_tool_regex + ) # self.inherit_matching_entities = functools.lru_cache(maxsize=None)(self.__inherit_matching_entities) self.inherit_matching_entities = self.__inherit_matching_entities @@ -30,7 +32,9 @@ def __compile_tool_regex(self, key): log.error(f"Failed to compile regex: {key}") raise - def _find_entities_matching_id(self, entity_list: dict[str, Entity], entity_name: str): + def _find_entities_matching_id( + self, entity_list: dict[str, Entity], entity_name: str + ): default_inherits = self.__get_default_inherits(entity_list) if default_inherits: matches = [default_inherits] @@ -41,12 +45,17 @@ def _find_entities_matching_id(self, entity_list: dict[str, Entity], entity_name match = entity_list[key] if match.abstract: from galaxy.jobs.mapper import JobMappingException - raise JobMappingException(f"This entity is abstract and cannot be mapped : {match}") + + raise JobMappingException( + f"This entity is abstract and cannot be mapped : {match}" + ) else: matches.append(match) return matches - def __inherit_matching_entities(self, entity_list: dict[str, Entity], entity_name: str): + def __inherit_matching_entities( + self, entity_list: dict[str, Entity], entity_name: str + ): matches = self._find_entities_matching_id(entity_list, entity_name) return self.inherit_entities(matches) @@ -60,7 +69,10 @@ def __get_default_inherits(self, entity_list: dict[str, Entity]): def __apply_default_destination_inheritance(self, entity_list): default_inherits = self.__get_default_inherits(entity_list) if default_inherits: - return [self.inherit_entities([default_inherits, entity]) for entity in entity_list.values()] + return [ + self.inherit_entities([default_inherits, entity]) + for entity in entity_list.values() + ] else: return entity_list.values() @@ -82,8 +94,11 @@ def rank(self, entity, destinations, context): def match_and_rank_destinations(self, entity, destinations, context): # At this point, the resource requirements (cores, mem, gpus) are unevaluated. # So temporarily evaluate them so we can match up with a destination. - matches = [dest for dest in self.__apply_default_destination_inheritance(destinations) - if dest.matches(entity.evaluate_resources(context), context)] + matches = [ + dest + for dest in self.__apply_default_destination_inheritance(destinations) + if dest.matches(entity.evaluate_resources(context), context) + ] return self.rank(entity, matches, context) def to_galaxy_destination(self, destination): @@ -104,8 +119,11 @@ def _find_matching_entities(self, tool, user): entity_list = [tool_entity] if user: - role_entities = (self.inherit_matching_entities(self.config.roles, role.name) - for role in user.all_roles() if not role.deleted) + role_entities = ( + self.inherit_matching_entities(self.config.roles, role.name) + for role in user.all_roles() + if not role.deleted + ) # trim empty user_role_entities = (role for role in role_entities if role) user_role_entity = next(user_role_entities, None) @@ -124,18 +142,12 @@ def match_combine_evaluate_entities(self, context, tool, user): # 2. Combine entity requirements combined_entity = self.combine_entities(entity_list) - context.update({ - 'entity': combined_entity, - 'self': combined_entity - }) + context.update({"entity": combined_entity, "self": combined_entity}) # 3. Evaluate rules only, so that all expressions are collapsed into a flat entity. The final # values for expressions should be evaluated only after combining with the destination. evaluated_entity = combined_entity.evaluate_rules(context) - context.update({ - 'entity': evaluated_entity, - 'self': evaluated_entity - }) + context.update({"entity": evaluated_entity, "self": evaluated_entity}) # Remove the rules as they've already been evaluated, and should not be re-evaluated when combining # with destinations @@ -143,28 +155,40 @@ def match_combine_evaluate_entities(self, context, tool, user): return evaluated_entity - def map_to_destination(self, app, tool, user, job, job_wrapper=None, resource_params=None, - workflow_invocation_uuid=None): + def map_to_destination( + self, + app, + tool, + user, + job, + job_wrapper=None, + resource_params=None, + workflow_invocation_uuid=None, + ): # 1. Create evaluation context - these are the common variables available within any code block context = {} context.update(self.global_context or {}) - context.update({ - 'app': app, - 'tool': tool, - 'user': user, - 'job': job, - 'job_wrapper': job_wrapper, - 'resource_params': resource_params, - 'workflow_invocation_uuid': workflow_invocation_uuid, - 'mapper': self - }) + context.update( + { + "app": app, + "tool": tool, + "user": user, + "job": job, + "job_wrapper": job_wrapper, + "resource_params": resource_params, + "workflow_invocation_uuid": workflow_invocation_uuid, + "mapper": self, + } + ) # 2. Find, combine and evaluate entities that match this tool and user evaluated_entity = self.match_combine_evaluate_entities(context, tool, user) # 3. Match and rank destinations that best match the combined entity - ranked_dest_entities = self.match_and_rank_destinations(evaluated_entity, self.destinations, context) + ranked_dest_entities = self.match_and_rank_destinations( + evaluated_entity, self.destinations, context + ) # 4. Fully combine entity with matching destinations if ranked_dest_entities: @@ -176,8 +200,10 @@ def map_to_destination(self, app, tool, user, job, job_wrapper=None, resource_pa # 5. Return the top-ranked destination that evaluates successfully return self.to_galaxy_destination(evaluated_destination) except TryNextDestinationOrFail as ef: - log.exception(f"Destination entity: {d} matched but could not fulfill requirements due to: {ef}." - " Trying next candidate...") + log.exception( + f"Destination entity: {d} matched but could not fulfill requirements due to: {ef}." + " Trying next candidate..." + ) except TryNextDestinationOrWait: wait_exception_raised = True if wait_exception_raised: @@ -185,4 +211,7 @@ def map_to_destination(self, app, tool, user, job, job_wrapper=None, resource_pa # No matching destinations. Throw an exception from galaxy.jobs.mapper import JobMappingException - raise JobMappingException(f"No destinations are available to fulfill request: {evaluated_entity.id}") + + raise JobMappingException( + f"No destinations are available to fulfill request: {evaluated_entity.id}" + ) diff --git a/tpv/core/util.py b/tpv/core/util.py index 5dea468..3ae5dce 100644 --- a/tpv/core/util.py +++ b/tpv/core/util.py @@ -5,9 +5,9 @@ def load_yaml_from_url_or_path(url_or_path: str): - yaml = ruamel.yaml.YAML(typ='safe') + yaml = ruamel.yaml.YAML(typ="safe") if os.path.isfile(url_or_path): - with open(url_or_path, 'r') as f: + with open(url_or_path, "r") as f: return yaml.load(f) else: with requests.get(url_or_path) as r: diff --git a/tpv/rules/gateway.py b/tpv/rules/gateway.py index 8f739a6..c2acf29 100644 --- a/tpv/rules/gateway.py +++ b/tpv/rules/gateway.py @@ -31,23 +31,38 @@ def setup_destination_mapper(app, tpv_config_files): def reload_destination_mapper(path=None): # reload all config files when one file changes to preserve order of loading the files global ACTIVE_DESTINATION_MAPPER - ACTIVE_DESTINATION_MAPPER = load_destination_mapper(tpv_config_files, reload=True) + ACTIVE_DESTINATION_MAPPER = load_destination_mapper( + tpv_config_files, reload=True + ) for tpv_config_file in tpv_config_files: if os.path.isfile(tpv_config_file): log.info(f"Watching for changes in file: {tpv_config_file}") - CONFIG_WATCHERS[tpv_config_file] = ( - CONFIG_WATCHERS.get(tpv_config_file) or - get_watcher(app.config, 'watch_job_rules', monitor_what_str='job rules')) - CONFIG_WATCHERS[tpv_config_file].watch_file(tpv_config_file, callback=reload_destination_mapper) + CONFIG_WATCHERS[tpv_config_file] = CONFIG_WATCHERS.get( + tpv_config_file + ) or get_watcher( + app.config, "watch_job_rules", monitor_what_str="job rules" + ) + CONFIG_WATCHERS[tpv_config_file].watch_file( + tpv_config_file, callback=reload_destination_mapper + ) CONFIG_WATCHERS[tpv_config_file].start() return mapper -def map_tool_to_destination(app, job, tool, user, tpv_config_files, job_wrapper=None, resource_params=None, - workflow_invocation_uuid=None): +def map_tool_to_destination( + app, + job, + tool, + user, + tpv_config_files, + job_wrapper=None, + resource_params=None, + workflow_invocation_uuid=None, +): global ACTIVE_DESTINATION_MAPPER if not ACTIVE_DESTINATION_MAPPER: ACTIVE_DESTINATION_MAPPER = setup_destination_mapper(app, tpv_config_files) - return ACTIVE_DESTINATION_MAPPER.map_to_destination(app, tool, user, job, job_wrapper, resource_params, - workflow_invocation_uuid) + return ACTIVE_DESTINATION_MAPPER.map_to_destination( + app, tool, user, job, job_wrapper, resource_params, workflow_invocation_uuid + ) From af57588602418d479f6f06e365e13b54e66130a0 Mon Sep 17 00:00:00 2001 From: nuwang <2070605+nuwang@users.noreply.github.com> Date: Sun, 25 Aug 2024 19:05:49 +0530 Subject: [PATCH 29/30] Add format command and reformat test code --- tests/test_entity.py | 47 ++- tests/test_helpers.py | 11 +- tests/test_mapper_basic.py | 59 ++-- tests/test_mapper_context.py | 90 +++-- tests/test_mapper_destinations.py | 449 ++++++++++++++++++------- tests/test_mapper_inheritance.py | 217 ++++++++---- tests/test_mapper_merge_multiple.py | 270 +++++++++++---- tests/test_mapper_params_specific.py | 54 +-- tests/test_mapper_rank.py | 23 +- tests/test_mapper_resubmit.py | 27 +- tests/test_mapper_role.py | 96 ++++-- tests/test_mapper_rules.py | 413 ++++++++++++++++------- tests/test_mapper_sample.py | 21 +- tests/test_mapper_user.py | 79 +++-- tests/test_scenarios.py | 367 ++++++++++++++------ tests/test_shell.py | 479 +++++++++++++++++++-------- tox.ini | 3 + 17 files changed, 1949 insertions(+), 756 deletions(-) diff --git a/tests/test_entity.py b/tests/test_entity.py index 6d1e4cf..89f112b 100644 --- a/tests/test_entity.py +++ b/tests/test_entity.py @@ -1,27 +1,33 @@ import os import unittest -from tpv.rules import gateway -from tpv.core.entities import Destination -from tpv.core.entities import Tool -from tpv.core.loader import TPVConfigLoader + from tpv.commands.test import mock_galaxy +from tpv.core.entities import Destination, Tool +from tpv.core.loader import TPVConfigLoader +from tpv.rules import gateway class TestEntity(unittest.TestCase): @staticmethod def _map_to_destination(app, job, tool, user): - tpv_config = os.path.join(os.path.dirname(__file__), 'fixtures/mapping-rule-argument-based.yml') + tpv_config = os.path.join( + os.path.dirname(__file__), "fixtures/mapping-rule-argument-based.yml" + ) gateway.ACTIVE_DESTINATION_MAPPER = None - return gateway.map_tool_to_destination(app, job, tool, user, tpv_config_files=[tpv_config]) + return gateway.map_tool_to_destination( + app, job, tool, user, tpv_config_files=[tpv_config] + ) # issue: https://github.com/galaxyproject/total-perspective-vortex/issues/53 def test_all_entities_refer_to_same_loader(self): - app = mock_galaxy.App(job_conf=os.path.join(os.path.dirname(__file__), 'fixtures/job_conf.yml')) + app = mock_galaxy.App( + job_conf=os.path.join(os.path.dirname(__file__), "fixtures/job_conf.yml") + ) job = mock_galaxy.Job() - tool = mock_galaxy.Tool('bwa') - user = mock_galaxy.User('ford', 'prefect@vortex.org') + tool = mock_galaxy.Tool("bwa") + user = mock_galaxy.User("ford", "prefect@vortex.org") # just map something so the ACTIVE_DESTINATION_MAPPER is populated self._map_to_destination(app, job, tool, user) @@ -29,18 +35,21 @@ def test_all_entities_refer_to_same_loader(self): # get the original loader original_loader = gateway.ACTIVE_DESTINATION_MAPPER.loader - context = { - 'app': app, - 'job': job - } + context = {"app": app, "job": job} # make sure we are still referring to the same loader after evaluation - evaluated_entity = gateway.ACTIVE_DESTINATION_MAPPER.match_combine_evaluate_entities(context, tool, user) + evaluated_entity = ( + gateway.ACTIVE_DESTINATION_MAPPER.match_combine_evaluate_entities( + context, tool, user + ) + ) assert evaluated_entity.evaluator == original_loader for rule in evaluated_entity.rules: assert rule.evaluator == original_loader def test_destination_to_dict(self): - tpv_config = os.path.join(os.path.dirname(__file__), 'fixtures/mapping-rule-argument-based.yml') + tpv_config = os.path.join( + os.path.dirname(__file__), "fixtures/mapping-rule-argument-based.yml" + ) loader = TPVConfigLoader.from_url_or_path(tpv_config) # create a destination @@ -48,12 +57,16 @@ def test_destination_to_dict(self): # serialize the destination serialized_destination = destination.dict() # deserialize the same destination - deserialized_destination = Destination(evaluator=loader, **serialized_destination) + deserialized_destination = Destination( + evaluator=loader, **serialized_destination + ) # make sure the deserialized destination is the same as the original self.assertEqual(deserialized_destination, destination) def test_tool_to_dict(self): - tpv_config = os.path.join(os.path.dirname(__file__), 'fixtures/mapping-rule-argument-based.yml') + tpv_config = os.path.join( + os.path.dirname(__file__), "fixtures/mapping-rule-argument-based.yml" + ) loader = TPVConfigLoader.from_url_or_path(tpv_config) # create a tool diff --git a/tests/test_helpers.py b/tests/test_helpers.py index eee7492..221e52a 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -1,20 +1,25 @@ """Unit tests module for the helper functions""" + import unittest + from tpv.commands.test import mock_galaxy from tpv.core.helpers import get_dataset_attributes class TestHelpers(unittest.TestCase): """Tests for helper functions""" + def test_get_dataset_attributes(self): """Test that the function returns a dictionary with the correct attributes""" job = mock_galaxy.Job() job.add_input_dataset( mock_galaxy.DatasetAssociation( "test", - mock_galaxy.Dataset("test.txt", file_size=7*1024**3, object_store_id="files1") - ) + mock_galaxy.Dataset( + "test.txt", file_size=7 * 1024**3, object_store_id="files1" + ), ) + ) dataset_attributes = get_dataset_attributes(job.input_datasets) - expected_result = {0: {'object_store_id': 'files1', 'size': 7*1024**3}} + expected_result = {0: {"object_store_id": "files1", "size": 7 * 1024**3}} self.assertEqual(dataset_attributes, expected_result) diff --git a/tests/test_mapper_basic.py b/tests/test_mapper_basic.py index a70e079..fd3d938 100644 --- a/tests/test_mapper_basic.py +++ b/tests/test_mapper_basic.py @@ -1,66 +1,83 @@ import os import re import unittest -from tpv.rules import gateway -from tpv.commands.test import mock_galaxy + from galaxy.jobs.mapper import JobMappingException +from tpv.commands.test import mock_galaxy +from tpv.rules import gateway + class TestMapperBasic(unittest.TestCase): @staticmethod def _map_to_destination(tool, tpv_config_path=None): - galaxy_app = mock_galaxy.App(job_conf=os.path.join(os.path.dirname(__file__), 'fixtures/job_conf.yml')) + galaxy_app = mock_galaxy.App( + job_conf=os.path.join(os.path.dirname(__file__), "fixtures/job_conf.yml") + ) job = mock_galaxy.Job() - user = mock_galaxy.User('gargravarr', 'fairycake@vortex.org') - tpv_config = tpv_config_path or os.path.join(os.path.dirname(__file__), - 'fixtures/mapping-basic.yml') + user = mock_galaxy.User("gargravarr", "fairycake@vortex.org") + tpv_config = tpv_config_path or os.path.join( + os.path.dirname(__file__), "fixtures/mapping-basic.yml" + ) gateway.ACTIVE_DESTINATION_MAPPER = None - return gateway.map_tool_to_destination(galaxy_app, job, tool, user, tpv_config_files=[tpv_config]) + return gateway.map_tool_to_destination( + galaxy_app, job, tool, user, tpv_config_files=[tpv_config] + ) def test_map_default_tool(self): - tool = mock_galaxy.Tool('sometool') + tool = mock_galaxy.Tool("sometool") destination = self._map_to_destination(tool) self.assertEqual(destination.id, "local") def test_map_overridden_tool(self): - tool = mock_galaxy.Tool('bwa') + tool = mock_galaxy.Tool("bwa") destination = self._map_to_destination(tool) self.assertEqual(destination.id, "k8s_environment") def test_map_unschedulable_tool(self): - tool = mock_galaxy.Tool('unschedulable_tool') - with self.assertRaisesRegex(JobMappingException, "No destinations are available to fulfill request"): + tool = mock_galaxy.Tool("unschedulable_tool") + with self.assertRaisesRegex( + JobMappingException, "No destinations are available to fulfill request" + ): self._map_to_destination(tool) def test_map_invalidly_tagged_tool(self): - tool = mock_galaxy.Tool('invalidly_tagged_tool') - config = os.path.join(os.path.dirname(__file__), 'fixtures/mapping-invalid-tags.yml') - with self.assertRaisesRegex(Exception, r"Duplicate tags found: 'general' in \['require', 'reject'\]"): + tool = mock_galaxy.Tool("invalidly_tagged_tool") + config = os.path.join( + os.path.dirname(__file__), "fixtures/mapping-invalid-tags.yml" + ) + with self.assertRaisesRegex( + Exception, r"Duplicate tags found: 'general' in \['require', 'reject'\]" + ): self._map_to_destination(tool, tpv_config_path=config) def test_map_tool_by_regex(self): - tool = mock_galaxy.Tool('regex_tool_test') + tool = mock_galaxy.Tool("regex_tool_test") destination = self._map_to_destination(tool) self.assertEqual(destination.id, "k8s_environment") def test_map_tool_by_regex_mismatch(self): - tool = mock_galaxy.Tool('regex_t_test') + tool = mock_galaxy.Tool("regex_t_test") destination = self._map_to_destination(tool) self.assertEqual(destination.id, "local") def test_map_tool_with_invalid_regex(self): - tool = mock_galaxy.Tool('sometool') - config = os.path.join(os.path.dirname(__file__), 'fixtures/mapping-invalid-regex.yml') + tool = mock_galaxy.Tool("sometool") + config = os.path.join( + os.path.dirname(__file__), "fixtures/mapping-invalid-regex.yml" + ) with self.assertRaisesRegex(re.error, "bad escape"): self._map_to_destination(tool, tpv_config_path=config) def test_map_abstract_tool_should_fail(self): - tool = mock_galaxy.Tool('my_abstract_tool') - with self.assertRaisesRegex(JobMappingException, "This entity is abstract and cannot be mapped"): + tool = mock_galaxy.Tool("my_abstract_tool") + with self.assertRaisesRegex( + JobMappingException, "This entity is abstract and cannot be mapped" + ): self._map_to_destination(tool) def test_map_concrete_descendant_should_succeed(self): - tool = mock_galaxy.Tool('my_concrete_tool') + tool = mock_galaxy.Tool("my_concrete_tool") destination = self._map_to_destination(tool) self.assertEqual(destination.id, "local") diff --git a/tests/test_mapper_context.py b/tests/test_mapper_context.py index 905ddca..acefad5 100644 --- a/tests/test_mapper_context.py +++ b/tests/test_mapper_context.py @@ -1,56 +1,100 @@ import os import unittest -from tpv.rules import gateway + from tpv.commands.test import mock_galaxy +from tpv.rules import gateway class TestMapperContext(unittest.TestCase): @staticmethod def _map_to_destination(tool, user, datasets, tpv_config_path=None): - galaxy_app = mock_galaxy.App(job_conf=os.path.join(os.path.dirname(__file__), 'fixtures/job_conf.yml')) + galaxy_app = mock_galaxy.App( + job_conf=os.path.join(os.path.dirname(__file__), "fixtures/job_conf.yml") + ) job = mock_galaxy.Job() for d in datasets: job.add_input_dataset(d) - tpv_config = tpv_config_path or os.path.join(os.path.dirname(__file__), - 'fixtures/mapping-context.yml') + tpv_config = tpv_config_path or os.path.join( + os.path.dirname(__file__), "fixtures/mapping-context.yml" + ) gateway.ACTIVE_DESTINATION_MAPPER = None - return gateway.map_tool_to_destination(galaxy_app, job, tool, user, tpv_config_files=[tpv_config]) + return gateway.map_tool_to_destination( + galaxy_app, job, tool, user, tpv_config_files=[tpv_config] + ) def test_map_context_default_overrides_global(self): - tool = mock_galaxy.Tool('trinity') - user = mock_galaxy.User('gargravarr', 'fairycake@vortex.org') - datasets = [mock_galaxy.DatasetAssociation("test", mock_galaxy.Dataset("test.txt", file_size=5*1024**3))] + tool = mock_galaxy.Tool("trinity") + user = mock_galaxy.User("gargravarr", "fairycake@vortex.org") + datasets = [ + mock_galaxy.DatasetAssociation( + "test", mock_galaxy.Dataset("test.txt", file_size=5 * 1024**3) + ) + ] destination = self._map_to_destination(tool, user, datasets) self.assertEqual(destination.id, "local") - self.assertEqual([env['value'] for env in destination.env if env['name'] == 'TEST_JOB_SLOTS'], ['3']) - self.assertEqual(destination.params['native_spec'], '--mem 9 --cores 3 --gpus 3') + self.assertEqual( + [ + env["value"] + for env in destination.env + if env["name"] == "TEST_JOB_SLOTS" + ], + ["3"], + ) + self.assertEqual( + destination.params["native_spec"], "--mem 9 --cores 3 --gpus 3" + ) def test_map_tool_overrides_default(self): - tool = mock_galaxy.Tool('bwa') - user = mock_galaxy.User('gargravarr', 'fairycake@vortex.org') - datasets = [mock_galaxy.DatasetAssociation("test", mock_galaxy.Dataset("test.txt", file_size=5*1024**3))] + tool = mock_galaxy.Tool("bwa") + user = mock_galaxy.User("gargravarr", "fairycake@vortex.org") + datasets = [ + mock_galaxy.DatasetAssociation( + "test", mock_galaxy.Dataset("test.txt", file_size=5 * 1024**3) + ) + ] destination = self._map_to_destination(tool, user, datasets) self.assertEqual(destination.id, "k8s_environment") - self.assertEqual([env['value'] for env in destination.env if env['name'] == 'TEST_JOB_SLOTS'], ['5']) - self.assertEqual(destination.params['native_spec'], '--mem 15 --cores 5 --gpus 4') + self.assertEqual( + [ + env["value"] + for env in destination.env + if env["name"] == "TEST_JOB_SLOTS" + ], + ["5"], + ) + self.assertEqual( + destination.params["native_spec"], "--mem 15 --cores 5 --gpus 4" + ) def test_context_variable_overridden_in_rule(self): # test that job will not fail with 40GB input size because large_input_size has been set to 60 - tool = mock_galaxy.Tool('bwa') - user = mock_galaxy.User('gargravarr', 'fairycake@vortex.org') - datasets = [mock_galaxy.DatasetAssociation("test", mock_galaxy.Dataset("test.txt", file_size=40*1024**3))] + tool = mock_galaxy.Tool("bwa") + user = mock_galaxy.User("gargravarr", "fairycake@vortex.org") + datasets = [ + mock_galaxy.DatasetAssociation( + "test", mock_galaxy.Dataset("test.txt", file_size=40 * 1024**3) + ) + ] destination = self._map_to_destination(tool, user, datasets) - self.assertEqual(destination.params['native_spec'], '--mem 15 --cores 5 --gpus 2') + self.assertEqual( + destination.params["native_spec"], "--mem 15 --cores 5 --gpus 2" + ) def test_context_variable_defined_for_tool_in_rule(self): # test that context variable set for tool entity but not set in ancestor entities is defined - tool = mock_galaxy.Tool('canu') - user = mock_galaxy.User('gargravarr', 'fairycake@vortex.org') - datasets = [mock_galaxy.DatasetAssociation("test", mock_galaxy.Dataset("test.txt", file_size=3*1024**3))] + tool = mock_galaxy.Tool("canu") + user = mock_galaxy.User("gargravarr", "fairycake@vortex.org") + datasets = [ + mock_galaxy.DatasetAssociation( + "test", mock_galaxy.Dataset("test.txt", file_size=3 * 1024**3) + ) + ] destination = self._map_to_destination(tool, user, datasets) - self.assertEqual(destination.params['native_spec'], '--mem 9 --cores 3 --gpus 1') + self.assertEqual( + destination.params["native_spec"], "--mem 9 --cores 3 --gpus 1" + ) diff --git a/tests/test_mapper_destinations.py b/tests/test_mapper_destinations.py index 49d6dfe..7a3ff62 100644 --- a/tests/test_mapper_destinations.py +++ b/tests/test_mapper_destinations.py @@ -1,224 +1,449 @@ import os import unittest -from tpv.rules import gateway -from tpv.commands.test import mock_galaxy + from galaxy.jobs.mapper import JobMappingException +from tpv.commands.test import mock_galaxy +from tpv.rules import gateway + class TestMapperDestinations(unittest.TestCase): @staticmethod def _map_to_destination(tool, user, datasets, tpv_config_paths): - galaxy_app = mock_galaxy.App(job_conf=os.path.join(os.path.dirname(__file__), 'fixtures/job_conf.yml')) + galaxy_app = mock_galaxy.App( + job_conf=os.path.join(os.path.dirname(__file__), "fixtures/job_conf.yml") + ) job = mock_galaxy.Job() for d in datasets: job.add_input_dataset(d) gateway.ACTIVE_DESTINATION_MAPPER = None - return gateway.map_tool_to_destination(galaxy_app, job, tool, user, tpv_config_files=tpv_config_paths) + return gateway.map_tool_to_destination( + galaxy_app, job, tool, user, tpv_config_files=tpv_config_paths + ) def test_destination_no_rule_match(self): - tool = mock_galaxy.Tool('bwa') - user = mock_galaxy.User('ford', 'prefect@vortex.org') + tool = mock_galaxy.Tool("bwa") + user = mock_galaxy.User("ford", "prefect@vortex.org") - config = os.path.join(os.path.dirname(__file__), 'fixtures/mapping-destinations.yml') + config = os.path.join( + os.path.dirname(__file__), "fixtures/mapping-destinations.yml" + ) # an intermediate file size should compute correct values - datasets = [mock_galaxy.DatasetAssociation("test", mock_galaxy.Dataset("test.txt", file_size=7*1024**3))] - destination = self._map_to_destination(tool, user, datasets, tpv_config_paths=[config]) + datasets = [ + mock_galaxy.DatasetAssociation( + "test", mock_galaxy.Dataset("test.txt", file_size=7 * 1024**3) + ) + ] + destination = self._map_to_destination( + tool, user, datasets, tpv_config_paths=[config] + ) self.assertEqual(destination.id, "k8s_environment") - self.assertEqual([env['value'] for env in destination.env if env['name'] == 'TEST_JOB_SLOTS'], ['2']) - self.assertEqual([env['value'] for env in destination.env if env['name'] == 'SPECIAL_FLAG'], ['first']) - self.assertEqual([env['value'] for env in destination.env if env['name'] == 'DOCKER_ENABLED'], ['true']) - self.assertEqual(destination.params['memory_requests'], '6') + self.assertEqual( + [ + env["value"] + for env in destination.env + if env["name"] == "TEST_JOB_SLOTS" + ], + ["2"], + ) + self.assertEqual( + [env["value"] for env in destination.env if env["name"] == "SPECIAL_FLAG"], + ["first"], + ) + self.assertEqual( + [ + env["value"] + for env in destination.env + if env["name"] == "DOCKER_ENABLED" + ], + ["true"], + ) + self.assertEqual(destination.params["memory_requests"], "6") def test_destination_rule_match_once(self): - tool = mock_galaxy.Tool('bwa') - user = mock_galaxy.User('ford', 'prefect@vortex.org') + tool = mock_galaxy.Tool("bwa") + user = mock_galaxy.User("ford", "prefect@vortex.org") - config = os.path.join(os.path.dirname(__file__), 'fixtures/mapping-destinations.yml') + config = os.path.join( + os.path.dirname(__file__), "fixtures/mapping-destinations.yml" + ) # an intermediate file size should compute correct values - datasets = [mock_galaxy.DatasetAssociation("test", mock_galaxy.Dataset("test.txt", file_size=12*1024**3))] - destination = self._map_to_destination(tool, user, datasets, tpv_config_paths=[config]) + datasets = [ + mock_galaxy.DatasetAssociation( + "test", mock_galaxy.Dataset("test.txt", file_size=12 * 1024**3) + ) + ] + destination = self._map_to_destination( + tool, user, datasets, tpv_config_paths=[config] + ) self.assertEqual(destination.id, "another_k8s_environment") - self.assertEqual([env['value'] for env in destination.env if env['name'] == 'TEST_JOB_SLOTS'], ['2']) - self.assertEqual([env['value'] for env in destination.env if env['name'] == 'TEST_ENTITY_PRIORITY'], ['4']) - self.assertEqual([env['value'] for env in destination.env if env['name'] == 'SPECIAL_FLAG'], ['second']) - self.assertEqual([env['value'] for env in destination.env if env['name'] == 'DOCKER_ENABLED'], []) - self.assertEqual(destination.params['memory_requests'], '12') + self.assertEqual( + [ + env["value"] + for env in destination.env + if env["name"] == "TEST_JOB_SLOTS" + ], + ["2"], + ) + self.assertEqual( + [ + env["value"] + for env in destination.env + if env["name"] == "TEST_ENTITY_PRIORITY" + ], + ["4"], + ) + self.assertEqual( + [env["value"] for env in destination.env if env["name"] == "SPECIAL_FLAG"], + ["second"], + ) + self.assertEqual( + [ + env["value"] + for env in destination.env + if env["name"] == "DOCKER_ENABLED" + ], + [], + ) + self.assertEqual(destination.params["memory_requests"], "12") def test_destination_rule_match_twice(self): - tool = mock_galaxy.Tool('bwa') - user = mock_galaxy.User('ford', 'prefect@vortex.org') + tool = mock_galaxy.Tool("bwa") + user = mock_galaxy.User("ford", "prefect@vortex.org") - config = os.path.join(os.path.dirname(__file__), 'fixtures/mapping-destinations.yml') + config = os.path.join( + os.path.dirname(__file__), "fixtures/mapping-destinations.yml" + ) # an intermediate file size should compute correct values - datasets = [mock_galaxy.DatasetAssociation("test", mock_galaxy.Dataset("test.txt", file_size=22*1024**3))] - with self.assertRaisesRegex(JobMappingException, "No destinations are available to fulfill request: bwa"): + datasets = [ + mock_galaxy.DatasetAssociation( + "test", mock_galaxy.Dataset("test.txt", file_size=22 * 1024**3) + ) + ] + with self.assertRaisesRegex( + JobMappingException, "No destinations are available to fulfill request: bwa" + ): self._map_to_destination(tool, user, datasets, tpv_config_paths=[config]) def test_destination_inheritance(self): - tool = mock_galaxy.Tool('inheritance_test_tool') - user = mock_galaxy.User('ford', 'prefect@vortex.org') + tool = mock_galaxy.Tool("inheritance_test_tool") + user = mock_galaxy.User("ford", "prefect@vortex.org") - config = os.path.join(os.path.dirname(__file__), 'fixtures/mapping-destinations.yml') + config = os.path.join( + os.path.dirname(__file__), "fixtures/mapping-destinations.yml" + ) # an intermediate file size should compute correct values - datasets = [mock_galaxy.DatasetAssociation("test", mock_galaxy.Dataset("test.txt", file_size=12*1024**3))] - destination = self._map_to_destination(tool, user, datasets, tpv_config_paths=[config]) + datasets = [ + mock_galaxy.DatasetAssociation( + "test", mock_galaxy.Dataset("test.txt", file_size=12 * 1024**3) + ) + ] + destination = self._map_to_destination( + tool, user, datasets, tpv_config_paths=[config] + ) self.assertEqual(destination.id, "inherited_k8s_environment") - self.assertEqual([env['value'] for env in destination.env if env['name'] == 'TEST_JOB_SLOTS'], ['2']) - self.assertEqual([env['value'] for env in destination.env if env['name'] == 'TEST_ENTITY_PRIORITY'], ['4']) - self.assertEqual([env['value'] for env in destination.env if env['name'] == 'TEST_ENTITY_GPUS'], ['0']) - self.assertEqual([env['value'] for env in destination.env if env['name'] == 'SPECIAL_FLAG'], ['third']) - self.assertEqual([env['value'] for env in destination.env if env['name'] == 'DOCKER_ENABLED'], []) - self.assertEqual(destination.params['memory_requests'], '18') + self.assertEqual( + [ + env["value"] + for env in destination.env + if env["name"] == "TEST_JOB_SLOTS" + ], + ["2"], + ) + self.assertEqual( + [ + env["value"] + for env in destination.env + if env["name"] == "TEST_ENTITY_PRIORITY" + ], + ["4"], + ) + self.assertEqual( + [ + env["value"] + for env in destination.env + if env["name"] == "TEST_ENTITY_GPUS" + ], + ["0"], + ) + self.assertEqual( + [env["value"] for env in destination.env if env["name"] == "SPECIAL_FLAG"], + ["third"], + ) + self.assertEqual( + [ + env["value"] + for env in destination.env + if env["name"] == "DOCKER_ENABLED" + ], + [], + ) + self.assertEqual(destination.params["memory_requests"], "18") def test_destination_can_raise_not_ready_exception(self): - tool = mock_galaxy.Tool('three_core_test_tool') - user = mock_galaxy.User('tricia', 'tmcmillan@vortex.org') - - config = os.path.join(os.path.dirname(__file__), 'fixtures/mapping-destinations.yml') - - datasets = [mock_galaxy.DatasetAssociation("test", mock_galaxy.Dataset("test.txt", file_size=12*1024**3))] + tool = mock_galaxy.Tool("three_core_test_tool") + user = mock_galaxy.User("tricia", "tmcmillan@vortex.org") + + config = os.path.join( + os.path.dirname(__file__), "fixtures/mapping-destinations.yml" + ) + + datasets = [ + mock_galaxy.DatasetAssociation( + "test", mock_galaxy.Dataset("test.txt", file_size=12 * 1024**3) + ) + ] from galaxy.jobs.mapper import JobNotReadyException + with self.assertRaises(JobNotReadyException): - destination = self._map_to_destination(tool, user, datasets, tpv_config_paths=[config]) + destination = self._map_to_destination( + tool, user, datasets, tpv_config_paths=[config] + ) print(destination) def test_custom_destination_naming(self): - tool = mock_galaxy.Tool('custom_tool') - user = mock_galaxy.User('ford', 'prefect@vortex.org') + tool = mock_galaxy.Tool("custom_tool") + user = mock_galaxy.User("ford", "prefect@vortex.org") - config = os.path.join(os.path.dirname(__file__), 'fixtures/mapping-destinations.yml') + config = os.path.join( + os.path.dirname(__file__), "fixtures/mapping-destinations.yml" + ) # an intermediate file size should compute correct values - datasets = [mock_galaxy.DatasetAssociation("test", mock_galaxy.Dataset("test.txt", file_size=12*1024**3))] - destination = self._map_to_destination(tool, user, datasets, tpv_config_paths=[config]) + datasets = [ + mock_galaxy.DatasetAssociation( + "test", mock_galaxy.Dataset("test.txt", file_size=12 * 1024**3) + ) + ] + destination = self._map_to_destination( + tool, user, datasets, tpv_config_paths=[config] + ) self.assertEqual(destination.id, "my-dest-with-2-cores-6-mem") def test_destination_with_handler_tags(self): - tool = mock_galaxy.Tool('tool_with_handler_tags') - user = mock_galaxy.User('ford', 'prefect@vortex.org') + tool = mock_galaxy.Tool("tool_with_handler_tags") + user = mock_galaxy.User("ford", "prefect@vortex.org") - config = os.path.join(os.path.dirname(__file__), 'fixtures/mapping-destinations.yml') + config = os.path.join( + os.path.dirname(__file__), "fixtures/mapping-destinations.yml" + ) # an intermediate file size should compute correct values - datasets = [mock_galaxy.DatasetAssociation("test", mock_galaxy.Dataset("test.txt", file_size=12*1024**3))] - destination = self._map_to_destination(tool, user, datasets, tpv_config_paths=[config]) + datasets = [ + mock_galaxy.DatasetAssociation( + "test", mock_galaxy.Dataset("test.txt", file_size=12 * 1024**3) + ) + ] + destination = self._map_to_destination( + tool, user, datasets, tpv_config_paths=[config] + ) self.assertEqual(destination.id, "destination_with_handler_tags") self.assertEqual(destination.tags, ["registered_user_concurrent_jobs_20"]) def test_abstract_destinations_are_not_schedulable(self): - tool = mock_galaxy.Tool('tool_matching_abstract_dest') - user = mock_galaxy.User('ford', 'prefect@vortex.org') + tool = mock_galaxy.Tool("tool_matching_abstract_dest") + user = mock_galaxy.User("ford", "prefect@vortex.org") - config = os.path.join(os.path.dirname(__file__), 'fixtures/mapping-destinations.yml') + config = os.path.join( + os.path.dirname(__file__), "fixtures/mapping-destinations.yml" + ) # an intermediate file size should compute correct values - datasets = [mock_galaxy.DatasetAssociation("test", mock_galaxy.Dataset("test.txt", file_size=12*1024**3))] + datasets = [ + mock_galaxy.DatasetAssociation( + "test", mock_galaxy.Dataset("test.txt", file_size=12 * 1024**3) + ) + ] with self.assertRaises(JobMappingException): self._map_to_destination(tool, user, datasets, tpv_config_paths=[config]) def test_concrete_destinations_are_schedulable(self): - tool = mock_galaxy.Tool('tool_matching_abstract_inherited_dest') - user = mock_galaxy.User('ford', 'prefect@vortex.org') + tool = mock_galaxy.Tool("tool_matching_abstract_inherited_dest") + user = mock_galaxy.User("ford", "prefect@vortex.org") - config = os.path.join(os.path.dirname(__file__), 'fixtures/mapping-destinations.yml') + config = os.path.join( + os.path.dirname(__file__), "fixtures/mapping-destinations.yml" + ) # an intermediate file size should compute correct values - datasets = [mock_galaxy.DatasetAssociation("test", mock_galaxy.Dataset("test.txt", file_size=12 * 1024 ** 3))] - destination = self._map_to_destination(tool, user, datasets, tpv_config_paths=[config]) + datasets = [ + mock_galaxy.DatasetAssociation( + "test", mock_galaxy.Dataset("test.txt", file_size=12 * 1024**3) + ) + ] + destination = self._map_to_destination( + tool, user, datasets, tpv_config_paths=[config] + ) self.assertEqual(destination.id, "my_concrete_destination") def test_accepted_cpus_honoured(self): - user = mock_galaxy.User('toolmatchuser', 'toolmatchuser@vortex.org') - config = os.path.join(os.path.dirname(__file__), 'fixtures/mapping-destinations.yml') - datasets = [mock_galaxy.DatasetAssociation("test", mock_galaxy.Dataset("test.txt", file_size=12 * 1024 ** 3))] + user = mock_galaxy.User("toolmatchuser", "toolmatchuser@vortex.org") + config = os.path.join( + os.path.dirname(__file__), "fixtures/mapping-destinations.yml" + ) + datasets = [ + mock_galaxy.DatasetAssociation( + "test", mock_galaxy.Dataset("test.txt", file_size=12 * 1024**3) + ) + ] # First tool should not match - tool = mock_galaxy.Tool('tool_for_testing_cpu_acceptance_non_match') - destination = self._map_to_destination(tool, user, datasets, tpv_config_paths=[config]) + tool = mock_galaxy.Tool("tool_for_testing_cpu_acceptance_non_match") + destination = self._map_to_destination( + tool, user, datasets, tpv_config_paths=[config] + ) self.assertEqual(destination.id, "destination_without_min_cpu_accepted") # second tool should match the min requirements for the destination - tool = mock_galaxy.Tool('tool_for_testing_cpu_acceptance_match') - destination = self._map_to_destination(tool, user, datasets, tpv_config_paths=[config]) + tool = mock_galaxy.Tool("tool_for_testing_cpu_acceptance_match") + destination = self._map_to_destination( + tool, user, datasets, tpv_config_paths=[config] + ) self.assertEqual(destination.id, "destination_with_min_cpu_accepted") # third tool requesting zero cpus should match either: # - destination_without_min_cpu_accepted # - destination_zero_max_cpu_accepted - tool = mock_galaxy.Tool('tool_for_testing_cpu_acceptance_zero') - destination = self._map_to_destination(tool, user, datasets, tpv_config_paths=[config]) - self.assertIn(destination.id, {"destination_without_min_cpu_accepted", "destination_zero_max_cpu_accepted"}) + tool = mock_galaxy.Tool("tool_for_testing_cpu_acceptance_zero") + destination = self._map_to_destination( + tool, user, datasets, tpv_config_paths=[config] + ) + self.assertIn( + destination.id, + { + "destination_without_min_cpu_accepted", + "destination_zero_max_cpu_accepted", + }, + ) def test_accepted_gpus_honoured(self): - user = mock_galaxy.User('toolmatchuser', 'toolmatchuser@vortex.org') - config = os.path.join(os.path.dirname(__file__), 'fixtures/mapping-destinations.yml') - datasets = [mock_galaxy.DatasetAssociation("test", mock_galaxy.Dataset("test.txt", file_size=12 * 1024 ** 3))] + user = mock_galaxy.User("toolmatchuser", "toolmatchuser@vortex.org") + config = os.path.join( + os.path.dirname(__file__), "fixtures/mapping-destinations.yml" + ) + datasets = [ + mock_galaxy.DatasetAssociation( + "test", mock_galaxy.Dataset("test.txt", file_size=12 * 1024**3) + ) + ] # First tool should not match - tool = mock_galaxy.Tool('tool_for_testing_gpu_acceptance_non_match') - destination = self._map_to_destination(tool, user, datasets, tpv_config_paths=[config]) + tool = mock_galaxy.Tool("tool_for_testing_gpu_acceptance_non_match") + destination = self._map_to_destination( + tool, user, datasets, tpv_config_paths=[config] + ) self.assertEqual(destination.id, "destination_without_min_gpu_accepted") # second tool should match the min requirements for the destination - tool = mock_galaxy.Tool('tool_for_testing_gpu_acceptance_match') - destination = self._map_to_destination(tool, user, datasets, tpv_config_paths=[config]) + tool = mock_galaxy.Tool("tool_for_testing_gpu_acceptance_match") + destination = self._map_to_destination( + tool, user, datasets, tpv_config_paths=[config] + ) self.assertEqual(destination.id, "destination_with_min_gpu_accepted") # third tool requesting zero gpus should match either: # - destination_without_min_gpu_accepted # - destination_zero_max_gpu_accepted - tool = mock_galaxy.Tool('tool_for_testing_gpu_acceptance_zero') - destination = self._map_to_destination(tool, user, datasets, tpv_config_paths=[config]) - self.assertIn(destination.id, {"destination_without_min_gpu_accepted", "destination_zero_max_gpu_accepted"}) + tool = mock_galaxy.Tool("tool_for_testing_gpu_acceptance_zero") + destination = self._map_to_destination( + tool, user, datasets, tpv_config_paths=[config] + ) + self.assertIn( + destination.id, + { + "destination_without_min_gpu_accepted", + "destination_zero_max_gpu_accepted", + }, + ) def test_accepted_mem_honoured(self): - user = mock_galaxy.User('toolmatchuser', 'toolmatchuser@vortex.org') - config = os.path.join(os.path.dirname(__file__), 'fixtures/mapping-destinations.yml') - datasets = [mock_galaxy.DatasetAssociation("test", mock_galaxy.Dataset("test.txt", file_size=12 * 1024 ** 3))] + user = mock_galaxy.User("toolmatchuser", "toolmatchuser@vortex.org") + config = os.path.join( + os.path.dirname(__file__), "fixtures/mapping-destinations.yml" + ) + datasets = [ + mock_galaxy.DatasetAssociation( + "test", mock_galaxy.Dataset("test.txt", file_size=12 * 1024**3) + ) + ] # First tool should not match - tool = mock_galaxy.Tool('tool_for_testing_mem_acceptance_non_match') - destination = self._map_to_destination(tool, user, datasets, tpv_config_paths=[config]) + tool = mock_galaxy.Tool("tool_for_testing_mem_acceptance_non_match") + destination = self._map_to_destination( + tool, user, datasets, tpv_config_paths=[config] + ) self.assertEqual(destination.id, "destination_without_min_mem_accepted") # second tool should match the min requirements for the destination - tool = mock_galaxy.Tool('tool_for_testing_mem_acceptance_match') - destination = self._map_to_destination(tool, user, datasets, tpv_config_paths=[config]) + tool = mock_galaxy.Tool("tool_for_testing_mem_acceptance_match") + destination = self._map_to_destination( + tool, user, datasets, tpv_config_paths=[config] + ) self.assertEqual(destination.id, "destination_with_min_mem_accepted") # third tool requesting zero mem should match either: # - destination_without_min_mem_accepted # - destination_zero_max_mem_accepted - tool = mock_galaxy.Tool('tool_for_testing_mem_acceptance_zero') - destination = self._map_to_destination(tool, user, datasets, tpv_config_paths=[config]) - self.assertIn(destination.id, {"destination_without_min_mem_accepted", "destination_zero_max_mem_accepted"}) + tool = mock_galaxy.Tool("tool_for_testing_mem_acceptance_zero") + destination = self._map_to_destination( + tool, user, datasets, tpv_config_paths=[config] + ) + self.assertIn( + destination.id, + { + "destination_without_min_mem_accepted", + "destination_zero_max_mem_accepted", + }, + ) def test_user_map_to_destination_accepting_offline(self): - user = mock_galaxy.User('albo', 'pulsar_canberra_user@act.au') - - config = os.path.join(os.path.dirname(__file__), 'fixtures/mapping-destinations.yml') - - tool = mock_galaxy.Tool('toolshed_hifiasm') - datasets = [mock_galaxy.DatasetAssociation("test", mock_galaxy.Dataset("test.txt", file_size=12 * 1024 ** 3))] - destination = self._map_to_destination(tool, user, datasets, tpv_config_paths=[config]) + user = mock_galaxy.User("albo", "pulsar_canberra_user@act.au") + + config = os.path.join( + os.path.dirname(__file__), "fixtures/mapping-destinations.yml" + ) + + tool = mock_galaxy.Tool("toolshed_hifiasm") + datasets = [ + mock_galaxy.DatasetAssociation( + "test", mock_galaxy.Dataset("test.txt", file_size=12 * 1024**3) + ) + ] + destination = self._map_to_destination( + tool, user, datasets, tpv_config_paths=[config] + ) self.assertEqual(destination.id, "pulsar-canberra") def test_destination_clamping(self): """_summary_ Any variables defined in a tool should be evaluated as late as possible, so that destination level - clamping works. + clamping works. """ - user = mock_galaxy.User('albo', 'pulsar_canberra_user@act.au') - - config = os.path.join(os.path.dirname(__file__), 'fixtures/mapping-destinations.yml') - - tool = mock_galaxy.Tool('tool_for_testing_resource_clamping') - datasets = [mock_galaxy.DatasetAssociation("test", mock_galaxy.Dataset("test.txt", file_size=12 * 1024 ** 3))] - destination = self._map_to_destination(tool, user, datasets, tpv_config_paths=[config]) + user = mock_galaxy.User("albo", "pulsar_canberra_user@act.au") + + config = os.path.join( + os.path.dirname(__file__), "fixtures/mapping-destinations.yml" + ) + + tool = mock_galaxy.Tool("tool_for_testing_resource_clamping") + datasets = [ + mock_galaxy.DatasetAssociation( + "test", mock_galaxy.Dataset("test.txt", file_size=12 * 1024**3) + ) + ] + destination = self._map_to_destination( + tool, user, datasets, tpv_config_paths=[config] + ) self.assertEqual(destination.id, "clamped_destination") - self.assertEqual([env['value'] for env in destination.env if env['name'] == 'MY_DEST_ENV'], ['cores: 8 mem: 24 gpus: 1']) - self.assertEqual([env['value'] for env in destination.env if env['name'] == 'MY_TOOL_ENV'], ['cores: 8 mem: 24 gpus: 1']) + self.assertEqual( + [env["value"] for env in destination.env if env["name"] == "MY_DEST_ENV"], + ["cores: 8 mem: 24 gpus: 1"], + ) + self.assertEqual( + [env["value"] for env in destination.env if env["name"] == "MY_TOOL_ENV"], + ["cores: 8 mem: 24 gpus: 1"], + ) diff --git a/tests/test_mapper_inheritance.py b/tests/test_mapper_inheritance.py index 75e8f5c..a7ff24f 100644 --- a/tests/test_mapper_inheritance.py +++ b/tests/test_mapper_inheritance.py @@ -1,116 +1,213 @@ import os import unittest -from tpv.rules import gateway + from tpv.commands.test import mock_galaxy from tpv.core.loader import InvalidParentException +from tpv.rules import gateway class TestMapperInheritance(unittest.TestCase): @staticmethod - def _map_to_destination(tool, user, datasets, tpv_config_path=None, tpv_config_files = []): - galaxy_app = mock_galaxy.App(job_conf=os.path.join(os.path.dirname(__file__), 'fixtures/job_conf.yml')) + def _map_to_destination( + tool, user, datasets, tpv_config_path=None, tpv_config_files=[] + ): + galaxy_app = mock_galaxy.App( + job_conf=os.path.join(os.path.dirname(__file__), "fixtures/job_conf.yml") + ) job = mock_galaxy.Job() for d in datasets: job.add_input_dataset(d) - tpv_config = tpv_config_path or os.path.join(os.path.dirname(__file__), - 'fixtures/mapping-inheritance.yml') + tpv_config = tpv_config_path or os.path.join( + os.path.dirname(__file__), "fixtures/mapping-inheritance.yml" + ) if not tpv_config_files: tpv_config_files = [tpv_config] gateway.ACTIVE_DESTINATION_MAPPER = None - return gateway.map_tool_to_destination(galaxy_app, job, tool, user, tpv_config_files=tpv_config_files) + return gateway.map_tool_to_destination( + galaxy_app, job, tool, user, tpv_config_files=tpv_config_files + ) def test_map_inherit_twice(self): - tool = mock_galaxy.Tool('trinity') - user = mock_galaxy.User('gargravarr', 'fairycake@vortex.org') - datasets = [mock_galaxy.DatasetAssociation("test", mock_galaxy.Dataset("test.txt", file_size=5*1024**3))] + tool = mock_galaxy.Tool("trinity") + user = mock_galaxy.User("gargravarr", "fairycake@vortex.org") + datasets = [ + mock_galaxy.DatasetAssociation( + "test", mock_galaxy.Dataset("test.txt", file_size=5 * 1024**3) + ) + ] destination = self._map_to_destination(tool, user, datasets) self.assertEqual(destination.id, "k8s_environment") - self.assertEqual([env['value'] for env in destination.env if env['name'] == 'TEST_JOB_SLOTS'], ['4']) - self.assertEqual(destination.params['native_spec'], '--mem 16 --cores 4 --gpus 3') + self.assertEqual( + [ + env["value"] + for env in destination.env + if env["name"] == "TEST_JOB_SLOTS" + ], + ["4"], + ) + self.assertEqual( + destination.params["native_spec"], "--mem 16 --cores 4 --gpus 3" + ) def test_map_inherit_thrice(self): - tool = mock_galaxy.Tool('hisat') - user = mock_galaxy.User('gargravarr', 'fairycake@vortex.org') - datasets = [mock_galaxy.DatasetAssociation("test", mock_galaxy.Dataset("test.txt", file_size=5*1024**3))] + tool = mock_galaxy.Tool("hisat") + user = mock_galaxy.User("gargravarr", "fairycake@vortex.org") + datasets = [ + mock_galaxy.DatasetAssociation( + "test", mock_galaxy.Dataset("test.txt", file_size=5 * 1024**3) + ) + ] destination = self._map_to_destination(tool, user, datasets) self.assertEqual(destination.id, "local") - self.assertEqual([env['value'] for env in destination.env if env['name'] == 'TEST_JOB_SLOTS'], ['4']) - self.assertEqual(destination.params['native_spec'], '--mem 16 --cores 4 --gpus 4') + self.assertEqual( + [ + env["value"] + for env in destination.env + if env["name"] == "TEST_JOB_SLOTS" + ], + ["4"], + ) + self.assertEqual( + destination.params["native_spec"], "--mem 16 --cores 4 --gpus 4" + ) def test_map_inherit_invalid(self): - tool = mock_galaxy.Tool('tophat') - user = mock_galaxy.User('gargravarr', 'fairycake@vortex.org') - datasets = [mock_galaxy.DatasetAssociation("test", mock_galaxy.Dataset("test.txt", file_size=5*1024**3))] - tpv_config_path = os.path.join(os.path.dirname(__file__), 'fixtures/mapping-inheritance-invalid.yml') + tool = mock_galaxy.Tool("tophat") + user = mock_galaxy.User("gargravarr", "fairycake@vortex.org") + datasets = [ + mock_galaxy.DatasetAssociation( + "test", mock_galaxy.Dataset("test.txt", file_size=5 * 1024**3) + ) + ] + tpv_config_path = os.path.join( + os.path.dirname(__file__), "fixtures/mapping-inheritance-invalid.yml" + ) with self.assertRaises(InvalidParentException): - self._map_to_destination(tool, user, datasets, tpv_config_path=tpv_config_path) + self._map_to_destination( + tool, user, datasets, tpv_config_path=tpv_config_path + ) def test_map_inherit_no_default(self): - tool = mock_galaxy.Tool('hisat') - user = mock_galaxy.User('gargravarr', 'fairycake@vortex.org') - datasets = [mock_galaxy.DatasetAssociation("test", mock_galaxy.Dataset("test.txt", file_size=5*1024**3))] - tpv_config_path = os.path.join(os.path.dirname(__file__), 'fixtures/mapping-inheritance-no-default.yml') + tool = mock_galaxy.Tool("hisat") + user = mock_galaxy.User("gargravarr", "fairycake@vortex.org") + datasets = [ + mock_galaxy.DatasetAssociation( + "test", mock_galaxy.Dataset("test.txt", file_size=5 * 1024**3) + ) + ] + tpv_config_path = os.path.join( + os.path.dirname(__file__), "fixtures/mapping-inheritance-no-default.yml" + ) - destination = self._map_to_destination(tool, user, datasets, tpv_config_path=tpv_config_path) + destination = self._map_to_destination( + tool, user, datasets, tpv_config_path=tpv_config_path + ) self.assertEqual(destination.id, "local") - self.assertFalse([env['value'] for env in destination.env if env['name'] == 'TEST_JOB_SLOTS']) - self.assertEqual(destination.params['another_spec'], '--gpus 4') + self.assertFalse( + [env["value"] for env in destination.env if env["name"] == "TEST_JOB_SLOTS"] + ) + self.assertEqual(destination.params["another_spec"], "--gpus 4") def test_map_inherit_no_default_no_tool_def(self): - tool = mock_galaxy.Tool('some_random_tool') - user = mock_galaxy.User('gargravarr', 'fairycake@vortex.org') - tpv_config_path = os.path.join(os.path.dirname(__file__), 'fixtures/mapping-inheritance-no-default.yml') - - destination = self._map_to_destination(tool, user, datasets=[], tpv_config_path=tpv_config_path) + tool = mock_galaxy.Tool("some_random_tool") + user = mock_galaxy.User("gargravarr", "fairycake@vortex.org") + tpv_config_path = os.path.join( + os.path.dirname(__file__), "fixtures/mapping-inheritance-no-default.yml" + ) + + destination = self._map_to_destination( + tool, user, datasets=[], tpv_config_path=tpv_config_path + ) self.assertEqual(destination.id, "local") # ref: https://github.com/galaxyproject/total-perspective-vortex/pull/68 def test_map_with_default_tool_in_same_file(self): - tool_id = 'toolshed.g2.bx.psu.edu/repos/bgruening/bionano_scaffold/bionano_scaffold/1.23a' - user = mock_galaxy.User('majikthise', 'majikthise@vortex.org') + tool_id = "toolshed.g2.bx.psu.edu/repos/bgruening/bionano_scaffold/bionano_scaffold/1.23a" + user = mock_galaxy.User("majikthise", "majikthise@vortex.org") tool = mock_galaxy.Tool(tool_id) tpv_config_files = [ - os.path.join(os.path.dirname(__file__), 'fixtures/default_tool/scenario-shared-rules.yml'), - os.path.join(os.path.dirname(__file__), 'fixtures/default_tool/scenario-local-config-default-tool.yml'), - os.path.join(os.path.dirname(__file__), 'fixtures/default_tool/scenario-local-config-tools.yml'), - os.path.join(os.path.dirname(__file__), 'fixtures/default_tool/scenario-local-config-destinations.yml'), + os.path.join( + os.path.dirname(__file__), + "fixtures/default_tool/scenario-shared-rules.yml", + ), + os.path.join( + os.path.dirname(__file__), + "fixtures/default_tool/scenario-local-config-default-tool.yml", + ), + os.path.join( + os.path.dirname(__file__), + "fixtures/default_tool/scenario-local-config-tools.yml", + ), + os.path.join( + os.path.dirname(__file__), + "fixtures/default_tool/scenario-local-config-destinations.yml", + ), ] - destination = self._map_to_destination(tool, user, datasets=[], tpv_config_files=tpv_config_files) - self.assertEqual('--cores=8 --mem=250', destination.params.get('submit_native_specification')) + destination = self._map_to_destination( + tool, user, datasets=[], tpv_config_files=tpv_config_files + ) + self.assertEqual( + "--cores=8 --mem=250", destination.params.get("submit_native_specification") + ) # ref: https://github.com/galaxyproject/total-perspective-vortex/pull/68 def test_map_with_default_rules_in_dedicated_file(self): - tool_id = 'toolshed.g2.bx.psu.edu/repos/bgruening/bionano_scaffold/bionano_scaffold/1.23a' - user = mock_galaxy.User('majikthise', 'majikthise@vortex.org') + tool_id = "toolshed.g2.bx.psu.edu/repos/bgruening/bionano_scaffold/bionano_scaffold/1.23a" + user = mock_galaxy.User("majikthise", "majikthise@vortex.org") tool = mock_galaxy.Tool(tool_id) tpv_config_files = [ - os.path.join(os.path.dirname(__file__), 'fixtures/default_tool/scenario-shared-rules.yml'), - os.path.join(os.path.dirname(__file__), 'fixtures/default_tool/scenario-local-config-tools-with-default.yml'), - os.path.join(os.path.dirname(__file__), 'fixtures/default_tool/scenario-local-config-destinations.yml'), + os.path.join( + os.path.dirname(__file__), + "fixtures/default_tool/scenario-shared-rules.yml", + ), + os.path.join( + os.path.dirname(__file__), + "fixtures/default_tool/scenario-local-config-tools-with-default.yml", + ), + os.path.join( + os.path.dirname(__file__), + "fixtures/default_tool/scenario-local-config-destinations.yml", + ), ] - destination = self._map_to_destination(tool, user, datasets=[], tpv_config_files=tpv_config_files) - self.assertEqual('--cores=8 --mem=250', destination.params.get('submit_native_specification')) + destination = self._map_to_destination( + tool, user, datasets=[], tpv_config_files=tpv_config_files + ) + self.assertEqual( + "--cores=8 --mem=250", destination.params.get("submit_native_specification") + ) def test_destination_inherits_runner_from_default(self): - tool_id = 'kraken2' - user = mock_galaxy.User('benjy', 'benjymouse@vortex.org') + tool_id = "kraken2" + user = mock_galaxy.User("benjy", "benjymouse@vortex.org") tool = mock_galaxy.Tool(tool_id) - datasets = [mock_galaxy.DatasetAssociation("test", mock_galaxy.Dataset("test.txt", file_size=5*1024**3))] + datasets = [ + mock_galaxy.DatasetAssociation( + "test", mock_galaxy.Dataset("test.txt", file_size=5 * 1024**3) + ) + ] destination = self._map_to_destination(tool, user, datasets) - self.assertEqual('destination_that_inherits_runner_from_default', destination.id) - self.assertEqual('local', destination.runner) + self.assertEqual( + "destination_that_inherits_runner_from_default", destination.id + ) + self.assertEqual("local", destination.runner) def test_general_destination_inheritance(self): - tool_id = 'kraken5' - user = mock_galaxy.User('frankie', 'frankiemouse@vortex.org') + tool_id = "kraken5" + user = mock_galaxy.User("frankie", "frankiemouse@vortex.org") tool = mock_galaxy.Tool(tool_id) - datasets = [mock_galaxy.DatasetAssociation("test", mock_galaxy.Dataset("test.txt", file_size=5*1024**3))] + datasets = [ + mock_galaxy.DatasetAssociation( + "test", mock_galaxy.Dataset("test.txt", file_size=5 * 1024**3) + ) + ] destination = self._map_to_destination(tool, user, datasets) - self.assertEqual('destination_that_inherits_everything_from_k8s', destination.id) - self.assertTrue('ABC' in [e.get('name') for e in destination.env]) - self.assertEqual('extra-args', destination.params.get('docker_extra')) - self.assertEqual('k8s', destination.runner) + self.assertEqual( + "destination_that_inherits_everything_from_k8s", destination.id + ) + self.assertTrue("ABC" in [e.get("name") for e in destination.env]) + self.assertEqual("extra-args", destination.params.get("docker_extra")) + self.assertEqual("k8s", destination.runner) diff --git a/tests/test_mapper_merge_multiple.py b/tests/test_mapper_merge_multiple.py index 4fa1c10..90a0ad5 100644 --- a/tests/test_mapper_merge_multiple.py +++ b/tests/test_mapper_merge_multiple.py @@ -1,123 +1,263 @@ import os import unittest -from tpv.rules import gateway -from tpv.commands.test import mock_galaxy + from galaxy.jobs.mapper import JobMappingException +from tpv.commands.test import mock_galaxy +from tpv.rules import gateway + class TestMapperMergeMultipleConfigs(unittest.TestCase): @staticmethod def _map_to_destination(tool, user, datasets, tpv_config_paths): - galaxy_app = mock_galaxy.App(job_conf=os.path.join(os.path.dirname(__file__), 'fixtures/job_conf.yml')) + galaxy_app = mock_galaxy.App( + job_conf=os.path.join(os.path.dirname(__file__), "fixtures/job_conf.yml") + ) job = mock_galaxy.Job() for d in datasets: job.add_input_dataset(d) gateway.ACTIVE_DESTINATION_MAPPER = None - return gateway.map_tool_to_destination(galaxy_app, job, tool, user, tpv_config_files=tpv_config_paths) + return gateway.map_tool_to_destination( + galaxy_app, job, tool, user, tpv_config_files=tpv_config_paths + ) def test_merge_remote_and_local(self): - tool = mock_galaxy.Tool('bwa') - user = mock_galaxy.User('ford', 'prefect@vortex.org') + tool = mock_galaxy.Tool("bwa") + user = mock_galaxy.User("ford", "prefect@vortex.org") - config_first = "https://github.com/galaxyproject/total-perspective-vortex/raw/main/" \ - "tests/fixtures/mapping-merge-multiple-remote.yml" - config_second = os.path.join(os.path.dirname(__file__), 'fixtures/mapping-merge-multiple-local.yml') + config_first = ( + "https://github.com/galaxyproject/total-perspective-vortex/raw/main/" + "tests/fixtures/mapping-merge-multiple-remote.yml" + ) + config_second = os.path.join( + os.path.dirname(__file__), "fixtures/mapping-merge-multiple-local.yml" + ) # a small file size should fail because of remote rule - datasets = [mock_galaxy.DatasetAssociation("test", mock_galaxy.Dataset("test.txt", file_size=1*1024**3))] - with self.assertRaisesRegex(JobMappingException, "We don't run piddling datasets"): - self._map_to_destination(tool, user, datasets, tpv_config_paths=[config_first, config_second]) + datasets = [ + mock_galaxy.DatasetAssociation( + "test", mock_galaxy.Dataset("test.txt", file_size=1 * 1024**3) + ) + ] + with self.assertRaisesRegex( + JobMappingException, "We don't run piddling datasets" + ): + self._map_to_destination( + tool, user, datasets, tpv_config_paths=[config_first, config_second] + ) # a large file size should fail because of local rule - datasets = [mock_galaxy.DatasetAssociation("test", mock_galaxy.Dataset("test.txt", file_size=25*1024**3))] - with self.assertRaisesRegex(JobMappingException, "Too much data, shouldn't run"): - self._map_to_destination(tool, user, datasets, tpv_config_paths=[config_first, config_second]) + datasets = [ + mock_galaxy.DatasetAssociation( + "test", mock_galaxy.Dataset("test.txt", file_size=25 * 1024**3) + ) + ] + with self.assertRaisesRegex( + JobMappingException, "Too much data, shouldn't run" + ): + self._map_to_destination( + tool, user, datasets, tpv_config_paths=[config_first, config_second] + ) # an intermediate file size should compute correct values - datasets = [mock_galaxy.DatasetAssociation("test", mock_galaxy.Dataset("test.txt", file_size=7*1024**3))] - destination = self._map_to_destination(tool, user, datasets, tpv_config_paths=[config_first, config_second]) + datasets = [ + mock_galaxy.DatasetAssociation( + "test", mock_galaxy.Dataset("test.txt", file_size=7 * 1024**3) + ) + ] + destination = self._map_to_destination( + tool, user, datasets, tpv_config_paths=[config_first, config_second] + ) self.assertEqual(destination.id, "k8s_environment") - self.assertEqual([env['value'] for env in destination.env if env['name'] == 'TEST_JOB_SLOTS'], ['4']) - self.assertEqual(destination.params['native_spec'], '--mem 8 --cores 2') - self.assertEqual(destination.params['custom_context_remote'], 'remote var') - self.assertEqual(destination.params['custom_context_local'], 'local var') - self.assertEqual(destination.params['custom_context_override'], 'local override') + self.assertEqual( + [ + env["value"] + for env in destination.env + if env["name"] == "TEST_JOB_SLOTS" + ], + ["4"], + ) + self.assertEqual(destination.params["native_spec"], "--mem 8 --cores 2") + self.assertEqual(destination.params["custom_context_remote"], "remote var") + self.assertEqual(destination.params["custom_context_local"], "local var") + self.assertEqual( + destination.params["custom_context_override"], "local override" + ) def test_merge_local_with_local(self): - tool = mock_galaxy.Tool('bwa') - user = mock_galaxy.User('ford', 'prefect@vortex.org') + tool = mock_galaxy.Tool("bwa") + user = mock_galaxy.User("ford", "prefect@vortex.org") - config_first = os.path.join(os.path.dirname(__file__), 'fixtures/mapping-merge-multiple-remote.yml') - config_second = os.path.join(os.path.dirname(__file__), 'fixtures/mapping-merge-multiple-local.yml') + config_first = os.path.join( + os.path.dirname(__file__), "fixtures/mapping-merge-multiple-remote.yml" + ) + config_second = os.path.join( + os.path.dirname(__file__), "fixtures/mapping-merge-multiple-local.yml" + ) # a small file size should fail because of remote rule - datasets = [mock_galaxy.DatasetAssociation("test", mock_galaxy.Dataset("test.txt", file_size=1*1024**3))] - with self.assertRaisesRegex(JobMappingException, "We don't run piddling datasets"): - self._map_to_destination(tool, user, datasets, tpv_config_paths=[config_first, config_second]) + datasets = [ + mock_galaxy.DatasetAssociation( + "test", mock_galaxy.Dataset("test.txt", file_size=1 * 1024**3) + ) + ] + with self.assertRaisesRegex( + JobMappingException, "We don't run piddling datasets" + ): + self._map_to_destination( + tool, user, datasets, tpv_config_paths=[config_first, config_second] + ) # a large file size should fail because of local rule - datasets = [mock_galaxy.DatasetAssociation("test", mock_galaxy.Dataset("test.txt", file_size=25*1024**3))] - with self.assertRaisesRegex(JobMappingException, "Too much data, shouldn't run"): - self._map_to_destination(tool, user, datasets, tpv_config_paths=[config_first, config_second]) + datasets = [ + mock_galaxy.DatasetAssociation( + "test", mock_galaxy.Dataset("test.txt", file_size=25 * 1024**3) + ) + ] + with self.assertRaisesRegex( + JobMappingException, "Too much data, shouldn't run" + ): + self._map_to_destination( + tool, user, datasets, tpv_config_paths=[config_first, config_second] + ) # an intermediate file size should compute correct values - datasets = [mock_galaxy.DatasetAssociation("test", mock_galaxy.Dataset("test.txt", file_size=7*1024**3))] - destination = self._map_to_destination(tool, user, datasets, tpv_config_paths=[config_first, config_second]) + datasets = [ + mock_galaxy.DatasetAssociation( + "test", mock_galaxy.Dataset("test.txt", file_size=7 * 1024**3) + ) + ] + destination = self._map_to_destination( + tool, user, datasets, tpv_config_paths=[config_first, config_second] + ) self.assertEqual(destination.id, "k8s_environment") - self.assertEqual([env['value'] for env in destination.env if env['name'] == 'TEST_JOB_SLOTS'], ['4']) - self.assertEqual(destination.params['native_spec'], '--mem 8 --cores 2') - self.assertEqual(destination.params['custom_context_remote'], 'remote var') - self.assertEqual(destination.params['custom_context_local'], 'local var') - self.assertEqual(destination.params['custom_context_override'], 'local override') + self.assertEqual( + [ + env["value"] + for env in destination.env + if env["name"] == "TEST_JOB_SLOTS" + ], + ["4"], + ) + self.assertEqual(destination.params["native_spec"], "--mem 8 --cores 2") + self.assertEqual(destination.params["custom_context_remote"], "remote var") + self.assertEqual(destination.params["custom_context_local"], "local var") + self.assertEqual( + destination.params["custom_context_override"], "local override" + ) def test_merge_rules(self): - tool = mock_galaxy.Tool('bwa') - user = mock_galaxy.User('ford', 'prefect@vortex.org') + tool = mock_galaxy.Tool("bwa") + user = mock_galaxy.User("ford", "prefect@vortex.org") - config_first = os.path.join(os.path.dirname(__file__), 'fixtures/mapping-merge-multiple-remote.yml') - config_second = os.path.join(os.path.dirname(__file__), 'fixtures/mapping-merge-multiple-local.yml') + config_first = os.path.join( + os.path.dirname(__file__), "fixtures/mapping-merge-multiple-remote.yml" + ) + config_second = os.path.join( + os.path.dirname(__file__), "fixtures/mapping-merge-multiple-local.yml" + ) # the highmem rule should take effect - datasets = [mock_galaxy.DatasetAssociation("test", mock_galaxy.Dataset("test.txt", file_size=42*1024**3))] + datasets = [ + mock_galaxy.DatasetAssociation( + "test", mock_galaxy.Dataset("test.txt", file_size=42 * 1024**3) + ) + ] with self.assertRaisesRegex(JobMappingException, "a different kind of error"): - self._map_to_destination(tool, user, datasets, tpv_config_paths=[config_first]) + self._map_to_destination( + tool, user, datasets, tpv_config_paths=[config_first] + ) # the highmem rule should not take effect for this size, as we've overridden it - datasets = [mock_galaxy.DatasetAssociation("test", mock_galaxy.Dataset("test.txt", file_size=42*1024**3))] - destination = self._map_to_destination(tool, user, datasets, tpv_config_paths=[config_first, config_second]) + datasets = [ + mock_galaxy.DatasetAssociation( + "test", mock_galaxy.Dataset("test.txt", file_size=42 * 1024**3) + ) + ] + destination = self._map_to_destination( + tool, user, datasets, tpv_config_paths=[config_first, config_second] + ) self.assertEqual(destination.id, "another_k8s_environment") def test_merge_rules_with_multiple_matches(self): - tool = mock_galaxy.Tool('hisat2') - user = mock_galaxy.User('ford', 'prefect@vortex.org') + tool = mock_galaxy.Tool("hisat2") + user = mock_galaxy.User("ford", "prefect@vortex.org") - config_first = os.path.join(os.path.dirname(__file__), 'fixtures/mapping-merge-multiple-remote.yml') - config_second = os.path.join(os.path.dirname(__file__), 'fixtures/mapping-merge-multiple-local.yml') + config_first = os.path.join( + os.path.dirname(__file__), "fixtures/mapping-merge-multiple-remote.yml" + ) + config_second = os.path.join( + os.path.dirname(__file__), "fixtures/mapping-merge-multiple-local.yml" + ) # the highmem rule should take effect, with local override winning - datasets = [mock_galaxy.DatasetAssociation("test", mock_galaxy.Dataset("test.txt", file_size=42*1024**3))] - destination = self._map_to_destination(tool, user, datasets, tpv_config_paths=[config_first, config_second]) + datasets = [ + mock_galaxy.DatasetAssociation( + "test", mock_galaxy.Dataset("test.txt", file_size=42 * 1024**3) + ) + ] + destination = self._map_to_destination( + tool, user, datasets, tpv_config_paths=[config_first, config_second] + ) self.assertEqual(destination.id, "another_k8s_environment") # since the last defined hisat2 contains overridden defaults, those defaults will apply - self.assertEqual([env['value'] for env in destination.env if env['name'] == 'TEST_JOB_SLOTS'], ['6']) + self.assertEqual( + [ + env["value"] + for env in destination.env + if env["name"] == "TEST_JOB_SLOTS" + ], + ["6"], + ) # this var is not overridden by the last defined defaults, and therefore, the remote value of cores*2 applies - self.assertEqual([env['value'] for env in destination.env if env['name'] == 'MORE_JOB_SLOTS'], ['12']) - self.assertEqual(destination.params['native_spec'], '--mem 18 --cores 6') + self.assertEqual( + [ + env["value"] + for env in destination.env + if env["name"] == "MORE_JOB_SLOTS" + ], + ["12"], + ) + self.assertEqual(destination.params["native_spec"], "--mem 18 --cores 6") def test_merge_rules_local_defaults_do_not_override_remote_tool(self): - tool = mock_galaxy.Tool('toolshed.g2.bx.psu.edu/repos/iuc/disco/disco/v1.0') - user = mock_galaxy.User('ford', 'prefect@vortex.org') + tool = mock_galaxy.Tool("toolshed.g2.bx.psu.edu/repos/iuc/disco/disco/v1.0") + user = mock_galaxy.User("ford", "prefect@vortex.org") - config_first = os.path.join(os.path.dirname(__file__), 'fixtures/mapping-merge-multiple-remote.yml') - config_second = os.path.join(os.path.dirname(__file__), 'fixtures/mapping-merge-multiple-local.yml') + config_first = os.path.join( + os.path.dirname(__file__), "fixtures/mapping-merge-multiple-remote.yml" + ) + config_second = os.path.join( + os.path.dirname(__file__), "fixtures/mapping-merge-multiple-local.yml" + ) # the disco rules should take effect, with local override winning - datasets = [mock_galaxy.DatasetAssociation("test", mock_galaxy.Dataset("test.txt", file_size=42*1024**3))] - destination = self._map_to_destination(tool, user, datasets, tpv_config_paths=[config_first, config_second]) + datasets = [ + mock_galaxy.DatasetAssociation( + "test", mock_galaxy.Dataset("test.txt", file_size=42 * 1024**3) + ) + ] + destination = self._map_to_destination( + tool, user, datasets, tpv_config_paths=[config_first, config_second] + ) self.assertEqual(destination.id, "k8s_environment") # since the last defined hisat2 contains overridden defaults, those defaults will apply - self.assertEqual([env['value'] for env in destination.env if env['name'] == 'DISCO_MAX_MEMORY'], ['24']) - self.assertEqual([env['value'] for env in destination.env if env['name'] == 'DISCO_MORE_PARAMS'], ['just another param']) + self.assertEqual( + [ + env["value"] + for env in destination.env + if env["name"] == "DISCO_MAX_MEMORY" + ], + ["24"], + ) + self.assertEqual( + [ + env["value"] + for env in destination.env + if env["name"] == "DISCO_MORE_PARAMS" + ], + ["just another param"], + ) # this var is not overridden by the last defined defaults, and therefore, the remote value applies - self.assertEqual(destination.params['native_spec'], '--mem 24 --cores 8') + self.assertEqual(destination.params["native_spec"], "--mem 24 --cores 8") diff --git a/tests/test_mapper_params_specific.py b/tests/test_mapper_params_specific.py index a220363..dcd5cd1 100644 --- a/tests/test_mapper_params_specific.py +++ b/tests/test_mapper_params_specific.py @@ -1,53 +1,63 @@ import os import unittest -from tpv.rules import gateway + from tpv.commands.test import mock_galaxy +from tpv.rules import gateway class TestParamsSpecific(unittest.TestCase): @staticmethod def _map_to_destination(tool, user): - galaxy_app = mock_galaxy.App(job_conf=os.path.join(os.path.dirname(__file__), 'fixtures/job_conf.yml')) + galaxy_app = mock_galaxy.App( + job_conf=os.path.join(os.path.dirname(__file__), "fixtures/job_conf.yml") + ) job = mock_galaxy.Job() - tpv_config = os.path.join(os.path.dirname(__file__), 'fixtures/mapping-params-specific.yml') + tpv_config = os.path.join( + os.path.dirname(__file__), "fixtures/mapping-params-specific.yml" + ) gateway.ACTIVE_DESTINATION_MAPPER = None - return gateway.map_tool_to_destination(galaxy_app, job, tool, user, tpv_config_files=[tpv_config]) + return gateway.map_tool_to_destination( + galaxy_app, job, tool, user, tpv_config_files=[tpv_config] + ) def test_default_does_not_inherit_descendant_params(self): - tool = mock_galaxy.Tool('bwa') - user = mock_galaxy.User('ford', 'prefect@vortex.org') + tool = mock_galaxy.Tool("bwa") + user = mock_galaxy.User("ford", "prefect@vortex.org") destination = self._map_to_destination(tool, user) - self.assertTrue('earth' not in destination.params) + self.assertTrue("earth" not in destination.params) def test_default_does_not_inherit_descendant_env(self): - tool = mock_galaxy.Tool('agrajag') - user = mock_galaxy.User('ford', 'prefect@vortex.org') + tool = mock_galaxy.Tool("agrajag") + user = mock_galaxy.User("ford", "prefect@vortex.org") destination = self._map_to_destination(tool, user) - self.assertTrue('JAVA_MEM' not in [e['name'] for e in destination.env]) + self.assertTrue("JAVA_MEM" not in [e["name"] for e in destination.env]) def test_map_complex_parameter(self): - tool = mock_galaxy.Tool('bwa') - user = mock_galaxy.User('gargravarr', 'fairycake@vortex.org') + tool = mock_galaxy.Tool("bwa") + user = mock_galaxy.User("gargravarr", "fairycake@vortex.org") destination = self._map_to_destination(tool, user) - self.assertEqual(destination.params['container_override'][0]['identifier'], 'busybox:ubuntu-14.04-2') + self.assertEqual( + destination.params["container_override"][0]["identifier"], + "busybox:ubuntu-14.04-2", + ) def test_env_with_int_value_is_converted_to_string(self): - tool = mock_galaxy.Tool('grappa') - user = mock_galaxy.User('gargravarr', 'fairycake@vortex.org') + tool = mock_galaxy.Tool("grappa") + user = mock_galaxy.User("gargravarr", "fairycake@vortex.org") destination = self._map_to_destination(tool, user) - self.assertEqual(type(destination.env[1]['value']), str) - self.assertEqual(destination.env[1]['value'], '42') + self.assertEqual(type(destination.env[1]["value"]), str) + self.assertEqual(destination.env[1]["value"], "42") def test_param_with_int_or_bool_value_is_not_converted_to_string(self): - tool = mock_galaxy.Tool('grappa') - user = mock_galaxy.User('gargravarr', 'fairycake@vortex.org') + tool = mock_galaxy.Tool("grappa") + user = mock_galaxy.User("gargravarr", "fairycake@vortex.org") destination = self._map_to_destination(tool, user) - self.assertEqual(type(destination.params['is_a_bool']), bool) - self.assertEqual(destination.params['is_a_bool'], True) - self.assertEqual(destination.params['int_value'], 1010) + self.assertEqual(type(destination.params["is_a_bool"]), bool) + self.assertEqual(destination.params["is_a_bool"], True) + self.assertEqual(destination.params["int_value"], 1010) diff --git a/tests/test_mapper_rank.py b/tests/test_mapper_rank.py index 1787e96..d17858d 100644 --- a/tests/test_mapper_rank.py +++ b/tests/test_mapper_rank.py @@ -1,29 +1,36 @@ import os import unittest -from tpv.rules import gateway + from tpv.commands.test import mock_galaxy +from tpv.rules import gateway class TestMapperRank(unittest.TestCase): @staticmethod def _map_to_destination(tool, user): - galaxy_app = mock_galaxy.App(job_conf=os.path.join(os.path.dirname(__file__), 'fixtures/job_conf.yml')) + galaxy_app = mock_galaxy.App( + job_conf=os.path.join(os.path.dirname(__file__), "fixtures/job_conf.yml") + ) job = mock_galaxy.Job() - tpv_config = os.path.join(os.path.dirname(__file__), 'fixtures/mapping-rank.yml') + tpv_config = os.path.join( + os.path.dirname(__file__), "fixtures/mapping-rank.yml" + ) gateway.ACTIVE_DESTINATION_MAPPER = None - return gateway.map_tool_to_destination(galaxy_app, job, tool, user, tpv_config_files=[tpv_config]) + return gateway.map_tool_to_destination( + galaxy_app, job, tool, user, tpv_config_files=[tpv_config] + ) def test_map_custom_rank(self): - tool = mock_galaxy.Tool('bwa') - user = mock_galaxy.User('ford', 'prefect@vortex.org') + tool = mock_galaxy.Tool("bwa") + user = mock_galaxy.User("ford", "prefect@vortex.org") destination = self._map_to_destination(tool, user) self.assertEqual(destination.id, "k8s_environment") def test_map_default_rank_but_with_preference(self): - tool = mock_galaxy.Tool('trinity') - user = mock_galaxy.User('ford', 'prefect@vortex.org') + tool = mock_galaxy.Tool("trinity") + user = mock_galaxy.User("ford", "prefect@vortex.org") destination = self._map_to_destination(tool, user) self.assertEqual(destination.id, "another_k8s_environment") diff --git a/tests/test_mapper_resubmit.py b/tests/test_mapper_resubmit.py index 9af20af..03c3545 100644 --- a/tests/test_mapper_resubmit.py +++ b/tests/test_mapper_resubmit.py @@ -1,24 +1,31 @@ import os import pytest - -from galaxy_test.driver.integration_util import IntegrationTestCase from galaxy.webapps.base import webapp +from galaxy_test.driver.integration_util import IntegrationTestCase class TestMapperResubmission(IntegrationTestCase): - default_tool_conf = os.path.join(os.path.dirname(__file__), 'fixtures/resubmit/tool_conf_resubmit.xml') + default_tool_conf = os.path.join( + os.path.dirname(__file__), "fixtures/resubmit/tool_conf_resubmit.xml" + ) @classmethod def handle_galaxy_config_kwds(cls, config): - config["config_dir"] = os.path.join(os.path.dirname(__file__), 'fixtures') + config["config_dir"] = os.path.join(os.path.dirname(__file__), "fixtures") config["job_config_file"] = "job_conf_resubmit.yml" - config["tool_data_path"] = os.path.join(os.path.dirname(__file__), 'fixtures/resubmit') - config["tool_data_table_config_path"] = os.path.join(os.path.dirname(__file__), - 'fixtures/resubmit/tool_data_tables.xml.sample') - config["data_manager_config_file"] = os.path.join(os.path.dirname(__file__), - 'fixtures/resubmit/data_manager_conf.xml.sample') - config["template_path"] = os.path.abspath(os.path.join(os.path.dirname(webapp.__file__), 'templates')) + config["tool_data_path"] = os.path.join( + os.path.dirname(__file__), "fixtures/resubmit" + ) + config["tool_data_table_config_path"] = os.path.join( + os.path.dirname(__file__), "fixtures/resubmit/tool_data_tables.xml.sample" + ) + config["data_manager_config_file"] = os.path.join( + os.path.dirname(__file__), "fixtures/resubmit/data_manager_conf.xml.sample" + ) + config["template_path"] = os.path.abspath( + os.path.join(os.path.dirname(webapp.__file__), "templates") + ) def _assert_job_passes(self, tool_id="exit_code_oom", resource_parameters=None): resource_parameters = resource_parameters or {} diff --git a/tests/test_mapper_role.py b/tests/test_mapper_role.py index a7b65b4..27a5db1 100644 --- a/tests/test_mapper_role.py +++ b/tests/test_mapper_role.py @@ -1,68 +1,110 @@ import os import unittest -from tpv.rules import gateway -from tpv.commands.test import mock_galaxy + from galaxy.jobs.mapper import JobMappingException +from tpv.commands.test import mock_galaxy +from tpv.rules import gateway + class TestMapperRole(unittest.TestCase): @staticmethod def _map_to_destination(tool, user): - galaxy_app = mock_galaxy.App(job_conf=os.path.join(os.path.dirname(__file__), 'fixtures/job_conf.yml')) + galaxy_app = mock_galaxy.App( + job_conf=os.path.join(os.path.dirname(__file__), "fixtures/job_conf.yml") + ) job = mock_galaxy.Job() - tpv_config = os.path.join(os.path.dirname(__file__), 'fixtures/mapping-role.yml') + tpv_config = os.path.join( + os.path.dirname(__file__), "fixtures/mapping-role.yml" + ) gateway.ACTIVE_DESTINATION_MAPPER = None - return gateway.map_tool_to_destination(galaxy_app, job, tool, user, tpv_config_files=[tpv_config]) + return gateway.map_tool_to_destination( + galaxy_app, job, tool, user, tpv_config_files=[tpv_config] + ) def test_map_default_role(self): - tool = mock_galaxy.Tool('bwa') - user = mock_galaxy.User('ford', 'prefect@vortex.org') + tool = mock_galaxy.Tool("bwa") + user = mock_galaxy.User("ford", "prefect@vortex.org") destination = self._map_to_destination(tool, user) self.assertEqual(destination.id, "k8s_environment") def test_map_overridden_role(self): - tool = mock_galaxy.Tool('bwa') - user = mock_galaxy.User('gargravarr', 'fairycake@vortex.org', roles=["training"]) + tool = mock_galaxy.Tool("bwa") + user = mock_galaxy.User( + "gargravarr", "fairycake@vortex.org", roles=["training"] + ) - with self.assertRaisesRegex(JobMappingException, "No destinations are available to fulfill request"): + with self.assertRaisesRegex( + JobMappingException, "No destinations are available to fulfill request" + ): self._map_to_destination(tool, user) def test_map_role_by_regex(self): - tool = mock_galaxy.Tool('bwa') - user = mock_galaxy.User('gargravarr', 'fairycake@vortex.org', roles=["newtraining2021group"]) + tool = mock_galaxy.Tool("bwa") + user = mock_galaxy.User( + "gargravarr", "fairycake@vortex.org", roles=["newtraining2021group"] + ) destination = self._map_to_destination(tool, user) self.assertEqual(destination.id, "k8s_environment") def test_map_role_env_combine_order(self): - tool = mock_galaxy.Tool('bwa') - user = mock_galaxy.User('gargravarr', 'fairycake@vortex.org', roles=["newtraining2021group"]) + tool = mock_galaxy.Tool("bwa") + user = mock_galaxy.User( + "gargravarr", "fairycake@vortex.org", roles=["newtraining2021group"] + ) destination = self._map_to_destination(tool, user) self.assertEqual(destination.id, "k8s_environment") - self.assertEqual([env['value'] for env in destination.env if env['name'] == 'TOOL_AND_USER_DEFINED'], ['user']) - self.assertEqual([env['value'] for env in destination.env if env['name'] == 'TOOL_AND_ROLE_DEFINED'], ['role']) - self.assertEqual([env['value'] for env in destination.env if env['name'] == 'TOOL_USER_AND_ROLE_DEFINED'], - ['user']) - self.assertEqual([env['value'] for env in destination.env if env['name'] == 'USER_AND_ROLE_DEFINED'], - ['user']) + self.assertEqual( + [ + env["value"] + for env in destination.env + if env["name"] == "TOOL_AND_USER_DEFINED" + ], + ["user"], + ) + self.assertEqual( + [ + env["value"] + for env in destination.env + if env["name"] == "TOOL_AND_ROLE_DEFINED" + ], + ["role"], + ) + self.assertEqual( + [ + env["value"] + for env in destination.env + if env["name"] == "TOOL_USER_AND_ROLE_DEFINED" + ], + ["user"], + ) + self.assertEqual( + [ + env["value"] + for env in destination.env + if env["name"] == "USER_AND_ROLE_DEFINED" + ], + ["user"], + ) def test_map_role_training_matching_tag_values(self): - user = mock_galaxy.User('trin', 'tragula@perspective.org', roles=["seminar-ga"]) + user = mock_galaxy.User("trin", "tragula@perspective.org", roles=["seminar-ga"]) # test default training rule - tool = mock_galaxy.Tool('quast') + tool = mock_galaxy.Tool("quast") destination = self._map_to_destination(tool, user) - self.assertEqual(destination.params['native_spec'], '--mem 1 --cores 1') + self.assertEqual(destination.params["native_spec"], "--mem 1 --cores 1") # test training small pulsar rule - tool = mock_galaxy.Tool('fastqc') + tool = mock_galaxy.Tool("fastqc") destination = self._map_to_destination(tool, user) - self.assertEqual(destination.params['native_spec'], '--mem 2 --cores 2') + self.assertEqual(destination.params["native_spec"], "--mem 2 --cores 2") # test default training large pulsar rule - tool = mock_galaxy.Tool('flye') + tool = mock_galaxy.Tool("flye") destination = self._map_to_destination(tool, user) - self.assertEqual(destination.params['native_spec'], '--mem 3 --cores 3') + self.assertEqual(destination.params["native_spec"], "--mem 3 --cores 3") diff --git a/tests/test_mapper_rules.py b/tests/test_mapper_rules.py index 8c49be3..f82c084 100644 --- a/tests/test_mapper_rules.py +++ b/tests/test_mapper_rules.py @@ -1,164 +1,297 @@ import os -import time -import tempfile import shutil +import tempfile +import time import unittest -from tpv.rules import gateway + +from galaxy.jobs.mapper import JobMappingException, JobNotReadyException + from tpv.commands.test import mock_galaxy -from galaxy.jobs.mapper import JobMappingException -from galaxy.jobs.mapper import JobNotReadyException +from tpv.rules import gateway class TestMapperRules(unittest.TestCase): @staticmethod - def _map_to_destination(tool, user, datasets, param_values=None, tpv_config_files=None, app=None): - galaxy_app = app or mock_galaxy.App(job_conf=os.path.join(os.path.dirname(__file__), 'fixtures/job_conf.yml')) + def _map_to_destination( + tool, user, datasets, param_values=None, tpv_config_files=None, app=None + ): + galaxy_app = app or mock_galaxy.App( + job_conf=os.path.join(os.path.dirname(__file__), "fixtures/job_conf.yml") + ) job = mock_galaxy.Job() for d in datasets: job.add_input_dataset(d) if param_values: job.param_values = param_values - tpv_configs = tpv_config_files or [os.path.join(os.path.dirname(__file__), 'fixtures/mapping-rules.yml')] + tpv_configs = tpv_config_files or [ + os.path.join(os.path.dirname(__file__), "fixtures/mapping-rules.yml") + ] gateway.ACTIVE_DESTINATION_MAPPER = None - return gateway.map_tool_to_destination(galaxy_app, job, tool, user, tpv_config_files=tpv_configs) + return gateway.map_tool_to_destination( + galaxy_app, job, tool, user, tpv_config_files=tpv_configs + ) def test_map_rule_size_small(self): - tool = mock_galaxy.Tool('bwa') - user = mock_galaxy.User('ford', 'prefect@vortex.org') - datasets = [mock_galaxy.DatasetAssociation("test", mock_galaxy.Dataset("test.txt", file_size=1*1024**3))] + tool = mock_galaxy.Tool("bwa") + user = mock_galaxy.User("ford", "prefect@vortex.org") + datasets = [ + mock_galaxy.DatasetAssociation( + "test", mock_galaxy.Dataset("test.txt", file_size=1 * 1024**3) + ) + ] with self.assertRaises(JobMappingException): self._map_to_destination(tool, user, datasets) def test_map_rule_size_medium(self): - tool = mock_galaxy.Tool('bwa') - user = mock_galaxy.User('gargravarr', 'fairycake@vortex.org') - datasets = [mock_galaxy.DatasetAssociation("test", mock_galaxy.Dataset("test.txt", file_size=5*1024**3))] + tool = mock_galaxy.Tool("bwa") + user = mock_galaxy.User("gargravarr", "fairycake@vortex.org") + datasets = [ + mock_galaxy.DatasetAssociation( + "test", mock_galaxy.Dataset("test.txt", file_size=5 * 1024**3) + ) + ] destination = self._map_to_destination(tool, user, datasets) self.assertEqual(destination.id, "k8s_environment") - self.assertEqual([env['value'] for env in destination.env if env['name'] == 'TEST_JOB_SLOTS'], ['4']) - self.assertEqual(destination.params['native_spec'], '--mem 16 --cores 4') + self.assertEqual( + [ + env["value"] + for env in destination.env + if env["name"] == "TEST_JOB_SLOTS" + ], + ["4"], + ) + self.assertEqual(destination.params["native_spec"], "--mem 16 --cores 4") def test_map_rule_size_large(self): - tool = mock_galaxy.Tool('bwa') - user = mock_galaxy.User('gargravarr', 'fairycake@vortex.org') - datasets = [mock_galaxy.DatasetAssociation("test", mock_galaxy.Dataset("test.txt", file_size=15*1024**3))] + tool = mock_galaxy.Tool("bwa") + user = mock_galaxy.User("gargravarr", "fairycake@vortex.org") + datasets = [ + mock_galaxy.DatasetAssociation( + "test", mock_galaxy.Dataset("test.txt", file_size=15 * 1024**3) + ) + ] - with self.assertRaisesRegex(JobMappingException, "No destinations are available to fulfill request"): + with self.assertRaisesRegex( + JobMappingException, "No destinations are available to fulfill request" + ): self._map_to_destination(tool, user, datasets) def test_map_rule_user(self): - tool = mock_galaxy.Tool('bwa') - user = mock_galaxy.User('arthur', 'arthur@vortex.org') - datasets = [mock_galaxy.DatasetAssociation("test", mock_galaxy.Dataset("test.txt", file_size=15*1024**3))] + tool = mock_galaxy.Tool("bwa") + user = mock_galaxy.User("arthur", "arthur@vortex.org") + datasets = [ + mock_galaxy.DatasetAssociation( + "test", mock_galaxy.Dataset("test.txt", file_size=15 * 1024**3) + ) + ] with self.assertRaises(JobMappingException): self._map_to_destination(tool, user, datasets) def test_map_rule_user_params(self): - tool = mock_galaxy.Tool('bwa') - user = mock_galaxy.User('gargravarr', 'fairycake@vortex.org') - datasets = [mock_galaxy.DatasetAssociation("test", mock_galaxy.Dataset("test.txt", file_size=5*1024**3))] + tool = mock_galaxy.Tool("bwa") + user = mock_galaxy.User("gargravarr", "fairycake@vortex.org") + datasets = [ + mock_galaxy.DatasetAssociation( + "test", mock_galaxy.Dataset("test.txt", file_size=5 * 1024**3) + ) + ] destination = self._map_to_destination(tool, user, datasets) self.assertEqual(destination.id, "k8s_environment") - self.assertEqual([env['value'] for env in destination.env if env['name'] == 'TEST_JOB_SLOTS_USER'], ['4']) - self.assertEqual(destination.params['native_spec_user'], '--mem 16 --cores 4') + self.assertEqual( + [ + env["value"] + for env in destination.env + if env["name"] == "TEST_JOB_SLOTS_USER" + ], + ["4"], + ) + self.assertEqual(destination.params["native_spec_user"], "--mem 16 --cores 4") def test_rules_automatically_reload_on_update(self): - with tempfile.NamedTemporaryFile('w+t') as tmp_file: - rule_file = os.path.join(os.path.dirname(__file__), 'fixtures/mapping-rules.yml') + with tempfile.NamedTemporaryFile("w+t") as tmp_file: + rule_file = os.path.join( + os.path.dirname(__file__), "fixtures/mapping-rules.yml" + ) shutil.copy2(rule_file, tmp_file.name) - tool = mock_galaxy.Tool('bwa') - user = mock_galaxy.User('gargravarr', 'fairycake@vortex.org') - datasets = [mock_galaxy.DatasetAssociation("test", mock_galaxy.Dataset("test.txt", file_size=5*1024**3))] + tool = mock_galaxy.Tool("bwa") + user = mock_galaxy.User("gargravarr", "fairycake@vortex.org") + datasets = [ + mock_galaxy.DatasetAssociation( + "test", mock_galaxy.Dataset("test.txt", file_size=5 * 1024**3) + ) + ] - destination = self._map_to_destination(tool, user, datasets, tpv_config_files=[tmp_file.name]) - self.assertEqual([env['value'] for env in destination.env if env['name'] == 'TEST_JOB_SLOTS_USER'], ['4']) + destination = self._map_to_destination( + tool, user, datasets, tpv_config_files=[tmp_file.name] + ) + self.assertEqual( + [ + env["value"] + for env in destination.env + if env["name"] == "TEST_JOB_SLOTS_USER" + ], + ["4"], + ) # update the rule file - updated_rule_file = os.path.join(os.path.dirname(__file__), 'fixtures/mapping-rules-changed.yml') + updated_rule_file = os.path.join( + os.path.dirname(__file__), "fixtures/mapping-rules-changed.yml" + ) shutil.copy2(updated_rule_file, tmp_file.name) # wait for reload time.sleep(0.5) # should have loaded the new rules - destination = self._map_to_destination(tool, user, datasets, tpv_config_files=[tmp_file.name]) - self.assertEqual([env['value'] for env in destination.env if env['name'] == 'TEST_JOB_SLOTS_USER'], ['8']) + destination = self._map_to_destination( + tool, user, datasets, tpv_config_files=[tmp_file.name] + ) + self.assertEqual( + [ + env["value"] + for env in destination.env + if env["name"] == "TEST_JOB_SLOTS_USER" + ], + ["8"], + ) def test_multiple_files_automatically_reload_on_update(self): - with tempfile.NamedTemporaryFile('w+t') as tmp_file1, tempfile.NamedTemporaryFile('w+t') as tmp_file2: - rule_file = os.path.join(os.path.dirname(__file__), 'fixtures/mapping-rules.yml') + with tempfile.NamedTemporaryFile( + "w+t" + ) as tmp_file1, tempfile.NamedTemporaryFile("w+t") as tmp_file2: + rule_file = os.path.join( + os.path.dirname(__file__), "fixtures/mapping-rules.yml" + ) shutil.copy2(rule_file, tmp_file1.name) - rule_file = os.path.join(os.path.dirname(__file__), 'fixtures/mapping-rules-extra.yml') + rule_file = os.path.join( + os.path.dirname(__file__), "fixtures/mapping-rules-extra.yml" + ) shutil.copy2(rule_file, tmp_file2.name) - tool = mock_galaxy.Tool('bwa') - user = mock_galaxy.User('gargravarr', 'fairycake@vortex.org') - datasets = [mock_galaxy.DatasetAssociation("test", mock_galaxy.Dataset("test.txt", file_size=5*1024**3))] + tool = mock_galaxy.Tool("bwa") + user = mock_galaxy.User("gargravarr", "fairycake@vortex.org") + datasets = [ + mock_galaxy.DatasetAssociation( + "test", mock_galaxy.Dataset("test.txt", file_size=5 * 1024**3) + ) + ] - destination = self._map_to_destination(tool, user, datasets, tpv_config_files=[ - tmp_file1.name, tmp_file2.name]) - self.assertEqual([env['value'] for env in destination.env if env['name'] == 'TEST_JOB_SLOTS_USER'], ['3']) + destination = self._map_to_destination( + tool, user, datasets, tpv_config_files=[tmp_file1.name, tmp_file2.name] + ) + self.assertEqual( + [ + env["value"] + for env in destination.env + if env["name"] == "TEST_JOB_SLOTS_USER" + ], + ["3"], + ) # update the rule files - updated_rule_file = os.path.join(os.path.dirname(__file__), 'fixtures/mapping-rules-changed.yml') + updated_rule_file = os.path.join( + os.path.dirname(__file__), "fixtures/mapping-rules-changed.yml" + ) shutil.copy2(updated_rule_file, tmp_file1.name) - updated_rule_file = os.path.join(os.path.dirname(__file__), 'fixtures/mapping-rules-changed-extra.yml') + updated_rule_file = os.path.join( + os.path.dirname(__file__), "fixtures/mapping-rules-changed-extra.yml" + ) shutil.copy2(updated_rule_file, tmp_file2.name) # wait for reload time.sleep(0.5) # should have loaded the new rules - destination = self._map_to_destination(tool, user, datasets, tpv_config_files=[ - tmp_file1.name, tmp_file2.name]) - self.assertEqual([env['value'] for env in destination.env if env['name'] == 'TEST_JOB_SLOTS_USER'], ['10']) + destination = self._map_to_destination( + tool, user, datasets, tpv_config_files=[tmp_file1.name, tmp_file2.name] + ) + self.assertEqual( + [ + env["value"] + for env in destination.env + if env["name"] == "TEST_JOB_SLOTS_USER" + ], + ["10"], + ) def test_map_with_syntax_error(self): - tool = mock_galaxy.Tool('bwa') - user = mock_galaxy.User('ford', 'prefect@vortex.org') - datasets = [mock_galaxy.DatasetAssociation("test", mock_galaxy.Dataset("test.txt", file_size=1*1024**3))] + tool = mock_galaxy.Tool("bwa") + user = mock_galaxy.User("ford", "prefect@vortex.org") + datasets = [ + mock_galaxy.DatasetAssociation( + "test", mock_galaxy.Dataset("test.txt", file_size=1 * 1024**3) + ) + ] with self.assertRaises(SyntaxError): - tpv_config = os.path.join(os.path.dirname(__file__), 'fixtures/mapping-syntax-error.yml') - self._map_to_destination(tool, user, datasets, tpv_config_files=[tpv_config]) + tpv_config = os.path.join( + os.path.dirname(__file__), "fixtures/mapping-syntax-error.yml" + ) + self._map_to_destination( + tool, user, datasets, tpv_config_files=[tpv_config] + ) def test_map_with_execute_block(self): - tool = mock_galaxy.Tool('bwa') - user = mock_galaxy.User('ford', 'prefect@vortex.org') - datasets = [mock_galaxy.DatasetAssociation("test", mock_galaxy.Dataset("test.txt", file_size=7*1024**3))] + tool = mock_galaxy.Tool("bwa") + user = mock_galaxy.User("ford", "prefect@vortex.org") + datasets = [ + mock_galaxy.DatasetAssociation( + "test", mock_galaxy.Dataset("test.txt", file_size=7 * 1024**3) + ) + ] with self.assertRaises(JobNotReadyException): - tpv_config = os.path.join(os.path.dirname(__file__), 'fixtures/mapping-rule-execute.yml') - self._map_to_destination(tool, user, datasets, tpv_config_files=[tpv_config]) + tpv_config = os.path.join( + os.path.dirname(__file__), "fixtures/mapping-rule-execute.yml" + ) + self._map_to_destination( + tool, user, datasets, tpv_config_files=[tpv_config] + ) def test_map_with_execute_block_side_effects(self): - tool = mock_galaxy.Tool('bwa') - user = mock_galaxy.User('ford', 'prefect@vortex.org') - datasets = [mock_galaxy.DatasetAssociation("test", mock_galaxy.Dataset("test.txt", file_size=11*1024**3))] + tool = mock_galaxy.Tool("bwa") + user = mock_galaxy.User("ford", "prefect@vortex.org") + datasets = [ + mock_galaxy.DatasetAssociation( + "test", mock_galaxy.Dataset("test.txt", file_size=11 * 1024**3) + ) + ] - tpv_config = os.path.join(os.path.dirname(__file__), 'fixtures/mapping-rule-execute.yml') - destination = self._map_to_destination(tool, user, datasets, tpv_config_files=[tpv_config]) - self.assertEqual(destination.params['my_brand_new_param'], "hello_world") - self.assertEqual(destination.params['native_spec'], '--mem 24 --cores 8') + tpv_config = os.path.join( + os.path.dirname(__file__), "fixtures/mapping-rule-execute.yml" + ) + destination = self._map_to_destination( + tool, user, datasets, tpv_config_files=[tpv_config] + ) + self.assertEqual(destination.params["my_brand_new_param"], "hello_world") + self.assertEqual(destination.params["native_spec"], "--mem 24 --cores 8") def test_job_args_match_helper(self): - tool = mock_galaxy.Tool('limbo') - user = mock_galaxy.User('gag', 'gaghalfrunt@vortex.org') - tpv_config = os.path.join(os.path.dirname(__file__), 'fixtures/mapping-rule-argument-based.yml') - datasets = [mock_galaxy.DatasetAssociation("test", mock_galaxy.Dataset("test.txt", file_size=7*1024**3))] + tool = mock_galaxy.Tool("limbo") + user = mock_galaxy.User("gag", "gaghalfrunt@vortex.org") + tpv_config = os.path.join( + os.path.dirname(__file__), "fixtures/mapping-rule-argument-based.yml" + ) + datasets = [ + mock_galaxy.DatasetAssociation( + "test", mock_galaxy.Dataset("test.txt", file_size=7 * 1024**3) + ) + ] param_values = { - 'output_style': 'flat', - 'colour': {'nighttime': 'blue'}, - 'input_opts': {'tabs_to_spaces': False, 'db_selector': 'db'}, + "output_style": "flat", + "colour": {"nighttime": "blue"}, + "input_opts": {"tabs_to_spaces": False, "db_selector": "db"}, } - destination = self._map_to_destination(tool, user, datasets, param_values, tpv_config_files=[tpv_config]) - self.assertEqual(destination.id, 'k8s_environment') + destination = self._map_to_destination( + tool, user, datasets, param_values, tpv_config_files=[tpv_config] + ) + self.assertEqual(destination.id, "k8s_environment") def test_concurrent_job_count_helper(self): @@ -167,7 +300,7 @@ def create_user(app, mock_user): user = app.model.User( username=mock_user.username, email=mock_user.email, - password='helloworld', + password="helloworld", ) sa_session.add(user) sa_session.flush() @@ -185,21 +318,35 @@ def create_job(app, mock_user, mock_tool): job.user = user job.tool_id = mock_tool.id job.state = "running" - job.destination_id = 'local' + job.destination_id = "local" sa_session.add(job) sa_session.flush() return job.id - tool_user_limit_2 = mock_galaxy.Tool('toolshed.g2.bx.psu.edu/repos/rnateam/mafft/rbc_mafft/7.221.3') - tool_total_limit_3 = mock_galaxy.Tool('toolshed.g2.bx.psu.edu/repos/artbio/repenrich/repenrich/1.6.1') - user_eccentrica = mock_galaxy.User('eccentrica', 'eccentricagallumbits@vortex.org') - user_roosta = mock_galaxy.User('roosta', 'roosta@vortex.org') + tool_user_limit_2 = mock_galaxy.Tool( + "toolshed.g2.bx.psu.edu/repos/rnateam/mafft/rbc_mafft/7.221.3" + ) + tool_total_limit_3 = mock_galaxy.Tool( + "toolshed.g2.bx.psu.edu/repos/artbio/repenrich/repenrich/1.6.1" + ) + user_eccentrica = mock_galaxy.User( + "eccentrica", "eccentricagallumbits@vortex.org" + ) + user_roosta = mock_galaxy.User("roosta", "roosta@vortex.org") app = mock_galaxy.App( - job_conf=os.path.join(os.path.dirname(__file__), 'fixtures/job_conf.yml'), create_model=True) + job_conf=os.path.join(os.path.dirname(__file__), "fixtures/job_conf.yml"), + create_model=True, + ) - datasets = [mock_galaxy.DatasetAssociation("test", mock_galaxy.Dataset("test.txt", file_size=7*1024**3))] - tpv_config = os.path.join(os.path.dirname(__file__), 'fixtures/mapping-rule-tool-limits.yml') + datasets = [ + mock_galaxy.DatasetAssociation( + "test", mock_galaxy.Dataset("test.txt", file_size=7 * 1024**3) + ) + ] + tpv_config = os.path.join( + os.path.dirname(__file__), "fixtures/mapping-rule-tool-limits.yml" + ) user_roosta.id = create_user(app, user_roosta) user_eccentrica.id = create_user(app, user_eccentrica) @@ -210,12 +357,26 @@ def create_job(app, mock_user, mock_tool): # roosta cannot create another rbc_mafft job with self.assertRaises(JobNotReadyException): - tpv_config = os.path.join(os.path.dirname(__file__), 'fixtures/mapping-rule-tool-limits.yml') - self._map_to_destination(tool_user_limit_2, user_roosta, datasets, tpv_config_files=[tpv_config], app=app) + tpv_config = os.path.join( + os.path.dirname(__file__), "fixtures/mapping-rule-tool-limits.yml" + ) + self._map_to_destination( + tool_user_limit_2, + user_roosta, + datasets, + tpv_config_files=[tpv_config], + app=app, + ) # eccentrica can run a rbc_mafft job - destination = self._map_to_destination(tool_user_limit_2, user_eccentrica, datasets, tpv_config_files=[tpv_config], app=app) - self.assertEqual(destination.id, 'local') + destination = self._map_to_destination( + tool_user_limit_2, + user_eccentrica, + datasets, + tpv_config_files=[tpv_config], + app=app, + ) + self.assertEqual(destination.id, "local") # set up 3 running jobs for repenrich tool create_job(app, user_eccentrica, tool_total_limit_3) @@ -224,44 +385,64 @@ def create_job(app, mock_user, mock_tool): # roosta cannot create another repenrich job with self.assertRaises(JobNotReadyException): - tpv_config = os.path.join(os.path.dirname(__file__), 'fixtures/mapping-rule-tool-limits.yml') - self._map_to_destination(tool_total_limit_3, user_roosta, datasets, tpv_config_files=[tpv_config], app=app) + tpv_config = os.path.join( + os.path.dirname(__file__), "fixtures/mapping-rule-tool-limits.yml" + ) + self._map_to_destination( + tool_total_limit_3, + user_roosta, + datasets, + tpv_config_files=[tpv_config], + app=app, + ) def test_tool_version_comparison_helpers(self): - tpv_config = os.path.join(os.path.dirname(__file__), 'fixtures/mapping-rule-tool-limits.yml') - user = mock_galaxy.User('ford', 'prefect@vortex.org') - datasets = [mock_galaxy.DatasetAssociation("test", mock_galaxy.Dataset("test.txt", file_size=1*1024**3))] + tpv_config = os.path.join( + os.path.dirname(__file__), "fixtures/mapping-rule-tool-limits.yml" + ) + user = mock_galaxy.User("ford", "prefect@vortex.org") + datasets = [ + mock_galaxy.DatasetAssociation( + "test", mock_galaxy.Dataset("test.txt", file_size=1 * 1024**3) + ) + ] def mock_trinity_with_version(version): return mock_galaxy.Tool( - id=f'toolshed.g2.bx.psu.edu/repos/iuc/trinity/trinity/{version}', - version=version + id=f"toolshed.g2.bx.psu.edu/repos/iuc/trinity/trinity/{version}", + version=version, ) - env_keys = [ # env keys that may be added by cooked trinity rules - 'version_gte_2.15.1+galaxy0', - 'version_gt_2.15.1+galaxy0', - 'version_lt_2.10.1+galaxy7', - 'version_lte_2.10.1+galaxy7', + env_keys = [ # env keys that may be added by cooked trinity rules + "version_gte_2.15.1+galaxy0", + "version_gt_2.15.1+galaxy0", + "version_lt_2.10.1+galaxy7", + "version_lte_2.10.1+galaxy7", ] # trinity version 3.15.1+galaxy0 - tool = mock_trinity_with_version('3.15.1+galaxy0') - destination = self._map_to_destination(tool, user, datasets, tpv_config_files=[tpv_config]) + tool = mock_trinity_with_version("3.15.1+galaxy0") + destination = self._map_to_destination( + tool, user, datasets, tpv_config_files=[tpv_config] + ) self.assertCountEqual( - [e.get('name') for e in destination.env if e.get('name') in env_keys], - ['version_gte_2.15.1+galaxy0', 'version_gt_2.15.1+galaxy0'], + [e.get("name") for e in destination.env if e.get("name") in env_keys], + ["version_gte_2.15.1+galaxy0", "version_gt_2.15.1+galaxy0"], ) # trinity version 2.15.1+galaxy0 - tool = mock_trinity_with_version('2.15.1+galaxy0') - destination = self._map_to_destination(tool, user, datasets, tpv_config_files=[tpv_config]) + tool = mock_trinity_with_version("2.15.1+galaxy0") + destination = self._map_to_destination( + tool, user, datasets, tpv_config_files=[tpv_config] + ) self.assertCountEqual( - [e.get('name') for e in destination.env if e.get('name') in env_keys], - ['version_gte_2.15.1+galaxy0'], + [e.get("name") for e in destination.env if e.get("name") in env_keys], + ["version_gte_2.15.1+galaxy0"], ) # trinity version 2.10.1+galaxy6 - tool = mock_trinity_with_version('2.10.1+galaxy6') - destination = self._map_to_destination(tool, user, datasets, tpv_config_files=[tpv_config]) + tool = mock_trinity_with_version("2.10.1+galaxy6") + destination = self._map_to_destination( + tool, user, datasets, tpv_config_files=[tpv_config] + ) self.assertCountEqual( - [e.get('name') for e in destination.env if e.get('name') in env_keys], - ['version_lt_2.10.1+galaxy7', 'version_lte_2.10.1+galaxy7'], + [e.get("name") for e in destination.env if e.get("name") in env_keys], + ["version_lt_2.10.1+galaxy7", "version_lte_2.10.1+galaxy7"], ) diff --git a/tests/test_mapper_sample.py b/tests/test_mapper_sample.py index d282a28..3a6cf62 100644 --- a/tests/test_mapper_sample.py +++ b/tests/test_mapper_sample.py @@ -1,22 +1,29 @@ import os import unittest -from tpv.rules import gateway + from tpv.commands.test import mock_galaxy +from tpv.rules import gateway class TestMapperSample(unittest.TestCase): @staticmethod def _map_to_destination(tool): - galaxy_app = mock_galaxy.App(job_conf=os.path.join(os.path.dirname(__file__), 'fixtures/job_conf.yml')) + galaxy_app = mock_galaxy.App( + job_conf=os.path.join(os.path.dirname(__file__), "fixtures/job_conf.yml") + ) job = mock_galaxy.Job() - user = mock_galaxy.User('gargravarr', 'fairycake@vortex.org') - tpv_config = os.path.join(os.path.dirname(__file__), 'fixtures/mapping-sample.yml') + user = mock_galaxy.User("gargravarr", "fairycake@vortex.org") + tpv_config = os.path.join( + os.path.dirname(__file__), "fixtures/mapping-sample.yml" + ) gateway.ACTIVE_DESTINATION_MAPPER = None - return gateway.map_tool_to_destination(galaxy_app, job, tool, user, tpv_config_files=[tpv_config]) + return gateway.map_tool_to_destination( + galaxy_app, job, tool, user, tpv_config_files=[tpv_config] + ) def test_map_sample_tool(self): - tool = mock_galaxy.Tool('sometool') + tool = mock_galaxy.Tool("sometool") destination = self._map_to_destination(tool) self.assertEqual(destination.id, "local") - self.assertEqual(destination.params['local_slots'], '2') + self.assertEqual(destination.params["local_slots"], "2") diff --git a/tests/test_mapper_user.py b/tests/test_mapper_user.py index 729f47f..2db15ab 100644 --- a/tests/test_mapper_user.py +++ b/tests/test_mapper_user.py @@ -1,95 +1,110 @@ import os import unittest -from tpv.rules import gateway -from tpv.core.entities import IncompatibleTagsException + from tpv.commands.test import mock_galaxy +from tpv.core.entities import IncompatibleTagsException +from tpv.rules import gateway class TestMapperUser(unittest.TestCase): @staticmethod def _map_to_destination(tool, user, tpv_config_path=None): - galaxy_app = mock_galaxy.App(job_conf=os.path.join(os.path.dirname(__file__), 'fixtures/job_conf.yml')) + galaxy_app = mock_galaxy.App( + job_conf=os.path.join(os.path.dirname(__file__), "fixtures/job_conf.yml") + ) job = mock_galaxy.Job() - tpv_config = tpv_config_path or os.path.join(os.path.dirname(__file__), 'fixtures/mapping-user.yml') + tpv_config = tpv_config_path or os.path.join( + os.path.dirname(__file__), "fixtures/mapping-user.yml" + ) gateway.ACTIVE_DESTINATION_MAPPER = None - return gateway.map_tool_to_destination(galaxy_app, job, tool, user, tpv_config_files=[tpv_config]) + return gateway.map_tool_to_destination( + galaxy_app, job, tool, user, tpv_config_files=[tpv_config] + ) def test_map_default_user(self): - tool = mock_galaxy.Tool('bwa') - user = mock_galaxy.User('ford', 'ford@vortex.org') + tool = mock_galaxy.Tool("bwa") + user = mock_galaxy.User("ford", "ford@vortex.org") destination = self._map_to_destination(tool, user) self.assertEqual(destination.id, "k8s_environment") def test_map_overridden_user(self): - tool = mock_galaxy.Tool('bwa') - user = mock_galaxy.User('gargravarr', 'fairycake@vortex.org') + tool = mock_galaxy.Tool("bwa") + user = mock_galaxy.User("gargravarr", "fairycake@vortex.org") destination = self._map_to_destination(tool, user) self.assertEqual(destination.id, "k8s_environment") def test_map_unschedulable_user(self): - tool = mock_galaxy.Tool('bwa') - user = mock_galaxy.User('arthur', 'arthur@vortex.org') + tool = mock_galaxy.Tool("bwa") + user = mock_galaxy.User("arthur", "arthur@vortex.org") with self.assertRaises(IncompatibleTagsException): self._map_to_destination(tool, user) def test_map_invalidly_tagged_user(self): - tool = mock_galaxy.Tool('bwa') - user = mock_galaxy.User('infinitely', 'improbable@vortex.org') - - config = os.path.join(os.path.dirname(__file__), 'fixtures/mapping-user-invalid-tags.yml') - with self.assertRaisesRegex(Exception, r"Duplicate tags found: 'pulsar' in \['require', 'reject'\]"): + tool = mock_galaxy.Tool("bwa") + user = mock_galaxy.User("infinitely", "improbable@vortex.org") + + config = os.path.join( + os.path.dirname(__file__), "fixtures/mapping-user-invalid-tags.yml" + ) + with self.assertRaisesRegex( + Exception, r"Duplicate tags found: 'pulsar' in \['require', 'reject'\]" + ): self._map_to_destination(tool, user, tpv_config_path=config) def test_map_user_by_regex(self): - tool = mock_galaxy.Tool('bwa') - user = mock_galaxy.User('gargravarr', 'fairycake@vortex.org') + tool = mock_galaxy.Tool("bwa") + user = mock_galaxy.User("gargravarr", "fairycake@vortex.org") destination = self._map_to_destination(tool, user) self.assertEqual(destination.id, "k8s_environment") def test_map_user_by_regex_mismatch(self): - tool = mock_galaxy.Tool('bwa') - user = mock_galaxy.User('gargravarr', 'fairycake@notvortex.org') + tool = mock_galaxy.Tool("bwa") + user = mock_galaxy.User("gargravarr", "fairycake@notvortex.org") with self.assertRaises(IncompatibleTagsException): self._map_to_destination(tool, user) def test_map_user_entity_usage_scenario_1(self): - tool = mock_galaxy.Tool('bwa') - user = mock_galaxy.User('ford', 'ford@vortex.org') + tool = mock_galaxy.Tool("bwa") + user = mock_galaxy.User("ford", "ford@vortex.org") destination = self._map_to_destination(tool, user) self.assertEqual(destination.id, "k8s_environment") # should use the lower of the two core and mem values for this user - self.assertEqual(destination.params['native_spec'], '--mem 4 --cores 2') + self.assertEqual(destination.params["native_spec"], "--mem 4 --cores 2") def test_map_user_entity_usage_scenario_2(self): - tool = mock_galaxy.Tool('bwa') - user = mock_galaxy.User('ford', 'fairycake@vortex.org') + tool = mock_galaxy.Tool("bwa") + user = mock_galaxy.User("ford", "fairycake@vortex.org") destination = self._map_to_destination(tool, user) self.assertEqual(destination.id, "k8s_environment") # should use the lower of the two core and mem values for this user - self.assertEqual(destination.params['native_spec'], '--mem 8 --cores 1') + self.assertEqual(destination.params["native_spec"], "--mem 8 --cores 1") def test_tool_below_min_resources_for_user(self): - tool = mock_galaxy.Tool('tool_below_min_resources') - user = mock_galaxy.User('prefect', 'prefect@vortex.org') + tool = mock_galaxy.Tool("tool_below_min_resources") + user = mock_galaxy.User("prefect", "prefect@vortex.org") destination = self._map_to_destination(tool, user) self.assertEqual(destination.id, "special_resource_environment") # should use the lower of the two core and special_resource_environment values for this user - self.assertEqual(destination.params['native_spec'], '--mem 16 --cores 2 --gpus 2') + self.assertEqual( + destination.params["native_spec"], "--mem 16 --cores 2 --gpus 2" + ) def test_tool_above_max_resources_for_user(self): - tool = mock_galaxy.Tool('tool_above_max_resources') - user = mock_galaxy.User('prefect', 'prefect@vortex.org') + tool = mock_galaxy.Tool("tool_above_max_resources") + user = mock_galaxy.User("prefect", "prefect@vortex.org") destination = self._map_to_destination(tool, user) self.assertEqual(destination.id, "special_resource_environment") # should use the lower of the two core and mem values for this user - self.assertEqual(destination.params['native_spec'], '--mem 32 --cores 4 --gpus 3') + self.assertEqual( + destination.params["native_spec"], "--mem 32 --cores 4 --gpus 3" + ) diff --git a/tests/test_scenarios.py b/tests/test_scenarios.py index d51cb49..594d310 100644 --- a/tests/test_scenarios.py +++ b/tests/test_scenarios.py @@ -1,52 +1,74 @@ import os -import time -import tempfile import pathlib -import responses import shutil +import tempfile +import time import unittest + +import responses from galaxy.jobs.mapper import JobMappingException -from tpv.rules import gateway + from tpv.commands.test import mock_galaxy +from tpv.rules import gateway class TestScenarios(unittest.TestCase): @staticmethod - def _map_to_destination(tool, user, datasets=[], tpv_config_path=None, job_conf=None, app=None): + def _map_to_destination( + tool, user, datasets=[], tpv_config_path=None, job_conf=None, app=None + ): if job_conf: - galaxy_app = mock_galaxy.App(job_conf=os.path.join(os.path.dirname(__file__), job_conf)) + galaxy_app = mock_galaxy.App( + job_conf=os.path.join(os.path.dirname(__file__), job_conf) + ) elif app: galaxy_app = app else: - galaxy_app = mock_galaxy.App(job_conf=os.path.join(os.path.dirname(__file__), 'fixtures/job_conf.yml')) + galaxy_app = mock_galaxy.App( + job_conf=os.path.join( + os.path.dirname(__file__), "fixtures/job_conf.yml" + ) + ) job = mock_galaxy.Job() for d in datasets: job.add_input_dataset(d) - tpv_config = tpv_config_path or os.path.join(os.path.dirname(__file__), 'fixtures/mapping-rules.yml') + tpv_config = tpv_config_path or os.path.join( + os.path.dirname(__file__), "fixtures/mapping-rules.yml" + ) gateway.ACTIVE_DESTINATION_MAPPER = None - return gateway.map_tool_to_destination(galaxy_app, job, tool, user, tpv_config_files=[tpv_config]) + return gateway.map_tool_to_destination( + galaxy_app, job, tool, user, tpv_config_files=[tpv_config] + ) def test_scenario_node_marked_offline(self): - with tempfile.NamedTemporaryFile('w+t') as tmp_file: - rule_file = os.path.join(os.path.dirname(__file__), 'fixtures/scenario-node-online.yml') + with tempfile.NamedTemporaryFile("w+t") as tmp_file: + rule_file = os.path.join( + os.path.dirname(__file__), "fixtures/scenario-node-online.yml" + ) shutil.copy2(rule_file, tmp_file.name) - tool = mock_galaxy.Tool('bwa') - user = mock_galaxy.User('gargravarr', 'fairycake@vortex.org') + tool = mock_galaxy.Tool("bwa") + user = mock_galaxy.User("gargravarr", "fairycake@vortex.org") - destination = self._map_to_destination(tool, user, tpv_config_path=tmp_file.name) + destination = self._map_to_destination( + tool, user, tpv_config_path=tmp_file.name + ) self.assertEqual(destination.id, "k8s_environment") # update the rule file with one node marked offline - updated_rule_file = os.path.join(os.path.dirname(__file__), 'fixtures/scenario-node-offline.yml') + updated_rule_file = os.path.join( + os.path.dirname(__file__), "fixtures/scenario-node-offline.yml" + ) shutil.copy2(updated_rule_file, tmp_file.name) # wait for reload time.sleep(0.5) # should now map to the available node - destination = self._map_to_destination(tool, user, tpv_config_path=tmp_file.name) + destination = self._map_to_destination( + tool, user, tpv_config_path=tmp_file.name + ) self.assertEqual(destination.id, "local") @responses.activate @@ -59,16 +81,31 @@ def test_scenario_job_too_small_for_high_memory_node(self): method=responses.GET, url="http://stats.genome.edu.au:8086/query", body=pathlib.Path( - os.path.join(os.path.dirname(__file__), 'fixtures/response-job-too-small-for-highmem.yml')).read_text(), + os.path.join( + os.path.dirname(__file__), + "fixtures/response-job-too-small-for-highmem.yml", + ) + ).read_text(), match_querystring=False, ) - tool = mock_galaxy.Tool('bwa-mem') - user = mock_galaxy.User('simon', 'simon@unimelb.edu.au') - datasets = [mock_galaxy.DatasetAssociation("input", mock_galaxy.Dataset("input.fastq", file_size=10*1024**3))] - rules_file = os.path.join(os.path.dirname(__file__), 'fixtures/scenario-job-too-small-for-highmem.yml') - destination = self._map_to_destination(tool, user, datasets=datasets, tpv_config_path=rules_file, - job_conf='fixtures/job_conf_scenario_usegalaxy_au.yml') + tool = mock_galaxy.Tool("bwa-mem") + user = mock_galaxy.User("simon", "simon@unimelb.edu.au") + datasets = [ + mock_galaxy.DatasetAssociation( + "input", mock_galaxy.Dataset("input.fastq", file_size=10 * 1024**3) + ) + ] + rules_file = os.path.join( + os.path.dirname(__file__), "fixtures/scenario-job-too-small-for-highmem.yml" + ) + destination = self._map_to_destination( + tool, + user, + datasets=datasets, + tpv_config_path=rules_file, + job_conf="fixtures/job_conf_scenario_usegalaxy_au.yml", + ) self.assertEqual(destination.id, "general_pulsar_2") @responses.activate @@ -81,17 +118,31 @@ def test_scenario_node_offline_high_cpu(self): method=responses.GET, url="http://stats.genome.edu.au:8086/query", body=pathlib.Path( - os.path.join(os.path.dirname(__file__), 'fixtures/response-node-offline-high-cpu.yml')).read_text(), + os.path.join( + os.path.dirname(__file__), + "fixtures/response-node-offline-high-cpu.yml", + ) + ).read_text(), match_querystring=False, ) - tool = mock_galaxy.Tool('bwa-mem') - user = mock_galaxy.User('steve', 'steve@unimelb.edu.au') - datasets = [mock_galaxy.DatasetAssociation("input", mock_galaxy.Dataset("input.fastq", - file_size=0.1*1024**3))] - rules_file = os.path.join(os.path.dirname(__file__), 'fixtures/scenario-node-offline-high-cpu.yml') - destination = self._map_to_destination(tool, user, datasets=datasets, tpv_config_path=rules_file, - job_conf='fixtures/job_conf_scenario_usegalaxy_au.yml') + tool = mock_galaxy.Tool("bwa-mem") + user = mock_galaxy.User("steve", "steve@unimelb.edu.au") + datasets = [ + mock_galaxy.DatasetAssociation( + "input", mock_galaxy.Dataset("input.fastq", file_size=0.1 * 1024**3) + ) + ] + rules_file = os.path.join( + os.path.dirname(__file__), "fixtures/scenario-node-offline-high-cpu.yml" + ) + destination = self._map_to_destination( + tool, + user, + datasets=datasets, + tpv_config_path=rules_file, + job_conf="fixtures/job_conf_scenario_usegalaxy_au.yml", + ) self.assertEqual(destination.id, "general_pulsar_2") @responses.activate @@ -103,17 +154,31 @@ def test_scenario_trinity_with_rules(self): method=responses.GET, url="http://stats.genome.edu.au:8086/query", body=pathlib.Path( - os.path.join(os.path.dirname(__file__), 'fixtures/response-trinity-job-with-rules.yml')).read_text(), + os.path.join( + os.path.dirname(__file__), + "fixtures/response-trinity-job-with-rules.yml", + ) + ).read_text(), match_querystring=False, ) - tool = mock_galaxy.Tool('trinity') - user = mock_galaxy.User('someone', 'someone@unimelb.edu.au') - datasets = [mock_galaxy.DatasetAssociation("input", mock_galaxy.Dataset("input.fastq", - file_size=0.1*1024**3))] - rules_file = os.path.join(os.path.dirname(__file__), 'fixtures/scenario-trinity-job-with-rules.yml') - destination = self._map_to_destination(tool, user, datasets=datasets, tpv_config_path=rules_file, - job_conf='fixtures/job_conf_scenario_usegalaxy_au.yml') + tool = mock_galaxy.Tool("trinity") + user = mock_galaxy.User("someone", "someone@unimelb.edu.au") + datasets = [ + mock_galaxy.DatasetAssociation( + "input", mock_galaxy.Dataset("input.fastq", file_size=0.1 * 1024**3) + ) + ] + rules_file = os.path.join( + os.path.dirname(__file__), "fixtures/scenario-trinity-job-with-rules.yml" + ) + destination = self._map_to_destination( + tool, + user, + datasets=datasets, + tpv_config_path=rules_file, + job_conf="fixtures/job_conf_scenario_usegalaxy_au.yml", + ) self.assertEqual(destination.id, "highmem_pulsar_2") @responses.activate @@ -125,19 +190,35 @@ def test_scenario_trinity_job_too_much_data(self): method=responses.GET, url="http://stats.genome.edu.au:8086/query", body=pathlib.Path( - os.path.join(os.path.dirname(__file__), 'fixtures/response-trinity-job-with-rules.yml')).read_text(), + os.path.join( + os.path.dirname(__file__), + "fixtures/response-trinity-job-with-rules.yml", + ) + ).read_text(), match_querystring=False, ) - tool = mock_galaxy.Tool('trinity') - user = mock_galaxy.User('someone', 'someone@unimelb.edu.au') - datasets = [mock_galaxy.DatasetAssociation("input", mock_galaxy.Dataset("input.fastq", - file_size=1000*1024**3))] - rules_file = os.path.join(os.path.dirname(__file__), 'fixtures/scenario-trinity-job-too-much-data.yml') - with self.assertRaisesRegex(JobMappingException, - "Input file size of 1000.0GB is > maximum allowed 200GB limit"): - self._map_to_destination(tool, user, datasets=datasets, tpv_config_path=rules_file, - job_conf='fixtures/job_conf_scenario_usegalaxy_au.yml') + tool = mock_galaxy.Tool("trinity") + user = mock_galaxy.User("someone", "someone@unimelb.edu.au") + datasets = [ + mock_galaxy.DatasetAssociation( + "input", mock_galaxy.Dataset("input.fastq", file_size=1000 * 1024**3) + ) + ] + rules_file = os.path.join( + os.path.dirname(__file__), "fixtures/scenario-trinity-job-too-much-data.yml" + ) + with self.assertRaisesRegex( + JobMappingException, + "Input file size of 1000.0GB is > maximum allowed 200GB limit", + ): + self._map_to_destination( + tool, + user, + datasets=datasets, + tpv_config_path=rules_file, + job_conf="fixtures/job_conf_scenario_usegalaxy_au.yml", + ) @responses.activate def test_scenario_non_pulsar_enabled_job(self): @@ -148,17 +229,31 @@ def test_scenario_non_pulsar_enabled_job(self): method=responses.GET, url="http://stats.genome.edu.au:8086/query", body=pathlib.Path( - os.path.join(os.path.dirname(__file__), 'fixtures/response-non-pulsar-enabled-job.yml')).read_text(), + os.path.join( + os.path.dirname(__file__), + "fixtures/response-non-pulsar-enabled-job.yml", + ) + ).read_text(), match_querystring=False, ) - tool = mock_galaxy.Tool('fastp') - user = mock_galaxy.User('kate', 'kate@unimelb.edu.au') - datasets = [mock_galaxy.DatasetAssociation("input", mock_galaxy.Dataset("input.fastq", - file_size=1000*1024**3))] - rules_file = os.path.join(os.path.dirname(__file__), 'fixtures/scenario-non-pulsar-enabled-job.yml') - destination = self._map_to_destination(tool, user, datasets=datasets, tpv_config_path=rules_file, - job_conf='fixtures/job_conf_scenario_usegalaxy_au.yml') + tool = mock_galaxy.Tool("fastp") + user = mock_galaxy.User("kate", "kate@unimelb.edu.au") + datasets = [ + mock_galaxy.DatasetAssociation( + "input", mock_galaxy.Dataset("input.fastq", file_size=1000 * 1024**3) + ) + ] + rules_file = os.path.join( + os.path.dirname(__file__), "fixtures/scenario-non-pulsar-enabled-job.yml" + ) + destination = self._map_to_destination( + tool, + user, + datasets=datasets, + tpv_config_path=rules_file, + job_conf="fixtures/job_conf_scenario_usegalaxy_au.yml", + ) self.assertEqual(destination.id, "slurm") @responses.activate @@ -170,17 +265,31 @@ def test_scenario_jenkins_bot_user(self): method=responses.GET, url="http://stats.genome.edu.au:8086/query", body=pathlib.Path( - os.path.join(os.path.dirname(__file__), 'fixtures/response-trinity-job-with-rules.yml')).read_text(), + os.path.join( + os.path.dirname(__file__), + "fixtures/response-trinity-job-with-rules.yml", + ) + ).read_text(), match_querystring=False, ) - tool = mock_galaxy.Tool('fastp') - user = mock_galaxy.User('jenkinsbot', 'jenkinsbot@unimelb.edu.au') - datasets = [mock_galaxy.DatasetAssociation("input", mock_galaxy.Dataset("input.fastq", - file_size=1000*1024**3))] - rules_file = os.path.join(os.path.dirname(__file__), 'fixtures/scenario-jenkins-bot-user.yml') - destination = self._map_to_destination(tool, user, datasets=datasets, tpv_config_path=rules_file, - job_conf='fixtures/job_conf_scenario_usegalaxy_au.yml') + tool = mock_galaxy.Tool("fastp") + user = mock_galaxy.User("jenkinsbot", "jenkinsbot@unimelb.edu.au") + datasets = [ + mock_galaxy.DatasetAssociation( + "input", mock_galaxy.Dataset("input.fastq", file_size=1000 * 1024**3) + ) + ] + rules_file = os.path.join( + os.path.dirname(__file__), "fixtures/scenario-jenkins-bot-user.yml" + ) + destination = self._map_to_destination( + tool, + user, + datasets=datasets, + tpv_config_path=rules_file, + job_conf="fixtures/job_conf_scenario_usegalaxy_au.yml", + ) self.assertEqual(destination.id, "slurm") @responses.activate @@ -193,17 +302,32 @@ def test_scenario_admin_group_user(self): method=responses.GET, url="http://stats.genome.edu.au:8086/query", body=pathlib.Path( - os.path.join(os.path.dirname(__file__), 'fixtures/response-admin-group-user.yml')).read_text(), + os.path.join( + os.path.dirname(__file__), "fixtures/response-admin-group-user.yml" + ) + ).read_text(), match_querystring=False, ) - tool = mock_galaxy.Tool('trinity') - user = mock_galaxy.User('pulsar-hm2-user', 'pulsar-hm2-user@unimelb.edu.au', roles=["ga_admins"]) - datasets = [mock_galaxy.DatasetAssociation("input", mock_galaxy.Dataset("input.fastq", - file_size=1000*1024**3))] - rules_file = os.path.join(os.path.dirname(__file__), 'fixtures/scenario-admin-group-user.yml') - destination = self._map_to_destination(tool, user, datasets=datasets, tpv_config_path=rules_file, - job_conf='fixtures/job_conf_scenario_usegalaxy_au.yml') + tool = mock_galaxy.Tool("trinity") + user = mock_galaxy.User( + "pulsar-hm2-user", "pulsar-hm2-user@unimelb.edu.au", roles=["ga_admins"] + ) + datasets = [ + mock_galaxy.DatasetAssociation( + "input", mock_galaxy.Dataset("input.fastq", file_size=1000 * 1024**3) + ) + ] + rules_file = os.path.join( + os.path.dirname(__file__), "fixtures/scenario-admin-group-user.yml" + ) + destination = self._map_to_destination( + tool, + user, + datasets=datasets, + tpv_config_path=rules_file, + job_conf="fixtures/job_conf_scenario_usegalaxy_au.yml", + ) self.assertEqual(destination.id, "highmem_pulsar_2") @responses.activate @@ -215,7 +339,10 @@ def test_scenario_too_many_highmem_jobs(self): method=responses.GET, url="http://stats.genome.edu.au:8086/query", body=pathlib.Path( - os.path.join(os.path.dirname(__file__), 'fixtures/response-admin-group-user.yml')).read_text(), + os.path.join( + os.path.dirname(__file__), "fixtures/response-admin-group-user.yml" + ) + ).read_text(), match_querystring=False, ) @@ -233,45 +360,76 @@ def create_job(app, destination): sa_session.flush() app = mock_galaxy.App( - job_conf=os.path.join(os.path.dirname(__file__), 'fixtures/job_conf_scenario_usegalaxy_au.yml'), - create_model=True) + job_conf=os.path.join( + os.path.dirname(__file__), "fixtures/job_conf_scenario_usegalaxy_au.yml" + ), + create_model=True, + ) create_job(app, "highmem_pulsar_1") create_job(app, "highmem_pulsar_2") create_job(app, "highmem_pulsar_1") create_job(app, "highmem_pulsar_2") - tool = mock_galaxy.Tool('trinity') - user = mock_galaxy.User('highmemuser', 'highmemuser@unimelb.edu.au', roles=["ga_admins"]) - datasets = [mock_galaxy.DatasetAssociation("input", mock_galaxy.Dataset("input.fastq", - file_size=1000*1024**3))] - rules_file = os.path.join(os.path.dirname(__file__), 'fixtures/scenario-too-many-highmem-jobs.yml') - destination = self._map_to_destination(tool, user, datasets=datasets, tpv_config_path=rules_file, app=app) + tool = mock_galaxy.Tool("trinity") + user = mock_galaxy.User( + "highmemuser", "highmemuser@unimelb.edu.au", roles=["ga_admins"] + ) + datasets = [ + mock_galaxy.DatasetAssociation( + "input", mock_galaxy.Dataset("input.fastq", file_size=1000 * 1024**3) + ) + ] + rules_file = os.path.join( + os.path.dirname(__file__), "fixtures/scenario-too-many-highmem-jobs.yml" + ) + destination = self._map_to_destination( + tool, user, datasets=datasets, tpv_config_path=rules_file, app=app + ) self.assertEqual(destination.id, "highmem_pulsar_1") # exceed the limit create_job(app, "highmem_pulsar_1") with self.assertRaisesRegex( - JobMappingException, "You cannot have more than 4 high-mem jobs running concurrently"): - self._map_to_destination(tool, user, datasets=datasets, tpv_config_path=rules_file, app=app) + JobMappingException, + "You cannot have more than 4 high-mem jobs running concurrently", + ): + self._map_to_destination( + tool, user, datasets=datasets, tpv_config_path=rules_file, app=app + ) @responses.activate def test_scenario_usegalaxy_dev(self): """ Check whether usegalaxy.au dev dispatch works """ - tool = mock_galaxy.Tool('upload1') - user = mock_galaxy.User('catherine', 'catherine@unimelb.edu.au') - datasets = [mock_galaxy.DatasetAssociation("input", mock_galaxy.Dataset("input.fastq", - file_size=1000*1024**3))] - rules_file = os.path.join(os.path.dirname(__file__), 'fixtures/scenario-usegalaxy-dev.yml') - destination = self._map_to_destination(tool, user, datasets=datasets, tpv_config_path=rules_file, - job_conf='fixtures/job_conf_scenario_usegalaxy_au.yml') + tool = mock_galaxy.Tool("upload1") + user = mock_galaxy.User("catherine", "catherine@unimelb.edu.au") + datasets = [ + mock_galaxy.DatasetAssociation( + "input", mock_galaxy.Dataset("input.fastq", file_size=1000 * 1024**3) + ) + ] + rules_file = os.path.join( + os.path.dirname(__file__), "fixtures/scenario-usegalaxy-dev.yml" + ) + destination = self._map_to_destination( + tool, + user, + datasets=datasets, + tpv_config_path=rules_file, + job_conf="fixtures/job_conf_scenario_usegalaxy_au.yml", + ) self.assertEqual(destination.id, "slurm") - tool = mock_galaxy.Tool('hifiasm') - destination = self._map_to_destination(tool, user, datasets=datasets, tpv_config_path=rules_file, - job_conf='fixtures/job_conf_scenario_usegalaxy_au.yml') + tool = mock_galaxy.Tool("hifiasm") + destination = self._map_to_destination( + tool, + user, + datasets=datasets, + tpv_config_path=rules_file, + job_conf="fixtures/job_conf_scenario_usegalaxy_au.yml", + ) self.assertEqual(destination.id, "pulsar-nci-test") @responses.activate @@ -280,10 +438,21 @@ def test_scenario_usegalaxy_eu_training(self): Check whether training groups are correctly generated. Specifically, this test checks whether execute blocks can modify entity properties. """ - tool = mock_galaxy.Tool('bwa') - user = mock_galaxy.User('gargravarr', 'fairycake@vortex.org', roles=["training-group1", "training-group2"]) + tool = mock_galaxy.Tool("bwa") + user = mock_galaxy.User( + "gargravarr", + "fairycake@vortex.org", + roles=["training-group1", "training-group2"], + ) - rules_file = os.path.join(os.path.dirname(__file__), 'fixtures/scenario-usegalaxy-eu-training.yml') + rules_file = os.path.join( + os.path.dirname(__file__), "fixtures/scenario-usegalaxy-eu-training.yml" + ) destination = self._map_to_destination(tool, user, tpv_config_path=rules_file) - self.assertEqual(destination.params['requirements'], '(GalaxyGroup == "compute") || ((GalaxyGroup == "training-group1") || (GalaxyGroup == "training-group2"))') - self.assertEqual(destination.params['+Group'], '"training-group1, training-group2"') + self.assertEqual( + destination.params["requirements"], + '(GalaxyGroup == "compute") || ((GalaxyGroup == "training-group1") || (GalaxyGroup == "training-group2"))', + ) + self.assertEqual( + destination.params["+Group"], '"training-group1, training-group2"' + ) diff --git a/tests/test_shell.py b/tests/test_shell.py index 21a4168..c14b468 100644 --- a/tests/test_shell.py +++ b/tests/test_shell.py @@ -1,9 +1,9 @@ -from collections import OrderedDict import contextlib import io import os import sys import unittest +from collections import OrderedDict import pytest import yaml @@ -40,137 +40,197 @@ def call_shell_command(*args): return run_python_script(main, list(args)) def test_lint_no_errors_non_verbose(self): - tpv_config = os.path.join(os.path.dirname(__file__), 'fixtures/scenario-usegalaxy-dev.yml') + tpv_config = os.path.join( + os.path.dirname(__file__), "fixtures/scenario-usegalaxy-dev.yml" + ) output = self.call_shell_command("tpv", "lint", tpv_config) self.assertTrue( "lint successful" in output, - f"Expected lint to be successful but output was: {output}") + f"Expected lint to be successful but output was: {output}", + ) def test_lint_no_errors_verbose(self): - tpv_config = os.path.join(os.path.dirname(__file__), 'fixtures/scenario-usegalaxy-dev.yml') + tpv_config = os.path.join( + os.path.dirname(__file__), "fixtures/scenario-usegalaxy-dev.yml" + ) output = self.call_shell_command("tpv", "-vvvv", "lint", tpv_config) self.assertTrue( "lint successful" in output, - f"Expected lint to be successful but output was: {output}") + f"Expected lint to be successful but output was: {output}", + ) def test_lint_syntax_error(self): - tpv_config = os.path.join(os.path.dirname(__file__), 'fixtures/mapping-syntax-error.yml') + tpv_config = os.path.join( + os.path.dirname(__file__), "fixtures/mapping-syntax-error.yml" + ) output = self.call_shell_command("tpv", "lint", tpv_config) self.assertTrue( - "lint failed" in output, - f"Expected lint to fail but output was: {output}") + "lint failed" in output, f"Expected lint to fail but output was: {output}" + ) self.assertTrue( - "oops syntax!" in output, - f"Expected lint to fail but output was: {output}") + "oops syntax!" in output, f"Expected lint to fail but output was: {output}" + ) def test_lint_invalid_regex(self): - tpv_config = os.path.join(os.path.dirname(__file__), 'fixtures/linter/linter-invalid-regex.yml') + tpv_config = os.path.join( + os.path.dirname(__file__), "fixtures/linter/linter-invalid-regex.yml" + ) output = self.call_shell_command("tpv", "lint", tpv_config) self.assertTrue( - "lint failed" in output, - f"Expected lint to fail but output was: {output}") + "lint failed" in output, f"Expected lint to fail but output was: {output}" + ) self.assertTrue( "Failed to compile regex: bwa" in output, - f"Expected lint to fail but output was: {output}") + f"Expected lint to fail but output was: {output}", + ) def test_lint_no_runner_defined(self): - tpv_config = os.path.join(os.path.dirname(__file__), 'fixtures/linter/linter-no-runner-defined.yml') + tpv_config = os.path.join( + os.path.dirname(__file__), "fixtures/linter/linter-no-runner-defined.yml" + ) output = self.call_shell_command("tpv", "-vv", "lint", tpv_config) self.assertTrue( - "lint failed" in output, - f"Expected lint to fail but output was: {output}") + "lint failed" in output, f"Expected lint to fail but output was: {output}" + ) self.assertTrue( "Destination 'local'" in output, - f"Expected absence of runner param to be reported for destination local but output was: {output}") + f"Expected absence of runner param to be reported for destination local but output was: {output}", + ) self.assertTrue( "Destination 'another_env_without_runner'" in output, "Expected absence of runner param to be reported for destination another_env_without_runner " - f"but output was: {output}") + f"but output was: {output}", + ) self.assertTrue( "Destination 'k8s_environment'" not in output, - f"Did not expect 'k8s_environment' to be reported as it defines the runner param but output was: {output}") + f"Did not expect 'k8s_environment' to be reported as it defines the runner param but output was: {output}", + ) def test_lint_destination_defines_cores_instead_of_accepted_cores(self): - tpv_config = os.path.join(os.path.dirname(__file__), 'fixtures/linter/linter-legacy-destinations.yml') + tpv_config = os.path.join( + os.path.dirname(__file__), "fixtures/linter/linter-legacy-destinations.yml" + ) output = self.call_shell_command("tpv", "-vv", "lint", tpv_config) self.assertTrue( - "lint failed" in output, - f"Expected lint to fail but output was: {output}") + "lint failed" in output, f"Expected lint to fail but output was: {output}" + ) self.assertTrue( "The destination named: local_with_mem" in output, - f"Expected an errors when cores, mem or gpu are defined on a destination but output was: {output}") + f"Expected an errors when cores, mem or gpu are defined on a destination but output was: {output}", + ) self.assertTrue( "The destination named: k8s_environment_with_cores" in output, - f"Expected an errors when cores, mem or gpu are defined on a destination but output was: {output}") + f"Expected an errors when cores, mem or gpu are defined on a destination but output was: {output}", + ) self.assertTrue( "The destination named: another_env_with_gpus" in output, - f"Expected an errors when cores, mem or gpu are defined on a destination but output was: {output}") + f"Expected an errors when cores, mem or gpu are defined on a destination but output was: {output}", + ) self.assertTrue( "working_dest" not in output, - f"Did not expect destination: `working_dest` to be in the output, but found: {output}") + f"Did not expect destination: `working_dest` to be in the output, but found: {output}", + ) def test_warn_if_default_inherits_not_marked_abstract(self): - tpv_config = os.path.join(os.path.dirname(__file__), - 'fixtures/linter/linter-default-inherits-marked-abstract.yml') + tpv_config = os.path.join( + os.path.dirname(__file__), + "fixtures/linter/linter-default-inherits-marked-abstract.yml", + ) output = self.call_shell_command("tpv", "-vvvv", "lint", tpv_config) self.assertTrue( - "WARNING" in output and "The tool named: default is marked globally as" in output, + "WARNING" in output + and "The tool named: default is marked globally as" in output, f"Expected a warning when the default abstract class for a tool is not marked abstract but output " - f"was: {output}") + f"was: {output}", + ) self.assertTrue( - "WARNING" in output and "The destination named: default is marked globally as" in output, + "WARNING" in output + and "The destination named: default is marked globally as" in output, f"Expected a warning when the default abstract class for a tool is not marked abstract but output " - f"was: {output}") + f"was: {output}", + ) def test_format_basic(self): - tpv_config = os.path.join(os.path.dirname(__file__), 'fixtures/formatter/formatter-basic.yml') + tpv_config = os.path.join( + os.path.dirname(__file__), "fixtures/formatter/formatter-basic.yml" + ) with open(tpv_config) as f: before_formatting = yaml.safe_load(f) output = self.call_shell_command("tpv", "format", tpv_config) after_formatting = yaml.safe_load(output) - self.assertTrue(before_formatting == after_formatting, - "Expected content to be the same after formatting") - self.assertTrue(OrderedDict(before_formatting) != OrderedDict(after_formatting), - "Expected ordering to be different after formatting") + self.assertTrue( + before_formatting == after_formatting, + "Expected content to be the same after formatting", + ) + self.assertTrue( + OrderedDict(before_formatting) != OrderedDict(after_formatting), + "Expected ordering to be different after formatting", + ) # keys should be in expected order - self.assertEqual(list(before_formatting.keys()), ["global", "destinations", "users", "tools"]) - self.assertEqual(list(after_formatting.keys()), ["global", "tools", "users", "destinations"]) + self.assertEqual( + list(before_formatting.keys()), ["global", "destinations", "users", "tools"] + ) + self.assertEqual( + list(after_formatting.keys()), ["global", "tools", "users", "destinations"] + ) # default inherits should be first - self.assertEqual(list(before_formatting['tools']).index('base_default'), 3) - self.assertEqual(list(after_formatting['tools']).index('base_default'), 0) - self.assertEqual(list(before_formatting['destinations']).index('base_default'), 2) - self.assertEqual(list(after_formatting['destinations']).index('base_default'), 0) + self.assertEqual(list(before_formatting["tools"]).index("base_default"), 3) + self.assertEqual(list(after_formatting["tools"]).index("base_default"), 0) + self.assertEqual( + list(before_formatting["destinations"]).index("base_default"), 2 + ) + self.assertEqual( + list(after_formatting["destinations"]).index("base_default"), 0 + ) # scheduling tags should be in expected order - self.assertEqual(list(before_formatting['tools']['base_default']['scheduling'].keys()), - ["prefer", "accept", "reject", "require"]) - self.assertEqual(list(after_formatting['tools']['base_default']['scheduling'].keys()), - ["require", "prefer", "accept", "reject"]) + self.assertEqual( + list(before_formatting["tools"]["base_default"]["scheduling"].keys()), + ["prefer", "accept", "reject", "require"], + ) + self.assertEqual( + list(after_formatting["tools"]["base_default"]["scheduling"].keys()), + ["require", "prefer", "accept", "reject"], + ) # context var order should not be changed - self.assertEqual(list(before_formatting['tools']['base_default']['context'].keys()), - ['my_context_var1', 'another_context_var2']) - self.assertEqual(list(before_formatting['tools']['base_default']['context'].keys()), - list(after_formatting['tools']['base_default']['context'].keys())) + self.assertEqual( + list(before_formatting["tools"]["base_default"]["context"].keys()), + ["my_context_var1", "another_context_var2"], + ) + self.assertEqual( + list(before_formatting["tools"]["base_default"]["context"].keys()), + list(after_formatting["tools"]["base_default"]["context"].keys()), + ) # params should be in alphabetical order - self.assertEqual(list(before_formatting['tools']['base_default']['params'].keys()), - ['nativeSpecification', 'anotherParam']) - self.assertEqual(list(after_formatting['tools']['base_default']['params'].keys()), - ['anotherParam', 'nativeSpecification']) + self.assertEqual( + list(before_formatting["tools"]["base_default"]["params"].keys()), + ["nativeSpecification", "anotherParam"], + ) + self.assertEqual( + list(after_formatting["tools"]["base_default"]["params"].keys()), + ["anotherParam", "nativeSpecification"], + ) # env should be in alphabetical order - self.assertEqual(list(before_formatting['tools']['base_default']['env'].keys()), - ['some_env', 'another_env']) - self.assertEqual(list(after_formatting['tools']['base_default']['env'].keys()), - ['another_env', 'some_env']) + self.assertEqual( + list(before_formatting["tools"]["base_default"]["env"].keys()), + ["some_env", "another_env"], + ) + self.assertEqual( + list(after_formatting["tools"]["base_default"]["env"].keys()), + ["another_env", "some_env"], + ) def test_format_rules(self): - tpv_config = os.path.join(os.path.dirname(__file__), 'fixtures/formatter/formatter-basic.yml') + tpv_config = os.path.join( + os.path.dirname(__file__), "fixtures/formatter/formatter-basic.yml" + ) with open(tpv_config) as f: before_formatting = yaml.safe_load(f) @@ -178,135 +238,286 @@ def test_format_rules(self): after_formatting = yaml.safe_load(output) # rules should remain in original order - self.assertEqual(before_formatting['tools']['.*hifiasm.*']['rules'][0]['id'], "my_rule_2") - self.assertEqual(after_formatting['tools']['.*hifiasm.*']['rules'][0]['id'], "my_rule_2") - self.assertEqual(before_formatting['tools']['.*hifiasm.*']['rules'][1]['id'], "my_rule_1") - self.assertEqual(after_formatting['tools']['.*hifiasm.*']['rules'][1]['id'], "my_rule_1") + self.assertEqual( + before_formatting["tools"][".*hifiasm.*"]["rules"][0]["id"], "my_rule_2" + ) + self.assertEqual( + after_formatting["tools"][".*hifiasm.*"]["rules"][0]["id"], "my_rule_2" + ) + self.assertEqual( + before_formatting["tools"][".*hifiasm.*"]["rules"][1]["id"], "my_rule_1" + ) + self.assertEqual( + after_formatting["tools"][".*hifiasm.*"]["rules"][1]["id"], "my_rule_1" + ) # rule elements should be in expected order - self.assertEqual(list(before_formatting['tools']['.*hifiasm.*']['rules'][0].keys()), - ['mem', 'if', 'id', 'cores', 'context', 'params', 'env', 'scheduling']) - self.assertEqual(list(after_formatting['tools']['.*hifiasm.*']['rules'][0].keys()), - ['id', 'if', 'context', 'cores', 'mem', 'env', 'params', 'scheduling']) + self.assertEqual( + list(before_formatting["tools"][".*hifiasm.*"]["rules"][0].keys()), + ["mem", "if", "id", "cores", "context", "params", "env", "scheduling"], + ) + self.assertEqual( + list(after_formatting["tools"][".*hifiasm.*"]["rules"][0].keys()), + ["id", "if", "context", "cores", "mem", "env", "params", "scheduling"], + ) # scheduling tags within rules should be in expected order - self.assertEqual(list(before_formatting['tools']['.*hifiasm.*']['rules'][0]['scheduling'].keys()), - ['accept', 'prefer', 'reject', 'require']) - self.assertEqual(list(after_formatting['tools']['.*hifiasm.*']['rules'][0]['scheduling'].keys()), - ["require", "prefer", "accept", "reject"]) + self.assertEqual( + list( + before_formatting["tools"][".*hifiasm.*"]["rules"][0][ + "scheduling" + ].keys() + ), + ["accept", "prefer", "reject", "require"], + ) + self.assertEqual( + list( + after_formatting["tools"][".*hifiasm.*"]["rules"][0][ + "scheduling" + ].keys() + ), + ["require", "prefer", "accept", "reject"], + ) # context var order should not be changed - self.assertEqual(list(before_formatting['tools']['.*hifiasm.*']['rules'][0]['context'].keys()), - ['myvar', 'anothervar']) - self.assertEqual(list(before_formatting['tools']['.*hifiasm.*']['rules'][0]['context'].keys()), - list(after_formatting['tools']['.*hifiasm.*']['rules'][0]['context'].keys())) + self.assertEqual( + list( + before_formatting["tools"][".*hifiasm.*"]["rules"][0]["context"].keys() + ), + ["myvar", "anothervar"], + ) + self.assertEqual( + list( + before_formatting["tools"][".*hifiasm.*"]["rules"][0]["context"].keys() + ), + list( + after_formatting["tools"][".*hifiasm.*"]["rules"][0]["context"].keys() + ), + ) # params order should not be changed - self.assertEqual(list(before_formatting['tools']['.*hifiasm.*']['rules'][0]['params'].keys()), - ['MY_PARAM2', 'MY_PARAM1']) - self.assertEqual(list(after_formatting['tools']['.*hifiasm.*']['rules'][0]['params'].keys()), - ['MY_PARAM1', 'MY_PARAM2']) + self.assertEqual( + list( + before_formatting["tools"][".*hifiasm.*"]["rules"][0]["params"].keys() + ), + ["MY_PARAM2", "MY_PARAM1"], + ) + self.assertEqual( + list(after_formatting["tools"][".*hifiasm.*"]["rules"][0]["params"].keys()), + ["MY_PARAM1", "MY_PARAM2"], + ) # env order should not be changed - self.assertEqual(list(before_formatting['tools']['.*hifiasm.*']['rules'][0]['env'].keys()), - ['SOME_ENV2', 'SOME_ENV1']) - self.assertEqual(list(after_formatting['tools']['.*hifiasm.*']['rules'][0]['env'].keys()), - ['SOME_ENV1', 'SOME_ENV2']) + self.assertEqual( + list(before_formatting["tools"][".*hifiasm.*"]["rules"][0]["env"].keys()), + ["SOME_ENV2", "SOME_ENV1"], + ) + self.assertEqual( + list(after_formatting["tools"][".*hifiasm.*"]["rules"][0]["env"].keys()), + ["SOME_ENV1", "SOME_ENV2"], + ) def test_format_error(self): - tpv_config = os.path.join(os.path.dirname(__file__), 'fixtures/file-does-not-exist.yml') + tpv_config = os.path.join( + os.path.dirname(__file__), "fixtures/file-does-not-exist.yml" + ) output = self.call_shell_command("tpv", "format", tpv_config) self.assertTrue( "format failed" in output, - f"Expected format to fail but output was: {output}") + f"Expected format to fail but output was: {output}", + ) def test_format_string_block_handling(self): - tpv_config = os.path.join(os.path.dirname(__file__), 'fixtures/formatter/formatter-string-types-input.yml') + tpv_config = os.path.join( + os.path.dirname(__file__), + "fixtures/formatter/formatter-string-types-input.yml", + ) output = self.call_shell_command("tpv", "format", tpv_config) - with open(os.path.join(os.path.dirname(__file__), - 'fixtures/formatter/formatter-string-types-formatted.yml')) as f: + with open( + os.path.join( + os.path.dirname(__file__), + "fixtures/formatter/formatter-string-types-formatted.yml", + ) + ) as f: expected_output = f.read() self.assertEqual(output, expected_output) def test_format_lengthy_key_handling(self): - tpv_config = os.path.join(os.path.dirname(__file__), 'fixtures/formatter/formatter-long-key-input.yml') + tpv_config = os.path.join( + os.path.dirname(__file__), "fixtures/formatter/formatter-long-key-input.yml" + ) output = self.call_shell_command("tpv", "format", tpv_config) - with open(os.path.join(os.path.dirname(__file__), 'fixtures/formatter/formatter-long-key-formatted.yml')) as f: + with open( + os.path.join( + os.path.dirname(__file__), + "fixtures/formatter/formatter-long-key-formatted.yml", + ) + ) as f: expected_output = f.read() self.assertEqual(output, expected_output) def test_format_tool_sort_order(self): - tpv_config = os.path.join(os.path.dirname(__file__), 'fixtures/formatter/formatter-tool-sort-order-input.yml') + tpv_config = os.path.join( + os.path.dirname(__file__), + "fixtures/formatter/formatter-tool-sort-order-input.yml", + ) output = self.call_shell_command("tpv", "format", tpv_config) - with open(os.path.join(os.path.dirname(__file__), - 'fixtures/formatter/formatter-tool-sort-order-formatted.yml')) as f: + with open( + os.path.join( + os.path.dirname(__file__), + "fixtures/formatter/formatter-tool-sort-order-formatted.yml", + ) + ) as f: expected_output = f.read() self.assertEqual(output, expected_output) @pytest.mark.usefixtures("chdir_tests") def test_dry_run_tpv_config_from_job_conf_default_tool(self): - job_config = 'fixtures/job_conf_dry_run.yml' + job_config = "fixtures/job_conf_dry_run.yml" output = self.call_shell_command("tpv", "dry-run", "--job-conf", job_config) - self.assertTrue("id: local" in output, - f"Expected 'id: local' destination\n{output}") + self.assertTrue( + "id: local" in output, f"Expected 'id: local' destination\n{output}" + ) @pytest.mark.usefixtures("chdir_tests") def test_dry_run_tpv_config_from_job_conf_pulsar_tool(self): - job_config = 'fixtures/job_conf_dry_run.yml' - output = self.call_shell_command("tpv", "dry-run", "--job-conf", job_config, "--tool", "bwa") - self.assertTrue("id: k8s_environment" in output, - f"Expected 'id: k8s_environment' destination\n{output}") + job_config = "fixtures/job_conf_dry_run.yml" + output = self.call_shell_command( + "tpv", "dry-run", "--job-conf", job_config, "--tool", "bwa" + ) + self.assertTrue( + "id: k8s_environment" in output, + f"Expected 'id: k8s_environment' destination\n{output}", + ) @pytest.mark.usefixtures("chdir_tests") def test_dry_run_tpv_config_from_job_conf_unschedulable_tool(self): - job_config = 'fixtures/job_conf_dry_run.yml' + job_config = "fixtures/job_conf_dry_run.yml" with self.assertRaises(JobMappingException): - self.call_shell_command("tpv", "dry-run", "--job-conf", job_config, "--tool", "unschedulable_tool") + self.call_shell_command( + "tpv", + "dry-run", + "--job-conf", + job_config, + "--tool", + "unschedulable_tool", + ) @pytest.mark.usefixtures("chdir_tests") def test_dry_run_tpv_config_from_job_conf_regex_tool(self): - job_config = 'fixtures/job_conf_dry_run.yml' - output = self.call_shell_command("tpv", "dry-run", "--job-conf", job_config, "--tool", "regex_tool/hoopy_frood") - self.assertTrue("id: k8s_environment" in output, - f"Expected 'id: k8s_environment' destination\n{output}") + job_config = "fixtures/job_conf_dry_run.yml" + output = self.call_shell_command( + "tpv", + "dry-run", + "--job-conf", + job_config, + "--tool", + "regex_tool/hoopy_frood", + ) + self.assertTrue( + "id: k8s_environment" in output, + f"Expected 'id: k8s_environment' destination\n{output}", + ) def test_dry_run_input_size_piddling(self): - job_config = os.path.join(os.path.dirname(__file__), 'fixtures/job_conf_dry_run.yml') - tpv_config = os.path.join(os.path.dirname(__file__), 'fixtures/mapping-rules.yml') + job_config = os.path.join( + os.path.dirname(__file__), "fixtures/job_conf_dry_run.yml" + ) + tpv_config = os.path.join( + os.path.dirname(__file__), "fixtures/mapping-rules.yml" + ) with self.assertRaises(JobMappingException): - self.call_shell_command("tpv", "dry-run", "--job-conf", job_config, tpv_config) + self.call_shell_command( + "tpv", "dry-run", "--job-conf", job_config, tpv_config + ) def test_dry_run_conditional_input_size_ok(self): - job_config = os.path.join(os.path.dirname(__file__), 'fixtures/job_conf_dry_run.yml') - tpv_config = os.path.join(os.path.dirname(__file__), 'fixtures/mapping-rules.yml') + job_config = os.path.join( + os.path.dirname(__file__), "fixtures/job_conf_dry_run.yml" + ) + tpv_config = os.path.join( + os.path.dirname(__file__), "fixtures/mapping-rules.yml" + ) output = self.call_shell_command( - "tpv", "dry-run", "--job-conf", job_config, "--tool", "bwa", "--input-size", "6", tpv_config) - self.assertTrue("id: k8s_environment" in output, - f"Expected 'id: k8s_environment' destination\n{output}") + "tpv", + "dry-run", + "--job-conf", + job_config, + "--tool", + "bwa", + "--input-size", + "6", + tpv_config, + ) + self.assertTrue( + "id: k8s_environment" in output, + f"Expected 'id: k8s_environment' destination\n{output}", + ) def test_dry_run_conditional_input_size_too_big(self): - job_config = os.path.join(os.path.dirname(__file__), 'fixtures/job_conf_dry_run.yml') - tpv_config = os.path.join(os.path.dirname(__file__), 'fixtures/mapping-rules.yml') + job_config = os.path.join( + os.path.dirname(__file__), "fixtures/job_conf_dry_run.yml" + ) + tpv_config = os.path.join( + os.path.dirname(__file__), "fixtures/mapping-rules.yml" + ) with self.assertRaises(JobMappingException): self.call_shell_command( - "tpv", "dry-run", "--job-conf", job_config, "--tool", "bwa", "--input-size", "20", tpv_config) + "tpv", + "dry-run", + "--job-conf", + job_config, + "--tool", + "bwa", + "--input-size", + "20", + tpv_config, + ) def test_dry_run_user_email(self): - job_config = os.path.join(os.path.dirname(__file__), 'fixtures/job_conf_dry_run.yml') - tpv_config = os.path.join(os.path.dirname(__file__), 'fixtures/mapping-rules.yml') + job_config = os.path.join( + os.path.dirname(__file__), "fixtures/job_conf_dry_run.yml" + ) + tpv_config = os.path.join( + os.path.dirname(__file__), "fixtures/mapping-rules.yml" + ) output = self.call_shell_command( - "tpv", "dry-run", "--job-conf", job_config, "--input-size", "6", "--user", "krikkitrobot@planetkrikkit.org", - tpv_config) - self.assertTrue("name: TEST_JOB_SLOTS_USER" in output, - f"Expected 'name: TEST_JOB_SLOTS_USER' in destination\n{output}") - + "tpv", + "dry-run", + "--job-conf", + job_config, + "--input-size", + "6", + "--user", + "krikkitrobot@planetkrikkit.org", + tpv_config, + ) + self.assertTrue( + "name: TEST_JOB_SLOTS_USER" in output, + f"Expected 'name: TEST_JOB_SLOTS_USER' in destination\n{output}", + ) + def test_dry_run_tool_with_version(self): - job_config = os.path.join(os.path.dirname(__file__), 'fixtures/job_conf_dry_run.yml') - tpv_config = os.path.join(os.path.dirname(__file__), 'fixtures/mapping-rules.yml') + job_config = os.path.join( + os.path.dirname(__file__), "fixtures/job_conf_dry_run.yml" + ) + tpv_config = os.path.join( + os.path.dirname(__file__), "fixtures/mapping-rules.yml" + ) output = self.call_shell_command( - "tpv", "dry-run", "--job-conf", job_config, "--input-size", "6", "--user", "krikkitrobot@planetkrikkit.org", - "--tool", "toolshed.g2.bx.psu.edu/repos/iuc/bwameth/bwameth/42", - tpv_config) - self.assertTrue("bwameth_is_great" in output, - f"Expected 'bwameth_is_great' in destination\n{output}") - + "tpv", + "dry-run", + "--job-conf", + job_config, + "--input-size", + "6", + "--user", + "krikkitrobot@planetkrikkit.org", + "--tool", + "toolshed.g2.bx.psu.edu/repos/iuc/bwameth/bwameth/42", + tpv_config, + ) + self.assertTrue( + "bwameth_is_great" in output, + f"Expected 'bwameth_is_great' in destination\n{output}", + ) diff --git a/tox.ini b/tox.ini index 14d82af..f69fd1d 100644 --- a/tox.ini +++ b/tox.ini @@ -34,3 +34,6 @@ deps = commands = isort tpv black tpv +deps = + isort + black From 39acaf86658f338f5fc06c8a97cceb4b1327adc6 Mon Sep 17 00:00:00 2001 From: nuwang <2070605+nuwang@users.noreply.github.com> Date: Sun, 25 Aug 2024 19:13:53 +0530 Subject: [PATCH 30/30] Disable slow tests --- .github/workflows/tests.yaml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 7091237..57702d8 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -81,7 +81,8 @@ jobs: run: pip install tox - name: Run tox - run: tox -e py${{ matrix.python-version }} -- --runslow + # run: tox -e py${{ matrix.python-version }} -- --runslow + run: tox -e py${{ matrix.python-version }} env: PYTHONUNBUFFERED: "True"