From c00d375ef868fd789eada834db2e6575644d930c Mon Sep 17 00:00:00 2001 From: peter Date: Mon, 23 Jan 2023 12:19:08 -0800 Subject: [PATCH] updates? --- c/tests/test_stats.c | 20 +++++++ c/tskit/trees.c | 74 ++++++++++++++++++++++++++ c/tskit/trees.h | 11 ++++ python/_tskitmodule.c | 99 +++++++++++++++++++++++++++++++++++ python/tests/test_lowlevel.py | 89 +++++++++++++++++++++++++++++++ python/tskit/trees.py | 83 +++++++++++++++++++++++++++++ 6 files changed, 376 insertions(+) diff --git a/c/tests/test_stats.c b/c/tests/test_stats.c index 35991288d4..43f6b04544 100644 --- a/c/tests/test_stats.c +++ b/c/tests/test_stats.c @@ -1340,6 +1340,24 @@ test_paper_ex_genetic_relatedness(void) tsk_treeseq_free(&ts); } +static void +test_paper_ex_genetic_relatedness_weighted(void) +{ + tsk_treeseq_t ts; + double weights[] = { 1.2, 0.1, 0.0, 0.0, 3.4, 5.0, 1.0, -1.0 }; + tsk_id_t indexes[] = { 0, 0, 0, 1 }; + double result[2]; + int ret; + + tsk_treeseq_from_text(&ts, 10, paper_ex_nodes, paper_ex_edges, NULL, paper_ex_sites, + paper_ex_mutations, paper_ex_individuals, NULL, 0); + + ret = tsk_treeseq_genetic_relatedness_weighted( + &ts, 2, weights, 2, indexes, 0, NULL, result, TSK_STAT_SITE); + CU_ASSERT_EQUAL_FATAL(ret, 0); + tsk_treeseq_free(&ts); +} + static void test_paper_ex_genetic_relatedness_errors(void) { @@ -1773,6 +1791,8 @@ main(int argc, char **argv) { "test_paper_ex_genetic_relatedness_errors", test_paper_ex_genetic_relatedness_errors }, { "test_paper_ex_genetic_relatedness", test_paper_ex_genetic_relatedness }, + { "test_paper_ex_genetic_relatedness_weighted", + test_paper_ex_genetic_relatedness_weighted }, { "test_paper_ex_Y2_errors", test_paper_ex_Y2_errors }, { "test_paper_ex_Y2", test_paper_ex_Y2 }, { "test_paper_ex_f2_errors", test_paper_ex_f2_errors }, diff --git a/c/tskit/trees.c b/c/tskit/trees.c index 4fcb2ee376..a0d2fb82c1 100644 --- a/c/tskit/trees.c +++ b/c/tskit/trees.c @@ -2059,6 +2059,11 @@ typedef struct { const tsk_id_t *set_indexes; } sample_count_stat_params_t; +typedef struct { + double *total_weights; + const tsk_id_t *index_tuples; +} indexed_weight_stat_params_t; + static int tsk_treeseq_sample_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, @@ -3012,6 +3017,75 @@ tsk_treeseq_genetic_relatedness(const tsk_treeseq_t *self, tsk_size_t num_sample return ret; } +static int +genetic_relatedness_weighted_summary_func(tsk_size_t state_dim, const double *state, + tsk_size_t result_dim, double *result, void *params) +{ + indexed_weight_stat_params_t args = *(indexed_weight_stat_params_t *) params; + const double *x = state; + tsk_id_t i, j; + tsk_size_t k; + double meanx, ni, nj; + + meanx = state[state_dim - 1] / args.total_weights[state_dim - 1]; + for (k = 0; k < result_dim; k++) { + i = args.index_tuples[2 * k]; + j = args.index_tuples[2 * k + 1]; + ni = args.total_weights[i]; + nj = args.total_weights[j]; + result[k] = (x[i] - ni * meanx) * (x[j] - nj * meanx) / 2; + } + return 0; +} + +int +tsk_treeseq_genetic_relatedness_weighted(const tsk_treeseq_t *self, + tsk_size_t num_weights, const double *weights, tsk_size_t num_index_tuples, + const tsk_id_t *index_tuples, tsk_size_t num_windows, const double *windows, + double *result, tsk_flags_t options) +{ + int ret = 0; + tsk_size_t num_samples = self->num_samples; + size_t j, k; + indexed_weight_stat_params_t args; + const double *row; + double *new_row; + double *total_weights = tsk_calloc((num_weights + 1), sizeof(*total_weights)); + double *new_weights + = tsk_malloc((num_weights + 1) * num_samples * sizeof(*new_weights)); + + if (total_weights == NULL || new_weights == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + + // Add a column of ones to W + for (j = 0; j < num_samples; j++) { + row = GET_2D_ROW(weights, num_weights, j); + new_row = GET_2D_ROW(new_weights, num_weights + 1, j); + for (k = 0; k < num_weights; k++) { + new_row[k] = row[k]; + total_weights[k] += row[k]; + } + new_row[num_weights] = 1.0; + } + total_weights[num_weights] = (double) num_samples; + + args.total_weights = total_weights; + args.index_tuples = index_tuples; + ret = tsk_treeseq_general_stat(self, num_weights + 1, new_weights, num_index_tuples, + genetic_relatedness_weighted_summary_func, &args, num_windows, windows, options, + result); + if (ret != 0) { + goto out; + } + +out: + tsk_safe_free(total_weights); + tsk_safe_free(new_weights); + return ret; +} + static int Y2_summary_func(tsk_size_t TSK_UNUSED(state_dim), const double *state, tsk_size_t result_dim, double *result, void *params) diff --git a/c/tskit/trees.h b/c/tskit/trees.h index cae952dee3..17d2c06c22 100644 --- a/c/tskit/trees.h +++ b/c/tskit/trees.h @@ -970,6 +970,17 @@ int tsk_treeseq_trait_linear_model(const tsk_treeseq_t *self, tsk_size_t num_wei const double *weights, tsk_size_t num_covariates, const double *covariates, tsk_size_t num_windows, const double *windows, tsk_flags_t options, double *result); +/* Two way weighted stats with covariates */ + +typedef int two_way_weighted_method(const tsk_treeseq_t *self, tsk_size_t num_weights, + const double *weights, tsk_size_t num_index_tuples, const tsk_id_t *index_tuples, + tsk_size_t num_windows, const double *windows, double *result, tsk_flags_t options); + +int tsk_treeseq_genetic_relatedness_weighted(const tsk_treeseq_t *self, + tsk_size_t num_weights, const double *weights, tsk_size_t num_index_tuples, + const tsk_id_t *index_tuples, tsk_size_t num_windows, const double *windows, + double *result, tsk_flags_t options); + /* One way sample set stats */ typedef int one_way_sample_stat_method(const tsk_treeseq_t *self, diff --git a/python/_tskitmodule.c b/python/_tskitmodule.c index 22f78c1244..a531499fa0 100644 --- a/python/_tskitmodule.c +++ b/python/_tskitmodule.c @@ -9322,6 +9322,93 @@ TreeSequence_k_way_stat_method(TreeSequence *self, PyObject *args, PyObject *kwd return ret; } +static PyObject * +TreeSequence_k_way_weighted_stat_method(TreeSequence *self, PyObject *args, + PyObject *kwds, npy_intp tuple_size, two_way_weighted_method *method) +{ + PyObject *ret = NULL; + static char *kwlist[] = { "weights", "indexes", "windows", "mode", "span_normalise", + "polarised", NULL }; + PyObject *weights = NULL; + PyObject *indexes = NULL; + PyObject *windows = NULL; + PyArrayObject *weights_array = NULL; + PyArrayObject *indexes_array = NULL; + PyArrayObject *windows_array = NULL; + PyArrayObject *result_array = NULL; + tsk_size_t num_windows, num_index_tuples; + npy_intp *w_shape, *shape; + tsk_flags_t options = 0; + char *mode = NULL; + int span_normalise = true; + int polarised = false; + int err; + + if (TreeSequence_check_state(self) != 0) { + goto out; + } + if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOO|sii", kwlist, &weights, &indexes, + &windows, &mode, &span_normalise, &polarised)) { + goto out; + } + if (parse_stats_mode(mode, &options) != 0) { + goto out; + } + if (span_normalise) { + options |= TSK_STAT_SPAN_NORMALISE; + } + if (polarised) { + options |= TSK_STAT_POLARISED; + } + if (parse_windows(windows, &windows_array, &num_windows) != 0) { + goto out; + } + weights_array = (PyArrayObject *) PyArray_FROMANY( + weights, NPY_FLOAT64, 2, 2, NPY_ARRAY_IN_ARRAY); + if (weights_array == NULL) { + goto out; + } + w_shape = PyArray_DIMS(weights_array); + if (w_shape[0] != (npy_intp) tsk_treeseq_get_num_samples(self->tree_sequence)) { + PyErr_SetString(PyExc_ValueError, "First dimension must be num_samples"); + goto out; + } + + indexes_array = (PyArrayObject *) PyArray_FROMANY( + indexes, NPY_INT32, 2, 2, NPY_ARRAY_IN_ARRAY); + if (indexes_array == NULL) { + goto out; + } + shape = PyArray_DIMS(indexes_array); + if (shape[0] < 1 || shape[1] != tuple_size) { + PyErr_Format( + PyExc_ValueError, "indexes must be a k x %d array.", (int) tuple_size); + goto out; + } + num_index_tuples = shape[0]; + + result_array = TreeSequence_allocate_results_array( + self, options, num_windows, num_index_tuples); + if (result_array == NULL) { + goto out; + } + err = method(self->tree_sequence, w_shape[1], PyArray_DATA(weights_array), + num_index_tuples, PyArray_DATA(indexes_array), num_windows, + PyArray_DATA(windows_array), PyArray_DATA(result_array), options); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = (PyObject *) result_array; + result_array = NULL; +out: + Py_XDECREF(weights_array); + Py_XDECREF(indexes_array); + Py_XDECREF(windows_array); + Py_XDECREF(result_array); + return ret; +} + static PyObject * TreeSequence_divergence(TreeSequence *self, PyObject *args, PyObject *kwds) { @@ -9335,6 +9422,14 @@ TreeSequence_genetic_relatedness(TreeSequence *self, PyObject *args, PyObject *k self, args, kwds, 2, tsk_treeseq_genetic_relatedness); } +static PyObject * +TreeSequence_genetic_relatedness_weighted( + TreeSequence *self, PyObject *args, PyObject *kwds) +{ + return TreeSequence_k_way_weighted_stat_method( + self, args, kwds, 2, tsk_treeseq_genetic_relatedness_weighted); +} + static PyObject * TreeSequence_Y2(TreeSequence *self, PyObject *args, PyObject *kwds) { @@ -10049,6 +10144,10 @@ static PyMethodDef TreeSequence_methods[] = { .ml_meth = (PyCFunction) TreeSequence_genetic_relatedness, .ml_flags = METH_VARARGS | METH_KEYWORDS, .ml_doc = "Computes genetic relatedness between sample sets." }, + { .ml_name = "genetic_relatedness_weighted", + .ml_meth = (PyCFunction) TreeSequence_genetic_relatedness_weighted, + .ml_flags = METH_VARARGS | METH_KEYWORDS, + .ml_doc = "Computes genetic relatedness between weighted sums of samples." }, { .ml_name = "Y1", .ml_meth = (PyCFunction) TreeSequence_Y1, .ml_flags = METH_VARARGS | METH_KEYWORDS, diff --git a/python/tests/test_lowlevel.py b/python/tests/test_lowlevel.py index 8133c9d7eb..89e2c8e143 100644 --- a/python/tests/test_lowlevel.py +++ b/python/tests/test_lowlevel.py @@ -1663,6 +1663,15 @@ def test_window_errors(self): with pytest.raises(_tskit.LibraryError): f(windows=bad_window, **params) + def test_polarisation(self): + ts, f, params = self.get_example() + with pytest.raises(TypeError): + f(polarised="sdf", **params) + x1 = f(polarised=False, **params) + x2 = f(polarised=True, **params) + # Basic check just to run both code paths + assert x1.shape == x2.shape + def test_windows_output(self): ts, f, params = self.get_example() del params["windows"] @@ -2025,6 +2034,74 @@ def f(indexes): f(bad_dim) +class TwoWayWeightedStatsMixin(StatsInterfaceMixin): + """ + Tests for the weighted two way sample stats. + """ + + def get_example(self): + ts, method = self.get_method() + params = { + "weights": np.zeros((ts.get_num_samples(), 2)) + 0.5, + "indexes": [[0, 1]], + "windows": [0, ts.get_sequence_length()], + } + return ts, method, params + + def test_basic_example(self): + ts, method = self.get_method() + div = method( + np.zeros((ts.get_num_samples(), 1)) + 0.5, + [[0, 1]], + windows=[0, ts.get_sequence_length()], + ) + assert div.shape == (1, 1) + + def test_bad_weights(self): + ts, f, params = self.get_example() + del params["weights"] + n = ts.get_num_samples() + + for bad_weight_shape in [(n - 1, 1), (n + 1, 1), (0, 3)]: + with pytest.raises(ValueError): + f(weights=np.ones(bad_weight_shape), **params) + + def test_output_dims(self): + ts, method, params = self.get_example() + weights = params.pop("weights") + params["windows"] = [0, ts.get_sequence_length()] + + for mode in ["site", "branch"]: + out = method(weights[:, [0]], mode=mode, **params) + assert out.shape == (1, 1) + out = method(weights, mode=mode, **params) + assert out.shape == (1, 1) + out = method(weights[:, [0, 0, 0]], mode=mode, **params) + assert out.shape == (1, 1) + mode = "node" + N = ts.get_num_nodes() + out = method(weights[:, [0]], mode=mode, **params) + assert out.shape == (1, N, 1) + out = method(weights, mode=mode, **params) + assert out.shape == (1, N, 1) + out = method(weights[:, [0, 0, 0]], mode=mode, **params) + assert out.shape == (1, N, 1) + + def test_set_index_errors(self): + ts, method, params = self.get_example() + del params["indexes"] + + def f(indexes): + method(indexes=indexes, **params) + + for bad_array in ["wer", {}, [[[], []], [[], []]]]: + with pytest.raises(ValueError): + f(bad_array) + for bad_dim in [[[]], [[1], [1]]]: + with pytest.raises(ValueError): + f(bad_dim) + + class ThreeWaySampleStatsMixin(SampleSetMixin): """ Tests for the two way sample stats. @@ -2211,6 +2288,12 @@ def get_method(self): return ts, ts.f2 +class TestGeneticRelatedness(LowLevelTestCase, TwoWaySampleStatsMixin): + def get_method(self): + ts = self.get_example_tree_sequence() + return ts, ts.genetic_relatedness + + class TestY3(LowLevelTestCase, ThreeWaySampleStatsMixin): def get_method(self): ts = self.get_example_tree_sequence() @@ -2229,6 +2312,12 @@ def get_method(self): return ts, ts.f4 +class TestWeightedGeneticRelatedness(LowLevelTestCase, TwoWayWeightedStatsMixin): + def get_method(self): + ts = self.get_example_tree_sequence() + return ts, ts.genetic_relatedness_weighted + + class TestGeneralStatsInterface(LowLevelTestCase, StatsInterfaceMixin): """ Tests for the general stats interface. diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 80a3028da3..5393dfc53d 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -7555,6 +7555,47 @@ def __k_way_sample_set_stat( stat = stat[()] return stat + def __k_way_weighted_stat( + self, + ll_method, + k, + W, + indexes=None, + windows=None, + mode=None, + span_normalise=True, + polarised=False, + ): + if indexes is None: + if W.shape[1] != k: + raise ValueError( + "Must specify indexes if there are not exactly {} columsn " + "in W.".format(k) + ) + indexes = np.arange(k, dtype=np.int32) + drop_dimension = False + indexes = util.safe_np_int_cast(indexes, np.int32) + if len(indexes.shape) == 1: + indexes = indexes.reshape((1, indexes.shape[0])) + drop_dimension = True + if len(indexes.shape) != 2 or indexes.shape[1] != k: + raise ValueError( + "Indexes must be convertable to a 2D numpy array with {} " + "columns".format(k) + ) + stat = self.__run_windowed_stat( + windows, + ll_method, + W, + indexes, + mode=mode, + span_normalise=span_normalise, + polarised=polarised, + ) + if drop_dimension: + stat = stat.reshape(stat.shape[:-1]) + return stat + ############################################ # Statistics definitions ############################################ @@ -7840,6 +7881,48 @@ def genetic_relatedness( return out + def genetic_relatedness_weighted( + self, + W, + indexes=None, + windows=None, + mode="site", + span_normalise=True, + polarised=False, + ): + r""" + Computes weighted genetic relatedness: if the k-th pair of indices is (i, j) + then the k-th column of output will be + :math:`\sum_{a,b} W_{ai} W_{bj} C_{ab}`, + where :math:`W` is the matrix of weights, and :math:`C_{ab}` is the + {meth}`.genetic_relatedness` between sample i and sample j. + + :param numpy.ndarray W: An array of values with one row for each sample and one + column for each set of weights. + :param list indexes: A list of 2-tuples, or None. + :param list windows: An increasing list of breakpoints between the windows + to compute the statistic in. + :param str mode: A string giving the "type" of the statistic to be computed + (defaults to "site"). + :param bool span_normalise: Whether to divide the result by the span of the + window (defaults to True). + :return: A ndarray with shape equal to (num windows, num statistics). + """ + if W.shape[0] != self.num_samples: + raise ValueError( + "First trait dimension must be equal to number of samples." + ) + return self.__k_way_weighted_stat( + self._ll_tree_sequence.genetic_relatedness_weighted, + 2, + W, + indexes=indexes, + windows=windows, + mode=mode, + span_normalise=span_normalise, + polarised=polarised, + ) + def trait_covariance(self, W, windows=None, mode="site", span_normalise=True): """ Computes the mean squared covariances between each of the columns of ``W``