diff --git a/prp/cli.py b/prp/cli.py index 9ef9586..41acd8b 100644 --- a/prp/cli.py +++ b/prp/cli.py @@ -7,7 +7,7 @@ from pydantic import ValidationError from .models.metadata import SoupVersion, SoupType -from .models.phenotype import ElementType +from .models.phenotype import ElementType, ElementStressSubtype from .models.qc import QcMethodIndex from .models.sample import MethodIndex, PipelineResult from .models.typing import TypingMethod @@ -63,12 +63,18 @@ def cli(): "-k", "--kraken", type=click.File(), help="Kraken species annotation results" ) @click.option( - "-a", "--amrfinder", type=str, help="amrfinderplus anti-microbial resistance results" + "-a", + "--amrfinder", + type=str, + help="amrfinderplus anti-microbial resistance results", ) @click.option("-m", "--mlst", type=click.File(), help="MLST prediction results") @click.option("-c", "--cgmlst", type=click.File(), help="cgMLST prediction results") @click.option( - "-v", "--virulence", type=click.File(), help="Virulence factor prediction results" + "-v", + "--virulencefinder", + type=click.File(), + help="Virulence factor prediction results", ) @click.option( "-r", @@ -89,7 +95,7 @@ def create_output( kraken, mlst, cgmlst, - virulence, + virulencefinder, amrfinder, resfinder, quality, @@ -102,7 +108,7 @@ def create_output( LOG.info("Start generating pipeline result json") results = { "run_metadata": { - "run": parse_run_info(run_metadata), + "run": parse_run_info(run_metadata), "databases": get_database_info(process_metadata), }, "qc": [], @@ -133,7 +139,10 @@ def create_output( if resfinder: LOG.info("Parse resistance results") pred_res = json.load(resfinder) - methods = [ElementType.AMR, ElementType.BIOCIDE, ElementType.HEAT] + methods = [ + ElementType.AMR, + ElementType.STRESS, + ] for method in methods: res: MethodIndex = parse_resfinder_amr_pred(pred_res, method) # exclude empty results from output @@ -145,9 +154,7 @@ def create_output( LOG.info("Parse amr results") methods = [ ElementType.AMR, - ElementType.BIOCIDE, - ElementType.METAL, - ElementType.HEAT, + ElementType.STRESS, ] for method in methods: res: MethodIndex = parse_amrfinder_amr_pred(amrfinder, method) @@ -156,10 +163,11 @@ def create_output( results["element_type_result"].append(vir) # get virulence factors in sample - if virulence: - LOG.info("Parse virulence results") - vir: MethodIndex = parse_virulencefinder_vir_pred(virulence) - results["element_type_result"].append(vir) + if virulencefinder: + LOG.info("Parse virulencefinder results") + vir: MethodIndex | None = parse_virulencefinder_vir_pred(virulencefinder) + if vir is not None: + results["element_type_result"].append(vir) # species id if kraken: @@ -175,7 +183,7 @@ def create_output( pred_res = json.load(mykrobe) results["run_metadata"]["databases"].append( SoupVersion( - name="mykrobe-predictor", + name="mykrobe-predictor", version=pred_res[sample_id]["version"]["mykrobe-predictor"], type=SoupType.DB, ) @@ -211,9 +219,8 @@ def create_output( try: output_data = PipelineResult( - sample_id=sample_id, - schema_version=OUTPUT_SCHEMA_VERSION, - **results) + sample_id=sample_id, schema_version=OUTPUT_SCHEMA_VERSION, **results + ) except ValidationError as err: click.secho("Input failed Validation", fg="red") click.secho(err) diff --git a/prp/models/base.py b/prp/models/base.py index 3a90570..9466ecd 100644 --- a/prp/models/base.py +++ b/prp/models/base.py @@ -6,7 +6,7 @@ class RWModel(BaseModel): # pylint: disable=too-few-public-methods """Base model for read/ write operations""" model_config = ConfigDict( - allow_population_by_alias = True, - populate_by_name = True, - use_enum_values = True, + allow_population_by_alias=True, + populate_by_name=True, + use_enum_values=True, ) diff --git a/prp/models/metadata.py b/prp/models/metadata.py index 325b8c1..6bb3271 100644 --- a/prp/models/metadata.py +++ b/prp/models/metadata.py @@ -29,8 +29,14 @@ class RunInformation(RWModel): pipeline: str version: str commit: str - analysis_profile: str = Field(..., alias="analysisProfile") - configuration_files: List[str] = Field(..., alias="configurationFiles") + analysis_profile: str = Field( + ..., + alias="analysisProfile", + description="The analysis profile used when starting the pipeline", + ) + configuration_files: List[str] = Field( + ..., alias="configurationFiles", description="Nextflow configuration used" + ) workflow_name: str sample_name: str sequencing_platform: str diff --git a/prp/models/phenotype.py b/prp/models/phenotype.py index 3695b14..b6ba400 100644 --- a/prp/models/phenotype.py +++ b/prp/models/phenotype.py @@ -2,11 +2,18 @@ from enum import Enum from typing import Dict, List, Optional, Union -from pydantic import BaseModel +from pydantic import BaseModel, Field from .base import RWModel +class SequenceStand(Enum): + """Definition of DNA strand.""" + + FORWARD = "+" + REVERSE = "-" + + class PredictionSoftware(Enum): """Container for prediciton software names.""" @@ -29,89 +36,130 @@ class ElementType(Enum): """Categories of resistance and virulence genes.""" AMR = "AMR" - ACID = "STRESS_ACID" - BIOCIDE = "STRESS_BIOCIDE" - METAL = "STRESS_METAL" - HEAT = "STRESS_HEAT" + STRESS = "STRESS" VIR = "VIRULENCE" +class ElementStressSubtype(Enum): + """Categories of resistance and virulence genes.""" + + ACID = "ACID" + BIOCIDE = "BIOCIDE" + METAL = "METAL" + HEAT = "HEAT" + + +class ElementAmrSubtype(Enum): + """Categories of resistance and virulence genes.""" + + AMR = "AMR" + + +class ElementVirulenceSubtype(Enum): + """Categories of resistance and virulence genes.""" + + VIR = "VIRULENCE" + class DatabaseReference(RWModel): """Refernece to a database.""" - ref_database: Optional[str] - ref_id: Optional[str] + ref_database: Optional[str] = None + ref_id: Optional[str] = None class GeneBase(BaseModel): """Container for gene information""" - accession: Optional[str] + accession: Optional[str] = None # prediction info - depth: Optional[float] - identity: Optional[float] - coverage: Optional[float] - ref_start_pos: Optional[int] - ref_end_pos: Optional[int] - ref_gene_length: Optional[int] + depth: Optional[float] = None + identity: Optional[float] = None + coverage: Optional[float] = None + ref_start_pos: Optional[int] = None + ref_end_pos: Optional[int] = None + ref_gene_length: Optional[int] = Field( + default=None, + alias="target_length", + description="The length of the query protein or gene.", + ) alignment_length: Optional[int] # amrfinder extra info - contig_id: Optional[str] - gene_symbol: Optional[str] - sequence_name: Optional[str] - ass_start_pos: Optional[int] - ass_end_pos: Optional[int] - strand: Optional[str] - element_type: Optional[str] - element_subtype: Optional[str] - target_length: Optional[int] - res_class: Optional[str] - res_subclass: Optional[str] - method: Optional[str] - close_seq_name: Optional[str] + contig_id: Optional[str] = None + gene_symbol: Optional[str] = None + sequence_name: Optional[str] = Field( + default=None, description="Reference sequence name" + ) + ass_start_pos: Optional[int] = Field( + default=None, description="Start position on the assembly" + ) + ass_end_pos: Optional[int] = Field( + default=None, description="End position on the assembly" + ) + strand: Optional[SequenceStand] = None + element_type: ElementType = Field( + description="The predominant function fo the gene." + ) + element_subtype: ElementStressSubtype | ElementAmrSubtype | ElementVirulenceSubtype = Field( + description="Further functional categorization of the genes." + ) + res_class: Optional[str] = None + res_subclass: Optional[str] = None + method: Optional[str] = Field( + default=None, description="Generic description of the prediction method" + ) + close_seq_name: Optional[str] = Field( + default=None, + description="Name of the closest competing hit if there are multiple equaly good hits", + ) class ResistanceGene(GeneBase, DatabaseReference): """Container for resistance gene information""" - phenotypes: List[str] + phenotypes: List[str] = [] class VirulenceGene(GeneBase, DatabaseReference): """Container for virulence gene information""" - virulence_category: Optional[str] + virulence_category: Optional[str] = None class VariantBase(DatabaseReference): """Container for mutation information""" - variant_type: Optional[VariantType] - genes: Optional[List[str]] - position: Optional[int] - ref_nt: Optional[str] - alt_nt: Optional[str] + variant_type: VariantType + genes: List[str] + position: int + ref_nt: str + alt_nt: str # prediction info - depth: Optional[float] - contig_id: Optional[str] - gene_symbol: Optional[str] - sequence_name: Optional[str] - ass_start_pos: Optional[int] - ass_end_pos: Optional[int] - strand: Optional[str] - element_type: Optional[str] - element_subtype: Optional[str] - target_length: Optional[int] - res_class: Optional[str] - res_subclass: Optional[str] - method: Optional[str] - close_seq_name: Optional[str] - type: Optional[str] - change: Optional[str] - nucleotide_change: Optional[str] - protein_change: Optional[str] - annotation: Optional[List[Dict]] - drugs: Optional[List[Dict]] + depth: Optional[float] = None + contig_id: Optional[str] = None + gene_symbol: Optional[str] = None + sequence_name: Optional[str] = Field( + default=None, description="Reference sequence name" + ) + ass_start_pos: Optional[int] = Field( + default=None, description="Assembly start position" + ) + ass_end_pos: Optional[int] = Field( + default=None, description="Assembly end position" + ) + strand: Optional[SequenceStand] = None + element_type: Optional[ElementType] = None + element_subtype: Optional[str] = None + target_length: Optional[int] = None + res_class: Optional[str] = None + res_subclass: Optional[str] = None + method: Optional[str] = None + close_seq_name: Optional[str] = None + type: Optional[str] = None + change: Optional[str] = None + nucleotide_change: Optional[str] = None + protein_change: Optional[str] = None + annotation: Optional[List[Dict]] = None + drugs: Optional[List[Dict]] = None class ResistanceVariant(VariantBase): diff --git a/prp/models/qc.py b/prp/models/qc.py index 35fa1ad..73d3941 100644 --- a/prp/models/qc.py +++ b/prp/models/qc.py @@ -1,6 +1,6 @@ """QC data models.""" from enum import Enum -from typing import Dict, Union +from typing import Dict from pydantic import BaseModel @@ -50,5 +50,5 @@ class QcMethodIndex(RWModel): """ software: QcSoftware - version: Union[str, None] - result: Union[QuastQcResult, PostAlignQcResult] + version: str | None = None + result: QuastQcResult | PostAlignQcResult diff --git a/prp/parse/metadata.py b/prp/parse/metadata.py index 8569a8f..815559f 100644 --- a/prp/parse/metadata.py +++ b/prp/parse/metadata.py @@ -6,6 +6,7 @@ LOG = logging.getLogger(__name__) + def get_database_info(process_metadata: List[TextIO]) -> List[SoupVersion]: """Get database or software information. @@ -35,4 +36,4 @@ def parse_run_info(run_metadata: TextIO) -> RunInformation: """ LOG.info("Parse run metadata.") run_info = RunInformation(**json.load(run_metadata)) - return run_info \ No newline at end of file + return run_info diff --git a/prp/parse/phenotype/amrfinder.py b/prp/parse/phenotype/amrfinder.py index 5a2a86e..279b7e1 100644 --- a/prp/parse/phenotype/amrfinder.py +++ b/prp/parse/phenotype/amrfinder.py @@ -1,6 +1,7 @@ """Parse AMRfinder plus result.""" from typing import Tuple +import logging import pandas as pd from ...models.phenotype import ElementType, ElementTypeResult @@ -9,24 +10,19 @@ from ...models.sample import MethodIndex from .utils import _default_resistance +LOG = logging.getLogger(__name__) + def _parse_amrfinder_amr_results(predictions: dict) -> Tuple[ResistanceGene, ...]: """Parse amrfinder prediction results from amrfinderplus.""" genes = [] for prediction in predictions: gene = ResistanceGene( - virulence_category=None, accession=prediction["close_seq_accn"], - depth=None, identity=prediction["ref_seq_identity"], coverage=prediction["ref_seq_cov"], - ref_start_pos=None, - ref_end_pos=None, ref_gene_length=prediction["ref_seq_len"], alignment_length=prediction["align_len"], - ref_database=None, - phenotypes=[], - ref_id=None, contig_id=prediction["contig_id"], gene_symbol=prediction["gene_symbol"], sequence_name=prediction["sequence_name"], @@ -42,10 +38,10 @@ def _parse_amrfinder_amr_results(predictions: dict) -> Tuple[ResistanceGene, ... close_seq_name=prediction["close_seq_name"], ) genes.append(gene) - return ElementTypeResult(phenotypes=[], genes=genes, mutations=[]) + return ElementTypeResult(phenotypes={}, genes=genes, mutations=[]) -def parse_amrfinder_amr_pred(file, element_type: str) -> ElementTypeResult: +def parse_amrfinder_amr_pred(file: str, element_type: ElementType) -> ElementTypeResult: """Parse amrfinder resistance prediction results.""" LOG.info("Parsing amrfinder amr prediction") with open(file, "rb") as tsvfile: @@ -68,27 +64,11 @@ def parse_amrfinder_amr_pred(file, element_type: str) -> ElementTypeResult: ) hits = hits.drop(columns=["Protein identifier", "HMM id", "HMM description"]) hits = hits.where(pd.notnull(hits), None) - if element_type == ElementType.AMR: - predictions = hits[hits["element_type"] == "AMR"].to_dict(orient="records") - results: ElementTypeResult = _parse_amrfinder_amr_results(predictions) - elif element_type == ElementType.HEAT: - predictions = hits[(hits["element_subtype"] == "HEAT")].to_dict( - orient="records" - ) - results: ElementTypeResult = _parse_amrfinder_amr_results(predictions) - elif element_type == ElementType.BIOCIDE: - predictions = hits[ - (hits["element_subtype"] == "ACID") - & (hits["element_subtype"] == "BIOCIDE") - ].to_dict(orient="records") - results: ElementTypeResult = _parse_amrfinder_amr_results(predictions) - elif element_type == ElementType.METAL: - predictions = hits[hits["element_subtype"] == "METAL"].to_dict( - orient="records" - ) - results: ElementTypeResult = _parse_amrfinder_amr_results(predictions) - else: - results = _default_resistance() + # group predictions based on their element type + predictions = (hits + .loc[lambda row: row.element_type == element_type] + .to_dict(orient="records")) + results: ElementTypeResult = _parse_amrfinder_amr_results(predictions) return MethodIndex(type=element_type, result=results, software=Software.AMRFINDER) @@ -125,7 +105,7 @@ def _parse_amrfinder_vir_results(predictions: dict) -> ElementTypeResult: close_seq_name=prediction["close_seq_name"], ) genes.append(gene) - return ElementTypeResult(phenotypes=[], genes=genes, mutations=[]) + return ElementTypeResult(phenotypes={}, genes=genes, mutations=[]) def parse_amrfinder_vir_pred(file: str): diff --git a/prp/parse/phenotype/resfinder.py b/prp/parse/phenotype/resfinder.py index ce8e86b..d417426 100644 --- a/prp/parse/phenotype/resfinder.py +++ b/prp/parse/phenotype/resfinder.py @@ -1,16 +1,44 @@ """Parse resfinder results.""" import logging -from typing import Any, Dict, Tuple +from typing import Any, Dict, Tuple, List +from itertools import chain from ...models.metadata import SoupVersions -from ...models.phenotype import ElementType, ElementTypeResult +from ...models.phenotype import ElementType, ElementAmrSubtype, ElementStressSubtype, ElementTypeResult from ...models.phenotype import PredictionSoftware as Software -from ...models.phenotype import ResistanceGene, ResistanceVariant +from ...models.phenotype import ResistanceGene, ResistanceVariant, VariantType from ...models.sample import MethodIndex from .utils import _default_resistance LOG = logging.getLogger(__name__) +STRESS_FACTORS = { + ElementStressSubtype.BIOCIDE: [ + "formaldehyde", + "benzylkonium chloride", + "ethidium bromide", + "chlorhexidine", + "cetylpyridinium chloride", + "hydrogen peroxide"], + ElementStressSubtype.HEAT: ["temperature"] +} + + +def _assign_res_subtype(prediction: Dict[str, Any], element_type: ElementType) -> ElementStressSubtype | None: + """Assign element subtype from resfindere prediction.""" + assigned_subtype = None + if element_type == ElementType.STRESS: + for sub_type, phenotypes in STRESS_FACTORS.items(): + # get intersection of subtype phenotypes and predicted phenos + intersect = set(phenotypes) & set(prediction["phenotypes"]) + if len(intersect) > 0: + assigned_subtype = sub_type + elif element_type == ElementType.AMR: + assigned_subtype = ElementAmrSubtype.AMR + else: + LOG.warning(f"Dont know how to assign subtype for {element_type}") + return assigned_subtype + def _get_resfinder_amr_sr_profie(resfinder_result, limit_to_phenotypes=None): """Get resfinder susceptibility/resistance profile.""" @@ -34,13 +62,11 @@ def _get_resfinder_amr_sr_profie(resfinder_result, limit_to_phenotypes=None): def _parse_resfinder_amr_genes( resfinder_result, limit_to_phenotypes=None -) -> Tuple[ResistanceGene, ...]: +) -> List[ResistanceGene]: """Get resistance genes from resfinder result.""" results = [] - if not "seq_regions" in resfinder_result: - results = _default_resistance().genes - return results + return _default_resistance().genes for info in resfinder_result["seq_regions"].values(): # Get only acquired resistance genes @@ -53,6 +79,11 @@ def _parse_resfinder_amr_genes( if len(intersect) == 0: continue + # get element type by peeking at first phenotype + pheno = info["phenotypes"][0] + res_category = resfinder_result["phenotypes"][pheno]["category"].upper() + category = ElementType(res_category) + # store results gene = ResistanceGene( gene_symbol=info["name"], @@ -67,18 +98,8 @@ def _parse_resfinder_amr_genes( phenotypes=info["phenotypes"], ref_database=info["ref_database"][0], ref_id=info["ref_id"], - contig_id=None, - sequence_name=None, - ass_start_pos=None, - ass_end_pos=None, - strand=None, - element_type=None, - element_subtype=None, - target_length=None, - res_class=None, - res_subclass=None, - method=None, - close_seq_name=None, + element_type=category, + element_subtype=_assign_res_subtype(info, category), ) results.append(gene) return results @@ -105,11 +126,11 @@ def _parse_resfinder_amr_variants( info["depth"] = 0 # translate variation type bools into classifier if info["substitution"]: - var_type = "substitution" + var_type = VariantType.SUBSTITUTION elif info["insertion"]: - var_type = "insertion" + var_type = VariantType.INSERTION elif info["deletion"]: - var_type = "deletion" + var_type = VariantType.DELETION else: raise ValueError("Output has no known mutation type") if not "seq_regions" in info: @@ -131,37 +152,25 @@ def _parse_resfinder_amr_variants( def parse_resfinder_amr_pred( - prediction: Dict[str, Any], resistance_category + prediction: Dict[str, Any], resistance_category: ElementType ) -> Tuple[SoupVersions, ElementTypeResult]: """Parse resfinder resistance prediction results.""" # resfinder missclassifies resistance the param amr_category by setting all to amr LOG.info("Parsing resistance prediction") # parse resistance based on the category + stress_factors = list(chain(*STRESS_FACTORS.values())) categories = { - ElementType.BIOCIDE: [ - "formaldehyde", - "benzylkonium chloride", - "ethidium bromide", - "chlorhexidine", - "cetylpyridinium chloride", - "hydrogen peroxide", - ], - ElementType.HEAT: ["temperature"], + ElementType.STRESS: stress_factors, + ElementType.AMR: list(set(prediction["phenotypes"]) - set(stress_factors)) } - categories[ElementType.AMR] = list( - set(prediction["phenotypes"]) - - set(categories[ElementType.BIOCIDE] + categories[ElementType.HEAT]) - ) - # parse resistance + sr_profile = _get_resfinder_amr_sr_profie( + prediction, categories[resistance_category] + ) + res_genes = _parse_resfinder_amr_genes(prediction, categories[resistance_category]) + res_mut = _parse_resfinder_amr_variants(prediction, categories[resistance_category]) resistance = ElementTypeResult( - phenotypes=_get_resfinder_amr_sr_profie( - prediction, categories[resistance_category] - ), - genes=_parse_resfinder_amr_genes(prediction, categories[resistance_category]), - mutations=_parse_resfinder_amr_variants( - prediction, categories[resistance_category] - ), + phenotypes=sr_profile, genes=res_genes, mutations=res_mut ) return MethodIndex( type=resistance_category, software=Software.RESFINDER, result=resistance diff --git a/prp/parse/phenotype/virulencefinder.py b/prp/parse/phenotype/virulencefinder.py index 43720b0..6736769 100644 --- a/prp/parse/phenotype/virulencefinder.py +++ b/prp/parse/phenotype/virulencefinder.py @@ -62,4 +62,4 @@ def parse_virulencefinder_vir_pred(file: str) -> ElementTypeResult | None: results: ElementTypeResult = _parse_virulencefinder_vir_results(pred) return MethodIndex( type=ElementType.VIR, software=Software.VIRFINDER, result=results - ) \ No newline at end of file + )