Skip to content

Commit

Permalink
updates?
Browse files Browse the repository at this point in the history
  • Loading branch information
petrelharp committed Jan 23, 2023
1 parent 66ca353 commit c00d375
Show file tree
Hide file tree
Showing 6 changed files with 376 additions and 0 deletions.
20 changes: 20 additions & 0 deletions c/tests/test_stats.c
Original file line number Diff line number Diff line change
Expand Up @@ -1340,6 +1340,24 @@ test_paper_ex_genetic_relatedness(void)
tsk_treeseq_free(&ts);
}

static void
test_paper_ex_genetic_relatedness_weighted(void)
{
tsk_treeseq_t ts;
double weights[] = { 1.2, 0.1, 0.0, 0.0, 3.4, 5.0, 1.0, -1.0 };
tsk_id_t indexes[] = { 0, 0, 0, 1 };
double result[2];
int ret;

tsk_treeseq_from_text(&ts, 10, paper_ex_nodes, paper_ex_edges, NULL, paper_ex_sites,
paper_ex_mutations, paper_ex_individuals, NULL, 0);

ret = tsk_treeseq_genetic_relatedness_weighted(
&ts, 2, weights, 2, indexes, 0, NULL, result, TSK_STAT_SITE);
CU_ASSERT_EQUAL_FATAL(ret, 0);
tsk_treeseq_free(&ts);
}

