From 0a12df0d89ed0d8c21248ca21645dc7444855005 Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Mon, 16 Dec 2024 23:15:14 +0000 Subject: [PATCH] Fixup tests --- sc2ts/inference.py | 2 + tests/test_inference.py | 84 +++++++++++++++-------------------------- 2 files changed, 33 insertions(+), 53 deletions(-) diff --git a/sc2ts/inference.py b/sc2ts/inference.py index 866c737..61e65f0 100644 --- a/sc2ts/inference.py +++ b/sc2ts/inference.py @@ -1645,6 +1645,8 @@ def get_closest_mutation(node, site_id): def extract_haplotypes(ts, samples): + # Annoyingly tskit doesn't allow us to specify duplicate samples, which can + # happen perfectly well here, so we must work around. unique_samples = list(set(samples)) H = ts.genotype_matrix(samples=unique_samples, isolated_as_missing=False).T ret = [] diff --git a/tests/test_inference.py b/tests/test_inference.py index 60ca0ea..82b5ec1 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -50,59 +50,6 @@ def recombinant_example_1(ts_map): return ts, s -def recombinant_example_3_way_same_parent(recombinant_example_1): - """ - Example recombinant created by cherry picking two samples that differ - by mutations on either end of the genome, and smushing them together. - Note there's only two mutations needed, so we need to set num_mismatches=2 - """ - ts = recombinant_example_1 - s = "recombinant_example_1_0" - parent = ts.samples()[ts.metadata["sc2ts"]["samples_strain"].index(s)] - # ts = ts_map["2020-02-13"] - # parent = ts.samples()[ts. - # parent = 45 - # assert ts.node(parent).metadata["strain"] == "SRR11597163" - # add a child that has a bunch of mutations at the start and end - h = ts.genotype_matrix(samples=[parent], alleles=tuple("ACGT-")).T[0] - - tables = ts.dump_tables() - node_time = -1 - child = tables.nodes.add_row(time=node_time) - tables.edges.add_row(0, ts.sequence_length, parent=parent, child=child) - - start = 11_000 - stop = 11_050 - for k in range(start, stop): - tables.mutations.add_row(site=k, derived_state="A", node=child, time=node_time) - h[k] = 0 - - # Stick on a bunch of mutations either side to force switching back to same haplotype - for k in range(1, 5): - tables.mutations.add_row( - site=start - k, derived_state="A", node=child, time=node_time - ) - tables.mutations.add_row( - site=stop + k, derived_state="A", node=child, time=node_time - ) - - tables.sort() - s = sc2ts.Sample("frankentype", "2020-02-14", haplotype=h.astype(np.int8)) - return tables.tree_sequence(), s - - -def tmp_metadata_db(tmp_path, strains, date): - data = [] - for strain in strains: - data.append({"strain": strain, "date": date}) - df = pd.DataFrame(data) - csv_path = tmp_path / "metadata.csv" - df.to_csv(csv_path) - db_path = tmp_path / "metadata.db" - sc2ts.MetadataDb.import_csv(csv_path, db_path, sep=",") - return sc2ts.MetadataDb(db_path) - - def test_get_group_strains(fx_ts_map): ts = fx_ts_map["2020-02-13"] groups = sc2ts.get_group_strains(ts) @@ -1379,3 +1326,34 @@ def test_example_3_way_same_parent(self, fx_recombinant_example_3): m = s.hmm_match assert m.parents == [55, 54, 55] assert m.breakpoints == [0, 15001, 29825, 29904] + + +class TestExtractHaplotypes: + + @pytest.mark.parametrize( + ["samples", "result"], + [ + ([0], [[0]]), + ([0, 1], [[0], [0]]), + ([0, 3], [[0], [1]]), + ([3, 0], [[1], [0]]), + ([0, 1, 2, 3], [[0], [0], [0], [1]]), + ([3, 1, 2, 3], [[1], [0], [0], [1]]), + ([3, 3, 3, 3], [[1], [1], [1], [1]]), + ], + ) + def test_one_leaf_mutation(self, samples, result): + # 3.00┊ 6 ┊ + # ┊ ┏━┻━┓ ┊ + # 2.00┊ ┃ 5 ┊ + # ┊ ┃ ┏━┻┓ ┊ + # 1.00┊ ┃ ┃ 4 ┊ + # ┊ ┃ ┃ ┏┻┓ ┊ + # 0.00┊ 0 1 2 3x┊ + # 0 1 + ts = tskit.Tree.generate_comb(4).tree_sequence + tables = ts.dump_tables() + tables.sites.add_row(0, "A") + tables.mutations.add_row(site=0, node=3, derived_state="T") + ts = tables.tree_sequence() + nt.assert_array_equal(sc2ts.extract_haplotypes(ts, samples), result)