diff --git a/python/tests/test_highlevel.py b/python/tests/test_highlevel.py index e560fbfe40..5929daa6ae 100644 --- a/python/tests/test_highlevel.py +++ b/python/tests/test_highlevel.py @@ -3822,6 +3822,23 @@ def test_num_children(self): for u in tree.nodes(): assert tree.num_children(u) == len(tree.children(u)) + def test_ancestors(self): + tree = tskit.Tree.generate_balanced(10, arity=3) + ancestors_arrays = {u: [] for u in np.arange(tree.tree_sequence.num_nodes)} + ancestors_arrays[-1] = [] + for u in tree.nodes(order="preorder"): + parent = tree.parent(u) + if parent != tskit.NULL: + ancestors_arrays[u] = [parent] + ancestors_arrays[tree.parent(u)] + for u in tree.nodes(): + assert list(tree.ancestors(u)) == ancestors_arrays[u] + + def test_ancestors_empty(self): + ts = tskit.Tree.generate_comb(10).tree_sequence + tree = ts.delete_intervals([[0, 1]]).first() + for u in ts.samples(): + assert len(list(tree.ancestors(u))) == 0 + @pytest.mark.parametrize("ts", get_example_tree_sequences()) def test_virtual_root_semantics(self, ts): for tree in ts.trees(): diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 79ca182d67..085750941d 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -1091,6 +1091,15 @@ def parent_array(self): """ return self._parent_array + def ancestors(self, u): + """ + Returns an iterator over the ancestors of node ``u`` in this tree. + """ + u = self.parent(u) + while u != -1: + yield u + u = self.parent(u) + # Quintuply linked tree structure. def left_child(self, u):