Skip to content

Commit

Permalink
Merge pull request #184 from siemdejong/issue-183
Browse files Browse the repository at this point in the history
fix: changed Subplot to Axes, changed the minimum support version to Python 3.9, upgraded dependencies.
  • Loading branch information
Jacobluke- authored Sep 25, 2024
2 parents 360fdc3 + 3ac4852 commit d3d531f
Show file tree
Hide file tree
Showing 10 changed files with 124 additions and 98 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test-pytest.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ jobs:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: 3.8
python-version: 3.9
cache: "pip"
cache-dependency-path: settings.ini
- name: Run pytest
Expand Down
5 changes: 3 additions & 2 deletions dabest/_dabest_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
# %% ../nbs/API/dabest_object.ipynb 4
# Import standard data science libraries
from numpy import array, repeat, random, issubdtype, number
import numpy as np
import pandas as pd
from scipy.stats import norm
from scipy.stats import randint
Expand Down Expand Up @@ -479,7 +480,7 @@ def _check_errors(self, x, y, idx, experiment, experiment_label, x1_level):

# Handling str type condition
if is_str_condition_met:
if len(pd.unique(idx).tolist()) != 2:
if len(np.unique(idx).tolist()) != 2:
err0 = "`mini_meta` is True, but `idx` ({})".format(idx)
err1 = "does not contain exactly 2 unique columns."
raise ValueError(err0 + err1)
Expand Down Expand Up @@ -667,7 +668,7 @@ def _get_plot_data(self, x, y, all_plot_groups):
all_plot_groups, ordered=True, inplace=True
)
else:
plot_data.loc[:, self.__xvar] = pd.Categorical(
plot_data[self.__xvar] = pd.Categorical(
plot_data[self.__xvar], categories=all_plot_groups, ordered=True
)

Expand Down
2 changes: 2 additions & 0 deletions dabest/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@
'dabest.misc_tools.get_kwargs': ('API/misc_tools.html#get_kwargs', 'dabest/misc_tools.py'),
'dabest.misc_tools.get_params': ('API/misc_tools.html#get_params', 'dabest/misc_tools.py'),
'dabest.misc_tools.get_plot_groups': ('API/misc_tools.html#get_plot_groups', 'dabest/misc_tools.py'),
'dabest.misc_tools.get_unique_categories': ( 'API/misc_tools.html#get_unique_categories',
'dabest/misc_tools.py'),
'dabest.misc_tools.get_varname': ('API/misc_tools.html#get_varname', 'dabest/misc_tools.py'),
'dabest.misc_tools.initialize_fig': ('API/misc_tools.html#initialize_fig', 'dabest/misc_tools.py'),
'dabest.misc_tools.merge_two_dicts': ('API/misc_tools.html#merge_two_dicts', 'dabest/misc_tools.py'),
Expand Down
41 changes: 28 additions & 13 deletions dabest/misc_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/API/misc_tools.ipynb.

# %% auto 0
__all__ = ['merge_two_dicts', 'unpack_and_add', 'print_greeting', 'get_varname', 'get_params', 'get_kwargs', 'get_color_palette',
'initialize_fig', 'get_plot_groups', 'add_counts_to_ticks', 'extract_contrast_plotting_ticks',
'set_xaxis_ticks_and_lims', 'show_legend', 'Gardner_Altman_Plot_Aesthetic_Adjustments',
'Cumming_Plot_Aesthetic_Adjustments', 'General_Plot_Aesthetic_Adjustments']
__all__ = ['merge_two_dicts', 'unpack_and_add', 'print_greeting', 'get_varname', 'get_unique_categories', 'get_params',
'get_kwargs', 'get_color_palette', 'initialize_fig', 'get_plot_groups', 'add_counts_to_ticks',
'extract_contrast_plotting_ticks', 'set_xaxis_ticks_and_lims', 'show_legend',
'Gardner_Altman_Plot_Aesthetic_Adjustments', 'Cumming_Plot_Aesthetic_Adjustments',
'General_Plot_Aesthetic_Adjustments']

# %% ../nbs/API/misc_tools.ipynb 4
import datetime as dt
Expand Down Expand Up @@ -78,6 +79,19 @@ def get_varname(obj):
if len(matching_vars) > 0:
return matching_vars[0]
return ""


def get_unique_categories(names):
"""
Extract unique categories from various input types.
"""
if isinstance(names, np.ndarray):
return names # numpy.unique() returns a sorted array
elif isinstance(names, (pd.Categorical, pd.Series)):
return names.cat.categories if hasattr(names, 'cat') else names.unique()
else:
# For dict_keys and other iterables
return np.unique(list(names))

def get_params(effectsize_df, plot_kwargs):
"""
Expand Down Expand Up @@ -369,6 +383,7 @@ def get_color_palette(plot_kwargs, plot_data, xvar, show_pairs, idx, all_plot_gr
raise ValueError(err1 + err2)

if custom_pal is None and color_col is None:
categories = get_unique_categories(names)
swarm_colors = [sns.desaturate(c, swarm_desat) for c in unsat_colors]
contrast_colors = [sns.desaturate(c, contrast_desat) for c in unsat_colors]
bar_color = [sns.desaturate(c, bar_desat) for c in unsat_colors]
Expand All @@ -382,9 +397,9 @@ def get_color_palette(plot_kwargs, plot_data, xvar, show_pairs, idx, all_plot_gr
plot_palette_raw[names_i] = swarm_colors[i]
plot_palette_contrast[names_i] = contrast_colors[i]
else:
plot_palette_raw = dict(zip(names.categories, swarm_colors))
plot_palette_contrast = dict(zip(names.categories, contrast_colors))
plot_palette_bar = dict(zip(names.categories, bar_color))
plot_palette_raw = dict(zip(categories, swarm_colors))
plot_palette_contrast = dict(zip(categories, contrast_colors))
plot_palette_bar = dict(zip(categories, bar_color))

# For Sankey Diagram plot, no need to worry about the color, each bar will have the same two colors
# default color palette will be set to "hls"
Expand Down Expand Up @@ -541,7 +556,7 @@ def get_plot_groups(is_paired, idx, proportional, all_plot_groups):

def add_counts_to_ticks(plot_data, xvar, yvar, rawdata_axes, plot_kwargs):
# Add the counts to the rawdata axes xticks.
counts = plot_data.groupby(xvar).count()[yvar]
counts = plot_data.groupby(xvar, observed=False).count()[yvar]

def lookup_value(text):
try:
Expand Down Expand Up @@ -695,19 +710,19 @@ def Gardner_Altman_Plot_Aesthetic_Adjustments(effect_size_type, plot_data, xvar,
# Check that the effect size is within the swarm ylims.
if effect_size_type in ["mean_diff", "cohens_d", "hedges_g", "cohens_h"]:
control_group_summary = (
plot_data.groupby(xvar)
plot_data.groupby(xvar, observed=False)
.mean(numeric_only=True)
.loc[current_control, yvar]
)
test_group_summary = (
plot_data.groupby(xvar).mean(numeric_only=True).loc[current_group, yvar]
plot_data.groupby(xvar, observed=False).mean(numeric_only=True).loc[current_group, yvar]
)
elif effect_size_type == "median_diff":
control_group_summary = (
plot_data.groupby(xvar).median().loc[current_control, yvar]
plot_data.groupby(xvar, observed=False).median(numeric_only=True).loc[current_control, yvar]
)
test_group_summary = (
plot_data.groupby(xvar).median().loc[current_group, yvar]
plot_data.groupby(xvar, observed=False).median(numeric_only=True).loc[current_group, yvar]
)

if swarm_ylim is None:
Expand Down Expand Up @@ -751,7 +766,7 @@ def Gardner_Altman_Plot_Aesthetic_Adjustments(effect_size_type, plot_data, xvar,
pooled_sd = stds[0]

if effect_size_type == "hedges_g":
gby_count = plot_data.groupby(xvar).count()
gby_count = plot_data.groupby(xvar, observed=False).count()
len_control = gby_count.loc[current_control, yvar]
len_test = gby_count.loc[current_group, yvar]

Expand Down
58 changes: 29 additions & 29 deletions dabest/plot_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,25 +120,25 @@ def error_bar(
else:
group_order = pd.unique(data[x])

means = data.groupby(x)[y].mean().reindex(index=group_order)
means = data.groupby(x, observed=False)[y].mean().reindex(index=group_order)

if method in ["proportional_error_bar", "sankey_error_bar"]:
g = lambda x: np.sqrt(
(np.sum(x) * (len(x) - np.sum(x))) / (len(x) * len(x) * len(x))
)
sd = data.groupby(x)[y].apply(g)
sd = data.groupby(x, observed=False)[y].apply(g)
else:
sd = data.groupby(x)[y].std().reindex(index=group_order)
sd = data.groupby(x, observed=False)[y].std().reindex(index=group_order)

lower_sd = means - sd
upper_sd = means + sd

if (lower_sd < ax_ylims[0]).any() or (upper_sd > ax_ylims[1]).any():
kwargs["clip_on"] = True

medians = data.groupby(x)[y].median().reindex(index=group_order)
medians = data.groupby(x, observed=False)[y].median().reindex(index=group_order)
quantiles = (
data.groupby(x)[y].quantile([0.25, 0.75]).unstack().reindex(index=group_order)
data.groupby(x, observed=False)[y].quantile([0.25, 0.75]).unstack().reindex(index=group_order)
)
lower_quartiles = quantiles[0.25]
upper_quartiles = quantiles[0.75]
Expand Down Expand Up @@ -978,7 +978,7 @@ def swarm_bars_plotter(plot_data: object, xvar: str, yvar: str, ax: object,
else:
swarm_bars_order = pd.unique(plot_data[xvar])

swarm_means = plot_data.groupby(xvar)[yvar].mean().reindex(index=swarm_bars_order)
swarm_means = plot_data.groupby(xvar, observed=False)[yvar].mean().reindex(index=swarm_bars_order)
swarm_bars_colors = (
[swarm_bars_kwargs.get('color')] * (max(swarm_bars_order) + 1)
if swarm_bars_kwargs.get('color') is not None
Expand Down Expand Up @@ -1199,7 +1199,7 @@ def slopegraph_plotter(dabest_obj, plot_data, xvar, yvar, color_col, plot_palett
if color_col is None:
slopegraph_kwargs["color"] = ytick_color
else:
color_key = observation[color_col][0]
color_key = observation[color_col].iloc[0]
if isinstance(color_key, (str, np.int64, np.float64)):
slopegraph_kwargs["color"] = plot_palette_raw[color_key]
slopegraph_kwargs["label"] = color_key
Expand Down Expand Up @@ -1497,7 +1497,7 @@ def swarmplot(
data: pd.DataFrame,
x: str,
y: str,
ax: axes.Subplot,
ax: axes.Axes,
order: List = None,
hue: str = None,
palette: Union[Iterable, str] = "black",
Expand All @@ -1521,8 +1521,8 @@ def swarmplot(
The column in the DataFrame to be used as the x-axis.
y : str
The column in the DataFrame to be used as the y-axis.
ax : axes._subplots.Subplot | axes._axes.Axes
Matplotlib AxesSubplot object for which the plot would be drawn on. Default is None.
ax : axes.Axes
Matplotlib axes.Axes object for which the plot would be drawn on. Default is None.
order : List
The order in which x-axis categories should be displayed. Default is None.
hue : str
Expand Down Expand Up @@ -1552,8 +1552,8 @@ def swarmplot(
Returns
-------
axes._subplots.Subplot | axes._axes.Axes
Matplotlib AxesSubplot object for which the swarm plot has been drawn on.
axes.Axes
Matplotlib axes.Axes object for which the swarm plot has been drawn on.
"""
s = SwarmPlot(data, x, y, ax, order, hue, palette, zorder, size, side, jitter)
ax = s.plot(is_drop_gutter, gutter_limit, ax, filled, **kwargs)
Expand All @@ -1566,7 +1566,7 @@ def __init__(
data: pd.DataFrame,
x: str,
y: str,
ax: axes.Subplot,
ax: axes.Axes,
order: List = None,
hue: str = None,
palette: Union[Iterable, str] = "black",
Expand All @@ -1586,8 +1586,8 @@ def __init__(
The column in the DataFrame to be used as the x-axis.
y : str
The column in the DataFrame to be used as the y-axis.
ax : axes.Subplot
Matplotlib AxesSubplot object for which the plot would be drawn on.
ax : axes.Axes
Matplotlib axes.Axes object for which the plot would be drawn on.
order : List
The order in which x-axis categories should be displayed. Default is None.
hue : str
Expand Down Expand Up @@ -1674,7 +1674,7 @@ def __init__(
self.__dsize = dsize

def _check_errors(
self, data: pd.DataFrame, ax: axes.Subplot, size: float, side: str
self, data: pd.DataFrame, ax: axes.Axes, size: float, side: str
) -> None:
"""
Check the validity of input parameters. Raises exceptions if detected.
Expand All @@ -1683,8 +1683,8 @@ def _check_errors(
----------
data : pd.Dataframe
Input data used for generation of the swarmplot.
ax : axes.Subplot
Matplotlib AxesSubplot object for which the plot would be drawn on.
ax : axes.Axes
Matplotlib axes.Axes object for which the plot would be drawn on.
size : int | float
scalar value determining size of dots of the swarmplot.
side: str
Expand All @@ -1697,9 +1697,9 @@ def _check_errors(
# Type enforcement
if not isinstance(data, pd.DataFrame):
raise ValueError("`data` must be a Pandas Dataframe.")
if not isinstance(ax, (axes._subplots.Subplot, axes._axes.Axes)):
if not isinstance(ax, axes.Axes):
raise ValueError(
f"`ax` must be a Matplotlib AxesSubplot. The current `ax` is a {type(ax)}"
f"`ax` must be a Matplotlib axes.Axes. The current `ax` is a {type(ax)}"
)
if not isinstance(size, (int, float)):
raise ValueError("`size` must be a scalar or float.")
Expand Down Expand Up @@ -1859,9 +1859,10 @@ def _swarm(
raise ValueError("`dsize` must be a scalar or float.")

# Sorting algorithm based off of: https://github.com/mgymrek/pybeeswarm
points_data = pd.DataFrame(
{"y": [yval * 1.0 / dsize for yval in values], "x": [0] * len(values)}
)
points_data = pd.DataFrame({
"y": [yval * 1.0 / dsize for yval in values],
"x": np.zeros(len(values), dtype=float) # Initialize with float zeros
})
for i in range(1, points_data.shape[0]):
y_i = points_data["y"].values[i]
points_placed = points_data[0:i]
Expand Down Expand Up @@ -1968,7 +1969,7 @@ def plot(
ax: axes.Subplot,
filled: Union[bool, List, Tuple],
**kwargs,
) -> axes.Subplot:
) -> axes.Axes:
"""
Generate a swarm plot.
Expand All @@ -1978,7 +1979,7 @@ def plot(
If True, drop points that hit the gutters; otherwise, readjust them.
gutter_limit : int | float
The limit for points hitting the gutters.
ax : axes.Subplot
ax : axes.Axes
The matplotlib figure object to which the swarm plot will be added.
filled : bool | List | Tuple
Determines whether the dots in the swarmplot are filled or not. If set to False,
Expand All @@ -1990,8 +1991,8 @@ def plot(
Returns
-------
axes.Subplot:
The matplotlib figure containing the swarm plot.
axes.Axes:
The matplotlib axes containing the swarm plot.
"""
# Input validation
if not isinstance(is_drop_gutter, bool):
Expand Down Expand Up @@ -2019,8 +2020,7 @@ def plot(
0 # x-coordinate of center of each individual swarm of the swarm plot
)
x_tick_tabels = []

for group_i, values_i in self.__data_copy.groupby(self.__x):
for group_i, values_i in self.__data_copy.groupby(self.__x, observed=False):
x_new = []
values_i_y = values_i[self.__y]
x_offset = self._swarm(
Expand Down
5 changes: 3 additions & 2 deletions nbs/API/dabest_object.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
"#| export\n",
"# Import standard data science libraries\n",
"from numpy import array, repeat, random, issubdtype, number\n",
"import numpy as np\n",
"import pandas as pd\n",
"from scipy.stats import norm\n",
"from scipy.stats import randint"
Expand Down Expand Up @@ -547,7 +548,7 @@
"\n",
" # Handling str type condition\n",
" if is_str_condition_met:\n",
" if len(pd.unique(idx).tolist()) != 2:\n",
" if len(np.unique(idx).tolist()) != 2:\n",
" err0 = \"`mini_meta` is True, but `idx` ({})\".format(idx)\n",
" err1 = \"does not contain exactly 2 unique columns.\"\n",
" raise ValueError(err0 + err1)\n",
Expand Down Expand Up @@ -735,7 +736,7 @@
" all_plot_groups, ordered=True, inplace=True\n",
" )\n",
" else:\n",
" plot_data.loc[:, self.__xvar] = pd.Categorical(\n",
" plot_data[self.__xvar] = pd.Categorical(\n",
" plot_data[self.__xvar], categories=all_plot_groups, ordered=True\n",
" )\n",
"\n",
Expand Down
Loading

0 comments on commit d3d531f

Please sign in to comment.