diff --git a/.gitignore b/.gitignore
index 71e60b8a..44ae04d9 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,3 +1,4 @@
+.idea
*.pyc
*.swp
*.DS_Store
diff --git a/README.md b/README.md
index 5167fa96..f215fd1d 100644
--- a/README.md
+++ b/README.md
@@ -1,9 +1,18 @@
-# ggplot
+# ggplot - a working, maintained fork
-
-
+## Why this fork?
+`ggplot` is a great python library. However, it is no longer maintained by its owner and still has a bunch of issues which remain unsolved. Some of these include incompatibility with newer versions of `pandas` and Python 3.
-### What is it?
+Many projects still rely on `ggplot` and many have to either move to alternatives or manually update `ggpy`'s code after installing it to fix compatibility issues. To get a better context of the issue you can refer to [#654 Is this project dead?](https://github.com/yhat/ggpy/issues/654)
+
+This fork is a working copy of `ggplot` which is readily maintained and is open to updates and fixes so that developers do not have to make fixes manually.
+
+## Installation
+```bash
+$ pip3 install git+https://github.com/sushinoya/ggpy
+```
+
+## What is ggpy?
`ggplot` is a Python implementation of the grammar of graphics. It is not intended
to be a feature-for-feature port of [`ggplot2 for R`](https://github.com/hadley/ggplot2)--though
there is much greatness in `ggplot2`, the Python world could stand to benefit
@@ -19,31 +28,3 @@ ggplot(diamonds, aes(x='price', color='clarity')) + \
facet_wrap('cut')
```

-
-### Installation
-```bash
-$ pip install -U ggplot
-# or
-$ conda install -c conda-forge ggplot
-# or
-pip install git+https://github.com/yhat/ggplot.git
-```
-
-### Examples
-Examples are the best way to learn. There is a Jupyter Notebook full of them.
-There are also notebooks that show how to do particular things with ggplot
-(i.e. [make a scatter plot](./docs/how-to/Making%20a%20Scatter%20Plot.ipynb) or [make a histogram](./docs/how-to/Making%20a%20Scatter%20Plot.ipynb)).
-
-- [docs](./docs)
-- [gallery](./docs/Gallery.ipynb)
-- [various examples](./examples.md)
-
-
-### What happened to the old version that didn't work?
-It's gone--the windows, the doors, [everything](https://www.youtube.com/watch?v=YuxCKv_0GZc).
-Just kidding, [you can find it here](https://github.com/yhat/ggplot/tree/v0.6.6), though I'm not sure why you'd want to look at it. The data grouping and manipulation bits were re-written
-(so they actually worked) with things like facets in mind.
-
-### Contributing
-Thanks to all of the ggplot [contributors](./contributors.md#contributors)!
-See *[contributing.md](./contributing.md)*.
diff --git a/docs/examples.py b/docs/examples.py
index bfbdb5fb..db123f5f 100644
--- a/docs/examples.py
+++ b/docs/examples.py
@@ -103,7 +103,7 @@
df = pd.DataFrame({"x": np.arange(1000)})
df['y_low'] = df.x * 0.9
df['y_high'] = df.x * 1.1
-df['thing'] = ['a' if i%2==0 else 'b' for i in df.x]
+df['thing'] = ['a' if i % 2 == 0 else 'b' for i in df.x]
p = ggplot(df, aes(x='x', ymin='y_low', ymax='y_high')) + geom_area()
p.save("./examples/example-" + str(uuid.uuid4()) + ".png")
# # area w/ facet
@@ -131,7 +131,7 @@
#
df = pd.DataFrame({"x": np.arange(100)})
df['y'] = df.x * 10
-df['z'] = ["a" if x%2==0 else "b" for x in df.x]
+df['z'] = ["a" if x % 2 == 0 else "b" for x in df.x]
#
# # polar coords
p = ggplot(df, aes(x='x', y='y')) + geom_point() + coord_polar()
@@ -158,7 +158,7 @@
p.save("./examples/example-" + str(uuid.uuid4()) + ".png")
#
# # # x dates formatting faceted
-pageviews['z'] = ["a" if i%2==0 else "b" for i in range(len(pageviews))]
+pageviews['z'] = ["a" if i % 2 == 0 else "b" for i in range(len(pageviews))]
p = ggplot(pageviews, aes(x='date_hour', y='pageviews')) + geom_line() + scale_x_date(labels=date_format('%B %-d, %Y')) + facet_grid(y='z')
p.save("./examples/example-" + str(uuid.uuid4()) + ".png")
#
diff --git a/ggplot/aes.py b/ggplot/aes.py
index 495b2265..85c6a375 100755
--- a/ggplot/aes.py
+++ b/ggplot/aes.py
@@ -8,11 +8,10 @@
from patsy.eval import EvalEnvironment
-from . import utils
-
import numpy as np
import pandas as pd
+
class aes(UserDict):
"""
Creates a dictionary that is used to evaluate
@@ -72,7 +71,7 @@ def __init__(self, *args, **kwargs):
self.__eval_env__ = EvalEnvironment.capture(1)
def __deepcopy__(self, memo):
- '''deepcopy support for ggplot'''
+ """deepcopy support for ggplot"""
result = aes()
for key, item in self.__dict__.items():
# don't make a deepcopy of the env!
@@ -122,7 +121,7 @@ def _get_discrete_aes(self, df):
for aes_type, column in self.data.items():
if aes_type in ['x', 'y']:
continue
- elif aes_type=="group":
+ elif aes_type == "group":
discrete_aes.append((aes_type, column))
elif column not in non_numeric_columns:
continue
diff --git a/ggplot/chart_components.py b/ggplot/chart_components.py
index dacae6c6..dcf248c8 100644
--- a/ggplot/chart_components.py
+++ b/ggplot/chart_components.py
@@ -46,13 +46,13 @@ class xlim(object):
>>> ggplot(mpg, aes(x='hwy')) + geom_hisotgram() + xlim(0, 20)
"""
def __init__(self, low = None, high = None):
- if low != None :
+ if low is not None:
try:
_ = low - 0
except TypeError:
raise Exception("The 'low' argument to", self.__class__.__name__,
"must be of a numeric type or None")
- if high != None :
+ if high is not None:
try:
_ = high - 0
except TypeError:
@@ -83,13 +83,13 @@ class ylim(object):
>>> ggplot(mpg, aes(x='hwy')) + geom_hisotgram() + ylim(0, 5)
"""
def __init__(self, low = None, high = None):
- if low != None :
+ if low is not None:
try:
_ = low - 0
except TypeError:
raise Exception("The 'low' argument to", self.__class__.__name__,
"must be of a numeric type or None")
- if high != None :
+ if high is not None:
try:
_ = high - 0
except TypeError:
@@ -140,7 +140,7 @@ class ylab(object):
Examples
--------
- >>> ggplot(mpg, aes(x='hwy')) + geom_hisotgram() + ylab("Count\n(# of cars)")
+ >>> ggplot(mpg, aes(x='hwy')) + geom_hisotgram() + ylab('''Count\n(# of cars)''')
"""
def __init__(self, ylab):
if ylab is None:
@@ -169,7 +169,7 @@ class labs(object):
Examples
--------
- >>> ggplot(mpg, aes(x='hwy')) + geom_hisotgram() + labs("Miles / gallon", "Count\n(# of cars)", "MPG Plot")
+ >>> ggplot(mpg, aes(x='hwy')) + geom_hisotgram() + labs("Miles / gallon", '''Count\n(# of cars)''', "MPG Plot")
"""
def __init__(self, x=None, y=None, title=None):
self.x = x
diff --git a/ggplot/facets.py b/ggplot/facets.py
index be6827b9..d3b8bfb6 100644
--- a/ggplot/facets.py
+++ b/ggplot/facets.py
@@ -20,7 +20,7 @@ def __init__(self, data, is_wrap, rowvar=None, colvar=None, nrow=None, ncol=None
# assign subplot indices to rowvars and columnvars
self.ndim = ndim = self.calculate_ndimensions(data, rowvar, colvar)
- if is_wrap==True:
+ if is_wrap:
if self.nrow:
self.ncol = ncol = int(math.ceil(ndim / float(self.nrow)))
self.nrow = nrow = int(self.nrow)
@@ -47,9 +47,9 @@ def __init__(self, data, is_wrap, rowvar=None, colvar=None, nrow=None, ncol=None
value = next(facet_values)
except Exception as e:
continue
- if ncol==1:
+ if ncol == 1:
self.facet_map[value] = (row, None)
- elif nrow==1:
+ elif nrow == 1:
self.facet_map[value] = (None, col)
else:
self.facet_map[value] = (row, col)
@@ -119,12 +119,13 @@ def __init__(self, x=None, y=None, nrow=None, ncol=None, scales=None):
self.scales = scales
def __radd__(self, gg):
- if gg.__class__.__name__=="ggplot":
+ if gg.__class__.__name__ == "ggplot":
gg.facets = Facet(gg.data, True, self.x_var, self.y_var, nrow=self.nrow, ncol=self.ncol, scales=self.scales)
return gg
return self
+
class facet_grid(object):
"""
Layout panels from x and (optionally) y variables in a grid format.
@@ -155,7 +156,7 @@ def __init__(self, x=None, y=None, scales=None):
self.scales = scales
def __radd__(self, gg):
- if gg.__class__.__name__=="ggplot":
+ if gg.__class__.__name__ == "ggplot":
gg.facets = Facet(gg.data, False, self.x_var, self.y_var, scales=self.scales)
return gg
return self
diff --git a/ggplot/geoms/geom.py b/ggplot/geoms/geom.py
index 1b177368..9142c00b 100755
--- a/ggplot/geoms/geom.py
+++ b/ggplot/geoms/geom.py
@@ -59,7 +59,7 @@ def _get_plot_args(self, data, _aes):
for key, value in _aes.items():
if value not in data:
mpl_params[key] = value
- elif data[value].nunique()==1:
+ elif data[value].nunique() == 1:
mpl_params[key] = data[value].iloc[0]
else:
mpl_params[key] = data[value]
diff --git a/ggplot/geoms/geom_area.py b/ggplot/geoms/geom_area.py
index c89b3135..971e9564 100755
--- a/ggplot/geoms/geom_area.py
+++ b/ggplot/geoms/geom_area.py
@@ -43,7 +43,7 @@ def plot(self, ax, data, _aes):
if self.last_y is None:
self.last_y = pd.Series(np.repeat(0, len(data)))
ymin = self.last_y
- if self.DEFAULT_PARAMS['position']=="stack":
+ if self.DEFAULT_PARAMS['position'] == "stack":
ymax = self.last_y.reset_index(drop=True) + data[variables['y']].reset_index(drop=True)
else:
ymax = data[variables['y']]
diff --git a/ggplot/geoms/geom_bar.py b/ggplot/geoms/geom_bar.py
index 07a62fdd..037bb10a 100755
--- a/ggplot/geoms/geom_bar.py
+++ b/ggplot/geoms/geom_bar.py
@@ -91,9 +91,9 @@ def plot(self, ax, data, _aes, x_levels, fill_levels, lookups):
xticks = []
for i, x_level in enumerate(x_levels):
- mask = data[variables['x']]==x_level
+ mask = data[variables['x']] == x_level
row = data[mask]
- if len(row)==0:
+ if len(row) == 0:
xticks.append(i)
continue
@@ -111,19 +111,19 @@ def plot(self, ax, data, _aes, x_levels, fill_levels, lookups):
height = 1.0
ypos = 0
else:
- mask = (lookups[variables['x']]==x_level) & (lookups[variables['fill']]==fillval)
+ mask = (lookups[variables['x']] == x_level) & (lookups[variables['fill']] == fillval)
height = lookups[mask]['__calc_weight__'].sum()
- mask = (lookups[variables['x']]==x_level) & (lookups[variables['fill']] < fillval)
+ mask = (lookups[variables['x']] == x_level) & (lookups[variables['fill']] < fillval)
ypos = lookups[mask]['__calc_weight__'].sum()
else:
if fill_levels is not None:
- dodge = (width * fill_idx)
+ dodge = width * fill_idx
else:
dodge = width
ypos = 0.0
height = row[weight_col].sum()
- xy = (dodge + i - fill_x_adjustment, ypos)
+ xy = (dodge + i - fill_x_adjustment, ypos)
ax.add_patch(patches.Rectangle(xy, width, height, **params))
if fill_levels is not None:
diff --git a/ggplot/geoms/geom_boxplot.py b/ggplot/geoms/geom_boxplot.py
index 9c6b6ee8..7550b9e5 100755
--- a/ggplot/geoms/geom_boxplot.py
+++ b/ggplot/geoms/geom_boxplot.py
@@ -34,7 +34,7 @@ def plot(self, ax, data, _aes, x_levels):
xticks = []
for (i, xvalue) in enumerate(x_levels):
- subset = data[data[variables['x']]==xvalue]
+ subset = data[data[variables['x']] == xvalue]
xi = np.repeat(i, len(subset))
yvalues = subset[variables['y']]
xticks.append(i)
@@ -42,22 +42,22 @@ def plot(self, ax, data, _aes, x_levels):
bounds_25_75 = yvalues.quantile([0.25, 0.75]).values
bounds_5_95 = yvalues.quantile([0.05, 0.95]).values
- if self.params.get('outliers', True)==True:
+ if self.params.get('outliers', True):
mask = ((yvalues > bounds_5_95[1]) | (yvalues < bounds_5_95[0])).values
ax.scatter(x=xi[mask], y=yvalues[mask], c=self.params.get('outlier_color', 'black'))
- if self.params.get('lines', True)==True:
+ if self.params.get('lines', True):
ax.vlines(x=i, ymin=bounds_25_75[1], ymax=bounds_5_95[1])
ax.vlines(x=i, ymin=bounds_5_95[0], ymax=bounds_25_75[0])
- if self.params.get('notch', False)==True:
+ if self.params.get('notch', False):
ax.hlines(bounds_5_95[0], i - 0.25/2, i + 0.25/2, linewidth=2)
ax.hlines(bounds_5_95[1], i - 0.25/2, i + 0.25/2, linewidth=2)
- if self.params.get('median', True)==True:
+ if self.params.get('median', True):
ax.hlines(yvalues.median(), i - 0.25, i + 0.25, linewidth=2)
- if self.params.get('box', True)==True:
+ if self.params.get('box', True):
params = {
'facecolor': 'white',
'edgecolor': 'black',
diff --git a/ggplot/geoms/geom_density.py b/ggplot/geoms/geom_density.py
index b306f587..27995920 100755
--- a/ggplot/geoms/geom_density.py
+++ b/ggplot/geoms/geom_density.py
@@ -49,6 +49,6 @@ def plot(self, ax, data, _aes):
params = self._get_plot_args(data, _aes)
variables = _aes.data
x = data[variables['x']]
- x = x[x.isnull()==False]
+ x = x[x.isnull() == False]
x, y = self._calculate_density(x)
ax.plot(x, y, **params)
diff --git a/ggplot/geoms/geom_histogram.py b/ggplot/geoms/geom_histogram.py
index 9940b857..b54ce418 100755
--- a/ggplot/geoms/geom_histogram.py
+++ b/ggplot/geoms/geom_histogram.py
@@ -42,7 +42,7 @@ def plot(self, ax, data, _aes):
variables = _aes.data
x = data[variables['x']]
- x = x[x.isnull()==False]
+ x = x[x.isnull() == False]
if 'binwidth' in self.params:
params['bins'] = np.arange(np.min(x), np.max(x) + self.params['binwidth'], self.params['binwidth'])
diff --git a/ggplot/geoms/geom_line.py b/ggplot/geoms/geom_line.py
index c6dc72a6..25488bdc 100755
--- a/ggplot/geoms/geom_line.py
+++ b/ggplot/geoms/geom_line.py
@@ -37,8 +37,8 @@ def plot(self, ax, data, _aes):
y = data[variables['y']]
nulls = (x.isnull() | y.isnull())
- x = x[nulls==False]
- y = y[nulls==False]
+ x = x[nulls == False]
+ y = y[nulls == False]
if self.is_path:
pass
diff --git a/ggplot/geoms/geom_step.py b/ggplot/geoms/geom_step.py
index 2d46c7bb..16957c14 100755
--- a/ggplot/geoms/geom_step.py
+++ b/ggplot/geoms/geom_step.py
@@ -37,8 +37,8 @@ def plot(self, ax, data, _aes):
y = data[variables['y']]
nulls = (x.isnull() | y.isnull())
- x = x[nulls==False]
- y = y[nulls==False]
+ x = x[nulls == False]
+ y = y[nulls == False]
xs = [None] * (2 * (len(x)-1))
ys = [None] * (2 * (len(x)-1))
diff --git a/ggplot/geoms/geom_tile.py b/ggplot/geoms/geom_tile.py
index 17e44d48..cddc52b7 100755
--- a/ggplot/geoms/geom_tile.py
+++ b/ggplot/geoms/geom_tile.py
@@ -55,7 +55,7 @@ def plot(self, ax, data, _aes):
counts = data[[weight, variables['x'] + "_cut", variables['y'] + "_cut"]].groupby([variables['x'] + "_cut", variables['y'] + "_cut"]).count().fillna(0)
weighted = data[[weight, variables['x'] + "_cut", variables['y'] + "_cut"]].groupby([variables['x'] + "_cut", variables['y'] + "_cut"]).sum().fillna(0)
- if self.params['interpolate']==False:
+ if self.params['interpolate'] == False:
def get_xy():
for x in x_bins:
for y in y_bins:
diff --git a/ggplot/geoms/geom_violin.py b/ggplot/geoms/geom_violin.py
index f4c69b07..6fec6064 100755
--- a/ggplot/geoms/geom_violin.py
+++ b/ggplot/geoms/geom_violin.py
@@ -27,8 +27,8 @@ def plot(self, ax, data, _aes, x_levels):
variables = _aes.data
xticks = []
- for (i, xvalue) in enumerate(x_levels):
- subset = data[data[variables['x']]==xvalue]
+ for i, xvalue in enumerate(x_levels):
+ subset = data[data[variables['x']] == xvalue]
yi = subset[variables['y']].values
# so this is weird, apparently violinplot is *the only plot that
diff --git a/ggplot/ggplot.py b/ggplot/ggplot.py
index ba12a6bd..776d46c1 100755
--- a/ggplot/ggplot.py
+++ b/ggplot/ggplot.py
@@ -23,6 +23,7 @@
if os.environ.get("GGPLOT_DEV"):
from PIL import Image
+
class ggplot(object):
"""
ggplot is the base layer or object that you use to define
@@ -192,13 +193,13 @@ def apply_theme(self):
warnings.warn(msg, RuntimeWarning)
def apply_coords(self):
- if self.coords=="equal":
+ if self.coords == "equal":
for ax in self._iterate_subplots():
min_val = np.min([np.min(ax.get_yticks()), np.min(ax.get_xticks())])
max_val = np.max([np.max(ax.get_yticks()), np.max(ax.get_xticks())])
ax.set_xticks(np.linspace(min_val, max_val, 7))
ax.set_yticks(np.linspace(min_val, max_val, 7))
- elif self.coords=="flip":
+ elif self.coords == "flip":
if 'x' in self._aes.data and 'y' in self._aes.data:
x = self._aes.data['x']
y = self._aes.data['y']
@@ -337,17 +338,17 @@ def _get_mapping(self, aes_type, colname):
from "a" => "#4682B4".
"""
mapping = None
- if aes_type=="color":
+ if aes_type == "color":
mapping = discretemappers.color_gen(self.data[colname].nunique(), colors=self.manual_color_list)
- elif aes_type=="fill":
+ elif aes_type == "fill":
mapping = discretemappers.color_gen(self.data[colname].nunique(), colors=self.manual_fill_list)
- elif aes_type=="shape":
+ elif aes_type == "shape":
mapping = discretemappers.shape_gen()
- elif aes_type=="linetype":
+ elif aes_type == "linetype":
mapping = discretemappers.linetype_gen()
- elif aes_type=="size":
+ elif aes_type == "size":
mapping = discretemappers.size_gen(self.data[colname].unique())
- elif aes_type=="group":
+ elif aes_type == "group":
mapping = discretemappers.identity_gen(self.data[colname].unique())
return mapping
@@ -425,11 +426,11 @@ def make_facets(self):
"Creates figure and axes for m x n facet grid/wrap"
sharex, sharey = True, True
if self.facets:
- if self.facets.scales=="free":
+ if self.facets.scales == "free":
sharex, sharey = False, False
- elif self.facets.scales=="free_x":
+ elif self.facets.scales == "free_x":
sharex, sharey = False, True
- elif self.facets.scales=="free_y":
+ elif self.facets.scales == "free_y":
sharex, sharey = True, False
facet_params = dict(sharex=sharex, sharey=sharey)
@@ -438,7 +439,7 @@ def make_facets(self):
facet_params['nrows'] = nrow
facet_params['ncols'] = ncol
- if self.coords=="polar":
+ if self.coords == "polar":
facet_params['subplot_kw'] = { "polar": True }
fig, axs = plt.subplots(**facet_params)
@@ -463,7 +464,8 @@ def get_facet_groups(self, group):
col_variable = self.facets.colvar
row_variable = self.facets.rowvar
- if self.facets.is_wrap==True:
+ font = {'fontsize': 10}
+ if self.facets.is_wrap:
groups = [row_variable, col_variable]
groups = [g for g in groups if g]
for (i, (name, subgroup)) in enumerate(group.groupby(groups)):
@@ -472,12 +474,11 @@ def get_facet_groups(self, group):
# this only happens when a field is being used both as a facet parameter AND as a discrete aesthetic (i.e. shape)
row, col = self.facets.facet_map[name]
- if len(self.subplots.shape)==1:
+ if len(self.subplots.shape) == 1:
ax = self.subplots[i]
else:
ax = self.get_subplot(row, col)
- font = { 'fontsize': 10 }
yield (ax, subgroup)
for item in self.facets.generate_subplot_index(self.data, self.facets.rowvar, self.facets.colvar):
@@ -510,7 +511,7 @@ def get_facet_groups(self, group):
for (_, (colname, subgroup)) in enumerate(group.groupby(col_variable)):
row, col = self.facets.facet_map[colname]
ax = self.subplots[col]
- if self.facets.is_wrap==True:
+ if self.facets.is_wrap:
ax.set_title("%s=%s" % (col_variable, colname))
else:
ax.set_title(colname, fontdict={'fontsize': 10})
@@ -520,7 +521,7 @@ def get_facet_groups(self, group):
for (row, (rowname, subgroup)) in enumerate(group.groupby(row_variable)):
row, col = self.facets.facet_map[rowname]
- if self.facets.is_wrap==True:
+ if self.facets.is_wrap:
ax = self.subplots[row]
ax.set_title("%s=%s" % (row_variable, rowname))
else:
@@ -570,11 +571,11 @@ def save_as_base64(self, as_tag=False, width=None, height=None, dpi=180):
height: int, float
height of the plot in inches
"""
- imgdata = six.StringIO()
+ imgdata = six.BytesIO()
self.save(imgdata, width=width, height=height, dpi=dpi)
imgdata.seek(0) # rewind the data
- uri = 'data:image/png;base64,' + urllib.quote(base64.b64encode(imgdata.buf))
- if as_tag==True:
+ uri = 'data:image/png;base64,' + urllib.quote(base64.b64encode(imgdata.read()))
+ if as_tag:
return '
' % uri
else:
return uri
@@ -585,7 +586,7 @@ def _prep_layer_for_plotting(self, layer, facetgroup):
function on them. This function performs those perperations and then
returns a dictionary of **kwargs for the layer.plot function to use.
"""
- if layer.__class__.__name__=="geom_bar":
+ if layer.__class__.__name__ == "geom_bar":
mask = True
df = layer.setup_data(self.data, self._aes, facets=self.facets)
if df is None:
@@ -593,13 +594,13 @@ def _prep_layer_for_plotting(self, layer, facetgroup):
if self.facets:
facet_filter = facetgroup[self.facets.facet_cols].iloc[0].to_dict()
for k, v in facet_filter.items():
- mask = (mask) & (df[k]==v)
+ mask = mask & (df[k] == v)
df = df[mask]
if 'fill' in self._aes:
fillcol_raw = self._aes['fill'][:-5]
fillcol = self._aes['fill']
- fill_levels = self.data[[fillcol_raw, fillcol]].sort(fillcol_raw)[fillcol].unique()
+ fill_levels = self.data[[fillcol_raw, fillcol]].sort_values(by=fillcol_raw)[fillcol].unique()
else:
fill_levels = None
return dict(x_levels=self.data[self._aes['x']].unique(), fill_levels=fill_levels, lookups=df)
@@ -610,7 +611,7 @@ def _prep_layer_for_plotting(self, layer, facetgroup):
return dict()
def make(self):
- "Constructs the plot using the methods. This is the 'main' for ggplot"
+ """Constructs the plot using the methods. This is the 'main' for ggplot"""
plt.close()
with mpl.rc_context():
self.apply_theme()
@@ -619,8 +620,8 @@ def make(self):
self.fig, self.subplots = self.make_facets()
else:
subplot_kw = {}
- if self.coords=="polar":
- subplot_kw = { "polar": True }
+ if self.coords == "polar":
+ subplot_kw = {"polar": True}
self.fig, self.subplots = plt.subplots(subplot_kw=subplot_kw)
self.apply_scales()
@@ -631,7 +632,7 @@ def make(self):
for ax, facetgroup in self.get_facet_groups(group):
for layer in self.layers:
kwargs = self._prep_layer_for_plotting(layer, facetgroup)
- if kwargs==False:
+ if not kwargs:
continue
layer.plot(ax, facetgroup, self._aes, **kwargs)
diff --git a/ggplot/legend.py b/ggplot/legend.py
index ca894cf8..c29f6694 100644
--- a/ggplot/legend.py
+++ b/ggplot/legend.py
@@ -44,17 +44,17 @@ def linetype_legend(linetype):
return plt.Line2D([0],[0], color='black', linestyle=linetype)
def make_aesthetic_legend(aesthetic, value):
- if aesthetic=='color':
+ if aesthetic == 'color':
return color_legend(value)
- elif aesthetic=='fill':
+ elif aesthetic == 'fill':
return color_legend(value)
- elif aesthetic=='size':
+ elif aesthetic == 'size':
return size_legend(value)
- elif aesthetic=='alpha':
+ elif aesthetic == 'alpha':
return alpha_legend(value)
- elif aesthetic=='shape':
+ elif aesthetic == 'shape':
return shape_legend(value)
- elif aesthetic=='linetype':
+ elif aesthetic == 'linetype':
return linetype_legend(value)
else:
print(aesthetic + " not found")
diff --git a/ggplot/qplot.py b/ggplot/qplot.py
index 9ef231d6..13b0f325 100644
--- a/ggplot/qplot.py
+++ b/ggplot/qplot.py
@@ -3,17 +3,16 @@
from .ggplot import ggplot
from .aes import aes
-from .chart_components import ggtitle, xlim, ylim, xlab, ylab, labs
-from .geoms import geom_point, geom_bar, geom_histogram, geom_line # , geom_boxplot
+from .chart_components import ggtitle, xlim, ylim, xlab, ylab
+from .geoms import geom_point, geom_bar, geom_histogram, geom_line #geom_boxplot
from .scales.scale_log import scale_x_log, scale_y_log
import pandas as pd
-import numpy as np
import six
def qplot(x, y=None, color=None, size=None, fill=None, data=None,
- geom="auto", stat=[], position=[], xlim=None, ylim=None, log="",
- main=None, xlab=None, ylab="", asp=None):
+ geom="auto", stat=[], position=[], xlimit=None, ylimit=None, log="",
+ main=None, xlabel=None, ylabel="", asp=None):
"""
Parameters
----------
@@ -35,17 +34,17 @@ def qplot(x, y=None, color=None, size=None, fill=None, data=None,
specifies which statistics to use
position: list
gives position adjustment to use
- xlim: tuple
+ xlimit: tuple
limits on x axis; i.e. (0, 10)
- ylim: tuple, None
+ ylimit: tuple, None
limits on y axis; i.e. (0, 10)
log: string
which variables to log transform ("x", "y", or "xy")
main: string
title for the plot
- xlab: string
+ xlabel: string
title for the x axis
- ylab: string
+ ylabel: string
title for the y axis
asp: string
the y/x aspect ratio
@@ -69,8 +68,8 @@ def qplot(x, y=None, color=None, size=None, fill=None, data=None,
>>> print qplot('mpg', data=mtcars, geom="hist", main="hist")
>>> print qplot('mpg', data=mtcars, geom="histogram", main="histogram")
>>> print qplot('cyl', 'mpg', data=mtcars, geom="bar", main="bar")
- >>> print qplot('mpg', 'drat', data=mtcars, xlab= "x lab", main="xlab")
- >>> print qplot('mpg', 'drat', data=mtcars, ylab = "y lab", main="ylab")
+ >>> print qplot('mpg', 'drat', data=mtcars, xlabel="x lab", main="xlab")
+ >>> print qplot('mpg', 'drat', data=mtcars, ylabel="y lab", main="ylab")
"""
if x is not None and not isinstance(x, six.string_types):
@@ -80,7 +79,6 @@ def qplot(x, y=None, color=None, size=None, fill=None, data=None,
data['y'] = y
y = 'y'
-
aes_elements = {"x": x}
if y:
aes_elements["y"] = y
@@ -101,7 +99,7 @@ def qplot(x, y=None, color=None, size=None, fill=None, data=None,
"point": geom_point,
}
# taking our best guess
- if geom=="auto":
+ if geom == "auto":
if y is None:
geom = geom_histogram
else:
@@ -114,10 +112,14 @@ def qplot(x, y=None, color=None, size=None, fill=None, data=None,
p += scale_x_log()
if "y" in log:
p += scale_y_log()
- if xlab:
- p += xlabel(xlab)
- if ylab:
- p += ylabel(ylab)
+ if xlabel:
+ p += xlab(xlabel)
+ if ylabel:
+ p += ylab(ylabel)
+ if xlimit:
+ p += xlim(*tuple(xlimit))
+ if ylimit:
+ p += ylim(*tuple(ylimit))
if main:
p += ggtitle(main)
return p
diff --git a/ggplot/scales/scale_color_brewer.py b/ggplot/scales/scale_color_brewer.py
index d73393b1..465813bc 100644
--- a/ggplot/scales/scale_color_brewer.py
+++ b/ggplot/scales/scale_color_brewer.py
@@ -1,7 +1,7 @@
from __future__ import (absolute_import, division, print_function,
unicode_literals)
+
from .scale import scale
-from copy import deepcopy
import brewer2mpl
@@ -73,7 +73,15 @@ def __radd__(self, gg):
# only one color used
n_colors = 3
- bmap = brewer2mpl.get_map(palette, ctype, n_colors)
+ try:
+ bmap = brewer2mpl.get_map(palette, ctype, n_colors)
+ except ValueError as e:
+ if not str(e).startswith('Invalid number for map type'):
+ raise e
+ palettes = brewer2mpl.COLOR_MAPS[_handle_shorthand(ctype).lower().capitalize()][palette]
+ n_colors = int(max(str(k) for k in palettes))
+ bmap = brewer2mpl.get_map(palette, ctype, n_colors)
+
gg.manual_color_list = bmap.hex_colors
return gg
diff --git a/ggplot/scales/scale_color_gradient.py b/ggplot/scales/scale_color_gradient.py
index 9fa96500..e47083d6 100644
--- a/ggplot/scales/scale_color_gradient.py
+++ b/ggplot/scales/scale_color_gradient.py
@@ -30,7 +30,7 @@ class scale_color_gradient(scale):
Examples
--------
>>> from ggplot import *
- >>> diamons_premium = diamonds[diamonds.cut=='Premium']
+ >>> diamons_premium = diamonds[diamonds.cut == 'Premium']
>>> gg = ggplot(diamons_premium, aes(x='depth', y='carat', colour='price')) + \\
... geom_point()
>>> print(gg + scale_colour_gradient(low='red', mid='white', high='blue', limits=[4000,6000]) + \\
diff --git a/ggplot/stats/smoothers.py b/ggplot/stats/smoothers.py
index 5a8fbf37..e8b1869c 100644
--- a/ggplot/stats/smoothers.py
+++ b/ggplot/stats/smoothers.py
@@ -1,7 +1,6 @@
from __future__ import (absolute_import, division, print_function,
unicode_literals)
import numpy as np
-from pandas.lib import Timestamp
import pandas as pd
import statsmodels.api as sm
from statsmodels.nonparametric.smoothers_lowess import lowess as smlowess
@@ -11,7 +10,7 @@
import datetime
date_types = (
- pd.tslib.Timestamp,
+ pd.Timestamp,
pd.DatetimeIndex,
pd.Period,
pd.PeriodIndex,
@@ -19,8 +18,8 @@
datetime.time
)
_isdate = lambda x: isinstance(x, date_types)
-SPAN = 2/3.
-ALPHA = 0.05 # significance level for confidence interval
+SPAN = 2 / 3.
+ALPHA = 0.05 # significance level for confidence interval
def _snakify(txt):
txt = txt.strip().lower()
@@ -49,8 +48,8 @@ def lm(x, y, alpha=ALPHA):
predict_ci_upp = df['predict_ci_95%_upp'].values
if x_is_date:
- x = [Timestamp.fromordinal(int(i)) for i in x]
- return (x, fittedvalues, predict_mean_ci_low, predict_mean_ci_upp)
+ x = [pd.Timestamp.fromordinal(int(i)) for i in x]
+ return x, fittedvalues, predict_mean_ci_low, predict_mean_ci_upp
def lowess(x, y, span=SPAN):
"returns y-values estimated using the lowess function in statsmodels."
@@ -67,13 +66,13 @@ def lowess(x, y, span=SPAN):
y = pd.Series(result[::,1])
lower, upper = stats.t.interval(span, len(x), loc=0, scale=2)
std = np.std(y)
- y1 = pd.Series(lower * std + y)
- y2 = pd.Series(upper * std + y)
+ y1 = pd.Series(lower * std + y)
+ y2 = pd.Series(upper * std + y)
if x_is_date:
- x = [Timestamp.fromordinal(int(i)) for i in x]
+ x = [pd.Timestamp.fromordinal(int(i)) for i in x]
- return (x, y, y1, y2)
+ return x, y, y1, y2
def mavg(x,y, window):
"compute moving average"
@@ -87,5 +86,5 @@ def mavg(x,y, window):
y2 = y + std_err
if x_is_date:
- x = [Timestamp.fromordinal(int(i)) for i in x]
- return (x, y, y1, y2)
+ x = [pd.Timestamp.fromordinal(int(i)) for i in x]
+ return x, y, y1, y2
diff --git a/ggplot/stats/stat_smooth.py b/ggplot/stats/stat_smooth.py
index 2264a6aa..ea5a3925 100644
--- a/ggplot/stats/stat_smooth.py
+++ b/ggplot/stats/stat_smooth.py
@@ -86,7 +86,7 @@ def plot(self, ax, data, _aes):
params['alpha'] = 0.2
order = np.argsort(x)
- if self.params.get('se', True)==True:
+ if self.params.get('se', True):
if is_date(smoothed_data.x.iloc[0]):
dtype = smoothed_data.x.iloc[0].__class__
x = np.array([i.toordinal() for i in smoothed_data.x])
@@ -95,6 +95,6 @@ def plot(self, ax, data, _aes):
ax.set_xticklabels(new_ticks)
else:
ax.fill_between(smoothed_data.x, smoothed_data.y1, smoothed_data.y2, **params)
- if self.params.get('fit', True)==True:
+ if self.params.get('fit', True):
del params['alpha']
ax.plot(smoothed_data.x, smoothed_data.y, **params)
diff --git a/ggplot/themes/element_text.py b/ggplot/themes/element_text.py
index 785d22ea..56be88cf 100644
--- a/ggplot/themes/element_text.py
+++ b/ggplot/themes/element_text.py
@@ -1,8 +1,6 @@
from __future__ import (absolute_import, division, print_function,
unicode_literals)
-from matplotlib.text import Text
-
FACES = ["plain", "italic", "bold", "bold.italic"]
class element_text(object):
diff --git a/ggplot/themes/theme.py b/ggplot/themes/theme.py
index 1bccf16a..c4b0e70a 100644
--- a/ggplot/themes/theme.py
+++ b/ggplot/themes/theme.py
@@ -50,7 +50,7 @@ def __init__(self):
pass
def __radd__(self, other):
- if other.__class__.__name__=="ggplot":
+ if other.__class__.__name__ == "ggplot":
other.theme = self
return other
@@ -124,7 +124,7 @@ def __init__(self, *args, **kwargs):
self.things = deepcopy(kwargs)
def __radd__(self, other):
- if other.__class__.__name__=="ggplot":
+ if other.__class__.__name__ == "ggplot":
other.theme = self
for key, value in self.things.items():
try:
diff --git a/ggplot/themes/theme_538.py b/ggplot/themes/theme_538.py
index d9182f25..06790ab0 100644
--- a/ggplot/themes/theme_538.py
+++ b/ggplot/themes/theme_538.py
@@ -1,9 +1,9 @@
from __future__ import (absolute_import, division, print_function,
unicode_literals)
-from cycler import cycler
from .theme import theme_base
from cycler import cycler
+
class theme_538(theme_base):
"""
Theme for 538.
diff --git a/ggplot/themes/themes.py b/ggplot/themes/themes.py
index 408d1e62..5ef6ab4b 100644
--- a/ggplot/themes/themes.py
+++ b/ggplot/themes/themes.py
@@ -1,15 +1,18 @@
from __future__ import (absolute_import, division, print_function,
unicode_literals)
+
+from copy import deepcopy, copy
+
import matplotlib as mpl
import matplotlib.pyplot as plt
-from cycler import cycler
+
class theme(object):
def __init__(self):
self._rcParams = {}
def __radd__(self, other):
- if other.__class__.__name__=="ggplot":
+ if other.__class__.__name__ == "ggplot":
other.theme = self
return other
diff --git a/ggplot/utils.py b/ggplot/utils.py
index d0d5dfb6..e83de79a 100644
--- a/ggplot/utils.py
+++ b/ggplot/utils.py
@@ -1,21 +1,21 @@
from __future__ import (absolute_import, division, print_function,
unicode_literals)
+import datetime
import matplotlib.cbook as cbook
import numpy as np
import pandas as pd
-import datetime
def format_ticks(ticks):
are_ints = True
for t in ticks:
try:
- if int(t)!=t:
+ if int(t) != t:
are_ints = False
except:
return ticks
- if are_ints==True:
+ if are_ints:
return [int(t) for t in ticks]
return ticks
@@ -60,15 +60,10 @@ def is_categorical(obj):
"""
try:
float(obj.iloc[0])
- return False
+ return is_sequence_of_strings(obj) or is_sequence_of_booleans(obj)
except:
return True
- if is_sequence_of_strings(obj):
- return True
- if is_sequence_of_booleans(obj):
- return True
- return False
def is_iterable(obj):
try:
@@ -77,8 +72,9 @@ def is_iterable(obj):
except:
return False
+
date_types = (
- pd.tslib.Timestamp,
+ pd.Timestamp,
pd.DatetimeIndex,
pd.Period,
pd.PeriodIndex,
@@ -86,17 +82,20 @@ def is_iterable(obj):
datetime.time
)
+
def is_date(x):
return isinstance(x, date_types)
+
def calc_n_bins(series):
- "https://en.wikipedia.org/wiki/Histogram#Number_of_bins_and_width"
- q75, q25 = np.percentile(series, [75 , 25])
+ """https://en.wikipedia.org/wiki/Histogram#Number_of_bins_and_width"""
+ q75, q25 = np.percentile(series, [75, 25])
iqr = q75 - q25
h = (2 * iqr) / (len(series)**(1/3.))
k = (series.max() - series.min()) / h
return k
+
def sorted_unique(series):
"""Return the unique values of *series*, correctly sorted."""
# This handles Categorical data types, which sorted(series.unique()) fails
diff --git a/tests/compare_runs.py b/tests/compare_runs.py
index b778aa5c..50d3eed7 100644
--- a/tests/compare_runs.py
+++ b/tests/compare_runs.py
@@ -1,5 +1,7 @@
+from __future__ import print_function
import os
import sys
+
import imgdiff
import tempfile
import operator
@@ -7,6 +9,9 @@
from PIL import Image
import pandas as pd
+if sys.version[0] == 3:
+ from functools import reduce
+
rundir1 = sys.argv[1]
rundir2 = sys.argv[2]
html = """
@@ -59,4 +64,4 @@ def calc_mse(image1, image2):
html += "\n" + "