Skip to content

Commit

Permalink
Fixed metric lists comparison length None bug (#473)
Browse files Browse the repository at this point in the history
* fixed metric lists comparison length None bug

* lint check

* link == to is check

* test that not setting a column name results in an empty list
  • Loading branch information
austinhk authored Jul 30, 2024
1 parent f452770 commit b5f44ce
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 5 deletions.
5 changes: 4 additions & 1 deletion rubicon_ml/viz/metric_lists_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,10 @@ def __init__(
):
super().__init__(dash_title="compare metric lists")

self.column_names = column_names
if column_names is None:
self.column_names = []
else:
self.column_names = column_names
self.experiments = experiments
self.selected_metric = selected_metric

Expand Down
2 changes: 1 addition & 1 deletion tests/unit/intake_rubicon/test_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def test_metric_list_source():
visualization = source.read()

assert visualization is not None
assert visualization.column_names == catalog_data_sample["column_names"]
assert visualization.column_names == []
assert visualization.selected_metric == catalog_data_sample["selected_metric"]

source.close()
10 changes: 7 additions & 3 deletions tests/unit/viz/test_metric_lists_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
from rubicon_ml.viz import MetricListsComparison


def test_metric_lists_comparison(viz_experiments):
@pytest.mark.parametrize("column_names", [["var_0", "var_1", "var_2", "var_3", "var_4"], None])
def test_metric_lists_comparison(viz_experiments, column_names):
metric_comparison = MetricListsComparison(
column_names=["var_0", "var_1", "var_2", "var_3", "var_4"],
column_names=column_names,
experiments=viz_experiments,
selected_metric="test metric 2",
)
Expand All @@ -19,7 +20,10 @@ def test_metric_lists_comparison(viz_experiments):
expected_experiment_ids.remove(experiment.id)

assert len(expected_experiment_ids) == 0
assert metric_comparison.column_names == ["var_0", "var_1", "var_2", "var_3", "var_4"]
if column_names is None:
assert metric_comparison.column_names == []
else:
assert metric_comparison.column_names == ["var_0", "var_1", "var_2", "var_3", "var_4"]
assert metric_comparison.selected_metric == "test metric 2"


Expand Down

0 comments on commit b5f44ce

Please sign in to comment.