Skip to content

Commit

Permalink
maybe works!
Browse files Browse the repository at this point in the history
  • Loading branch information
petrelharp committed Sep 12, 2023
1 parent 69ed0e3 commit 111e105
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 14 deletions.
37 changes: 29 additions & 8 deletions c/tests/test_trees.c
Original file line number Diff line number Diff line change
Expand Up @@ -8217,13 +8217,17 @@ test_extend_edges_simple(void)
{
int ret;
tsk_treeseq_t ts, ets;
const char *nodes_ex = "1 0 -1 -1\n"
"1 0 -1 -1\n"
"0 2.0 -1 -1\n";
const char *edges_ex = "0 10 2 0\n"
"0 10 2 1\n";
const char *nodes = "1 0 -1 -1\n"
"1 0 -1 -1\n"
"0 2.0 -1 -1\n";
const char *edges = "0 10 2 0\n"
"0 10 2 1\n";
const char *sites = "0.0 0\n"
"1.0 0\n";
const char *mutations = "0 0 1 -1 0.5\n"
"1 1 1 -1 0.5\n";

tsk_treeseq_from_text(&ts, 10, nodes_ex, edges_ex, NULL, NULL, NULL, NULL, NULL, 0);
tsk_treeseq_from_text(&ts, 10, nodes, edges, NULL, sites, mutations, NULL, NULL, 0);

ret = tsk_treeseq_extend_edges(&ts, 10, 0, &ets);
CU_ASSERT_EQUAL_FATAL(ret, 0);
Expand All @@ -8234,7 +8238,7 @@ test_extend_edges_simple(void)
}

