Skip to content

Commit

Permalink
Add tests to check LS HMM of _tskit.lshmm compared to BEAGLE
Browse files Browse the repository at this point in the history
  • Loading branch information
szhan committed Aug 9, 2023
1 parent b6f9872 commit b4628f8
Showing 1 changed file with 301 additions and 0 deletions.
301 changes: 301 additions & 0 deletions python/tests/test_imputation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,301 @@
"""
Tests for genotype imputation (forward and Baum-Welsh algorithms).
"""
import io

import numpy as np

import _tskit
import tskit


# A toy tree sequence containing 3 diploid individuals with 5 sites and 5 mutations.
# Two toy query haplotypes are targets for imputation.

toy_ts_nodes_text = """\
id is_sample time population individual metadata
0 1 0.000000 0 0
1 1 0.000000 0 0
2 1 0.000000 0 1
3 1 0.000000 0 1
4 1 0.000000 0 2
5 1 0.000000 0 2
6 0 0.029768 0 -1
7 0 0.133017 0 -1
8 0 0.223233 0 -1
9 0 0.651586 0 -1
10 0 0.698831 0 -1
11 0 2.114867 0 -1
12 0 4.322031 0 -1
13 0 7.432311 0 -1
"""

toy_ts_edges_text = """\
left right parent child metadata
0.000000 1000000.000000 6 0
0.000000 1000000.000000 6 3
0.000000 1000000.000000 7 2
0.000000 1000000.000000 7 5
0.000000 1000000.000000 8 1
0.000000 1000000.000000 8 4
0.000000 781157.000000 9 6
0.000000 781157.000000 9 7
0.000000 505438.000000 10 8
0.000000 505438.000000 10 9
505438.000000 549484.000000 11 8
505438.000000 549484.000000 11 9
781157.000000 1000000.000000 12 6
781157.000000 1000000.000000 12 7
549484.000000 1000000.000000 13 8
549484.000000 781157.000000 13 9
781157.000000 1000000.000000 13 12
"""

toy_ts_sites_text = """\
position ancestral_state metadata
200000.000000 A
300000.000000 C
520000.000000 G
600000.000000 T
900000.000000 A
"""

toy_ts_mutations_text = """\
site node time derived_state parent metadata
0 9 unknown G -1
1 8 unknown A -1
2 9 unknown T -1
3 9 unknown C -1
4 12 unknown C -1
"""

toy_ts_individuals_text = """\
flags
0
0
0
"""

toy_query_haplotypes_01 = np.array(
[
[
1,
0,
-1,
0,
0,
],
[
0,
1,
-1,
1,
0,
],
],
dtype=np.int32,
)

toy_query_haplotypes_ACGT = np.array(
[
[2, 1, -1, 3, 0], # GCTA
[0, 0, -1, 1, 0], # AACA
],
dtype=np.int32,
)


def get_toy_data():
ref_ts = tskit.load_text(
nodes=io.StringIO(toy_ts_nodes_text),
edges=io.StringIO(toy_ts_edges_text),
sites=io.StringIO(toy_ts_sites_text),
mutations=io.StringIO(toy_ts_mutations_text),
individuals=io.StringIO(toy_ts_individuals_text),
strict=False,
)
query_h = toy_query_haplotypes_ACGT
return [ref_ts, query_h]


def get_tskit_forward_backward_matrices(ts, h):
m = ts.num_sites
fm = _tskit.CompressedMatrix(ts._ll_tree_sequence)
bm = _tskit.CompressedMatrix(ts._ll_tree_sequence)
ls_hmm = _tskit.LsHmm(
ts._ll_tree_sequence, np.zeros(m) + 0.1, np.zeros(m) + 0.1, acgt_alleles=True
)
ls_hmm.forward_matrix(h, fm)
ls_hmm.backward_matrix(h, fm.normalisation_factor, bm)
return [fm.decode(), bm.decode()]


# BEAGLE 4.1 was run on the toy data set above using default parameters.
#
# In the query VCF, the site at position 520,000 was redacted and then imputed.
# Note that the ancestral allele in the simulated tree sequence is
# treated as the REF in the VCFs.
#
# The following are the forward probability matrices and backward probability
# matrices calculated when imputing into the third individual above. There are
# two sets of matrices, one for each haplotype.
#
# Notes about calculations:
# n = number of haplotypes in ref. panel
# M = number of markers
# m = index of marker (site)
# h = index of haplotype in ref. panel
#
# In forward probability matrix,
# fwd[m][h] = emission prob., if m = 0 (first marker)
# fwd[m][h] = emission prob. * (scale * fwd[m - 1][h] + shift), otherwise
# where scale = (1 - switch prob.)/sum of fwd[m - 1],
# and shift = switch prob./n.
#
# In backward probability matrix,
# bwd[m][h] = 1, if m = M - 1 (last marker) // DON'T SEE THIS IN BEAGLE
# unadj. bwd[m][h] = emission prob. / n
# bwd[m][h] = (unadj. bwd[m][h] + shift) * scale, otherwise
# where scale = (1 - switch prob.)/sum of unadj. bwd[m],
# and shift = switch prob./n.
#
# For each site, the sum of backward value over all haplotypes is calculated
# before scaling and shifting.

