Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

"matrix multiplication" statistic #1246

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions c/tests/test_stats.c
Original file line number Diff line number Diff line change
Expand Up @@ -1504,6 +1504,24 @@ test_paper_ex_genetic_relatedness(void)
tsk_treeseq_free(&ts);
}

static void
test_paper_ex_genetic_relatedness_weighted(void)
{
tsk_treeseq_t ts;
double weights[] = { 1.2, 0.1, 0.0, 0.0, 3.4, 5.0, 1.0, -1.0 };
tsk_id_t indexes[] = { 0, 0, 0, 1 };
double result[2];
int ret;

tsk_treeseq_from_text(&ts, 10, paper_ex_nodes, paper_ex_edges, NULL, paper_ex_sites,
paper_ex_mutations, paper_ex_individuals, NULL, 0);

ret = tsk_treeseq_genetic_relatedness_weighted(
&ts, 2, weights, 2, indexes, 0, NULL, result, TSK_STAT_SITE);
CU_ASSERT_EQUAL_FATAL(ret, 0);
tsk_treeseq_free(&ts);
}

static void
test_paper_ex_genetic_relatedness_errors(void)
{
Expand Down Expand Up @@ -2108,6 +2126,8 @@ main(int argc, char **argv)
{ "test_paper_ex_genetic_relatedness_errors",
test_paper_ex_genetic_relatedness_errors },
{ "test_paper_ex_genetic_relatedness", test_paper_ex_genetic_relatedness },
{ "test_paper_ex_genetic_relatedness_weighted",
test_paper_ex_genetic_relatedness_weighted },
{ "test_paper_ex_Y2_errors", test_paper_ex_Y2_errors },
{ "test_paper_ex_Y2", test_paper_ex_Y2 },
{ "test_paper_ex_f2_errors", test_paper_ex_f2_errors },
Expand Down
74 changes: 74 additions & 0 deletions c/tskit/trees.c
Original file line number Diff line number Diff line change
Expand Up @@ -2071,6 +2071,11 @@
const tsk_id_t *set_indexes;
} sample_count_stat_params_t;

typedef struct {
double *total_weights;
const tsk_id_t *index_tuples;
} indexed_weight_stat_params_t;

static int
tsk_treeseq_sample_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sample_sets,
const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets,
Expand Down Expand Up @@ -3025,6 +3030,75 @@
return ret;
}

static int
genetic_relatedness_weighted_summary_func(tsk_size_t state_dim, const double *state,
tsk_size_t result_dim, double *result, void *params)
{
indexed_weight_stat_params_t args = *(indexed_weight_stat_params_t *) params;
const double *x = state;
tsk_id_t i, j;
tsk_size_t k;
double meanx, ni, nj;

meanx = state[state_dim - 1] / args.total_weights[state_dim - 1];
for (k = 0; k < result_dim; k++) {
i = args.index_tuples[2 * k];
j = args.index_tuples[2 * k + 1];
ni = args.total_weights[i];
nj = args.total_weights[j];
result[k] = (x[i] - ni * meanx) * (x[j] - nj * meanx) / 2;
}
return 0;
}

int
tsk_treeseq_genetic_relatedness_weighted(const tsk_treeseq_t *self,
tsk_size_t num_weights, const double *weights, tsk_size_t num_index_tuples,
const tsk_id_t *index_tuples, tsk_size_t num_windows, const double *windows,
double *result, tsk_flags_t options)
{
int ret = 0;
tsk_size_t num_samples = self->num_samples;
size_t j, k;
indexed_weight_stat_params_t args;
const double *row;
double *new_row;
double *total_weights = tsk_calloc((num_weights + 1), sizeof(*total_weights));
double *new_weights
= tsk_malloc((num_weights + 1) * num_samples * sizeof(*new_weights));

if (total_weights == NULL || new_weights == NULL) {
ret = TSK_ERR_NO_MEMORY;
goto out;

Check warning on line 3072 in c/tskit/trees.c

View check run for this annotation

Codecov / codecov/patch

c/tskit/trees.c#L3071-L3072

Added lines #L3071 - L3072 were not covered by tests
}

// Add a column of ones to W
for (j = 0; j < num_samples; j++) {
row = GET_2D_ROW(weights, num_weights, j);
new_row = GET_2D_ROW(new_weights, num_weights + 1, j);
for (k = 0; k < num_weights; k++) {
new_row[k] = row[k];
total_weights[k] += row[k];
}
new_row[num_weights] = 1.0;
}
total_weights[num_weights] = (double) num_samples;

args.total_weights = total_weights;
args.index_tuples = index_tuples;
ret = tsk_treeseq_general_stat(self, num_weights + 1, new_weights, num_index_tuples,
genetic_relatedness_weighted_summary_func, &args, num_windows, windows, options,
result);
if (ret != 0) {
goto out;

Check warning on line 3093 in c/tskit/trees.c

View check run for this annotation

Codecov / codecov/patch

c/tskit/trees.c#L3093

Added line #L3093 was not covered by tests
}

out:
tsk_safe_free(total_weights);
tsk_safe_free(new_weights);
return ret;
}

