Skip to content

Commit

Permalink
Adding MetricListComparisonDataSource (#467)
Browse files Browse the repository at this point in the history
* Created an intake source for MetricListComparisonDataSource for reading Metric List Comparison visualizations
* updated formatting
* removed duplicate test in test_viz.py
* edit metriclist types and removing experiments from metriclistcomparison in publish.py
* edit test_viz.py metric list parameter names
* update metric list tests
* edit comments

---------

Co-authored-by: Jacqueline Hui <[email protected]>
  • Loading branch information
jeh362 and jhui18 authored Jul 22, 2024
1 parent 17c4971 commit f452770
Show file tree
Hide file tree
Showing 6 changed files with 94 additions and 2 deletions.
2 changes: 2 additions & 0 deletions rubicon_ml/intake_rubicon/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
DataframePlotDataSource,
ExperimentsTableDataSource,
MetricCorrelationPlotDataSource,
MetricListComparisonDataSource,
)

__all__ = [
"ExperimentSource",
"ExperimentsTableDataSource",
"MetricCorrelationPlotDataSource",
"DataframePlotDataSource",
"MetricListComparisonDataSource",
]
23 changes: 21 additions & 2 deletions rubicon_ml/intake_rubicon/publish.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@
from rubicon_ml.viz import DataframePlot
from rubicon_ml.viz.experiments_table import ExperimentsTable
from rubicon_ml.viz.metric_correlation_plot import MetricCorrelationPlot
from rubicon_ml.viz.metric_lists_comparison import MetricListsComparison


