Skip to content

Commit

Permalink
Fixup numeric comparisons
Browse files Browse the repository at this point in the history
  • Loading branch information
jeromekelleher committed Jul 16, 2023
1 parent ea2c822 commit dca554e
Showing 1 changed file with 32 additions and 26 deletions.
58 changes: 32 additions & 26 deletions python/tests/test_haplotype_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,22 @@ def compute_next_probability(self, site_id, p_last, is_match, node):
return p_t * p_e


def assert_compressed_matrices_equal(cm1, cm2):

nt.assert_array_almost_equal(cm1.normalisation_factor, cm2.normalisation_factor)

for j in range(cm1.num_sites):
site1 = cm1.get_site(j)
site2 = cm2.get_site(j)
assert len(site1) == len(site2)
site1 = dict(site1)
site2 = dict(site2)
assert set(site1.keys()) == set(site2.keys())
for node in site1.keys():
# TODO the precision value should be used as a parameter here
nt.assert_allclose(site1[node], site2[node], rtol=1e-5, atol=1e-8)


class CompressedMatrix:
"""
Class representing a num_samples x num_sites matrix compressed by a
Expand Down Expand Up @@ -1068,7 +1084,7 @@ def check_viterbi(ts, h, recombination=None, mutation=None):
# Not true in general, but let's see how far it goes
nt.assert_array_equal(path_lib, path_tree)

nt.assert_array_almost_equal(cm_lib.normalisation_factor, cm.normalisation_factor)
nt.assert_allclose(cm_lib.normalisation_factor, cm.normalisation_factor)

return path

Expand Down Expand Up @@ -1101,8 +1117,8 @@ def check_forward_matrix(ts, h, recombination=None, mutation=None):
h, ts, recombination, mutation, scale_mutation_based_on_n_alleles=False
)
F2 = cm.decode()
nt.assert_array_almost_equal(F, F2)
nt.assert_array_almost_equal(c, cm.normalisation_factor)
nt.assert_allclose(F, F2)
nt.assert_allclose(c, cm.normalisation_factor)
ll_tree = np.sum(np.log10(cm.normalisation_factor))
nt.assert_allclose(ll_tree, ll)

Expand All @@ -1112,8 +1128,10 @@ def check_forward_matrix(ts, h, recombination=None, mutation=None):
ls_hmm.forward_matrix(h, cm_lib)
F3 = cm_lib.decode()

nt.assert_array_almost_equal(F, F3)
nt.assert_array_almost_equal(c, cm_lib.normalisation_factor)
assert_compressed_matrices_equal(cm, cm_lib)

nt.assert_allclose(F, F3)
nt.assert_allclose(c, cm_lib.normalisation_factor)
return cm_lib


Expand Down Expand Up @@ -1154,29 +1172,17 @@ def check_backward_matrix(ts, h, forward_cm, recombination=None, mutation=None):
cm_lib = _tskit.CompressedMatrix(ll_ts)
ls_hmm.backward_matrix(h, forward_cm.normalisation_factor, cm_lib)

for j in range(ts.num_sites):
py_site = backward_cm.get_site(j)
lib_site = backward_cm.get_site(j)
assert len(py_site) == len(lib_site)
py_site = dict(py_site)
lib_site = dict(lib_site)
assert set(py_site.keys()) == set(lib_site.keys())
# NOTE this probably won't work always and we'll need to put in
# some wiggle. But, they should be identical values, up to precision.
# However, the C and Python round() implementations are slightly different
# so this will almost certainly break.
for node in py_site.keys():
assert py_site[node] == lib_site[node]

# Something weird is happening here - why don't these agree?

nt.assert_array_equal(cm_lib.normalisation_factor, forward_cm.normalisation_factor)
assert_compressed_matrices_equal(backward_cm, cm_lib)

# FIXME!! There must be an error in the lib decode method
B_lib = cm_lib.decode()
B_tree = backward_cm.decode()
# print(B_tree)
# print(B_lib)
nt.assert_array_almost_equal(B_tree, B_lib)
nt.assert_array_almost_equal(B, B_lib)
print()
print(B)
print(B_tree)
print(B_lib)
nt.assert_allclose(B_tree, B_lib)
nt.assert_allclose(B, B_lib)
# print(B_lib)
# print(B)

Expand Down

0 comments on commit dca554e

Please sign in to comment.