From e8a3e0d69e57a006b2b5410ae8e39d7c40adca55 Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Sat, 22 Apr 2023 00:04:24 +0100 Subject: [PATCH] Finish up implementation of samples arg --- c/tests/test_trees.c | 15 +++++++++++-- c/tskit/trees.c | 27 +++++++++++++++++++++- python/_tskitmodule.c | 26 +++++++++++++++++----- python/tests/test_divmat.py | 42 +++++++++++++++++++++++++++-------- python/tests/test_lowlevel.py | 4 ++++ python/tskit/trees.py | 23 ++++++++++--------- 6 files changed, 109 insertions(+), 28 deletions(-) diff --git a/c/tests/test_trees.c b/c/tests/test_trees.c index 51491e0643..c8d051f5c1 100644 --- a/c/tests/test_trees.c +++ b/c/tests/test_trees.c @@ -3808,11 +3808,22 @@ test_simplest_divergence_matrix(void) tsk_treeseq_from_text(&ts, 1, nodes, edges, NULL, NULL, NULL, NULL, NULL, 0); - /* ret = tsk_treeseq_divergence_matrix(&ts, 0, NULL, 0, D); */ ret = tsk_treeseq_divergence_matrix(&ts, 2, sample_ids, 0, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal(4, D, result); + ret = tsk_treeseq_divergence_matrix(&ts, 0, NULL, 0, NULL, 0, D); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(4, D, result); + + sample_ids[0] = -1; + ret = tsk_treeseq_divergence_matrix(&ts, 2, sample_ids, 0, NULL, 0, result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_NODE_OUT_OF_BOUNDS); + + sample_ids[0] = 3; + ret = tsk_treeseq_divergence_matrix(&ts, 2, sample_ids, 0, NULL, 0, result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_NODE_OUT_OF_BOUNDS); + tsk_treeseq_free(&ts); } @@ -3861,7 +3872,7 @@ test_simplest_divergence_matrix_internal_sample(void) { const char *nodes = "1 0 0\n" "1 0 0\n" - "1 1 0\n"; + "0 1 0\n"; const char *edges = "0 1 2 0,1\n"; tsk_treeseq_t ts; tsk_id_t sample_ids[] = { 0, 1, 2 }; diff --git a/c/tskit/trees.c b/c/tskit/trees.c index 435a0fe839..a2141c2d2f 100644 --- a/c/tskit/trees.c +++ b/c/tskit/trees.c @@ -6252,6 +6252,26 @@ sv_tables_mrca(const sv_tables_t *self, tsk_id_t x, tsk_id_t y) return sv_tables_mrca_one_based(self, x + 1, y + 1) - 1; } +static int +tsk_treeseq_check_node_bounds( + const tsk_treeseq_t *self, tsk_size_t num_nodes, const tsk_id_t *nodes) +{ + int ret = 0; + tsk_size_t j; + tsk_id_t u; + const tsk_id_t N = (tsk_id_t) self->tables->nodes.num_rows; + + for (j = 0; j < num_nodes; j++) { + u = nodes[j]; + if (u < 0 || u >= N) { + ret = TSK_ERR_NODE_OUT_OF_BOUNDS; + goto out; + } + } +out: + return ret; +} + int tsk_treeseq_divergence_matrix(const tsk_treeseq_t *self, tsk_size_t num_samples, const tsk_id_t *samples_in, tsk_size_t num_windows, const double *windows, @@ -6259,7 +6279,7 @@ tsk_treeseq_divergence_matrix(const tsk_treeseq_t *self, tsk_size_t num_samples, { int ret = 0; tsk_tree_t tree; - const tsk_id_t *samples = self->samples; + const tsk_id_t *restrict samples = self->samples; const double default_windows[] = { 0, self->tables->sequence_length }; const double *restrict nodes_time = self->tables->nodes.time; tsk_size_t n = self->num_samples; @@ -6292,7 +6312,12 @@ tsk_treeseq_divergence_matrix(const tsk_treeseq_t *self, tsk_size_t num_samples, if (samples_in != NULL) { samples = samples_in; n = num_samples; + ret = tsk_treeseq_check_node_bounds(self, n, samples); + if (ret != 0) { + goto out; + } } + memset(result, 0, num_windows * n * n * sizeof(*result)); for (i = 0; i < num_windows; i++) { diff --git a/python/_tskitmodule.c b/python/_tskitmodule.c index 5410985de9..802fafe7c9 100644 --- a/python/_tskitmodule.c +++ b/python/_tskitmodule.c @@ -9641,25 +9641,38 @@ static PyObject * TreeSequence_divergence_matrix(TreeSequence *self, PyObject *args, PyObject *kwds) { PyObject *ret = NULL; - static char *kwlist[] = { "windows", NULL }; + static char *kwlist[] = { "windows", "samples", NULL }; PyArrayObject *result_array = NULL; PyObject *windows = NULL; + PyObject *py_samples = Py_None; PyArrayObject *windows_array = NULL; + PyArrayObject *samples_array = NULL; tsk_flags_t options = 0; - npy_intp dims[3]; + npy_intp *shape, dims[3]; tsk_size_t num_samples, num_windows; + tsk_id_t *samples = NULL; int err; if (TreeSequence_check_state(self) != 0) { goto out; } - if (!PyArg_ParseTupleAndKeywords(args, kwds, "O", kwlist, &windows)) { + if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|O", kwlist, &windows, &py_samples)) { goto out; } + num_samples = tsk_treeseq_get_num_samples(self->tree_sequence); + if (py_samples != Py_None) { + samples_array = (PyArrayObject *) PyArray_FROMANY( + py_samples, NPY_INT32, 1, 1, NPY_ARRAY_IN_ARRAY); + if (samples_array == NULL) { + goto out; + } + shape = PyArray_DIMS(samples_array); + samples = PyArray_DATA(samples_array); + num_samples = (tsk_size_t) shape[0]; + } if (parse_windows(windows, &windows_array, &num_windows) != 0) { goto out; } - num_samples = tsk_treeseq_get_num_samples(self->tree_sequence); dims[0] = num_windows; dims[1] = num_samples; dims[2] = num_samples; @@ -9670,7 +9683,8 @@ TreeSequence_divergence_matrix(TreeSequence *self, PyObject *args, PyObject *kwd // clang-format off Py_BEGIN_ALLOW_THREADS err = tsk_treeseq_divergence_matrix( - self->tree_sequence, 0, NULL, + self->tree_sequence, + num_samples, samples, num_windows, PyArray_DATA(windows_array), options, PyArray_DATA(result_array)); Py_END_ALLOW_THREADS @@ -9685,6 +9699,8 @@ TreeSequence_divergence_matrix(TreeSequence *self, PyObject *args, PyObject *kwd result_array = NULL; out: Py_XDECREF(result_array); + Py_XDECREF(windows_array); + /* Py_XDECREF(samples_array); */ return ret; } diff --git a/python/tests/test_divmat.py b/python/tests/test_divmat.py index 206a97e344..193fd79135 100644 --- a/python/tests/test_divmat.py +++ b/python/tests/test_divmat.py @@ -260,14 +260,18 @@ def check_divmat( D1 = divergence_matrix(ts, windows=windows, samples=samples) if compare_stats_api: + # Somethings like duplicate samples aren't worth hacking around for in + # stats API. D2 = lib_divergence_matrix(ts, windows=windows, samples=samples) # print("windows = ", windows) # print(D1) # print(D2) np.testing.assert_allclose(D1, D2) - # D3 = ts.divergence_matrix(windows=windows) - # # print(D3) - # np.testing.assert_allclose(D1, D3) + assert D1.shape == D2.shape + D3 = ts.divergence_matrix(windows=windows, samples=samples) + # print(D3) + assert D1.shape == D3.shape + np.testing.assert_allclose(D1, D3) return D1 @@ -579,9 +583,9 @@ def test_disconnected_non_sample_topology(self): class TestThreadsNoWindows: - def check(self, ts, num_threads): - D1 = ts.divergence_matrix(num_threads=0) - D2 = ts.divergence_matrix(num_threads=num_threads) + def check(self, ts, num_threads, samples=None): + D1 = ts.divergence_matrix(num_threads=0, samples=samples) + D2 = ts.divergence_matrix(num_threads=num_threads, samples=samples) np.testing.assert_array_almost_equal(D1, D2) @pytest.mark.parametrize("num_threads", [1, 2, 3, 5, 26, 27]) @@ -590,6 +594,12 @@ def test_all_trees(self, num_threads): assert ts.num_trees == 26 self.check(ts, num_threads) + @pytest.mark.parametrize("samples", [None, [0, 1]]) + def test_all_trees_samples(self, samples): + ts = tsutil.all_trees_ts(4) + assert ts.num_trees == 26 + self.check(ts, 2, samples) + @pytest.mark.parametrize("n", [2, 3, 5, 15]) @pytest.mark.parametrize("num_threads", range(1, 5)) def test_simple_sims(self, n, num_threads): @@ -606,9 +616,11 @@ def test_simple_sims(self, n, num_threads): class TestThreadsWindows: - def check(self, ts, num_threads, *, windows): - D1 = ts.divergence_matrix(num_threads=0, windows=windows) - D2 = ts.divergence_matrix(num_threads=num_threads, windows=windows) + def check(self, ts, num_threads, *, windows, samples=None): + D1 = ts.divergence_matrix(num_threads=0, windows=windows, samples=samples) + D2 = ts.divergence_matrix( + num_threads=num_threads, windows=windows, samples=samples + ) np.testing.assert_array_almost_equal(D1, D2) @pytest.mark.parametrize("num_threads", [1, 2, 3, 5, 26, 27]) @@ -628,6 +640,18 @@ def test_all_trees(self, num_threads, windows): assert ts.num_trees == 26 self.check(ts, num_threads, windows=windows) + @pytest.mark.parametrize("samples", [None, [0, 1]]) + @pytest.mark.parametrize( + ["windows"], + [ + ([0, 26],), + (None,), + ], + ) + def test_all_trees_samples(self, samples, windows): + ts = tsutil.all_trees_ts(4) + self.check(ts, 2, windows=windows, samples=samples) + @pytest.mark.parametrize("num_threads", range(1, 5)) @pytest.mark.parametrize( ["windows"], diff --git a/python/tests/test_lowlevel.py b/python/tests/test_lowlevel.py index b9b731f0a5..5833a79da2 100644 --- a/python/tests/test_lowlevel.py +++ b/python/tests/test_lowlevel.py @@ -1534,12 +1534,16 @@ def test_divergence_matrix(self): ts = self.get_example_tree_sequence(n, random_seed=12) D = ts.divergence_matrix([0, ts.get_sequence_length()]) assert D.shape == (1, n, n) + D = ts.divergence_matrix([0, ts.get_sequence_length()], samples=[0, 1]) + assert D.shape == (1, 2, 2) with pytest.raises(TypeError): ts.divergence_matrix(windoze=[0, 1]) with pytest.raises(ValueError, match="at least 2"): ts.divergence_matrix(windows=[0]) with pytest.raises(_tskit.LibraryError, match="BAD_WINDOWS"): ts.divergence_matrix(windows=[-1, 0, 1]) + with pytest.raises(ValueError): + ts.divergence_matrix(windows=[0, 1], samples="sdf") def test_load_tables_build_indexes(self): for ts in self.get_example_tree_sequences(): diff --git a/python/tskit/trees.py b/python/tskit/trees.py index a8d75d8341..94195cbde6 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -7770,24 +7770,24 @@ def _chunk_windows(windows, num_chunks): # k += 1 # return A - # NOTE see older definition above that we didn't finish up. Are there things - # we should take from this? + # NOTE see older definition of divmat above that we didn't finish up. + # Are there things we should take from this? - def _parallelise_divmat_by_tree(self, num_threads): + def _parallelise_divmat_by_tree(self, num_threads, samples): """ No windows were specified, so we can chunk up the whole genome by tree, and do a simple sum of the results. """ def worker(interval): - return self._ll_tree_sequence.divergence_matrix(interval) + return self._ll_tree_sequence.divergence_matrix(interval, samples=samples) work = self._chunk_sequence_by_tree(num_threads) with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as pool: results = pool.map(worker, work) return sum(results) - def _parallelise_divmat_by_window(self, windows, num_threads): + def _parallelise_divmat_by_window(self, windows, num_threads, samples): """ We assume we have a number of windows that's >= to the number of threads available, and let each thread have a chunk of the @@ -7797,7 +7797,9 @@ def _parallelise_divmat_by_window(self, windows, num_threads): """ def worker(sub_windows): - return self._ll_tree_sequence.divergence_matrix(sub_windows) + return self._ll_tree_sequence.divergence_matrix( + sub_windows, samples=samples + ) work = self._chunk_windows(windows, num_threads) with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: @@ -7805,20 +7807,19 @@ def worker(sub_windows): concurrent.futures.wait(futures) return np.vstack([future.result() for future in futures]) - def divergence_matrix(self, *, windows=None, num_threads=0): - # TODO implement "samples" argument + def divergence_matrix(self, *, windows=None, samples=None, num_threads=0): windows_specified = windows is not None windows = [0, self.sequence_length] if windows is None else windows # NOTE: maybe we want to use a different default for num_threads here, just # following the approach in GNN if num_threads <= 0: - D = self._ll_tree_sequence.divergence_matrix(windows) + D = self._ll_tree_sequence.divergence_matrix(windows, samples=samples) else: if windows_specified: - D = self._parallelise_divmat_by_window(windows, num_threads) + D = self._parallelise_divmat_by_window(windows, num_threads, samples) else: - D = self._parallelise_divmat_by_tree(num_threads) + D = self._parallelise_divmat_by_tree(num_threads, samples) if not windows_specified: # Drop the windows dimension