static void
assert_equal_except_edges(const tsk_treeseq_t *ts1, const tsk_treeseq_t *ts2)
assert_equal_except_edges_and_mutation_nodes(const tsk_treeseq_t *ts1, const tsk_treeseq_t *ts2)
{
tsk_table_collection_t t1, t2;
int ret;
Expand All @@ -8245,8 +8249,15 @@ assert_equal_except_edges(const tsk_treeseq_t *ts1, const tsk_treeseq_t *ts2)
ret = tsk_table_collection_copy(ts2->tables, &t2, 0);
CU_ASSERT_EQUAL_FATAL(ret, 0);

memset(t1.mutations.node, 0, t1.mutations.num_rows * sizeof(*t1.mutations.node));
memset(t2.mutations.node, 0, t2.mutations.num_rows * sizeof(*t2.mutations.node));

CU_ASSERT_TRUE(tsk_mutation_table_equals(&t1.mutations, &t2.mutations, 0));

tsk_edge_table_clear(&t1.edges);
tsk_edge_table_clear(&t2.edges);
tsk_mutation_table_clear(&t1.mutations);
tsk_mutation_table_clear(&t2.mutations);

CU_ASSERT_TRUE(tsk_table_collection_equals(&t1, &t2, 0));

Expand Down Expand Up @@ -8299,6 +8310,12 @@ test_extend_edges(void)
"2 5 7 4\n"
"5 7 8 1\n"
"5 7 8 5\n";
const char *sites_ex = "0.0 0\n"
"9.0 0\n";
const char *mutations_ex = "0 4 1 -1 2.5\n"
"0 4 2 0 1.5\n"
"1 5 1 -1 2.5\n"
"1 5 2 2 1.5\n";

/* Doing this rather than tsk_treeseq_from_text because the edges are unsorted */
ret = tsk_table_collection_init(&tables, 0);
Expand All @@ -8308,6 +8325,10 @@ test_extend_edges(void)
CU_ASSERT_EQUAL_FATAL(tables.nodes.num_rows, 9);
parse_edges(edges_ex, &tables.edges);
CU_ASSERT_EQUAL_FATAL(tables.edges.num_rows, 18);
parse_sites(sites_ex, &tables.sites);
CU_ASSERT_EQUAL_FATAL(tables.sites.num_rows, 2);
parse_mutations(mutations_ex, &tables.mutations);
CU_ASSERT_EQUAL_FATAL(tables.mutations.num_rows, 4);
ret = tsk_table_collection_sort(&tables, NULL, 0);
CU_ASSERT_EQUAL_FATAL(ret, 0);
ret = tsk_treeseq_init(&ts, &tables, TSK_TS_INIT_BUILD_INDEXES);
Expand All @@ -8322,7 +8343,7 @@ test_extend_edges(void)
for (max_iter = 1; max_iter < 10; max_iter++) {
ret = tsk_treeseq_extend_edges(&ts, max_iter, 0, &ets);
CU_ASSERT_EQUAL_FATAL(ret, 0);
assert_equal_except_edges(&ts, &ets);
assert_equal_except_edges_and_mutation_nodes(&ts, &ets);
CU_ASSERT_TRUE(ets.tables->edges.num_rows >= 12);
tsk_treeseq_free(&ets);
}
Expand Down
86 changes: 86 additions & 0 deletions c/tskit/trees.c
Original file line number Diff line number Diff line change
Expand Up @@ -7053,6 +7053,73 @@ tsk_treeseq_extend_edges_iter(
return ret;
}


static int
tsk_treeseq_slide_mutation_nodes_up(const tsk_treeseq_t *self, tsk_mutation_table_t *mutations)
{
int ret = 0;
bool valid;
double t, right;
tsk_id_t c, p, tj, e, next_mut;
tsk_tree_position_t tree_pos;
const tsk_table_collection_t *tables = self->tables;
const tsk_size_t num_nodes = tables->nodes.num_rows;
const tsk_edge_table_t *edges = &tables->edges;
tsk_id_t *parent = tsk_malloc(num_nodes * sizeof(*parent));
double *sites_position = tables->sites.position;
double *nodes_time = tables->nodes.time;

memset(&tree_pos, 0, sizeof(tree_pos));

if (parent == NULL) {
ret = TSK_ERR_NO_MEMORY;
goto out;
}
memset(parent, 0xff, num_nodes * sizeof(*parent));

ret = tsk_tree_position_init(&tree_pos, self, 0);
if (ret != 0) {
goto out;
}

valid = tsk_tree_position_next(&tree_pos);
next_mut = 0;

while (valid) {
right = tree_pos.interval.right;

for (tj = tree_pos.out.start; tj != tree_pos.out.stop; tj += TSK_DIR_FORWARD) {
e = tree_pos.out.order[tj];
parent[edges->child[e]] = TSK_NULL;
}
for (tj = tree_pos.in.start; tj != tree_pos.in.stop; tj += TSK_DIR_FORWARD) {
e = tree_pos.in.order[tj];
parent[edges->child[e]] = edges->parent[e];
}
while (next_mut < (tsk_id_t) mutations->num_rows &&
sites_position[mutations->site[next_mut]] < right) {
t = mutations->time[next_mut];
c = mutations->node[next_mut];
p = parent[c];
while (p != TSK_NULL && nodes_time[p] <= t) {
c = p;
p = parent[c];
}
tsk_bug_assert(nodes_time[c] <= t);
mutations->node[next_mut] = c;
next_mut++;
}

valid = tsk_tree_position_next(&tree_pos);
}

out:
tsk_safe_free(parent);

return ret;
}


int TSK_WARN_UNUSED
tsk_treeseq_extend_edges(const tsk_treeseq_t *self, int max_iter,
tsk_flags_t TSK_UNUSED(options), tsk_treeseq_t *output)
Expand Down Expand Up @@ -7080,6 +7147,10 @@ tsk_treeseq_extend_edges(const tsk_treeseq_t *self, int max_iter,
if (ret != 0) {
goto out;
}
ret = tsk_mutation_table_clear(&tables.mutations);
if (ret != 0) {
goto out;
}
ret = tsk_treeseq_init(&ts, &tables, 0);
if (ret != 0) {
goto out;
Expand All @@ -7105,6 +7176,21 @@ tsk_treeseq_extend_edges(const tsk_treeseq_t *self, int max_iter,
last_num_edges = tsk_treeseq_get_num_edges(&ts);
}

/* Remap mutation nodes */
ret = tsk_mutation_table_copy(&self->tables->mutations, &tables.mutations, TSK_NO_INIT);
if (ret != 0) {
goto out;
}
ret = tsk_treeseq_slide_mutation_nodes_up(&ts, &tables.mutations);
if (ret != 0) {
goto out;
}
tsk_treeseq_free(&ts);
ret = tsk_treeseq_init(&ts, &tables, TSK_TS_INIT_BUILD_INDEXES);
if (ret != 0) {
goto out;
}

/* Hand ownership of the tree sequence to the calling code */
tsk_memcpy(output, &ts, sizeof(ts));
tsk_memset(&ts, 0, sizeof(*output));
Expand Down
9 changes: 3 additions & 6 deletions python/tests/test_extend_edges.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,15 @@ def extend_edges(ts, max_iter=10):
last_num_edges = ts.num_edges

tables = ts.dump_tables()
mutations = _fix_mutation_nodes(ts, mutations)
mutations = _slide_mutation_nodes_up(ts, mutations)
tables.mutations.replace_with(mutations)
tables.build_index()
ts = tables.tree_sequence()

return ts


def _fix_mutation_nodes(ts, mutations):
def _slide_mutation_nodes_up(ts, mutations):
# adjusts mutations' nodes to place each mutation on the correct edge given
# their time; requires mutation times be nonmissing and the mutation times
# be >= their nodes' times.
Expand Down Expand Up @@ -70,7 +70,7 @@ def _fix_mutation_nodes(ts, mutations):
while p != -1 and ts.nodes_time[p] <= t:
c = p
p = parent[c]
assert t >= ts.nodes_time[c]
assert ts.nodes_time[c] <= t
if p != -1:
assert t < ts.nodes_time[p]
new_nodes[next_mut] = c
Expand All @@ -88,9 +88,6 @@ def _fix_mutation_nodes(ts, mutations):
for mut, n in zip(mutations, new_nodes):
new_mutations.append(mut.replace(node=n))

print("before", mutations)
print("after", new_mutations)

return new_mutations


Expand Down

0 comments on commit 111e105

Please sign in to comment.