-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
dvc render: Add plotly parallel coords renderer
New renderer based on plotly. Not exposed to `dvc plots`. Generate plotly datapoints from `TabularData`. pre-requisite #4455
- Loading branch information
Showing
2 changed files
with
201 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
import json | ||
from collections import defaultdict | ||
from typing import TYPE_CHECKING, Any, Dict, Optional | ||
|
||
from dvc.render.base import Renderer | ||
|
||
if TYPE_CHECKING: | ||
from dvc.compare import TabularData | ||
|
||
|
||
class ParallelCoordinatesRenderer(Renderer): | ||
DIV = """ | ||
<div id = "{id}"> | ||
<script type = "text/javascript"> | ||
var plotly_data = {partial}; | ||
Plotly.newPlot("{id}", plotly_data.data, plotly_data.layout); | ||
</script> | ||
</div> | ||
""" | ||
|
||
SCRIPTS = """ | ||
<script src="https://cdn.plot.ly/plotly-latest.min.js"></script> | ||
""" | ||
|
||
# pylint: disable=W0231 | ||
def __init__( | ||
self, tabular_data: "TabularData", color_by: Optional[str] = None | ||
): | ||
self.tabular_data = tabular_data | ||
self.color_by = color_by | ||
self.filename = "experiments" | ||
|
||
def _convert(self, path): | ||
return self.as_json() | ||
|
||
def as_json(self) -> str: | ||
tabular_dict = defaultdict(list) | ||
for row in self.tabular_data.as_dict(): | ||
for col_name, value in row.items(): | ||
tabular_dict[col_name].append(str(value)) | ||
|
||
trace: Dict[str, Any] = {"type": "parcoords", "dimensions": []} | ||
for label, values in tabular_dict.items(): | ||
is_categorical = False | ||
|
||
try: | ||
float_values = [float(x) for x in values] | ||
except ValueError: | ||
is_categorical = True | ||
dummy_values = list(range(len(values))) | ||
|
||
if is_categorical: | ||
trace["dimensions"].append( | ||
{ | ||
"label": label, | ||
"values": dummy_values, | ||
"tickvals": dummy_values, | ||
"ticktext": values, | ||
} | ||
) | ||
else: | ||
trace["dimensions"].append( | ||
{"label": label, "values": float_values} | ||
) | ||
|
||
if label == self.color_by: | ||
trace["line"] = { | ||
"color": dummy_values if is_categorical else float_values, | ||
"showscale": True, | ||
} | ||
if is_categorical: | ||
trace["line"]["colorbar"] = { | ||
"tickmode": "array", | ||
"tickvals": dummy_values, | ||
"ticktext": values, | ||
} | ||
|
||
return json.dumps({"data": [trace], "layout": {}}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
import json | ||
|
||
from dvc.compare import TabularData | ||
from dvc.render.html import write | ||
from dvc.render.plotly import ParallelCoordinatesRenderer | ||
|
||
# pylint: disable=W1514 | ||
|
||
|
||
def expected_format(result): | ||
assert "data" in result | ||
assert "layout" in result | ||
assert isinstance(result["data"], list) | ||
assert result["data"][0]["type"] == "parcoords" | ||
assert isinstance(result["data"][0]["dimensions"], list) | ||
return True | ||
|
||
|
||
def test_scalar_columns(): | ||
td = TabularData(["col-1", "col-2"]) | ||
td.extend([["0.1", "1"], ["2", "0.2"]]) | ||
renderer = ParallelCoordinatesRenderer(td) | ||
|
||
result = json.loads(renderer.as_json()) | ||
|
||
assert expected_format(result) | ||
|
||
assert result["data"][0]["dimensions"][0] == { | ||
"label": "col-1", | ||
"values": [0.1, 2.0], | ||
} | ||
assert result["data"][0]["dimensions"][1] == { | ||
"label": "col-2", | ||
"values": [1.0, 0.2], | ||
} | ||
|
||
|
||
def test_categorical_columns(): | ||
td = TabularData(["col-1"]) | ||
td.extend([["foo"], ["bar"]]) | ||
renderer = ParallelCoordinatesRenderer(td) | ||
|
||
result = json.loads(renderer.as_json()) | ||
|
||
assert expected_format(result) | ||
|
||
assert result["data"][0]["dimensions"][0] == { | ||
"label": "col-1", | ||
"values": [0, 1], | ||
"tickvals": [0, 1], | ||
"ticktext": ["foo", "bar"], | ||
} | ||
|
||
|
||
def test_mixed_columns(): | ||
td = TabularData(["categorical", "scalar"]) | ||
td.extend([["foo", "0.1"], ["bar", "2"]]) | ||
renderer = ParallelCoordinatesRenderer(td) | ||
|
||
result = json.loads(renderer.as_json()) | ||
|
||
assert expected_format(result) | ||
|
||
assert result["data"][0]["dimensions"][0] == { | ||
"label": "categorical", | ||
"values": [0, 1], | ||
"tickvals": [0, 1], | ||
"ticktext": ["foo", "bar"], | ||
} | ||
assert result["data"][0]["dimensions"][1] == { | ||
"label": "scalar", | ||
"values": [0.1, 2.0], | ||
} | ||
|
||
|
||
def test_color_by_scalar(): | ||
td = TabularData(["categorical", "scalar"]) | ||
td.extend([["foo", "0.1"], ["bar", "2"]]) | ||
renderer = ParallelCoordinatesRenderer(td, color_by="scalar") | ||
|
||
result = json.loads(renderer.as_json()) | ||
|
||
assert expected_format(result) | ||
assert result["data"][0]["line"] == { | ||
"color": [0.1, 2.0], | ||
"showscale": True, | ||
} | ||
|
||
|
||
def test_color_by_categorical(): | ||
td = TabularData(["categorical", "scalar"]) | ||
td.extend([["foo", "0.1"], ["bar", "2"]]) | ||
renderer = ParallelCoordinatesRenderer(td, color_by="categorical") | ||
|
||
result = json.loads(renderer.as_json()) | ||
|
||
assert expected_format(result) | ||
assert result["data"][0]["line"] == { | ||
"color": [0, 1], | ||
"showscale": True, | ||
"colorbar": { | ||
"tickmode": "array", | ||
"tickvals": [0, 1], | ||
"ticktext": ["foo", "bar"], | ||
}, | ||
} | ||
|
||
|
||
def test_write_parallel_coordinates(tmp_dir): | ||
td = TabularData(["categorical", "scalar"]) | ||
td.extend([["foo", "0.1"], ["bar", "2"]]) | ||
|
||
renderer = ParallelCoordinatesRenderer(td) | ||
html_path = write(tmp_dir, renderers=[renderer]) | ||
|
||
html_text = html_path.read_text() | ||
|
||
assert ParallelCoordinatesRenderer.SCRIPTS in html_text | ||
|
||
div = ParallelCoordinatesRenderer.DIV.format( | ||
id="plot_experiments", partial=renderer.as_json() | ||
) | ||
assert div in html_text |