static int
Y2_summary_func(tsk_size_t TSK_UNUSED(state_dim), const double *state,
tsk_size_t result_dim, double *result, void *params)
Expand Down
11 changes: 11 additions & 0 deletions c/tskit/trees.h
Original file line number Diff line number Diff line change
Expand Up @@ -943,6 +943,17 @@ int tsk_treeseq_trait_linear_model(const tsk_treeseq_t *self, tsk_size_t num_wei
const double *weights, tsk_size_t num_covariates, const double *covariates,
tsk_size_t num_windows, const double *windows, tsk_flags_t options, double *result);

/* Two way weighted stats with covariates */

typedef int two_way_weighted_method(const tsk_treeseq_t *self, tsk_size_t num_weights,
const double *weights, tsk_size_t num_index_tuples, const tsk_id_t *index_tuples,
tsk_size_t num_windows, const double *windows, double *result, tsk_flags_t options);

int tsk_treeseq_genetic_relatedness_weighted(const tsk_treeseq_t *self,
tsk_size_t num_weights, const double *weights, tsk_size_t num_index_tuples,
const tsk_id_t *index_tuples, tsk_size_t num_windows, const double *windows,
double *result, tsk_flags_t options);

/* One way sample set stats */

typedef int one_way_sample_stat_method(const tsk_treeseq_t *self,
Expand Down
1 change: 1 addition & 0 deletions docs/python-api.md
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,7 @@ Single site
TreeSequence.Fst
TreeSequence.genealogical_nearest_neighbours
TreeSequence.genetic_relatedness
TreeSequence.genetic_relatedness_weighted
TreeSequence.general_stat
TreeSequence.segregating_sites
TreeSequence.sample_count_stat
Expand Down
7 changes: 7 additions & 0 deletions docs/stats.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ appears beside the listed method.
* Multi-way
* {meth}`~TreeSequence.divergence`
* {meth}`~TreeSequence.genetic_relatedness`
{meth}`~TreeSequence.genetic_relatedness_weighted`
* {meth}`~TreeSequence.f4`
{meth}`~TreeSequence.f3`
{meth}`~TreeSequence.f2`
Expand Down Expand Up @@ -593,6 +594,12 @@ and boolean expressions (e.g., {math}`(x > 0)`) are interpreted as 0/1.
where {math}`m = \frac{1}{n}\sum_{k=1}^n x_k` with {math}`n` the total number
of samples.

`genetic_relatedness_weighted`
: {math}`f(w_i, w_j, x_i, x_j) = \frac{1}{2}(x_i - w_i m) (x_j - w_j m)`,

where {math}`m = \frac{1}{n}\sum_{k=1}^n x_k` with {math}`n` the total number
of samples, and {math}`w_j = \sum_{k=1}^n W_kj` is the sum of the weights in the {math}`j`th column of the weight matrix.

`Y2`
: {math}`f(x_1, x_2) = \frac{x_1 (n_2 - x_2) (n_2 - x_2 - 1)}{n_1 n_2 (n_2 - 1)}`

Expand Down
99 changes: 99 additions & 0 deletions python/_tskitmodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -9595,6 +9595,93 @@
return ret;
}

