Skip to content

Commit

Permalink
matmul plumbing
Browse files Browse the repository at this point in the history
  • Loading branch information
petrelharp committed Mar 16, 2021
1 parent fbba717 commit b7cb129
Show file tree
Hide file tree
Showing 5 changed files with 295 additions and 0 deletions.
20 changes: 20 additions & 0 deletions c/tests/test_stats.c
Original file line number Diff line number Diff line change
Expand Up @@ -1352,6 +1352,24 @@ test_paper_ex_genetic_relatedness(void)
tsk_treeseq_free(&ts);
}

static void
test_paper_ex_genetic_relatedness_weighted(void)
{
tsk_treeseq_t ts;
double weights[] = { 1.2, 0.1, 0.0, 0.0, 3.4, 5.0, 1.0, -1.0 };
tsk_id_t indexes[] = { 0, 0, 0, 1 };
double result[2];
int ret;

tsk_treeseq_from_text(&ts, 10, paper_ex_nodes, paper_ex_edges, NULL, paper_ex_sites,
paper_ex_mutations, paper_ex_individuals, NULL, 0);

ret = tsk_treeseq_genetic_relatedness_weighted(
&ts, 2, weights, 2, indexes, 0, NULL, result, TSK_STAT_SITE);
CU_ASSERT_EQUAL_FATAL(ret, 0);
tsk_treeseq_free(&ts);
}

