Skip to content

Commit

Permalink
programmatically select columns (#437)
Browse files Browse the repository at this point in the history
* programmatically select columns
* add new tests
  • Loading branch information
ryanSoley authored Jun 6, 2024
1 parent 81ebb43 commit 0a28bc3
Show file tree
Hide file tree
Showing 5 changed files with 158 additions and 25 deletions.
28 changes: 14 additions & 14 deletions notebooks/viz/experiments-table.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,11 @@
"\n",
"for i in range(0, 24):\n",
" experiment = project.log_experiment()\n",
" experiment.log_parameter(name=\"max_depth\", value=random.randrange(5, 25, 5))\n",
" experiment.log_parameter(name=\"n_estimators\", value=random.randrange(2, 12, 2))\n",
" experiment.log_metric(name=\"accuracy\", value=random.random())"
" experiment.log_parameter(name=\"max_depth\", tags=[\"show me\"], value=random.randrange(5, 25, 5))\n",
" experiment.log_parameter(name=\"n_estimators\", tags=[\"show me\"], value=random.randrange(2, 12, 2))\n",
" experiment.log_parameter(name=\"extra_parameter\", value=random.randrange(0, 2))\n",
" experiment.log_metric(name=\"accuracy\", tags=[\"show me\"], value=random.random())\n",
" experiment.log_metric(name=\"extra_metric\", value=random.random())"
]
},
{
Expand All @@ -67,26 +69,24 @@
"logged and view the table right in this notebook with `show`. The Dash application\n",
"itself will be running on http://127.0.0.1:8050/ when running locally. Use the\n",
"`serve` command to launch the server directly without rendering the widget in the\n",
"current Python interpreter."
"current Python interpreter.\n",
"\n",
"**Note:** The parameters and metrics shown when the `ExperimentsTable` first renders\n",
"can be controlled by the class' other input arguments. The configuration below will\n",
"only show the metrics and parameters tagged with \"show me\" when the table first renders."
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "8cb7d191-2b60-4984-80ba-dbb2b310c4bf",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Dash is running on http://127.0.0.1:8050/\n"
]
}
],
"outputs": [],
"source": [
"ExperimentsTable(\n",
" experiments=project.experiments(),\n",
" parameter_query_tags=[\"show me\"],\n",
" metric_query_tags=[\"show me\"],\n",
").show()"
]
},
Expand Down Expand Up @@ -115,7 +115,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
"version": "3.10.14"
}
},
"nbformat": 4,
Expand Down
Binary file modified notebooks/viz/experiments-table.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
78 changes: 69 additions & 9 deletions rubicon_ml/viz/experiments_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,48 @@ class ExperimentsTable(VizBase):
is_selectable : bool, optional
True to enable selection of the rows in the table, False otherwise.
Defaults to True.
metric_names : list of str
If provided, only show the metrics with names in the given list. If
`metric_query_tags` are also provided, this will only select metrics
from the tag-filtered results.
metric_query_tags : list of str, optional
If provided, only show the metrics with the given tags in the table.
metric_query_type : 'and' or 'or', optional
When `metric_query_tags` are given, 'and' shows the metrics with all of
the given tags and 'or' shows the metrics with any of the given tags.
parameter_names : list of str
If provided, only show the parameters with names in the given list. If
`parameter_query_tags` are also provided, this will only select
parameters from the tag-filtered results.
parameter_query_tags : list of str, optional
If provided, only show the parameters with the given tags in the table.
parameter_query_type : 'and' or 'or', optional
When `parameter_query_tags` are given, 'and' shows the paramters with
all of the given tags and 'or' shows the parameters with any of the
given tags.
"""

def __init__(self, experiments=None, is_selectable=True):
def __init__(
self,
experiments=None,
is_selectable=True,
metric_names=None,
metric_query_tags=None,
metric_query_type=None,
parameter_names=None,
parameter_query_tags=None,
parameter_query_type=None,
):
super().__init__(dash_title="experiment table")

self.experiments = experiments
self.is_selectable = is_selectable
self.metric_names = metric_names
self.metric_query_tags = metric_query_tags
self.metric_query_type = metric_query_type
self.parameter_names = parameter_names
self.parameter_query_tags = parameter_query_tags
self.parameter_query_type = parameter_query_type

@property
def layout(self):
Expand Down Expand Up @@ -219,15 +254,15 @@ def load_experiment_data(self):
if applicable.
"""
self.experiment_records = []
self.metric_names = set()
self.parameter_names = set()

self.all_columns = ["id", "name", "created_at", "model_name", "commit_hash", "tags"]
self.hidden_columns = []

self.commit_hash = None
self.github_url = None

all_parameter_names = set()
all_metric_names = set()
commit_hashes = set()
show_columns = {"id", "created_at"}

Expand Down Expand Up @@ -259,23 +294,48 @@ def load_experiment_data(self):
for parameter in experiment.parameters():
experiment_record[parameter.name] = str(parameter.value)

self.parameter_names.add(parameter.name)
all_parameter_names.add(parameter.name)

for metric in experiment.metrics():
experiment_record[metric.name] = str(metric.value)

self.metric_names.add(metric.name)
all_metric_names.add(metric.name)

