diff --git a/c/tests/test_trees.c b/c/tests/test_trees.c index 8c73d3ffb4..155973c75e 100644 --- a/c/tests/test_trees.c +++ b/c/tests/test_trees.c @@ -8217,13 +8217,17 @@ test_extend_edges_simple(void) { int ret; tsk_treeseq_t ts, ets; - const char *nodes_ex = "1 0 -1 -1\n" - "1 0 -1 -1\n" - "0 2.0 -1 -1\n"; - const char *edges_ex = "0 10 2 0\n" - "0 10 2 1\n"; + const char *nodes = "1 0 -1 -1\n" + "1 0 -1 -1\n" + "0 2.0 -1 -1\n"; + const char *edges = "0 10 2 0\n" + "0 10 2 1\n"; + const char *sites = "0.0 0\n" + "1.0 0\n"; + const char *mutations = "0 0 1 -1 0.5\n" + "1 1 1 -1 0.5\n"; - tsk_treeseq_from_text(&ts, 10, nodes_ex, edges_ex, NULL, NULL, NULL, NULL, NULL, 0); + tsk_treeseq_from_text(&ts, 10, nodes, edges, NULL, sites, mutations, NULL, NULL, 0); ret = tsk_treeseq_extend_edges(&ts, 10, 0, &ets); CU_ASSERT_EQUAL_FATAL(ret, 0); @@ -8234,7 +8238,7 @@ test_extend_edges_simple(void) } static void -assert_equal_except_edges(const tsk_treeseq_t *ts1, const tsk_treeseq_t *ts2) +assert_equal_except_edges_and_mutation_nodes(const tsk_treeseq_t *ts1, const tsk_treeseq_t *ts2) { tsk_table_collection_t t1, t2; int ret; @@ -8245,8 +8249,15 @@ assert_equal_except_edges(const tsk_treeseq_t *ts1, const tsk_treeseq_t *ts2) ret = tsk_table_collection_copy(ts2->tables, &t2, 0); CU_ASSERT_EQUAL_FATAL(ret, 0); + memset(t1.mutations.node, 0, t1.mutations.num_rows * sizeof(*t1.mutations.node)); + memset(t2.mutations.node, 0, t2.mutations.num_rows * sizeof(*t2.mutations.node)); + + CU_ASSERT_TRUE(tsk_mutation_table_equals(&t1.mutations, &t2.mutations, 0)); + tsk_edge_table_clear(&t1.edges); tsk_edge_table_clear(&t2.edges); + tsk_mutation_table_clear(&t1.mutations); + tsk_mutation_table_clear(&t2.mutations); CU_ASSERT_TRUE(tsk_table_collection_equals(&t1, &t2, 0)); @@ -8299,6 +8310,12 @@ test_extend_edges(void) "2 5 7 4\n" "5 7 8 1\n" "5 7 8 5\n"; + const char *sites_ex = "0.0 0\n" + "9.0 0\n"; + const char *mutations_ex = "0 4 1 -1 2.5\n" + "0 4 2 0 1.5\n" + "1 5 1 -1 2.5\n" + "1 5 2 2 1.5\n"; /* Doing this rather than tsk_treeseq_from_text because the edges are unsorted */ ret = tsk_table_collection_init(&tables, 0); @@ -8308,6 +8325,10 @@ test_extend_edges(void) CU_ASSERT_EQUAL_FATAL(tables.nodes.num_rows, 9); parse_edges(edges_ex, &tables.edges); CU_ASSERT_EQUAL_FATAL(tables.edges.num_rows, 18); + parse_sites(sites_ex, &tables.sites); + CU_ASSERT_EQUAL_FATAL(tables.sites.num_rows, 2); + parse_mutations(mutations_ex, &tables.mutations); + CU_ASSERT_EQUAL_FATAL(tables.mutations.num_rows, 4); ret = tsk_table_collection_sort(&tables, NULL, 0); CU_ASSERT_EQUAL_FATAL(ret, 0); ret = tsk_treeseq_init(&ts, &tables, TSK_TS_INIT_BUILD_INDEXES); @@ -8322,7 +8343,7 @@ test_extend_edges(void) for (max_iter = 1; max_iter < 10; max_iter++) { ret = tsk_treeseq_extend_edges(&ts, max_iter, 0, &ets); CU_ASSERT_EQUAL_FATAL(ret, 0); - assert_equal_except_edges(&ts, &ets); + assert_equal_except_edges_and_mutation_nodes(&ts, &ets); CU_ASSERT_TRUE(ets.tables->edges.num_rows >= 12); tsk_treeseq_free(&ets); } diff --git a/c/tskit/trees.c b/c/tskit/trees.c index 54c21d4d44..bb7002386f 100644 --- a/c/tskit/trees.c +++ b/c/tskit/trees.c @@ -7053,6 +7053,73 @@ tsk_treeseq_extend_edges_iter( return ret; } + +static int +tsk_treeseq_slide_mutation_nodes_up(const tsk_treeseq_t *self, tsk_mutation_table_t *mutations) +{ + int ret = 0; + bool valid; + double t, right; + tsk_id_t c, p, tj, e, next_mut; + tsk_tree_position_t tree_pos; + const tsk_table_collection_t *tables = self->tables; + const tsk_size_t num_nodes = tables->nodes.num_rows; + const tsk_edge_table_t *edges = &tables->edges; + tsk_id_t *parent = tsk_malloc(num_nodes * sizeof(*parent)); + double *sites_position = tables->sites.position; + double *nodes_time = tables->nodes.time; + + memset(&tree_pos, 0, sizeof(tree_pos)); + + if (parent == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + memset(parent, 0xff, num_nodes * sizeof(*parent)); + + ret = tsk_tree_position_init(&tree_pos, self, 0); + if (ret != 0) { + goto out; + } + + valid = tsk_tree_position_next(&tree_pos); + next_mut = 0; + + while (valid) { + right = tree_pos.interval.right; + + for (tj = tree_pos.out.start; tj != tree_pos.out.stop; tj += TSK_DIR_FORWARD) { + e = tree_pos.out.order[tj]; + parent[edges->child[e]] = TSK_NULL; + } + for (tj = tree_pos.in.start; tj != tree_pos.in.stop; tj += TSK_DIR_FORWARD) { + e = tree_pos.in.order[tj]; + parent[edges->child[e]] = edges->parent[e]; + } + while (next_mut < (tsk_id_t) mutations->num_rows && + sites_position[mutations->site[next_mut]] < right) { + t = mutations->time[next_mut]; + c = mutations->node[next_mut]; + p = parent[c]; + while (p != TSK_NULL && nodes_time[p] <= t) { + c = p; + p = parent[c]; + } + tsk_bug_assert(nodes_time[c] <= t); + mutations->node[next_mut] = c; + next_mut++; + } + + valid = tsk_tree_position_next(&tree_pos); + } + +out: + tsk_safe_free(parent); + + return ret; +} + + int TSK_WARN_UNUSED tsk_treeseq_extend_edges(const tsk_treeseq_t *self, int max_iter, tsk_flags_t TSK_UNUSED(options), tsk_treeseq_t *output) @@ -7080,6 +7147,10 @@ tsk_treeseq_extend_edges(const tsk_treeseq_t *self, int max_iter, if (ret != 0) { goto out; } + ret = tsk_mutation_table_clear(&tables.mutations); + if (ret != 0) { + goto out; + } ret = tsk_treeseq_init(&ts, &tables, 0); if (ret != 0) { goto out; @@ -7105,6 +7176,21 @@ tsk_treeseq_extend_edges(const tsk_treeseq_t *self, int max_iter, last_num_edges = tsk_treeseq_get_num_edges(&ts); } + /* Remap mutation nodes */ + ret = tsk_mutation_table_copy(&self->tables->mutations, &tables.mutations, TSK_NO_INIT); + if (ret != 0) { + goto out; + } + ret = tsk_treeseq_slide_mutation_nodes_up(&ts, &tables.mutations); + if (ret != 0) { + goto out; + } + tsk_treeseq_free(&ts); + ret = tsk_treeseq_init(&ts, &tables, TSK_TS_INIT_BUILD_INDEXES); + if (ret != 0) { + goto out; + } + /* Hand ownership of the tree sequence to the calling code */ tsk_memcpy(output, &ts, sizeof(ts)); tsk_memset(&ts, 0, sizeof(*output)); diff --git a/python/tests/test_extend_edges.py b/python/tests/test_extend_edges.py index 44660e3ece..94ffae0dc4 100644 --- a/python/tests/test_extend_edges.py +++ b/python/tests/test_extend_edges.py @@ -29,7 +29,7 @@ def extend_edges(ts, max_iter=10): last_num_edges = ts.num_edges tables = ts.dump_tables() - mutations = _fix_mutation_nodes(ts, mutations) + mutations = _slide_mutation_nodes_up(ts, mutations) tables.mutations.replace_with(mutations) tables.build_index() ts = tables.tree_sequence() @@ -37,7 +37,7 @@ def extend_edges(ts, max_iter=10): return ts -def _fix_mutation_nodes(ts, mutations): +def _slide_mutation_nodes_up(ts, mutations): # adjusts mutations' nodes to place each mutation on the correct edge given # their time; requires mutation times be nonmissing and the mutation times # be >= their nodes' times. @@ -70,7 +70,7 @@ def _fix_mutation_nodes(ts, mutations): while p != -1 and ts.nodes_time[p] <= t: c = p p = parent[c] - assert t >= ts.nodes_time[c] + assert ts.nodes_time[c] <= t if p != -1: assert t < ts.nodes_time[p] new_nodes[next_mut] = c @@ -88,9 +88,6 @@ def _fix_mutation_nodes(ts, mutations): for mut, n in zip(mutations, new_nodes): new_mutations.append(mut.replace(node=n)) - print("before", mutations) - print("after", new_mutations) - return new_mutations