Skip to content

Commit

Permalink
Fixup tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jeromekelleher committed Dec 16, 2024
1 parent 1d672ed commit 0a12df0
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 53 deletions.
2 changes: 2 additions & 0 deletions sc2ts/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -1645,6 +1645,8 @@ def get_closest_mutation(node, site_id):


def extract_haplotypes(ts, samples):
# Annoyingly tskit doesn't allow us to specify duplicate samples, which can
# happen perfectly well here, so we must work around.
unique_samples = list(set(samples))
H = ts.genotype_matrix(samples=unique_samples, isolated_as_missing=False).T
ret = []
Expand Down
84 changes: 31 additions & 53 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,59 +50,6 @@ def recombinant_example_1(ts_map):
return ts, s


def recombinant_example_3_way_same_parent(recombinant_example_1):
"""
Example recombinant created by cherry picking two samples that differ
by mutations on either end of the genome, and smushing them together.
Note there's only two mutations needed, so we need to set num_mismatches=2
"""
ts = recombinant_example_1
s = "recombinant_example_1_0"
parent = ts.samples()[ts.metadata["sc2ts"]["samples_strain"].index(s)]
# ts = ts_map["2020-02-13"]
# parent = ts.samples()[ts.
# parent = 45
# assert ts.node(parent).metadata["strain"] == "SRR11597163"
# add a child that has a bunch of mutations at the start and end
h = ts.genotype_matrix(samples=[parent], alleles=tuple("ACGT-")).T[0]

tables = ts.dump_tables()
node_time = -1
child = tables.nodes.add_row(time=node_time)
tables.edges.add_row(0, ts.sequence_length, parent=parent, child=child)

start = 11_000
stop = 11_050
for k in range(start, stop):
tables.mutations.add_row(site=k, derived_state="A", node=child, time=node_time)
h[k] = 0

# Stick on a bunch of mutations either side to force switching back to same haplotype
for k in range(1, 5):
tables.mutations.add_row(
site=start - k, derived_state="A", node=child, time=node_time
)
tables.mutations.add_row(
site=stop + k, derived_state="A", node=child, time=node_time
)

tables.sort()
s = sc2ts.Sample("frankentype", "2020-02-14", haplotype=h.astype(np.int8))
return tables.tree_sequence(), s


def tmp_metadata_db(tmp_path, strains, date):
data = []
for strain in strains:
data.append({"strain": strain, "date": date})
df = pd.DataFrame(data)
csv_path = tmp_path / "metadata.csv"
df.to_csv(csv_path)
db_path = tmp_path / "metadata.db"
sc2ts.MetadataDb.import_csv(csv_path, db_path, sep=",")
return sc2ts.MetadataDb(db_path)


def test_get_group_strains(fx_ts_map):
ts = fx_ts_map["2020-02-13"]
groups = sc2ts.get_group_strains(ts)
Expand Down Expand Up @@ -1379,3 +1326,34 @@ def test_example_3_way_same_parent(self, fx_recombinant_example_3):
m = s.hmm_match
assert m.parents == [55, 54, 55]
assert m.breakpoints == [0, 15001, 29825, 29904]


class TestExtractHaplotypes:

@pytest.mark.parametrize(
["samples", "result"],
[
([0], [[0]]),
([0, 1], [[0], [0]]),
([0, 3], [[0], [1]]),
([3, 0], [[1], [0]]),
([0, 1, 2, 3], [[0], [0], [0], [1]]),
([3, 1, 2, 3], [[1], [0], [0], [1]]),
([3, 3, 3, 3], [[1], [1], [1], [1]]),
],
)
def test_one_leaf_mutation(self, samples, result):
# 3.00┊ 6 ┊
# ┊ ┏━┻━┓ ┊
# 2.00┊ ┃ 5 ┊
# ┊ ┃ ┏━┻┓ ┊
# 1.00┊ ┃ ┃ 4 ┊
# ┊ ┃ ┃ ┏┻┓ ┊
# 0.00┊ 0 1 2 3x┊
# 0 1
ts = tskit.Tree.generate_comb(4).tree_sequence
tables = ts.dump_tables()
tables.sites.add_row(0, "A")
tables.mutations.add_row(site=0, node=3, derived_state="T")
ts = tables.tree_sequence()
nt.assert_array_equal(sc2ts.extract_haplotypes(ts, samples), result)

0 comments on commit 0a12df0

Please sign in to comment.