def publish(
experiments,
visualization_object: Optional[
Union["ExperimentsTable", "MetricCorrelationPlot", "DataframePlot"]
Union["ExperimentsTable", "MetricCorrelationPlot", "DataframePlot", "MetricListsComparison"]
] = None,
output_filepath=None,
base_catalog_filepath=None,
Expand Down Expand Up @@ -108,6 +109,7 @@ def _build_catalog(experiments, visualization):
from rubicon_ml.viz import DataframePlot
from rubicon_ml.viz.experiments_table import ExperimentsTable
from rubicon_ml.viz.metric_correlation_plot import MetricCorrelationPlot
from rubicon_ml.viz.metric_lists_comparison import MetricListsComparison

"""Helper function to build catalog dictionary from given experiments.
Expand Down Expand Up @@ -138,6 +140,7 @@ def _build_catalog(experiments, visualization):
catalog["sources"][experiment_catalog_name] = appended_experiment_catalog

# create visualization entry to the catalog file
# visualization is an ExperimentsTable
if visualization is not None:
if isinstance(visualization, ExperimentsTable):
appended_visualization_catalog = {
Expand All @@ -154,6 +157,7 @@ def _build_catalog(experiments, visualization):
}
catalog["sources"]["experiment_table"] = appended_visualization_catalog

# vizualization is a MetricCorrelationPlot
if isinstance(visualization, MetricCorrelationPlot):
appended_visualization_catalog = {
"driver": "rubicon_ml_metric_correlation_plot",
Expand All @@ -165,7 +169,7 @@ def _build_catalog(experiments, visualization):
}
catalog["sources"]["metric_correlation_plot"] = appended_visualization_catalog

# vizualization is an DataframePlot
# vizualization is a DataframePlot
if isinstance(visualization, DataframePlot):
appended_visualization_catalog = {
"driver": "rubicon_ml_dataframe_plot",
Expand All @@ -181,4 +185,19 @@ def _build_catalog(experiments, visualization):

# append visualization object to end of catalog file

# vizualization is a MetricListsComparison
if isinstance(visualization, MetricListsComparison):
appended_visualization_catalog = {
"driver": "rubicon_ml_metric_list",
"args": {
"column_names": visualization.column_names,
"selected_metric": visualization.selected_metric,
},
}

# append visualization object to end of catalog file
catalog["sources"]["metric_list"] = appended_visualization_catalog

# append visualization object to end of catalog file

return catalog
22 changes: 22 additions & 0 deletions rubicon_ml/intake_rubicon/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,25 @@ def _get_schema(self):
self._visualization_object = DataframePlot(**self._catalog_data)

return super()._get_schema()


class MetricListComparisonDataSource(VizDataSourceMixin):
"""An Intake data source for reading `rubicon` Metric List Comparison visualizations."""

version = __version__

container = "python"
name = "rubicon_ml_metric_list"

def __init__(self, metadata=None, **catalog_data):
self._catalog_data = catalog_data or {}

super().__init__(metadata=metadata)

def _get_schema(self):
"""Creates a Metric List Comparison visualization and sets it as the visualization object attribute"""
from rubicon_ml.viz import MetricListsComparison

self._visualization_object = MetricListsComparison(**self._catalog_data)

return super()._get_schema()
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ intake.drivers =
rubicon_ml_experiment_table = rubicon_ml.intake_rubicon.viz:ExperimentsTableDataSource
rubicon_ml_metric_correlation_plot = rubicon_ml.intake_rubicon.viz:MetricCorrelationPlotDataSource
rubicon_ml_dataframe_plot = rubicon_ml.intake_rubicon.viz:DataframePlotDataSource
rubicon_ml_metric_list = rubicon_ml.intake_rubicon.viz:MetricListComparisonDataSource

[versioneer]
vcs = git
Expand Down
30 changes: 30 additions & 0 deletions tests/unit/intake_rubicon/test_publish.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from rubicon_ml.viz.dataframe_plot import DataframePlot
from rubicon_ml.viz.experiments_table import ExperimentsTable
from rubicon_ml.viz.metric_correlation_plot import MetricCorrelationPlot
from rubicon_ml.viz.metric_lists_comparison import MetricListsComparison


def test_publish(project_client):
Expand Down Expand Up @@ -97,6 +98,35 @@ def test_publish(project_client):
)
assert catalog["sources"]["dataframe_plot"] is not None

# MetricListComparison

visualization_object = MetricListsComparison()
catalog_yaml = publish(project.experiments(), visualization_object)
catalog = yaml.safe_load(catalog_yaml)

assert f"experiment_{experiment.id.replace('-', '_')}" in catalog["sources"]
assert (
"rubicon_ml_experiment"
== catalog["sources"][f"experiment_{experiment.id.replace('-', '_')}"]["driver"]
)
assert (
experiment.repository.root_dir
== catalog["sources"][f"experiment_{experiment.id.replace('-', '_')}"]["args"]["urlpath"]
)
assert (
experiment.id
== catalog["sources"][f"experiment_{experiment.id.replace('-', '_')}"]["args"][
"experiment_id"
]
)
assert (
project.name
== catalog["sources"][f"experiment_{experiment.id.replace('-', '_')}"]["args"][
"project_name"
]
)
assert catalog["sources"]["metric_list"] is not None


def test_publish_from_multiple_sources():
rubicon_a = Rubicon(persistence="memory", root_dir="path/a")
Expand Down
18 changes: 18 additions & 0 deletions tests/unit/intake_rubicon/test_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
DataframePlotDataSource,
ExperimentsTableDataSource,
MetricCorrelationPlotDataSource,
MetricListComparisonDataSource,
)

root = os.path.dirname(__file__)
Expand Down Expand Up @@ -93,3 +94,20 @@ def test_datatable_plot_source():
assert visualization.y == catalog_data_sample["y"]

source.close()


def test_metric_list_source():
catalog_data_sample = {"column_names": None, "selected_metric": None}

source = MetricListComparisonDataSource(catalog_data_sample)
assert source is not None

source.discover()

visualization = source.read()

assert visualization is not None
assert visualization.column_names == catalog_data_sample["column_names"]
assert visualization.selected_metric == catalog_data_sample["selected_metric"]

source.close()

0 comments on commit f452770

Please sign in to comment.