Skip to content

Commit

Permalink
Refactored code for emission probability
Browse files Browse the repository at this point in the history
  • Loading branch information
jeromekelleher committed Jul 13, 2023
1 parent 256768e commit 040ac59
Showing 1 changed file with 37 additions and 82 deletions.
119 changes: 37 additions & 82 deletions python/tests/test_haplotype_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,36 @@ def process_site(self, site, haplotype_state, forwards=True):
st.value /= s
st.value = round(st.value, self.precision)

def compute_emission_proba(self, site_id, is_match):
mu = self.mu[site_id]
n_alleles = self.n_alleles[site_id]
if self.scale_mutation_based_on_n_alleles:
if is_match:
# Scale mutation based on the number of alleles
# - so the mutation rate is the mutation rate to one of the
# alleles. The overall mutation rate is then
# (n_alleles - 1) * mutation_rate.
p_e = 1 - (n_alleles - 1) * mu
else:
p_e = mu - mu * (n_alleles == 1)
# Added boolean in case we're at an invariant site
else:
# No scaling based on the number of alleles
# - so the mutation rate is the mutation rate to anything.
# This means that we must rescale the mutation rate to a different
# allele, by the number of alleles.
if n_alleles == 1: # In case we're at an invariant site
if is_match:
p_e = 1
else:
p_e = 0
else:
if is_match:
p_e = 1 - mu
else:
p_e = mu / (n_alleles - 1)
return p_e

def run_forward(self, h):
n = self.ts.num_samples
self.tree.clear()
Expand Down Expand Up @@ -552,40 +582,11 @@ def compute_normalisation_factor(self):
s += self.N[j] * st.value
return s

def compute_next_probability(
self, site_id, p_last, is_match, node
): # Note node only used in Viterbi
# Note node only used in Viterbi
def compute_next_probability(self, site_id, p_last, is_match, node):
rho = self.rho[site_id]
mu = self.mu[site_id]
n = self.ts.num_samples
n_alleles = self.n_alleles[site_id]

if self.scale_mutation_based_on_n_alleles:
if is_match:
# Scale mutation based on the number of alleles
# - so the mutation rate is the mutation rate to one of the
# alleles. The overall mutation rate is then
# (n_alleles - 1) * mutation_rate.
p_e = 1 - (n_alleles - 1) * mu
else:
p_e = mu - mu * (n_alleles == 1)
# Added boolean in case we're at an invariant site
else:
# No scaling based on the number of alleles
# - so the mutation rate is the mutation rate to anything.
# This means that we must rescale the mutation rate to a different
# allele, by the number of alleles.
if n_alleles == 1: # In case we're at an invariant site
if is_match:
p_e = 1
else:
p_e = 0
else:
if is_match:
p_e = 1 - mu
else:
p_e = mu / (n_alleles - 1)

p_e = self.compute_emission_proba(site_id, is_match)
p_t = p_last * (1 - rho) + rho / n
return p_t * p_e

Expand Down Expand Up @@ -623,28 +624,9 @@ def compute_normalisation_factor(self):
s += self.N[j] * st.value
return s

def compute_next_probability(
self, site_id, p_next, is_match, node
): # Note node only used in Viterbi
mu = self.mu[site_id]
n_alleles = self.n_alleles[site_id]

if self.scale_mutation_based_on_n_alleles:
if is_match:
p_e = 1 - (n_alleles - 1) * mu
else:
p_e = mu - mu * (n_alleles == 1)
else:
if n_alleles == 1:
if is_match:
p_e = 1
else:
p_e = 0
else:
if is_match:
p_e = 1 - mu
else:
p_e = mu / (n_alleles - 1)
# Note node only used in Viterbi
def compute_next_probability(self, site_id, p_next, is_match, node):
p_e = self.compute_emission_proba(site_id, is_match)
return p_next * p_e


Expand Down Expand Up @@ -681,9 +663,7 @@ def compute_normalisation_factor(self):

def compute_next_probability(self, site_id, p_last, is_match, node):
rho = self.rho[site_id]
mu = self.mu[site_id]
n = self.ts.num_samples
n_alleles = self.n_alleles[site_id]

p_no_recomb = p_last * (1 - rho + rho / n)
p_recomb = rho / n
Expand All @@ -695,32 +675,7 @@ def compute_next_probability(self, site_id, p_last, is_match, node):
recombination_required = True
self.output.add_recombination_required(site_id, node, recombination_required)

if self.scale_mutation_based_on_n_alleles:
if is_match:
# Scale mutation based on the number of alleles
# - so the mutation rate is the mutation rate to one of the
# alleles. The overall mutation rate is then
# (n_alleles - 1) * mutation_rate.
p_e = 1 - (n_alleles - 1) * mu
else:
p_e = mu - mu * (n_alleles == 1)
# Added boolean in case we're at an invariant site
else:
# No scaling based on the number of alleles
# - so the mutation rate is the mutation rate to anything.
# This means that we must rescale the mutation rate to a different
# allele, by the number of alleles.
if n_alleles == 1: # In case we're at an invariant site
if is_match:
p_e = 1
else:
p_e = 0
else:
if is_match:
p_e = 1 - mu
else:
p_e = mu / (n_alleles - 1)

p_e = self.compute_emission_proba(site_id, is_match)
return p_t * p_e


Expand Down

0 comments on commit 040ac59

Please sign in to comment.