beagle_fwd_matrix_text_1 = """
m,h,probRec,probNoRec,noErrProb,errProb,refAl,queryAl,shiftFac,scaleFac,sumSite,val,
0,0,0.000000,1.000000,0.999900,0.000100,1,1,0.000000,1.000000,0.999900,0.999900,
0,1,0.000000,1.000000,0.999900,0.000100,0,1,0.000000,1.000000,1.000000,0.000100,
0,2,0.000000,1.000000,0.999900,0.000100,1,1,0.000000,1.000000,1.999900,0.999900,
0,3,0.000000,1.000000,0.999900,0.000100,1,1,0.000000,1.000000,2.999800,0.999900,
0,4,0.000000,1.000000,0.999900,0.000100,0,1,0.000000,1.000000,2.999900,0.000100,
0,5,0.000000,1.000000,0.999900,0.000100,1,1,0.000000,1.000000,3.999800,0.999900,
1,0,1.000000,0.000000,0.999900,0.000100,0,0,0.166667,0.000000,0.166650,0.166650,
1,1,1.000000,0.000000,0.999900,0.000100,1,0,0.166667,0.000000,0.166667,0.000017,
1,2,1.000000,0.000000,0.999900,0.000100,0,0,0.166667,0.000000,0.333317,0.166650,
1,3,1.000000,0.000000,0.999900,0.000100,0,0,0.166667,0.000000,0.499967,0.166650,
1,4,1.000000,0.000000,0.999900,0.000100,1,0,0.166667,0.000000,0.499983,0.000017,
1,5,1.000000,0.000000,0.999900,0.000100,0,0,0.166667,0.000000,0.666633,0.166650,
2,0,1.000000,0.000000,0.999900,0.000100,1,0,0.166667,0.000000,0.000017,0.000017,
2,1,1.000000,0.000000,0.999900,0.000100,0,0,0.166667,0.000000,0.166667,0.166650,
2,2,1.000000,0.000000,0.999900,0.000100,1,0,0.166667,0.000000,0.166683,0.000017,
2,3,1.000000,0.000000,0.999900,0.000100,1,0,0.166667,0.000000,0.166700,0.000017,
2,4,1.000000,0.000000,0.999900,0.000100,0,0,0.166667,0.000000,0.333350,0.166650,
2,5,1.000000,0.000000,0.999900,0.000100,1,0,0.166667,0.000000,0.333367,0.000017,
3,0,1.000000,0.000000,0.999900,0.000100,1,0,0.166667,0.000000,0.000017,0.000017,
3,1,1.000000,0.000000,0.999900,0.000100,0,0,0.166667,0.000000,0.166667,0.166650,
3,2,1.000000,0.000000,0.999900,0.000100,1,0,0.166667,0.000000,0.166683,0.000017,
3,3,1.000000,0.000000,0.999900,0.000100,1,0,0.166667,0.000000,0.166700,0.000017,
3,4,1.000000,0.000000,0.999900,0.000100,0,0,0.166667,0.000000,0.333350,0.166650,
3,5,1.000000,0.000000,0.999900,0.000100,1,0,0.166667,0.000000,0.333367,0.000017,
"""

