Skip to content

Commit

Permalink
Created an Intake Source for Experiment Table (#456)
Browse files Browse the repository at this point in the history
* Update __init__.py

Added the ability to get access to the ExperimentsTableDataSource

* Update __init__.py

* Create viz.py

Uses VizDataSourceMixing and is an Intake data source for reading rubicon Experiment Table Visualizations

* Update experiments_table.py

Had to add a more detailed path to publish due to error with partially initialized module.

* Update setup.cfg

Added the experiment table data source as a an Intake Driver

* Fixed visualization object input for test_publish.py

Inputs to tests are fixtures only and since we can't pass visualization_object, we create one with an empty ExperimentsTable in the test_publish() function itself.

* Create test_viz.py

Add a new test file to contain the new test_experiments_table_source for testing on the source itself.
  • Loading branch information
yashvb authored Jul 9, 2024
1 parent 3748bf6 commit 3735436
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 8 deletions.
3 changes: 2 additions & 1 deletion rubicon_ml/intake_rubicon/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import intake # noqa F401

from rubicon_ml.intake_rubicon.experiment import ExperimentSource
from rubicon_ml.intake_rubicon.viz import ExperimentsTableDataSource

__all__ = ["ExperimentSource"]
__all__ = ["ExperimentSource", "ExperimentsTableDataSource"]
22 changes: 22 additions & 0 deletions rubicon_ml/intake_rubicon/viz.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from rubicon_ml import __version__
from rubicon_ml.intake_rubicon.base import VizDataSourceMixin
from rubicon_ml.viz import ExperimentsTable


class ExperimentsTableDataSource(VizDataSourceMixin):
"""An Intake data source for reading `rubicon` Experiment Table visualizations."""

version = __version__

container = "python"
name = "rubicon_ml_experiments_table"

def __init__(self, metadata=None, **catalog_data):
self._catalog_data = catalog_data or {}

super().__init__(metadata=metadata)

def _get_schema(self):
"""Creates an Experiments Table visualization and sets it as the visualization object attribute"""
self._visualization_object = ExperimentsTable(**self._catalog_data)
return super()._get_schema()
2 changes: 1 addition & 1 deletion rubicon_ml/viz/experiments_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from dash import dash_table, dcc, html
from dash.dependencies import ALL, Input, Output, State

from rubicon_ml import publish
from rubicon_ml.intake_rubicon.publish import publish
from rubicon_ml.viz.base import VizBase
from rubicon_ml.viz.common.colors import light_blue, plot_background_blue

Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ console_scripts =
rubicon_ml = rubicon_ml.cli:cli
intake.drivers =
rubicon_ml_experiment = rubicon_ml.intake_rubicon.experiment:ExperimentSource
rubicon_ml_experiment_table = rubicon_ml.intake_rubicon.viz:ExperimentsTableDataSource

[versioneer]
vcs = git
Expand Down
10 changes: 4 additions & 6 deletions tests/unit/intake_rubicon/test_publish.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
from typing import TYPE_CHECKING, Optional

import fsspec
import yaml

from rubicon_ml import Rubicon, publish

if TYPE_CHECKING:
from rubicon_ml.viz.experiments_table import ExperimentsTable
from rubicon_ml.viz.experiments_table import ExperimentsTable


def test_publish(project_client, visualization_object: Optional["ExperimentsTable"] = None):
def test_publish(project_client):
project = project_client
experiment = project.log_experiment()
visualization_object = ExperimentsTable()
catalog_yaml = publish(project.experiments(), visualization_object)
catalog = yaml.safe_load(catalog_yaml)

Expand All @@ -36,6 +33,7 @@ def test_publish(project_client, visualization_object: Optional["ExperimentsTabl
"project_name"
]
)
assert catalog["sources"]["experiment_table"] is not None


def test_publish_from_multiple_sources():
Expand Down
35 changes: 35 additions & 0 deletions tests/unit/intake_rubicon/test_viz.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import os

from rubicon_ml.intake_rubicon.viz import ExperimentsTableDataSource

root = os.path.dirname(__file__)


def test_experiments_table_source():
catalog_data_sample = {
"is_selectable": True,
"metric_names": None,
"metric_query_tags": None,
"metric_query_type": None,
"parameter_names": None,
"parameter_query_tags": None,
"parameter_query_type": None,
}

source = ExperimentsTableDataSource(catalog_data_sample)
assert source is not None

source.discover()

visualization = source.read()

assert visualization is not None
assert visualization.is_selectable == catalog_data_sample["is_selectable"]
assert visualization.metric_names == catalog_data_sample["metric_names"]
assert visualization.metric_query_tags == catalog_data_sample["metric_query_tags"]
assert visualization.metric_query_type == catalog_data_sample["metric_query_type"]
assert visualization.parameter_names == catalog_data_sample["parameter_names"]
assert visualization.parameter_query_tags == catalog_data_sample["parameter_query_tags"]
assert visualization.parameter_query_type == catalog_data_sample["parameter_query_type"]

source.close()

0 comments on commit 3735436

Please sign in to comment.