diff --git a/rubicon_ml/intake_rubicon/__init__.py b/rubicon_ml/intake_rubicon/__init__.py index 3ea2161e..12695f5e 100644 --- a/rubicon_ml/intake_rubicon/__init__.py +++ b/rubicon_ml/intake_rubicon/__init__.py @@ -2,6 +2,7 @@ from rubicon_ml.intake_rubicon.experiment import ExperimentSource from rubicon_ml.intake_rubicon.viz import ( + DataframePlotDataSource, ExperimentsTableDataSource, MetricCorrelationPlotDataSource, ) @@ -10,4 +11,5 @@ "ExperimentSource", "ExperimentsTableDataSource", "MetricCorrelationPlotDataSource", + "DataframePlotDataSource", ] diff --git a/rubicon_ml/intake_rubicon/publish.py b/rubicon_ml/intake_rubicon/publish.py index e925a092..cadc8b61 100644 --- a/rubicon_ml/intake_rubicon/publish.py +++ b/rubicon_ml/intake_rubicon/publish.py @@ -4,13 +4,16 @@ import yaml if TYPE_CHECKING: + from rubicon_ml.viz import DataframePlot from rubicon_ml.viz.experiments_table import ExperimentsTable from rubicon_ml.viz.metric_correlation_plot import MetricCorrelationPlot def publish( experiments, - visualization_object: Optional[Union["ExperimentsTable", "MetricCorrelationPlot"]] = None, + visualization_object: Optional[ + Union["ExperimentsTable", "MetricCorrelationPlot", "DataframePlot"] + ] = None, output_filepath=None, base_catalog_filepath=None, ): @@ -102,6 +105,7 @@ def _update_catalog( def _build_catalog(experiments, visualization): + from rubicon_ml.viz import DataframePlot from rubicon_ml.viz.experiments_table import ExperimentsTable from rubicon_ml.viz.metric_correlation_plot import MetricCorrelationPlot @@ -149,6 +153,7 @@ def _build_catalog(experiments, visualization): }, } catalog["sources"]["experiment_table"] = appended_visualization_catalog + if isinstance(visualization, MetricCorrelationPlot): appended_visualization_catalog = { "driver": "rubicon_ml_metric_correlation_plot", @@ -160,6 +165,20 @@ def _build_catalog(experiments, visualization): } catalog["sources"]["metric_correlation_plot"] = appended_visualization_catalog + # vizualization is an DataframePlot + if isinstance(visualization, DataframePlot): + appended_visualization_catalog = { + "driver": "rubicon_ml_dataframe_plot", + "args": { + "dataframe_name": visualization.dataframe_name, + "x": visualization.x, + "y": visualization.y, + }, + } + + # append visualization object to end of catalog file + catalog["sources"]["dataframe_plot"] = appended_visualization_catalog + # append visualization object to end of catalog file return catalog diff --git a/rubicon_ml/intake_rubicon/viz.py b/rubicon_ml/intake_rubicon/viz.py index 79b81935..024f1a90 100644 --- a/rubicon_ml/intake_rubicon/viz.py +++ b/rubicon_ml/intake_rubicon/viz.py @@ -1,6 +1,5 @@ from rubicon_ml import __version__ from rubicon_ml.intake_rubicon.base import VizDataSourceMixin -from rubicon_ml.viz import MetricCorrelationPlot class ExperimentsTableDataSource(VizDataSourceMixin): @@ -35,11 +34,35 @@ class MetricCorrelationPlotDataSource(VizDataSourceMixin): def __init__(self, metadata=None, **catalog_data): self._catalog_data = catalog_data or {} - print(self._catalog_data) super().__init__(metadata=metadata) def _get_schema(self): """Creates a Metric Correlation Plot visualization and sets it as the visualization object attribute""" + from rubicon_ml.viz import MetricCorrelationPlot + self._visualization_object = MetricCorrelationPlot(**self._catalog_data) + + return super()._get_schema() + + +class DataframePlotDataSource(VizDataSourceMixin): + """An Intake data source for reading `rubicon` Dataframe Plot visualizations.""" + + version = __version__ + + container = "python" + name = "rubicon_ml_dataframe_plot" + + def __init__(self, metadata=None, **catalog_data): + self._catalog_data = catalog_data or {} + + super().__init__(metadata=metadata) + + def _get_schema(self): + """Creates a Dataframe Plot visualization and sets it as the visualization object attribute""" + from rubicon_ml.viz import DataframePlot + + self._visualization_object = DataframePlot(**self._catalog_data) + return super()._get_schema() diff --git a/setup.cfg b/setup.cfg index 626674d8..d3cc418d 100644 --- a/setup.cfg +++ b/setup.cfg @@ -63,6 +63,7 @@ intake.drivers = rubicon_ml_experiment = rubicon_ml.intake_rubicon.experiment:ExperimentSource rubicon_ml_experiment_table = rubicon_ml.intake_rubicon.viz:ExperimentsTableDataSource rubicon_ml_metric_correlation_plot = rubicon_ml.intake_rubicon.viz:MetricCorrelationPlotDataSource + rubicon_ml_dataframe_plot = rubicon_ml.intake_rubicon.viz:DataframePlotDataSource [versioneer] vcs = git diff --git a/tests/unit/intake_rubicon/test_publish.py b/tests/unit/intake_rubicon/test_publish.py index 3370a64d..f7691bf5 100644 --- a/tests/unit/intake_rubicon/test_publish.py +++ b/tests/unit/intake_rubicon/test_publish.py @@ -2,6 +2,7 @@ import yaml from rubicon_ml import Rubicon, publish +from rubicon_ml.viz.dataframe_plot import DataframePlot from rubicon_ml.viz.experiments_table import ExperimentsTable from rubicon_ml.viz.metric_correlation_plot import MetricCorrelationPlot @@ -67,6 +68,35 @@ def test_publish(project_client): ) assert catalog["sources"]["metric_correlation_plot"] is not None + # Dataframe Plot + + visualization_object = DataframePlot(dataframe_name="test_dataframe") + catalog_yaml = publish(project.experiments(), visualization_object) + catalog = yaml.safe_load(catalog_yaml) + + assert f"experiment_{experiment.id.replace('-', '_')}" in catalog["sources"] + assert ( + "rubicon_ml_experiment" + == catalog["sources"][f"experiment_{experiment.id.replace('-', '_')}"]["driver"] + ) + assert ( + experiment.repository.root_dir + == catalog["sources"][f"experiment_{experiment.id.replace('-', '_')}"]["args"]["urlpath"] + ) + assert ( + experiment.id + == catalog["sources"][f"experiment_{experiment.id.replace('-', '_')}"]["args"][ + "experiment_id" + ] + ) + assert ( + project.name + == catalog["sources"][f"experiment_{experiment.id.replace('-', '_')}"]["args"][ + "project_name" + ] + ) + assert catalog["sources"]["dataframe_plot"] is not None + def test_publish_from_multiple_sources(): rubicon_a = Rubicon(persistence="memory", root_dir="path/a") diff --git a/tests/unit/intake_rubicon/test_viz.py b/tests/unit/intake_rubicon/test_viz.py index 972326c4..d519be00 100644 --- a/tests/unit/intake_rubicon/test_viz.py +++ b/tests/unit/intake_rubicon/test_viz.py @@ -1,6 +1,7 @@ import os from rubicon_ml.intake_rubicon.viz import ( + DataframePlotDataSource, ExperimentsTableDataSource, MetricCorrelationPlotDataSource, ) @@ -58,3 +59,37 @@ def test_metric_correlation_plot_source(): assert visualization.selected_metric == catalog_data_sample["selected_metric"] source.close() + + +def test_datatable_plot_source(): + catalog_data_sample = { + "dataframe_name": "dataframe_name", + "experiments": None, + "plotting_func": None, + "plotting_func_kwargs": None, + "x": None, + "y": None, + } + + source = DataframePlotDataSource( + dataframe_name="dataframe_name", + experiments=None, + plotting_func=None, + plotting_func_kwargs=None, + x=None, + y=None, + ) + assert source is not None + + source.discover() + + visualization = source.read() + + assert visualization is not None + assert visualization.dataframe_name == catalog_data_sample["dataframe_name"] + assert visualization.plotting_func == catalog_data_sample["plotting_func"] + assert visualization.plotting_func_kwargs == catalog_data_sample["plotting_func_kwargs"] + assert visualization.x == catalog_data_sample["x"] + assert visualization.y == catalog_data_sample["y"] + + source.close()