beagle_bwd_matrix_text_1 = """
m,h,probRec,probNoRec,noErrProb,errProb,refAl,queryAl,shiftFac,scaleFac,sumSite,val,
3,0,-1,-1,-1,-1,-1,-1,-1,-1,-1,1.000000,
3,1,-1,-1,-1,-1,-1,-1,-1,-1,-1,1.000000,
3,2,-1,-1,-1,-1,-1,-1,-1,-1,-1,1.000000,
3,3,-1,-1,-1,-1,-1,-1,-1,-1,-1,1.000000,
3,4,-1,-1,-1,-1,-1,-1,-1,-1,-1,1.000000,
3,5,-1,-1,-1,-1,-1,-1,-1,-1,-1,1.000000,
2,0,1.000000,0.000000,0.999900,0.000100,1,0,0.000000,0.166667,0.333367,0.166667,
2,1,1.000000,0.000000,0.999900,0.000100,0,0,0.000000,0.166667,0.333367,0.166667,
2,2,1.000000,0.000000,0.999900,0.000100,1,0,0.000000,0.166667,0.333367,0.166667,
2,3,1.000000,0.000000,0.999900,0.000100,1,0,0.000000,0.166667,0.333367,0.166667,
2,4,1.000000,0.000000,0.999900,0.000100,0,0,0.000000,0.166667,0.333367,0.166667,
2,5,1.000000,0.000000,0.999900,0.000100,1,0,0.000000,0.166667,0.333367,0.166667,
1,0,1.000000,0.000000,0.999900,0.000100,0,0,0.000000,0.166667,0.333367,0.166667,
1,1,1.000000,0.000000,0.999900,0.000100,1,0,0.000000,0.166667,0.333367,0.166667,
1,2,1.000000,0.000000,0.999900,0.000100,0,0,0.000000,0.166667,0.333367,0.166667,
1,3,1.000000,0.000000,0.999900,0.000100,0,0,0.000000,0.166667,0.333367,0.166667,
1,4,1.000000,0.000000,0.999900,0.000100,1,0,0.000000,0.166667,0.333367,0.166667,
1,5,1.000000,0.000000,0.999900,0.000100,0,0,0.000000,0.166667,0.333367,0.166667,
0,0,1.000000,0.000000,0.999900,0.000100,1,0,0.000000,0.166667,0.666633,0.166667,
0,1,1.000000,0.000000,0.999900,0.000100,0,0,0.000000,0.166667,0.666633,0.166667,
0,2,1.000000,0.000000,0.999900,0.000100,1,0,0.000000,0.166667,0.666633,0.166667,
0,3,1.000000,0.000000,0.999900,0.000100,1,0,0.000000,0.166667,0.666633,0.166667,
0,4,1.000000,0.000000,0.999900,0.000100,0,0,0.000000,0.166667,0.666633,0.166667,
0,5,1.000000,0.000000,0.999900,0.000100,1,0,0.000000,0.166667,0.666633,0.166667,
"""

beagle_fwd_matrix_text_2 = """
m,h,probRec,probNoRec,noErrProb,errProb,refAl,queryAl,shiftFac,scaleFac,sumSite,val,
0,0,0.000000,1.000000,0.999900,0.000100,1,0,0.000000,1.000000,0.000100,0.000100,
0,1,0.000000,1.000000,0.999900,0.000100,0,0,0.000000,1.000000,1.000000,0.999900,
0,2,0.000000,1.000000,0.999900,0.000100,1,0,0.000000,1.000000,1.000100,0.000100,
0,3,0.000000,1.000000,0.999900,0.000100,1,0,0.000000,1.000000,1.000200,0.000100,
0,4,0.000000,1.000000,0.999900,0.000100,0,0,0.000000,1.000000,2.000100,0.999900,
0,5,0.000000,1.000000,0.999900,0.000100,1,0,0.000000,1.000000,2.000200,0.000100,
1,0,1.000000,0.000000,0.999900,0.000100,0,1,0.166667,0.000000,0.000017,0.000017,
1,1,1.000000,0.000000,0.999900,0.000100,1,1,0.166667,0.000000,0.166667,0.166650,
1,2,1.000000,0.000000,0.999900,0.000100,0,1,0.166667,0.000000,0.166683,0.000017,
1,3,1.000000,0.000000,0.999900,0.000100,0,1,0.166667,0.000000,0.166700,0.000017,
1,4,1.000000,0.000000,0.999900,0.000100,1,1,0.166667,0.000000,0.333350,0.166650,
1,5,1.000000,0.000000,0.999900,0.000100,0,1,0.166667,0.000000,0.333367,0.000017,
2,0,1.000000,0.000000,0.999900,0.000100,1,1,0.166667,0.000000,0.166650,0.166650,
2,1,1.000000,0.000000,0.999900,0.000100,0,1,0.166667,0.000000,0.166667,0.000017,
2,2,1.000000,0.000000,0.999900,0.000100,1,1,0.166667,0.000000,0.333317,0.166650,
2,3,1.000000,0.000000,0.999900,0.000100,1,1,0.166667,0.000000,0.499967,0.166650,
2,4,1.000000,0.000000,0.999900,0.000100,0,1,0.166667,0.000000,0.499983,0.000017,
2,5,1.000000,0.000000,0.999900,0.000100,1,1,0.166667,0.000000,0.666633,0.166650,
3,0,1.000000,0.000000,0.999900,0.000100,1,0,0.166667,0.000000,0.000017,0.000017,
3,1,1.000000,0.000000,0.999900,0.000100,0,0,0.166667,0.000000,0.166667,0.166650,
3,2,1.000000,0.000000,0.999900,0.000100,1,0,0.166667,0.000000,0.166683,0.000017,
3,3,1.000000,0.000000,0.999900,0.000100,1,0,0.166667,0.000000,0.166700,0.000017,
3,4,1.000000,0.000000,0.999900,0.000100,0,0,0.166667,0.000000,0.333350,0.166650,
3,5,1.000000,0.000000,0.999900,0.000100,1,0,0.166667,0.000000,0.333367,0.000017,
"""