static void
test_paper_ex_genetic_relatedness_errors(void)
{
Expand Down Expand Up @@ -1773,6 +1791,8 @@ main(int argc, char **argv)
{ "test_paper_ex_genetic_relatedness_errors",
test_paper_ex_genetic_relatedness_errors },
{ "test_paper_ex_genetic_relatedness", test_paper_ex_genetic_relatedness },
{ "test_paper_ex_genetic_relatedness_weighted",
test_paper_ex_genetic_relatedness_weighted },
{ "test_paper_ex_Y2_errors", test_paper_ex_Y2_errors },
{ "test_paper_ex_Y2", test_paper_ex_Y2 },
{ "test_paper_ex_f2_errors", test_paper_ex_f2_errors },
Expand Down
74 changes: 74 additions & 0 deletions c/tskit/trees.c
Original file line number Diff line number Diff line change
Expand Up @@ -2059,6 +2059,11 @@ typedef struct {
const tsk_id_t *set_indexes;
} sample_count_stat_params_t;

typedef struct {
double *total_weights;
const tsk_id_t *index_tuples;
} indexed_weight_stat_params_t;

static int
tsk_treeseq_sample_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sample_sets,
const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets,
Expand Down Expand Up @@ -3012,6 +3017,75 @@ tsk_treeseq_genetic_relatedness(const tsk_treeseq_t *self, tsk_size_t num_sample
return ret;
}

static int
genetic_relatedness_weighted_summary_func(tsk_size_t state_dim, const double *state,
tsk_size_t result_dim, double *result, void *params)
{
indexed_weight_stat_params_t args = *(indexed_weight_stat_params_t *) params;
const double *x = state;
tsk_id_t i, j;
tsk_size_t k;
double meanx, ni, nj;

meanx = state[state_dim - 1] / args.total_weights[state_dim - 1];
for (k = 0; k < result_dim; k++) {
i = args.index_tuples[2 * k];
j = args.index_tuples[2 * k + 1];
ni = args.total_weights[i];
nj = args.total_weights[j];
result[k] = (x[i] - ni * meanx) * (x[j] - nj * meanx) / 2;
}
return 0;
}

int
tsk_treeseq_genetic_relatedness_weighted(const tsk_treeseq_t *self,
tsk_size_t num_weights, const double *weights, tsk_size_t num_index_tuples,
const tsk_id_t *index_tuples, tsk_size_t num_windows, const double *windows,
double *result, tsk_flags_t options)
{
int ret = 0;
tsk_size_t num_samples = self->num_samples;
size_t j, k;
indexed_weight_stat_params_t args;
const double *row;
double *new_row;
double *total_weights = tsk_calloc((num_weights + 1), sizeof(*total_weights));
double *new_weights
= tsk_malloc((num_weights + 1) * num_samples * sizeof(*new_weights));

if (total_weights == NULL || new_weights == NULL) {
ret = TSK_ERR_NO_MEMORY;
goto out;
}

// Add a column of ones to W
for (j = 0; j < num_samples; j++) {
row = GET_2D_ROW(weights, num_weights, j);
new_row = GET_2D_ROW(new_weights, num_weights + 1, j);
for (k = 0; k < num_weights; k++) {
new_row[k] = row[k];
total_weights[k] += row[k];
}
new_row[num_weights] = 1.0;
}
total_weights[num_weights] = (double) num_samples;

args.total_weights = total_weights;
args.index_tuples = index_tuples;
ret = tsk_treeseq_general_stat(self, num_weights + 1, new_weights, num_index_tuples,
genetic_relatedness_weighted_summary_func, &args, num_windows, windows, options,
result);
if (ret != 0) {
goto out;
}

out:
tsk_safe_free(total_weights);
tsk_safe_free(new_weights);
return ret;
}

static int
Y2_summary_func(tsk_size_t TSK_UNUSED(state_dim), const double *state,
tsk_size_t result_dim, double *result, void *params)
Expand Down
11 changes: 11 additions & 0 deletions c/tskit/trees.h
Original file line number Diff line number Diff line change
Expand Up @@ -970,6 +970,17 @@ int tsk_treeseq_trait_linear_model(const tsk_treeseq_t *self, tsk_size_t num_wei
const double *weights, tsk_size_t num_covariates, const double *covariates,
tsk_size_t num_windows, const double *windows, tsk_flags_t options, double *result);

/* Two way weighted stats with covariates */

typedef int two_way_weighted_method(const tsk_treeseq_t *self, tsk_size_t num_weights,
const double *weights, tsk_size_t num_index_tuples, const tsk_id_t *index_tuples,
tsk_size_t num_windows, const double *windows, double *result, tsk_flags_t options);

int tsk_treeseq_genetic_relatedness_weighted(const tsk_treeseq_t *self,
tsk_size_t num_weights, const double *weights, tsk_size_t num_index_tuples,
const tsk_id_t *index_tuples, tsk_size_t num_windows, const double *windows,
double *result, tsk_flags_t options);

/* One way sample set stats */

typedef int one_way_sample_stat_method(const tsk_treeseq_t *self,
Expand Down
99 changes: 99 additions & 0 deletions python/_tskitmodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -9322,6 +9322,93 @@ TreeSequence_k_way_stat_method(TreeSequence *self, PyObject *args, PyObject *kwd
return ret;
}

static PyObject *
TreeSequence_k_way_weighted_stat_method(TreeSequence *self, PyObject *args,
PyObject *kwds, npy_intp tuple_size, two_way_weighted_method *method)
{
PyObject *ret = NULL;
static char *kwlist[] = { "weights", "indexes", "windows", "mode", "span_normalise",
"polarised", NULL };
PyObject *weights = NULL;
PyObject *indexes = NULL;
PyObject *windows = NULL;
PyArrayObject *weights_array = NULL;
PyArrayObject *indexes_array = NULL;
PyArrayObject *windows_array = NULL;
PyArrayObject *result_array = NULL;
tsk_size_t num_windows, num_index_tuples;
npy_intp *w_shape, *shape;
tsk_flags_t options = 0;
char *mode = NULL;
int span_normalise = true;
int polarised = false;
int err;

if (TreeSequence_check_state(self) != 0) {
goto out;
}
if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOO|sii", kwlist, &weights, &indexes,
&windows, &mode, &span_normalise, &polarised)) {
goto out;
}
if (parse_stats_mode(mode, &options) != 0) {
goto out;
}
if (span_normalise) {
options |= TSK_STAT_SPAN_NORMALISE;
}
if (polarised) {
options |= TSK_STAT_POLARISED;
}
if (parse_windows(windows, &windows_array, &num_windows) != 0) {
goto out;
}
weights_array = (PyArrayObject *) PyArray_FROMANY(
weights, NPY_FLOAT64, 2, 2, NPY_ARRAY_IN_ARRAY);
if (weights_array == NULL) {
goto out;
}
w_shape = PyArray_DIMS(weights_array);
if (w_shape[0] != (npy_intp) tsk_treeseq_get_num_samples(self->tree_sequence)) {
PyErr_SetString(PyExc_ValueError, "First dimension must be num_samples");
goto out;
}

indexes_array = (PyArrayObject *) PyArray_FROMANY(
indexes, NPY_INT32, 2, 2, NPY_ARRAY_IN_ARRAY);
if (indexes_array == NULL) {
goto out;
}
shape = PyArray_DIMS(indexes_array);
if (shape[0] < 1 || shape[1] != tuple_size) {
PyErr_Format(
PyExc_ValueError, "indexes must be a k x %d array.", (int) tuple_size);
goto out;
}
num_index_tuples = shape[0];

result_array = TreeSequence_allocate_results_array(
self, options, num_windows, num_index_tuples);
if (result_array == NULL) {
goto out;
}
err = method(self->tree_sequence, w_shape[1], PyArray_DATA(weights_array),
num_index_tuples, PyArray_DATA(indexes_array), num_windows,
PyArray_DATA(windows_array), PyArray_DATA(result_array), options);
if (err != 0) {
handle_library_error(err);
goto out;
}
ret = (PyObject *) result_array;
result_array = NULL;
out:
Py_XDECREF(weights_array);
Py_XDECREF(indexes_array);
Py_XDECREF(windows_array);
Py_XDECREF(result_array);
return ret;
}

static PyObject *
TreeSequence_divergence(TreeSequence *self, PyObject *args, PyObject *kwds)
{
Expand All @@ -9335,6 +9422,14 @@ TreeSequence_genetic_relatedness(TreeSequence *self, PyObject *args, PyObject *k
self, args, kwds, 2, tsk_treeseq_genetic_relatedness);
}

