Skip to content

Commit

Permalink
Fix tests for lshmm 0.0.8
Browse files Browse the repository at this point in the history
  • Loading branch information
szhan committed Jun 24, 2024
1 parent d3c59ba commit a343802
Show file tree
Hide file tree
Showing 5 changed files with 136 additions and 92 deletions.
2 changes: 1 addition & 1 deletion python/requirements/CI-complete/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ coverage==7.2.7
dendropy==4.6.1
h5py==3.9.0
kastore==0.3.2
lshmm==0.0.5
lshmm==0.0.8
msgpack==1.0.5
msprime==1.2.0
networkx==3.1
Expand Down
4 changes: 2 additions & 2 deletions python/requirements/CI-tests-pip/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
lshmm==0.0.5
lshmm==0.0.8
numpy==1.24.1
pytest==7.1.3
pytest-cov==4.0.0
Expand All @@ -14,4 +14,4 @@ newick==1.3.2
tszip==0.2.2
kastore==0.3.2
lxml==4.9.2
numba<=0.59.1 #Pinned directly as 0.60.0 fails
numba<=0.59.1 #Pinned directly as 0.60.0 fails
2 changes: 1 addition & 1 deletion python/requirements/development.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ h5py>=2.6.0
jsonschema>=3.0.0
jupyter-book>=0.12.1
kastore
lshmm>=0.0.5
lshmm>=0.0.8
matplotlib
meson>=0.61.0
msgpack>=1.0.0
Expand Down
99 changes: 59 additions & 40 deletions python/tests/test_genotype_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -1326,11 +1326,10 @@ def verify(self, ts):
ts_check, mapping = ts.simplify(
range(1, n + 1), filter_sites=False, map_nodes=True
)
H_check = ts_check.genotype_matrix()
G_check = np.zeros((m, n, n))
for i in range(m):
G_check[i, :, :] = np.add.outer(
ts_check.genotype_matrix()[i, :], ts_check.genotype_matrix()[i, :]
)
G_check[i, :, :] = np.add.outer(H_check[i, :], H_check[i, :])

cm_d = ls_forward_tree(s[0, :], ts_check, r, mu)
ll_tree = np.sum(np.log10(cm_d.normalisation_factor))
Expand All @@ -1345,12 +1344,16 @@ def verify(self, ts):
self.assertAllClose(ll_tree, ll_mirror_tree_dict)

# Ensure that the decoded matrices are the same
flipped_H_check = np.flip(H_check, axis=0)
flipped_s = np.flip(s, axis=1)

F_mirror_matrix, c, ll = ls.forwards(
np.flip(G_check, axis=0),
np.flip(s, axis=1),
r_flip,
p_mutation=np.flip(mu),
scale_mutation_based_on_n_alleles=False,
flipped_H_check,
flipped_s,
ploidy=2,
prob_recombination=r_flip,
prob_mutation=np.flip(mu),
scale_mutation_rate=False,
)

self.assertAllClose(F_mirror_matrix, cm_mirror.decode())
Expand All @@ -1367,14 +1370,18 @@ def verify(self, ts):
ts_check, mapping = ts.simplify(
range(1, n + 1), filter_sites=False, map_nodes=True
)
H_check = ts_check.genotype_matrix()
G_check = np.zeros((m, n, n))
for i in range(m):
G_check[i, :, :] = np.add.outer(
ts_check.genotype_matrix()[i, :], ts_check.genotype_matrix()[i, :]
)
G_check[i, :, :] = np.add.outer(H_check[i, :], H_check[i, :])

F, c, ll = ls.forwards(
G_check, s, r, p_mutation=mu, scale_mutation_based_on_n_alleles=False
reference_panel=H_check,
query=s,
ploidy=2,
prob_recombination=r,
prob_mutation=mu,
scale_mutation_rate=False,
)
cm_d = ls_forward_tree(s[0, :], ts_check, r, mu)
self.assertAllClose(cm_d.decode(), F)
Expand All @@ -1393,22 +1400,27 @@ def verify(self, ts):
ts_check, mapping = ts.simplify(
range(1, n + 1), filter_sites=False, map_nodes=True
)
H_check = ts_check.genotype_matrix()
G_check = np.zeros((m, n, n))
for i in range(m):
G_check[i, :, :] = np.add.outer(
ts_check.genotype_matrix()[i, :], ts_check.genotype_matrix()[i, :]
)
G_check[i, :, :] = np.add.outer(H_check[i, :], H_check[i, :])

