diff --git a/malariagen_data/anoph/dipclust.py b/malariagen_data/anoph/dipclust.py index bb9c6781..8da94030 100644 --- a/malariagen_data/anoph/dipclust.py +++ b/malariagen_data/anoph/dipclust.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple, Sequence +from typing import Optional, Tuple import allel # type: ignore import numpy as np @@ -549,11 +549,52 @@ def _dipclust_concat_subplots( return fig + def _insert_dipclust_snp_trace( + self, + *, + figures, + subplot_heights, + snp_row_height: plotly_params.height = 25, + transcript: base_params.transcript, + snp_query: Optional[base_params.snp_query] = AA_CHANGE_QUERY, + sample_sets: Optional[base_params.sample_sets], + sample_query: Optional[base_params.sample_query], + sample_query_options: Optional[base_params.sample_query_options], + site_mask: Optional[base_params.site_mask], + dendro_sample_id_order: np.ndarray, + snp_filter_min_maf: float, + snp_colorscale: Optional[plotly_params.color_continuous_scale], + chunks: base_params.chunks = base_params.native_chunks, + inline_array: base_params.inline_array = base_params.inline_array_default, + ): + snp_trace, n_snps_transcript = self._dipclust_snp_trace( + transcript=transcript, + sample_sets=sample_sets, + sample_query=sample_query, + sample_query_options=sample_query_options, + snp_query=snp_query, + site_mask=site_mask, + dendro_sample_id_order=dendro_sample_id_order, + snp_filter_min_maf=snp_filter_min_maf, + snp_colorscale=snp_colorscale, + chunks=chunks, + inline_array=inline_array, + ) + + if snp_trace: + figures.append(snp_trace) + subplot_heights.append(snp_row_height * n_snps_transcript) + else: + print( + f"No SNPs were found below {snp_filter_min_maf} allele frequency. Omitting SNP genotype plot." + ) + return figures, subplot_heights + @doc( summary="Perform diplotype clustering, annotated with heterozygosity, gene copy number and amino acid variants.", parameters=dict( heterozygosity="Plot heterozygosity track.", - snp_transcripts="Plot amino acid variants for these transcripts.", + snp_transcript="Plot amino acid variants for these transcripts.", cnv_region="Plot gene CNV calls for this region.", snp_filter_min_maf="Filter amino acid variants with alternate allele frequency below this threshold.", ), @@ -563,7 +604,7 @@ def plot_diplotype_clustering_advanced( region: base_params.regions, heterozygosity: bool = True, heterozygosity_colorscale: plotly_params.color_continuous_scale = "Greys", - snp_transcripts: Sequence[base_params.transcript] = [], + snp_transcript: dipclust_params.snp_transcript = None, snp_colorscale: plotly_params.color_continuous_scale = "Greys", snp_filter_min_maf: float = 0.05, snp_query: Optional[base_params.snp_query] = AA_CHANGE_QUERY, @@ -603,7 +644,7 @@ def plot_diplotype_clustering_advanced( chunks: base_params.chunks = base_params.native_chunks, inline_array: base_params.inline_array = base_params.inline_array_default, ): - if cohort_size and snp_transcripts: + if cohort_size and snp_transcript: cohort_size = None print( "Cohort size is not supported with amino acid heatmap. Overriding cohort size to None." @@ -684,9 +725,11 @@ def plot_diplotype_clustering_advanced( figures.append(cnv_trace) subplot_heights.append(cnv_row_height * n_cnv_genes) - for snp_transcript in snp_transcripts: - snp_trace, n_snps_transcript = self._dipclust_snp_trace( + if isinstance(snp_transcript, str): + figures, subplot_heights = self._insert_dipclust_snp_trace( transcript=snp_transcript, + figures=figures, + subplot_heights=subplot_heights, sample_sets=sample_sets, sample_query=sample_query, sample_query_options=sample_query_options, @@ -698,13 +741,22 @@ def plot_diplotype_clustering_advanced( chunks=chunks, inline_array=inline_array, ) - - if snp_trace: - figures.append(snp_trace) - subplot_heights.append(snp_row_height * n_snps_transcript) - else: - print( - f"No SNPs were found below {snp_filter_min_maf} allele frequency. Omitting SNP genotype plot." + elif isinstance(snp_transcript, list): + for st in snp_transcript: + figures, subplot_heights = self._insert_dipclust_snp_trace( + transcript=st, + figures=figures, + subplot_heights=subplot_heights, + sample_sets=sample_sets, + sample_query=sample_query, + sample_query_options=sample_query_options, + snp_query=snp_query, + site_mask=site_mask, + dendro_sample_id_order=dendro_sample_id_order, + snp_filter_min_maf=snp_filter_min_maf, + snp_colorscale=snp_colorscale, + chunks=chunks, + inline_array=inline_array, ) # Calculate total height based on subplot heights, plus a fixed diff --git a/malariagen_data/anoph/dipclust_params.py b/malariagen_data/anoph/dipclust_params.py index f2f0376a..0fedf326 100644 --- a/malariagen_data/anoph/dipclust_params.py +++ b/malariagen_data/anoph/dipclust_params.py @@ -1,9 +1,17 @@ """Parameters for diplotype clustering functions.""" +from typing_extensions import Annotated, TypeAlias, Union, Sequence + from .distance_params import distance_metric from .clustering_params import linkage_method +from .base_params import transcript linkage_method_default: linkage_method = "complete" distance_metric_default: distance_metric = "cityblock" + +snp_transcript: TypeAlias = Annotated[ + Union[None, transcript, Sequence[transcript]], + "A transcript or a list of transcripts", +] diff --git a/notebooks/plot_diplotype_clustering.ipynb b/notebooks/plot_diplotype_clustering.ipynb index fb6e6665..83d9c505 100644 --- a/notebooks/plot_diplotype_clustering.ipynb +++ b/notebooks/plot_diplotype_clustering.ipynb @@ -36,14 +36,14 @@ { "cell_type": "code", "execution_count": null, - "id": "92a9bdfd-d808-4c07-8bdd-f2eb3d1e5614", + "id": "15d0bf3e-bbbe-4b67-a39d-c7ed5622cc3b", "metadata": {}, "outputs": [], "source": [ - "ag3.plot_diplotype_clustering_advanced(\n", - " region='2L:2,350,000-2,680,000',\n", - " snp_transcripts=['AGAP004707-RD', 'AGAP004717-RA'],\n", - " snp_query=\"effect == 'NON_SYNONYMOUS_CODING'\",\n", + "fig = ag3.plot_diplotype_clustering_advanced(\n", + " region=\"2R:28,480,000-28,500,000\",\n", + " cnv_region=\"2R:28,480,000-28,500,000\",\n", + " snp_transcript='AGAP002862-RA',\n", " snp_filter_min_maf=0.05,\n", " sample_sets=\"AG1000G-GH\",\n", " site_mask=\"gamb_colu\",\n", @@ -52,21 +52,23 @@ " linkage_method=\"complete\",\n", " count_sort=True,\n", " distance_sort=False,\n", - ")" + " show=False,\n", + ")\n", + "fig" ] }, { "cell_type": "code", "execution_count": null, - "id": "15d0bf3e-bbbe-4b67-a39d-c7ed5622cc3b", + "id": "92a9bdfd-d808-4c07-8bdd-f2eb3d1e5614", "metadata": {}, "outputs": [], "source": [ - "fig = ag3.plot_diplotype_clustering_advanced(\n", - " region=\"2R:28,480,000-28,500,000\",\n", - " cnv_region=\"2R:28,480,000-28,500,000\",\n", - " snp_transcripts=['AGAP002862-RA'],\n", - " snp_filter_min_maf=0.05,\n", + "ag3.plot_diplotype_clustering_advanced(\n", + " region='2R:28,480,000-28,490,000',\n", + " snp_transcript=['AGAP002862-RA', 'AGAP002864-RA'],\n", + " snp_query=\"effect == 'NON_SYNONYMOUS_CODING'\",\n", + " snp_filter_min_maf=0.1,\n", " sample_sets=\"AG1000G-GH\",\n", " site_mask=\"gamb_colu\",\n", " color=\"taxon\",\n", @@ -74,9 +76,7 @@ " linkage_method=\"complete\",\n", " count_sort=True,\n", " distance_sort=False,\n", - " show=False,\n", - ")\n", - "fig" + ")" ] }, { @@ -89,7 +89,7 @@ "ag3.plot_diplotype_clustering_advanced(\n", " region=\"2R:28,480,000-28,500,000\",\n", " cnv_region = \"2R:28,480,000-28,500,000\",\n", - " snp_transcripts=['AGAP002862-RA'],\n", + " snp_transcript=None,\n", " sample_sets=[\"AG1000G-GH\", 'AG1000G-BF-A'],\n", " snp_filter_min_maf=0.05,\n", " site_mask=\"gamb_colu\",\n", @@ -412,7 +412,7 @@ "source": [ "af1.plot_diplotype_clustering_advanced(\n", " region = \"X:8,438,477-8,460,887\",\n", - " snp_transcripts=[\"LOC125764232_t1\"],\n", + " snp_transcript=[\"LOC125764232_t1\"],\n", " cnv_region=\"X:8,418,477-8,480,887\",\n", " sample_sets=[\"1232-VO-KE-OCHOMO-VMF00044\", \"1231-VO-MULTI-WONDJI-VMF00043\", \"1236-VO-TZ-OKUMU-VMF00090\"],\n", " sample_query=\"country in ['Kenya', 'Uganda', 'Tanzania'] and taxon == 'funestus'\"\n", diff --git a/tests/anoph/test_dipclust.py b/tests/anoph/test_dipclust.py index 1c1de419..0090030a 100644 --- a/tests/anoph/test_dipclust.py +++ b/tests/anoph/test_dipclust.py @@ -159,7 +159,7 @@ def test_plot_diplotype_clustering_advanced_with_transcript( sample_queries = (None, "sex_call == 'F'") dipclust_params = dict( region=contig, - snp_transcripts=transcripts, + snp_transcript=transcripts, sample_sets=[random.choice(all_sample_sets)], linkage_method=random.choice(linkage_methods), distance_metric="cityblock",