Skip to content

Commit

Permalink
Merge pull request #452 from jeromekelleher/3-way-recomb-same-parent
Browse files Browse the repository at this point in the history
Support matching back to same parent in recombinant
  • Loading branch information
jeromekelleher authored Dec 16, 2024
2 parents 0a43c33 + 0a12df0 commit 562c467
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 14 deletions.
14 changes: 12 additions & 2 deletions sc2ts/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -1644,6 +1644,17 @@ def get_closest_mutation(node, site_id):
logger.debug(f"Characterised {num_mutations} mutations")


def extract_haplotypes(ts, samples):
# Annoyingly tskit doesn't allow us to specify duplicate samples, which can
# happen perfectly well here, so we must work around.
unique_samples = list(set(samples))
H = ts.genotype_matrix(samples=unique_samples, isolated_as_missing=False).T
ret = []
for node_id in samples:
ret.append(H[unique_samples.index(node_id)])
return ret


def characterise_recombinants(ts, samples):
"""
Update the metadata for any recombinants to add interval information to the metadata.
Expand All @@ -1657,8 +1668,7 @@ def characterise_recombinants(ts, samples):
# but recombinants are rare so let's keep this simple
for s in recombinants:
parents = [seg.parent for seg in s.hmm_match.path]
# Can't have missing data here, so we're OK.
H = ts.genotype_matrix(samples=parents, isolated_as_missing=False).T
H = extract_haplotypes(ts, parents)
breakpoint_intervals = []
for j in range(len(parents) - 1):
parents_differ = np.where(H[j] != H[j + 1])[0]
Expand Down
75 changes: 63 additions & 12 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import dataclasses
import collections
import hashlib
import logging
Expand Down Expand Up @@ -49,18 +50,6 @@ def recombinant_example_1(ts_map):
return ts, s


def tmp_metadata_db(tmp_path, strains, date):
data = []
for strain in strains:
data.append({"strain": strain, "date": date})
df = pd.DataFrame(data)
csv_path = tmp_path / "metadata.csv"
df.to_csv(csv_path)
db_path = tmp_path / "metadata.db"
sc2ts.MetadataDb.import_csv(csv_path, db_path, sep=",")
return sc2ts.MetadataDb(db_path)


def test_get_group_strains(fx_ts_map):
ts = fx_ts_map["2020-02-13"]
groups = sc2ts.get_group_strains(ts)
Expand Down Expand Up @@ -1306,3 +1295,65 @@ def test_example_3(self, fx_recombinant_example_3):
m = s.hmm_match
assert m.parents == [53, 54, 55]
assert m.breakpoints == [0, 114, 15010, 29904]

def test_example_3_way_same_parent(self, fx_recombinant_example_3):
ts = fx_recombinant_example_3
strains = ts.metadata["sc2ts"]["samples_strain"]
assert strains[-1].startswith("recomb")
u = ts.samples()[-1]
h = ts.genotype_matrix(samples=[u], alleles=tuple(sc2ts.IUPAC_ALLELES)).T[0]
tables = ts.dump_tables()
keep_edges = ts.edges_child < u
tables.edges.keep_rows(keep_edges)
keep_nodes = np.ones(ts.num_nodes, dtype=bool)
tables.nodes[u] = tables.nodes[u].replace(flags=0)
tables.sort()
base_ts = tables.tree_sequence()

s = sc2ts.Sample("3way", "2020-02-14", haplotype=h.astype(np.int8))
sc2ts.match_tsinfer(
samples=[s],
ts=base_ts,
num_mismatches=2,
mismatch_threshold=10,
mirror_coordinates=False,
)
# Force back to the same parent so we can check that we're robust to
# same parent
s.hmm_match.path[0] = dataclasses.replace(s.hmm_match.path[0], parent=55)
sc2ts.characterise_recombinants(ts, [s])

m = s.hmm_match
assert m.parents == [55, 54, 55]
assert m.breakpoints == [0, 15001, 29825, 29904]


class TestExtractHaplotypes:

@pytest.mark.parametrize(
["samples", "result"],
[
([0], [[0]]),
([0, 1], [[0], [0]]),
([0, 3], [[0], [1]]),
([3, 0], [[1], [0]]),
([0, 1, 2, 3], [[0], [0], [0], [1]]),
([3, 1, 2, 3], [[1], [0], [0], [1]]),
([3, 3, 3, 3], [[1], [1], [1], [1]]),
],
)
def test_one_leaf_mutation(self, samples, result):
# 3.00┊ 6 ┊
# ┊ ┏━┻━┓ ┊
# 2.00┊ ┃ 5 ┊
# ┊ ┃ ┏━┻┓ ┊
# 1.00┊ ┃ ┃ 4 ┊
# ┊ ┃ ┃ ┏┻┓ ┊
# 0.00┊ 0 1 2 3x┊
# 0 1
ts = tskit.Tree.generate_comb(4).tree_sequence
tables = ts.dump_tables()
tables.sites.add_row(0, "A")
tables.mutations.add_row(site=0, node=3, derived_state="T")
ts = tables.tree_sequence()
nt.assert_array_equal(sc2ts.extract_haplotypes(ts, samples), result)

0 comments on commit 562c467

Please sign in to comment.