Skip to content

Commit

Permalink
Finish up implementation of samples arg
Browse files Browse the repository at this point in the history
  • Loading branch information
jeromekelleher committed Apr 21, 2023
1 parent 2895f65 commit e8a3e0d
Show file tree
Hide file tree
Showing 6 changed files with 109 additions and 28 deletions.
15 changes: 13 additions & 2 deletions c/tests/test_trees.c
Original file line number Diff line number Diff line change
Expand Up @@ -3808,11 +3808,22 @@ test_simplest_divergence_matrix(void)

tsk_treeseq_from_text(&ts, 1, nodes, edges, NULL, NULL, NULL, NULL, NULL, 0);

/* ret = tsk_treeseq_divergence_matrix(&ts, 0, NULL, 0, D); */
ret = tsk_treeseq_divergence_matrix(&ts, 2, sample_ids, 0, NULL, 0, result);
CU_ASSERT_EQUAL_FATAL(ret, 0);
assert_arrays_almost_equal(4, D, result);

ret = tsk_treeseq_divergence_matrix(&ts, 0, NULL, 0, NULL, 0, D);
CU_ASSERT_EQUAL_FATAL(ret, 0);
assert_arrays_almost_equal(4, D, result);

sample_ids[0] = -1;
ret = tsk_treeseq_divergence_matrix(&ts, 2, sample_ids, 0, NULL, 0, result);
CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_NODE_OUT_OF_BOUNDS);

sample_ids[0] = 3;
ret = tsk_treeseq_divergence_matrix(&ts, 2, sample_ids, 0, NULL, 0, result);
CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_NODE_OUT_OF_BOUNDS);

tsk_treeseq_free(&ts);
}

Expand Down Expand Up @@ -3861,7 +3872,7 @@ test_simplest_divergence_matrix_internal_sample(void)
{
const char *nodes = "1 0 0\n"
"1 0 0\n"
"1 1 0\n";
"0 1 0\n";
const char *edges = "0 1 2 0,1\n";
tsk_treeseq_t ts;
tsk_id_t sample_ids[] = { 0, 1, 2 };
Expand Down
27 changes: 26 additions & 1 deletion c/tskit/trees.c
Original file line number Diff line number Diff line change
Expand Up @@ -6252,14 +6252,34 @@ sv_tables_mrca(const sv_tables_t *self, tsk_id_t x, tsk_id_t y)
return sv_tables_mrca_one_based(self, x + 1, y + 1) - 1;
}

static int
tsk_treeseq_check_node_bounds(
const tsk_treeseq_t *self, tsk_size_t num_nodes, const tsk_id_t *nodes)
{
int ret = 0;
tsk_size_t j;
tsk_id_t u;
const tsk_id_t N = (tsk_id_t) self->tables->nodes.num_rows;

for (j = 0; j < num_nodes; j++) {
u = nodes[j];
if (u < 0 || u >= N) {
ret = TSK_ERR_NODE_OUT_OF_BOUNDS;
goto out;
}
}
out:
return ret;
}

