diff --git a/python/tests/beagle.py b/python/tests/beagle.py index 170984e982..66ef7a85b4 100644 --- a/python/tests/beagle.py +++ b/python/tests/beagle.py @@ -45,6 +45,8 @@ import numpy as np +import _tskit + def convert_to_genetic_map_position(pos): """ @@ -642,3 +644,49 @@ def run_beagle(ref_h, query_h, pos, miscall_rate=0.0001, ne=1e6, debug=False): print(imputed_alleles) assert len(imputed_alleles) == x return (imputed_alleles, i_allele_probs) + + +def run_tsimpute(ref_ts, query_h, pos, mu, rho): + """ + Run the simplified BEAGLE algorithm above on a tree sequence. + + :param numpy.ndarray ref_h: Reference haplotypes. + :param numpy.ndarray query_h: One query haplotype. + :param numpy.ndarray pos: Site positions of all the markers. + :param numpy.ndarray mu: Mismatch probabilities. + :param numpy.ndarray rho: Switch probabilities. + :param bool debug: Whether to print intermediate results. + :return: Imputed alleles and interpolated allele probabilities. + :rtype: list(numpy.ndarray, numpy.ndarray) + """ + # Prepare marker positions. + genotyped_site_ids = np.where(query_h != -1)[0] + genotyped_pos = pos[genotyped_site_ids] + imputed_site_ids = np.where(query_h == -1)[0] + imputed_pos = pos[imputed_site_ids] + # Prepare reference haplotypes. + ref_ts_m = ref_ts.delete_sites(site_ids=imputed_site_ids) + ref_h_m = ref_ts_m.genotype_matrix() + ref_ts_x = ref_ts.delete_sites(site_ids=genotyped_site_ids) + ref_h_x = ref_ts_x.genotype_matrix() + query_h_m = query_h[genotyped_site_ids] + # Get forward and backward matrices from ts. + fm = _tskit.CompressedMatrix(ref_ts_m._ll_tree_sequence) + bm = _tskit.CompressedMatrix(ref_ts_m._ll_tree_sequence) + ls_hmm = _tskit.LsHmm(ref_ts_m._ll_tree_sequence, mu, rho, acgt_alleles=True) + ls_hmm.forward_matrix(query_h_m.T, fm) + ls_hmm.backward_matrix(query_h_m.T, fm.normalisation_factor, bm) + # Compute state probability matrix. + sm = compute_state_probability_matrix( + fm.decode(), + bm.decode(), + ref_h_m, + query_h_m, + ) + # Interpolate allele probabilities. + i_allele_probs = interpolate_allele_probabilities( + sm, ref_h_x, genotyped_pos, imputed_pos + ) + # Get MAP alleles at imputed markers. + imputed_alleles = get_map_alleles(i_allele_probs) + return (imputed_alleles, i_allele_probs) diff --git a/python/tests/test_imputation.py b/python/tests/test_imputation.py index bed9b8dd4d..05c886812e 100644 --- a/python/tests/test_imputation.py +++ b/python/tests/test_imputation.py @@ -6,11 +6,7 @@ import numpy as np import pytest -import _tskit import tskit -from tests.beagle import compute_state_probability_matrix -from tests.beagle import get_map_alleles -from tests.beagle import interpolate_allele_probabilities from tests.beagle import run_beagle @@ -630,35 +626,3 @@ def parse_matrix(csv_text): # This returns a record array, which is essentially the same as a # pandas dataframe, which we can access via df["m"] etc. return np.lib.npyio.recfromcsv(io.StringIO(csv_text)) - - -def run_tsimpute(ref_ts, query_h, pos, mu, rho): - # Prepare marker positions. - genotyped_site_ids = np.where(query_h != -1)[0] - genotyped_pos = pos[genotyped_site_ids] - imputed_site_ids = np.where(query_h == -1)[0] - imputed_pos = pos[imputed_site_ids] - # Prepare reference haplotypes. - ref_ts_m = ref_ts.delete_sites(site_ids=imputed_site_ids) - ref_h_m = ref_ts_m.genotype_matrix() - ref_ts_x = ref_ts.delete_sites(site_ids=genotyped_site_ids) - ref_h_x = ref_ts_x.genotype_matrix() - query_h_m = query_h[genotyped_site_ids] - # Get forward and backward matrices from ts. - fm = _tskit.CompressedMatrix(ref_ts_m._ll_tree_sequence) - bm = _tskit.CompressedMatrix(ref_ts_m._ll_tree_sequence) - ls_hmm = _tskit.LsHmm(ref_ts_m._ll_tree_sequence, mu, rho, acgt_alleles=True) - ls_hmm.forward_matrix(query_h_m.T, fm) - ls_hmm.backward_matrix(query_h_m.T, fm.normalisation_factor, bm) - # Compute state probability matrix. - sm = compute_state_probability_matrix( - fm.decode(), - bm.decode(), - ref_h_m, - query_h_m, - ) - # Interpolate allele probabilities. - ap = interpolate_allele_probabilities(sm, ref_h_x, genotyped_pos, imputed_pos) - # Get imputed alleles. - ia = get_map_alleles(ap) - return ia