Skip to content

Commit

Permalink
Halve time taken for test_tree_stats
Browse files Browse the repository at this point in the history
  • Loading branch information
benjeffery committed Oct 16, 2024
1 parent ee7f3b5 commit c76a198
Showing 1 changed file with 32 additions and 8 deletions.
40 changes: 32 additions & 8 deletions python/tests/test_tree_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -1282,7 +1282,10 @@ def site_segregating_sites(ts, sample_sets, windows=None, span_normalise=True):
haps = ts.genotype_matrix(isolated_as_missing=False)
site_positions = [x.position for x in ts.sites()]
for i, X in enumerate(sample_sets):
X_index = np.where(np.isin(samples, X))[0]
set_X = set(X)
X_index = np.where(np.fromiter((s in set_X for s in samples), dtype=bool))[
0
]
for k in range(ts.num_sites):
if (site_positions[k] >= begin) and (site_positions[k] < end):
num_alleles = len(set(haps[k, X_index]))
Expand Down Expand Up @@ -1430,7 +1433,10 @@ def site_tajimas_d(ts, sample_sets, windows=None):
nn = n[i]
S = 0
T = 0
X_index = np.where(np.isin(samples, X))[0]
set_X = set(X)
X_index = np.where(np.fromiter((s in set_X for s in samples), dtype=bool))[
0
]
for k in range(ts.num_sites):
if (site_positions[k] >= begin) and (site_positions[k] < end):
hX = haps[k, X_index]
Expand Down Expand Up @@ -4891,7 +4897,10 @@ def branch_trait_covariance(ts, W, windows=None, span_normalise=True):
has_trees = True
SS = 0
for u in range(ts.num_nodes):
below = np.isin(samples, list(tr.samples(u)))
tree_samples = set(tr.samples(u))
below = np.fromiter(
(s in tree_samples for s in samples), dtype=bool
)
branch_length = tr.branch_length(u)
SS += covsq(w, below) * branch_length
S += SS * (min(end, tr.interval.right) - max(begin, tr.interval.left))
Expand Down Expand Up @@ -4926,7 +4935,10 @@ def node_trait_covariance(ts, W, windows=None, span_normalise=True):
break
SS = np.zeros(ts.num_nodes)
for u in range(ts.num_nodes):
below = np.isin(samples, list(tr.samples(u)))
tree_samples = set(tr.samples(u))
below = np.fromiter(
(s in tree_samples for s in samples), dtype=bool
)
SS[u] += covsq(w, below)
S += SS * (min(end, tr.interval.right) - max(begin, tr.interval.left))
out[j, :, i] = S
Expand Down Expand Up @@ -5102,7 +5114,10 @@ def branch_trait_correlation(ts, W, windows=None, span_normalise=True):
has_trees = True
SS = 0
for u in range(ts.num_nodes):
below = np.isin(samples, list(tr.samples(u)))
tree_samples = set(tr.samples(u))
below = np.fromiter(
(s in tree_samples for s in samples), dtype=bool
)
p = np.mean(below)
if p > 0 and p < 1:
branch_length = tr.branch_length(u)
Expand Down Expand Up @@ -5143,7 +5158,10 @@ def node_trait_correlation(ts, W, windows=None, span_normalise=True):
break
SS = np.zeros(ts.num_nodes)
for u in range(ts.num_nodes):
below = np.isin(samples, list(tr.samples(u)))
tree_samples = set(tr.samples(u))
below = np.fromiter(
(s in tree_samples for s in samples), dtype=bool
)
p = np.mean(below)
if p > 0 and p < 1:
# SS[u] += sum(w[below])**2 / 2
Expand Down Expand Up @@ -5366,7 +5384,10 @@ def branch_trait_linear_model(ts, W, Z, windows=None, span_normalise=True):
has_trees = True
SS = 0
for u in range(ts.num_nodes):
below = np.isin(samples, list(tr.samples(u)))
tree_samples = set(tr.samples(u))
below = np.fromiter(
(s in tree_samples for s in samples), dtype=bool
)
branch_length = tr.branch_length(u)
SS += linear_model(w, below, Z) * branch_length
S += SS * (min(end, tr.interval.right) - max(begin, tr.interval.left))
Expand Down Expand Up @@ -5401,7 +5422,10 @@ def node_trait_linear_model(ts, W, Z, windows=None, span_normalise=True):
break
SS = np.zeros(ts.num_nodes)
for u in range(ts.num_nodes):
below = np.isin(samples, list(tr.samples(u)))
tree_samples = set(tr.samples(u))
below = np.fromiter(
(s in tree_samples for s in samples), dtype=bool
)
SS[u] += linear_model(w, below, Z)
S += SS * (min(end, tr.interval.right) - max(begin, tr.interval.left))
out[j, :, i] = S
Expand Down

0 comments on commit c76a198

Please sign in to comment.