beagle_bwd_matrix_text_2 = """
m,h,probRec,probNoRec,noErrProb,errProb,refAl,queryAl,shiftFac,scaleFac,sumSite,val,
3,0,-1,-1,-1,-1,-1,-1,-1,-1,-1,1.000000,
3,1,-1,-1,-1,-1,-1,-1,-1,-1,-1,1.000000,
3,2,-1,-1,-1,-1,-1,-1,-1,-1,-1,1.000000,
3,3,-1,-1,-1,-1,-1,-1,-1,-1,-1,1.000000,
3,4,-1,-1,-1,-1,-1,-1,-1,-1,-1,1.000000,
3,5,-1,-1,-1,-1,-1,-1,-1,-1,-1,1.000000,
2,0,1.000000,0.000000,0.999900,0.000100,1,0,0.000000,0.166667,0.333367,0.166667,
2,1,1.000000,0.000000,0.999900,0.000100,0,0,0.000000,0.166667,0.333367,0.166667,
2,2,1.000000,0.000000,0.999900,0.000100,1,0,0.000000,0.166667,0.333367,0.166667,
2,3,1.000000,0.000000,0.999900,0.000100,1,0,0.000000,0.166667,0.333367,0.166667,
2,4,1.000000,0.000000,0.999900,0.000100,0,0,0.000000,0.166667,0.333367,0.166667,
2,5,1.000000,0.000000,0.999900,0.000100,1,0,0.000000,0.166667,0.333367,0.166667,
1,0,1.000000,0.000000,0.999900,0.000100,0,1,0.000000,0.166667,0.666633,0.166667,
1,1,1.000000,0.000000,0.999900,0.000100,1,1,0.000000,0.166667,0.666633,0.166667,
1,2,1.000000,0.000000,0.999900,0.000100,0,1,0.000000,0.166667,0.666633,0.166667,
1,3,1.000000,0.000000,0.999900,0.000100,0,1,0.000000,0.166667,0.666633,0.166667,
1,4,1.000000,0.000000,0.999900,0.000100,1,1,0.000000,0.166667,0.666633,0.166667,
1,5,1.000000,0.000000,0.999900,0.000100,0,1,0.000000,0.166667,0.666633,0.166667,
0,0,1.000000,0.000000,0.999900,0.000100,1,1,0.000000,0.166667,0.333367,0.166667,
0,1,1.000000,0.000000,0.999900,0.000100,0,1,0.000000,0.166667,0.333367,0.166667,
0,2,1.000000,0.000000,0.999900,0.000100,1,1,0.000000,0.166667,0.333367,0.166667,
0,3,1.000000,0.000000,0.999900,0.000100,1,1,0.000000,0.166667,0.333367,0.166667,
0,4,1.000000,0.000000,0.999900,0.000100,0,1,0.000000,0.166667,0.333367,0.166667,
0,5,1.000000,0.000000,0.999900,0.000100,1,1,0.000000,0.166667,0.333367,0.166667,
"""


def parse_matrix(csv_text):
# This returns a record array, which is essentially the same as a
# pandas dataframe, which we can access via df["m"] etc.
return np.lib.npyio.recfromcsv(io.StringIO(csv_text))


def test_toy_example():
ref_ts, query = get_toy_data()
print(list(ref_ts.haplotypes()))
print(ref_ts)
print(query)
tskit_fwd, tskit_bwd = get_tskit_forward_backward_matrices(ref_ts, query[0])
beagle_fwd = parse_matrix(beagle_fwd_matrix_text_1)
beagle_bwd = parse_matrix(beagle_bwd_matrix_text_1)
print("Forward probability matrix")
print("tskit")
print(tskit_fwd)
print("beagle")
print(beagle_fwd["val"].reshape((4, 6)))
print("Backward probability matrix")
print("tskit")
print(tskit_bwd)
print("beagle")
print(beagle_bwd["val"].reshape((4, 6)))

0 comments on commit b4628f8

Please sign in to comment.