static void
test_paper_ex_genetic_relatedness_errors(void)
{
Expand Down Expand Up @@ -1712,6 +1730,8 @@ main(int argc, char **argv)
{ "test_paper_ex_genetic_relatedness_errors",
test_paper_ex_genetic_relatedness_errors },
{ "test_paper_ex_genetic_relatedness", test_paper_ex_genetic_relatedness },
{ "test_paper_ex_genetic_relatedness_weighted",
test_paper_ex_genetic_relatedness_weighted },
{ "test_paper_ex_Y2_errors", test_paper_ex_Y2_errors },
{ "test_paper_ex_Y2", test_paper_ex_Y2 },
{ "test_paper_ex_f2_errors", test_paper_ex_f2_errors },
Expand Down
82 changes: 82 additions & 0 deletions c/tskit/trees.c
Original file line number Diff line number Diff line change
Expand Up @@ -1851,6 +1851,11 @@ typedef struct {
const tsk_id_t *set_indexes;
} sample_count_stat_params_t;

typedef struct {
double *total_weights;
const tsk_id_t *index_tuples;
} indexed_weight_stat_params_t;

static int
tsk_treeseq_sample_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sample_sets,
const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets,
Expand Down Expand Up @@ -2799,6 +2804,83 @@ tsk_treeseq_genetic_relatedness(const tsk_treeseq_t *self, tsk_size_t num_sample
return ret;
}

static int
genetic_relatedness_weighted_summary_func(size_t state_dim, const double *state,
size_t result_dim, double *result, void *params)
{
indexed_weight_stat_params_t args = *(indexed_weight_stat_params_t *) params;
const double *x = state;
tsk_id_t i, j;
size_t k;
double meanx, ni, nj;

meanx = state[state_dim - 1] / args.total_weights[state_dim - 1];
;
for (k = 0; k < result_dim; k++) {
i = args.index_tuples[2 * k];
j = args.index_tuples[2 * k + 1];
ni = args.total_weights[i];
nj = args.total_weights[j];
result[k] = (x[i] - ni * meanx) * (x[j] - nj * meanx) / 2;
}
return 0;
}

int
tsk_treeseq_genetic_relatedness_weighted(const tsk_treeseq_t *self,
tsk_size_t num_weights, const double *weights, tsk_size_t num_index_tuples,
const tsk_id_t *index_tuples, tsk_size_t num_windows, const double *windows,
double *result, tsk_flags_t options)
{
int ret = 0;
tsk_size_t num_samples = self->num_samples;
size_t j, k;
indexed_weight_stat_params_t args;
const double *row;
double *new_row;
double *total_weights = malloc((num_weights + 1) * sizeof(*total_weights));
double *new_weights = malloc((num_weights + 1) * num_samples * sizeof(*new_weights));

if (total_weights == NULL || new_weights == NULL) {
ret = TSK_ERR_NO_MEMORY;
goto out;
}

// add a column of ones to W
for (k = 0; k < num_samples; k++) {
row = GET_2D_ROW(weights, num_weights, k);
new_row = GET_2D_ROW(new_weights, num_weights + 1, k);
for (j = 0; j < num_weights; j++) {
new_row[j] = row[j];
}
new_row[num_weights] = 1.0;
}

/* TODO: sanity check indexes */

for (j = 0; j < num_samples; j++) {
row = GET_2D_ROW(new_weights, num_weights + 1, j);
for (k = 0; k < num_weights + 1; k++) {
total_weights[k] += row[k];
}
}

args.total_weights = total_weights;
args.index_tuples = index_tuples;

ret = tsk_treeseq_general_stat(self, num_weights + 1, new_weights, num_index_tuples,
genetic_relatedness_weighted_summary_func, &args, num_windows, windows, result,
options);
if (ret != 0) {
goto out;
}

out:
tsk_safe_free(total_weights);
tsk_safe_free(new_weights);
return ret;
}

static int
Y2_summary_func(size_t TSK_UNUSED(state_dim), const double *state, size_t result_dim,
double *result, void *params)
Expand Down
11 changes: 11 additions & 0 deletions c/tskit/trees.h
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,17 @@ int tsk_treeseq_trait_linear_model(const tsk_treeseq_t *self, tsk_size_t num_wei
const double *weights, tsk_size_t num_covariates, const double *covariates,
tsk_size_t num_windows, const double *windows, double *result, tsk_flags_t options);

/* Two way weighted stats with covariates */

typedef int two_way_weighted_method(const tsk_treeseq_t *self, tsk_size_t num_weights,
const double *weights, tsk_size_t num_index_tuples, const tsk_id_t *index_tuples,
tsk_size_t num_windows, const double *windows, double *result, tsk_flags_t options);

int tsk_treeseq_genetic_relatedness_weighted(const tsk_treeseq_t *self,
tsk_size_t num_weights, const double *weights, tsk_size_t num_index_tuples,
const tsk_id_t *index_tuples, tsk_size_t num_windows, const double *windows,
double *result, tsk_flags_t options);

/* One way sample set stats */

typedef int one_way_sample_stat_method(const tsk_treeseq_t *self,
Expand Down
99 changes: 99 additions & 0 deletions python/_tskitmodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -7453,6 +7453,93 @@ TreeSequence_k_way_stat_method(TreeSequence *self, PyObject *args, PyObject *kwd
return ret;
}

static PyObject *
TreeSequence_k_way_weighted_stat_method(TreeSequence *self, PyObject *args,
PyObject *kwds, npy_intp tuple_size, two_way_weighted_method *method)
{
PyObject *ret = NULL;
static char *kwlist[] = { "weights", "indexes", "windows", "mode", "span_normalise",
"polarised", NULL };
PyObject *weights = NULL;
PyObject *indexes = NULL;
PyObject *windows = NULL;
PyArrayObject *weights_array = NULL;
PyArrayObject *indexes_array = NULL;
PyArrayObject *windows_array = NULL;
PyArrayObject *result_array = NULL;
tsk_size_t num_windows, num_index_tuples;
npy_intp *w_shape, *shape;
tsk_flags_t options = 0;
char *mode = NULL;
int span_normalise = true;
int polarised = false;
int err;

if (TreeSequence_check_state(self) != 0) {
goto out;
}
if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOO|sii", kwlist, &weights, &indexes,
&windows, &mode, &span_normalise, &polarised)) {
goto out;
}
if (parse_stats_mode(mode, &options) != 0) {
goto out;
}
if (span_normalise) {
options |= TSK_STAT_SPAN_NORMALISE;
}
if (polarised) {
options |= TSK_STAT_POLARISED;
}
if (parse_windows(windows, &windows_array, &num_windows) != 0) {
goto out;
}
weights_array = (PyArrayObject *) PyArray_FROMANY(
weights, NPY_FLOAT64, 2, 2, NPY_ARRAY_IN_ARRAY);
if (weights_array == NULL) {
goto out;
}
w_shape = PyArray_DIMS(weights_array);
if (w_shape[0] != tsk_treeseq_get_num_samples(self->tree_sequence)) {
PyErr_SetString(PyExc_ValueError, "First dimension must be num_samples");
goto out;
}

indexes_array = (PyArrayObject *) PyArray_FROMANY(
indexes, NPY_INT32, 2, 2, NPY_ARRAY_IN_ARRAY);
if (indexes_array == NULL) {
goto out;
}
shape = PyArray_DIMS(indexes_array);
if (shape[0] < 1 || shape[1] != tuple_size) {
PyErr_Format(
PyExc_ValueError, "indexes must be a k x %d array.", (int) tuple_size);
goto out;
}
num_index_tuples = shape[0];

result_array = TreeSequence_allocate_results_array(
self, options, num_windows, num_index_tuples);
if (result_array == NULL) {
goto out;
}
err = method(self->tree_sequence, w_shape[1], PyArray_DATA(weights_array),
num_index_tuples, PyArray_DATA(indexes_array), num_windows,
PyArray_DATA(windows_array), PyArray_DATA(result_array), options);
if (err != 0) {
handle_library_error(err);
goto out;
}
ret = (PyObject *) result_array;
result_array = NULL;
out:
Py_XDECREF(weights_array);
Py_XDECREF(indexes_array);
Py_XDECREF(windows_array);
Py_XDECREF(result_array);
return ret;
}

