Skip to content

Commit

Permalink
metric correlation plot
Browse files Browse the repository at this point in the history
  • Loading branch information
austinhk authored and Austin Kim committed Jul 18, 2024
1 parent 878d691 commit 566be50
Show file tree
Hide file tree
Showing 6 changed files with 111 additions and 20 deletions.
4 changes: 2 additions & 2 deletions rubicon_ml/intake_rubicon/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import intake # noqa F401

from rubicon_ml.intake_rubicon.experiment import ExperimentSource
from rubicon_ml.intake_rubicon.viz import ExperimentsTableDataSource
from rubicon_ml.intake_rubicon.viz import ExperimentsTableDataSource, MetricCorrelationPlotDataSource

__all__ = ["ExperimentSource", "ExperimentsTableDataSource"]
__all__ = ["ExperimentSource", "ExperimentsTableDataSource", "MetricCorrelationPlotDataSource",]
49 changes: 32 additions & 17 deletions rubicon_ml/intake_rubicon/publish.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Optional, Union

import fsspec
import yaml

if TYPE_CHECKING:
from rubicon_ml.viz.experiments_table import ExperimentsTable
from rubicon_ml.viz.metric_correlation_plot import MetricCorrelationPlot



def publish(
experiments,
# visualization object passed, defaulted to None
visualization_object: Optional["ExperimentsTable"] = None,
visualization_object: Optional[Union["ExperimentsTable", "MetricCorrelationPlot"]] = None,
output_filepath=None,
base_catalog_filepath=None,
):
Expand Down Expand Up @@ -102,6 +103,9 @@ def _update_catalog(


def _build_catalog(experiments, visualization):
from rubicon_ml.viz.experiments_table import ExperimentsTable
from rubicon_ml.viz.metric_correlation_plot import MetricCorrelationPlot

"""Helper function to build catalog dictionary from given experiments.
Parameters
Expand Down Expand Up @@ -132,20 +136,31 @@ def _build_catalog(experiments, visualization):

# create visualization entry to the catalog file
if visualization is not None:
appended_visualization_catalog = {
"driver": "rubicon_ml_experiment_table",
"args": {
"is_selectable": visualization.is_selectable,
"metric_names": visualization.metric_names,
"metric_query_tags": visualization.metric_query_tags,
"metric_query_type": visualization.metric_query_type,
"parameter_names": visualization.parameter_names,
"parameter_query_tags": visualization.parameter_query_tags,
"parameter_query_type": visualization.parameter_query_type,
},
}
if isinstance(visualization, ExperimentsTable):
appended_visualization_catalog = {
"driver": "rubicon_ml_experiment_table",
"args": {
"is_selectable": visualization.is_selectable,
"metric_names": visualization.metric_names,
"metric_query_tags": visualization.metric_query_tags,
"metric_query_type": visualization.metric_query_type,
"parameter_names": visualization.parameter_names,
"parameter_query_tags": visualization.parameter_query_tags,
"parameter_query_type": visualization.parameter_query_type,
},
}
catalog["sources"]["experiment_table"] = appended_visualization_catalog
if isinstance(visualization, MetricCorrelationPlot):
appended_visualization_catalog = {
"driver": "rubicon_ml_metric_correlation_plot",
"args": {
"metric_names": visualization.metric_names,
"parameter_names": visualization.parameter_names,
"selected_metric": visualization.selected_metric,
},
}
catalog["sources"]["metric_correlation_plot"] = appended_visualization_catalog

# append visualization object to end of catalog file
catalog["sources"]["experiment_table"] = appended_visualization_catalog


return catalog
20 changes: 20 additions & 0 deletions rubicon_ml/intake_rubicon/viz.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from rubicon_ml import __version__
from rubicon_ml.intake_rubicon.base import VizDataSourceMixin
from rubicon_ml.viz import MetricCorrelationPlot


class ExperimentsTableDataSource(VizDataSourceMixin):
Expand All @@ -22,3 +23,22 @@ def _get_schema(self):
self._visualization_object = ExperimentsTable(**self._catalog_data)

return super()._get_schema()

class MetricCorrelationPlotDataSource(VizDataSourceMixin):
"""An Intake data source for reading `rubicon` Metric Correlation Plot visualizations."""

version = __version__

container = "python"
name = "rubicon_ml_metric_correlation_plot"

def __init__(self, metadata=None, **catalog_data):
self._catalog_data = catalog_data or {}
print(self._catalog_data)

super().__init__(metadata=metadata)

def _get_schema(self):
"""Creates a Metric Correlation Plot visualization and sets it as the visualization object attribute"""
self._visualization_object = MetricCorrelationPlot(**self._catalog_data)
return super()._get_schema()
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ console_scripts =
intake.drivers =
rubicon_ml_experiment = rubicon_ml.intake_rubicon.experiment:ExperimentSource
rubicon_ml_experiment_table = rubicon_ml.intake_rubicon.viz:ExperimentsTableDataSource
rubicon_ml_metric_correlation_plot = rubicon_ml.intake_rubicon.viz:MetricCorrelationPlotDataSource

[versioneer]
vcs = git
Expand Down
33 changes: 33 additions & 0 deletions tests/unit/intake_rubicon/test_publish.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,13 @@

from rubicon_ml import Rubicon, publish
from rubicon_ml.viz.experiments_table import ExperimentsTable
from rubicon_ml.viz.metric_correlation_plot import MetricCorrelationPlot


def test_publish(project_client):

# Experiment Table

project = project_client
experiment = project.log_experiment()
visualization_object = ExperimentsTable()
Expand Down Expand Up @@ -35,6 +39,35 @@ def test_publish(project_client):
)
assert catalog["sources"]["experiment_table"] is not None

# Metric Correlation Plot

visualization_object = MetricCorrelationPlot()
catalog_yaml = publish(project.experiments(), visualization_object)
catalog = yaml.safe_load(catalog_yaml)

assert f"experiment_{experiment.id.replace('-', '_')}" in catalog["sources"]
assert (
"rubicon_ml_experiment"
== catalog["sources"][f"experiment_{experiment.id.replace('-', '_')}"]["driver"]
)
assert (
experiment.repository.root_dir
== catalog["sources"][f"experiment_{experiment.id.replace('-', '_')}"]["args"]["urlpath"]
)
assert (
experiment.id
== catalog["sources"][f"experiment_{experiment.id.replace('-', '_')}"]["args"][
"experiment_id"
]
)
assert (
project.name
== catalog["sources"][f"experiment_{experiment.id.replace('-', '_')}"]["args"][
"project_name"
]
)
assert catalog["sources"]["metric_correlation_plot"] is not None


def test_publish_from_multiple_sources():
rubicon_a = Rubicon(persistence="memory", root_dir="path/a")
Expand Down
24 changes: 23 additions & 1 deletion tests/unit/intake_rubicon/test_viz.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os

from rubicon_ml.intake_rubicon.viz import ExperimentsTableDataSource
from rubicon_ml.intake_rubicon.viz import ExperimentsTableDataSource, MetricCorrelationPlotDataSource

root = os.path.dirname(__file__)

Expand Down Expand Up @@ -33,3 +33,25 @@ def test_experiments_table_source():
assert visualization.parameter_query_type == catalog_data_sample["parameter_query_type"]

source.close()

def test_metric_correlation_plot_source():
catalog_data_sample = {
"metric_names": None,
"parameter_names": None,
"selected_metric": None,
}

source = MetricCorrelationPlotDataSource(catalog_data_sample)
assert source is not None

source.discover()

visualization = source.read()

assert visualization is not None
assert visualization.metric_names == catalog_data_sample["metric_names"]
assert visualization.parameter_names == catalog_data_sample["parameter_names"]
assert visualization.selected_metric == catalog_data_sample["selected_metric"]

source.close()

0 comments on commit 566be50

Please sign in to comment.