From d102032067e6c661f55bee22cfa8919a2f1c5731 Mon Sep 17 00:00:00 2001 From: szhan Date: Mon, 1 Jul 2024 12:30:20 +0100 Subject: [PATCH] Support multiple allele states --- lshmm/core.py | 87 ++++++++++ tests/test_nontree_vit_haploid_tstv.py | 215 +++++++++++++++++++++++++ 2 files changed, 302 insertions(+) create mode 100644 tests/test_nontree_vit_haploid_tstv.py diff --git a/lshmm/core.py b/lshmm/core.py index b8bd86d..269ef25 100644 --- a/lshmm/core.py +++ b/lshmm/core.py @@ -302,6 +302,93 @@ def get_emission_probability_haploid(ref_allele, query_allele, site, emission_ma return emission_matrix[site, 1] +@jit.numba_njit +def get_emission_matrix_haploid_tstv(mu, kappa=None): + """ + Return an emission probability matrix that allows for mutational bias + towards transitions or transversions. + + Transition and transversion probabilities are defined such that + the probability of a particular type of transition is equal to + `kappa` * the probability of a particular type of transversion, + and that the total probability of mutation is equal to `mu`. + + When `kappa` is set to None, it defaults to 1. + + :param float mu: Probability of mutation to any allele. + :param float kappa: Transition-to-transversion rate ratio. + """ + if np.any(mu < 0.0) or np.any(mu > 1.0): + raise ValueError("Probability of mutation must be in [0, 1].") + + if kappa is not None and kappa <= 0: + raise ValueError("Transition-to-transversion rate ratio must be positive.") + + if kappa is None: + kappa = 1.0 + + num_sites = len(mu) + num_alleles = 4 # Assume that ACGT are encoded as 0 to 3. + + # Initialise emission probability matrix with zeros. + emission_matrix = ( + np.zeros((num_sites, num_alleles, num_alleles), dtype=np.float64) - 1 + ) + + for i in range(num_sites): + for j in range(num_alleles): + for k in range(num_alleles): + if j == k: + emission_matrix[i, j, k] = 1 - mu[i] + else: + mu_over_two_plus_kappa = mu[i] / (2.0 + kappa) + emission_matrix[i, j, k] = mu_over_two_plus_kappa + # Transitions: A <-> G and C <-> T. + is_transition_AG = j in [0, 2] and k in [0, 2] + is_transition_CT = j in [1, 3] and k in [1, 3] + if is_transition_AG or is_transition_CT: + emission_matrix[i, j, k] *= kappa + + row_sum = np.sum(emission_matrix[i, j, :]) + if not np.isclose(row_sum, 1.0): + err_msg = f"Row values must sum to one. {row_sum}" + raise ValueError(err_msg) + + return emission_matrix + + +@jit.numba_njit +def get_emission_probability_haploid_tstv( + ref_allele, query_allele, site, emission_matrix +): + """ + Return the emission probability at a specified site for the haploid case, + given an emission probability matrix. + + The emission probability matrix is an array of size (m, 4), + where m = number of sites. + + :param int ref_allele: Reference allele. + :param int query_allele: Query allele. + :param int site: Site index. + :param numpy.ndarray emission_matrix: Emission probability matrix. + :return: Emission probability. + :rtype: float + """ + if ref_allele == MISSING: + raise ValueError("Reference allele cannot be MISSING.") + if query_allele == NONCOPY: + raise ValueError("Query allele cannot be NONCOPY.") + if emission_matrix.shape[1] != 4 or emission_matrix.shape[2] != 4: + raise ValueError("Emission probability matrix has incorrect shape.") + if ref_allele == NONCOPY: + return 0.0 + elif query_allele == MISSING: + return 1.0 + else: + return emission_matrix[site, ref_allele, query_allele] + + # Functions to assign emission probabilities for diploid LS HMM. @jit.numba_njit def get_emission_matrix_diploid(mu, num_sites, num_alleles, scale_mutation_rate): diff --git a/tests/test_nontree_vit_haploid_tstv.py b/tests/test_nontree_vit_haploid_tstv.py new file mode 100644 index 0000000..701d05e --- /dev/null +++ b/tests/test_nontree_vit_haploid_tstv.py @@ -0,0 +1,215 @@ +import itertools +import pytest + +import numpy as np +import numba as nb + +from . import lsbase +import lshmm.core as core +import lshmm.vit_haploid as vh + + +class TestNonTreeViterbiHaploid(lsbase.ViterbiAlgorithmBase): + def verify(self, ts, include_ancestors): + H, queries = self.get_examples_haploid(ts, include_ancestors) + m = H.shape[0] + n = H.shape[1] + + r_s = [ + np.zeros(m) + 0.01, + np.random.rand(m), + 1e-5 * (np.random.rand(m) + 0.5) / 2, + np.zeros(m) + 0.2, + np.zeros(m) + 1e-6, + ] + mu_s = [ + np.zeros(m) + 0.01, + np.random.rand(m) * 0.2, + 1e-5 * (np.random.rand(m) + 0.5) / 2, + np.zeros(m) + 0.2, + np.zeros(m) + 1e-6, + ] + kappa_s = [0.25, 0.5, 1.0, 1.5, 2.0] + + for s, r, mu, kappa in itertools.product(queries, r_s, mu_s, kappa_s): + e = core.get_emission_matrix_haploid_tstv(mu, kappa) + + V_vs, P_vs, ll_vs = vh.forwards_viterbi_hap_naive( + n=n, + m=m, + H=H, + s=s, + e=e, + r=r, + emission_func=core.get_emission_probability_haploid_tstv, + ) + path_vs = vh.backwards_viterbi_hap(m=m, V_last=V_vs[m - 1, :], P=P_vs) + ll_check = vh.path_ll_hap( + n=n, + m=m, + H=H, + path=path_vs, + s=s, + e=e, + r=r, + emission_func=core.get_emission_probability_haploid_tstv, + ) + self.assertAllClose(ll_vs, ll_check) + + V_tmp, P_tmp, ll_tmp = vh.forwards_viterbi_hap_naive_vec( + n=n, + m=m, + H=H, + s=s, + e=e, + r=r, + emission_func=core.get_emission_probability_haploid_tstv, + ) + path_tmp = vh.backwards_viterbi_hap(m=m, V_last=V_tmp[m - 1, :], P=P_tmp) + ll_check = vh.path_ll_hap( + n=n, + m=m, + H=H, + path=path_tmp, + s=s, + e=e, + r=r, + emission_func=core.get_emission_probability_haploid_tstv, + ) + self.assertAllClose(ll_tmp, ll_check) + self.assertAllClose(ll_vs, ll_tmp) + + V_tmp, P_tmp, ll_tmp = vh.forwards_viterbi_hap_naive_low_mem( + n=n, + m=m, + H=H, + s=s, + e=e, + r=r, + emission_func=core.get_emission_probability_haploid_tstv, + ) + path_tmp = vh.backwards_viterbi_hap(m=m, V_last=V_tmp, P=P_tmp) + ll_check = vh.path_ll_hap( + n=n, + m=m, + H=H, + path=path_tmp, + s=s, + e=e, + r=r, + emission_func=core.get_emission_probability_haploid_tstv, + ) + self.assertAllClose(ll_tmp, ll_check) + self.assertAllClose(ll_vs, ll_tmp) + + V_tmp, P_tmp, ll_tmp = vh.forwards_viterbi_hap_naive_low_mem_rescaling( + n=n, + m=m, + H=H, + s=s, + e=e, + r=r, + emission_func=core.get_emission_probability_haploid_tstv, + ) + path_tmp = vh.backwards_viterbi_hap(m, V_tmp, P_tmp) + ll_check = vh.path_ll_hap( + n=n, + m=m, + H=H, + path=path_tmp, + s=s, + e=e, + r=r, + emission_func=core.get_emission_probability_haploid_tstv, + ) + self.assertAllClose(ll_tmp, ll_check) + self.assertAllClose(ll_vs, ll_tmp) + + V_tmp, P_tmp, ll_tmp = vh.forwards_viterbi_hap_low_mem_rescaling( + n=n, + m=m, + H=H, + s=s, + e=e, + r=r, + emission_func=core.get_emission_probability_haploid_tstv, + ) + path_tmp = vh.backwards_viterbi_hap(m=m, V_last=V_tmp, P=P_tmp) + ll_check = vh.path_ll_hap( + n=n, + m=m, + H=H, + path=path_tmp, + s=s, + e=e, + r=r, + emission_func=core.get_emission_probability_haploid_tstv, + ) + self.assertAllClose(ll_tmp, ll_check) + self.assertAllClose(ll_vs, ll_tmp) + + V_tmp, P_tmp, ll_tmp = vh.forwards_viterbi_hap_lower_mem_rescaling( + n=n, + m=m, + H=H, + s=s, + e=e, + r=r, + emission_func=core.get_emission_probability_haploid_tstv, + ) + path_tmp = vh.backwards_viterbi_hap(m, V_tmp, P_tmp) + ll_check = vh.path_ll_hap( + n=n, + m=m, + H=H, + path=path_tmp, + s=s, + e=e, + r=r, + emission_func=core.get_emission_probability_haploid_tstv, + ) + self.assertAllClose(ll_tmp, ll_check) + self.assertAllClose(ll_vs, ll_tmp) + + ( + V_tmp, + V_argmaxes_tmp, + recombs, + ll_tmp, + ) = vh.forwards_viterbi_hap_lower_mem_rescaling_no_pointer( + n=n, + m=m, + H=H, + s=s, + e=e, + r=r, + emission_func=core.get_emission_probability_haploid_tstv, + ) + path_tmp = vh.backwards_viterbi_hap_no_pointer( + m=m, + V_argmaxes=V_argmaxes_tmp, + recombs=nb.typed.List(recombs), + ) + ll_check = vh.path_ll_hap( + n=n, + m=m, + H=H, + path=path_tmp, + s=s, + e=e, + r=r, + emission_func=core.get_emission_probability_haploid_tstv, + ) + self.assertAllClose(ll_tmp, ll_check) + self.assertAllClose(ll_vs, ll_tmp) + + @pytest.mark.parametrize("include_ancestors", [True, False]) + def test_ts_multiallelic_n10_no_recomb(self, include_ancestors): + ts = self.get_ts_multiallelic_n10_no_recomb() + self.verify(ts, include_ancestors) + + @pytest.mark.parametrize("num_samples", [8, 16, 32]) + @pytest.mark.parametrize("include_ancestors", [True, False]) + def test_ts_multiallelic(self, num_samples, include_ancestors): + ts = self.get_ts_multiallelic(num_samples) + self.verify(ts, include_ancestors)