From b3b04a6f3947c0440c4d05634acd927ea94770a1 Mon Sep 17 00:00:00 2001 From: Jinglin Peng Date: Fri, 15 Jul 2022 07:29:00 +0800 Subject: [PATCH] refactor(eda): make comp. and plot API consistent --- dataprep/eda/distribution/__init__.py | 2 +- dataprep/eda/distribution/compute/__init__.py | 13 ++++---- dataprep/eda/distribution/render.py | 31 +++++++++---------- 3 files changed, 22 insertions(+), 24 deletions(-) diff --git a/dataprep/eda/distribution/__init__.py b/dataprep/eda/distribution/__init__.py index 4800326fc..baf4cccf7 100644 --- a/dataprep/eda/distribution/__init__.py +++ b/dataprep/eda/distribution/__init__.py @@ -98,7 +98,7 @@ def plot( cfg = Config.from_dict(display, config) with ProgressBar(minimum=1, disable=not progress): - itmdt = compute(df, col1, col2, col3, cfg=cfg, dtype=dtype) + itmdt = compute(df, col1, col2, col3, config=cfg, dtype=dtype) to_render = render(itmdt, cfg) diff --git a/dataprep/eda/distribution/compute/__init__.py b/dataprep/eda/distribution/compute/__init__.py index e7fc309b8..c4c0a88ab 100644 --- a/dataprep/eda/distribution/compute/__init__.py +++ b/dataprep/eda/distribution/compute/__init__.py @@ -25,7 +25,7 @@ def compute( col2: Optional[Union[str, LatLong]] = None, col3: Optional[str] = None, *, - cfg: Union[Config, Dict[str, Any], None] = None, + config: Union[Config, Dict[str, Any], None] = None, display: Optional[List[str]] = None, dtype: Optional[DTypeDef] = None, ) -> Intermediate: @@ -36,10 +36,10 @@ def compute( ---------- df DataFrame from which visualizations are generated - cfg: Union[Config, Dict[str, Any], None], default None + config: Union[Config, Dict[str, Any], None], default None When a user call plot(), the created Config object will be passed to compute(). When a user call compute() directly, if he/she wants to customize the output, - cfg is a dictionary for configuring. If not, cfg is None and + config is a dictionary for configuring. If not, config is None and default values will be used for parameters. display: Optional[List[str]], default None A list containing the names of the visualizations to display. Only exist when @@ -60,10 +60,9 @@ def compute( suppress_warnings() - if isinstance(cfg, dict): - cfg = Config.from_dict(display, cfg) - - elif not cfg: + if isinstance(config, dict): + cfg = Config.from_dict(display, config) + else: cfg = Config() x, y, z = col1, col2, col3 diff --git a/dataprep/eda/distribution/render.py b/dataprep/eda/distribution/render.py index 478b75292..43b5afe91 100644 --- a/dataprep/eda/distribution/render.py +++ b/dataprep/eda/distribution/render.py @@ -2455,43 +2455,42 @@ def render_dt_num_cat(itmdt: Intermediate, cfg: Config) -> Dict[str, Any]: } -def render(itmdt: Intermediate, cfg: Config) -> Union[LayoutDOM, Dict[str, Any]]: +def render(itmdt: Intermediate, config: Config) -> Union[LayoutDOM, Dict[str, Any]]: """ Render a basic plot Parameters ---------- itmdt The Intermediate containing results from the compute function. - cfg + config Config instance """ # pylint: disable = too-many-branches - if itmdt.visual_type == "distribution_grid": - visual_elem = render_distribution_grid(itmdt, cfg) + visual_elem = render_distribution_grid(itmdt, config) elif itmdt.visual_type == "categorical_column": - visual_elem = render_cat(itmdt, cfg) + visual_elem = render_cat(itmdt, config) elif itmdt.visual_type == "geography_column": - visual_elem = render_geo(itmdt, cfg) + visual_elem = render_geo(itmdt, config) elif itmdt.visual_type == "numerical_column": - visual_elem = render_num(itmdt, cfg) + visual_elem = render_num(itmdt, config) elif itmdt.visual_type == "datetime_column": - visual_elem = render_dt(itmdt, cfg) + visual_elem = render_dt(itmdt, config) elif itmdt.visual_type == "cat_and_num_cols": - visual_elem = render_cat_num(itmdt, cfg) + visual_elem = render_cat_num(itmdt, config) elif itmdt.visual_type == "geo_and_num_cols": - visual_elem = render_geo_num(itmdt, cfg) + visual_elem = render_geo_num(itmdt, config) elif itmdt.visual_type == "latlong_and_num_cols": - visual_elem = render_latlong_num(itmdt, cfg) + visual_elem = render_latlong_num(itmdt, config) elif itmdt.visual_type == "two_num_cols": - visual_elem = render_two_num(itmdt, cfg) + visual_elem = render_two_num(itmdt, config) elif itmdt.visual_type == "two_cat_cols": - visual_elem = render_two_cat(itmdt, cfg) + visual_elem = render_two_cat(itmdt, config) elif itmdt.visual_type == "dt_and_num_cols": - visual_elem = render_dt_num(itmdt, cfg) + visual_elem = render_dt_num(itmdt, config) elif itmdt.visual_type == "dt_and_cat_cols": - visual_elem = render_dt_cat(itmdt, cfg) + visual_elem = render_dt_cat(itmdt, config) elif itmdt.visual_type == "dt_cat_num_cols": - visual_elem = render_dt_num_cat(itmdt, cfg) + visual_elem = render_dt_num_cat(itmdt, config) return visual_elem