diff --git a/rubicon_ml/intake_rubicon/__init__.py b/rubicon_ml/intake_rubicon/__init__.py index 936d4a68..cd940b99 100644 --- a/rubicon_ml/intake_rubicon/__init__.py +++ b/rubicon_ml/intake_rubicon/__init__.py @@ -1,6 +1,6 @@ import intake # noqa F401 from rubicon_ml.intake_rubicon.experiment import ExperimentSource -from rubicon_ml.intake_rubicon.viz import ExperimentsTableDataSource +from rubicon_ml.intake_rubicon.viz import ExperimentsTableDataSource, MetricCorrelationPlotDataSource -__all__ = ["ExperimentSource", "ExperimentsTableDataSource"] +__all__ = ["ExperimentSource", "ExperimentsTableDataSource", "MetricCorrelationPlotDataSource",] diff --git a/rubicon_ml/intake_rubicon/publish.py b/rubicon_ml/intake_rubicon/publish.py index a7142a14..ecaabcc0 100644 --- a/rubicon_ml/intake_rubicon/publish.py +++ b/rubicon_ml/intake_rubicon/publish.py @@ -1,16 +1,17 @@ -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Optional, Union import fsspec import yaml if TYPE_CHECKING: from rubicon_ml.viz.experiments_table import ExperimentsTable + from rubicon_ml.viz.metric_correlation_plot import MetricCorrelationPlot + def publish( experiments, - # visualization object passed, defaulted to None - visualization_object: Optional["ExperimentsTable"] = None, + visualization_object: Optional[Union["ExperimentsTable", "MetricCorrelationPlot"]] = None, output_filepath=None, base_catalog_filepath=None, ): @@ -102,6 +103,9 @@ def _update_catalog( def _build_catalog(experiments, visualization): + from rubicon_ml.viz.experiments_table import ExperimentsTable + from rubicon_ml.viz.metric_correlation_plot import MetricCorrelationPlot + """Helper function to build catalog dictionary from given experiments. Parameters @@ -132,20 +136,31 @@ def _build_catalog(experiments, visualization): # create visualization entry to the catalog file if visualization is not None: - appended_visualization_catalog = { - "driver": "rubicon_ml_experiment_table", - "args": { - "is_selectable": visualization.is_selectable, - "metric_names": visualization.metric_names, - "metric_query_tags": visualization.metric_query_tags, - "metric_query_type": visualization.metric_query_type, - "parameter_names": visualization.parameter_names, - "parameter_query_tags": visualization.parameter_query_tags, - "parameter_query_type": visualization.parameter_query_type, - }, - } + if isinstance(visualization, ExperimentsTable): + appended_visualization_catalog = { + "driver": "rubicon_ml_experiment_table", + "args": { + "is_selectable": visualization.is_selectable, + "metric_names": visualization.metric_names, + "metric_query_tags": visualization.metric_query_tags, + "metric_query_type": visualization.metric_query_type, + "parameter_names": visualization.parameter_names, + "parameter_query_tags": visualization.parameter_query_tags, + "parameter_query_type": visualization.parameter_query_type, + }, + } + catalog["sources"]["experiment_table"] = appended_visualization_catalog + if isinstance(visualization, MetricCorrelationPlot): + appended_visualization_catalog = { + "driver": "rubicon_ml_metric_correlation_plot", + "args": { + "metric_names": visualization.metric_names, + "parameter_names": visualization.parameter_names, + "selected_metric": visualization.selected_metric, + }, + } + catalog["sources"]["metric_correlation_plot"] = appended_visualization_catalog # append visualization object to end of catalog file - catalog["sources"]["experiment_table"] = appended_visualization_catalog - + return catalog diff --git a/rubicon_ml/intake_rubicon/viz.py b/rubicon_ml/intake_rubicon/viz.py index dbda152a..c03a4220 100644 --- a/rubicon_ml/intake_rubicon/viz.py +++ b/rubicon_ml/intake_rubicon/viz.py @@ -1,5 +1,6 @@ from rubicon_ml import __version__ from rubicon_ml.intake_rubicon.base import VizDataSourceMixin +from rubicon_ml.viz import MetricCorrelationPlot class ExperimentsTableDataSource(VizDataSourceMixin): @@ -22,3 +23,22 @@ def _get_schema(self): self._visualization_object = ExperimentsTable(**self._catalog_data) return super()._get_schema() + +class MetricCorrelationPlotDataSource(VizDataSourceMixin): + """An Intake data source for reading `rubicon` Metric Correlation Plot visualizations.""" + + version = __version__ + + container = "python" + name = "rubicon_ml_metric_correlation_plot" + + def __init__(self, metadata=None, **catalog_data): + self._catalog_data = catalog_data or {} + print(self._catalog_data) + + super().__init__(metadata=metadata) + + def _get_schema(self): + """Creates a Metric Correlation Plot visualization and sets it as the visualization object attribute""" + self._visualization_object = MetricCorrelationPlot(**self._catalog_data) + return super()._get_schema() diff --git a/setup.cfg b/setup.cfg index 753a6605..626674d8 100644 --- a/setup.cfg +++ b/setup.cfg @@ -62,6 +62,7 @@ console_scripts = intake.drivers = rubicon_ml_experiment = rubicon_ml.intake_rubicon.experiment:ExperimentSource rubicon_ml_experiment_table = rubicon_ml.intake_rubicon.viz:ExperimentsTableDataSource + rubicon_ml_metric_correlation_plot = rubicon_ml.intake_rubicon.viz:MetricCorrelationPlotDataSource [versioneer] vcs = git diff --git a/tests/unit/intake_rubicon/test_publish.py b/tests/unit/intake_rubicon/test_publish.py index 09ea394a..b3c31021 100644 --- a/tests/unit/intake_rubicon/test_publish.py +++ b/tests/unit/intake_rubicon/test_publish.py @@ -3,9 +3,13 @@ from rubicon_ml import Rubicon, publish from rubicon_ml.viz.experiments_table import ExperimentsTable +from rubicon_ml.viz.metric_correlation_plot import MetricCorrelationPlot def test_publish(project_client): + + # Experiment Table + project = project_client experiment = project.log_experiment() visualization_object = ExperimentsTable() @@ -35,6 +39,35 @@ def test_publish(project_client): ) assert catalog["sources"]["experiment_table"] is not None + # Metric Correlation Plot + + visualization_object = MetricCorrelationPlot() + catalog_yaml = publish(project.experiments(), visualization_object) + catalog = yaml.safe_load(catalog_yaml) + + assert f"experiment_{experiment.id.replace('-', '_')}" in catalog["sources"] + assert ( + "rubicon_ml_experiment" + == catalog["sources"][f"experiment_{experiment.id.replace('-', '_')}"]["driver"] + ) + assert ( + experiment.repository.root_dir + == catalog["sources"][f"experiment_{experiment.id.replace('-', '_')}"]["args"]["urlpath"] + ) + assert ( + experiment.id + == catalog["sources"][f"experiment_{experiment.id.replace('-', '_')}"]["args"][ + "experiment_id" + ] + ) + assert ( + project.name + == catalog["sources"][f"experiment_{experiment.id.replace('-', '_')}"]["args"][ + "project_name" + ] + ) + assert catalog["sources"]["metric_correlation_plot"] is not None + def test_publish_from_multiple_sources(): rubicon_a = Rubicon(persistence="memory", root_dir="path/a") diff --git a/tests/unit/intake_rubicon/test_viz.py b/tests/unit/intake_rubicon/test_viz.py index 1623eb26..a6ec6cc5 100644 --- a/tests/unit/intake_rubicon/test_viz.py +++ b/tests/unit/intake_rubicon/test_viz.py @@ -1,6 +1,6 @@ import os -from rubicon_ml.intake_rubicon.viz import ExperimentsTableDataSource +from rubicon_ml.intake_rubicon.viz import ExperimentsTableDataSource, MetricCorrelationPlotDataSource root = os.path.dirname(__file__) @@ -33,3 +33,25 @@ def test_experiments_table_source(): assert visualization.parameter_query_type == catalog_data_sample["parameter_query_type"] source.close() + +def test_metric_correlation_plot_source(): + catalog_data_sample = { + "metric_names": None, + "parameter_names": None, + "selected_metric": None, + } + + source = MetricCorrelationPlotDataSource(catalog_data_sample) + assert source is not None + + source.discover() + + visualization = source.read() + + assert visualization is not None + assert visualization.metric_names == catalog_data_sample["metric_names"] + assert visualization.parameter_names == catalog_data_sample["parameter_names"] + assert visualization.selected_metric == catalog_data_sample["selected_metric"] + + source.close() +