diff --git a/hierarch/internal_functions.py b/hierarch/internal_functions.py index 73ed7cd..93735aa 100644 --- a/hierarch/internal_functions.py +++ b/hierarch/internal_functions.py @@ -3,6 +3,7 @@ import scipy.stats as stats from collections import Counter import sympy.utilities.iterables as iterables +import hierarch.numba_overloads @nb.njit def nb_data_grabber(data, col, treatment_labels): @@ -16,8 +17,6 @@ def nb_unique(input_data, axis=0): ''' Internal function that serves the same purpose as np.unique(a, return_index=True, return_counts=True) when called on a 2D arrays. Appears to asymptotically approach np.unique's speed when every row is unique, but otherwise runs faster. - Note: the returned indexes are NOT the indexes of the unique rows in the original data, they are the indexes of the unique rows in the sorted data. This doesn't make any difference so long as input_data is sorted. - Parameters ---------- input_data: 2D array @@ -34,25 +33,34 @@ def nb_unique(input_data, axis=0): counts: number of instances of each unique row (or column) in the input array ''' + + #don't want to sort original data if axis == 1: data = input_data.T.copy() else: data = input_data.copy() - - for i in range(data.shape[1]-1,-1,-1): - data = data[data[:,i].argsort(kind="mergesort")] - - idx = np.zeros(1, dtype=np.int64) - counts = np.ones(1, dtype=np.int64) - additional_uniques = np.where(~np_all_axis1(data[:-1] == data[1:]))[0] + 1 + #so we can remember the original indexes of each row + orig_idx = np.array([i for i in range(data.shape[0])]) + + #sort our data AND the original indexes + for i in range(data.shape[1]-1,-1,-1): + sorter = data[:,i].argsort(kind="mergesort") #mergesort to keep associations + data = data[sorter] + orig_idx = orig_idx[sorter] + + #get original indexes + idx = [0] + additional_uniques = np.nonzero(~np.all((data[:-1] == data[1:]),axis=1))[0] + 1 idx = np.append(idx, additional_uniques) - counts = idx[1:].copy() - counts = np.append(counts, data.shape[0]) + + #get counts for each unique row + counts = [1] + counts = np.append(idx[1:], data.shape[0]) counts = counts - idx - - return data[idx], idx, counts + + return data[idx], orig_idx[idx], counts @nb.jit(nopython=True) def welch_statistic(data, col, treatment_labels): @@ -74,15 +82,18 @@ def welch_statistic(data, col, treatment_labels): Note: The formula for Welch's t reduces to Student's t when sample_a and sample_b are the same size, so use this function whenever you need a t statistic. ''' - - sample_a, sample_b = nb_data_grabber(data, 0, treatment_labels) - meandiff = (np.mean(sample_a) - np.mean(sample_b)) + #get our two samples from the data matrix + sample_a, sample_b = nb_data_grabber(data, col, treatment_labels) - var_weight_one = (np.var(sample_a)*(sample_a.size/(sample_a.size - 1))) / len(sample_a) + #mean difference + meandiff = (np.mean(sample_a) - np.mean(sample_b)) + #weighted sample variances + var_weight_one = (np.var(sample_a)*(sample_a.size/(sample_a.size - 1))) / len(sample_a) var_weight_two = (np.var(sample_b)*(sample_b.size/(sample_b.size - 1))) / len(sample_b) + #compute t statistic t = meandiff / np.sqrt(var_weight_one + var_weight_two) return t @@ -97,7 +108,7 @@ def quick_shuffle(w): @nb.jit() -def randomize_chunks(values, keys): +def randomize_chunks(shuffled_col, splits): ''' Internal function for permuting a column a data while paying attention to the dependency structure of the prior column. Numba's implementation of np.random.permutation is faster than numpy's, so we're using this. @@ -115,10 +126,11 @@ def randomize_chunks(values, keys): List of permuted values of the second-to-last column of the values array ''' - append_col=np.empty(0, dtype=np.float64) - for i in keys: - append_col = np.hstack((append_col, np.random.permutation(values[np_all_axis1(values[:,:-2] == values[i][:-2])][:,-2]))) - return append_col + for idx, _ in enumerate(splits): + if idx < len(splits)-1: + np.random.shuffle(shuffled_col[splits[idx]:splits[idx+1]-1]) + else: + np.random.shuffle(shuffled_col[splits[idx]:]) @nb.jit(nopython=True, cache=True) def np_all_axis1(x): @@ -235,24 +247,22 @@ def permute_column(data, col_to_permute=-2, iterator=None, seed=None): if iterator == None: + shuffled_col_values = data[:,col_to_permute-1][indexes] if col_to_permute == 1: - shuffled_col_values = data[:,col_to_permute-1][indexes] quick_shuffle(shuffled_col_values) else: keys = unique_idx_w_cache(values)[-2] - shuffled_col_values = randomize_chunks(values, keys) + randomize_chunks(shuffled_col_values, keys) else: shuffled_col_values = iterator - if len(shuffled_col_values) == data.shape[0]: - new_col = shuffled_col_values - else: - new_col = np.repeat(shuffled_col_values, counts) + if len(shuffled_col_values) < data.shape[0]: + shuffled_col_values = np.repeat(shuffled_col_values, counts) permute = np.array(data) - permute[:,col_to_permute-1] = new_col + permute[:,col_to_permute-1] = shuffled_col_values return permute diff --git a/hierarch/numba_overloads.py b/hierarch/numba_overloads.py new file mode 100644 index 0000000..2a17bcb --- /dev/null +++ b/hierarch/numba_overloads.py @@ -0,0 +1,69 @@ +import numpy as np +import numba as nb +from numba import types +from numba.extending import overload, register_jitable +from numba.core.errors import TypingError + + +@register_jitable +def _np_all_flat(x): + out = x.all() + return out + +@register_jitable +def _np_all_axis1(arr): + out = np.logical_and(arr[:,0], arr[:,1]) + for idx,v in enumerate(arr[:,2:]): + for v_2 in iter(v): + out[idx] = np.logical_and(v_2, out[idx]) + return out + +@register_jitable +def _np_all_axis0(arr): + out = np.logical_and(arr[0], arr[1]) + for v in iter(arr[2:]): + for idx, v_2 in enumerate(v): + out[idx] = np.logical_and(v_2, out[idx]) + return out + +@overload(np.all) +def np_all(x, axis=None): + + # Generalization of Numba's overload for ndarray.all with axis arguments for 2D arrays. + + + if isinstance(axis, types.Optional): + axis = axis.type + + if not isinstance(axis, (types.Integer, types.NoneType)): + raise TypingError("'axis' must be 0, 1, or None") + + if not isinstance(x, types.Array): + raise TypingError('Only accepts NumPy ndarray') + + if not (1 <= x.ndim <= 2): + raise TypingError('Only supports 1D or 2D NumPy ndarrays') + + if isinstance(axis, types.NoneType): + def _np_all_impl(x, axis=None): + return _np_all_flat(x) + return _np_all_impl + + elif x.ndim == 1: + def _np_all_impl(x, axis=None): + return _np_all_flat(x) + return _np_all_impl + + elif x.ndim == 2: + def _np_all_impl(x, axis=None): + if axis == 0: + return _np_all_axis0(x) + else: + return _np_all_axis1(x) + return _np_all_impl + + else: + def _np_all_impl(x, axis=None): + return _np_all_flat(x) + return _np_all_impl +