diff --git a/python/tests/beagle_numba.py b/python/tests/beagle_numba.py index f01252acbc..ca3a3b728d 100644 --- a/python/tests/beagle_numba.py +++ b/python/tests/beagle_numba.py @@ -50,7 +50,8 @@ __VERSION__ = "0.0.0" __DATE__ = "20XXXXXX" -_ACGT_ALLELES = [0, 1, 2, 3, tskit.MISSING_DATA] +_ACGT_ALLELES_INT = [0, 1, 2, 3, tskit.MISSING_DATA] +_ACGT_ALLELES_STR = "ACGT" @dataclass(frozen=True) @@ -119,16 +120,16 @@ def __post_init__(self): assert ( self.alleles.shape == self.allele_probs.shape ), "Dimensions in alleles and allele probabilities don't match." - for i in range(len(self.alleles)): + for i in range(self.alleles.shape[1]): assert np.all(np.isin(self.alleles[:, i], [self.refs[i], self.alts[i]])) assert np.all( - np.isin(np.unique(self.refs), _ACGT_ALLELES) + np.isin(np.unique(self.refs), _ACGT_ALLELES_INT) ), "Unrecognized alleles are in REF alleles." assert np.all( - np.isin(np.unique(self.alts), _ACGT_ALLELES) + np.isin(np.unique(self.alts), _ACGT_ALLELES_INT) ), "Unrecognized alleles are in ALT alleles." assert np.all( - np.isin(np.unique(self.alleles), _ACGT_ALLELES) + np.isin(np.unique(self.alleles), _ACGT_ALLELES_INT) ), "Unrecognized alleles are in alleles." @property @@ -170,13 +171,12 @@ def remap_alleles(a): :return: Recoded alleles. :rtype: np.ndarray(dtype=np.int8) """ - _ALLELES_ACGT = "ACGT" b = np.zeros(len(a), dtype=np.int8) - 1 # Encoded as missing by default for i in range(len(a)): if a[i] in [None, ""]: continue - elif a[i] in _ALLELES_ACGT: - b[i] = _ALLELES_ACGT.index(a[i]) + elif a[i] in _ACGT_ALLELES_STR: + b[i] = _ACGT_ALLELES_STR.index(a[i]) else: raise AssertionError(f"Allele {a[i]} is not recognised.") return b @@ -864,8 +864,8 @@ def write_vcf(ref_ts, impdata, out_file, chr_name="1"): line_str = chr_name + "\t" line_str += str(int(impdata.site_pos[i])) + "\t" line_str += str(i) + "\t" - REF = _ACGT_ALLELES[impdata.get_ref_allele_at_site(i)] - ALT = _ACGT_ALLELES[impdata.get_alt_allele_at_site(i)] + REF = _ACGT_ALLELES_STR[impdata.get_ref_allele_at_site(i)] + ALT = _ACGT_ALLELES_STR[impdata.get_alt_allele_at_site(i)] line_str += REF + "\t" line_str += ALT + "\t" # QUAL