From c818a6c8a8769e29100a4747eea74d6396c1bd5a Mon Sep 17 00:00:00 2001 From: szhan Date: Thu, 20 Jun 2024 17:46:50 +0100 Subject: [PATCH] Fix tests for lshmm 0.0.8 --- .../requirements/CI-complete/requirements.txt | 2 +- .../CI-tests-pip/requirements.txt | 4 +- python/requirements/development.txt | 2 +- python/tests/test_genotype_matching.py | 103 +++++++---- python/tests/test_haplotype_matching.py | 174 +++++++++++------- 5 files changed, 176 insertions(+), 109 deletions(-) diff --git a/python/requirements/CI-complete/requirements.txt b/python/requirements/CI-complete/requirements.txt index c5975a6153..62e2d3d365 100644 --- a/python/requirements/CI-complete/requirements.txt +++ b/python/requirements/CI-complete/requirements.txt @@ -3,7 +3,7 @@ coverage==7.2.7 dendropy==4.6.1 h5py==3.9.0 kastore==0.3.2 -lshmm==0.0.5 +lshmm==0.0.8 msgpack==1.0.5 msprime==1.2.0 networkx==3.1 diff --git a/python/requirements/CI-tests-pip/requirements.txt b/python/requirements/CI-tests-pip/requirements.txt index 284a1ccfbd..d713bbdea0 100644 --- a/python/requirements/CI-tests-pip/requirements.txt +++ b/python/requirements/CI-tests-pip/requirements.txt @@ -1,4 +1,4 @@ -lshmm==0.0.5 +lshmm==0.0.8 numpy==1.24.1 pytest==7.1.3 pytest-cov==4.0.0 @@ -14,4 +14,4 @@ newick==1.3.2 tszip==0.2.2 kastore==0.3.2 lxml==4.9.2 -numba<=0.59.1 #Pinned directly as 0.60.0 fails \ No newline at end of file +numba<=0.59.1 #Pinned directly as 0.60.0 fails diff --git a/python/requirements/development.txt b/python/requirements/development.txt index c6a27a2416..6bb0e13496 100644 --- a/python/requirements/development.txt +++ b/python/requirements/development.txt @@ -10,7 +10,7 @@ h5py>=2.6.0 jsonschema>=3.0.0 jupyter-book>=0.12.1 kastore -lshmm>=0.0.5 +lshmm>=0.0.8 matplotlib meson>=0.61.0 msgpack>=1.0.0 diff --git a/python/tests/test_genotype_matching.py b/python/tests/test_genotype_matching.py index a04e3873a6..01c8b40a59 100644 --- a/python/tests/test_genotype_matching.py +++ b/python/tests/test_genotype_matching.py @@ -4,6 +4,7 @@ import lshmm as ls import msprime import numpy as np +import pytest import tskit @@ -1257,6 +1258,7 @@ def assertAllClose(self, A, B): # Define a bunch of very small tree-sequences for testing a collection of # parameters on + @pytest.mark.skip(reason="No plans to implement diploid LS HMM yet.") def test_simple_n_10_no_recombination(self): ts = msprime.simulate( 10, recombination_rate=0, mutation_rate=0.5, random_seed=42 @@ -1264,11 +1266,13 @@ def test_simple_n_10_no_recombination(self): assert ts.num_sites > 3 self.verify(ts) + @pytest.mark.skip(reason="No plans to implement diploid LS HMM yet.") def test_simple_n_6(self): ts = msprime.simulate(6, recombination_rate=2, mutation_rate=7, random_seed=42) assert ts.num_sites > 5 self.verify(ts) + @pytest.mark.skip(reason="No plans to implement diploid LS HMM yet.") def test_simple_n_8_high_recombination(self): ts = msprime.simulate(8, recombination_rate=20, mutation_rate=5, random_seed=42) assert ts.num_trees > 15 @@ -1326,11 +1330,10 @@ def verify(self, ts): ts_check, mapping = ts.simplify( range(1, n + 1), filter_sites=False, map_nodes=True ) + H_check = ts_check.genotype_matrix() G_check = np.zeros((m, n, n)) for i in range(m): - G_check[i, :, :] = np.add.outer( - ts_check.genotype_matrix()[i, :], ts_check.genotype_matrix()[i, :] - ) + G_check[i, :, :] = np.add.outer(H_check[i, :], H_check[i, :]) cm_d = ls_forward_tree(s[0, :], ts_check, r, mu) ll_tree = np.sum(np.log10(cm_d.normalisation_factor)) @@ -1345,12 +1348,16 @@ def verify(self, ts): self.assertAllClose(ll_tree, ll_mirror_tree_dict) # Ensure that the decoded matrices are the same + flipped_H_check = np.flip(H_check, axis=0) + flipped_s = np.flip(s, axis=1) + F_mirror_matrix, c, ll = ls.forwards( - np.flip(G_check, axis=0), - np.flip(s, axis=1), - r_flip, - p_mutation=np.flip(mu), - scale_mutation_based_on_n_alleles=False, + flipped_H_check, + flipped_s, + ploidy=2, + prob_recombination=r_flip, + prob_mutation=np.flip(mu), + scale_mutation_rate=False, ) self.assertAllClose(F_mirror_matrix, cm_mirror.decode()) @@ -1367,14 +1374,18 @@ def verify(self, ts): ts_check, mapping = ts.simplify( range(1, n + 1), filter_sites=False, map_nodes=True ) + H_check = ts_check.genotype_matrix() G_check = np.zeros((m, n, n)) for i in range(m): - G_check[i, :, :] = np.add.outer( - ts_check.genotype_matrix()[i, :], ts_check.genotype_matrix()[i, :] - ) + G_check[i, :, :] = np.add.outer(H_check[i, :], H_check[i, :]) F, c, ll = ls.forwards( - G_check, s, r, p_mutation=mu, scale_mutation_based_on_n_alleles=False + reference_panel=H_check, + query=s, + ploidy=2, + prob_recombination=r, + prob_mutation=mu, + scale_mutation_rate=False, ) cm_d = ls_forward_tree(s[0, :], ts_check, r, mu) self.assertAllClose(cm_d.decode(), F) @@ -1393,22 +1404,27 @@ def verify(self, ts): ts_check, mapping = ts.simplify( range(1, n + 1), filter_sites=False, map_nodes=True ) + H_check = ts_check.genotype_matrix() G_check = np.zeros((m, n, n)) for i in range(m): - G_check[i, :, :] = np.add.outer( - ts_check.genotype_matrix()[i, :], ts_check.genotype_matrix()[i, :] - ) + G_check[i, :, :] = np.add.outer(H_check[i, :], H_check[i, :]) F, c, ll = ls.forwards( - G_check, s, r, p_mutation=mu, scale_mutation_based_on_n_alleles=False + reference_panel=H_check, + query=s, + ploidy=2, + prob_recombination=r, + prob_mutation=mu, + scale_mutation_rate=False, ) B = ls.backwards( - G_check, - s, - c, - r, - p_mutation=mu, - scale_mutation_based_on_n_alleles=False, + reference_panel=H_check, + query=s, + ploidy=2, + normalisation_factor_from_forward=c, + prob_recombination=r, + prob_mutation=mu, + scale_mutation_rate=False, ) # Note, need to remove the first sample from the ts, and ensure that @@ -1447,22 +1463,28 @@ def verify(self, ts): ts_check, mapping = ts.simplify( range(1, n + 1), filter_sites=False, map_nodes=True ) + H_check = ts_check.genotype_matrix() G_check = np.zeros((m, n, n)) for i in range(m): - G_check[i, :, :] = np.add.outer( - ts_check.genotype_matrix()[i, :], ts_check.genotype_matrix()[i, :] - ) + G_check[i, :, :] = np.add.outer(H_check[i, :], H_check[i, :]) ts_check = ts.simplify(range(1, n + 1), filter_sites=False) + phased_path, ll = ls.viterbi( - G_check, s, r, p_mutation=mu, scale_mutation_based_on_n_alleles=False + reference_panel=H_check, + query=s, + ploidy=2, + prob_recombination=r, + prob_mutation=mu, + scale_mutation_rate=False, ) - path_ll_matrix = ls.path_ll( - G_check, - s, - phased_path, - r, - p_mutation=mu, - scale_mutation_based_on_n_alleles=False, + path_ll_matrix = ls.path_loglik( + reference_panel=H_check, + query=s, + ploidy=2, + path=phased_path, + prob_recombination=r, + prob_mutation=mu, + scale_mutation_rate=False, ) c_v = ls_viterbi_tree(s[0, :], ts_check, r, mu) @@ -1471,13 +1493,14 @@ def verify(self, ts): # Attempt to get the path path_tree_dict = c_v.traceback() # Work out the likelihood of the proposed path - path_ll_tree = ls.path_ll( - G_check, - s, - np.transpose(path_tree_dict), - r, - p_mutation=mu, - scale_mutation_based_on_n_alleles=False, + path_ll_tree = ls.path_loglik( + reference_panel=H_check, + query=s, + ploidy=2, + path=np.transpose(path_tree_dict), + prob_recombination=r, + prob_mutation=mu, + scale_mutation_rate=False, ) self.assertAllClose(ll, ll_tree) diff --git a/python/tests/test_haplotype_matching.py b/python/tests/test_haplotype_matching.py index dcc1d684fb..dc01e8370d 100644 --- a/python/tests/test_haplotype_matching.py +++ b/python/tests/test_haplotype_matching.py @@ -333,11 +333,12 @@ def update_probabilities(self, site, haplotype_state): while allelic_state[v] == -1: v = tree.parent(v) assert v != -1 - match = ( - haplotype_state == MISSING or haplotype_state == allelic_state[v] - ) + match = haplotype_state == allelic_state[v] + is_query_missing = haplotype_state == MISSING # Note that the node u is used only by Viterbi - st.value = self.compute_next_probability(site.id, st.value, match, u) + st.value = self.compute_next_probability( + site.id, st.value, match, is_query_missing, u + ) # Unset the states allelic_state[tree.root] = -1 @@ -404,7 +405,9 @@ def run(self, h): def compute_normalisation_factor(self): raise NotImplementedError() - def compute_next_probability(self, site_id, p_last, is_match, node): + def compute_next_probability( + self, site_id, p_last, is_match, is_query_missing, node + ): raise NotImplementedError() @@ -435,10 +438,15 @@ def compute_normalisation_factor(self): s += self.N[j] * st.value return s - def compute_next_probability(self, site_id, p_last, is_match, node): + def compute_next_probability( + self, site_id, p_last, is_match, is_query_missing, node + ): rho = self.rho[site_id] n = self.ts.num_samples - p_e = self.compute_emission_proba(site_id, is_match) + if is_query_missing: + p_e = 1.0 + else: + p_e = self.compute_emission_proba(site_id, is_match) p_t = p_last * (1 - rho) + rho / n return p_t * p_e @@ -448,8 +456,13 @@ class BackwardAlgorithm(ForwardAlgorithm): The Li and Stephens backward algorithm. """ - def compute_next_probability(self, site_id, p_next, is_match, node): - p_e = self.compute_emission_proba(site_id, is_match) + def compute_next_probability( + self, site_id, p_next, is_match, is_query_missing, node + ): + if is_query_missing: + p_e = 1.0 + else: + p_e = self.compute_emission_proba(site_id, is_match) return p_next * p_e def process_site(self, site, haplotype_state, s): @@ -515,7 +528,9 @@ def compute_normalisation_factor(self): ) return max_st.value - def compute_next_probability(self, site_id, p_last, is_match, node): + def compute_next_probability( + self, site_id, p_last, is_match, is_query_missing, node + ): rho = self.rho[site_id] n = self.ts.num_samples @@ -529,7 +544,11 @@ def compute_next_probability(self, site_id, p_last, is_match, node): recombination_required = True self.output.add_recombination_required(site_id, node, recombination_required) - p_e = self.compute_emission_proba(site_id, is_match) + if is_query_missing: + p_e = 1.0 + else: + p_e = self.compute_emission_proba(site_id, is_match) + return p_t * p_e @@ -679,12 +698,12 @@ def traceback(self): def get_site_alleles(ts, h, alleles): if alleles is None: - n_alleles = np.int8( - [ - len(np.unique(np.append(ts.genotype_matrix()[j, :], h[j]))) - for j in range(ts.num_sites) - ] - ) + n_alleles = np.zeros(ts.num_sites, dtype=np.int8) - 1 + for j in range(ts.num_sites): + uniq_alleles = np.unique(np.append(ts.genotype_matrix()[j, :], h[j])) + uniq_alleles = uniq_alleles[uniq_alleles != MISSING] + n_alleles[j] = len(uniq_alleles) + assert np.all(n_alleles > 0) alleles = tskit.ALLELES_ACGT if len(set(alleles).intersection(next(ts.variants()).alleles)) == 0: alleles = tskit.ALLELES_01 @@ -925,12 +944,15 @@ def verify(self, ts): self.assertAllClose(ll_tree, ll_mirror_tree) # Ensure that the decoded matrices are the same + flipped_H = np.flip(H, axis=0) + flipped_s = np.flip(s, axis=1) F_mirror_matrix, c, ll = ls.forwards( - np.flip(H, axis=0), - np.flip(s, axis=1), - r_flip, - p_mutation=np.flip(mu), - scale_mutation_based_on_n_alleles=False, + reference_panel=flipped_H, + query=flipped_s, + ploidy=1, + prob_recombination=r_flip, + prob_mutation=np.flip(mu), + scale_mutation_rate=False, ) self.assertAllClose(F_mirror_matrix, cm_mirror.decode()) @@ -950,11 +972,12 @@ def verify(self, ts): # Passed a vector of mutation rates, but rescaling each mutation # rate conditional on the number of alleles F, c, ll = ls.forwards( - H, - s, - r, - p_mutation=mu, - scale_mutation_based_on_n_alleles=scale_mutation, + reference_panel=H, + query=s, + ploidy=1, + prob_recombination=r, + prob_mutation=mu, + scale_mutation_rate=scale_mutation, ) # Note, need to remove the first sample from the ts, and ensure # that invariant sites aren't removed. @@ -978,15 +1001,21 @@ class TestForwardBackwardTree(FBAlgorithmBase): def verify(self, ts): for n, H, s, r, mu in self.example_parameters_haplotypes(ts): F, c, ll = ls.forwards( - H, s, r, p_mutation=mu, scale_mutation_based_on_n_alleles=False + reference_panel=H, + query=s, + ploidy=1, + prob_recombination=r, + prob_mutation=mu, + scale_mutation_rate=False, ) B = ls.backwards( - H, - s, - c, - r, - p_mutation=mu, - scale_mutation_based_on_n_alleles=False, + reference_panel=H, + query=s, + ploidy=1, + normalisation_factor_from_forward=c, + prob_recombination=r, + prob_mutation=mu, + scale_mutation_rate=False, ) # Note, need to remove the first sample from the ts, and ensure that @@ -1018,7 +1047,12 @@ class TestTreeViterbiHap(VitAlgorithmBase): def verify(self, ts): for n, H, s, r, mu in self.example_parameters_haplotypes(ts): path, ll = ls.viterbi( - H, s, r, p_mutation=mu, scale_mutation_based_on_n_alleles=False + reference_panel=H, + query=s, + ploidy=1, + prob_recombination=r, + prob_mutation=mu, + scale_mutation_rate=False, ) ts_check = ts.simplify(range(1, n + 1), filter_sites=False) cm = ls_viterbi_tree(s[0, :], ts_check, r, mu) @@ -1028,13 +1062,14 @@ def verify(self, ts): # Now, need to ensure that the likelihood of the preferred path is # the same as ll_tree (and ll). path_tree = cm.traceback() - ll_check = ls.path_ll( - H, - s, - path_tree, - r, - p_mutation=mu, - scale_mutation_based_on_n_alleles=False, + ll_check = ls.path_loglik( + reference_panel=H, + query=s, + ploidy=1, + path=path_tree, + prob_recombination=r, + prob_mutation=mu, + scale_mutation_rate=False, ) self.assertAllClose(ll, ll_check) @@ -1051,13 +1086,15 @@ def check_viterbi(ts, h, recombination=None, mutation=None): precision = 22 G = ts.genotype_matrix() + s = h.reshape(1, m) path, ll = ls.viterbi( - G, - h.reshape(1, m), - recombination, - p_mutation=mutation, - scale_mutation_based_on_n_alleles=False, + reference_panel=G, + query=s, + ploidy=1, + prob_recombination=recombination, + prob_mutation=mutation, + scale_mutation_rate=False, ) assert np.isscalar(ll) @@ -1069,13 +1106,14 @@ def check_viterbi(ts, h, recombination=None, mutation=None): # Check that the likelihood of the preferred path is # the same as ll_tree (and ll). path_tree = cm.traceback() - ll_check = ls.path_ll( - G, - h.reshape(1, m), - path_tree, - recombination, - p_mutation=mutation, - scale_mutation_based_on_n_alleles=False, + ll_check = ls.path_loglik( + reference_panel=G, + query=s, + ploidy=1, + path=path_tree, + prob_recombination=recombination, + prob_mutation=mutation, + scale_mutation_rate=False, ) nt.assert_allclose(ll_check, ll) @@ -1106,12 +1144,15 @@ def check_forward_matrix(ts, h, recombination=None, mutation=None): mutation = np.zeros(ts.num_sites) G = ts.genotype_matrix() + s = h.reshape(1, m) + F, c, ll = ls.forwards( - G, - h.reshape(1, m), - recombination, - p_mutation=mutation, - scale_mutation_based_on_n_alleles=False, + reference_panel=G, + query=s, + ploidy=1, + prob_recombination=recombination, + prob_mutation=mutation, + scale_mutation_rate=False, ) assert F.shape == (m, n) assert c.shape == (m,) @@ -1150,13 +1191,16 @@ def check_backward_matrix(ts, h, forward_cm, recombination=None, mutation=None): mutation = np.zeros(ts.num_sites) G = ts.genotype_matrix() + s = h.reshape(1, m) + B = ls.backwards( - G, - h.reshape(1, m), - forward_cm.normalisation_factor, - recombination, - p_mutation=mutation, - scale_mutation_based_on_n_alleles=False, + reference_panel=G, + query=s, + ploidy=1, + normalisation_factor_from_forward=forward_cm.normalisation_factor, + prob_recombination=recombination, + prob_mutation=mutation, + scale_mutation_rate=False, ) backward_cm = ls_backward_tree(