Skip to content

Commit

Permalink
Reorganise
Browse files Browse the repository at this point in the history
  • Loading branch information
szhan committed Sep 12, 2023
1 parent 7a6facd commit f0c4a2d
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 36 deletions.
48 changes: 48 additions & 0 deletions python/tests/beagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@

import numpy as np

import _tskit


def convert_to_genetic_map_position(pos):
"""
Expand Down Expand Up @@ -642,3 +644,49 @@ def run_beagle(ref_h, query_h, pos, miscall_rate=0.0001, ne=1e6, debug=False):
print(imputed_alleles)
assert len(imputed_alleles) == x
return (imputed_alleles, i_allele_probs)


def run_tsimpute(ref_ts, query_h, pos, mu, rho):
"""
Run the simplified BEAGLE algorithm above on a tree sequence.
:param numpy.ndarray ref_h: Reference haplotypes.
:param numpy.ndarray query_h: One query haplotype.
:param numpy.ndarray pos: Site positions of all the markers.
:param numpy.ndarray mu: Mismatch probabilities.
:param numpy.ndarray rho: Switch probabilities.
:param bool debug: Whether to print intermediate results.
:return: Imputed alleles and interpolated allele probabilities.
:rtype: list(numpy.ndarray, numpy.ndarray)
"""
# Prepare marker positions.
genotyped_site_ids = np.where(query_h != -1)[0]
genotyped_pos = pos[genotyped_site_ids]
imputed_site_ids = np.where(query_h == -1)[0]
imputed_pos = pos[imputed_site_ids]
# Prepare reference haplotypes.
ref_ts_m = ref_ts.delete_sites(site_ids=imputed_site_ids)
ref_h_m = ref_ts_m.genotype_matrix()
ref_ts_x = ref_ts.delete_sites(site_ids=genotyped_site_ids)
ref_h_x = ref_ts_x.genotype_matrix()
query_h_m = query_h[genotyped_site_ids]
# Get forward and backward matrices from ts.
fm = _tskit.CompressedMatrix(ref_ts_m._ll_tree_sequence)
bm = _tskit.CompressedMatrix(ref_ts_m._ll_tree_sequence)
ls_hmm = _tskit.LsHmm(ref_ts_m._ll_tree_sequence, mu, rho, acgt_alleles=True)
ls_hmm.forward_matrix(query_h_m.T, fm)
ls_hmm.backward_matrix(query_h_m.T, fm.normalisation_factor, bm)
# Compute state probability matrix.
sm = compute_state_probability_matrix(
fm.decode(),
bm.decode(),
ref_h_m,
query_h_m,
)
# Interpolate allele probabilities.
i_allele_probs = interpolate_allele_probabilities(
sm, ref_h_x, genotyped_pos, imputed_pos
)
# Get MAP alleles at imputed markers.
imputed_alleles = get_map_alleles(i_allele_probs)
return (imputed_alleles, i_allele_probs)
36 changes: 0 additions & 36 deletions python/tests/test_imputation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,7 @@
import numpy as np
import pytest

import _tskit
import tskit
from tests.beagle import compute_state_probability_matrix
from tests.beagle import get_map_alleles
from tests.beagle import interpolate_allele_probabilities
from tests.beagle import run_beagle


Expand Down Expand Up @@ -630,35 +626,3 @@ def parse_matrix(csv_text):
# This returns a record array, which is essentially the same as a
# pandas dataframe, which we can access via df["m"] etc.
return np.lib.npyio.recfromcsv(io.StringIO(csv_text))


def run_tsimpute(ref_ts, query_h, pos, mu, rho):
# Prepare marker positions.
genotyped_site_ids = np.where(query_h != -1)[0]
genotyped_pos = pos[genotyped_site_ids]
imputed_site_ids = np.where(query_h == -1)[0]
imputed_pos = pos[imputed_site_ids]
# Prepare reference haplotypes.
ref_ts_m = ref_ts.delete_sites(site_ids=imputed_site_ids)
ref_h_m = ref_ts_m.genotype_matrix()
ref_ts_x = ref_ts.delete_sites(site_ids=genotyped_site_ids)
ref_h_x = ref_ts_x.genotype_matrix()
query_h_m = query_h[genotyped_site_ids]
# Get forward and backward matrices from ts.
fm = _tskit.CompressedMatrix(ref_ts_m._ll_tree_sequence)
bm = _tskit.CompressedMatrix(ref_ts_m._ll_tree_sequence)
ls_hmm = _tskit.LsHmm(ref_ts_m._ll_tree_sequence, mu, rho, acgt_alleles=True)
ls_hmm.forward_matrix(query_h_m.T, fm)
ls_hmm.backward_matrix(query_h_m.T, fm.normalisation_factor, bm)
# Compute state probability matrix.
sm = compute_state_probability_matrix(
fm.decode(),
bm.decode(),
ref_h_m,
query_h_m,
)
# Interpolate allele probabilities.
ap = interpolate_allele_probabilities(sm, ref_h_x, genotyped_pos, imputed_pos)
# Get imputed alleles.
ia = get_map_alleles(ap)
return ia

0 comments on commit f0c4a2d

Please sign in to comment.