Skip to content

Commit

Permalink
Fix tests for lshmm 0.0.8
Browse files Browse the repository at this point in the history
  • Loading branch information
szhan authored and mergify[bot] committed Jun 24, 2024
1 parent d3c59ba commit c818a6c
Show file tree
Hide file tree
Showing 5 changed files with 176 additions and 109 deletions.
2 changes: 1 addition & 1 deletion python/requirements/CI-complete/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ coverage==7.2.7
dendropy==4.6.1
h5py==3.9.0
kastore==0.3.2
lshmm==0.0.5
lshmm==0.0.8
msgpack==1.0.5
msprime==1.2.0
networkx==3.1
Expand Down
4 changes: 2 additions & 2 deletions python/requirements/CI-tests-pip/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
lshmm==0.0.5
lshmm==0.0.8
numpy==1.24.1
pytest==7.1.3
pytest-cov==4.0.0
Expand All @@ -14,4 +14,4 @@ newick==1.3.2
tszip==0.2.2
kastore==0.3.2
lxml==4.9.2
numba<=0.59.1 #Pinned directly as 0.60.0 fails
numba<=0.59.1 #Pinned directly as 0.60.0 fails
2 changes: 1 addition & 1 deletion python/requirements/development.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ h5py>=2.6.0
jsonschema>=3.0.0
jupyter-book>=0.12.1
kastore
lshmm>=0.0.5
lshmm>=0.0.8
matplotlib
meson>=0.61.0
msgpack>=1.0.0
Expand Down
103 changes: 63 additions & 40 deletions python/tests/test_genotype_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import lshmm as ls
import msprime
import numpy as np
import pytest

import tskit

Expand Down Expand Up @@ -1257,18 +1258,21 @@ def assertAllClose(self, A, B):

# Define a bunch of very small tree-sequences for testing a collection of
# parameters on
@pytest.mark.skip(reason="No plans to implement diploid LS HMM yet.")
def test_simple_n_10_no_recombination(self):
ts = msprime.simulate(
10, recombination_rate=0, mutation_rate=0.5, random_seed=42
)
assert ts.num_sites > 3
self.verify(ts)

@pytest.mark.skip(reason="No plans to implement diploid LS HMM yet.")
def test_simple_n_6(self):
ts = msprime.simulate(6, recombination_rate=2, mutation_rate=7, random_seed=42)
assert ts.num_sites > 5
self.verify(ts)

@pytest.mark.skip(reason="No plans to implement diploid LS HMM yet.")
def test_simple_n_8_high_recombination(self):
ts = msprime.simulate(8, recombination_rate=20, mutation_rate=5, random_seed=42)
assert ts.num_trees > 15
Expand Down Expand Up @@ -1326,11 +1330,10 @@ def verify(self, ts):
ts_check, mapping = ts.simplify(
range(1, n + 1), filter_sites=False, map_nodes=True
)
H_check = ts_check.genotype_matrix()
G_check = np.zeros((m, n, n))
for i in range(m):
G_check[i, :, :] = np.add.outer(
ts_check.genotype_matrix()[i, :], ts_check.genotype_matrix()[i, :]
)
G_check[i, :, :] = np.add.outer(H_check[i, :], H_check[i, :])

cm_d = ls_forward_tree(s[0, :], ts_check, r, mu)
ll_tree = np.sum(np.log10(cm_d.normalisation_factor))
Expand All @@ -1345,12 +1348,16 @@ def verify(self, ts):
self.assertAllClose(ll_tree, ll_mirror_tree_dict)

# Ensure that the decoded matrices are the same
flipped_H_check = np.flip(H_check, axis=0)
flipped_s = np.flip(s, axis=1)

F_mirror_matrix, c, ll = ls.forwards(
np.flip(G_check, axis=0),
np.flip(s, axis=1),
r_flip,
p_mutation=np.flip(mu),
scale_mutation_based_on_n_alleles=False,
flipped_H_check,
flipped_s,
ploidy=2,
prob_recombination=r_flip,
prob_mutation=np.flip(mu),
scale_mutation_rate=False,
)

