Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

metric correlation plot intake source added #464

Merged
merged 3 commits into from
Jul 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ dependencies:
- versioneer

# for packaging
- setuptools
- setuptools<71.0.0
- wheel

# for edgetest
Expand Down
11 changes: 9 additions & 2 deletions rubicon_ml/intake_rubicon/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
import intake # noqa F401

from rubicon_ml.intake_rubicon.experiment import ExperimentSource
from rubicon_ml.intake_rubicon.viz import ExperimentsTableDataSource
from rubicon_ml.intake_rubicon.viz import (
ExperimentsTableDataSource,
MetricCorrelationPlotDataSource,
)

__all__ = ["ExperimentSource", "ExperimentsTableDataSource"]
__all__ = [
"ExperimentSource",
"ExperimentsTableDataSource",
"MetricCorrelationPlotDataSource",
]
46 changes: 30 additions & 16 deletions rubicon_ml/intake_rubicon/publish.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Optional, Union

import fsspec
import yaml

if TYPE_CHECKING:
from rubicon_ml.viz.experiments_table import ExperimentsTable
from rubicon_ml.viz.metric_correlation_plot import MetricCorrelationPlot


def publish(
experiments,
# visualization object passed, defaulted to None
visualization_object: Optional["ExperimentsTable"] = None,
visualization_object: Optional[Union["ExperimentsTable", "MetricCorrelationPlot"]] = None,
output_filepath=None,
base_catalog_filepath=None,
):
Expand Down Expand Up @@ -102,6 +102,9 @@ def _update_catalog(


def _build_catalog(experiments, visualization):
from rubicon_ml.viz.experiments_table import ExperimentsTable
from rubicon_ml.viz.metric_correlation_plot import MetricCorrelationPlot

"""Helper function to build catalog dictionary from given experiments.
Parameters
Expand Down Expand Up @@ -132,20 +135,31 @@ def _build_catalog(experiments, visualization):

# create visualization entry to the catalog file
if visualization is not None:
appended_visualization_catalog = {
"driver": "rubicon_ml_experiment_table",
"args": {
"is_selectable": visualization.is_selectable,
"metric_names": visualization.metric_names,
"metric_query_tags": visualization.metric_query_tags,
"metric_query_type": visualization.metric_query_type,
"parameter_names": visualization.parameter_names,
"parameter_query_tags": visualization.parameter_query_tags,
"parameter_query_type": visualization.parameter_query_type,
},
}
if isinstance(visualization, ExperimentsTable):
appended_visualization_catalog = {
"driver": "rubicon_ml_experiment_table",
"args": {
"is_selectable": visualization.is_selectable,
"metric_names": visualization.metric_names,
"metric_query_tags": visualization.metric_query_tags,
"metric_query_type": visualization.metric_query_type,
"parameter_names": visualization.parameter_names,
"parameter_query_tags": visualization.parameter_query_tags,
"parameter_query_type": visualization.parameter_query_type,
},
}
catalog["sources"]["experiment_table"] = appended_visualization_catalog
if isinstance(visualization, MetricCorrelationPlot):
appended_visualization_catalog = {
"driver": "rubicon_ml_metric_correlation_plot",
"args": {
"metric_names": visualization.metric_names,
"parameter_names": visualization.parameter_names,
"selected_metric": visualization.selected_metric,
},
}
catalog["sources"]["metric_correlation_plot"] = appended_visualization_catalog

# append visualization object to end of catalog file
catalog["sources"]["experiment_table"] = appended_visualization_catalog

return catalog
21 changes: 21 additions & 0 deletions rubicon_ml/intake_rubicon/viz.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from rubicon_ml import __version__
from rubicon_ml.intake_rubicon.base import VizDataSourceMixin
from rubicon_ml.viz import MetricCorrelationPlot


class ExperimentsTableDataSource(VizDataSourceMixin):
Expand All @@ -22,3 +23,23 @@ def _get_schema(self):
self._visualization_object = ExperimentsTable(**self._catalog_data)

return super()._get_schema()


class MetricCorrelationPlotDataSource(VizDataSourceMixin):
"""An Intake data source for reading `rubicon` Metric Correlation Plot visualizations."""

version = __version__

container = "python"
name = "rubicon_ml_metric_correlation_plot"

def __init__(self, metadata=None, **catalog_data):
self._catalog_data = catalog_data or {}
print(self._catalog_data)

super().__init__(metadata=metadata)

def _get_schema(self):
"""Creates a Metric Correlation Plot visualization and sets it as the visualization object attribute"""
self._visualization_object = MetricCorrelationPlot(**self._catalog_data)
return super()._get_schema()
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ console_scripts =
intake.drivers =
rubicon_ml_experiment = rubicon_ml.intake_rubicon.experiment:ExperimentSource
rubicon_ml_experiment_table = rubicon_ml.intake_rubicon.viz:ExperimentsTableDataSource
rubicon_ml_metric_correlation_plot = rubicon_ml.intake_rubicon.viz:MetricCorrelationPlotDataSource

[versioneer]
vcs = git
Expand Down
32 changes: 32 additions & 0 deletions tests/unit/intake_rubicon/test_publish.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@

from rubicon_ml import Rubicon, publish
from rubicon_ml.viz.experiments_table import ExperimentsTable
from rubicon_ml.viz.metric_correlation_plot import MetricCorrelationPlot


def test_publish(project_client):
# Experiment Table

project = project_client
experiment = project.log_experiment()
visualization_object = ExperimentsTable()
Expand Down Expand Up @@ -35,6 +38,35 @@ def test_publish(project_client):
)
assert catalog["sources"]["experiment_table"] is not None

# Metric Correlation Plot

visualization_object = MetricCorrelationPlot()
catalog_yaml = publish(project.experiments(), visualization_object)
catalog = yaml.safe_load(catalog_yaml)

assert f"experiment_{experiment.id.replace('-', '_')}" in catalog["sources"]
assert (
"rubicon_ml_experiment"
== catalog["sources"][f"experiment_{experiment.id.replace('-', '_')}"]["driver"]
)
assert (
experiment.repository.root_dir
== catalog["sources"][f"experiment_{experiment.id.replace('-', '_')}"]["args"]["urlpath"]
)
assert (
experiment.id
== catalog["sources"][f"experiment_{experiment.id.replace('-', '_')}"]["args"][
"experiment_id"
]
)
assert (
project.name
== catalog["sources"][f"experiment_{experiment.id.replace('-', '_')}"]["args"][
"project_name"
]
)
assert catalog["sources"]["metric_correlation_plot"] is not None


def test_publish_from_multiple_sources():
rubicon_a = Rubicon(persistence="memory", root_dir="path/a")
Expand Down
27 changes: 26 additions & 1 deletion tests/unit/intake_rubicon/test_viz.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import os

from rubicon_ml.intake_rubicon.viz import ExperimentsTableDataSource
from rubicon_ml.intake_rubicon.viz import (
ExperimentsTableDataSource,
MetricCorrelationPlotDataSource,
)

root = os.path.dirname(__file__)

Expand Down Expand Up @@ -33,3 +36,25 @@ def test_experiments_table_source():
assert visualization.parameter_query_type == catalog_data_sample["parameter_query_type"]

source.close()


def test_metric_correlation_plot_source():
catalog_data_sample = {
"metric_names": None,
"parameter_names": None,
"selected_metric": None,
}

source = MetricCorrelationPlotDataSource(catalog_data_sample)
assert source is not None

source.discover()

visualization = source.read()

assert visualization is not None
assert visualization.metric_names == catalog_data_sample["metric_names"]
assert visualization.parameter_names == catalog_data_sample["parameter_names"]
assert visualization.selected_metric == catalog_data_sample["selected_metric"]

source.close()
Loading