Skip to content

Commit

Permalink
Update how MISSING is handled when updating probabilities
Browse files Browse the repository at this point in the history
  • Loading branch information
szhan committed Jun 24, 2024
1 parent a343802 commit 54556c5
Showing 1 changed file with 26 additions and 17 deletions.
43 changes: 26 additions & 17 deletions python/tests/test_haplotype_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,11 +333,10 @@ def update_probabilities(self, site, haplotype_state):
while allelic_state[v] == -1:
v = tree.parent(v)
assert v != -1
match = (
haplotype_state == MISSING or haplotype_state == allelic_state[v]
)
match = haplotype_state == allelic_state[v]
is_query_missing = haplotype_state == MISSING
# Note that the node u is used only by Viterbi
st.value = self.compute_next_probability(site.id, st.value, match, u)
st.value = self.compute_next_probability(site.id, st.value, match, is_query_missing, u)

# Unset the states
allelic_state[tree.root] = -1
Expand Down Expand Up @@ -404,7 +403,7 @@ def run(self, h):
def compute_normalisation_factor(self):
raise NotImplementedError()

def compute_next_probability(self, site_id, p_last, is_match, node):
def compute_next_probability(self, site_id, p_last, is_match, is_query_missing, node):
raise NotImplementedError()


Expand Down Expand Up @@ -435,10 +434,13 @@ def compute_normalisation_factor(self):
s += self.N[j] * st.value
return s

def compute_next_probability(self, site_id, p_last, is_match, node):
def compute_next_probability(self, site_id, p_last, is_match, is_query_missing, node):
rho = self.rho[site_id]
n = self.ts.num_samples
p_e = self.compute_emission_proba(site_id, is_match)
if is_query_missing:
p_e = 1.0
else:
p_e = self.compute_emission_proba(site_id, is_match)
p_t = p_last * (1 - rho) + rho / n
return p_t * p_e

Expand All @@ -448,8 +450,11 @@ class BackwardAlgorithm(ForwardAlgorithm):
The Li and Stephens backward algorithm.
"""

def compute_next_probability(self, site_id, p_next, is_match, node):
p_e = self.compute_emission_proba(site_id, is_match)
def compute_next_probability(self, site_id, p_next, is_match, is_query_missing, node):
if is_query_missing:
p_e = 1.0
else:
p_e = self.compute_emission_proba(site_id, is_match)
return p_next * p_e

def process_site(self, site, haplotype_state, s):
Expand Down Expand Up @@ -515,7 +520,7 @@ def compute_normalisation_factor(self):
)
return max_st.value

def compute_next_probability(self, site_id, p_last, is_match, node):
def compute_next_probability(self, site_id, p_last, is_match, is_query_missing, node):
rho = self.rho[site_id]
n = self.ts.num_samples

Expand All @@ -529,7 +534,11 @@ def compute_next_probability(self, site_id, p_last, is_match, node):
recombination_required = True
self.output.add_recombination_required(site_id, node, recombination_required)

p_e = self.compute_emission_proba(site_id, is_match)
if is_query_missing:
p_e = 1.0
else:
p_e = self.compute_emission_proba(site_id, is_match)

return p_t * p_e


Expand Down Expand Up @@ -679,12 +688,12 @@ def traceback(self):

def get_site_alleles(ts, h, alleles):
if alleles is None:
n_alleles = np.int8(
[
len(np.unique(np.append(ts.genotype_matrix()[j, :], h[j])))
for j in range(ts.num_sites)
]
)
n_alleles = np.zeros(ts.num_sites, dtype=np.int8) - 1
for j in range(ts.num_sites):
uniq_alleles = np.unique(np.append(ts.genotype_matrix()[j, :], h[j]))
uniq_alleles = uniq_alleles[uniq_alleles != MISSING]
n_alleles[j] = len(uniq_alleles)
assert np.all(n_alleles > 0)
alleles = tskit.ALLELES_ACGT
if len(set(alleles).intersection(next(ts.variants()).alleles)) == 0:
alleles = tskit.ALLELES_01
Expand Down

0 comments on commit 54556c5

Please sign in to comment.