static PyObject *
TreeSequence_genetic_relatedness_weighted(
TreeSequence *self, PyObject *args, PyObject *kwds)
{
return TreeSequence_k_way_weighted_stat_method(
self, args, kwds, 2, tsk_treeseq_genetic_relatedness_weighted);
}

static PyObject *
TreeSequence_Y2(TreeSequence *self, PyObject *args, PyObject *kwds)
{
Expand Down Expand Up @@ -10049,6 +10144,10 @@ static PyMethodDef TreeSequence_methods[] = {
.ml_meth = (PyCFunction) TreeSequence_genetic_relatedness,
.ml_flags = METH_VARARGS | METH_KEYWORDS,
.ml_doc = "Computes genetic relatedness between sample sets." },
{ .ml_name = "genetic_relatedness_weighted",
.ml_meth = (PyCFunction) TreeSequence_genetic_relatedness_weighted,
.ml_flags = METH_VARARGS | METH_KEYWORDS,
.ml_doc = "Computes genetic relatedness between weighted sums of samples." },
{ .ml_name = "Y1",
.ml_meth = (PyCFunction) TreeSequence_Y1,
.ml_flags = METH_VARARGS | METH_KEYWORDS,
Expand Down
89 changes: 89 additions & 0 deletions python/tests/test_lowlevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1663,6 +1663,15 @@ def test_window_errors(self):
with pytest.raises(_tskit.LibraryError):
f(windows=bad_window, **params)

def test_polarisation(self):
ts, f, params = self.get_example()
with pytest.raises(TypeError):
f(polarised="sdf", **params)
x1 = f(polarised=False, **params)
x2 = f(polarised=True, **params)
# Basic check just to run both code paths
assert x1.shape == x2.shape

def test_windows_output(self):
ts, f, params = self.get_example()
del params["windows"]
Expand Down Expand Up @@ -2025,6 +2034,74 @@ def f(indexes):
f(bad_dim)


class TwoWayWeightedStatsMixin(StatsInterfaceMixin):
"""
Tests for the weighted two way sample stats.
"""

def get_example(self):
ts, method = self.get_method()
params = {
"weights": np.zeros((ts.get_num_samples(), 2)) + 0.5,
"indexes": [[0, 1]],
"windows": [0, ts.get_sequence_length()],
}
return ts, method, params

def test_basic_example(self):
ts, method = self.get_method()
div = method(
np.zeros((ts.get_num_samples(), 1)) + 0.5,
[[0, 1]],
windows=[0, ts.get_sequence_length()],
)
assert div.shape == (1, 1)

def test_bad_weights(self):
ts, f, params = self.get_example()
del params["weights"]
n = ts.get_num_samples()

for bad_weight_shape in [(n - 1, 1), (n + 1, 1), (0, 3)]:
with pytest.raises(ValueError):
f(weights=np.ones(bad_weight_shape), **params)

def test_output_dims(self):
ts, method, params = self.get_example()
weights = params.pop("weights")
params["windows"] = [0, ts.get_sequence_length()]

for mode in ["site", "branch"]:
out = method(weights[:, [0]], mode=mode, **params)
assert out.shape == (1, 1)
out = method(weights, mode=mode, **params)
assert out.shape == (1, 1)
out = method(weights[:, [0, 0, 0]], mode=mode, **params)
assert out.shape == (1, 1)
mode = "node"
N = ts.get_num_nodes()
out = method(weights[:, [0]], mode=mode, **params)
assert out.shape == (1, N, 1)
out = method(weights, mode=mode, **params)
assert out.shape == (1, N, 1)
out = method(weights[:, [0, 0, 0]], mode=mode, **params)
assert out.shape == (1, N, 1)

def test_set_index_errors(self):
ts, method, params = self.get_example()
del params["indexes"]

def f(indexes):
method(indexes=indexes, **params)

for bad_array in ["wer", {}, [[[], []], [[], []]]]:
with pytest.raises(ValueError):
f(bad_array)
for bad_dim in [[[]], [[1], [1]]]:
with pytest.raises(ValueError):
f(bad_dim)


class ThreeWaySampleStatsMixin(SampleSetMixin):
"""
Tests for the two way sample stats.
Expand Down Expand Up @@ -2211,6 +2288,12 @@ def get_method(self):
return ts, ts.f2


class TestGeneticRelatedness(LowLevelTestCase, TwoWaySampleStatsMixin):
def get_method(self):
ts = self.get_example_tree_sequence()
return ts, ts.genetic_relatedness


class TestY3(LowLevelTestCase, ThreeWaySampleStatsMixin):
def get_method(self):
ts = self.get_example_tree_sequence()
Expand All @@ -2229,6 +2312,12 @@ def get_method(self):
return ts, ts.f4


class TestWeightedGeneticRelatedness(LowLevelTestCase, TwoWayWeightedStatsMixin):
def get_method(self):
ts = self.get_example_tree_sequence()
return ts, ts.genetic_relatedness_weighted


class TestGeneralStatsInterface(LowLevelTestCase, StatsInterfaceMixin):
"""
Tests for the general stats interface.
Expand Down
Loading

0 comments on commit c00d375

Please sign in to comment.