From b797476d807abb974f72a0ea2748784d4f10ba8c Mon Sep 17 00:00:00 2001 From: Alistair Miles Date: Mon, 11 Dec 2023 21:16:42 +0000 Subject: [PATCH] Misc bug fixes (#480) * use Dataset.sizes[] instead of Dataset.dims[] * fix param typing * allow taxon order override * fix param handling * more params * more progress * relax typing --- malariagen_data/anoph/hap_data.py | 4 +- malariagen_data/anoph/plotly_params.py | 15 ++- malariagen_data/anoph/sample_metadata.py | 6 +- malariagen_data/anoph/snp_data.py | 135 ++++++++++++----------- malariagen_data/anopheles.py | 35 ++++-- malariagen_data/util.py | 4 +- notebooks/plot_njt.ipynb | 31 ++++++ notebooks/plot_pca.ipynb | 57 ++++++++++ tests/anoph/conftest.py | 2 +- tests/anoph/test_aim_data.py | 8 +- tests/anoph/test_cnv_data.py | 6 +- tests/anoph/test_hap_data.py | 6 +- tests/anoph/test_snp_data.py | 52 ++++----- tests/test_af1.py | 4 +- tests/test_ag3.py | 14 +-- tests/test_amin1.py | 8 +- tests/test_pf7_integration.py | 8 +- tests/test_pv4_integration.py | 8 +- 18 files changed, 261 insertions(+), 142 deletions(-) diff --git a/malariagen_data/anoph/hap_data.py b/malariagen_data/anoph/hap_data.py index b3863372c..14e161de7 100644 --- a/malariagen_data/anoph/hap_data.py +++ b/malariagen_data/anoph/hap_data.py @@ -270,7 +270,7 @@ def haplotypes( if min_cohort_size is not None: # Handle min cohort size. - n_samples = ds.dims["samples"] + n_samples = ds.sizes["samples"] if n_samples < min_cohort_size: raise ValueError( f"Not enough samples ({n_samples}) for minimum cohort size ({min_cohort_size})" @@ -278,7 +278,7 @@ def haplotypes( if max_cohort_size is not None: # Handle max cohort size. - n_samples = ds.dims["samples"] + n_samples = ds.sizes["samples"] if n_samples > max_cohort_size: rng = np.random.default_rng(seed=random_seed) loc_downsample = rng.choice( diff --git a/malariagen_data/anoph/plotly_params.py b/malariagen_data/anoph/plotly_params.py index dbc18e604..943fba104 100644 --- a/malariagen_data/anoph/plotly_params.py +++ b/malariagen_data/anoph/plotly_params.py @@ -62,7 +62,7 @@ ] category_order: TypeAlias = Annotated[ - Optional[List], + Optional[Union[List, Mapping]], "Control the order in which values appear in the legend.", ] @@ -97,12 +97,12 @@ ] color: TypeAlias = Annotated[ - Optional[str], + Optional[Union[str, Mapping]], "Name of variable to use to color the markers.", ] symbol: TypeAlias = Annotated[ - Optional[str], + Optional[Union[str, Mapping]], "Name of the variable to use to choose marker symbols.", ] @@ -166,3 +166,12 @@ Union[int, float], "The upper end of the range of values that the colormap covers.", ] + +legend_sizing: TypeAlias = Annotated[ + Literal["constant", "trace"], + """ + Determines if the legend items symbols scale with their corresponding + "trace" attributes or remain "constant" independent of the symbol size + on the graph. + """, +] diff --git a/malariagen_data/anoph/sample_metadata.py b/malariagen_data/anoph/sample_metadata.py index 7a98c80ff..031a0a8b0 100644 --- a/malariagen_data/anoph/sample_metadata.py +++ b/malariagen_data/anoph/sample_metadata.py @@ -1,5 +1,5 @@ import io -from typing import Any, Dict, List, Mapping, Optional, Tuple, Union +from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union import ipyleaflet import numpy as np @@ -508,14 +508,14 @@ def count_samples( self, sample_sets: Optional[base_params.sample_sets] = None, sample_query: Optional[base_params.sample_query] = None, - index: Union[str, Tuple[str, ...]] = ( + index: Union[str, Sequence[str]] = ( "country", "admin1_iso", "admin1_name", "admin2_name", "year", ), - columns: Union[str, Tuple[str, ...]] = "taxon", + columns: Union[str, Sequence[str]] = "taxon", ) -> pd.DataFrame: # Load sample metadata. df_samples = self.sample_metadata( diff --git a/malariagen_data/anoph/snp_data.py b/malariagen_data/anoph/snp_data.py index 06c3b7111..216cea519 100644 --- a/malariagen_data/anoph/snp_data.py +++ b/malariagen_data/anoph/snp_data.py @@ -925,7 +925,7 @@ def _snp_calls( # Handle min cohort size. if min_cohort_size is not None: - n_samples = ds.dims["samples"] + n_samples = ds.sizes["samples"] if n_samples < min_cohort_size: raise ValueError( f"not enough samples ({n_samples}) for minimum cohort size ({min_cohort_size})" @@ -933,7 +933,7 @@ def _snp_calls( # Handle max cohort size. if max_cohort_size is not None: - n_samples = ds.dims["samples"] + n_samples = ds.sizes["samples"] if n_samples > max_cohort_size: rng = np.random.default_rng(seed=random_seed) loc_downsample = rng.choice( @@ -1439,74 +1439,79 @@ def biallelic_snp_calls( chunks=chunks, ) - # Subset to biallelic sites. - ds_bi = ds.isel(variants=loc_bi) + with self._spinner("Prepare biallelic SNP calls"): + # Subset to biallelic sites. + ds_bi = ds.isel(variants=loc_bi) - # Start building a new dataset. - coords: Dict[str, Any] = dict() - data_vars: Dict[str, Any] = dict() + # Start building a new dataset. + coords: Dict[str, Any] = dict() + data_vars: Dict[str, Any] = dict() - # Store sample IDs. - coords["sample_id"] = ("samples",), ds_bi["sample_id"].data + # Store sample IDs. + coords["sample_id"] = ("samples",), ds_bi["sample_id"].data - # Store contig. - coords["variant_contig"] = ("variants",), ds_bi["variant_contig"].data + # Store contig. + coords["variant_contig"] = ("variants",), ds_bi["variant_contig"].data - # Store position. - coords["variant_position"] = ("variants",), ds_bi["variant_position"].data + # Store position. + coords["variant_position"] = ("variants",), ds_bi["variant_position"].data - # Store alleles, transformed. - variant_allele = ds_bi["variant_allele"].data - variant_allele = variant_allele.rechunk((variant_allele.chunks[0], -1)) - variant_allele_out = da.map_blocks( - lambda block: apply_allele_mapping(block, allele_mapping, max_allele=1), - variant_allele, - dtype=variant_allele.dtype, - chunks=(variant_allele.chunks[0], [2]), - ) - data_vars["variant_allele"] = ("variants", "alleles"), variant_allele_out - - # Store allele counts, transformed, so we don't have to recompute. - ac_out = apply_allele_mapping(ac_bi, allele_mapping, max_allele=1) - data_vars["variant_allele_count"] = ("variants", "alleles"), ac_out - - # Store genotype calls, transformed. - gt = ds_bi["call_genotype"].data - gt_out = allel.GenotypeDaskArray(gt).map_alleles(allele_mapping) - data_vars["call_genotype"] = ("variants", "samples", "ploidy"), gt_out.values - - # Build dataset. - ds_out = xr.Dataset(coords=coords, data_vars=data_vars, attrs=ds.attrs) - - # Apply conditions. - if max_missing_an or min_minor_ac: - loc_out = np.ones(ds_out.dims["variants"], dtype=bool) - - # Apply missingness condition. - if max_missing_an is not None: - an = ac_out.sum(axis=1) - an_missing = (ds_out.dims["samples"] * ds_out.dims["ploidy"]) - an - loc_missing = an_missing <= max_missing_an - loc_out &= loc_missing - - # Apply minor allele count condition. - if min_minor_ac is not None: - ac_minor = ac_out.min(axis=1) - loc_minor = ac_minor >= min_minor_ac - loc_out &= loc_minor - - ds_out = ds_out.isel(variants=loc_out) - - # Try to meet target number of SNPs. - if n_snps is not None: - if ds_out.dims["variants"] > (n_snps * 2): - # Do some thinning. - thin_step = ds_out.dims["variants"] // n_snps - loc_thin = slice(thin_offset, None, thin_step) - ds_out = ds_out.isel(variants=loc_thin) - - elif ds_out.dims["variants"] < n_snps: - raise ValueError("Not enough SNPs.") + # Store alleles, transformed. + variant_allele = ds_bi["variant_allele"].data + variant_allele = variant_allele.rechunk((variant_allele.chunks[0], -1)) + variant_allele_out = da.map_blocks( + lambda block: apply_allele_mapping(block, allele_mapping, max_allele=1), + variant_allele, + dtype=variant_allele.dtype, + chunks=(variant_allele.chunks[0], [2]), + ) + data_vars["variant_allele"] = ("variants", "alleles"), variant_allele_out + + # Store allele counts, transformed, so we don't have to recompute. + ac_out = apply_allele_mapping(ac_bi, allele_mapping, max_allele=1) + data_vars["variant_allele_count"] = ("variants", "alleles"), ac_out + + # Store genotype calls, transformed. + gt = ds_bi["call_genotype"].data + gt_out = allel.GenotypeDaskArray(gt).map_alleles(allele_mapping) + data_vars["call_genotype"] = ( + "variants", + "samples", + "ploidy", + ), gt_out.values + + # Build dataset. + ds_out = xr.Dataset(coords=coords, data_vars=data_vars, attrs=ds.attrs) + + # Apply conditions. + if max_missing_an or min_minor_ac: + loc_out = np.ones(ds_out.sizes["variants"], dtype=bool) + + # Apply missingness condition. + if max_missing_an is not None: + an = ac_out.sum(axis=1) + an_missing = (ds_out.sizes["samples"] * ds_out.sizes["ploidy"]) - an + loc_missing = an_missing <= max_missing_an + loc_out &= loc_missing + + # Apply minor allele count condition. + if min_minor_ac is not None: + ac_minor = ac_out.min(axis=1) + loc_minor = ac_minor >= min_minor_ac + loc_out &= loc_minor + + ds_out = ds_out.isel(variants=loc_out) + + # Try to meet target number of SNPs. + if n_snps is not None: + if ds_out.sizes["variants"] > (n_snps * 2): + # Do some thinning. + thin_step = ds_out.sizes["variants"] // n_snps + loc_thin = slice(thin_offset, None, thin_step) + ds_out = ds_out.isel(variants=loc_thin) + + elif ds_out.sizes["variants"] < n_snps: + raise ValueError("Not enough SNPs.") return ds_out diff --git a/malariagen_data/anopheles.py b/malariagen_data/anopheles.py index 374cbbd6c..906ca046b 100644 --- a/malariagen_data/anopheles.py +++ b/malariagen_data/anopheles.py @@ -563,7 +563,7 @@ def _map_snp_to_aa_change_frq_ds(ds): "event_nobs", ] - if ds.dims["variants"] == 1: + if ds.sizes["variants"] == 1: # keep everything as-is, no need for aggregation ds_out = ds[keep_vars + ["variant_alt_allele", "event_count"]] @@ -2311,7 +2311,7 @@ def _gene_cnv_frequencies( debug( "setup output dataframe - two rows for each gene, one for amplification and one for deletion" ) - n_genes = ds_cnv.dims["genes"] + n_genes = ds_cnv.sizes["genes"] df_genes = ds_cnv[ [ "gene_id", @@ -2569,7 +2569,7 @@ def _gene_cnv_frequencies_advanced( is_called = cn >= 0 debug("set up main event variables") - n_genes = ds_cnv.dims["genes"] + n_genes = ds_cnv.sizes["genes"] n_variants, n_cohorts = n_genes * 2, len(df_cohorts) count = np.zeros((n_variants, n_cohorts), dtype=int) nobs = np.zeros((n_variants, n_cohorts), dtype=int) @@ -3545,6 +3545,7 @@ def plot_frequencies_time_series( height: plotly_params.height = None, width: plotly_params.width = None, title: plotly_params.title = True, + legend_sizing: plotly_params.legend_sizing = "constant", show: plotly_params.show = True, renderer: plotly_params.renderer = None, **kwargs, @@ -3634,7 +3635,10 @@ def plot_frequencies_time_series( ) debug("tidy plot") - fig.update_layout(yaxis_range=[-0.05, 1.05]) + fig.update_layout( + yaxis_range=[-0.05, 1.05], + legend=dict(itemsizing=legend_sizing, tracegroupgap=0), + ) if show: # pragma: no cover fig.show(renderer=renderer) @@ -3833,11 +3837,15 @@ def plot_pca_coords( color_discrete_sequence: plotly_params.color_discrete_sequence = None, color_discrete_map: plotly_params.color_discrete_map = None, category_orders: plotly_params.category_order = None, + legend_sizing: plotly_params.legend_sizing = "constant", show: plotly_params.show = True, renderer: plotly_params.renderer = None, render_mode: plotly_params.render_mode = "svg", **kwargs, ) -> plotly_params.figure: + # Copy input data to avoid overwriting. + data = data.copy() + # Apply jitter if desired - helps spread out points when tightly clustered. if jitter_frac: np.random.seed(random_seed) @@ -3894,7 +3902,7 @@ def plot_pca_coords( # Tidy up. fig.update_layout( - legend=dict(itemsizing="constant"), + legend=dict(itemsizing=legend_sizing, tracegroupgap=0), ) fig.update_traces(marker={"size": marker_size}) @@ -3930,10 +3938,14 @@ def plot_pca_coords_3d( color_discrete_sequence: plotly_params.color_discrete_sequence = None, color_discrete_map: plotly_params.color_discrete_map = None, category_orders: plotly_params.category_order = None, + legend_sizing: plotly_params.legend_sizing = "constant", show: plotly_params.show = True, renderer: plotly_params.renderer = None, **kwargs, ) -> plotly_params.figure: + # Copy input data to avoid overwriting. + data = data.copy() + # Apply jitter if desired - helps spread out points when tightly clustered. if jitter_frac: np.random.seed(random_seed) @@ -3989,7 +4001,7 @@ def plot_pca_coords_3d( # Tidy up. fig.update_layout( scene=dict(aspectmode="cube"), - legend=dict(itemsizing="constant"), + legend=dict(itemsizing=legend_sizing, tracegroupgap=0), ) fig.update_traces(marker={"size": marker_size}) @@ -6486,6 +6498,7 @@ def plot_haplotype_clustering( color_discrete_sequence: plotly_params.color_discrete_sequence = None, color_discrete_map: plotly_params.color_discrete_map = None, category_orders: plotly_params.category_order = None, + legend_sizing: plotly_params.legend_sizing = "constant", ) -> plotly_params.figure: import sys @@ -6588,6 +6601,7 @@ def plot_haplotype_clustering( title_font=dict( size=title_font_size, ), + legend=dict(itemsizing=legend_sizing, tracegroupgap=0), ) if show: # pragma: no cover @@ -7120,7 +7134,7 @@ def plot_njt( category_orders: plotly_params.category_order = None, edge_legend: bool = False, leaf_legend: bool = True, - legend_sizing: str = "trace", + legend_sizing: plotly_params.legend_sizing = "constant", thin_offset: base_params.thin_offset = 0, sample_sets: Optional[base_params.sample_sets] = None, sample_query: Optional[base_params.sample_query] = None, @@ -7291,7 +7305,7 @@ def plot_njt( title_font=dict( size=title_font_size, ), - legend=dict(itemsizing=legend_sizing), + legend=dict(itemsizing=legend_sizing, tracegroupgap=0), ) # Style axes. @@ -7370,7 +7384,10 @@ def _setup_plotly_sample_colors( # Special case, default taxon colors and order. color_params = self._setup_taxon_colors() color_discrete_map_prepped = color_params["color_discrete_map"] - category_orders_prepped = color_params["category_orders"] + if category_orders is None: + category_orders_prepped = color_params["category_orders"] + else: + category_orders_prepped = category_orders color_prepped = color # Bail out early. return ( diff --git a/malariagen_data/util.py b/malariagen_data/util.py index 9501c5b12..c41482238 100644 --- a/malariagen_data/util.py +++ b/malariagen_data/util.py @@ -213,7 +213,7 @@ def dask_compress_dataset(ds, indexer, dim): assert isinstance(indexer, da.Array) assert indexer.ndim == 1 assert indexer.dtype == bool - assert indexer.shape[0] == ds.dims[dim] + assert indexer.shape[0] == ds.sizes[dim] # temporarily compute the indexer once, to avoid multiple reads from # the underlying data @@ -571,7 +571,7 @@ def _simple_xarray_concat_arrays( # Iterate over variable names. for k in names: - # Access the variable from the virst dataset. + # Access the variable from the first dataset. v = ds0[k] if dim in v.dims: diff --git a/notebooks/plot_njt.ipynb b/notebooks/plot_njt.ipynb index 4e8a32ddf..c7cb47d14 100644 --- a/notebooks/plot_njt.ipynb +++ b/notebooks/plot_njt.ipynb @@ -80,6 +80,7 @@ " color=\"taxon\",\n", " width=700,\n", " legend_sizing=\"constant\",\n", + " category_orders=dict(taxon=[\"coluzzii\", \"gambiae\"]),\n", ")" ] }, @@ -302,6 +303,36 @@ "id": "9c1dc76a-1fea-49ee-9341-a7542b2a74f0", "metadata": {}, "outputs": [], + "source": [ + "new_cohorts = {\n", + " \"East\": \"country in ['Malawi', 'Tanzania', 'Kenya', 'Uganda']\",\n", + " \"West\": \"country in ['Mali', 'Burkina Faso', 'Cameroon']\",\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bc7b2566-ac5d-4ca0-bc91-0208fe9d2cb0", + "metadata": {}, + "outputs": [], + "source": [ + "af1.plot_njt(\n", + " n_snps=10_000,\n", + " region=\"3RL:15,000,000-16,000,000\",\n", + " sample_sets=\"1.0\",\n", + " color=new_cohorts,\n", + " metric=\"euclidean\",\n", + " distance_sort=True,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9f5cd515-4933-4f08-82cc-c0fa57e77fb3", + "metadata": {}, + "outputs": [], "source": [] } ], diff --git a/notebooks/plot_pca.ipynb b/notebooks/plot_pca.ipynb index 04a7583e5..9dc79289b 100644 --- a/notebooks/plot_pca.ipynb +++ b/notebooks/plot_pca.ipynb @@ -162,6 +162,7 @@ "ag3.plot_pca_coords(\n", " df_pca,\n", " color=\"taxon\",\n", + " category_orders=dict(taxon=[\"coluzzii\", \"gambiae\"]),\n", ")" ] }, @@ -263,6 +264,62 @@ ")" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "e0c10b0d-3244-4862-acb7-cf60b9a6a9f2", + "metadata": {}, + "outputs": [], + "source": [ + "ag3.plot_pca_coords_3d(\n", + " df_pca,\n", + " color=\"taxon\",\n", + " category_orders=dict(taxon=[\"coluzzii\", \"gambiae\", \"arabiensis\", \"gcx1\", \"gcx2\", \"gcx3\"]),\n", + " marker_size=2,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "727402e7-43a1-4c7d-b21e-a051cd881f01", + "metadata": {}, + "outputs": [], + "source": [ + "new_cohorts = {\n", + " \"East\": \"country in ['Malawi', 'Tanzania', 'Kenya', 'Uganda']\",\n", + " \"West\": \"country in ['Mali', 'Burkina Faso', 'Cameroon']\",\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cccf1386-c578-4af3-97ec-dc7704840102", + "metadata": {}, + "outputs": [], + "source": [ + "ag3.plot_pca_coords(\n", + " df_pca,\n", + " color=new_cohorts,\n", + " marker_size=5,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1b1a399a-762d-42db-9e40-776575d9ef5a", + "metadata": {}, + "outputs": [], + "source": [ + "ag3.plot_pca_coords_3d(\n", + " df_pca,\n", + " color=new_cohorts,\n", + " marker_size=2,\n", + ")" + ] + }, { "cell_type": "code", "execution_count": null, diff --git a/tests/anoph/conftest.py b/tests/anoph/conftest.py index 332a508fd..47e4ffc6c 100644 --- a/tests/anoph/conftest.py +++ b/tests/anoph/conftest.py @@ -1577,7 +1577,7 @@ def init_aim_calls(self): # Add call_genotype variable. gt = np.random.choice( np.arange(2, dtype="i1"), - size=(ds.dims["variants"], ds.dims["samples"], 2), + size=(ds.sizes["variants"], ds.sizes["samples"], 2), replace=True, ) ds["call_genotype"] = ("variants", "samples", "ploidy"), gt diff --git a/tests/anoph/test_aim_data.py b/tests/anoph/test_aim_data.py index 2ecb36297..813fabc1d 100644 --- a/tests/anoph/test_aim_data.py +++ b/tests/anoph/test_aim_data.py @@ -75,7 +75,7 @@ def test_aim_variants(aims, ag3_sim_api): assert tuple(ds.attrs["contigs"]) == api.contigs # Check dimension lengths. - assert ds.dims["alleles"] == 2 + assert ds.sizes["alleles"] == 2 @pytest.mark.parametrize("aims", ["gambcolu_vs_arab", "gamb_vs_colu"]) @@ -152,9 +152,9 @@ def test_aim_calls(aims, ag3_sim_api): assert tuple(ds.attrs["contigs"]) == api.contigs # Check dimension lengths. - assert ds.dims["samples"] == len(df_samples) - assert ds.dims["alleles"] == 2 - assert ds.dims["ploidy"] == 2 + assert ds.sizes["samples"] == len(df_samples) + assert ds.sizes["alleles"] == 2 + assert ds.sizes["ploidy"] == 2 def test_aim_calls_errors(ag3_sim_api): diff --git a/tests/anoph/test_cnv_data.py b/tests/anoph/test_cnv_data.py index 68d25b779..599bd85dd 100644 --- a/tests/anoph/test_cnv_data.py +++ b/tests/anoph/test_cnv_data.py @@ -239,7 +239,7 @@ def test_cnv_hmm__sample_query(ag3_sim_fixture, ag3_sim_api: AnophelesCnvData): ) expected_samples = df_samples["sample_id"].tolist() n_samples_expected = len(expected_samples) - assert ds.dims["samples"] == n_samples_expected + assert ds.sizes["samples"] == n_samples_expected # check sample IDs assert ds["sample_id"].values.tolist() == df_samples["sample_id"].tolist() @@ -295,7 +295,7 @@ def test_cnv_hmm(fixture, api: AnophelesCnvData): # check dim lengths df_samples = api.sample_metadata(sample_sets=sample_sets) n_samples_expected = len(df_samples) - assert ds.dims["samples"] == n_samples_expected + assert ds.sizes["samples"] == n_samples_expected # check sample IDs assert ds["sample_id"].values.tolist() == df_samples["sample_id"].tolist() @@ -502,7 +502,7 @@ def test_cnv_discordant_read_calls(fixture, api: AnophelesCnvData): # check dim lengths df_samples = api.sample_metadata(sample_sets=sample_sets) n_samples = len(df_samples) - assert ds.dims["samples"] == n_samples + assert ds.sizes["samples"] == n_samples # check sample IDs assert ds["sample_id"].values.tolist() == df_samples["sample_id"].tolist() diff --git a/tests/anoph/test_hap_data.py b/tests/anoph/test_hap_data.py index 409f2bf0d..9c78e4db7 100644 --- a/tests/anoph/test_hap_data.py +++ b/tests/anoph/test_hap_data.py @@ -272,12 +272,12 @@ def check_haplotypes( n_samples_expected = min(n_samples_selected, max_cohort_size) else: n_samples_expected = n_samples_selected - n_samples = ds.dims["samples"] + n_samples = ds.sizes["samples"] assert n_samples == n_samples_expected if min_cohort_size: assert n_samples >= min_cohort_size - assert ds.dims["ploidy"] == 2 - assert ds.dims["alleles"] == 2 + assert ds.sizes["ploidy"] == 2 + assert ds.sizes["alleles"] == 2 # Check shapes. for f in expected_coords | expected_data_vars: diff --git a/tests/anoph/test_snp_data.py b/tests/anoph/test_snp_data.py index 851418985..f616c501c 100644 --- a/tests/anoph/test_snp_data.py +++ b/tests/anoph/test_snp_data.py @@ -258,10 +258,10 @@ def test_open_site_annotations(fixture, api): def _check_site_annotations(api: AnophelesSnpData, region, site_mask): ds_snp = api.snp_variants(region=region, site_mask=site_mask) - n_variants = ds_snp.dims["variants"] + n_variants = ds_snp.sizes["variants"] ds_ann = api.site_annotations(region=region, site_mask=site_mask) # Site annotations dataset should be aligned with SNP sites. - assert ds_ann.dims["variants"] == n_variants + assert ds_ann.sizes["variants"] == n_variants assert isinstance(ds_ann, xr.Dataset) for f in ( "codon_degeneracy", @@ -481,10 +481,10 @@ def check_snp_calls(api, sample_sets, region, site_mask): n_variants = len(pos) df_samples = api.sample_metadata(sample_sets=sample_sets) n_samples = len(df_samples) - assert ds.dims["variants"] == n_variants - assert ds.dims["samples"] == n_samples - assert ds.dims["ploidy"] == 2 - assert ds.dims["alleles"] == 4 + assert ds.sizes["variants"] == n_variants + assert ds.sizes["samples"] == n_samples + assert ds.sizes["ploidy"] == 2 + assert ds.sizes["alleles"] == 4 # Check shapes. for f in expected_coords | expected_data_vars: @@ -613,7 +613,7 @@ def test_snp_calls_with_sample_query_param(ag3_sim_api: AnophelesSnpData, sample else: ds = ag3_sim_api.snp_calls(region="3L", sample_query=sample_query) - assert ds.dims["samples"] == len(df_samples) + assert ds.sizes["samples"] == len(df_samples) assert_array_equal(ds["sample_id"].values, df_samples["sample_id"].values) @@ -631,7 +631,7 @@ def test_snp_calls_with_min_cohort_size_param(fixture, api: AnophelesSnpData): min_cohort_size=10, ) assert isinstance(ds, xr.Dataset) - assert ds.dims["samples"] >= 10 + assert ds.sizes["samples"] >= 10 with pytest.raises(ValueError): api.snp_calls( sample_sets=sample_sets, @@ -654,7 +654,7 @@ def test_snp_calls_with_max_cohort_size_param(fixture, api: AnophelesSnpData): max_cohort_size=15, ) assert isinstance(ds, xr.Dataset) - assert ds.dims["samples"] <= 15 + assert ds.sizes["samples"] <= 15 @parametrize_with_cases("fixture,api", cases=".") @@ -672,7 +672,7 @@ def test_snp_calls_with_cohort_size_param(fixture, api: AnophelesSnpData): cohort_size=cohort_size, ) assert isinstance(ds, xr.Dataset) - assert ds.dims["samples"] == cohort_size + assert ds.sizes["samples"] == cohort_size with pytest.raises(ValueError): api.snp_calls( sample_sets=sample_sets, @@ -699,7 +699,7 @@ def test_snp_calls_with_cohort_size_param(fixture, api: AnophelesSnpData): def test_snp_calls_with_site_class_param(ag3_sim_api: AnophelesSnpData, site_class): ds1 = ag3_sim_api.snp_calls(region="3L") ds2 = ag3_sim_api.snp_calls(region="3L", site_class=site_class) - assert ds2.dims["variants"] < ds1.dims["variants"] + assert ds2.sizes["variants"] < ds1.sizes["variants"] def check_snp_allele_counts(api, region, sample_sets, sample_query, site_mask): @@ -922,10 +922,10 @@ def check_biallelic_snp_calls_and_diplotypes( # Check dim lengths. df_samples = api.sample_metadata(sample_sets=sample_sets) n_samples = len(df_samples) - n_variants = ds.dims["variants"] - assert ds.dims["samples"] == n_samples - assert ds.dims["ploidy"] == 2 - assert ds.dims["alleles"] == 2 + n_variants = ds.sizes["variants"] + assert ds.sizes["samples"] == n_samples + assert ds.sizes["ploidy"] == 2 + assert ds.sizes["alleles"] == 2 # Check shapes. for f in expected_coords | expected_data_vars: @@ -967,7 +967,7 @@ def check_biallelic_snp_calls_and_diplotypes( assert isinstance(d1, xr.DataArray) # Check if any variants found, could be zero. - if ds.dims["variants"] == 0: + if ds.sizes["variants"] == 0: # Bail out early, can't run further tests. return ds @@ -992,8 +992,8 @@ def check_biallelic_snp_calls_and_diplotypes( assert isinstance(gn, np.ndarray) assert isinstance(samples, np.ndarray) assert gn.ndim == 2 - assert gn.shape[0] == ds.dims["variants"] - assert gn.shape[1] == ds.dims["samples"] + assert gn.shape[0] == ds.sizes["variants"] + assert gn.shape[1] == ds.sizes["samples"] assert np.all(gn >= 0) assert np.all(gn <= 2) ac = ds["variant_allele_count"].values @@ -1090,7 +1090,7 @@ def test_biallelic_snp_calls_and_diplotypes_with_sample_query_param( else: ds = ag3_sim_api.biallelic_snp_calls(region="3L", sample_query=sample_query) - assert ds.dims["samples"] == len(df_samples) + assert ds.sizes["samples"] == len(df_samples) assert_array_equal(ds["sample_id"].values, df_samples["sample_id"].values) @@ -1110,7 +1110,7 @@ def test_biallelic_snp_calls_and_diplotypes_with_min_cohort_size_param( min_cohort_size=10, ) assert isinstance(ds, xr.Dataset) - assert ds.dims["samples"] >= 10 + assert ds.sizes["samples"] >= 10 with pytest.raises(ValueError): api.biallelic_snp_calls( sample_sets=sample_sets, @@ -1135,7 +1135,7 @@ def test_biallelic_snp_calls_and_diplotypes_with_max_cohort_size_param( max_cohort_size=15, ) assert isinstance(ds, xr.Dataset) - assert ds.dims["samples"] <= 15 + assert ds.sizes["samples"] <= 15 @parametrize_with_cases("fixture,api", cases=".") @@ -1155,7 +1155,7 @@ def test_biallelic_snp_calls_and_diplotypes_with_cohort_size_param( cohort_size=cohort_size, ) assert isinstance(ds, xr.Dataset) - assert ds.dims["samples"] == cohort_size + assert ds.sizes["samples"] == cohort_size with pytest.raises(ValueError): api.biallelic_snp_calls( sample_sets=sample_sets, @@ -1185,7 +1185,7 @@ def test_biallelic_snp_calls_and_diplotypes_with_site_class_param( contig = random.choice(ag3_sim_api.contigs) ds1 = ag3_sim_api.biallelic_snp_calls(region=contig) ds2 = ag3_sim_api.biallelic_snp_calls(region=contig, site_class=site_class) - assert ds2.dims["variants"] < ds1.dims["variants"] + assert ds2.sizes["variants"] < ds1.sizes["variants"] check_biallelic_snp_calls_and_diplotypes( ag3_sim_api, region=contig, site_class=site_class ) @@ -1220,14 +1220,14 @@ def test_biallelic_snp_calls_and_diplotypes_with_conditions( ac_min = ac.min(axis=1) assert np.all(ac_min >= min_minor_ac) an = ac.sum(axis=1) - an_missing = (ds.dims["samples"] * ds.dims["ploidy"]) - an + an_missing = (ds.sizes["samples"] * ds.sizes["ploidy"]) - an assert np.all(an_missing <= max_missing_an) gt = ds["call_genotype"].values ac_check = allel.GenotypeArray(gt).count_alleles(max_allele=1) assert np.all(ac == ac_check) # Run tests with thinning. - n_snps_available = ds.dims["variants"] + n_snps_available = ds.sizes["variants"] # This should always be true, although depends on min_minor_ac and max_missing_an, # so the range of values for those parameters needs to be chosen with some case. assert n_snps_available > 2 @@ -1241,7 +1241,7 @@ def test_biallelic_snp_calls_and_diplotypes_with_conditions( max_missing_an=max_missing_an, n_snps=n_snps_requested, ) - n_snps_thinned = ds_thinned.dims["variants"] + n_snps_thinned = ds_thinned.sizes["variants"] assert n_snps_thinned >= n_snps_requested assert n_snps_thinned <= 2 * n_snps_requested diff --git a/tests/test_af1.py b/tests/test_af1.py index 53b614fab..750ec45da 100644 --- a/tests/test_af1.py +++ b/tests/test_af1.py @@ -397,7 +397,7 @@ def _check_snp_allele_frequencies_advanced( assert cohort_labels == expect_cohort_labels # check variants are consistent - assert ds.dims["variants"] == len(df_af) + assert ds.sizes["variants"] == len(df_af) for v in expected_variant_vars: c = v.split("variant_")[1] actual = ds[v] @@ -550,7 +550,7 @@ def _check_aa_allele_frequencies_advanced( assert cohort_labels == expect_cohort_labels # check variants are consistent - assert ds.dims["variants"] == len(df_af) + assert ds.sizes["variants"] == len(df_af) for v in expected_variant_vars: c = v.split("variant_")[1] actual = ds[v] diff --git a/tests/test_ag3.py b/tests/test_ag3.py index 37769c7fb..767f8977b 100644 --- a/tests/test_ag3.py +++ b/tests/test_ag3.py @@ -571,11 +571,11 @@ def test_gene_cnv(region, sample_sets): # check dim lengths df_samples = ag3.sample_metadata(sample_sets=sample_sets) n_samples = len(df_samples) - assert ds.dims["samples"] == n_samples + assert ds.sizes["samples"] == n_samples df_genome_features = ag3.genome_features(region=region) df_genes = df_genome_features.query("type == 'gene'") n_genes = len(df_genes) - assert ds.dims["genes"] == n_genes + assert ds.sizes["genes"] == n_genes # check IDs assert ds["sample_id"].values.tolist() == df_samples["sample_id"].tolist() @@ -638,11 +638,11 @@ def test_gene_cnv_xarray_indexing(region, sample_sets): o = ds.sel(genes=gene) assert isinstance(o, xr.Dataset) assert set(o.dims) == {"samples"} - assert o.dims["samples"] == ds.dims["samples"] + assert o.sizes["samples"] == ds.sizes["samples"] o = ds.sel(samples=sample) assert isinstance(o, xr.Dataset) assert set(o.dims) == {"genes"} - assert o.dims["genes"] == ds.dims["genes"] + assert o.sizes["genes"] == ds.sizes["genes"] o = ds.sel(genes=gene, samples=sample) assert isinstance(o, xr.Dataset) assert set(o.dims) == set() @@ -1122,7 +1122,7 @@ def _check_snp_allele_frequencies_advanced( assert cohort_labels == expect_cohort_labels # check variants are consistent - assert ds.dims["variants"] == len(df_af) + assert ds.sizes["variants"] == len(df_af) for v in expected_variant_vars: c = v.split("variant_")[1] actual = ds[v] @@ -1271,7 +1271,7 @@ def _check_aa_allele_frequencies_advanced( assert cohort_labels == expect_cohort_labels # check variants are consistent - assert ds.dims["variants"] == len(df_af) + assert ds.sizes["variants"] == len(df_af) for v in expected_variant_vars: c = v.split("variant_")[1] actual = ds[v] @@ -1525,7 +1525,7 @@ def _check_gene_cnv_frequencies_advanced( assert cohort_labels == expect_cohort_labels # check variants are consistent - assert ds.dims["variants"] == len(df_af) + assert ds.sizes["variants"] == len(df_af) for v in expected_variant_vars: c = v.split("variant_")[1] actual = ds[v] diff --git a/tests/test_amin1.py b/tests/test_amin1.py index db4f8edb4..91137abe8 100644 --- a/tests/test_amin1.py +++ b/tests/test_amin1.py @@ -142,10 +142,10 @@ def test_snp_calls(region, site_mask): # check dim lengths df_samples = amin1.sample_metadata() n_samples = len(df_samples) - n_variants = ds.dims["variants"] - assert ds.dims["samples"] == n_samples - assert ds.dims["ploidy"] == 2 - assert ds.dims["alleles"] == 4 + n_variants = ds.sizes["variants"] + assert ds.sizes["samples"] == n_samples + assert ds.sizes["ploidy"] == 2 + assert ds.sizes["alleles"] == 4 # check shapes for f in expected_coords | expected_data_vars: diff --git a/tests/test_pf7_integration.py b/tests/test_pf7_integration.py index 34c4aee35..a481a8689 100644 --- a/tests/test_pf7_integration.py +++ b/tests/test_pf7_integration.py @@ -225,10 +225,10 @@ def test_variant_calls(extended): # check dim lengths df_samples = pf7.sample_metadata() n_samples = len(df_samples) - n_variants = ds.dims["variants"] - assert ds.dims["samples"] == n_samples - assert ds.dims["ploidy"] == 2 - assert ds.dims["alleles"] == 7 + n_variants = ds.sizes["variants"] + assert ds.sizes["samples"] == n_samples + assert ds.sizes["ploidy"] == 2 + assert ds.sizes["alleles"] == 7 # check shapes for f in expected_coords | expected_data_vars: diff --git a/tests/test_pv4_integration.py b/tests/test_pv4_integration.py index bf3463594..8a62de8d1 100644 --- a/tests/test_pv4_integration.py +++ b/tests/test_pv4_integration.py @@ -173,10 +173,10 @@ def test_variant_calls(extended): # check dim lengths df_samples = pv4.sample_metadata() n_samples = len(df_samples) - n_variants = ds.dims["variants"] - assert ds.dims["samples"] == n_samples - assert ds.dims["ploidy"] == 2 - assert ds.dims["alleles"] == 7 + n_variants = ds.sizes["variants"] + assert ds.sizes["samples"] == n_samples + assert ds.sizes["ploidy"] == 2 + assert ds.sizes["alleles"] == 7 # check shapes for f in expected_coords | expected_data_vars: