Skip to content

Commit

Permalink
Made snp_transcript polymorphic to avoid breaking the API
Browse files Browse the repository at this point in the history
  • Loading branch information
jonbrenas committed Dec 13, 2024
1 parent 911010b commit 7a8aaf6
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 31 deletions.
78 changes: 65 additions & 13 deletions malariagen_data/anoph/dipclust.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Tuple, Sequence
from typing import Optional, Tuple

import allel # type: ignore
import numpy as np
Expand Down Expand Up @@ -549,11 +549,52 @@ def _dipclust_concat_subplots(

return fig

def _insert_dipclust_snp_trace(
self,
*,
figures,
subplot_heights,
snp_row_height: plotly_params.height = 25,
transcript: base_params.transcript,
snp_query: Optional[base_params.snp_query] = AA_CHANGE_QUERY,
sample_sets: Optional[base_params.sample_sets],
sample_query: Optional[base_params.sample_query],
sample_query_options: Optional[base_params.sample_query_options],
site_mask: Optional[base_params.site_mask],
dendro_sample_id_order: np.ndarray,
snp_filter_min_maf: float,
snp_colorscale: Optional[plotly_params.color_continuous_scale],
chunks: base_params.chunks = base_params.native_chunks,
inline_array: base_params.inline_array = base_params.inline_array_default,
):
snp_trace, n_snps_transcript = self._dipclust_snp_trace(
transcript=transcript,
sample_sets=sample_sets,
sample_query=sample_query,
sample_query_options=sample_query_options,
snp_query=snp_query,
site_mask=site_mask,
dendro_sample_id_order=dendro_sample_id_order,
snp_filter_min_maf=snp_filter_min_maf,
snp_colorscale=snp_colorscale,
chunks=chunks,
inline_array=inline_array,
)

if snp_trace:
figures.append(snp_trace)
subplot_heights.append(snp_row_height * n_snps_transcript)
else:
print(
f"No SNPs were found below {snp_filter_min_maf} allele frequency. Omitting SNP genotype plot."
)
return figures, subplot_heights

@doc(
summary="Perform diplotype clustering, annotated with heterozygosity, gene copy number and amino acid variants.",
parameters=dict(
heterozygosity="Plot heterozygosity track.",
snp_transcripts="Plot amino acid variants for these transcripts.",
snp_transcript="Plot amino acid variants for these transcripts.",
cnv_region="Plot gene CNV calls for this region.",
snp_filter_min_maf="Filter amino acid variants with alternate allele frequency below this threshold.",
),
Expand All @@ -563,7 +604,7 @@ def plot_diplotype_clustering_advanced(
region: base_params.regions,
heterozygosity: bool = True,
heterozygosity_colorscale: plotly_params.color_continuous_scale = "Greys",
snp_transcripts: Sequence[base_params.transcript] = [],
snp_transcript: dipclust_params.snp_transcript = None,
snp_colorscale: plotly_params.color_continuous_scale = "Greys",
snp_filter_min_maf: float = 0.05,
snp_query: Optional[base_params.snp_query] = AA_CHANGE_QUERY,
Expand Down Expand Up @@ -603,7 +644,7 @@ def plot_diplotype_clustering_advanced(
chunks: base_params.chunks = base_params.native_chunks,
inline_array: base_params.inline_array = base_params.inline_array_default,
):
if cohort_size and snp_transcripts:
if cohort_size and snp_transcript:
cohort_size = None
print(
"Cohort size is not supported with amino acid heatmap. Overriding cohort size to None."
Expand Down Expand Up @@ -684,9 +725,11 @@ def plot_diplotype_clustering_advanced(
figures.append(cnv_trace)
subplot_heights.append(cnv_row_height * n_cnv_genes)

for snp_transcript in snp_transcripts:
snp_trace, n_snps_transcript = self._dipclust_snp_trace(
if isinstance(snp_transcript, str):
figures, subplot_heights = self._insert_dipclust_snp_trace(
transcript=snp_transcript,
figures=figures,
subplot_heights=subplot_heights,
sample_sets=sample_sets,
sample_query=sample_query,
sample_query_options=sample_query_options,
Expand All @@ -698,13 +741,22 @@ def plot_diplotype_clustering_advanced(
chunks=chunks,
inline_array=inline_array,
)

if snp_trace:
figures.append(snp_trace)
subplot_heights.append(snp_row_height * n_snps_transcript)
else:
print(
f"No SNPs were found below {snp_filter_min_maf} allele frequency. Omitting SNP genotype plot."
elif isinstance(snp_transcript, list):
for st in snp_transcript:
figures, subplot_heights = self._insert_dipclust_snp_trace(
transcript=st,
figures=figures,
subplot_heights=subplot_heights,
sample_sets=sample_sets,
sample_query=sample_query,
sample_query_options=sample_query_options,
snp_query=snp_query,
site_mask=site_mask,
dendro_sample_id_order=dendro_sample_id_order,
snp_filter_min_maf=snp_filter_min_maf,
snp_colorscale=snp_colorscale,
chunks=chunks,
inline_array=inline_array,
)

# Calculate total height based on subplot heights, plus a fixed
Expand Down
8 changes: 8 additions & 0 deletions malariagen_data/anoph/dipclust_params.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
"""Parameters for diplotype clustering functions."""

from typing_extensions import Annotated, TypeAlias, Union, Sequence

from .distance_params import distance_metric
from .clustering_params import linkage_method
from .base_params import transcript


linkage_method_default: linkage_method = "complete"

distance_metric_default: distance_metric = "cityblock"

snp_transcript: TypeAlias = Annotated[
Union[None, transcript, Sequence[transcript]],
"A transcript or a list of transcripts",
]
34 changes: 17 additions & 17 deletions notebooks/plot_diplotype_clustering.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,14 @@
{
"cell_type": "code",
"execution_count": null,
"id": "92a9bdfd-d808-4c07-8bdd-f2eb3d1e5614",
"id": "15d0bf3e-bbbe-4b67-a39d-c7ed5622cc3b",
"metadata": {},
"outputs": [],
"source": [
"ag3.plot_diplotype_clustering_advanced(\n",
" region='2L:2,350,000-2,680,000',\n",
" snp_transcripts=['AGAP004707-RD', 'AGAP004717-RA'],\n",
" snp_query=\"effect == 'NON_SYNONYMOUS_CODING'\",\n",
"fig = ag3.plot_diplotype_clustering_advanced(\n",
" region=\"2R:28,480,000-28,500,000\",\n",
" cnv_region=\"2R:28,480,000-28,500,000\",\n",
" snp_transcript='AGAP002862-RA',\n",
" snp_filter_min_maf=0.05,\n",
" sample_sets=\"AG1000G-GH\",\n",
" site_mask=\"gamb_colu\",\n",
Expand All @@ -52,31 +52,31 @@
" linkage_method=\"complete\",\n",
" count_sort=True,\n",
" distance_sort=False,\n",
")"
" show=False,\n",
")\n",
"fig"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "15d0bf3e-bbbe-4b67-a39d-c7ed5622cc3b",
"id": "92a9bdfd-d808-4c07-8bdd-f2eb3d1e5614",
"metadata": {},
"outputs": [],
"source": [
"fig = ag3.plot_diplotype_clustering_advanced(\n",
" region=\"2R:28,480,000-28,500,000\",\n",
" cnv_region=\"2R:28,480,000-28,500,000\",\n",
" snp_transcripts=['AGAP002862-RA'],\n",
" snp_filter_min_maf=0.05,\n",
"ag3.plot_diplotype_clustering_advanced(\n",
" region='2R:28,480,000-28,490,000',\n",
" snp_transcript=['AGAP002862-RA', 'AGAP002864-RA'],\n",
" snp_query=\"effect == 'NON_SYNONYMOUS_CODING'\",\n",
" snp_filter_min_maf=0.1,\n",
" sample_sets=\"AG1000G-GH\",\n",
" site_mask=\"gamb_colu\",\n",
" color=\"taxon\",\n",
" symbol=\"country\",\n",
" linkage_method=\"complete\",\n",
" count_sort=True,\n",
" distance_sort=False,\n",
" show=False,\n",
")\n",
"fig"
")"
]
},
{
Expand All @@ -89,7 +89,7 @@
"ag3.plot_diplotype_clustering_advanced(\n",
" region=\"2R:28,480,000-28,500,000\",\n",
" cnv_region = \"2R:28,480,000-28,500,000\",\n",
" snp_transcripts=['AGAP002862-RA'],\n",
" snp_transcript=None,\n",
" sample_sets=[\"AG1000G-GH\", 'AG1000G-BF-A'],\n",
" snp_filter_min_maf=0.05,\n",
" site_mask=\"gamb_colu\",\n",
Expand Down Expand Up @@ -412,7 +412,7 @@
"source": [
"af1.plot_diplotype_clustering_advanced(\n",
" region = \"X:8,438,477-8,460,887\",\n",
" snp_transcripts=[\"LOC125764232_t1\"],\n",
" snp_transcript=[\"LOC125764232_t1\"],\n",
" cnv_region=\"X:8,418,477-8,480,887\",\n",
" sample_sets=[\"1232-VO-KE-OCHOMO-VMF00044\", \"1231-VO-MULTI-WONDJI-VMF00043\", \"1236-VO-TZ-OKUMU-VMF00090\"],\n",
" sample_query=\"country in ['Kenya', 'Uganda', 'Tanzania'] and taxon == 'funestus'\"\n",
Expand Down
2 changes: 1 addition & 1 deletion tests/anoph/test_dipclust.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def test_plot_diplotype_clustering_advanced_with_transcript(
sample_queries = (None, "sex_call == 'F'")
dipclust_params = dict(
region=contig,
snp_transcripts=transcripts,
snp_transcript=transcripts,
sample_sets=[random.choice(all_sample_sets)],
linkage_method=random.choice(linkage_methods),
distance_metric="cityblock",
Expand Down

0 comments on commit 7a8aaf6

Please sign in to comment.