From b7cb129fb96a72b3f0bedcc27d433201c964b122 Mon Sep 17 00:00:00 2001 From: peter Date: Mon, 15 Mar 2021 07:13:51 -0700 Subject: [PATCH] matmul plumbing --- c/tests/test_stats.c | 20 +++++++++ c/tskit/trees.c | 82 +++++++++++++++++++++++++++++++++++ c/tskit/trees.h | 11 +++++ python/_tskitmodule.c | 99 +++++++++++++++++++++++++++++++++++++++++++ python/tskit/trees.py | 83 ++++++++++++++++++++++++++++++++++++ 5 files changed, 295 insertions(+) diff --git a/c/tests/test_stats.c b/c/tests/test_stats.c index 26a36ddb46..ffaacede54 100644 --- a/c/tests/test_stats.c +++ b/c/tests/test_stats.c @@ -1352,6 +1352,24 @@ test_paper_ex_genetic_relatedness(void) tsk_treeseq_free(&ts); } +static void +test_paper_ex_genetic_relatedness_weighted(void) +{ + tsk_treeseq_t ts; + double weights[] = { 1.2, 0.1, 0.0, 0.0, 3.4, 5.0, 1.0, -1.0 }; + tsk_id_t indexes[] = { 0, 0, 0, 1 }; + double result[2]; + int ret; + + tsk_treeseq_from_text(&ts, 10, paper_ex_nodes, paper_ex_edges, NULL, paper_ex_sites, + paper_ex_mutations, paper_ex_individuals, NULL, 0); + + ret = tsk_treeseq_genetic_relatedness_weighted( + &ts, 2, weights, 2, indexes, 0, NULL, result, TSK_STAT_SITE); + CU_ASSERT_EQUAL_FATAL(ret, 0); + tsk_treeseq_free(&ts); +} + static void test_paper_ex_genetic_relatedness_errors(void) { @@ -1712,6 +1730,8 @@ main(int argc, char **argv) { "test_paper_ex_genetic_relatedness_errors", test_paper_ex_genetic_relatedness_errors }, { "test_paper_ex_genetic_relatedness", test_paper_ex_genetic_relatedness }, + { "test_paper_ex_genetic_relatedness_weighted", + test_paper_ex_genetic_relatedness_weighted }, { "test_paper_ex_Y2_errors", test_paper_ex_Y2_errors }, { "test_paper_ex_Y2", test_paper_ex_Y2 }, { "test_paper_ex_f2_errors", test_paper_ex_f2_errors }, diff --git a/c/tskit/trees.c b/c/tskit/trees.c index 5338f4b6cf..d6f0728b6d 100644 --- a/c/tskit/trees.c +++ b/c/tskit/trees.c @@ -1851,6 +1851,11 @@ typedef struct { const tsk_id_t *set_indexes; } sample_count_stat_params_t; +typedef struct { + double *total_weights; + const tsk_id_t *index_tuples; +} indexed_weight_stat_params_t; + static int tsk_treeseq_sample_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, @@ -2799,6 +2804,83 @@ tsk_treeseq_genetic_relatedness(const tsk_treeseq_t *self, tsk_size_t num_sample return ret; } +static int +genetic_relatedness_weighted_summary_func(size_t state_dim, const double *state, + size_t result_dim, double *result, void *params) +{ + indexed_weight_stat_params_t args = *(indexed_weight_stat_params_t *) params; + const double *x = state; + tsk_id_t i, j; + size_t k; + double meanx, ni, nj; + + meanx = state[state_dim - 1] / args.total_weights[state_dim - 1]; + ; + for (k = 0; k < result_dim; k++) { + i = args.index_tuples[2 * k]; + j = args.index_tuples[2 * k + 1]; + ni = args.total_weights[i]; + nj = args.total_weights[j]; + result[k] = (x[i] - ni * meanx) * (x[j] - nj * meanx) / 2; + } + return 0; +} + +int +tsk_treeseq_genetic_relatedness_weighted(const tsk_treeseq_t *self, + tsk_size_t num_weights, const double *weights, tsk_size_t num_index_tuples, + const tsk_id_t *index_tuples, tsk_size_t num_windows, const double *windows, + double *result, tsk_flags_t options) +{ + int ret = 0; + tsk_size_t num_samples = self->num_samples; + size_t j, k; + indexed_weight_stat_params_t args; + const double *row; + double *new_row; + double *total_weights = malloc((num_weights + 1) * sizeof(*total_weights)); + double *new_weights = malloc((num_weights + 1) * num_samples * sizeof(*new_weights)); + + if (total_weights == NULL || new_weights == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + + // add a column of ones to W + for (k = 0; k < num_samples; k++) { + row = GET_2D_ROW(weights, num_weights, k); + new_row = GET_2D_ROW(new_weights, num_weights + 1, k); + for (j = 0; j < num_weights; j++) { + new_row[j] = row[j]; + } + new_row[num_weights] = 1.0; + } + + /* TODO: sanity check indexes */ + + for (j = 0; j < num_samples; j++) { + row = GET_2D_ROW(new_weights, num_weights + 1, j); + for (k = 0; k < num_weights + 1; k++) { + total_weights[k] += row[k]; + } + } + + args.total_weights = total_weights; + args.index_tuples = index_tuples; + + ret = tsk_treeseq_general_stat(self, num_weights + 1, new_weights, num_index_tuples, + genetic_relatedness_weighted_summary_func, &args, num_windows, windows, result, + options); + if (ret != 0) { + goto out; + } + +out: + tsk_safe_free(total_weights); + tsk_safe_free(new_weights); + return ret; +} + static int Y2_summary_func(size_t TSK_UNUSED(state_dim), const double *state, size_t result_dim, double *result, void *params) diff --git a/c/tskit/trees.h b/c/tskit/trees.h index d9fb3f12f2..fc79919917 100644 --- a/c/tskit/trees.h +++ b/c/tskit/trees.h @@ -317,6 +317,17 @@ int tsk_treeseq_trait_linear_model(const tsk_treeseq_t *self, tsk_size_t num_wei const double *weights, tsk_size_t num_covariates, const double *covariates, tsk_size_t num_windows, const double *windows, double *result, tsk_flags_t options); +/* Two way weighted stats with covariates */ + +typedef int two_way_weighted_method(const tsk_treeseq_t *self, tsk_size_t num_weights, + const double *weights, tsk_size_t num_index_tuples, const tsk_id_t *index_tuples, + tsk_size_t num_windows, const double *windows, double *result, tsk_flags_t options); + +int tsk_treeseq_genetic_relatedness_weighted(const tsk_treeseq_t *self, + tsk_size_t num_weights, const double *weights, tsk_size_t num_index_tuples, + const tsk_id_t *index_tuples, tsk_size_t num_windows, const double *windows, + double *result, tsk_flags_t options); + /* One way sample set stats */ typedef int one_way_sample_stat_method(const tsk_treeseq_t *self, diff --git a/python/_tskitmodule.c b/python/_tskitmodule.c index 5e61f37e4d..23d2957c64 100644 --- a/python/_tskitmodule.c +++ b/python/_tskitmodule.c @@ -7453,6 +7453,93 @@ TreeSequence_k_way_stat_method(TreeSequence *self, PyObject *args, PyObject *kwd return ret; } +static PyObject * +TreeSequence_k_way_weighted_stat_method(TreeSequence *self, PyObject *args, + PyObject *kwds, npy_intp tuple_size, two_way_weighted_method *method) +{ + PyObject *ret = NULL; + static char *kwlist[] = { "weights", "indexes", "windows", "mode", "span_normalise", + "polarised", NULL }; + PyObject *weights = NULL; + PyObject *indexes = NULL; + PyObject *windows = NULL; + PyArrayObject *weights_array = NULL; + PyArrayObject *indexes_array = NULL; + PyArrayObject *windows_array = NULL; + PyArrayObject *result_array = NULL; + tsk_size_t num_windows, num_index_tuples; + npy_intp *w_shape, *shape; + tsk_flags_t options = 0; + char *mode = NULL; + int span_normalise = true; + int polarised = false; + int err; + + if (TreeSequence_check_state(self) != 0) { + goto out; + } + if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOO|sii", kwlist, &weights, &indexes, + &windows, &mode, &span_normalise, &polarised)) { + goto out; + } + if (parse_stats_mode(mode, &options) != 0) { + goto out; + } + if (span_normalise) { + options |= TSK_STAT_SPAN_NORMALISE; + } + if (polarised) { + options |= TSK_STAT_POLARISED; + } + if (parse_windows(windows, &windows_array, &num_windows) != 0) { + goto out; + } + weights_array = (PyArrayObject *) PyArray_FROMANY( + weights, NPY_FLOAT64, 2, 2, NPY_ARRAY_IN_ARRAY); + if (weights_array == NULL) { + goto out; + } + w_shape = PyArray_DIMS(weights_array); + if (w_shape[0] != tsk_treeseq_get_num_samples(self->tree_sequence)) { + PyErr_SetString(PyExc_ValueError, "First dimension must be num_samples"); + goto out; + } + + indexes_array = (PyArrayObject *) PyArray_FROMANY( + indexes, NPY_INT32, 2, 2, NPY_ARRAY_IN_ARRAY); + if (indexes_array == NULL) { + goto out; + } + shape = PyArray_DIMS(indexes_array); + if (shape[0] < 1 || shape[1] != tuple_size) { + PyErr_Format( + PyExc_ValueError, "indexes must be a k x %d array.", (int) tuple_size); + goto out; + } + num_index_tuples = shape[0]; + + result_array = TreeSequence_allocate_results_array( + self, options, num_windows, num_index_tuples); + if (result_array == NULL) { + goto out; + } + err = method(self->tree_sequence, w_shape[1], PyArray_DATA(weights_array), + num_index_tuples, PyArray_DATA(indexes_array), num_windows, + PyArray_DATA(windows_array), PyArray_DATA(result_array), options); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = (PyObject *) result_array; + result_array = NULL; +out: + Py_XDECREF(weights_array); + Py_XDECREF(indexes_array); + Py_XDECREF(windows_array); + Py_XDECREF(result_array); + return ret; +} + static PyObject * TreeSequence_divergence(TreeSequence *self, PyObject *args, PyObject *kwds) { @@ -7466,6 +7553,14 @@ TreeSequence_genetic_relatedness(TreeSequence *self, PyObject *args, PyObject *k self, args, kwds, 2, tsk_treeseq_genetic_relatedness); } +static PyObject * +TreeSequence_genetic_relatedness_weighted( + TreeSequence *self, PyObject *args, PyObject *kwds) +{ + return TreeSequence_k_way_weighted_stat_method( + self, args, kwds, 2, tsk_treeseq_genetic_relatedness_weighted); +} + static PyObject * TreeSequence_Y2(TreeSequence *self, PyObject *args, PyObject *kwds) { @@ -7787,6 +7882,10 @@ static PyMethodDef TreeSequence_methods[] = { .ml_meth = (PyCFunction) TreeSequence_genetic_relatedness, .ml_flags = METH_VARARGS | METH_KEYWORDS, .ml_doc = "Computes genetic relatedness between sample sets." }, + { .ml_name = "genetic_relatedness_weighted", + .ml_meth = (PyCFunction) TreeSequence_genetic_relatedness_weighted, + .ml_flags = METH_VARARGS | METH_KEYWORDS, + .ml_doc = "Computes genetic relatedness between weighted sums of samples." }, { .ml_name = "Y1", .ml_meth = (PyCFunction) TreeSequence_Y1, .ml_flags = METH_VARARGS | METH_KEYWORDS, diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 78d1193455..47b1ed9898 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -5751,6 +5751,47 @@ def __k_way_sample_set_stat( stat = stat.reshape(stat.shape[:-1]) return stat + def __k_way_weighted_stat( + self, + ll_method, + k, + W, + indexes=None, + windows=None, + mode=None, + span_normalise=True, + polarised=False, + ): + if indexes is None: + if W.shape[1] != k: + raise ValueError( + "Must specify indexes if there are not exactly {} columsn " + "in W.".format(k) + ) + indexes = np.arange(k, dtype=np.int32) + drop_dimension = False + indexes = util.safe_np_int_cast(indexes, np.int32) + if len(indexes.shape) == 1: + indexes = indexes.reshape((1, indexes.shape[0])) + drop_dimension = True + if len(indexes.shape) != 2 or indexes.shape[1] != k: + raise ValueError( + "Indexes must be convertable to a 2D numpy array with {} " + "columns".format(k) + ) + stat = self.__run_windowed_stat( + windows, + ll_method, + W, + indexes, + mode=mode, + span_normalise=span_normalise, + polarised=polarised, + ) + if drop_dimension: + stat = stat.reshape(stat.shape[:-1]) + return stat + ############################################ # Statistics definitions ############################################ @@ -6012,6 +6053,48 @@ def genetic_relatedness( return out + def genetic_relatedness_weighted( + self, + W, + indexes=None, + windows=None, + mode="site", + span_normalise=True, + polarised=False, + ): + r""" + Computes weighted genetic relatedness: if the k-th pair of indices is (i, j) + then the k-th column of output will be + :math:`\sum_{a,b} W_{ai} W_{bj} C_{ab}`, + where :math:`W` is the matrix of weights, and :math:`C_{ab}` is the + {meth}`.genetic_relatedness` between sample i and sample j. + + :param numpy.ndarray W: An array of values with one row for each sample and one + column for each set of weights. + :param list indexes: A list of 2-tuples, or None. + :param list windows: An increasing list of breakpoints between the windows + to compute the statistic in. + :param str mode: A string giving the "type" of the statistic to be computed + (defaults to "site"). + :param bool span_normalise: Whether to divide the result by the span of the + window (defaults to True). + :return: A ndarray with shape equal to (num windows, num statistics). + """ + if W.shape[0] != self.num_samples: + raise ValueError( + "First trait dimension must be equal to number of samples." + ) + return self.__k_way_weighted_stat( + self._ll_tree_sequence.genetic_relatedness_weighted, + 2, + W, + indexes=indexes, + windows=windows, + mode=mode, + span_normalise=span_normalise, + polarised=polarised, + ) + def trait_covariance(self, W, windows=None, mode="site", span_normalise=True): """ Computes the mean squared covariances between each of the columns of ``W``