F, c, ll = ls.forwards(
G_check, s, r, p_mutation=mu, scale_mutation_based_on_n_alleles=False
reference_panel=H_check,
query=s,
ploidy=2,
prob_recombination=r,
prob_mutation=mu,
scale_mutation_rate=False,
)
B = ls.backwards(
G_check,
s,
c,
r,
p_mutation=mu,
scale_mutation_based_on_n_alleles=False,
reference_panel=H_check,
query=s,
ploidy=2,
normalisation_factor_from_forward=c,
prob_recombination=r,
prob_mutation=mu,
scale_mutation_rate=False,
)

# Note, need to remove the first sample from the ts, and ensure that
Expand Down Expand Up @@ -1447,22 +1459,28 @@ def verify(self, ts):
ts_check, mapping = ts.simplify(
range(1, n + 1), filter_sites=False, map_nodes=True
)
H_check = ts_check.genotype_matrix()
G_check = np.zeros((m, n, n))
for i in range(m):
G_check[i, :, :] = np.add.outer(
ts_check.genotype_matrix()[i, :], ts_check.genotype_matrix()[i, :]
)
G_check[i, :, :] = np.add.outer(H_check[i, :], H_check[i, :])
ts_check = ts.simplify(range(1, n + 1), filter_sites=False)

