Skip to content

Commit

Permalink
catalog dataframe plot (#465)
Browse files Browse the repository at this point in the history
* Created an Intake Source for DataframePlotDataSource for reading Dataframe Plot visualizations
* fixing typing errors
* update styling issues
* update environment.yml
* remove nvm files
* update test_viz.py for dataframe test
* update parameters in test_viz.py for dataframe test
* merging branch to main
* resolve merge conflicts
* remove nvm files
* moved imports in viz.py and added dataframe plot tests in test_publish.py
* add dataframe argumentfor publish tests
* debugging test_viz.py
* removed experiments check in test_viz

---------

Co-authored-by: Jacqueline Hui <[email protected]>
  • Loading branch information
jeh362 and jhui18 authored Jul 19, 2024
1 parent 8c6b527 commit 17c4971
Show file tree
Hide file tree
Showing 6 changed files with 113 additions and 3 deletions.
2 changes: 2 additions & 0 deletions rubicon_ml/intake_rubicon/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from rubicon_ml.intake_rubicon.experiment import ExperimentSource
from rubicon_ml.intake_rubicon.viz import (
DataframePlotDataSource,
ExperimentsTableDataSource,
MetricCorrelationPlotDataSource,
)
Expand All @@ -10,4 +11,5 @@
"ExperimentSource",
"ExperimentsTableDataSource",
"MetricCorrelationPlotDataSource",
"DataframePlotDataSource",
]
21 changes: 20 additions & 1 deletion rubicon_ml/intake_rubicon/publish.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,16 @@
import yaml

if TYPE_CHECKING:
from rubicon_ml.viz import DataframePlot
from rubicon_ml.viz.experiments_table import ExperimentsTable
from rubicon_ml.viz.metric_correlation_plot import MetricCorrelationPlot


def publish(
experiments,
visualization_object: Optional[Union["ExperimentsTable", "MetricCorrelationPlot"]] = None,
visualization_object: Optional[
Union["ExperimentsTable", "MetricCorrelationPlot", "DataframePlot"]
] = None,
output_filepath=None,
base_catalog_filepath=None,
):
Expand Down Expand Up @@ -102,6 +105,7 @@ def _update_catalog(


def _build_catalog(experiments, visualization):
from rubicon_ml.viz import DataframePlot
from rubicon_ml.viz.experiments_table import ExperimentsTable
from rubicon_ml.viz.metric_correlation_plot import MetricCorrelationPlot

Expand Down Expand Up @@ -149,6 +153,7 @@ def _build_catalog(experiments, visualization):
},
}
catalog["sources"]["experiment_table"] = appended_visualization_catalog

if isinstance(visualization, MetricCorrelationPlot):
appended_visualization_catalog = {
"driver": "rubicon_ml_metric_correlation_plot",
Expand All @@ -160,6 +165,20 @@ def _build_catalog(experiments, visualization):
}
catalog["sources"]["metric_correlation_plot"] = appended_visualization_catalog

# vizualization is an DataframePlot
if isinstance(visualization, DataframePlot):
appended_visualization_catalog = {
"driver": "rubicon_ml_dataframe_plot",
"args": {
"dataframe_name": visualization.dataframe_name,
"x": visualization.x,
"y": visualization.y,
},
}

# append visualization object to end of catalog file
catalog["sources"]["dataframe_plot"] = appended_visualization_catalog

# append visualization object to end of catalog file

return catalog
27 changes: 25 additions & 2 deletions rubicon_ml/intake_rubicon/viz.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from rubicon_ml import __version__
from rubicon_ml.intake_rubicon.base import VizDataSourceMixin
from rubicon_ml.viz import MetricCorrelationPlot


class ExperimentsTableDataSource(VizDataSourceMixin):
Expand Down Expand Up @@ -35,11 +34,35 @@ class MetricCorrelationPlotDataSource(VizDataSourceMixin):

def __init__(self, metadata=None, **catalog_data):
self._catalog_data = catalog_data or {}
print(self._catalog_data)

super().__init__(metadata=metadata)

def _get_schema(self):
"""Creates a Metric Correlation Plot visualization and sets it as the visualization object attribute"""
from rubicon_ml.viz import MetricCorrelationPlot