static PyObject *
TreeSequence_k_way_weighted_stat_method(TreeSequence *self, PyObject *args,
PyObject *kwds, npy_intp tuple_size, two_way_weighted_method *method)
{
PyObject *ret = NULL;
static char *kwlist[] = { "weights", "indexes", "windows", "mode", "span_normalise",
"polarised", NULL };
PyObject *weights = NULL;
PyObject *indexes = NULL;
PyObject *windows = NULL;
PyArrayObject *weights_array = NULL;
PyArrayObject *indexes_array = NULL;
PyArrayObject *windows_array = NULL;
PyArrayObject *result_array = NULL;
tsk_size_t num_windows, num_index_tuples;
npy_intp *w_shape, *shape;
tsk_flags_t options = 0;
char *mode = NULL;
int span_normalise = true;
int polarised = false;
int err;

if (TreeSequence_check_state(self) != 0) {
goto out;
}
if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOO|sii", kwlist, &weights, &indexes,
&windows, &mode, &span_normalise, &polarised)) {
goto out;
}
if (parse_stats_mode(mode, &options) != 0) {
goto out;
}
if (span_normalise) {
options |= TSK_STAT_SPAN_NORMALISE;
}
if (polarised) {
options |= TSK_STAT_POLARISED;
}
if (parse_windows(windows, &windows_array, &num_windows) != 0) {
goto out;
}
weights_array = (PyArrayObject *) PyArray_FROMANY(
weights, NPY_FLOAT64, 2, 2, NPY_ARRAY_IN_ARRAY);
if (weights_array == NULL) {
goto out;

Check warning on line 9642 in python/_tskitmodule.c

View check run for this annotation

Codecov / codecov/patch

python/_tskitmodule.c#L9642

Added line #L9642 was not covered by tests
}
w_shape = PyArray_DIMS(weights_array);
if (w_shape[0] != (npy_intp) tsk_treeseq_get_num_samples(self->tree_sequence)) {
PyErr_SetString(PyExc_ValueError, "First dimension must be num_samples");
goto out;
}

indexes_array = (PyArrayObject *) PyArray_FROMANY(
indexes, NPY_INT32, 2, 2, NPY_ARRAY_IN_ARRAY);
if (indexes_array == NULL) {
goto out;
}
shape = PyArray_DIMS(indexes_array);
if (shape[0] < 1 || shape[1] != tuple_size) {
PyErr_Format(
PyExc_ValueError, "indexes must be a k x %d array.", (int) tuple_size);
goto out;
}
num_index_tuples = shape[0];

result_array = TreeSequence_allocate_results_array(
self, options, num_windows, num_index_tuples);
if (result_array == NULL) {
goto out;

Check warning on line 9666 in python/_tskitmodule.c

View check run for this annotation

Codecov / codecov/patch

python/_tskitmodule.c#L9666

Added line #L9666 was not covered by tests
}
err = method(self->tree_sequence, w_shape[1], PyArray_DATA(weights_array),
num_index_tuples, PyArray_DATA(indexes_array), num_windows,
PyArray_DATA(windows_array), PyArray_DATA(result_array), options);
if (err != 0) {
handle_library_error(err);
goto out;
}
ret = (PyObject *) result_array;
result_array = NULL;
out:
Py_XDECREF(weights_array);
Py_XDECREF(indexes_array);
Py_XDECREF(windows_array);
Py_XDECREF(result_array);
return ret;
}

static PyObject *
TreeSequence_divergence(TreeSequence *self, PyObject *args, PyObject *kwds)
{
Expand All @@ -9608,6 +9695,14 @@
self, args, kwds, 2, tsk_treeseq_genetic_relatedness);
}

static PyObject *
TreeSequence_genetic_relatedness_weighted(
TreeSequence *self, PyObject *args, PyObject *kwds)
{
return TreeSequence_k_way_weighted_stat_method(
self, args, kwds, 2, tsk_treeseq_genetic_relatedness_weighted);
}

static PyObject *
TreeSequence_Y2(TreeSequence *self, PyObject *args, PyObject *kwds)
{
Expand Down Expand Up @@ -10394,6 +10489,10 @@
.ml_meth = (PyCFunction) TreeSequence_genetic_relatedness,
.ml_flags = METH_VARARGS | METH_KEYWORDS,
.ml_doc = "Computes genetic relatedness between sample sets." },
{ .ml_name = "genetic_relatedness_weighted",
.ml_meth = (PyCFunction) TreeSequence_genetic_relatedness_weighted,
.ml_flags = METH_VARARGS | METH_KEYWORDS,
.ml_doc = "Computes genetic relatedness between weighted sums of samples." },
{ .ml_name = "Y1",
.ml_meth = (PyCFunction) TreeSequence_Y1,
.ml_flags = METH_VARARGS | METH_KEYWORDS,
Expand Down
Loading
Loading