From 040ac59078fdcf9db199b8517f63cf668125b5fe Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Thu, 13 Jul 2023 16:57:48 +0100 Subject: [PATCH] Refactored code for emission probability --- python/tests/test_haplotype_matching.py | 119 ++++++++---------------- 1 file changed, 37 insertions(+), 82 deletions(-) diff --git a/python/tests/test_haplotype_matching.py b/python/tests/test_haplotype_matching.py index ee2999296b..22d1c5c796 100644 --- a/python/tests/test_haplotype_matching.py +++ b/python/tests/test_haplotype_matching.py @@ -359,6 +359,36 @@ def process_site(self, site, haplotype_state, forwards=True): st.value /= s st.value = round(st.value, self.precision) + def compute_emission_proba(self, site_id, is_match): + mu = self.mu[site_id] + n_alleles = self.n_alleles[site_id] + if self.scale_mutation_based_on_n_alleles: + if is_match: + # Scale mutation based on the number of alleles + # - so the mutation rate is the mutation rate to one of the + # alleles. The overall mutation rate is then + # (n_alleles - 1) * mutation_rate. + p_e = 1 - (n_alleles - 1) * mu + else: + p_e = mu - mu * (n_alleles == 1) + # Added boolean in case we're at an invariant site + else: + # No scaling based on the number of alleles + # - so the mutation rate is the mutation rate to anything. + # This means that we must rescale the mutation rate to a different + # allele, by the number of alleles. + if n_alleles == 1: # In case we're at an invariant site + if is_match: + p_e = 1 + else: + p_e = 0 + else: + if is_match: + p_e = 1 - mu + else: + p_e = mu / (n_alleles - 1) + return p_e + def run_forward(self, h): n = self.ts.num_samples self.tree.clear() @@ -552,40 +582,11 @@ def compute_normalisation_factor(self): s += self.N[j] * st.value return s - def compute_next_probability( - self, site_id, p_last, is_match, node - ): # Note node only used in Viterbi + # Note node only used in Viterbi + def compute_next_probability(self, site_id, p_last, is_match, node): rho = self.rho[site_id] - mu = self.mu[site_id] n = self.ts.num_samples - n_alleles = self.n_alleles[site_id] - - if self.scale_mutation_based_on_n_alleles: - if is_match: - # Scale mutation based on the number of alleles - # - so the mutation rate is the mutation rate to one of the - # alleles. The overall mutation rate is then - # (n_alleles - 1) * mutation_rate. - p_e = 1 - (n_alleles - 1) * mu - else: - p_e = mu - mu * (n_alleles == 1) - # Added boolean in case we're at an invariant site - else: - # No scaling based on the number of alleles - # - so the mutation rate is the mutation rate to anything. - # This means that we must rescale the mutation rate to a different - # allele, by the number of alleles. - if n_alleles == 1: # In case we're at an invariant site - if is_match: - p_e = 1 - else: - p_e = 0 - else: - if is_match: - p_e = 1 - mu - else: - p_e = mu / (n_alleles - 1) - + p_e = self.compute_emission_proba(site_id, is_match) p_t = p_last * (1 - rho) + rho / n return p_t * p_e @@ -623,28 +624,9 @@ def compute_normalisation_factor(self): s += self.N[j] * st.value return s - def compute_next_probability( - self, site_id, p_next, is_match, node - ): # Note node only used in Viterbi - mu = self.mu[site_id] - n_alleles = self.n_alleles[site_id] - - if self.scale_mutation_based_on_n_alleles: - if is_match: - p_e = 1 - (n_alleles - 1) * mu - else: - p_e = mu - mu * (n_alleles == 1) - else: - if n_alleles == 1: - if is_match: - p_e = 1 - else: - p_e = 0 - else: - if is_match: - p_e = 1 - mu - else: - p_e = mu / (n_alleles - 1) + # Note node only used in Viterbi + def compute_next_probability(self, site_id, p_next, is_match, node): + p_e = self.compute_emission_proba(site_id, is_match) return p_next * p_e @@ -681,9 +663,7 @@ def compute_normalisation_factor(self): def compute_next_probability(self, site_id, p_last, is_match, node): rho = self.rho[site_id] - mu = self.mu[site_id] n = self.ts.num_samples - n_alleles = self.n_alleles[site_id] p_no_recomb = p_last * (1 - rho + rho / n) p_recomb = rho / n @@ -695,32 +675,7 @@ def compute_next_probability(self, site_id, p_last, is_match, node): recombination_required = True self.output.add_recombination_required(site_id, node, recombination_required) - if self.scale_mutation_based_on_n_alleles: - if is_match: - # Scale mutation based on the number of alleles - # - so the mutation rate is the mutation rate to one of the - # alleles. The overall mutation rate is then - # (n_alleles - 1) * mutation_rate. - p_e = 1 - (n_alleles - 1) * mu - else: - p_e = mu - mu * (n_alleles == 1) - # Added boolean in case we're at an invariant site - else: - # No scaling based on the number of alleles - # - so the mutation rate is the mutation rate to anything. - # This means that we must rescale the mutation rate to a different - # allele, by the number of alleles. - if n_alleles == 1: # In case we're at an invariant site - if is_match: - p_e = 1 - else: - p_e = 0 - else: - if is_match: - p_e = 1 - mu - else: - p_e = mu / (n_alleles - 1) - + p_e = self.compute_emission_proba(site_id, is_match) return p_t * p_e