Skip to content

Commit

Permalink
moved overloads to numba_overloads.py, improved speed of column permu…
Browse files Browse the repository at this point in the history
…tations
  • Loading branch information
rishi-kulkarni committed May 2, 2021
1 parent da6ccbf commit d0dade8
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 29 deletions.
68 changes: 39 additions & 29 deletions hierarch/internal_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import scipy.stats as stats
from collections import Counter
import sympy.utilities.iterables as iterables
import hierarch.numba_overloads

@nb.njit
def nb_data_grabber(data, col, treatment_labels):
Expand All @@ -16,8 +17,6 @@ def nb_unique(input_data, axis=0):
'''
Internal function that serves the same purpose as np.unique(a, return_index=True, return_counts=True) when called on a 2D arrays. Appears to asymptotically approach np.unique's speed when every row is unique, but otherwise runs faster.
Note: the returned indexes are NOT the indexes of the unique rows in the original data, they are the indexes of the unique rows in the sorted data. This doesn't make any difference so long as input_data is sorted.
Parameters
----------
input_data: 2D array
Expand All @@ -34,25 +33,34 @@ def nb_unique(input_data, axis=0):
counts: number of instances of each unique row (or column) in the input array
'''

#don't want to sort original data
if axis == 1:
data = input_data.T.copy()
else:
data = input_data.copy()

for i in range(data.shape[1]-1,-1,-1):
data = data[data[:,i].argsort(kind="mergesort")]

idx = np.zeros(1, dtype=np.int64)
counts = np.ones(1, dtype=np.int64)

additional_uniques = np.where(~np_all_axis1(data[:-1] == data[1:]))[0] + 1
#so we can remember the original indexes of each row
orig_idx = np.array([i for i in range(data.shape[0])])

#sort our data AND the original indexes
for i in range(data.shape[1]-1,-1,-1):
sorter = data[:,i].argsort(kind="mergesort") #mergesort to keep associations
data = data[sorter]
orig_idx = orig_idx[sorter]


#get original indexes
idx = [0]
additional_uniques = np.nonzero(~np.all((data[:-1] == data[1:]),axis=1))[0] + 1
idx = np.append(idx, additional_uniques)
counts = idx[1:].copy()
counts = np.append(counts, data.shape[0])

#get counts for each unique row
counts = [1]
counts = np.append(idx[1:], data.shape[0])
counts = counts - idx
return data[idx], idx, counts

return data[idx], orig_idx[idx], counts

@nb.jit(nopython=True)
def welch_statistic(data, col, treatment_labels):
Expand All @@ -74,15 +82,18 @@ def welch_statistic(data, col, treatment_labels):
Note: The formula for Welch's t reduces to Student's t when sample_a and sample_b are the same size, so use this function whenever you need a t statistic.
'''

sample_a, sample_b = nb_data_grabber(data, 0, treatment_labels)

meandiff = (np.mean(sample_a) - np.mean(sample_b))
#get our two samples from the data matrix
sample_a, sample_b = nb_data_grabber(data, col, treatment_labels)

var_weight_one = (np.var(sample_a)*(sample_a.size/(sample_a.size - 1))) / len(sample_a)
#mean difference
meandiff = (np.mean(sample_a) - np.mean(sample_b))

#weighted sample variances
var_weight_one = (np.var(sample_a)*(sample_a.size/(sample_a.size - 1))) / len(sample_a)
var_weight_two = (np.var(sample_b)*(sample_b.size/(sample_b.size - 1))) / len(sample_b)

#compute t statistic
t = meandiff / np.sqrt(var_weight_one + var_weight_two)
return t

Expand All @@ -97,7 +108,7 @@ def quick_shuffle(w):


@nb.jit()
def randomize_chunks(values, keys):
def randomize_chunks(shuffled_col, splits):
'''
Internal function for permuting a column a data while paying attention to the dependency structure of the prior column. Numba's implementation of np.random.permutation is faster than numpy's, so we're using this.
Expand All @@ -115,10 +126,11 @@ def randomize_chunks(values, keys):
List of permuted values of the second-to-last column of the values array
'''
append_col=np.empty(0, dtype=np.float64)
for i in keys:
append_col = np.hstack((append_col, np.random.permutation(values[np_all_axis1(values[:,:-2] == values[i][:-2])][:,-2])))
return append_col
for idx, _ in enumerate(splits):
if idx < len(splits)-1:
np.random.shuffle(shuffled_col[splits[idx]:splits[idx+1]-1])
else:
np.random.shuffle(shuffled_col[splits[idx]:])

@nb.jit(nopython=True, cache=True)
def np_all_axis1(x):
Expand Down Expand Up @@ -235,24 +247,22 @@ def permute_column(data, col_to_permute=-2, iterator=None, seed=None):


if iterator == None:
shuffled_col_values = data[:,col_to_permute-1][indexes]
if col_to_permute == 1:
shuffled_col_values = data[:,col_to_permute-1][indexes]
quick_shuffle(shuffled_col_values)
else:
keys = unique_idx_w_cache(values)[-2]
shuffled_col_values = randomize_chunks(values, keys)
randomize_chunks(shuffled_col_values, keys)

else:
shuffled_col_values = iterator


if len(shuffled_col_values) == data.shape[0]:
new_col = shuffled_col_values
else:
new_col = np.repeat(shuffled_col_values, counts)
if len(shuffled_col_values) < data.shape[0]:
shuffled_col_values = np.repeat(shuffled_col_values, counts)

permute = np.array(data)
permute[:,col_to_permute-1] = new_col
permute[:,col_to_permute-1] = shuffled_col_values

return permute

Expand Down
69 changes: 69 additions & 0 deletions hierarch/numba_overloads.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import numpy as np
import numba as nb
from numba import types
from numba.extending import overload, register_jitable
from numba.core.errors import TypingError


@register_jitable
def _np_all_flat(x):
out = x.all()
return out

@register_jitable
def _np_all_axis1(arr):
out = np.logical_and(arr[:,0], arr[:,1])
for idx,v in enumerate(arr[:,2:]):
for v_2 in iter(v):
out[idx] = np.logical_and(v_2, out[idx])
return out

@register_jitable
def _np_all_axis0(arr):
out = np.logical_and(arr[0], arr[1])
for v in iter(arr[2:]):
for idx, v_2 in enumerate(v):
out[idx] = np.logical_and(v_2, out[idx])
return out

@overload(np.all)
def np_all(x, axis=None):

# Generalization of Numba's overload for ndarray.all with axis arguments for 2D arrays.


if isinstance(axis, types.Optional):
axis = axis.type

if not isinstance(axis, (types.Integer, types.NoneType)):
raise TypingError("'axis' must be 0, 1, or None")

if not isinstance(x, types.Array):
raise TypingError('Only accepts NumPy ndarray')

if not (1 <= x.ndim <= 2):
raise TypingError('Only supports 1D or 2D NumPy ndarrays')

if isinstance(axis, types.NoneType):
def _np_all_impl(x, axis=None):
return _np_all_flat(x)
return _np_all_impl

elif x.ndim == 1:
def _np_all_impl(x, axis=None):
return _np_all_flat(x)
return _np_all_impl

elif x.ndim == 2:
def _np_all_impl(x, axis=None):
if axis == 0:
return _np_all_axis0(x)
else:
return _np_all_axis1(x)
return _np_all_impl

else:
def _np_all_impl(x, axis=None):
return _np_all_flat(x)
return _np_all_impl

0 comments on commit d0dade8

Please sign in to comment.