Skip to content

Commit

Permalink
Fix tests for lshmm 0.0.6
Browse files Browse the repository at this point in the history
  • Loading branch information
szhan committed Jun 20, 2024
1 parent d3c59ba commit c3cce05
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 76 deletions.
83 changes: 55 additions & 28 deletions python/tests/test_genotype_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -1345,12 +1345,17 @@ def verify(self, ts):
self.assertAllClose(ll_tree, ll_mirror_tree_dict)

# Ensure that the decoded matrices are the same
flipped_G_check = np.flip(G_check, axis=0)
flipped_s = np.flip(s, axis=1)
num_alleles = ls.core.get_num_alleles(flipped_G_check, flipped_s)

F_mirror_matrix, c, ll = ls.forwards(
np.flip(G_check, axis=0),
np.flip(s, axis=1),
r_flip,
p_mutation=np.flip(mu),
scale_mutation_based_on_n_alleles=False,
flipped_G_check,
flipped_s,
num_alleles=num_alleles,
prob_recombination=r_flip,
prob_mutation=np.flip(mu),
scale_mutation_rate=False,
)

self.assertAllClose(F_mirror_matrix, cm_mirror.decode())
Expand All @@ -1373,8 +1378,14 @@ def verify(self, ts):
ts_check.genotype_matrix()[i, :], ts_check.genotype_matrix()[i, :]
)

num_alleles = ls.core.get_num_alleles(G_check, s)
F, c, ll = ls.forwards(
G_check, s, r, p_mutation=mu, scale_mutation_based_on_n_alleles=False
reference_panel=G_check,
query=s,
num_alleles=num_alleles,
prob_recombination=r,
prob_mutation=mu,
scale_mutation_rate=False,
)
cm_d = ls_forward_tree(s[0, :], ts_check, r, mu)
self.assertAllClose(cm_d.decode(), F)
Expand All @@ -1399,16 +1410,23 @@ def verify(self, ts):
ts_check.genotype_matrix()[i, :], ts_check.genotype_matrix()[i, :]
)

num_alleles = ls.core.get_num_alleles(G_check, s)
F, c, ll = ls.forwards(
G_check, s, r, p_mutation=mu, scale_mutation_based_on_n_alleles=False
reference_panel=G_check,
query=s,
num_alleles=num_alleles,
prob_recombination=r,
prob_mutation=mu,
scale_mutation_rate=False,
)
B = ls.backwards(
G_check,
s,
c,
r,
p_mutation=mu,
scale_mutation_based_on_n_alleles=False,
reference_panel=G_check,
query=s,
num_alleles=num_alleles,
normalisation_factor_from_forward=c,
prob_recombination=r,
prob_mutation=mu,
scale_mutation_rate=False,
)

# Note, need to remove the first sample from the ts, and ensure that
Expand Down Expand Up @@ -1453,16 +1471,24 @@ def verify(self, ts):
ts_check.genotype_matrix()[i, :], ts_check.genotype_matrix()[i, :]
)
ts_check = ts.simplify(range(1, n + 1), filter_sites=False)

