diff --git a/python/tests/test_imputation.py b/python/tests/test_imputation.py new file mode 100644 index 0000000000..0943efc55a --- /dev/null +++ b/python/tests/test_imputation.py @@ -0,0 +1,301 @@ +""" +Tests for genotype imputation (forward and Baum-Welsh algorithms). +""" +import io + +import numpy as np + +import _tskit +import tskit + + +# A toy tree sequence containing 3 diploid individuals with 5 sites and 5 mutations. +# Two toy query haplotypes are targets for imputation. + +toy_ts_nodes_text = """\ +id is_sample time population individual metadata +0 1 0.000000 0 0 +1 1 0.000000 0 0 +2 1 0.000000 0 1 +3 1 0.000000 0 1 +4 1 0.000000 0 2 +5 1 0.000000 0 2 +6 0 0.029768 0 -1 +7 0 0.133017 0 -1 +8 0 0.223233 0 -1 +9 0 0.651586 0 -1 +10 0 0.698831 0 -1 +11 0 2.114867 0 -1 +12 0 4.322031 0 -1 +13 0 7.432311 0 -1 +""" + +toy_ts_edges_text = """\ +left right parent child metadata +0.000000 1000000.000000 6 0 +0.000000 1000000.000000 6 3 +0.000000 1000000.000000 7 2 +0.000000 1000000.000000 7 5 +0.000000 1000000.000000 8 1 +0.000000 1000000.000000 8 4 +0.000000 781157.000000 9 6 +0.000000 781157.000000 9 7 +0.000000 505438.000000 10 8 +0.000000 505438.000000 10 9 +505438.000000 549484.000000 11 8 +505438.000000 549484.000000 11 9 +781157.000000 1000000.000000 12 6 +781157.000000 1000000.000000 12 7 +549484.000000 1000000.000000 13 8 +549484.000000 781157.000000 13 9 +781157.000000 1000000.000000 13 12 +""" + +toy_ts_sites_text = """\ +position ancestral_state metadata +200000.000000 A +300000.000000 C +520000.000000 G +600000.000000 T +900000.000000 A +""" + +toy_ts_mutations_text = """\ +site node time derived_state parent metadata +0 9 unknown G -1 +1 8 unknown A -1 +2 9 unknown T -1 +3 9 unknown C -1 +4 12 unknown C -1 +""" + +toy_ts_individuals_text = """\ +flags +0 +0 +0 +""" + +toy_query_haplotypes_01 = np.array( + [ + [ + 1, + 0, + -1, + 0, + 0, + ], + [ + 0, + 1, + -1, + 1, + 0, + ], + ], + dtype=np.int32, +) + +toy_query_haplotypes_ACGT = np.array( + [ + [2, 1, -1, 3, 0], # GCTA + [0, 0, -1, 1, 0], # AACA + ], + dtype=np.int32, +) + + +def get_toy_data(): + ref_ts = tskit.load_text( + nodes=io.StringIO(toy_ts_nodes_text), + edges=io.StringIO(toy_ts_edges_text), + sites=io.StringIO(toy_ts_sites_text), + mutations=io.StringIO(toy_ts_mutations_text), + individuals=io.StringIO(toy_ts_individuals_text), + strict=False, + ) + query_h = toy_query_haplotypes_ACGT + return [ref_ts, query_h] + + +def get_tskit_forward_backward_matrices(ts, h): + m = ts.num_sites + fm = _tskit.CompressedMatrix(ts._ll_tree_sequence) + bm = _tskit.CompressedMatrix(ts._ll_tree_sequence) + ls_hmm = _tskit.LsHmm( + ts._ll_tree_sequence, np.zeros(m) + 0.1, np.zeros(m) + 0.1, acgt_alleles=True + ) + ls_hmm.forward_matrix(h, fm) + ls_hmm.backward_matrix(h, fm.normalisation_factor, bm) + return [fm.decode(), bm.decode()] + + +# BEAGLE 4.1 was run on the toy data set above using default parameters. +# +# In the query VCF, the site at position 520,000 was redacted and then imputed. +# Note that the ancestral allele in the simulated tree sequence is +# treated as the REF in the VCFs. +# +# The following are the forward probability matrices and backward probability +# matrices calculated when imputing into the third individual above. There are +# two sets of matrices, one for each haplotype. +# +# Notes about calculations: +# n = number of haplotypes in ref. panel +# M = number of markers +# m = index of marker (site) +# h = index of haplotype in ref. panel +# +# In forward probability matrix, +# fwd[m][h] = emission prob., if m = 0 (first marker) +# fwd[m][h] = emission prob. * (scale * fwd[m - 1][h] + shift), otherwise +# where scale = (1 - switch prob.)/sum of fwd[m - 1], +# and shift = switch prob./n. +# +# In backward probability matrix, +# bwd[m][h] = 1, if m = M - 1 (last marker) // DON'T SEE THIS IN BEAGLE +# unadj. bwd[m][h] = emission prob. / n +# bwd[m][h] = (unadj. bwd[m][h] + shift) * scale, otherwise +# where scale = (1 - switch prob.)/sum of unadj. bwd[m], +# and shift = switch prob./n. +# +# For each site, the sum of backward value over all haplotypes is calculated +# before scaling and shifting. + +beagle_fwd_matrix_text_1 = """ +m,h,probRec,probNoRec,noErrProb,errProb,refAl,queryAl,shiftFac,scaleFac,sumSite,val, +0,0,0.000000,1.000000,0.999900,0.000100,1,1,0.000000,1.000000,0.999900,0.999900, +0,1,0.000000,1.000000,0.999900,0.000100,0,1,0.000000,1.000000,1.000000,0.000100, +0,2,0.000000,1.000000,0.999900,0.000100,1,1,0.000000,1.000000,1.999900,0.999900, +0,3,0.000000,1.000000,0.999900,0.000100,1,1,0.000000,1.000000,2.999800,0.999900, +0,4,0.000000,1.000000,0.999900,0.000100,0,1,0.000000,1.000000,2.999900,0.000100, +0,5,0.000000,1.000000,0.999900,0.000100,1,1,0.000000,1.000000,3.999800,0.999900, +1,0,1.000000,0.000000,0.999900,0.000100,0,0,0.166667,0.000000,0.166650,0.166650, +1,1,1.000000,0.000000,0.999900,0.000100,1,0,0.166667,0.000000,0.166667,0.000017, +1,2,1.000000,0.000000,0.999900,0.000100,0,0,0.166667,0.000000,0.333317,0.166650, +1,3,1.000000,0.000000,0.999900,0.000100,0,0,0.166667,0.000000,0.499967,0.166650, +1,4,1.000000,0.000000,0.999900,0.000100,1,0,0.166667,0.000000,0.499983,0.000017, +1,5,1.000000,0.000000,0.999900,0.000100,0,0,0.166667,0.000000,0.666633,0.166650, +2,0,1.000000,0.000000,0.999900,0.000100,1,0,0.166667,0.000000,0.000017,0.000017, +2,1,1.000000,0.000000,0.999900,0.000100,0,0,0.166667,0.000000,0.166667,0.166650, +2,2,1.000000,0.000000,0.999900,0.000100,1,0,0.166667,0.000000,0.166683,0.000017, +2,3,1.000000,0.000000,0.999900,0.000100,1,0,0.166667,0.000000,0.166700,0.000017, +2,4,1.000000,0.000000,0.999900,0.000100,0,0,0.166667,0.000000,0.333350,0.166650, +2,5,1.000000,0.000000,0.999900,0.000100,1,0,0.166667,0.000000,0.333367,0.000017, +3,0,1.000000,0.000000,0.999900,0.000100,1,0,0.166667,0.000000,0.000017,0.000017, +3,1,1.000000,0.000000,0.999900,0.000100,0,0,0.166667,0.000000,0.166667,0.166650, +3,2,1.000000,0.000000,0.999900,0.000100,1,0,0.166667,0.000000,0.166683,0.000017, +3,3,1.000000,0.000000,0.999900,0.000100,1,0,0.166667,0.000000,0.166700,0.000017, +3,4,1.000000,0.000000,0.999900,0.000100,0,0,0.166667,0.000000,0.333350,0.166650, +3,5,1.000000,0.000000,0.999900,0.000100,1,0,0.166667,0.000000,0.333367,0.000017, +""" + +beagle_bwd_matrix_text_1 = """ +m,h,probRec,probNoRec,noErrProb,errProb,refAl,queryAl,shiftFac,scaleFac,sumSite,val, +3,0,-1,-1,-1,-1,-1,-1,-1,-1,-1,1.000000, +3,1,-1,-1,-1,-1,-1,-1,-1,-1,-1,1.000000, +3,2,-1,-1,-1,-1,-1,-1,-1,-1,-1,1.000000, +3,3,-1,-1,-1,-1,-1,-1,-1,-1,-1,1.000000, +3,4,-1,-1,-1,-1,-1,-1,-1,-1,-1,1.000000, +3,5,-1,-1,-1,-1,-1,-1,-1,-1,-1,1.000000, +2,0,1.000000,0.000000,0.999900,0.000100,1,0,0.000000,0.166667,0.333367,0.166667, +2,1,1.000000,0.000000,0.999900,0.000100,0,0,0.000000,0.166667,0.333367,0.166667, +2,2,1.000000,0.000000,0.999900,0.000100,1,0,0.000000,0.166667,0.333367,0.166667, +2,3,1.000000,0.000000,0.999900,0.000100,1,0,0.000000,0.166667,0.333367,0.166667, +2,4,1.000000,0.000000,0.999900,0.000100,0,0,0.000000,0.166667,0.333367,0.166667, +2,5,1.000000,0.000000,0.999900,0.000100,1,0,0.000000,0.166667,0.333367,0.166667, +1,0,1.000000,0.000000,0.999900,0.000100,0,0,0.000000,0.166667,0.333367,0.166667, +1,1,1.000000,0.000000,0.999900,0.000100,1,0,0.000000,0.166667,0.333367,0.166667, +1,2,1.000000,0.000000,0.999900,0.000100,0,0,0.000000,0.166667,0.333367,0.166667, +1,3,1.000000,0.000000,0.999900,0.000100,0,0,0.000000,0.166667,0.333367,0.166667, +1,4,1.000000,0.000000,0.999900,0.000100,1,0,0.000000,0.166667,0.333367,0.166667, +1,5,1.000000,0.000000,0.999900,0.000100,0,0,0.000000,0.166667,0.333367,0.166667, +0,0,1.000000,0.000000,0.999900,0.000100,1,0,0.000000,0.166667,0.666633,0.166667, +0,1,1.000000,0.000000,0.999900,0.000100,0,0,0.000000,0.166667,0.666633,0.166667, +0,2,1.000000,0.000000,0.999900,0.000100,1,0,0.000000,0.166667,0.666633,0.166667, +0,3,1.000000,0.000000,0.999900,0.000100,1,0,0.000000,0.166667,0.666633,0.166667, +0,4,1.000000,0.000000,0.999900,0.000100,0,0,0.000000,0.166667,0.666633,0.166667, +0,5,1.000000,0.000000,0.999900,0.000100,1,0,0.000000,0.166667,0.666633,0.166667, +""" + +beagle_fwd_matrix_text_2 = """ +m,h,probRec,probNoRec,noErrProb,errProb,refAl,queryAl,shiftFac,scaleFac,sumSite,val, +0,0,0.000000,1.000000,0.999900,0.000100,1,0,0.000000,1.000000,0.000100,0.000100, +0,1,0.000000,1.000000,0.999900,0.000100,0,0,0.000000,1.000000,1.000000,0.999900, +0,2,0.000000,1.000000,0.999900,0.000100,1,0,0.000000,1.000000,1.000100,0.000100, +0,3,0.000000,1.000000,0.999900,0.000100,1,0,0.000000,1.000000,1.000200,0.000100, +0,4,0.000000,1.000000,0.999900,0.000100,0,0,0.000000,1.000000,2.000100,0.999900, +0,5,0.000000,1.000000,0.999900,0.000100,1,0,0.000000,1.000000,2.000200,0.000100, +1,0,1.000000,0.000000,0.999900,0.000100,0,1,0.166667,0.000000,0.000017,0.000017, +1,1,1.000000,0.000000,0.999900,0.000100,1,1,0.166667,0.000000,0.166667,0.166650, +1,2,1.000000,0.000000,0.999900,0.000100,0,1,0.166667,0.000000,0.166683,0.000017, +1,3,1.000000,0.000000,0.999900,0.000100,0,1,0.166667,0.000000,0.166700,0.000017, +1,4,1.000000,0.000000,0.999900,0.000100,1,1,0.166667,0.000000,0.333350,0.166650, +1,5,1.000000,0.000000,0.999900,0.000100,0,1,0.166667,0.000000,0.333367,0.000017, +2,0,1.000000,0.000000,0.999900,0.000100,1,1,0.166667,0.000000,0.166650,0.166650, +2,1,1.000000,0.000000,0.999900,0.000100,0,1,0.166667,0.000000,0.166667,0.000017, +2,2,1.000000,0.000000,0.999900,0.000100,1,1,0.166667,0.000000,0.333317,0.166650, +2,3,1.000000,0.000000,0.999900,0.000100,1,1,0.166667,0.000000,0.499967,0.166650, +2,4,1.000000,0.000000,0.999900,0.000100,0,1,0.166667,0.000000,0.499983,0.000017, +2,5,1.000000,0.000000,0.999900,0.000100,1,1,0.166667,0.000000,0.666633,0.166650, +3,0,1.000000,0.000000,0.999900,0.000100,1,0,0.166667,0.000000,0.000017,0.000017, +3,1,1.000000,0.000000,0.999900,0.000100,0,0,0.166667,0.000000,0.166667,0.166650, +3,2,1.000000,0.000000,0.999900,0.000100,1,0,0.166667,0.000000,0.166683,0.000017, +3,3,1.000000,0.000000,0.999900,0.000100,1,0,0.166667,0.000000,0.166700,0.000017, +3,4,1.000000,0.000000,0.999900,0.000100,0,0,0.166667,0.000000,0.333350,0.166650, +3,5,1.000000,0.000000,0.999900,0.000100,1,0,0.166667,0.000000,0.333367,0.000017, +""" + +beagle_bwd_matrix_text_2 = """ +m,h,probRec,probNoRec,noErrProb,errProb,refAl,queryAl,shiftFac,scaleFac,sumSite,val, +3,0,-1,-1,-1,-1,-1,-1,-1,-1,-1,1.000000, +3,1,-1,-1,-1,-1,-1,-1,-1,-1,-1,1.000000, +3,2,-1,-1,-1,-1,-1,-1,-1,-1,-1,1.000000, +3,3,-1,-1,-1,-1,-1,-1,-1,-1,-1,1.000000, +3,4,-1,-1,-1,-1,-1,-1,-1,-1,-1,1.000000, +3,5,-1,-1,-1,-1,-1,-1,-1,-1,-1,1.000000, +2,0,1.000000,0.000000,0.999900,0.000100,1,0,0.000000,0.166667,0.333367,0.166667, +2,1,1.000000,0.000000,0.999900,0.000100,0,0,0.000000,0.166667,0.333367,0.166667, +2,2,1.000000,0.000000,0.999900,0.000100,1,0,0.000000,0.166667,0.333367,0.166667, +2,3,1.000000,0.000000,0.999900,0.000100,1,0,0.000000,0.166667,0.333367,0.166667, +2,4,1.000000,0.000000,0.999900,0.000100,0,0,0.000000,0.166667,0.333367,0.166667, +2,5,1.000000,0.000000,0.999900,0.000100,1,0,0.000000,0.166667,0.333367,0.166667, +1,0,1.000000,0.000000,0.999900,0.000100,0,1,0.000000,0.166667,0.666633,0.166667, +1,1,1.000000,0.000000,0.999900,0.000100,1,1,0.000000,0.166667,0.666633,0.166667, +1,2,1.000000,0.000000,0.999900,0.000100,0,1,0.000000,0.166667,0.666633,0.166667, +1,3,1.000000,0.000000,0.999900,0.000100,0,1,0.000000,0.166667,0.666633,0.166667, +1,4,1.000000,0.000000,0.999900,0.000100,1,1,0.000000,0.166667,0.666633,0.166667, +1,5,1.000000,0.000000,0.999900,0.000100,0,1,0.000000,0.166667,0.666633,0.166667, +0,0,1.000000,0.000000,0.999900,0.000100,1,1,0.000000,0.166667,0.333367,0.166667, +0,1,1.000000,0.000000,0.999900,0.000100,0,1,0.000000,0.166667,0.333367,0.166667, +0,2,1.000000,0.000000,0.999900,0.000100,1,1,0.000000,0.166667,0.333367,0.166667, +0,3,1.000000,0.000000,0.999900,0.000100,1,1,0.000000,0.166667,0.333367,0.166667, +0,4,1.000000,0.000000,0.999900,0.000100,0,1,0.000000,0.166667,0.333367,0.166667, +0,5,1.000000,0.000000,0.999900,0.000100,1,1,0.000000,0.166667,0.333367,0.166667, +""" + + +def parse_matrix(csv_text): + # This returns a record array, which is essentially the same as a + # pandas dataframe, which we can access via df["m"] etc. + return np.lib.npyio.recfromcsv(io.StringIO(csv_text)) + + +def test_toy_example(): + ref_ts, query = get_toy_data() + print(list(ref_ts.haplotypes())) + print(ref_ts) + print(query) + tskit_fwd, tskit_bwd = get_tskit_forward_backward_matrices(ref_ts, query[0]) + beagle_fwd = parse_matrix(beagle_fwd_matrix_text_1) + beagle_bwd = parse_matrix(beagle_bwd_matrix_text_1) + print("Forward probability matrix") + print("tskit") + print(tskit_fwd) + print("beagle") + print(beagle_fwd["val"].reshape((4, 6))) + print("Backward probability matrix") + print("tskit") + print(tskit_bwd) + print("beagle") + print(beagle_bwd["val"].reshape((4, 6)))