Skip to content

Commit

Permalink
Remove pandas dep for csv parsing
Browse files Browse the repository at this point in the history
  • Loading branch information
jeromekelleher committed Aug 9, 2023
1 parent a6df928 commit 2345c72
Showing 1 changed file with 44 additions and 45 deletions.
89 changes: 44 additions & 45 deletions python/tests/test_imputation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import io

import numpy as np
import pandas as pd

import _tskit
import tskit
Expand Down Expand Up @@ -258,54 +257,51 @@ def get_tskit_forward_backward_matrices(ts, h):
"""


def convert_to_numpy(matrix_text):
"""Converts a forward or backward matrix in text format to numpy."""
df = pd.read_csv(io.StringIO(matrix_text))
# Check that switch and non-switch probabilities sum to 1
assert np.all(np.isin(df.probRec + df.probNoRec, [1, -2]))
# Check that non-mismatch and mismatch probabilities sum to 1
assert np.all(np.isin(df.noErrProb + df.errProb, [1, -2]))
return df.val.to_numpy().reshape((4, 6))
def parse_matrix(csv_text):
# This returns a record array, which is essentially the same as a
# pandas dataframe, which we can access via df["m"] etc.
return np.lib.npyio.recfromcsv(io.StringIO(csv_text))


def get_forward_backward_matrices():
fwd_matrix_1 = convert_to_numpy(beagle_fwd_matrix_text_1)
bwd_matrix_1 = convert_to_numpy(beagle_bwd_matrix_text_1)
fwd_matrix_2 = convert_to_numpy(beagle_fwd_matrix_text_2)
bwd_matrix_2 = convert_to_numpy(beagle_bwd_matrix_text_2)
return [fwd_matrix_1, bwd_matrix_1, fwd_matrix_2, bwd_matrix_2]
# def get_forward_backward_matrices():
# fwd_matrix_1 = convert_to_numpy(beagle_fwd_matrix_text_1)
# bwd_matrix_1 = convert_to_numpy(beagle_bwd_matrix_text_1)
# fwd_matrix_2 = convert_to_numpy(beagle_fwd_matrix_text_2)
# bwd_matrix_2 = convert_to_numpy(beagle_bwd_matrix_text_2)
# return [fwd_matrix_1, bwd_matrix_1, fwd_matrix_2, bwd_matrix_2]


def get_test_data(matrix_text, field):
"""Extracts data to check forward or backward probability matrix calculations."""
df = pd.read_csv(io.StringIO(matrix_text))
m = 4 # Number of markers
n = 6 # Number of haplotypes
if field == "switch":
# Switch probability, one per site
return df.probRec.to_numpy().reshape((m, n))[:, 0]
elif field == "mismatch":
# Mismatch probability, one per site
return df.errProb.to_numpy().reshape((m, n))[:, 0]
elif field == "ref_hap_allele":
# Allele in haplotype in reference panel
# 0 = ref allele, 1 = alt allele
return df.refAl.to_numpy().reshape((m, n))
elif field == "query_hap_allele":
# Allele in haplotype in query
# 0 = ref allele, 1 = alt allele
return df.queryAl.to_numpy().reshape((m, n))[:, 0]
elif field == "shift":
# Shift factor, one per site
return df.shiftFac.to_numpy().reshape((m, n))[:, 0]
elif field == "scale":
# Scale factor, one per site
return df.scaleFac.to_numpy().reshape((m, n))[:, 0]
elif field == "sum":
# Sum of values over haplotypes
return df.sumSite.to_numpy().reshape((m, n))[:, 0]
else:
raise ValueError(f"Unknown field: {field}")
# def get_test_data(matrix_text, field):
# # JK: Not sure I see thhe point of this function?
# """Extracts data to check forward or backward probability matrix calculations."""
# df = pd.read_csv(io.StringIO(matrix_text))
# m = 4 # Number of markers
# n = 6 # Number of haplotypes
# if field == "switch":
# # Switch probability, one per site
# return df.probRec.to_numpy().reshape((m, n))[:, 0]
# elif field == "mismatch":
# # Mismatch probability, one per site
# return df.errProb.to_numpy().reshape((m, n))[:, 0]
# elif field == "ref_hap_allele":
# # Allele in haplotype in reference panel
# # 0 = ref allele, 1 = alt allele
# return df.refAl.to_numpy().reshape((m, n))
# elif field == "query_hap_allele":
# # Allele in haplotype in query
# # 0 = ref allele, 1 = alt allele
# return df.queryAl.to_numpy().reshape((m, n))[:, 0]
# elif field == "shift":
# # Shift factor, one per site
# return df.shiftFac.to_numpy().reshape((m, n))[:, 0]
# elif field == "scale":
# # Scale factor, one per site
# return df.scaleFac.to_numpy().reshape((m, n))[:, 0]
# elif field == "sum":
# # Sum of values over haplotypes
# return df.sumSite.to_numpy().reshape((m, n))[:, 0]
# else:
# raise ValueError(f"Unknown field: {field}")


def test_toy_example():
Expand All @@ -316,3 +312,6 @@ def test_toy_example():
fw, bw = get_tskit_forward_backward_matrices(ref_ts, query[0])
print(fw)
print(bw)
df_beagle = parse_matrix(beagle_fwd_matrix_text_1)
# compare beagle results with tskit
print(df_beagle["val"])

0 comments on commit 2345c72

Please sign in to comment.