diff --git a/c/tskit/tables.c b/c/tskit/tables.c index c5c414c6c2..41f2f2f7bb 100644 --- a/c/tskit/tables.c +++ b/c/tskit/tables.c @@ -12338,45 +12338,61 @@ tsk_table_collection_add_and_remap_node(tsk_table_collection_t *self, return ret; } +typedef struct _edge_list_t { + tsk_id_t edge; + struct _edge_list_t *next; +} edge_list_t; + int forward_extend(tsk_table_collection_t *self) { int ret = 0; double *new_left, *new_right; + tsk_id_t *num_edges; tsk_id_t tj, tk; const tsk_id_t *I, *O; const tsk_id_t M = (tsk_id_t) edges.num_rows; tsk_id_t *parent = NULL; + float left, right; + edge_list_t *pending_in, *pending_out; + int num_in, num_out; + num_edges = tsk_malloc(self->nodes.num_rows * sizeof(*num_edges)); new_left = tsk_malloc(self->edges.num_rows * sizeof(*new_left)); new_right = tsk_malloc(self->edges.num_rows * sizeof(*new_right)); - if (new_left == NULL || new_right == NULL) { + if (num_edges == NULL || new_left == NULL || new_right == NULL) { ret = TSK_ERR_NO_MEMORY; goto out; } + tsk_memset(num_edges, 0x00, self->nodes.num_rows * sizeof(*num_edges)); memcpy(new_left, self->edges.left, self->edges.num_rows * sizeof(*new_left)); memcpy(new_right, self->edges.right, self->edges.num_rows * sizeof(*new_right)); - parent = tsk_malloc(nodes.num_rows * sizeof(*parent)); - if (parent == NULL) { - ret = TSK_ERR_NO_MEMORY; - goto out; - } - tsk_memset(parent, 0xff, nodes.num_rows * sizeof(*parent)); - I = self->indexes.edge_insertion_order; O = self->indexes.edge_removal_order; - tj = 0; - tk = 0; + tj = 0; // current position in I + tk = 0; // current position in O left = 0; + num_in = 0; + num_out = 0; while (tj < M || left < self->sequence_length) { - while (tk < M && edges.right[O[tk]] == left) { - parent[edges.child[O[tk]]] = TSK_NULL; + while (tk < M && self->edges.right[O[tk]] == left) { + // parent[self->edges.child[O[tk]]] = TSK_NULL; + // add edge tk to pending_out + new_out *edge_list_t; // TODO MAKE NEW ONE + new_out.edge = tk; + pending_out[num_out - 1].next = new_out; + num_out++; tk++; + num_edges[self->edges.parent[O[tk]]] -= 1; + num_edges[self->edges.parent[O[tk]]] -= 1; } - while (tj < M && edges.left[I[tj]] == left) { - parent[edges.child[I[tj]]] = edges.parent[I[tj]]; + while (tj < M && self->edges.left[I[tj]] == left) { + // parent[self->edges.child[I[tj]]] = self->edges.parent[I[tj]]; + // add edge tj to pending_in tj++; + num_edges[self->edges.parent[I[tk]]] += 1; + num_edges[self->edges.parent[I[tk]]] += 1; } right = self->sequence_length; if (tj < M) { diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 84fb48ec08..80a3028da3 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -6856,17 +6856,13 @@ def _forward_extend(self, forwards=True): assert np.all(num_edges >= 0) pending_out = [] pending_in = [] - extended = [False for _ in edges_out] - # remove those edges we've removed entirely from consideration - for e1 in pending_out: - for e2 in pending_in: - if e1.parent == e2.parent and e1.child == e2.child: - pending_out.remove(e1) - pending_in.remove(e2) - for j1, e1 in enumerate(edges_out): - if not extended[j1]: - for j2, e2 in enumerate(edges_out): - if not extended[j2]: + # keep track of which edges we start extending + # so that we don't try to extend an edge twice + extended = np.repeat(False, self.num_edges) + for e1 in edges_out: + if not extended[e1.id]: + for e2 in edges_out: + if not extended[e2.id]: # need the intermediate node to not be present in # the new tree if (e1.parent == e2.child) and (num_edges[e2.child] == 0): @@ -6877,8 +6873,8 @@ def _forward_extend(self, forwards=True): and e2.parent == e_in.parent ): # extend e2->e1 and postpone e_in - extended[j1] = True - extended[j2] = True + extended[e1.id] = True + extended[e2.id] = True pending_out.extend([e1, e2]) pending_in.append(e_in) new_right[e1.id] = interval.right