diff --git a/rubicon_ml/intake_rubicon/__init__.py b/rubicon_ml/intake_rubicon/__init__.py index 12695f5e..30c3a338 100644 --- a/rubicon_ml/intake_rubicon/__init__.py +++ b/rubicon_ml/intake_rubicon/__init__.py @@ -5,6 +5,7 @@ DataframePlotDataSource, ExperimentsTableDataSource, MetricCorrelationPlotDataSource, + MetricListComparisonDataSource, ) __all__ = [ @@ -12,4 +13,5 @@ "ExperimentsTableDataSource", "MetricCorrelationPlotDataSource", "DataframePlotDataSource", + "MetricListComparisonDataSource", ] diff --git a/rubicon_ml/intake_rubicon/publish.py b/rubicon_ml/intake_rubicon/publish.py index cadc8b61..f70ab975 100644 --- a/rubicon_ml/intake_rubicon/publish.py +++ b/rubicon_ml/intake_rubicon/publish.py @@ -7,12 +7,13 @@ from rubicon_ml.viz import DataframePlot from rubicon_ml.viz.experiments_table import ExperimentsTable from rubicon_ml.viz.metric_correlation_plot import MetricCorrelationPlot + from rubicon_ml.viz.metric_lists_comparison import MetricListsComparison def publish( experiments, visualization_object: Optional[ - Union["ExperimentsTable", "MetricCorrelationPlot", "DataframePlot"] + Union["ExperimentsTable", "MetricCorrelationPlot", "DataframePlot", "MetricListsComparison"] ] = None, output_filepath=None, base_catalog_filepath=None, @@ -108,6 +109,7 @@ def _build_catalog(experiments, visualization): from rubicon_ml.viz import DataframePlot from rubicon_ml.viz.experiments_table import ExperimentsTable from rubicon_ml.viz.metric_correlation_plot import MetricCorrelationPlot + from rubicon_ml.viz.metric_lists_comparison import MetricListsComparison """Helper function to build catalog dictionary from given experiments. @@ -138,6 +140,7 @@ def _build_catalog(experiments, visualization): catalog["sources"][experiment_catalog_name] = appended_experiment_catalog # create visualization entry to the catalog file + # visualization is an ExperimentsTable if visualization is not None: if isinstance(visualization, ExperimentsTable): appended_visualization_catalog = { @@ -154,6 +157,7 @@ def _build_catalog(experiments, visualization): } catalog["sources"]["experiment_table"] = appended_visualization_catalog + # vizualization is a MetricCorrelationPlot if isinstance(visualization, MetricCorrelationPlot): appended_visualization_catalog = { "driver": "rubicon_ml_metric_correlation_plot", @@ -165,7 +169,7 @@ def _build_catalog(experiments, visualization): } catalog["sources"]["metric_correlation_plot"] = appended_visualization_catalog - # vizualization is an DataframePlot + # vizualization is a DataframePlot if isinstance(visualization, DataframePlot): appended_visualization_catalog = { "driver": "rubicon_ml_dataframe_plot", @@ -181,4 +185,19 @@ def _build_catalog(experiments, visualization): # append visualization object to end of catalog file + # vizualization is a MetricListsComparison + if isinstance(visualization, MetricListsComparison): + appended_visualization_catalog = { + "driver": "rubicon_ml_metric_list", + "args": { + "column_names": visualization.column_names, + "selected_metric": visualization.selected_metric, + }, + } + + # append visualization object to end of catalog file + catalog["sources"]["metric_list"] = appended_visualization_catalog + + # append visualization object to end of catalog file + return catalog diff --git a/rubicon_ml/intake_rubicon/viz.py b/rubicon_ml/intake_rubicon/viz.py index 024f1a90..78ff88fb 100644 --- a/rubicon_ml/intake_rubicon/viz.py +++ b/rubicon_ml/intake_rubicon/viz.py @@ -66,3 +66,25 @@ def _get_schema(self): self._visualization_object = DataframePlot(**self._catalog_data) return super()._get_schema() + + +class MetricListComparisonDataSource(VizDataSourceMixin): + """An Intake data source for reading `rubicon` Metric List Comparison visualizations.""" + + version = __version__ + + container = "python" + name = "rubicon_ml_metric_list" + + def __init__(self, metadata=None, **catalog_data): + self._catalog_data = catalog_data or {} + + super().__init__(metadata=metadata) + + def _get_schema(self): + """Creates a Metric List Comparison visualization and sets it as the visualization object attribute""" + from rubicon_ml.viz import MetricListsComparison + + self._visualization_object = MetricListsComparison(**self._catalog_data) + + return super()._get_schema() diff --git a/setup.cfg b/setup.cfg index d3cc418d..dc51eb0f 100644 --- a/setup.cfg +++ b/setup.cfg @@ -64,6 +64,7 @@ intake.drivers = rubicon_ml_experiment_table = rubicon_ml.intake_rubicon.viz:ExperimentsTableDataSource rubicon_ml_metric_correlation_plot = rubicon_ml.intake_rubicon.viz:MetricCorrelationPlotDataSource rubicon_ml_dataframe_plot = rubicon_ml.intake_rubicon.viz:DataframePlotDataSource + rubicon_ml_metric_list = rubicon_ml.intake_rubicon.viz:MetricListComparisonDataSource [versioneer] vcs = git diff --git a/tests/unit/intake_rubicon/test_publish.py b/tests/unit/intake_rubicon/test_publish.py index f7691bf5..ecdfb14f 100644 --- a/tests/unit/intake_rubicon/test_publish.py +++ b/tests/unit/intake_rubicon/test_publish.py @@ -5,6 +5,7 @@ from rubicon_ml.viz.dataframe_plot import DataframePlot from rubicon_ml.viz.experiments_table import ExperimentsTable from rubicon_ml.viz.metric_correlation_plot import MetricCorrelationPlot +from rubicon_ml.viz.metric_lists_comparison import MetricListsComparison def test_publish(project_client): @@ -97,6 +98,35 @@ def test_publish(project_client): ) assert catalog["sources"]["dataframe_plot"] is not None + # MetricListComparison + + visualization_object = MetricListsComparison() + catalog_yaml = publish(project.experiments(), visualization_object) + catalog = yaml.safe_load(catalog_yaml) + + assert f"experiment_{experiment.id.replace('-', '_')}" in catalog["sources"] + assert ( + "rubicon_ml_experiment" + == catalog["sources"][f"experiment_{experiment.id.replace('-', '_')}"]["driver"] + ) + assert ( + experiment.repository.root_dir + == catalog["sources"][f"experiment_{experiment.id.replace('-', '_')}"]["args"]["urlpath"] + ) + assert ( + experiment.id + == catalog["sources"][f"experiment_{experiment.id.replace('-', '_')}"]["args"][ + "experiment_id" + ] + ) + assert ( + project.name + == catalog["sources"][f"experiment_{experiment.id.replace('-', '_')}"]["args"][ + "project_name" + ] + ) + assert catalog["sources"]["metric_list"] is not None + def test_publish_from_multiple_sources(): rubicon_a = Rubicon(persistence="memory", root_dir="path/a") diff --git a/tests/unit/intake_rubicon/test_viz.py b/tests/unit/intake_rubicon/test_viz.py index d519be00..31534669 100644 --- a/tests/unit/intake_rubicon/test_viz.py +++ b/tests/unit/intake_rubicon/test_viz.py @@ -4,6 +4,7 @@ DataframePlotDataSource, ExperimentsTableDataSource, MetricCorrelationPlotDataSource, + MetricListComparisonDataSource, ) root = os.path.dirname(__file__) @@ -93,3 +94,20 @@ def test_datatable_plot_source(): assert visualization.y == catalog_data_sample["y"] source.close() + + +def test_metric_list_source(): + catalog_data_sample = {"column_names": None, "selected_metric": None} + + source = MetricListComparisonDataSource(catalog_data_sample) + assert source is not None + + source.discover() + + visualization = source.read() + + assert visualization is not None + assert visualization.column_names == catalog_data_sample["column_names"] + assert visualization.selected_metric == catalog_data_sample["selected_metric"] + + source.close()