diff --git a/c/tskit/trees.c b/c/tskit/trees.c index cc70668679..2d6e6240c3 100644 --- a/c/tskit/trees.c +++ b/c/tskit/trees.c @@ -2931,7 +2931,6 @@ genetic_relatedness_weighted_summary_func(size_t state_dim, const double *state, double meanx, ni, nj; meanx = state[state_dim - 1] / args.total_weights[state_dim - 1]; - ; for (k = 0; k < result_dim; k++) { i = args.index_tuples[2 * k]; j = args.index_tuples[2 * k + 1]; @@ -2954,7 +2953,7 @@ tsk_treeseq_genetic_relatedness_weighted(const tsk_treeseq_t *self, indexed_weight_stat_params_t args; const double *row; double *new_row; - double *total_weights = malloc((num_weights + 1) * sizeof(*total_weights)); + double *total_weights = calloc((num_weights + 1), sizeof(*total_weights)); double *new_weights = malloc((num_weights + 1) * num_samples * sizeof(*new_weights)); if (total_weights == NULL || new_weights == NULL) { @@ -2962,25 +2961,20 @@ tsk_treeseq_genetic_relatedness_weighted(const tsk_treeseq_t *self, goto out; } - // add a column of ones to W - for (k = 0; k < num_samples; k++) { - row = GET_2D_ROW(weights, num_weights, k); - new_row = GET_2D_ROW(new_weights, num_weights + 1, k); - for (j = 0; j < num_weights; j++) { - new_row[j] = row[j]; + // Add a column of ones to W + for (j = 0; j < num_samples; j++) { + row = GET_2D_ROW(weights, num_weights, j); + new_row = GET_2D_ROW(new_weights, num_weights + 1, j); + for (k = 0; k < num_weights; k++) { + new_row[k] = row[k]; + total_weights[k] += row[k]; } new_row[num_weights] = 1.0; } + total_weights[num_weights] = num_samples; /* TODO: sanity check indexes */ - for (j = 0; j < num_samples; j++) { - row = GET_2D_ROW(new_weights, num_weights + 1, j); - for (k = 0; k < num_weights + 1; k++) { - total_weights[k] += row[k]; - } - } - args.total_weights = total_weights; args.index_tuples = index_tuples; diff --git a/python/tests/test_lowlevel.py b/python/tests/test_lowlevel.py index 6dfdec6fa3..448b5f4052 100644 --- a/python/tests/test_lowlevel.py +++ b/python/tests/test_lowlevel.py @@ -1577,6 +1577,15 @@ def test_window_errors(self): with pytest.raises(_tskit.LibraryError): f(windows=bad_window, **params) + def test_polarisation(self): + ts, f, params = self.get_example() + with pytest.raises(TypeError): + f(polarised="sdf", **params) + x1 = f(polarised=False, **params) + x2 = f(polarised=True, **params) + # Basic check just to run both code paths + assert x1.shape == x2.shape + def test_windows_output(self): ts, f, params = self.get_example() del params["windows"] @@ -1992,6 +2001,20 @@ def test_output_dims(self): out = method(weights[:, [0, 0, 0]], mode=mode, **params) assert out.shape == (1, N, 1) + def test_set_index_errors(self): + ts, method, params = self.get_example() + del params["indexes"] + + def f(indexes): + method(indexes=indexes, **params) + + for bad_array in ["wer", {}, [[[], []], [[], []]]]: + with pytest.raises(ValueError): + f(bad_array) + for bad_dim in [[[]], [[1], [1]]]: + with pytest.raises(ValueError): + f(bad_dim) + class ThreeWaySampleStatsMixin(SampleSetMixin): """