Skip to content

Commit

Permalink
remaps mutations, and passes tests!
Browse files Browse the repository at this point in the history
  • Loading branch information
petrelharp committed Sep 12, 2023
1 parent 111e105 commit 90139ee
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 38 deletions.
5 changes: 5 additions & 0 deletions c/tskit/core.c
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,11 @@ tsk_strerror_internal(int err)
"values for any single site. "
"(TSK_ERR_MUTATION_TIME_HAS_BOTH_KNOWN_AND_UNKNOWN)";
break;
case TSK_ERR_DISALLOWED_UNKNOWN_MUTATION_TIME:
ret = "Some mutation times are marked 'unknown' for a method that requires "
"no unknown times. (Use compute_mutation_times to add times?) "
"(TSK_ERR_DISALLOWED_UNKNOWN_MUTATION_TIME)";
break;

/* Migration errors */
case TSK_ERR_UNSORTED_MIGRATIONS:
Expand Down
5 changes: 5 additions & 0 deletions c/tskit/core.h
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,11 @@ the edge on which it occurs, and wasn't TSK_UNKNOWN_TIME.
A single site had a mixture of known mutation times and TSK_UNKNOWN_TIME
*/
#define TSK_ERR_MUTATION_TIME_HAS_BOTH_KNOWN_AND_UNKNOWN -509
/**
Some mutations have TSK_UNKNOWN_TIME in an algorithm where that's
disallowed (use compute_mutation_times?).
*/
#define TSK_ERR_DISALLOWED_UNKNOWN_MUTATION_TIME -510
/** @} */

/**
Expand Down
20 changes: 12 additions & 8 deletions c/tskit/trees.c
Original file line number Diff line number Diff line change
Expand Up @@ -6956,7 +6956,6 @@ tsk_treeseq_extend_edges_iter(
out_parent[edges->child[e]] = edges->parent[e];
}


for (tj = tree_pos.out.start; tj != tree_pos.out.stop; tj += direction) {
e = tree_pos.out.order[tj];
if (out_parent[edges->child[e]] == TSK_NULL) {
Expand Down Expand Up @@ -7014,7 +7013,8 @@ tsk_treeseq_extend_edges_iter(
}
near_side[e_in] = there;
while (c != p) {
for (ex_out = edges_out_head; ex_out != NULL; ex_out = ex_out->next) {
for (ex_out = edges_out_head; ex_out != NULL;
ex_out = ex_out->next) {
e_out = ex_out->edge;
if (edges->child[e_out] == c) {
break;
Expand Down Expand Up @@ -7053,9 +7053,9 @@ tsk_treeseq_extend_edges_iter(
return ret;
}


static int
tsk_treeseq_slide_mutation_nodes_up(const tsk_treeseq_t *self, tsk_mutation_table_t *mutations)
tsk_treeseq_slide_mutation_nodes_up(
const tsk_treeseq_t *self, tsk_mutation_table_t *mutations)
{
int ret = 0;
bool valid;
Expand Down Expand Up @@ -7096,9 +7096,13 @@ tsk_treeseq_slide_mutation_nodes_up(const tsk_treeseq_t *self, tsk_mutation_tabl
e = tree_pos.in.order[tj];
parent[edges->child[e]] = edges->parent[e];
}
while (next_mut < (tsk_id_t) mutations->num_rows &&
sites_position[mutations->site[next_mut]] < right) {
while (next_mut < (tsk_id_t) mutations->num_rows
&& sites_position[mutations->site[next_mut]] < right) {
t = mutations->time[next_mut];
if (tsk_is_unknown_time(t)) {
ret = TSK_ERR_DISALLOWED_UNKNOWN_MUTATION_TIME;
goto out;
}
c = mutations->node[next_mut];
p = parent[c];
while (p != TSK_NULL && nodes_time[p] <= t) {
Expand All @@ -7119,7 +7123,6 @@ tsk_treeseq_slide_mutation_nodes_up(const tsk_treeseq_t *self, tsk_mutation_tabl
return ret;
}


int TSK_WARN_UNUSED
tsk_treeseq_extend_edges(const tsk_treeseq_t *self, int max_iter,
tsk_flags_t TSK_UNUSED(options), tsk_treeseq_t *output)
Expand Down Expand Up @@ -7177,7 +7180,8 @@ tsk_treeseq_extend_edges(const tsk_treeseq_t *self, int max_iter,
}

/* Remap mutation nodes */
ret = tsk_mutation_table_copy(&self->tables->mutations, &tables.mutations, TSK_NO_INIT);
ret = tsk_mutation_table_copy(
&self->tables->mutations, &tables.mutations, TSK_NO_INIT);
if (ret != 0) {
goto out;
}
Expand Down
70 changes: 40 additions & 30 deletions python/tests/test_extend_edges.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import numpy as np
import pytest

import _tskit
import tests.test_wright_fisher as wf
import tskit
from tests import tsutil
Expand Down Expand Up @@ -53,16 +54,20 @@ def _slide_mutation_nodes_up(ts, mutations):
while valid:
left, right = tree_pos.interval

for j in range(tree_pos.out_range.start, tree_pos.out_range.stop, tskit.FORWARD):
for j in range(
tree_pos.out_range.start, tree_pos.out_range.stop, tskit.FORWARD
):
e = tree_pos.out_range.order[j]
parent[edges.child[e]] = -1

for j in range(tree_pos.in_range.start, tree_pos.in_range.stop, tskit.FORWARD):
e = tree_pos.in_range.order[j]
parent[edges.child[e]] = edges.parent[e]

