diff --git a/malariagen_data/anoph/pca.py b/malariagen_data/anoph/pca.py index a5a28adeb..cc3b8dbdb 100644 --- a/malariagen_data/anoph/pca.py +++ b/malariagen_data/anoph/pca.py @@ -71,13 +71,15 @@ def pca( cohort_size: Optional[base_params.cohort_size] = None, min_cohort_size: Optional[base_params.min_cohort_size] = None, max_cohort_size: Optional[base_params.max_cohort_size] = None, + exclude_samples: Optional[base_params.samples] = None, + fit_exclude_samples: Optional[base_params.samples] = None, random_seed: base_params.random_seed = 42, inline_array: base_params.inline_array = base_params.inline_array_default, chunks: base_params.chunks = base_params.chunks_default, ) -> Tuple[pca_params.df_pca, pca_params.evr]: # Change this name if you ever change the behaviour of this function, to # invalidate any previously cached data. - name = "pca_v3" + name = "pca_v4" # Normalize params for consistent hash value. ( @@ -104,6 +106,8 @@ def pca( cohort_size=cohort_size, min_cohort_size=min_cohort_size, max_cohort_size=max_cohort_size, + exclude_samples=exclude_samples, + fit_exclude_samples=fit_exclude_samples, random_seed=random_seed, ) @@ -119,11 +123,11 @@ def pca( coords = results["coords"] evr = results["evr"] samples = results["samples"] + loc_keep_fit = results["loc_keep_fit"] # Load sample metadata. df_samples = self.sample_metadata( sample_sets=sample_sets, - sample_indices=sample_indices_prepped, ) # Ensure aligned with genotype data. @@ -134,6 +138,8 @@ def pca( {f"PC{i + 1}": coords[:, i] for i in range(coords.shape[1])} ) df_pca = df_samples.join(df_coords, how="inner") + # Add a column for which samples were included in fitting. + df_pca["pca_fit"] = loc_keep_fit return df_pca, evr @@ -153,6 +159,8 @@ def _pca( cohort_size, min_cohort_size, max_cohort_size, + exclude_samples, + fit_exclude_samples, random_seed, chunks, inline_array, @@ -177,12 +185,39 @@ def _pca( ) with self._spinner(desc="Compute PCA"): + # Exclude any samples prior to computing PCA. + if exclude_samples is not None: + x = np.array(exclude_samples, dtype="U") + loc_keep = ~np.isin(samples, x) + samples = samples[loc_keep] + gn = gn[:, loc_keep] + + # Exclude any samples from fitting only. + if fit_exclude_samples is not None: + xf = np.array(fit_exclude_samples, dtype="U") + loc_keep_fit = ~np.isin(samples, xf) + gn_fit = gn[:, loc_keep_fit] + else: + loc_keep_fit = np.ones(len(samples), dtype=bool) + gn_fit = gn + # Remove any sites where all genotypes are identical. - loc_var = np.any(gn != gn[:, 0, np.newaxis], axis=1) + loc_var = np.any(gn_fit != gn_fit[:, 0, np.newaxis], axis=1) + gn_fit_var = np.compress(loc_var, gn_fit, axis=0) gn_var = np.compress(loc_var, gn, axis=0) # Run the PCA. - coords, model = allel.pca(gn_var, n_components=n_components) + if fit_exclude_samples is None: + # Simple fit and transform on the same data. + coords, model = allel.pca(gn_fit_var, n_components=n_components) + + else: + # Fit and transform separately. + model = allel.stats.decomposition.GenotypePCA( + n_components=n_components, + ) + model.fit(gn_fit_var) + coords = model.transform(gn_var, copy=False) # Work around sign indeterminacy. for i in range(coords.shape[1]): @@ -191,7 +226,10 @@ def _pca( coords[:, i] = c * -1 results = dict( - samples=samples, coords=coords, evr=model.explained_variance_ratio_ + samples=samples, + coords=coords, + evr=model.explained_variance_ratio_, + loc_keep_fit=loc_keep_fit, ) return results diff --git a/notebooks/plot_pca.ipynb b/notebooks/plot_pca.ipynb index 6feb52895..ad1ec7e77 100644 --- a/notebooks/plot_pca.ipynb +++ b/notebooks/plot_pca.ipynb @@ -1,5 +1,21 @@ { "cells": [ + { + "cell_type": "markdown", + "id": "7d93e741-bddf-4a29-aea2-7bc6be697cb7", + "metadata": {}, + "source": [ + "# PCA plotting" + ] + }, + { + "cell_type": "markdown", + "id": "2f274c61-c660-47b0-92b9-83ab701eb9eb", + "metadata": {}, + "source": [ + "### Setup" + ] + }, { "cell_type": "code", "execution_count": null, @@ -50,6 +66,14 @@ "!rm -rf results_cache" ] }, + { + "cell_type": "markdown", + "id": "b3760474-2ab2-4291-b5ee-48feccd091f4", + "metadata": {}, + "source": [ + "## Mayotte" + ] + }, { "cell_type": "code", "execution_count": null, @@ -137,6 +161,14 @@ ")" ] }, + { + "cell_type": "markdown", + "id": "467dbd69-8ba2-41a2-8d01-d9fa78d42326", + "metadata": {}, + "source": [ + "## Burkina Faso" + ] + }, { "cell_type": "code", "execution_count": null, @@ -180,6 +212,14 @@ ")" ] }, + { + "cell_type": "markdown", + "id": "a2b46eb9-832d-417c-a082-9fc365b0fc62", + "metadata": {}, + "source": [ + "## Ag3.0" + ] + }, { "cell_type": "code", "execution_count": null, @@ -320,6 +360,14 @@ ")" ] }, + { + "cell_type": "markdown", + "id": "14d04103-2f6f-4f53-8d35-f286de5980d2", + "metadata": {}, + "source": [ + "## Af1.0" + ] + }, { "cell_type": "code", "execution_count": null, @@ -373,13 +421,48 @@ ")" ] }, + { + "cell_type": "markdown", + "id": "fdfbcd5a-5ef4-4cbc-b430-c0b541925142", + "metadata": {}, + "source": [ + "## Excluding samples" + ] + }, { "cell_type": "code", "execution_count": null, "id": "6c3251e2-a95c-4d4a-bff1-8e632be0eaf5", "metadata": {}, "outputs": [], - "source": [] + "source": [ + "df_pca, evr = ag3.pca(\n", + " region=\"3L:15,000,000-16,000,000\",\n", + " sample_sets=\"AG1000G-BF-A\",\n", + " n_snps=10_000,\n", + " max_cohort_size=50,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1f0b88a7-d9f6-44a8-9778-4cbaad397765", + "metadata": {}, + "outputs": [], + "source": [ + "df_pca.head()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4e370936-925a-4e97-86cb-b22f72b39e64", + "metadata": {}, + "outputs": [], + "source": [ + "ag3.plot_pca_variance(evr)" + ] }, { "cell_type": "code", @@ -387,12 +470,160 @@ "id": "9ea24435-d752-4269-a391-b057e1650d44", "metadata": {}, "outputs": [], - "source": [] + "source": [ + "ag3.plot_pca_coords(\n", + " df_pca,\n", + " color=\"taxon\",\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0f114dd9-965d-4579-8cce-76f45fde6a37", + "metadata": {}, + "outputs": [], + "source": [ + "exclude_samples = [\"AB0096-C\", \"AB0241-C\", \"AB0275-C\", \"AB0197-C\"]\n", + "\n", + "df_pca_ex, evr_ex = ag3.pca(\n", + " region=\"3L:15,000,000-16,000,000\",\n", + " sample_sets=\"AG1000G-BF-A\",\n", + " n_snps=10_000,\n", + " max_cohort_size=50,\n", + " exclude_samples=exclude_samples,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a5392844-a828-48f3-8222-bfc1817c3493", + "metadata": {}, + "outputs": [], + "source": [ + "df_pca_ex.head()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6d87eeb5-8dc6-4a31-a8a3-e936d4d05215", + "metadata": {}, + "outputs": [], + "source": [ + "df_pca_ex.query(f\"sample_id in {exclude_samples}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "770b01ee-642e-4903-93fa-eef084924234", + "metadata": {}, + "outputs": [], + "source": [ + "ag3.plot_pca_variance(evr_ex)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9438fd91-b463-4d28-8279-5eaa728a8f2e", + "metadata": {}, + "outputs": [], + "source": [ + "ag3.plot_pca_coords(\n", + " df_pca_ex,\n", + " color=\"taxon\",\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "58e7b2c1-b4c1-4c93-823c-8531693636e4", + "metadata": {}, + "source": [ + "## Excluding samples during fit" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1014a17c-12bb-46c3-9a7b-f2449d205ce3", + "metadata": {}, + "outputs": [], + "source": [ + "fit_exclude_samples = [\"AB0096-C\", \"AB0241-C\", \"AB0275-C\", \"AB0197-C\"]\n", + "\n", + "df_pca_fex, evr_fex = ag3.pca(\n", + " region=\"3L:15,000,000-16,000,000\",\n", + " sample_sets=\"AG1000G-BF-A\",\n", + " n_snps=10_000,\n", + " max_cohort_size=50,\n", + " fit_exclude_samples=fit_exclude_samples,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e48d46dd-5c7f-4573-a5f5-c15a389d2e1e", + "metadata": {}, + "outputs": [], + "source": [ + "df_pca_fex.head()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1fb5192e-f63f-443d-91a9-e82b2de14c69", + "metadata": {}, + "outputs": [], + "source": [ + "df_pca_fex.query(f\"sample_id in {fit_exclude_samples}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "873c53f5-8690-4690-85f0-8ce1845e2862", + "metadata": {}, + "outputs": [], + "source": [ + "ag3.plot_pca_variance(evr_fex)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0bb86ba2-e428-4ae0-9e0c-703345c89324", + "metadata": {}, + "outputs": [], + "source": [ + "ag3.plot_pca_coords(\n", + " df_pca_fex,\n", + " color=\"taxon\",\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7b28c4c8-9382-4889-8651-77ef0be929f6", + "metadata": {}, + "outputs": [], + "source": [ + "ag3.plot_pca_coords(\n", + " df_pca_fex,\n", + " color=\"pca_fit\",\n", + ")" + ] }, { "cell_type": "code", "execution_count": null, - "id": "1746f075-8b25-47fc-8f97-35093d8cc899", + "id": "33d788a2-f256-4930-b1e5-b4f31e681a36", "metadata": {}, "outputs": [], "source": [] @@ -414,7 +645,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.18" + "version": "3.10.12" }, "widgets": { "application/vnd.jupyter.widget-state+json": { diff --git a/tests/anoph/test_pca.py b/tests/anoph/test_pca.py index 8d452fcce..027d23055 100644 --- a/tests/anoph/test_pca.py +++ b/tests/anoph/test_pca.py @@ -158,3 +158,130 @@ def test_pca_plotting(fixture, api: AnophelesPca): symbol=symbol, ) assert isinstance(fig_3d, go.Figure) + + +@parametrize_with_cases("fixture,api", cases=".") +def test_pca_exclude_samples(fixture, api: AnophelesPca): + # Parameters for selecting input data. + all_sample_sets = api.sample_sets()["sample_set"].to_list() + data_params = dict( + region=random.choice(api.contigs), + sample_sets=random.sample(all_sample_sets, 2), + site_mask=random.choice((None,) + api.site_mask_ids), + ) + ds = api.biallelic_snp_calls( + min_minor_ac=pca_params.min_minor_ac_default, + max_missing_an=pca_params.max_missing_an_default, + **data_params, + ) + + # Exclusion parameters. + n_samples_excluded = random.randint(1, 5) + samples = ds["sample_id"].values.tolist() + exclude_samples = random.sample(samples, n_samples_excluded) + + # PCA parameters. + n_samples = ds.sizes["samples"] - n_samples_excluded + n_snps_available = ds.sizes["variants"] + n_snps = random.randint(1, n_snps_available) + n_components = random.randint(3, min(n_samples, n_snps, 10)) + + # Run the PCA. + pca_df, pca_evr = api.pca( + n_snps=n_snps, + n_components=n_components, + exclude_samples=exclude_samples, + **data_params, + ) + + # Check types. + assert isinstance(pca_df, pd.DataFrame) + assert isinstance(pca_evr, np.ndarray) + + # Check sizes. + assert len(pca_df) == n_samples + for i in range(n_components): + assert f"PC{i+1}" in pca_df.columns, ( + "n_components", + n_components, + "n_samples", + n_samples, + "n_snps_available", + n_snps_available, + "n_snps", + n_snps, + ) + assert f"PC{n_components+1}" not in pca_df.columns + assert "pca_fit" in pca_df.columns + assert pca_df["pca_fit"].all() + assert pca_evr.ndim == 1 + assert pca_evr.shape[0] == n_components + + # Check exclusions. + assert len(pca_df.query(f"sample_id in {exclude_samples}")) == 0 + + +@parametrize_with_cases("fixture,api", cases=".") +def test_pca_fit_exclude_samples(fixture, api: AnophelesPca): + # Parameters for selecting input data. + all_sample_sets = api.sample_sets()["sample_set"].to_list() + data_params = dict( + region=random.choice(api.contigs), + sample_sets=random.sample(all_sample_sets, 2), + site_mask=random.choice((None,) + api.site_mask_ids), + ) + ds = api.biallelic_snp_calls( + min_minor_ac=pca_params.min_minor_ac_default, + max_missing_an=pca_params.max_missing_an_default, + **data_params, + ) + + # Exclusion parameters. + n_samples_excluded = random.randint(1, 5) + samples = ds["sample_id"].values.tolist() + exclude_samples = random.sample(samples, n_samples_excluded) + + # PCA parameters. + n_samples = ds.sizes["samples"] + n_snps_available = ds.sizes["variants"] + n_snps = random.randint(1, n_snps_available) + n_components = random.randint(3, min(n_samples, n_snps, 10)) + + # Run the PCA. + pca_df, pca_evr = api.pca( + n_snps=n_snps, + n_components=n_components, + fit_exclude_samples=exclude_samples, + **data_params, + ) + + # Check types. + assert isinstance(pca_df, pd.DataFrame) + assert isinstance(pca_evr, np.ndarray) + + # Check sizes. + assert len(pca_df) == n_samples + for i in range(n_components): + assert f"PC{i+1}" in pca_df.columns, ( + "n_components", + n_components, + "n_samples", + n_samples, + "n_snps_available", + n_snps_available, + "n_snps", + n_snps, + ) + assert f"PC{n_components+1}" not in pca_df.columns + assert "pca_fit" in pca_df.columns + assert pca_evr.ndim == 1 + assert pca_evr.shape[0] == n_components + + # Check exclusions. + assert not pca_df["pca_fit"].all() + assert pca_df["pca_fit"].sum() == n_samples - n_samples_excluded + assert len(pca_df.query(f"sample_id in {exclude_samples}")) == n_samples_excluded + assert ( + len(pca_df.query(f"sample_id in {exclude_samples} and not pca_fit")) + == n_samples_excluded + )