Skip to content

Commit

Permalink
removed some redundancies from permute_column, 2x speedup
Browse files Browse the repository at this point in the history
  • Loading branch information
rishi-kulkarni committed Apr 30, 2021
1 parent d29f812 commit da6ccbf
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 24 deletions.
47 changes: 27 additions & 20 deletions hierarch/internal_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,13 @@
from collections import Counter
import sympy.utilities.iterables as iterables

@nb.njit
def nb_data_grabber(data, col, treatment_labels):
ret_list = []
for key in treatment_labels:
ret_list.append(data[:,-1][np.equal(data[:,col],key)])
return ret_list

@nb.jit(nopython=True, cache=True)
def nb_unique(input_data, axis=0):
'''
Expand Down Expand Up @@ -48,10 +55,13 @@ def nb_unique(input_data, axis=0):
return data[idx], idx, counts

@nb.jit(nopython=True)
def welch_statistic(sample_a, sample_b):
def welch_statistic(data, col, treatment_labels):
'''
Internal function that calculates Welch's t statistic.
Details on the validity of this test statistic can be found in "Studentized permutation tests for non-i.i.d. hypotheses and the generalized Behrens-Fisher problem
" by Arnold Janssen. https://doi.org/10.1016/S0167-7152(97)00043-6.
Parameters
----------
Expand All @@ -64,6 +74,9 @@ def welch_statistic(sample_a, sample_b):
Note: The formula for Welch's t reduces to Student's t when sample_a and sample_b are the same size, so use this function whenever you need a t statistic.
'''

sample_a, sample_b = nb_data_grabber(data, 0, treatment_labels)

meandiff = (np.mean(sample_a) - np.mean(sample_b))

var_weight_one = (np.var(sample_a)*(sample_a.size/(sample_a.size - 1))) / len(sample_a)
Expand Down Expand Up @@ -208,12 +221,9 @@ def permute_column(data, col_to_permute=-2, iterator=None, seed=None):
---------
permuted: an array the same size as data with column n - 1 permuted within column n - 2's clusters.
"""

key = hash(tuple((data[:,:col_to_permute+1].tobytes(),col_to_permute)))
key = hash(data[:,:col_to_permute+1].tobytes())

try:
values, indexes, counts = permute_column.__dict__[key]
Expand All @@ -225,29 +235,26 @@ def permute_column(data, col_to_permute=-2, iterator=None, seed=None):


if iterator == None:
try:
keys = unique_idx_w_cache(values)[-2]
shuffled_col_values = randomize_chunks(values, keys)


except:

if col_to_permute == 1:
shuffled_col_values = data[:,col_to_permute-1][indexes]
quick_shuffle(shuffled_col_values)
quick_shuffle(shuffled_col_values)
else:
keys = unique_idx_w_cache(values)[-2]
shuffled_col_values = randomize_chunks(values, keys)

else:
shuffled_col_values = iterator


if len(shuffled_col_values) != data[:,col_to_permute-1].size:
new_col = np.repeat(shuffled_col_values, counts)
else:
if len(shuffled_col_values) == data.shape[0]:
new_col = shuffled_col_values
else:
new_col = np.repeat(shuffled_col_values, counts)


permuted = data.copy()
permuted[:,col_to_permute-1] = new_col
return permuted
permute = np.array(data)
permute[:,col_to_permute-1] = new_col

return permute

def bootstrap_agg(bootstrap_sample, func=np.nanmean, agg_to=-2, first_data_col=-1):

Expand Down
8 changes: 4 additions & 4 deletions hierarch/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def two_sample_test(data_array, treatment_col, teststat="welch", skip=[], bootst
if data.dtype != 'float64':
data[:,:-1] = internal_functions.label_encode(data[:,:-1])
data = data.astype('float64')
data = np.unique(data, axis=0)
data = np.unique(data, axis=0) ###sorts the data matrix by row. 100% necessary.

treatment_labels = np.unique(data[:,treatment_col])

Expand All @@ -36,7 +36,7 @@ def two_sample_test(data_array, treatment_col, teststat="welch", skip=[], bootst
for m in range(levels_to_agg):
test = internal_functions.mean_agg(test)

truediff = np.abs(teststat(test[test[:,treatment_col] == treatment_labels[0]][:,-1], test[test[:,treatment_col] == treatment_labels[1]][:,-1]))
truediff = np.abs(teststat(test, treatment_col, treatment_labels))


means = []
Expand All @@ -55,12 +55,12 @@ def two_sample_test(data_array, treatment_col, teststat="welch", skip=[], bootst
#we are sampling all 20 permutations, so no need for rng.
for k in it_list:
permute_resample = internal_functions.permute_column(bootstrapped_sample, treatment_col+1, k)
means.append(teststat(permute_resample[permute_resample[:,treatment_col] == treatment_labels[0]][:,-1], permute_resample[permute_resample[:,treatment_col] == treatment_labels[1]][:,-1]))
means.append(teststat(permute_resample, treatment_col, treatment_labels))

else:
for k in range(permutations):
permute_resample = internal_functions.permute_column(bootstrapped_sample, treatment_col+1)
means.append(teststat(permute_resample[permute_resample[:,treatment_col] == treatment_labels[0]][:,-1], permute_resample[permute_resample[:,treatment_col] == treatment_labels[1]][:,-1]))
means.append(teststat(permute_resample, treatment_col, treatment_labels))

pval = np.where((np.array(np.abs(means)) >= truediff))[0].size / len(means)

Expand Down

0 comments on commit da6ccbf

Please sign in to comment.