From 82e97eb20718ebe926137d9704a251fd141d4b52 Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Sat, 15 Jul 2023 16:30:40 +0100 Subject: [PATCH] Fixed bug in next proba expression --- c/tskit/haplotype_matching.c | 21 ++++++++++++++++- python/tests/test_haplotype_matching.py | 31 +++++++++++++++++++------ 2 files changed, 44 insertions(+), 8 deletions(-) diff --git a/c/tskit/haplotype_matching.c b/c/tskit/haplotype_matching.c index eb6eeb0f4f..07df5f014e 100644 --- a/c/tskit/haplotype_matching.c +++ b/c/tskit/haplotype_matching.c @@ -1140,7 +1140,7 @@ tsk_ls_hmm_process_site_backward(tsk_ls_hmm_t *self, const tsk_site_t *site, b_last_sum = self->compute_normalisation_factor(self); for (j = 0; j < self->num_transitions; j++) { tsk_bug_assert(T[j].tree_node != TSK_NULL); - x = T[j].value * b_last_sum / n + (1 - rho) * T[j].value; + x = rho * b_last_sum / n + (1 - rho) * T[j].value; x /= normalisation_factor; T[j].value = tsk_round(x, precision); } @@ -1148,6 +1148,25 @@ tsk_ls_hmm_process_site_backward(tsk_ls_hmm_t *self, const tsk_site_t *site, return ret; } +/* def process_site(self, site, haplotype_state, s): */ +/* self.output.store_site( */ +/* site.id, */ +/* s, */ +/* # We need to filter out the -1 nodes here for the first site */ +/* # we examine. This is a bit of a hack */ +/* [(st.tree_node, st.value) for st in self.T if st.tree_node != -1], */ +/* ) */ +/* self.update_probabilities(site, haplotype_state) */ +/* self.compress() */ +/* b_last_sum = self.compute_normalisation_factor() */ +/* n = self.ts.num_samples */ +/* rho = self.rho[site.id] */ +/* for st in self.T: */ +/* if st.tree_node != tskit.NULL: */ +/* st.value = rho * b_last_sum / n + (1 - rho) * st.value */ +/* st.value /= s */ +/* st.value = round(st.value, self.precision) */ + static int tsk_ls_hmm_run_backward( tsk_ls_hmm_t *self, int32_t *haplotype, const double *forward_norm) diff --git a/python/tests/test_haplotype_matching.py b/python/tests/test_haplotype_matching.py index dbc98ed53e..bdcde5fad3 100644 --- a/python/tests/test_haplotype_matching.py +++ b/python/tests/test_haplotype_matching.py @@ -1118,7 +1118,7 @@ def check_forward_matrix(ts, h, recombination=None, mutation=None): def check_backward_matrix(ts, h, forward_cm, recombination=None, mutation=None): - # precision = 22 + precision = 22 h = np.array(h).astype(np.int8) n = ts.num_samples assert len(h) == ts.num_sites @@ -1143,6 +1143,7 @@ def check_backward_matrix(ts, h, forward_cm, recombination=None, mutation=None): recombination, mutation, forward_cm.normalisation_factor, + precision=precision, ) nt.assert_array_equal( backward_cm.normalisation_factor, forward_cm.normalisation_factor @@ -1150,14 +1151,30 @@ def check_backward_matrix(ts, h, forward_cm, recombination=None, mutation=None): B_tree = backward_cm.decode() nt.assert_array_almost_equal(B, B_tree) - # ll_ts = ts._ll_tree_sequence - # ls_hmm = _tskit.LsHmm(ll_ts, recombination, mutation, precision=precision) - # cm_lib = _tskit.CompressedMatrix(ll_ts) - # ls_hmm.backward_matrix(h, forward_cm.normalisation_factor, cm_lib) - # B_lib = cm_lib.decode() + ll_ts = ts._ll_tree_sequence + ls_hmm = _tskit.LsHmm(ll_ts, recombination, mutation, precision=precision) + cm_lib = _tskit.CompressedMatrix(ll_ts) + ls_hmm.backward_matrix(h, forward_cm.normalisation_factor, cm_lib) + + for j in range(ts.num_sites): + py_site = backward_cm.get_site(j) + lib_site = backward_cm.get_site(j) + assert len(py_site) == len(lib_site) + py_site = dict(py_site) + lib_site = dict(lib_site) + assert set(py_site.keys()) == set(lib_site.keys()) + # NOTE this probably won't work always and we'll need to put in + # some wiggle. But, they should be identical values, up to precision. + # However, the C and Python round() implementations are slightly different + # so this will almost certainly break. + for node in py_site.keys(): + assert py_site[node] == lib_site[node] + + nt.assert_array_equal(cm_lib.normalisation_factor, forward_cm.normalisation_factor) + B_lib = cm_lib.decode() # print(B_lib) # print(B) - # nt.assert_array_almost_equal(B, B_lib) + nt.assert_array_almost_equal(B, B_lib) def add_unique_sample_mutations(ts, start=0):