Skip to content

Commit

Permalink
treatment_col can be specified using a column name if the input data …
Browse files Browse the repository at this point in the history
…is a DataFrame
  • Loading branch information
rishi-kulkarni committed Jul 7, 2021
1 parent a4800a0 commit 12ae78a
Showing 1 changed file with 17 additions and 8 deletions.
25 changes: 17 additions & 8 deletions hierarch/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def _grabber(X, y, treatment_labels):

def hypothesis_test(
data_array,
treatment_col: int,
treatment_col,
compare="corr",
alternative="two-sided",
skip=None,
Expand All @@ -261,9 +261,10 @@ def hypothesis_test(
Array-like containing both the independent and dependent variables to
be analyzed. It's assumed that the final (rightmost) column
contains the dependent variable values.
treatment_col : int
treatment_col : int or str
The index number of the column containing "two samples" to be compared.
Indexing starts at 0.
Indexing starts at 0. If input data is a pandas DataFrame, this can be
the name of the column.
compare : str, optional
The test statistic to use to perform the hypothesis test, by default "corr"
which automatically calls the studentized covariance test statistic.
Expand Down Expand Up @@ -354,6 +355,8 @@ def hypothesis_test(

# turns the input array or dataframe into a float64 array
if isinstance(data_array, (np.ndarray, pd.DataFrame)):
if isinstance(data_array, pd.DataFrame) and isinstance(treatment_col, str):
treatment_col = int(data_array.columns.get_loc(treatment_col))
data = _preprocess_data(data_array)
else:
raise TypeError("Input data must be ndarray or DataFrame.")
Expand Down Expand Up @@ -490,7 +493,7 @@ def hypothesis_test(

def multi_sample_test(
data_array,
treatment_col: int,
treatment_col,
hypotheses="all",
correction="fdr",
compare="means",
Expand All @@ -511,9 +514,10 @@ def multi_sample_test(
Array-like containing both the independent and dependent variables to
be analyzed. It's assumed that the final (rightmost) column
contains the dependent variable values.
treatment_col : int
treatment_col : int or str
The index number of the column containing labels to be compared.
Indexing starts at 0.
Indexing starts at 0. If input data is a pandas DataFrame, this can
be the column name.
hypotheses : list of two-element lists or "all", optional
Hypotheses to be tested. If 'all' every pairwise comparison will be
tested. Can be passed a list of lists to restrict comparisons, which
Expand Down Expand Up @@ -667,6 +671,8 @@ def multi_sample_test(

# coerce data into an object array
if isinstance(data_array, pd.DataFrame):
if isinstance(treatment_col, str):
treatment_col = data_array.columns.get_loc(treatment_col)
data = data_array.to_numpy()
elif isinstance(data_array, np.ndarray):
data = data_array
Expand Down Expand Up @@ -870,9 +876,10 @@ def confidence_interval(
Array-like containing both the independent and dependent variables to
be analyzed. It's assumed that the final (rightmost) column
contains the dependent variable values.
treatment_col : int
treatment_col : int or str
The index number of the column containing "two samples" to be compared.
Indexing starts at 0.
Indexing starts at 0. If input data is a pandas DataFrame, this can be
the column name.
interval : float, optional
Percentage value indicating the confidence interval's coverage, by default 95
iterations : int, optional
Expand Down Expand Up @@ -980,6 +987,8 @@ def confidence_interval(

# turns the input array or dataframe into a float64 array
if isinstance(data_array, (np.ndarray, pd.DataFrame)):
if isinstance(data_array, pd.DataFrame) and isinstance(treatment_col, str):
treatment_col = int(data_array.columns.get_loc(treatment_col))
data = _preprocess_data(data_array)
else:
raise TypeError("Input data must be ndarray or DataFrame.")
Expand Down

0 comments on commit 12ae78a

Please sign in to comment.