num_alleles = ls.core.get_num_alleles(G_check, s)
phased_path, ll = ls.viterbi(
G_check, s, r, p_mutation=mu, scale_mutation_based_on_n_alleles=False
reference_panel=G_check,
query=s,
num_alleles=num_alleles,
prob_recombination=r,
prob_mutation=mu,
scale_mutation_rate=False,
)
path_ll_matrix = ls.path_ll(
G_check,
s,
phased_path,
r,
p_mutation=mu,
scale_mutation_based_on_n_alleles=False,
path_ll_matrix = ls.path_loglik(
reference_panel=G_check,
query=s,
num_alleles=num_alleles,
path=phased_path,
prob_recombination=r,
prob_mutation=mu,
scale_mutation_rate=False,
)

c_v = ls_viterbi_tree(s[0, :], ts_check, r, mu)
Expand All @@ -1471,13 +1497,14 @@ def verify(self, ts):
# Attempt to get the path
path_tree_dict = c_v.traceback()
# Work out the likelihood of the proposed path
path_ll_tree = ls.path_ll(
G_check,
s,
np.transpose(path_tree_dict),
r,
p_mutation=mu,
scale_mutation_based_on_n_alleles=False,
path_ll_tree = ls.path_loglik(
reference_panel=G_check,
query=s,
num_alleles=num_alleles,
path=np.transpose(path_tree_dict),
prob_recombination=r,
prob_mutation=mu,
scale_mutation_rate=False,
)

self.assertAllClose(ll, ll_tree)
Expand Down
128 changes: 80 additions & 48 deletions python/tests/test_haplotype_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -925,12 +925,16 @@ def verify(self, ts):
self.assertAllClose(ll_tree, ll_mirror_tree)

# Ensure that the decoded matrices are the same
flipped_H = np.flip(H, axis=0)
flipped_s = np.flip(s, axis=1)
num_alleles = ls.core.get_num_alleles(flipped_H, flipped_s)
F_mirror_matrix, c, ll = ls.forwards(
np.flip(H, axis=0),
np.flip(s, axis=1),
r_flip,
p_mutation=np.flip(mu),
scale_mutation_based_on_n_alleles=False,
reference_panel=flipped_H,
query=flipped_s,
num_alleles=num_alleles,
prob_recombination=r_flip,
prob_mutation=np.flip(mu),
scale_mutation_rate=False,
)

self.assertAllClose(F_mirror_matrix, cm_mirror.decode())
Expand All @@ -949,12 +953,14 @@ def verify(self, ts):
# Warning from lshmm:
# Passed a vector of mutation rates, but rescaling each mutation
# rate conditional on the number of alleles
num_alleles = ls.core.get_num_alleles(H, s)
F, c, ll = ls.forwards(
H,
s,
r,
p_mutation=mu,
scale_mutation_based_on_n_alleles=scale_mutation,
reference_panel=H,
query=s,
num_alleles=num_alleles,
prob_recombination=r,
prob_mutation=mu,
scale_mutation_rate=scale_mutation,
)
# Note, need to remove the first sample from the ts, and ensure
# that invariant sites aren't removed.
Expand All @@ -977,16 +983,23 @@ class TestForwardBackwardTree(FBAlgorithmBase):

def verify(self, ts):
for n, H, s, r, mu in self.example_parameters_haplotypes(ts):
num_alleles = ls.core.get_num_alleles(H, s)
F, c, ll = ls.forwards(
H, s, r, p_mutation=mu, scale_mutation_based_on_n_alleles=False
reference_panel=H,
query=s,
num_alleles=num_alleles,
prob_recombination=r,
prob_mutation=mu,
scale_mutation_rate=False,
)
B = ls.backwards(
H,
s,
c,
r,
p_mutation=mu,
scale_mutation_based_on_n_alleles=False,
reference_panel=H,
query=s,
num_alleles=num_alleles,
normalisation_factor_from_forward=c,
prob_recombination=r,
prob_mutation=mu,
scale_mutation_rate=False,
)

# Note, need to remove the first sample from the ts, and ensure that
Expand Down Expand Up @@ -1017,8 +1030,14 @@ class TestTreeViterbiHap(VitAlgorithmBase):

def verify(self, ts):
for n, H, s, r, mu in self.example_parameters_haplotypes(ts):
num_alleles = ls.core.get_num_alleles(H, s)
path, ll = ls.viterbi(
H, s, r, p_mutation=mu, scale_mutation_based_on_n_alleles=False
reference_panel=H,
query=s,
num_alleles=num_alleles,
prob_recombination=r,
prob_mutation=mu,
scale_mutation_rate=False,
)
ts_check = ts.simplify(range(1, n + 1), filter_sites=False)
cm = ls_viterbi_tree(s[0, :], ts_check, r, mu)
Expand All @@ -1028,13 +1047,14 @@ def verify(self, ts):
# Now, need to ensure that the likelihood of the preferred path is
# the same as ll_tree (and ll).
path_tree = cm.traceback()
ll_check = ls.path_ll(
H,
s,
path_tree,
r,
p_mutation=mu,
scale_mutation_based_on_n_alleles=False,
ll_check = ls.path_loglik(
reference_panel=H,
query=s,
num_alleles=num_alleles,
path=path_tree,
prob_recombination=r,
prob_mutation=mu,
scale_mutation_rate=False,
)
self.assertAllClose(ll, ll_check)

Expand All @@ -1051,13 +1071,16 @@ def check_viterbi(ts, h, recombination=None, mutation=None):
precision = 22

G = ts.genotype_matrix()
s = h.reshape(1, m)
num_alleles = ls.core.get_num_alleles(G, s)

path, ll = ls.viterbi(
G,
h.reshape(1, m),
recombination,
p_mutation=mutation,
scale_mutation_based_on_n_alleles=False,
reference_panel=G,
query=s,
num_alleles=num_alleles,
prob_recombination=recombination,
prob_mutation=mutation,
scale_mutation_rate=False,
)
assert np.isscalar(ll)

Expand All @@ -1069,13 +1092,14 @@ def check_viterbi(ts, h, recombination=None, mutation=None):
# Check that the likelihood of the preferred path is
# the same as ll_tree (and ll).
path_tree = cm.traceback()
ll_check = ls.path_ll(
G,
h.reshape(1, m),
path_tree,
recombination,
p_mutation=mutation,
scale_mutation_based_on_n_alleles=False,
ll_check = ls.path_loglik(
reference_panel=G,
query=s,
num_allele=num_alleles,
path=path_tree,
prob_recombination=recombination,
prob_mutation=mutation,
scale_mutation_rate=False,
)
nt.assert_allclose(ll_check, ll)

Expand Down Expand Up @@ -1106,12 +1130,16 @@ def check_forward_matrix(ts, h, recombination=None, mutation=None):
mutation = np.zeros(ts.num_sites)

G = ts.genotype_matrix()
s = h.reshape(1, m)
num_alleles = ls.core.get_num_alleles(G, s)

F, c, ll = ls.forwards(
G,
h.reshape(1, m),
recombination,
p_mutation=mutation,
scale_mutation_based_on_n_alleles=False,
reference_panel=G,
query=s,
num_alleles=num_alleles,
prob_recombination=recombination,
prob_mutation=mutation,
scale_mutation_rate=False,
)
assert F.shape == (m, n)
assert c.shape == (m,)
Expand Down Expand Up @@ -1150,13 +1178,17 @@ def check_backward_matrix(ts, h, forward_cm, recombination=None, mutation=None):
mutation = np.zeros(ts.num_sites)

G = ts.genotype_matrix()
s = h.reshape(1, m)
num_alleles = ls.core.get_num_alleles(G, s)

B = ls.backwards(
G,
h.reshape(1, m),
forward_cm.normalisation_factor,
recombination,
p_mutation=mutation,
scale_mutation_based_on_n_alleles=False,
reference_pane=G,
query=h.reshape(1, m),
num_alleles=num_alleles,
normalisation_factor_from_forward=forward_cm.normalisation_factor,
prob_recombination=recombination,
prob_mutation=mutation,
scale_mutation_rate=False,
)

backward_cm = ls_backward_tree(
Expand Down

0 comments on commit c3cce05

Please sign in to comment.