Skip to content

Commit

Permalink
Refactor ACGT_ALLELES
Browse files Browse the repository at this point in the history
  • Loading branch information
szhan committed Feb 26, 2024
1 parent 2178c3c commit b010245
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions python/tests/beagle_numba.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@
__VERSION__ = "0.0.0"
__DATE__ = "20XXXXXX"

_ACGT_ALLELES = [0, 1, 2, 3, tskit.MISSING_DATA]
_ACGT_ALLELES_INT = [0, 1, 2, 3, tskit.MISSING_DATA]
_ACGT_ALLELES_STR = "ACGT"


@dataclass(frozen=True)
Expand Down Expand Up @@ -119,16 +120,16 @@ def __post_init__(self):
assert (
self.alleles.shape == self.allele_probs.shape
), "Dimensions in alleles and allele probabilities don't match."
for i in range(len(self.alleles)):
for i in range(self.alleles.shape[1]):
assert np.all(np.isin(self.alleles[:, i], [self.refs[i], self.alts[i]]))
assert np.all(
np.isin(np.unique(self.refs), _ACGT_ALLELES)
np.isin(np.unique(self.refs), _ACGT_ALLELES_INT)
), "Unrecognized alleles are in REF alleles."
assert np.all(
np.isin(np.unique(self.alts), _ACGT_ALLELES)
np.isin(np.unique(self.alts), _ACGT_ALLELES_INT)
), "Unrecognized alleles are in ALT alleles."
assert np.all(
np.isin(np.unique(self.alleles), _ACGT_ALLELES)
np.isin(np.unique(self.alleles), _ACGT_ALLELES_INT)
), "Unrecognized alleles are in alleles."

@property
Expand Down Expand Up @@ -170,13 +171,12 @@ def remap_alleles(a):
:return: Recoded alleles.
:rtype: np.ndarray(dtype=np.int8)
"""
_ALLELES_ACGT = "ACGT"
b = np.zeros(len(a), dtype=np.int8) - 1 # Encoded as missing by default
for i in range(len(a)):
if a[i] in [None, ""]:
continue
elif a[i] in _ALLELES_ACGT:
b[i] = _ALLELES_ACGT.index(a[i])
elif a[i] in _ACGT_ALLELES_STR:
b[i] = _ACGT_ALLELES_STR.index(a[i])
else:
raise AssertionError(f"Allele {a[i]} is not recognised.")
return b
Expand Down Expand Up @@ -864,8 +864,8 @@ def write_vcf(ref_ts, impdata, out_file, chr_name="1"):
line_str = chr_name + "\t"
line_str += str(int(impdata.site_pos[i])) + "\t"
line_str += str(i) + "\t"
REF = _ACGT_ALLELES[impdata.get_ref_allele_at_site(i)]
ALT = _ACGT_ALLELES[impdata.get_alt_allele_at_site(i)]
REF = _ACGT_ALLELES_STR[impdata.get_ref_allele_at_site(i)]
ALT = _ACGT_ALLELES_STR[impdata.get_alt_allele_at_site(i)]
line_str += REF + "\t"
line_str += ALT + "\t"
# QUAL
Expand Down

0 comments on commit b010245

Please sign in to comment.