static PyObject *
TreeSequence_divergence(TreeSequence *self, PyObject *args, PyObject *kwds)
{
Expand All @@ -7466,6 +7553,14 @@ TreeSequence_genetic_relatedness(TreeSequence *self, PyObject *args, PyObject *k
self, args, kwds, 2, tsk_treeseq_genetic_relatedness);
}

static PyObject *
TreeSequence_genetic_relatedness_weighted(
TreeSequence *self, PyObject *args, PyObject *kwds)
{
return TreeSequence_k_way_weighted_stat_method(
self, args, kwds, 2, tsk_treeseq_genetic_relatedness_weighted);
}

static PyObject *
TreeSequence_Y2(TreeSequence *self, PyObject *args, PyObject *kwds)
{
Expand Down Expand Up @@ -7787,6 +7882,10 @@ static PyMethodDef TreeSequence_methods[] = {
.ml_meth = (PyCFunction) TreeSequence_genetic_relatedness,
.ml_flags = METH_VARARGS | METH_KEYWORDS,
.ml_doc = "Computes genetic relatedness between sample sets." },
{ .ml_name = "genetic_relatedness_weighted",
.ml_meth = (PyCFunction) TreeSequence_genetic_relatedness_weighted,
.ml_flags = METH_VARARGS | METH_KEYWORDS,
.ml_doc = "Computes genetic relatedness between weighted sums of samples." },
{ .ml_name = "Y1",
.ml_meth = (PyCFunction) TreeSequence_Y1,
.ml_flags = METH_VARARGS | METH_KEYWORDS,
Expand Down
83 changes: 83 additions & 0 deletions python/tskit/trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -5751,6 +5751,47 @@ def __k_way_sample_set_stat(
stat = stat.reshape(stat.shape[:-1])
return stat

def __k_way_weighted_stat(
self,
ll_method,
k,
W,
indexes=None,
windows=None,
mode=None,
span_normalise=True,
polarised=False,
):
if indexes is None:
if W.shape[1] != k:
raise ValueError(
"Must specify indexes if there are not exactly {} columsn "
"in W.".format(k)
)
indexes = np.arange(k, dtype=np.int32)
drop_dimension = False
indexes = util.safe_np_int_cast(indexes, np.int32)
if len(indexes.shape) == 1:
indexes = indexes.reshape((1, indexes.shape[0]))
drop_dimension = True
if len(indexes.shape) != 2 or indexes.shape[1] != k:
raise ValueError(
"Indexes must be convertable to a 2D numpy array with {} "
"columns".format(k)
)
stat = self.__run_windowed_stat(
windows,
ll_method,
W,
indexes,
mode=mode,
span_normalise=span_normalise,
polarised=polarised,
)
if drop_dimension:
stat = stat.reshape(stat.shape[:-1])
return stat

############################################
# Statistics definitions
############################################
Expand Down Expand Up @@ -6012,6 +6053,48 @@ def genetic_relatedness(

return out

def genetic_relatedness_weighted(
self,
W,
indexes=None,
windows=None,
mode="site",
span_normalise=True,
polarised=False,
):
r"""
Computes weighted genetic relatedness: if the k-th pair of indices is (i, j)
then the k-th column of output will be
:math:`\sum_{a,b} W_{ai} W_{bj} C_{ab}`,
where :math:`W` is the matrix of weights, and :math:`C_{ab}` is the
{meth}`.genetic_relatedness` between sample i and sample j.
:param numpy.ndarray W: An array of values with one row for each sample and one
column for each set of weights.
:param list indexes: A list of 2-tuples, or None.
:param list windows: An increasing list of breakpoints between the windows
to compute the statistic in.
:param str mode: A string giving the "type" of the statistic to be computed
(defaults to "site").
:param bool span_normalise: Whether to divide the result by the span of the
window (defaults to True).
:return: A ndarray with shape equal to (num windows, num statistics).
"""
if W.shape[0] != self.num_samples:
raise ValueError(
"First trait dimension must be equal to number of samples."
)
return self.__k_way_weighted_stat(
self._ll_tree_sequence.genetic_relatedness_weighted,
2,
W,
indexes=indexes,
windows=windows,
mode=mode,
span_normalise=span_normalise,
polarised=polarised,
)

def trait_covariance(self, W, windows=None, mode="site", span_normalise=True):
"""
Computes the mean squared covariances between each of the columns of ``W``
Expand Down

0 comments on commit b7cb129

Please sign in to comment.