Skip to content

Commit

Permalink
Merge pull request #152 from shakedzy/147-ks_abc-when-run-with-plot=f…
Browse files Browse the repository at this point in the history
…alse-still-plots-the-graph

Fixing issue #147
  • Loading branch information
shakedzy authored May 9, 2023
2 parents da87f99 + fdedc8b commit d4593b0
Show file tree
Hide file tree
Showing 7 changed files with 37 additions and 13 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# Change Log

## 0.7.4
* Handling running plotting functions with `plot=False` in Jupyter and truly avoid plotting (issue [#147](https://github.com/shakedzy/dython/issues/147))

## 0.7.3
* _Dython now officially supports only Python 3.8 or above_ (by-product of issue [#137](https://github.com/shakedzy/dython/issues/137))
* Added `nominal.replot_last_associations`: a new method to replot `nominal.associations` heat-maps (issue [#136](https://github.com/shakedzy/dython/issues/136))
Expand Down
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.7.3
0.7.4
2 changes: 2 additions & 0 deletions dython/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from . import nominal, model_utils, sampling, data_utils
from ._private import set_is_jupyter


def _get_version_from_setuptools():
Expand All @@ -9,3 +10,4 @@ def _get_version_from_setuptools():

__all__ = ["__version__"]
__version__ = _get_version_from_setuptools()
set_is_jupyter()
19 changes: 19 additions & 0 deletions dython/_private.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,24 @@
import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

IS_JUPYTER = None


def set_is_jupyter(force_to=None):
global IS_JUPYTER
if force_to is not None:
IS_JUPYTER = force_to
else:
IS_JUPYTER = "ipykernel_launcher.py" in sys.argv[0]


def plot_or_not(plot):
if plot:
plt.show()
elif not plot and IS_JUPYTER:
plt.close()


def convert(data, to, copy=True):
Expand Down
5 changes: 2 additions & 3 deletions dython/data_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from ._private import convert
from ._private import convert, plot_or_not


__all__ = [
Expand Down Expand Up @@ -113,8 +113,7 @@ def split_hist(
plt.title(title)
plt.ylabel(ylabel)
ax = plt.gca()
if plot:
plt.show()
plot_or_not(plot)
return ax


Expand Down
8 changes: 3 additions & 5 deletions dython/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, precision_recall_curve, auc
from scikitplot.helpers import binary_ks_curve
from ._private import convert
from ._private import convert, plot_or_not

__all__ = ["random_forest_feature_importance", "metric_graph", "ks_abc"]

Expand All @@ -28,8 +28,7 @@ def _display_metric_plot(
ax.legend(loc=legend)
if filename:
plt.savefig(filename)
if plot:
plt.show()
plot_or_not(plot)
return ax


Expand Down Expand Up @@ -516,8 +515,7 @@ def ks_abc(
ax.legend(loc=legend)
if filename:
plt.savefig(filename)
if plot:
plt.show()
plot_or_not(plot)
return {
"abc": abc,
"ks_stat": ks_statistic,
Expand Down
11 changes: 7 additions & 4 deletions dython/nominal.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@
import seaborn as sns
from psutil import cpu_count

from ._private import convert, remove_incomplete_samples, replace_nan_with_value
from ._private import (
convert,
remove_incomplete_samples,
replace_nan_with_value,
plot_or_not,
)
from .data_utils import identify_columns_by_type

__all__ = [
Expand Down Expand Up @@ -517,7 +522,6 @@ def associations(

# handling NaN values in data
if nan_strategy == _REPLACE:

# handling pandas categorical
dataset = _handling_category_for_nan_imputation(
dataset, nan_replace_value
Expand Down Expand Up @@ -920,8 +924,7 @@ def _plot_associations(
plt.title(title)
if filename:
plt.savefig(filename)
if plot:
plt.show()
plot_or_not(plot)
return ax


Expand Down

0 comments on commit d4593b0

Please sign in to comment.