diff --git a/.github/workflows/test-pytest.yaml b/.github/workflows/test-pytest.yaml index 344c88b7..599c62a6 100644 --- a/.github/workflows/test-pytest.yaml +++ b/.github/workflows/test-pytest.yaml @@ -8,7 +8,7 @@ jobs: - uses: actions/checkout@v3 - uses: actions/setup-python@v4 with: - python-version: 3.9 + python-version: 3.8 cache: "pip" cache-dependency-path: settings.ini - name: Run pytest diff --git a/dabest/_dabest_object.py b/dabest/_dabest_object.py index d61245dd..ec917b03 100644 --- a/dabest/_dabest_object.py +++ b/dabest/_dabest_object.py @@ -8,7 +8,6 @@ # %% ../nbs/API/dabest_object.ipynb 4 # Import standard data science libraries from numpy import array, repeat, random, issubdtype, number -import numpy as np import pandas as pd from scipy.stats import norm from scipy.stats import randint @@ -480,7 +479,7 @@ def _check_errors(self, x, y, idx, experiment, experiment_label, x1_level): # Handling str type condition if is_str_condition_met: - if len(np.unique(idx).tolist()) != 2: + if len(pd.unique(idx).tolist()) != 2: err0 = "`mini_meta` is True, but `idx` ({})".format(idx) err1 = "does not contain exactly 2 unique columns." raise ValueError(err0 + err1) @@ -668,7 +667,7 @@ def _get_plot_data(self, x, y, all_plot_groups): all_plot_groups, ordered=True, inplace=True ) else: - plot_data[self.__xvar] = pd.Categorical( + plot_data.loc[:, self.__xvar] = pd.Categorical( plot_data[self.__xvar], categories=all_plot_groups, ordered=True ) diff --git a/dabest/_modidx.py b/dabest/_modidx.py index eb260d99..5c62ae70 100644 --- a/dabest/_modidx.py +++ b/dabest/_modidx.py @@ -81,8 +81,6 @@ 'dabest.misc_tools.get_kwargs': ('API/misc_tools.html#get_kwargs', 'dabest/misc_tools.py'), 'dabest.misc_tools.get_params': ('API/misc_tools.html#get_params', 'dabest/misc_tools.py'), 'dabest.misc_tools.get_plot_groups': ('API/misc_tools.html#get_plot_groups', 'dabest/misc_tools.py'), - 'dabest.misc_tools.get_unique_categories': ( 'API/misc_tools.html#get_unique_categories', - 'dabest/misc_tools.py'), 'dabest.misc_tools.get_varname': ('API/misc_tools.html#get_varname', 'dabest/misc_tools.py'), 'dabest.misc_tools.initialize_fig': ('API/misc_tools.html#initialize_fig', 'dabest/misc_tools.py'), 'dabest.misc_tools.merge_two_dicts': ('API/misc_tools.html#merge_two_dicts', 'dabest/misc_tools.py'), diff --git a/dabest/misc_tools.py b/dabest/misc_tools.py index 3295a3e9..ff7e56d9 100644 --- a/dabest/misc_tools.py +++ b/dabest/misc_tools.py @@ -3,11 +3,10 @@ # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/API/misc_tools.ipynb. # %% auto 0 -__all__ = ['merge_two_dicts', 'unpack_and_add', 'print_greeting', 'get_varname', 'get_unique_categories', 'get_params', - 'get_kwargs', 'get_color_palette', 'initialize_fig', 'get_plot_groups', 'add_counts_to_ticks', - 'extract_contrast_plotting_ticks', 'set_xaxis_ticks_and_lims', 'show_legend', - 'Gardner_Altman_Plot_Aesthetic_Adjustments', 'Cumming_Plot_Aesthetic_Adjustments', - 'General_Plot_Aesthetic_Adjustments'] +__all__ = ['merge_two_dicts', 'unpack_and_add', 'print_greeting', 'get_varname', 'get_params', 'get_kwargs', 'get_color_palette', + 'initialize_fig', 'get_plot_groups', 'add_counts_to_ticks', 'extract_contrast_plotting_ticks', + 'set_xaxis_ticks_and_lims', 'show_legend', 'Gardner_Altman_Plot_Aesthetic_Adjustments', + 'Cumming_Plot_Aesthetic_Adjustments', 'General_Plot_Aesthetic_Adjustments'] # %% ../nbs/API/misc_tools.ipynb 4 import datetime as dt @@ -79,19 +78,6 @@ def get_varname(obj): if len(matching_vars) > 0: return matching_vars[0] return "" - - -def get_unique_categories(names): - """ - Extract unique categories from various input types. - """ - if isinstance(names, np.ndarray): - return names # numpy.unique() returns a sorted array - elif isinstance(names, (pd.Categorical, pd.Series)): - return names.cat.categories if hasattr(names, 'cat') else names.unique() - else: - # For dict_keys and other iterables - return np.unique(list(names)) def get_params(effectsize_df, plot_kwargs): """ @@ -383,7 +369,6 @@ def get_color_palette(plot_kwargs, plot_data, xvar, show_pairs, idx, all_plot_gr raise ValueError(err1 + err2) if custom_pal is None and color_col is None: - categories = get_unique_categories(names) swarm_colors = [sns.desaturate(c, swarm_desat) for c in unsat_colors] contrast_colors = [sns.desaturate(c, contrast_desat) for c in unsat_colors] bar_color = [sns.desaturate(c, bar_desat) for c in unsat_colors] @@ -397,9 +382,9 @@ def get_color_palette(plot_kwargs, plot_data, xvar, show_pairs, idx, all_plot_gr plot_palette_raw[names_i] = swarm_colors[i] plot_palette_contrast[names_i] = contrast_colors[i] else: - plot_palette_raw = dict(zip(categories, swarm_colors)) - plot_palette_contrast = dict(zip(categories, contrast_colors)) - plot_palette_bar = dict(zip(categories, bar_color)) + plot_palette_raw = dict(zip(names.categories, swarm_colors)) + plot_palette_contrast = dict(zip(names.categories, contrast_colors)) + plot_palette_bar = dict(zip(names.categories, bar_color)) # For Sankey Diagram plot, no need to worry about the color, each bar will have the same two colors # default color palette will be set to "hls" @@ -556,7 +541,7 @@ def get_plot_groups(is_paired, idx, proportional, all_plot_groups): def add_counts_to_ticks(plot_data, xvar, yvar, rawdata_axes, plot_kwargs): # Add the counts to the rawdata axes xticks. - counts = plot_data.groupby(xvar, observed=False).count()[yvar] + counts = plot_data.groupby(xvar).count()[yvar] def lookup_value(text): try: @@ -710,19 +695,19 @@ def Gardner_Altman_Plot_Aesthetic_Adjustments(effect_size_type, plot_data, xvar, # Check that the effect size is within the swarm ylims. if effect_size_type in ["mean_diff", "cohens_d", "hedges_g", "cohens_h"]: control_group_summary = ( - plot_data.groupby(xvar, observed=False) + plot_data.groupby(xvar) .mean(numeric_only=True) .loc[current_control, yvar] ) test_group_summary = ( - plot_data.groupby(xvar, observed=False).mean(numeric_only=True).loc[current_group, yvar] + plot_data.groupby(xvar).mean(numeric_only=True).loc[current_group, yvar] ) elif effect_size_type == "median_diff": control_group_summary = ( - plot_data.groupby(xvar, observed=False).median(numeric_only=True).loc[current_control, yvar] + plot_data.groupby(xvar).median().loc[current_control, yvar] ) test_group_summary = ( - plot_data.groupby(xvar, observed=False).median(numeric_only=True).loc[current_group, yvar] + plot_data.groupby(xvar).median().loc[current_group, yvar] ) if swarm_ylim is None: @@ -766,7 +751,7 @@ def Gardner_Altman_Plot_Aesthetic_Adjustments(effect_size_type, plot_data, xvar, pooled_sd = stds[0] if effect_size_type == "hedges_g": - gby_count = plot_data.groupby(xvar, observed=False).count() + gby_count = plot_data.groupby(xvar).count() len_control = gby_count.loc[current_control, yvar] len_test = gby_count.loc[current_group, yvar] diff --git a/dabest/plot_tools.py b/dabest/plot_tools.py index 2d82053e..fd56323d 100644 --- a/dabest/plot_tools.py +++ b/dabest/plot_tools.py @@ -120,15 +120,15 @@ def error_bar( else: group_order = pd.unique(data[x]) - means = data.groupby(x, observed=False)[y].mean().reindex(index=group_order) + means = data.groupby(x)[y].mean().reindex(index=group_order) if method in ["proportional_error_bar", "sankey_error_bar"]: g = lambda x: np.sqrt( (np.sum(x) * (len(x) - np.sum(x))) / (len(x) * len(x) * len(x)) ) - sd = data.groupby(x, observed=False)[y].apply(g) + sd = data.groupby(x)[y].apply(g) else: - sd = data.groupby(x, observed=False)[y].std().reindex(index=group_order) + sd = data.groupby(x)[y].std().reindex(index=group_order) lower_sd = means - sd upper_sd = means + sd @@ -136,9 +136,9 @@ def error_bar( if (lower_sd < ax_ylims[0]).any() or (upper_sd > ax_ylims[1]).any(): kwargs["clip_on"] = True - medians = data.groupby(x, observed=False)[y].median().reindex(index=group_order) + medians = data.groupby(x)[y].median().reindex(index=group_order) quantiles = ( - data.groupby(x, observed=False)[y].quantile([0.25, 0.75]).unstack().reindex(index=group_order) + data.groupby(x)[y].quantile([0.25, 0.75]).unstack().reindex(index=group_order) ) lower_quartiles = quantiles[0.25] upper_quartiles = quantiles[0.75] @@ -978,7 +978,7 @@ def swarm_bars_plotter(plot_data: object, xvar: str, yvar: str, ax: object, else: swarm_bars_order = pd.unique(plot_data[xvar]) - swarm_means = plot_data.groupby(xvar, observed=False)[yvar].mean().reindex(index=swarm_bars_order) + swarm_means = plot_data.groupby(xvar)[yvar].mean().reindex(index=swarm_bars_order) swarm_bars_colors = ( [swarm_bars_kwargs.get('color')] * (max(swarm_bars_order) + 1) if swarm_bars_kwargs.get('color') is not None @@ -1199,7 +1199,7 @@ def slopegraph_plotter(dabest_obj, plot_data, xvar, yvar, color_col, plot_palett if color_col is None: slopegraph_kwargs["color"] = ytick_color else: - color_key = observation[color_col].iloc[0] + color_key = observation[color_col][0] if isinstance(color_key, (str, np.int64, np.float64)): slopegraph_kwargs["color"] = plot_palette_raw[color_key] slopegraph_kwargs["label"] = color_key @@ -1497,7 +1497,7 @@ def swarmplot( data: pd.DataFrame, x: str, y: str, - ax: axes.Axes, + ax: axes.Subplot, order: List = None, hue: str = None, palette: Union[Iterable, str] = "black", @@ -1521,8 +1521,8 @@ def swarmplot( The column in the DataFrame to be used as the x-axis. y : str The column in the DataFrame to be used as the y-axis. - ax : axes.Axes - Matplotlib axes.Axes object for which the plot would be drawn on. Default is None. + ax : axes._subplots.Subplot | axes._axes.Axes + Matplotlib AxesSubplot object for which the plot would be drawn on. Default is None. order : List The order in which x-axis categories should be displayed. Default is None. hue : str @@ -1552,8 +1552,8 @@ def swarmplot( Returns ------- - axes.Axes - Matplotlib axes.Axes object for which the swarm plot has been drawn on. + axes._subplots.Subplot | axes._axes.Axes + Matplotlib AxesSubplot object for which the swarm plot has been drawn on. """ s = SwarmPlot(data, x, y, ax, order, hue, palette, zorder, size, side, jitter) ax = s.plot(is_drop_gutter, gutter_limit, ax, filled, **kwargs) @@ -1566,7 +1566,7 @@ def __init__( data: pd.DataFrame, x: str, y: str, - ax: axes.Axes, + ax: axes.Subplot, order: List = None, hue: str = None, palette: Union[Iterable, str] = "black", @@ -1586,8 +1586,8 @@ def __init__( The column in the DataFrame to be used as the x-axis. y : str The column in the DataFrame to be used as the y-axis. - ax : axes.Axes - Matplotlib axes.Axes object for which the plot would be drawn on. + ax : axes.Subplot + Matplotlib AxesSubplot object for which the plot would be drawn on. order : List The order in which x-axis categories should be displayed. Default is None. hue : str @@ -1674,7 +1674,7 @@ def __init__( self.__dsize = dsize def _check_errors( - self, data: pd.DataFrame, ax: axes.Axes, size: float, side: str + self, data: pd.DataFrame, ax: axes.Subplot, size: float, side: str ) -> None: """ Check the validity of input parameters. Raises exceptions if detected. @@ -1683,8 +1683,8 @@ def _check_errors( ---------- data : pd.Dataframe Input data used for generation of the swarmplot. - ax : axes.Axes - Matplotlib axes.Axes object for which the plot would be drawn on. + ax : axes.Subplot + Matplotlib AxesSubplot object for which the plot would be drawn on. size : int | float scalar value determining size of dots of the swarmplot. side: str @@ -1697,9 +1697,9 @@ def _check_errors( # Type enforcement if not isinstance(data, pd.DataFrame): raise ValueError("`data` must be a Pandas Dataframe.") - if not isinstance(ax, axes.Axes): + if not isinstance(ax, (axes._subplots.Subplot, axes._axes.Axes)): raise ValueError( - f"`ax` must be a Matplotlib axes.Axes. The current `ax` is a {type(ax)}" + f"`ax` must be a Matplotlib AxesSubplot. The current `ax` is a {type(ax)}" ) if not isinstance(size, (int, float)): raise ValueError("`size` must be a scalar or float.") @@ -1859,10 +1859,9 @@ def _swarm( raise ValueError("`dsize` must be a scalar or float.") # Sorting algorithm based off of: https://github.com/mgymrek/pybeeswarm - points_data = pd.DataFrame({ - "y": [yval * 1.0 / dsize for yval in values], - "x": np.zeros(len(values), dtype=float) # Initialize with float zeros - }) + points_data = pd.DataFrame( + {"y": [yval * 1.0 / dsize for yval in values], "x": [0] * len(values)} + ) for i in range(1, points_data.shape[0]): y_i = points_data["y"].values[i] points_placed = points_data[0:i] @@ -1969,7 +1968,7 @@ def plot( ax: axes.Subplot, filled: Union[bool, List, Tuple], **kwargs, - ) -> axes.Axes: + ) -> axes.Subplot: """ Generate a swarm plot. @@ -1979,7 +1978,7 @@ def plot( If True, drop points that hit the gutters; otherwise, readjust them. gutter_limit : int | float The limit for points hitting the gutters. - ax : axes.Axes + ax : axes.Subplot The matplotlib figure object to which the swarm plot will be added. filled : bool | List | Tuple Determines whether the dots in the swarmplot are filled or not. If set to False, @@ -1991,8 +1990,8 @@ def plot( Returns ------- - axes.Axes: - The matplotlib axes containing the swarm plot. + axes.Subplot: + The matplotlib figure containing the swarm plot. """ # Input validation if not isinstance(is_drop_gutter, bool): @@ -2020,7 +2019,8 @@ def plot( 0 # x-coordinate of center of each individual swarm of the swarm plot ) x_tick_tabels = [] - for group_i, values_i in self.__data_copy.groupby(self.__x, observed=False): + + for group_i, values_i in self.__data_copy.groupby(self.__x): x_new = [] values_i_y = values_i[self.__y] x_offset = self._swarm( diff --git a/nbs/API/dabest_object.ipynb b/nbs/API/dabest_object.ipynb index 48576a87..776b4fb1 100644 --- a/nbs/API/dabest_object.ipynb +++ b/nbs/API/dabest_object.ipynb @@ -57,7 +57,6 @@ "#| export\n", "# Import standard data science libraries\n", "from numpy import array, repeat, random, issubdtype, number\n", - "import numpy as np\n", "import pandas as pd\n", "from scipy.stats import norm\n", "from scipy.stats import randint" @@ -548,7 +547,7 @@ "\n", " # Handling str type condition\n", " if is_str_condition_met:\n", - " if len(np.unique(idx).tolist()) != 2:\n", + " if len(pd.unique(idx).tolist()) != 2:\n", " err0 = \"`mini_meta` is True, but `idx` ({})\".format(idx)\n", " err1 = \"does not contain exactly 2 unique columns.\"\n", " raise ValueError(err0 + err1)\n", @@ -736,7 +735,7 @@ " all_plot_groups, ordered=True, inplace=True\n", " )\n", " else:\n", - " plot_data[self.__xvar] = pd.Categorical(\n", + " plot_data.loc[:, self.__xvar] = pd.Categorical(\n", " plot_data[self.__xvar], categories=all_plot_groups, ordered=True\n", " )\n", "\n", diff --git a/nbs/API/misc_tools.ipynb b/nbs/API/misc_tools.ipynb index c8440e2b..6a1f40fc 100644 --- a/nbs/API/misc_tools.ipynb +++ b/nbs/API/misc_tools.ipynb @@ -49,12 +49,11 @@ { "cell_type": "code", "execution_count": null, - "id": "3c9a6ef1", + "id": "5f54be1c", "metadata": {}, "outputs": [], "source": [ "#| export\n", - "\n", "import datetime as dt\n", "import numpy as np\n", "from numpy import repeat\n", @@ -67,7 +66,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5f54be1c", + "id": "6b50da46", "metadata": {}, "outputs": [], "source": [ @@ -132,19 +131,6 @@ " if len(matching_vars) > 0:\n", " return matching_vars[0]\n", " return \"\"\n", - "\t\n", - "\n", - "def get_unique_categories(names):\n", - " \"\"\"\n", - " Extract unique categories from various input types.\n", - " \"\"\"\n", - " if isinstance(names, np.ndarray):\n", - " return names # numpy.unique() returns a sorted array\n", - " elif isinstance(names, (pd.Categorical, pd.Series)):\n", - " return names.cat.categories if hasattr(names, 'cat') else names.unique()\n", - " else:\n", - " # For dict_keys and other iterables\n", - " return np.unique(list(names))\n", "\n", "def get_params(effectsize_df, plot_kwargs):\n", " \"\"\"\n", @@ -436,7 +422,6 @@ " raise ValueError(err1 + err2)\n", "\n", " if custom_pal is None and color_col is None:\n", - " categories = get_unique_categories(names)\n", " swarm_colors = [sns.desaturate(c, swarm_desat) for c in unsat_colors]\n", " contrast_colors = [sns.desaturate(c, contrast_desat) for c in unsat_colors]\n", " bar_color = [sns.desaturate(c, bar_desat) for c in unsat_colors]\n", @@ -450,9 +435,9 @@ " plot_palette_raw[names_i] = swarm_colors[i]\n", " plot_palette_contrast[names_i] = contrast_colors[i]\n", " else:\n", - " plot_palette_raw = dict(zip(categories, swarm_colors))\n", - " plot_palette_contrast = dict(zip(categories, contrast_colors))\n", - " plot_palette_bar = dict(zip(categories, bar_color))\n", + " plot_palette_raw = dict(zip(names.categories, swarm_colors))\n", + " plot_palette_contrast = dict(zip(names.categories, contrast_colors))\n", + " plot_palette_bar = dict(zip(names.categories, bar_color))\n", "\n", " # For Sankey Diagram plot, no need to worry about the color, each bar will have the same two colors\n", " # default color palette will be set to \"hls\"\n", @@ -609,7 +594,7 @@ "\n", "def add_counts_to_ticks(plot_data, xvar, yvar, rawdata_axes, plot_kwargs):\n", " # Add the counts to the rawdata axes xticks.\n", - " counts = plot_data.groupby(xvar, observed=False).count()[yvar]\n", + " counts = plot_data.groupby(xvar).count()[yvar]\n", " \n", " def lookup_value(text):\n", " try:\n", @@ -763,19 +748,19 @@ " # Check that the effect size is within the swarm ylims.\n", " if effect_size_type in [\"mean_diff\", \"cohens_d\", \"hedges_g\", \"cohens_h\"]:\n", " control_group_summary = (\n", - " plot_data.groupby(xvar, observed=False)\n", + " plot_data.groupby(xvar)\n", " .mean(numeric_only=True)\n", " .loc[current_control, yvar]\n", " )\n", " test_group_summary = (\n", - " plot_data.groupby(xvar, observed=False).mean(numeric_only=True).loc[current_group, yvar]\n", + " plot_data.groupby(xvar).mean(numeric_only=True).loc[current_group, yvar]\n", " )\n", " elif effect_size_type == \"median_diff\":\n", " control_group_summary = (\n", - " plot_data.groupby(xvar, observed=False).median(numeric_only=True).loc[current_control, yvar]\n", + " plot_data.groupby(xvar).median().loc[current_control, yvar]\n", " )\n", " test_group_summary = (\n", - " plot_data.groupby(xvar, observed=False).median(numeric_only=True).loc[current_group, yvar]\n", + " plot_data.groupby(xvar).median().loc[current_group, yvar]\n", " )\n", "\n", " if swarm_ylim is None:\n", @@ -819,7 +804,7 @@ " pooled_sd = stds[0]\n", "\n", " if effect_size_type == \"hedges_g\":\n", - " gby_count = plot_data.groupby(xvar, observed=False).count()\n", + " gby_count = plot_data.groupby(xvar).count()\n", " len_control = gby_count.loc[current_control, yvar]\n", " len_test = gby_count.loc[current_group, yvar]\n", "\n", diff --git a/nbs/API/plot_tools.ipynb b/nbs/API/plot_tools.ipynb index 524295ab..f60aaff1 100644 --- a/nbs/API/plot_tools.ipynb +++ b/nbs/API/plot_tools.ipynb @@ -171,15 +171,15 @@ " else:\n", " group_order = pd.unique(data[x])\n", "\n", - " means = data.groupby(x, observed=False)[y].mean().reindex(index=group_order)\n", + " means = data.groupby(x)[y].mean().reindex(index=group_order)\n", "\n", " if method in [\"proportional_error_bar\", \"sankey_error_bar\"]:\n", " g = lambda x: np.sqrt(\n", " (np.sum(x) * (len(x) - np.sum(x))) / (len(x) * len(x) * len(x))\n", " )\n", - " sd = data.groupby(x, observed=False)[y].apply(g)\n", + " sd = data.groupby(x)[y].apply(g)\n", " else:\n", - " sd = data.groupby(x, observed=False)[y].std().reindex(index=group_order)\n", + " sd = data.groupby(x)[y].std().reindex(index=group_order)\n", "\n", " lower_sd = means - sd\n", " upper_sd = means + sd\n", @@ -187,9 +187,9 @@ " if (lower_sd < ax_ylims[0]).any() or (upper_sd > ax_ylims[1]).any():\n", " kwargs[\"clip_on\"] = True\n", "\n", - " medians = data.groupby(x, observed=False)[y].median().reindex(index=group_order)\n", + " medians = data.groupby(x)[y].median().reindex(index=group_order)\n", " quantiles = (\n", - " data.groupby(x, observed=False)[y].quantile([0.25, 0.75]).unstack().reindex(index=group_order)\n", + " data.groupby(x)[y].quantile([0.25, 0.75]).unstack().reindex(index=group_order)\n", " )\n", " lower_quartiles = quantiles[0.25]\n", " upper_quartiles = quantiles[0.75]\n", @@ -1029,7 +1029,7 @@ " else:\n", " swarm_bars_order = pd.unique(plot_data[xvar])\n", "\n", - " swarm_means = plot_data.groupby(xvar, observed=False)[yvar].mean().reindex(index=swarm_bars_order)\n", + " swarm_means = plot_data.groupby(xvar)[yvar].mean().reindex(index=swarm_bars_order)\n", " swarm_bars_colors = (\n", " [swarm_bars_kwargs.get('color')] * (max(swarm_bars_order) + 1) \n", " if swarm_bars_kwargs.get('color') is not None \n", @@ -1250,7 +1250,7 @@ " if color_col is None:\n", " slopegraph_kwargs[\"color\"] = ytick_color\n", " else:\n", - " color_key = observation[color_col].iloc[0]\n", + " color_key = observation[color_col][0]\n", " if isinstance(color_key, (str, np.int64, np.float64)):\n", " slopegraph_kwargs[\"color\"] = plot_palette_raw[color_key]\n", " slopegraph_kwargs[\"label\"] = color_key\n", @@ -1556,7 +1556,7 @@ " data: pd.DataFrame,\n", " x: str,\n", " y: str,\n", - " ax: axes.Axes,\n", + " ax: axes.Subplot,\n", " order: List = None,\n", " hue: str = None,\n", " palette: Union[Iterable, str] = \"black\",\n", @@ -1580,8 +1580,8 @@ " The column in the DataFrame to be used as the x-axis.\n", " y : str\n", " The column in the DataFrame to be used as the y-axis.\n", - " ax : axes.Axes\n", - " Matplotlib axes.Axes object for which the plot would be drawn on. Default is None.\n", + " ax : axes._subplots.Subplot | axes._axes.Axes\n", + " Matplotlib AxesSubplot object for which the plot would be drawn on. Default is None.\n", " order : List\n", " The order in which x-axis categories should be displayed. Default is None.\n", " hue : str\n", @@ -1611,8 +1611,8 @@ "\n", " Returns\n", " -------\n", - " axes.Axes\n", - " Matplotlib axes.Axes object for which the swarm plot has been drawn on.\n", + " axes._subplots.Subplot | axes._axes.Axes\n", + " Matplotlib AxesSubplot object for which the swarm plot has been drawn on.\n", " \"\"\"\n", " s = SwarmPlot(data, x, y, ax, order, hue, palette, zorder, size, side, jitter)\n", " ax = s.plot(is_drop_gutter, gutter_limit, ax, filled, **kwargs)\n", @@ -1625,7 +1625,7 @@ " data: pd.DataFrame,\n", " x: str,\n", " y: str,\n", - " ax: axes.Axes,\n", + " ax: axes.Subplot,\n", " order: List = None,\n", " hue: str = None,\n", " palette: Union[Iterable, str] = \"black\",\n", @@ -1645,8 +1645,8 @@ " The column in the DataFrame to be used as the x-axis.\n", " y : str\n", " The column in the DataFrame to be used as the y-axis.\n", - " ax : axes.Axes\n", - " Matplotlib axes.Axes object for which the plot would be drawn on.\n", + " ax : axes.Subplot\n", + " Matplotlib AxesSubplot object for which the plot would be drawn on.\n", " order : List\n", " The order in which x-axis categories should be displayed. Default is None.\n", " hue : str\n", @@ -1733,7 +1733,7 @@ " self.__dsize = dsize\n", "\n", " def _check_errors(\n", - " self, data: pd.DataFrame, ax: axes.Axes, size: float, side: str\n", + " self, data: pd.DataFrame, ax: axes.Subplot, size: float, side: str\n", " ) -> None:\n", " \"\"\"\n", " Check the validity of input parameters. Raises exceptions if detected.\n", @@ -1742,8 +1742,8 @@ " ----------\n", " data : pd.Dataframe\n", " Input data used for generation of the swarmplot.\n", - " ax : axes.Axes\n", - " Matplotlib axes.Axes object for which the plot would be drawn on.\n", + " ax : axes.Subplot\n", + " Matplotlib AxesSubplot object for which the plot would be drawn on.\n", " size : int | float\n", " scalar value determining size of dots of the swarmplot.\n", " side: str\n", @@ -1756,9 +1756,9 @@ " # Type enforcement\n", " if not isinstance(data, pd.DataFrame):\n", " raise ValueError(\"`data` must be a Pandas Dataframe.\")\n", - " if not isinstance(ax, axes.Axes):\n", + " if not isinstance(ax, (axes._subplots.Subplot, axes._axes.Axes)):\n", " raise ValueError(\n", - " f\"`ax` must be a Matplotlib axes.Axes. The current `ax` is a {type(ax)}\"\n", + " f\"`ax` must be a Matplotlib AxesSubplot. The current `ax` is a {type(ax)}\"\n", " )\n", " if not isinstance(size, (int, float)):\n", " raise ValueError(\"`size` must be a scalar or float.\")\n", @@ -1918,10 +1918,9 @@ " raise ValueError(\"`dsize` must be a scalar or float.\")\n", "\n", " # Sorting algorithm based off of: https://github.com/mgymrek/pybeeswarm\n", - " points_data = pd.DataFrame({\n", - " \"y\": [yval * 1.0 / dsize for yval in values],\n", - " \"x\": np.zeros(len(values), dtype=float) # Initialize with float zeros\n", - " })\n", + " points_data = pd.DataFrame(\n", + " {\"y\": [yval * 1.0 / dsize for yval in values], \"x\": [0] * len(values)}\n", + " )\n", " for i in range(1, points_data.shape[0]):\n", " y_i = points_data[\"y\"].values[i]\n", " points_placed = points_data[0:i]\n", @@ -2028,7 +2027,7 @@ " ax: axes.Subplot,\n", " filled: Union[bool, List, Tuple],\n", " **kwargs,\n", - " ) -> axes.Axes:\n", + " ) -> axes.Subplot:\n", " \"\"\"\n", " Generate a swarm plot.\n", "\n", @@ -2038,7 +2037,7 @@ " If True, drop points that hit the gutters; otherwise, readjust them.\n", " gutter_limit : int | float\n", " The limit for points hitting the gutters.\n", - " ax : axes.Axes\n", + " ax : axes.Subplot\n", " The matplotlib figure object to which the swarm plot will be added.\n", " filled : bool | List | Tuple\n", " Determines whether the dots in the swarmplot are filled or not. If set to False,\n", @@ -2050,8 +2049,8 @@ "\n", " Returns\n", " -------\n", - " axes.Axes:\n", - " The matplotlib axes containing the swarm plot.\n", + " axes.Subplot:\n", + " The matplotlib figure containing the swarm plot.\n", " \"\"\"\n", " # Input validation\n", " if not isinstance(is_drop_gutter, bool):\n", @@ -2079,7 +2078,8 @@ " 0 # x-coordinate of center of each individual swarm of the swarm plot\n", " )\n", " x_tick_tabels = []\n", - " for group_i, values_i in self.__data_copy.groupby(self.__x, observed=False):\n", + "\n", + " for group_i, values_i in self.__data_copy.groupby(self.__x):\n", " x_new = []\n", " values_i_y = values_i[self.__y]\n", " x_offset = self._swarm(\n", @@ -2172,6 +2172,14 @@ "\n", " return ax, swarm_legend_kwargs if self.__hue is not None else None" ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "022ea903", + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/nbs/tests/test_plot_tools.py b/nbs/tests/test_plot_tools.py index 70f73640..fb07f9c9 100644 --- a/nbs/tests/test_plot_tools.py +++ b/nbs/tests/test_plot_tools.py @@ -84,7 +84,7 @@ def test_check_data_matches_labels(): ("data", None, "`data` must be a Pandas Dataframe.", ValueError), ("x", None, "`x` must be a string.", ValueError), ("y", None, "`y` must be a string.", ValueError), - ("ax", None, "`ax` must be a Matplotlib axes.Axes. The current `ax` is a ", ValueError), + ("ax", None, "`ax` must be a Matplotlib AxesSubplot. The current `ax` is a ", ValueError), ("order", 5, "`order` must be either an Iterable or None.", ValueError), ("hue", 5, "`hue` must be either a string or None.", ValueError), ("palette", None, "`palette` must be either a string indicating a color name or an Iterable.", ValueError), diff --git a/settings.ini b/settings.ini index a6b36da8..5c22d22d 100644 --- a/settings.ini +++ b/settings.ini @@ -3,7 +3,7 @@ repo = DABEST-python lib_name = dabest version = 2024.03.29 -min_python = 3.9 +min_python = 3.8 license = apache2 ### nbdev ### @@ -37,7 +37,7 @@ language = English status = 3 user = acclab -requirements = fastcore pandas~=2.1.4 numpy~=1.26 matplotlib~=3.8.4 seaborn~=0.12.2 scipy~=1.12 datetime statsmodels lqrt +requirements = fastcore pandas~=1.5.0 numpy~=1.23.5 matplotlib~=3.6.3 seaborn~=0.12.2 scipy~=1.9.3 datetime statsmodels lqrt dev_requirements = pytest~=7.2.1 pytest-mpl~=0.16.1 ### Optional ###