int
tsk_treeseq_divergence_matrix(const tsk_treeseq_t *self, tsk_size_t num_samples,
const tsk_id_t *samples_in, tsk_size_t num_windows, const double *windows,
tsk_flags_t TSK_UNUSED(options), double *result)
{
int ret = 0;
tsk_tree_t tree;
const tsk_id_t *samples = self->samples;
const tsk_id_t *restrict samples = self->samples;
const double default_windows[] = { 0, self->tables->sequence_length };
const double *restrict nodes_time = self->tables->nodes.time;
tsk_size_t n = self->num_samples;
Expand Down Expand Up @@ -6292,7 +6312,12 @@ tsk_treeseq_divergence_matrix(const tsk_treeseq_t *self, tsk_size_t num_samples,
if (samples_in != NULL) {
samples = samples_in;
n = num_samples;
ret = tsk_treeseq_check_node_bounds(self, n, samples);
if (ret != 0) {
goto out;
}
}

memset(result, 0, num_windows * n * n * sizeof(*result));

for (i = 0; i < num_windows; i++) {
Expand Down
26 changes: 21 additions & 5 deletions python/_tskitmodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -9641,25 +9641,38 @@ static PyObject *
TreeSequence_divergence_matrix(TreeSequence *self, PyObject *args, PyObject *kwds)
{
PyObject *ret = NULL;
static char *kwlist[] = { "windows", NULL };
static char *kwlist[] = { "windows", "samples", NULL };
PyArrayObject *result_array = NULL;
PyObject *windows = NULL;
PyObject *py_samples = Py_None;
PyArrayObject *windows_array = NULL;
PyArrayObject *samples_array = NULL;
tsk_flags_t options = 0;
npy_intp dims[3];
npy_intp *shape, dims[3];
tsk_size_t num_samples, num_windows;
tsk_id_t *samples = NULL;
int err;

if (TreeSequence_check_state(self) != 0) {
goto out;
}
if (!PyArg_ParseTupleAndKeywords(args, kwds, "O", kwlist, &windows)) {
if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|O", kwlist, &windows, &py_samples)) {
goto out;
}
num_samples = tsk_treeseq_get_num_samples(self->tree_sequence);
if (py_samples != Py_None) {
samples_array = (PyArrayObject *) PyArray_FROMANY(
py_samples, NPY_INT32, 1, 1, NPY_ARRAY_IN_ARRAY);
if (samples_array == NULL) {
goto out;
}
shape = PyArray_DIMS(samples_array);
samples = PyArray_DATA(samples_array);
num_samples = (tsk_size_t) shape[0];
}
if (parse_windows(windows, &windows_array, &num_windows) != 0) {
goto out;
}
num_samples = tsk_treeseq_get_num_samples(self->tree_sequence);
dims[0] = num_windows;
dims[1] = num_samples;
dims[2] = num_samples;
Expand All @@ -9670,7 +9683,8 @@ TreeSequence_divergence_matrix(TreeSequence *self, PyObject *args, PyObject *kwd
// clang-format off
Py_BEGIN_ALLOW_THREADS
err = tsk_treeseq_divergence_matrix(
self->tree_sequence, 0, NULL,
self->tree_sequence,
num_samples, samples,
num_windows, PyArray_DATA(windows_array),
options, PyArray_DATA(result_array));
Py_END_ALLOW_THREADS
Expand All @@ -9685,6 +9699,8 @@ TreeSequence_divergence_matrix(TreeSequence *self, PyObject *args, PyObject *kwd
result_array = NULL;
out:
Py_XDECREF(result_array);
Py_XDECREF(windows_array);
/* Py_XDECREF(samples_array); */
return ret;
}

Expand Down
42 changes: 33 additions & 9 deletions python/tests/test_divmat.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,14 +260,18 @@ def check_divmat(

D1 = divergence_matrix(ts, windows=windows, samples=samples)
if compare_stats_api:
# Somethings like duplicate samples aren't worth hacking around for in
# stats API.
D2 = lib_divergence_matrix(ts, windows=windows, samples=samples)
# print("windows = ", windows)
# print(D1)
# print(D2)
np.testing.assert_allclose(D1, D2)
# D3 = ts.divergence_matrix(windows=windows)
# # print(D3)
# np.testing.assert_allclose(D1, D3)
assert D1.shape == D2.shape
D3 = ts.divergence_matrix(windows=windows, samples=samples)
# print(D3)
assert D1.shape == D3.shape
np.testing.assert_allclose(D1, D3)
return D1


Expand Down Expand Up @@ -579,9 +583,9 @@ def test_disconnected_non_sample_topology(self):


class TestThreadsNoWindows:
def check(self, ts, num_threads):
D1 = ts.divergence_matrix(num_threads=0)
D2 = ts.divergence_matrix(num_threads=num_threads)
def check(self, ts, num_threads, samples=None):
D1 = ts.divergence_matrix(num_threads=0, samples=samples)
D2 = ts.divergence_matrix(num_threads=num_threads, samples=samples)
np.testing.assert_array_almost_equal(D1, D2)

@pytest.mark.parametrize("num_threads", [1, 2, 3, 5, 26, 27])
Expand All @@ -590,6 +594,12 @@ def test_all_trees(self, num_threads):
assert ts.num_trees == 26
self.check(ts, num_threads)

@pytest.mark.parametrize("samples", [None, [0, 1]])
def test_all_trees_samples(self, samples):
ts = tsutil.all_trees_ts(4)
assert ts.num_trees == 26
self.check(ts, 2, samples)

@pytest.mark.parametrize("n", [2, 3, 5, 15])
@pytest.mark.parametrize("num_threads", range(1, 5))
def test_simple_sims(self, n, num_threads):
Expand All @@ -606,9 +616,11 @@ def test_simple_sims(self, n, num_threads):


class TestThreadsWindows:
def check(self, ts, num_threads, *, windows):
D1 = ts.divergence_matrix(num_threads=0, windows=windows)
D2 = ts.divergence_matrix(num_threads=num_threads, windows=windows)
def check(self, ts, num_threads, *, windows, samples=None):
D1 = ts.divergence_matrix(num_threads=0, windows=windows, samples=samples)
D2 = ts.divergence_matrix(
num_threads=num_threads, windows=windows, samples=samples
)
np.testing.assert_array_almost_equal(D1, D2)

@pytest.mark.parametrize("num_threads", [1, 2, 3, 5, 26, 27])
Expand All @@ -628,6 +640,18 @@ def test_all_trees(self, num_threads, windows):
assert ts.num_trees == 26
self.check(ts, num_threads, windows=windows)

@pytest.mark.parametrize("samples", [None, [0, 1]])
@pytest.mark.parametrize(
["windows"],
[
([0, 26],),
(None,),
],
)
def test_all_trees_samples(self, samples, windows):
ts = tsutil.all_trees_ts(4)
self.check(ts, 2, windows=windows, samples=samples)

@pytest.mark.parametrize("num_threads", range(1, 5))
@pytest.mark.parametrize(
["windows"],
Expand Down
4 changes: 4 additions & 0 deletions python/tests/test_lowlevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1534,12 +1534,16 @@ def test_divergence_matrix(self):
ts = self.get_example_tree_sequence(n, random_seed=12)
D = ts.divergence_matrix([0, ts.get_sequence_length()])
assert D.shape == (1, n, n)
D = ts.divergence_matrix([0, ts.get_sequence_length()], samples=[0, 1])
assert D.shape == (1, 2, 2)
with pytest.raises(TypeError):
ts.divergence_matrix(windoze=[0, 1])
with pytest.raises(ValueError, match="at least 2"):
ts.divergence_matrix(windows=[0])
with pytest.raises(_tskit.LibraryError, match="BAD_WINDOWS"):
ts.divergence_matrix(windows=[-1, 0, 1])
with pytest.raises(ValueError):
ts.divergence_matrix(windows=[0, 1], samples="sdf")

def test_load_tables_build_indexes(self):
for ts in self.get_example_tree_sequences():
Expand Down
23 changes: 12 additions & 11 deletions python/tskit/trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -7770,24 +7770,24 @@ def _chunk_windows(windows, num_chunks):
# k += 1
# return A

# NOTE see older definition above that we didn't finish up. Are there things
# we should take from this?
# NOTE see older definition of divmat above that we didn't finish up.
# Are there things we should take from this?

def _parallelise_divmat_by_tree(self, num_threads):
def _parallelise_divmat_by_tree(self, num_threads, samples):
"""
No windows were specified, so we can chunk up the whole genome by
tree, and do a simple sum of the results.
"""

def worker(interval):
return self._ll_tree_sequence.divergence_matrix(interval)
return self._ll_tree_sequence.divergence_matrix(interval, samples=samples)

work = self._chunk_sequence_by_tree(num_threads)
with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as pool:
results = pool.map(worker, work)
return sum(results)

def _parallelise_divmat_by_window(self, windows, num_threads):
def _parallelise_divmat_by_window(self, windows, num_threads, samples):
"""
We assume we have a number of windows that's >= to the number
of threads available, and let each thread have a chunk of the
Expand All @@ -7797,28 +7797,29 @@ def _parallelise_divmat_by_window(self, windows, num_threads):
"""

def worker(sub_windows):
return self._ll_tree_sequence.divergence_matrix(sub_windows)
return self._ll_tree_sequence.divergence_matrix(
sub_windows, samples=samples
)

work = self._chunk_windows(windows, num_threads)
with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor:
futures = [executor.submit(worker, sub_windows) for sub_windows in work]
concurrent.futures.wait(futures)
return np.vstack([future.result() for future in futures])

def divergence_matrix(self, *, windows=None, num_threads=0):
# TODO implement "samples" argument
def divergence_matrix(self, *, windows=None, samples=None, num_threads=0):
windows_specified = windows is not None
windows = [0, self.sequence_length] if windows is None else windows

# NOTE: maybe we want to use a different default for num_threads here, just
# following the approach in GNN
if num_threads <= 0:
D = self._ll_tree_sequence.divergence_matrix(windows)
D = self._ll_tree_sequence.divergence_matrix(windows, samples=samples)
else:
if windows_specified:
D = self._parallelise_divmat_by_window(windows, num_threads)
D = self._parallelise_divmat_by_window(windows, num_threads, samples)
else:
D = self._parallelise_divmat_by_tree(num_threads)
D = self._parallelise_divmat_by_tree(num_threads, samples)

if not windows_specified:
# Drop the windows dimension
Expand Down

0 comments on commit e8a3e0d

Please sign in to comment.