self.assertAllClose(F_mirror_matrix, cm_mirror.decode())
Expand All @@ -1367,14 +1374,18 @@ def verify(self, ts):
ts_check, mapping = ts.simplify(
range(1, n + 1), filter_sites=False, map_nodes=True
)
H_check = ts_check.genotype_matrix()
G_check = np.zeros((m, n, n))
for i in range(m):
G_check[i, :, :] = np.add.outer(
ts_check.genotype_matrix()[i, :], ts_check.genotype_matrix()[i, :]
)
G_check[i, :, :] = np.add.outer(H_check[i, :], H_check[i, :])

F, c, ll = ls.forwards(
G_check, s, r, p_mutation=mu, scale_mutation_based_on_n_alleles=False
reference_panel=H_check,
query=s,
ploidy=2,
prob_recombination=r,
prob_mutation=mu,
scale_mutation_rate=False,
)
cm_d = ls_forward_tree(s[0, :], ts_check, r, mu)
self.assertAllClose(cm_d.decode(), F)
Expand All @@ -1393,22 +1404,27 @@ def verify(self, ts):
ts_check, mapping = ts.simplify(
range(1, n + 1), filter_sites=False, map_nodes=True
)
H_check = ts_check.genotype_matrix()
G_check = np.zeros((m, n, n))
for i in range(m):
G_check[i, :, :] = np.add.outer(
ts_check.genotype_matrix()[i, :], ts_check.genotype_matrix()[i, :]
)
G_check[i, :, :] = np.add.outer(H_check[i, :], H_check[i, :])

F, c, ll = ls.forwards(
G_check, s, r, p_mutation=mu, scale_mutation_based_on_n_alleles=False
reference_panel=H_check,
query=s,
ploidy=2,
prob_recombination=r,
prob_mutation=mu,
scale_mutation_rate=False,
)
B = ls.backwards(
G_check,
s,
c,
r,
p_mutation=mu,
scale_mutation_based_on_n_alleles=False,
reference_panel=H_check,
query=s,
ploidy=2,
normalisation_factor_from_forward=c,
prob_recombination=r,
prob_mutation=mu,
scale_mutation_rate=False,
)

# Note, need to remove the first sample from the ts, and ensure that
Expand Down Expand Up @@ -1447,22 +1463,28 @@ def verify(self, ts):
ts_check, mapping = ts.simplify(
range(1, n + 1), filter_sites=False, map_nodes=True
)
H_check = ts_check.genotype_matrix()
G_check = np.zeros((m, n, n))
for i in range(m):
G_check[i, :, :] = np.add.outer(
ts_check.genotype_matrix()[i, :], ts_check.genotype_matrix()[i, :]
)
G_check[i, :, :] = np.add.outer(H_check[i, :], H_check[i, :])
ts_check = ts.simplify(range(1, n + 1), filter_sites=False)

phased_path, ll = ls.viterbi(
G_check, s, r, p_mutation=mu, scale_mutation_based_on_n_alleles=False
reference_panel=H_check,
query=s,
ploidy=2,
prob_recombination=r,
prob_mutation=mu,
scale_mutation_rate=False,
)
path_ll_matrix = ls.path_ll(
G_check,
s,
phased_path,
r,
p_mutation=mu,
scale_mutation_based_on_n_alleles=False,
path_ll_matrix = ls.path_loglik(
reference_panel=H_check,
query=s,
ploidy=2,
path=phased_path,
prob_recombination=r,
prob_mutation=mu,
scale_mutation_rate=False,
)

c_v = ls_viterbi_tree(s[0, :], ts_check, r, mu)
Expand All @@ -1471,13 +1493,14 @@ def verify(self, ts):
# Attempt to get the path
path_tree_dict = c_v.traceback()
# Work out the likelihood of the proposed path
path_ll_tree = ls.path_ll(
G_check,
s,
np.transpose(path_tree_dict),
r,
p_mutation=mu,
scale_mutation_based_on_n_alleles=False,
path_ll_tree = ls.path_loglik(
reference_panel=H_check,
query=s,
ploidy=2,
path=np.transpose(path_tree_dict),
prob_recombination=r,
prob_mutation=mu,
scale_mutation_rate=False,
)

self.assertAllClose(ll, ll_tree)
Expand Down
Loading

0 comments on commit c818a6c

Please sign in to comment.