phased_path, ll = ls.viterbi(
G_check, s, r, p_mutation=mu, scale_mutation_based_on_n_alleles=False
reference_panel=H_check,
query=s,
ploidy=2,
prob_recombination=r,
prob_mutation=mu,
scale_mutation_rate=False,
)
path_ll_matrix = ls.path_ll(
G_check,
s,
phased_path,
r,
p_mutation=mu,
scale_mutation_based_on_n_alleles=False,
path_ll_matrix = ls.path_loglik(
reference_panel=H_check,
query=s,
ploidy=2,
path=phased_path,
prob_recombination=r,
prob_mutation=mu,
scale_mutation_rate=False,
)

c_v = ls_viterbi_tree(s[0, :], ts_check, r, mu)
Expand All @@ -1471,13 +1489,14 @@ def verify(self, ts):
# Attempt to get the path
path_tree_dict = c_v.traceback()
# Work out the likelihood of the proposed path
path_ll_tree = ls.path_ll(
G_check,
s,
np.transpose(path_tree_dict),
r,
p_mutation=mu,
scale_mutation_based_on_n_alleles=False,
path_ll_tree = ls.path_loglik(
reference_panel=H_check,
query=s,
ploidy=2,
path=np.transpose(path_tree_dict),
prob_recombination=r,
prob_mutation=mu,
scale_mutation_rate=False,
)

self.assertAllClose(ll, ll_tree)
Expand Down
121 changes: 73 additions & 48 deletions python/tests/test_haplotype_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -925,12 +925,15 @@ def verify(self, ts):
self.assertAllClose(ll_tree, ll_mirror_tree)

# Ensure that the decoded matrices are the same
flipped_H = np.flip(H, axis=0)
flipped_s = np.flip(s, axis=1)
F_mirror_matrix, c, ll = ls.forwards(
np.flip(H, axis=0),
np.flip(s, axis=1),
r_flip,
p_mutation=np.flip(mu),
scale_mutation_based_on_n_alleles=False,
reference_panel=flipped_H,
query=flipped_s,
ploidy=1,
prob_recombination=r_flip,
prob_mutation=np.flip(mu),
scale_mutation_rate=False,
)

self.assertAllClose(F_mirror_matrix, cm_mirror.decode())
Expand All @@ -950,11 +953,12 @@ def verify(self, ts):
# Passed a vector of mutation rates, but rescaling each mutation
# rate conditional on the number of alleles
F, c, ll = ls.forwards(
H,
s,
r,
p_mutation=mu,
scale_mutation_based_on_n_alleles=scale_mutation,
reference_panel=H,
query=s,
ploidy=1,
prob_recombination=r,
prob_mutation=mu,
scale_mutation_rate=scale_mutation,
)
# Note, need to remove the first sample from the ts, and ensure
# that invariant sites aren't removed.
Expand All @@ -978,15 +982,21 @@ class TestForwardBackwardTree(FBAlgorithmBase):
def verify(self, ts):
for n, H, s, r, mu in self.example_parameters_haplotypes(ts):
F, c, ll = ls.forwards(
H, s, r, p_mutation=mu, scale_mutation_based_on_n_alleles=False
reference_panel=H,
query=s,
ploidy=1,
prob_recombination=r,
prob_mutation=mu,
scale_mutation_rate=False,
)
B = ls.backwards(
H,
s,
c,
r,
p_mutation=mu,
scale_mutation_based_on_n_alleles=False,
reference_panel=H,
query=s,
ploidy=1,
normalisation_factor_from_forward=c,
prob_recombination=r,
prob_mutation=mu,
scale_mutation_rate=False,
)

# Note, need to remove the first sample from the ts, and ensure that
Expand Down Expand Up @@ -1018,7 +1028,12 @@ class TestTreeViterbiHap(VitAlgorithmBase):
def verify(self, ts):
for n, H, s, r, mu in self.example_parameters_haplotypes(ts):
path, ll = ls.viterbi(
H, s, r, p_mutation=mu, scale_mutation_based_on_n_alleles=False
reference_panel=H,
query=s,
ploidy=1,
prob_recombination=r,
prob_mutation=mu,
scale_mutation_rate=False,
)
ts_check = ts.simplify(range(1, n + 1), filter_sites=False)
cm = ls_viterbi_tree(s[0, :], ts_check, r, mu)
Expand All @@ -1028,13 +1043,14 @@ def verify(self, ts):
# Now, need to ensure that the likelihood of the preferred path is
# the same as ll_tree (and ll).
path_tree = cm.traceback()
ll_check = ls.path_ll(
H,
s,
path_tree,
r,
p_mutation=mu,
scale_mutation_based_on_n_alleles=False,
ll_check = ls.path_loglik(
reference_panel=H,
query=s,
ploidy=1,
path=path_tree,
prob_recombination=r,
prob_mutation=mu,
scale_mutation_rate=False,
)
self.assertAllClose(ll, ll_check)

Expand All @@ -1051,13 +1067,15 @@ def check_viterbi(ts, h, recombination=None, mutation=None):
precision = 22

G = ts.genotype_matrix()
s = h.reshape(1, m)

path, ll = ls.viterbi(
G,
h.reshape(1, m),
recombination,
p_mutation=mutation,
scale_mutation_based_on_n_alleles=False,
reference_panel=G,
query=s,
ploidy=1,
prob_recombination=recombination,
prob_mutation=mutation,
scale_mutation_rate=False,
)
assert np.isscalar(ll)

Expand All @@ -1069,13 +1087,14 @@ def check_viterbi(ts, h, recombination=None, mutation=None):
# Check that the likelihood of the preferred path is
# the same as ll_tree (and ll).
path_tree = cm.traceback()
ll_check = ls.path_ll(
G,
h.reshape(1, m),
path_tree,
recombination,
p_mutation=mutation,
scale_mutation_based_on_n_alleles=False,
ll_check = ls.path_loglik(
reference_panel=G,
query=s,
ploidy=1,
path=path_tree,
prob_recombination=recombination,
prob_mutation=mutation,
scale_mutation_rate=False,
)
nt.assert_allclose(ll_check, ll)

Expand Down Expand Up @@ -1106,12 +1125,15 @@ def check_forward_matrix(ts, h, recombination=None, mutation=None):
mutation = np.zeros(ts.num_sites)

G = ts.genotype_matrix()
s = h.reshape(1, m)

F, c, ll = ls.forwards(
G,
h.reshape(1, m),
recombination,
p_mutation=mutation,
scale_mutation_based_on_n_alleles=False,
reference_panel=G,
query=s,
ploidy=1,
prob_recombination=recombination,
prob_mutation=mutation,
scale_mutation_rate=False,
)
assert F.shape == (m, n)
assert c.shape == (m,)
Expand Down Expand Up @@ -1150,13 +1172,16 @@ def check_backward_matrix(ts, h, forward_cm, recombination=None, mutation=None):
mutation = np.zeros(ts.num_sites)

G = ts.genotype_matrix()
s = h.reshape(1, m)

B = ls.backwards(
G,
h.reshape(1, m),
forward_cm.normalisation_factor,
recombination,
p_mutation=mutation,
scale_mutation_based_on_n_alleles=False,
reference_panel=G,
query=s,
ploidy=1,
normalisation_factor_from_forward=forward_cm.normalisation_factor,
prob_recombination=recombination,
prob_mutation=mutation,
scale_mutation_rate=False,
)

backward_cm = ls_backward_tree(
Expand Down

0 comments on commit a343802

Please sign in to comment.