while (next_mut < mutations.num_rows and
ts.sites_position[mutations.site[next_mut]] < right):
while (
next_mut < mutations.num_rows
and ts.sites_position[mutations.site[next_mut]] < right
):
t = mutations.time[next_mut]
c = new_nodes[next_mut]
p = parent[c]
Expand All @@ -84,12 +89,13 @@ def _slide_mutation_nodes_up(ts, mutations):
valid = tree_pos.next()

# in C the node column can be edited in place
new_mutations = tskit.MutationTable()
new_mutations = mutations.copy()
new_mutations.clear()
for mut, n in zip(mutations, new_nodes):
new_mutations.append(mut.replace(node=n))

return new_mutations


def _extend(ts, forwards=True):
# `degree` will record the degree of each node in the tree we'd get if
Expand Down Expand Up @@ -132,7 +138,7 @@ def _extend(ts, forwards=True):
# if an edge from p->c has been extended, entirely replacing
# another edge from p'->c, then both edges may be in edges_out,
# and we only want to include the *first* one.
for e, x in edges_out:
for e, _ in edges_out:
out_parent[edges.child[e]] = -1
tmp = []
for e, x in edges_out:
Expand Down Expand Up @@ -168,7 +174,7 @@ def _extend(ts, forwards=True):
# validate out_parent array
for c, p in enumerate(out_parent):
foundit = False
for e, x in edges_out:
for e, _ in edges_out:
if edges.child[e] == c:
assert edges.parent[e] == p
foundit = True
Expand Down Expand Up @@ -359,6 +365,18 @@ def test_runs(self):
ts = msprime.simulate(5, mutation_rate=1.0, random_seed=126)
self.verify_extend_edges(ts)

def test_unknown_times(self):
ts = msprime.simulate(5, mutation_rate=1.0, random_seed=126)
tables = ts.dump_tables()
tables.mutations.clear()
for mut in ts.mutations():
tables.mutations.append(mut.replace(time=tskit.UNKNOWN_TIME))
ts = tables.tree_sequence()
with pytest.raises(
_tskit.LibraryError, match="TSK_ERR_DISALLOWED_UNKNOWN_MUTATION_TIME"
):
_ = ts.extend_edges()

def test_max_iter(self):
ts = msprime.simulate(5, random_seed=126)
with pytest.raises(ValueError, match="max_iter"):
Expand All @@ -381,12 +399,12 @@ def test_simple_ex(self):
# 0 1 2 3 0 1 2 3 0 1 2 3 0 1 2 3
#
# Result:
#
# 6 6 6 6
# +-+-+ +-+-+ +-+-+ +-+-+
# 7 8 7 8 7 8 7 8
# | | ++-+ | | +-++ | |
# 4 5 4 | 5 4 | 5 4 5
#
# 6 6 6 6
# +-+-+ +-+-+ +-+-+ +-+-+
# 7 8 7 8 7 8 7 8
# | | ++-+ | | +-++ | |
# 4 5 4 | 5 4 | 5 4 5
# +++ +++ +++ | | | | +++ +++ +++
# 0 1 2 3 0 1 2 3 0 1 2 3 0 1 2 3

Expand Down Expand Up @@ -436,7 +454,7 @@ def test_simple_ex(self):
for t in ets.trees():
print(".....")
print(t.interval)
print(t.draw(format='ascii'))
print(t.draw(format="ascii"))
assert ts.num_edges == 18
assert ets.num_edges == 12
for t in ets.trees():
Expand All @@ -453,34 +471,22 @@ def test_wright_fisher(self):
tables = wf.wf_sim(N=5, ngens=20, num_loci=100, deep_history=False, seed=3)
tables.sort()
tables.simplify()
ts = msprime.sim_mutations(
tables.tree_sequence(),
rate=0.01,
random_seed=888
)
ts = msprime.sim_mutations(tables.tree_sequence(), rate=0.01, random_seed=888)
self.verify_extend_edges(ts, max_iter=1)
self.verify_extend_edges(ts)

def test_wright_fisher_unsimplified(self):
tables = wf.wf_sim(N=6, ngens=22, num_loci=100, deep_history=False, seed=4)
tables.sort()
ts = msprime.sim_mutations(
tables.tree_sequence(),
rate=0.01,
random_seed=888
)
ts = msprime.sim_mutations(tables.tree_sequence(), rate=0.01, random_seed=888)
self.verify_extend_edges(ts, max_iter=1)
self.verify_extend_edges(ts)

def test_wright_fisher_with_history(self):
tables = wf.wf_sim(N=8, ngens=15, num_loci=100, deep_history=True, seed=5)
tables.sort()
tables.simplify()
ts = msprime.sim_mutations(
tables.tree_sequence(),
rate=0.01,
random_seed=888
)
ts = msprime.sim_mutations(tables.tree_sequence(), rate=0.01, random_seed=888)
self.verify_extend_edges(ts, max_iter=1)
self.verify_extend_edges(ts)

Expand All @@ -502,8 +508,12 @@ class TestExamples:
"""

def check(self, ts):
lib_ts = ts.extend_edges()
if np.any(tskit.is_unknown_time(ts.mutations_time)):
tables = ts.dump_tables()
tables.compute_mutation_times()
ts = tables.tree_sequence()
py_ts = extend_edges(ts)
lib_ts = ts.extend_edges()
lib_ts.tables.assert_equals(py_ts.tables)

@pytest.mark.parametrize("ts", get_example_tree_sequences())
Expand Down

0 comments on commit 90139ee

Please sign in to comment.