self.experiment_records.append(experiment_record)

self.metric_names = list(self.metric_names)
self.parameter_names = list(self.parameter_names)
if self.parameter_query_tags is not None:
parameters = experiment.parameters(
tags=self.parameter_query_tags,
qtype=self.parameter_query_type,
)
show_parameter_names = set([p.name for p in parameters])
else:
show_parameter_names = all_parameter_names

if self.parameter_names is not None:
show_parameter_names = set(
[name for name in show_parameter_names if name in self.parameter_names]
)

if self.metric_query_tags is not None:
metrics = experiment.metrics(
tags=self.metric_query_tags,
qtype=self.metric_query_type,
)
show_metric_names = set([m.name for m in metrics])
else:
show_metric_names = all_metric_names

if self.metric_names is not None:
show_metric_names = set(
[name for name in show_metric_names if name in self.metric_names]
)

self.all_columns.extend(self.parameter_names + self.metric_names)
self.all_columns.extend(list(all_parameter_names) + list(all_metric_names))
self.hidden_columns = [
column
for column in self.all_columns
if column not in list(show_columns) + self.metric_names + self.parameter_names
if column not in show_columns | show_metric_names | show_parameter_names
]

if len(commit_hashes) == 1:
Expand Down
10 changes: 8 additions & 2 deletions tests/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,14 +272,20 @@ def viz_experiments(rubicon_and_project_client):
experiment.log_parameter(name="test param 0", value=random.choice([True, False]))
experiment.log_parameter(name="test param 1", value=random.randrange(2, 10, 2))
experiment.log_parameter(
name="test param 2", value=random.choice(["A", "B", "C", "D", "E"])
name="test param 2",
value=random.choice(["A", "B", "C", "D", "E"]),
tags=["a", "b"],
)

experiment.log_metric(name="test metric 0", value=random.random())
experiment.log_metric(name="test metric 1", value=random.random())

experiment.log_metric(name="test metric 2", value=[random.random() for _ in range(0, 5)])
experiment.log_metric(name="test metric 3", value=[random.random() for _ in range(0, 5)])
experiment.log_metric(
name="test metric 3",
value=[random.random() for _ in range(0, 5)],
tags=["a", "b"],
)

data = np.array(
[
Expand Down
67 changes: 67 additions & 0 deletions tests/unit/viz/test_experiments_table.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pytest
from dash import Dash

from rubicon_ml.viz import ExperimentsTable
Expand Down Expand Up @@ -39,10 +40,76 @@ def test_experiments_table_load_data(viz_experiments):
assert all([record.get(name) is not None for name in expected_metric_names])
assert all([record.get(name) is not None for name in expected_parameter_names])

assert all(
[
name not in experiments_table.hidden_columns
for name in expected_metric_names + expected_parameter_names
]
)

assert experiments_table.commit_hash == viz_experiments[0].commit_hash
assert experiments_table.github_url == f"test.github.url/tree/{viz_experiments[0].commit_hash}"


@pytest.mark.parametrize("filter_by", ["tags", "names"])
def test_experiments_table_load_filtered_data(filter_by, viz_experiments):
if filter_by == "tags":
tags = ["a", "b"]
qtype = "and"

expected_metric_names = [m.name for m in viz_experiments[0].metrics(tags=tags, qtype=qtype)]
expected_parameter_names = [
p.name for p in viz_experiments[0].parameters(tags=tags, qtype=qtype)
]

experiments_table_kwargs = {
"metric_query_tags": tags,
"metric_query_type": qtype,
"parameter_query_tags": tags,
"parameter_query_type": qtype,
}
elif filter_by == "names":
expected_metric_names = ["test metric 0"]
expected_parameter_names = ["test param 1", "test param 2"]

experiments_table_kwargs = {
"metric_names": expected_metric_names,
"parameter_names": expected_parameter_names,
}

experiments_table = ExperimentsTable(
experiments=viz_experiments,
**experiments_table_kwargs,
)
experiments_table.load_experiment_data()

expected_experiment_ids = [e.id for e in viz_experiments]

all_metric_names = [m.name for m in viz_experiments[0].metrics()]
all_parameter_names = [p.name for p in viz_experiments[0].parameters()]
unexpected_metric_names = list(set(all_metric_names).difference(set(expected_metric_names)))
unexpected_parameter_names = list(
set(all_parameter_names).difference(set(expected_parameter_names))
)

for record in experiments_table.experiment_records:
assert record["id"] in expected_experiment_ids
assert all([record.get(name) is not None for name in expected_metric_names])
assert all([record.get(name) is not None for name in expected_parameter_names])

# unexpected should still be in table...
assert all([record.get(name) is not None for name in unexpected_metric_names])
assert all([record.get(name) is not None for name in unexpected_parameter_names])

# ...but they should be hidden
assert all(
[
name in experiments_table.hidden_columns
for name in unexpected_metric_names + unexpected_parameter_names
]
)


def test_experiments_table_layout(viz_experiments):
experiments_table = ExperimentsTable(experiments=viz_experiments)
experiments_table.load_experiment_data()
Expand Down

0 comments on commit 0a28bc3

Please sign in to comment.