self._visualization_object = MetricCorrelationPlot(**self._catalog_data)

return super()._get_schema()


class DataframePlotDataSource(VizDataSourceMixin):
"""An Intake data source for reading `rubicon` Dataframe Plot visualizations."""

version = __version__

container = "python"
name = "rubicon_ml_dataframe_plot"

def __init__(self, metadata=None, **catalog_data):
self._catalog_data = catalog_data or {}

super().__init__(metadata=metadata)

def _get_schema(self):
"""Creates a Dataframe Plot visualization and sets it as the visualization object attribute"""
from rubicon_ml.viz import DataframePlot

self._visualization_object = DataframePlot(**self._catalog_data)

return super()._get_schema()
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ intake.drivers =
rubicon_ml_experiment = rubicon_ml.intake_rubicon.experiment:ExperimentSource
rubicon_ml_experiment_table = rubicon_ml.intake_rubicon.viz:ExperimentsTableDataSource
rubicon_ml_metric_correlation_plot = rubicon_ml.intake_rubicon.viz:MetricCorrelationPlotDataSource
rubicon_ml_dataframe_plot = rubicon_ml.intake_rubicon.viz:DataframePlotDataSource

[versioneer]
vcs = git
Expand Down
30 changes: 30 additions & 0 deletions tests/unit/intake_rubicon/test_publish.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import yaml

from rubicon_ml import Rubicon, publish
from rubicon_ml.viz.dataframe_plot import DataframePlot
from rubicon_ml.viz.experiments_table import ExperimentsTable
from rubicon_ml.viz.metric_correlation_plot import MetricCorrelationPlot

Expand Down Expand Up @@ -67,6 +68,35 @@ def test_publish(project_client):
)
assert catalog["sources"]["metric_correlation_plot"] is not None

# Dataframe Plot

visualization_object = DataframePlot(dataframe_name="test_dataframe")
catalog_yaml = publish(project.experiments(), visualization_object)
catalog = yaml.safe_load(catalog_yaml)

assert f"experiment_{experiment.id.replace('-', '_')}" in catalog["sources"]
assert (
"rubicon_ml_experiment"
== catalog["sources"][f"experiment_{experiment.id.replace('-', '_')}"]["driver"]
)
assert (
experiment.repository.root_dir
== catalog["sources"][f"experiment_{experiment.id.replace('-', '_')}"]["args"]["urlpath"]
)
assert (
experiment.id
== catalog["sources"][f"experiment_{experiment.id.replace('-', '_')}"]["args"][
"experiment_id"
]
)
assert (
project.name
== catalog["sources"][f"experiment_{experiment.id.replace('-', '_')}"]["args"][
"project_name"
]
)
assert catalog["sources"]["dataframe_plot"] is not None


def test_publish_from_multiple_sources():
rubicon_a = Rubicon(persistence="memory", root_dir="path/a")
Expand Down
35 changes: 35 additions & 0 deletions tests/unit/intake_rubicon/test_viz.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os

from rubicon_ml.intake_rubicon.viz import (
DataframePlotDataSource,
ExperimentsTableDataSource,
MetricCorrelationPlotDataSource,
)
Expand Down Expand Up @@ -58,3 +59,37 @@ def test_metric_correlation_plot_source():
assert visualization.selected_metric == catalog_data_sample["selected_metric"]

source.close()


def test_datatable_plot_source():
catalog_data_sample = {
"dataframe_name": "dataframe_name",
"experiments": None,
"plotting_func": None,
"plotting_func_kwargs": None,
"x": None,
"y": None,
}

source = DataframePlotDataSource(
dataframe_name="dataframe_name",
experiments=None,
plotting_func=None,
plotting_func_kwargs=None,
x=None,
y=None,
)
assert source is not None

source.discover()

visualization = source.read()

assert visualization is not None
assert visualization.dataframe_name == catalog_data_sample["dataframe_name"]
assert visualization.plotting_func == catalog_data_sample["plotting_func"]
assert visualization.plotting_func_kwargs == catalog_data_sample["plotting_func_kwargs"]
assert visualization.x == catalog_data_sample["x"]
assert visualization.y == catalog_data_sample["y"]

source.close()

0 comments on commit 17c4971

Please sign in to comment.