diff --git a/python/tests/test_haplotype_matching.py b/python/tests/test_haplotype_matching.py index bdcde5fad3..43f84a1503 100644 --- a/python/tests/test_haplotype_matching.py +++ b/python/tests/test_haplotype_matching.py @@ -1022,19 +1022,19 @@ def verify(self, ts): # TODO add params to run the various checks def check_viterbi(ts, h, recombination=None, mutation=None): h = np.array(h).astype(np.int8) - n = ts.num_samples - assert len(h) == ts.num_sites + m = ts.num_sites + assert len(h) == m if recombination is None: recombination = np.zeros(ts.num_sites) + 1e-9 if mutation is None: mutation = np.zeros(ts.num_sites) precision = 22 - H = ts.genotype_matrix().T + G = ts.genotype_matrix() path, ll = ls.viterbi( - H, - h.reshape(1, n), + G, + h.reshape(1, m), recombination, mutation_rate=mutation, scale_mutation_based_on_n_alleles=False, @@ -1050,8 +1050,8 @@ def check_viterbi(ts, h, recombination=None, mutation=None): # the same as ll_tree (and ll). path_tree = cm.traceback() ll_check = ls.path_ll( - H, - h.reshape(1, n), + G, + h.reshape(1, m), path_tree, recombination, mutation_rate=mutation, @@ -1079,16 +1079,16 @@ def check_forward_matrix(ts, h, recombination=None, mutation=None): h = np.array(h).astype(np.int8) n = ts.num_samples m = ts.num_sites - assert len(h) == ts.num_sites + assert len(h) == m if recombination is None: recombination = np.zeros(ts.num_sites) + 1e-9 if mutation is None: mutation = np.zeros(ts.num_sites) - H = ts.genotype_matrix().T + G = ts.genotype_matrix() F, c, ll = ls.forwards( - H, - h.reshape(1, n), + G, + h.reshape(1, m), recombination, mutation_rate=mutation, scale_mutation_based_on_n_alleles=False, @@ -1120,17 +1120,17 @@ def check_forward_matrix(ts, h, recombination=None, mutation=None): def check_backward_matrix(ts, h, forward_cm, recombination=None, mutation=None): precision = 22 h = np.array(h).astype(np.int8) - n = ts.num_samples - assert len(h) == ts.num_sites + m = ts.num_sites + assert len(h) == m if recombination is None: recombination = np.zeros(ts.num_sites) + 1e-9 if mutation is None: mutation = np.zeros(ts.num_sites) - H = ts.genotype_matrix().T + G = ts.genotype_matrix() B = ls.backwards( - H, - h.reshape(1, n), + G, + h.reshape(1, m), forward_cm.normalisation_factor, recombination, mutation_rate=mutation, @@ -1148,8 +1148,6 @@ def check_backward_matrix(ts, h, forward_cm, recombination=None, mutation=None): nt.assert_array_equal( backward_cm.normalisation_factor, forward_cm.normalisation_factor ) - B_tree = backward_cm.decode() - nt.assert_array_almost_equal(B, B_tree) ll_ts = ts._ll_tree_sequence ls_hmm = _tskit.LsHmm(ll_ts, recombination, mutation, precision=precision) @@ -1170,11 +1168,17 @@ def check_backward_matrix(ts, h, forward_cm, recombination=None, mutation=None): for node in py_site.keys(): assert py_site[node] == lib_site[node] + # Something weird is happening here - why don't these agree? + nt.assert_array_equal(cm_lib.normalisation_factor, forward_cm.normalisation_factor) B_lib = cm_lib.decode() + B_tree = backward_cm.decode() + # print(B_tree) # print(B_lib) - # print(B) + nt.assert_array_almost_equal(B_tree, B_lib) nt.assert_array_almost_equal(B, B_lib) + # print(B_lib) + # print(B) def add_unique_sample_mutations(ts, start=0): @@ -1258,3 +1262,16 @@ def test_switch_each_sample_missing_middle(self): nt.assert_array_equal([0, 3, 3, 3], path) cm = check_forward_matrix(ts, h) check_backward_matrix(ts, h, cm) + + +class TestSimulationExamples: + @pytest.mark.parametrize("n", [5, 10, 50]) + @pytest.mark.parametrize("L", [1, 10, 100]) + def test_continuous_genome(self, n, L): + ts = msprime.simulate( + n, length=L, recombination_rate=1, mutation_rate=1, random_seed=42 + ) + h = np.zeros(ts.num_sites, dtype=np.int8) + check_viterbi(ts, h) + cm = check_forward_matrix(ts, h) + check_backward_matrix(ts, h, cm)