Skip to content

Commit

Permalink
dvc render: Add plotly parallel coords renderer
Browse files Browse the repository at this point in the history
New renderer based on plotly. Not exposed to `dvc plots`.
Generate plotly datapoints from `TabularData`.

pre-requisite #4455
  • Loading branch information
daavoo committed Dec 1, 2021
1 parent 004e15c commit b1b7d39
Show file tree
Hide file tree
Showing 2 changed files with 201 additions and 0 deletions.
78 changes: 78 additions & 0 deletions dvc/render/plotly.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import json
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, Optional

from dvc.render.base import Renderer

if TYPE_CHECKING:
from dvc.compare import TabularData


class ParallelCoordinatesRenderer(Renderer):
DIV = """
<div id = "{id}">
<script type = "text/javascript">
var plotly_data = {partial};
Plotly.newPlot("{id}", plotly_data.data, plotly_data.layout);
</script>
</div>
"""

SCRIPTS = """
<script src="https://cdn.plot.ly/plotly-latest.min.js"></script>
"""

# pylint: disable=W0231
def __init__(
self, tabular_data: "TabularData", color_by: Optional[str] = None
):
self.tabular_data = tabular_data
self.color_by = color_by
self.filename = "experiments"

def _convert(self, path):
return self.as_json()

def as_json(self) -> str:
tabular_dict = defaultdict(list)
for row in self.tabular_data.as_dict():
for col_name, value in row.items():
tabular_dict[col_name].append(str(value))

trace: Dict[str, Any] = {"type": "parcoords", "dimensions": []}
for label, values in tabular_dict.items():
is_categorical = False

try:
float_values = [float(x) for x in values]
except ValueError:
is_categorical = True
dummy_values = list(range(len(values)))

if is_categorical:
trace["dimensions"].append(
{
"label": label,
"values": dummy_values,
"tickvals": dummy_values,
"ticktext": values,
}
)
else:
trace["dimensions"].append(
{"label": label, "values": float_values}
)

if label == self.color_by:
trace["line"] = {
"color": dummy_values if is_categorical else float_values,
"showscale": True,
}
if is_categorical:
trace["line"]["colorbar"] = {
"tickmode": "array",
"tickvals": dummy_values,
"ticktext": values,
}

return json.dumps({"data": [trace], "layout": {}})
123 changes: 123 additions & 0 deletions tests/unit/render/test_parallel_coordinates.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import json

from dvc.compare import TabularData
from dvc.render.html import write
from dvc.render.plotly import ParallelCoordinatesRenderer

# pylint: disable=W1514


def expected_format(result):
assert "data" in result
assert "layout" in result
assert isinstance(result["data"], list)
assert result["data"][0]["type"] == "parcoords"
assert isinstance(result["data"][0]["dimensions"], list)
return True


def test_scalar_columns():
td = TabularData(["col-1", "col-2"])
td.extend([["0.1", "1"], ["2", "0.2"]])
renderer = ParallelCoordinatesRenderer(td)

result = json.loads(renderer.as_json())

assert expected_format(result)

assert result["data"][0]["dimensions"][0] == {
"label": "col-1",
"values": [0.1, 2.0],
}
assert result["data"][0]["dimensions"][1] == {
"label": "col-2",
"values": [1.0, 0.2],
}


def test_categorical_columns():
td = TabularData(["col-1"])
td.extend([["foo"], ["bar"]])
renderer = ParallelCoordinatesRenderer(td)

result = json.loads(renderer.as_json())

assert expected_format(result)

assert result["data"][0]["dimensions"][0] == {
"label": "col-1",
"values": [0, 1],
"tickvals": [0, 1],
"ticktext": ["foo", "bar"],
}


def test_mixed_columns():
td = TabularData(["categorical", "scalar"])
td.extend([["foo", "0.1"], ["bar", "2"]])
renderer = ParallelCoordinatesRenderer(td)

result = json.loads(renderer.as_json())

assert expected_format(result)

assert result["data"][0]["dimensions"][0] == {
"label": "categorical",
"values": [0, 1],
"tickvals": [0, 1],
"ticktext": ["foo", "bar"],
}
assert result["data"][0]["dimensions"][1] == {
"label": "scalar",
"values": [0.1, 2.0],
}


def test_color_by_scalar():
td = TabularData(["categorical", "scalar"])
td.extend([["foo", "0.1"], ["bar", "2"]])
renderer = ParallelCoordinatesRenderer(td, color_by="scalar")

result = json.loads(renderer.as_json())

assert expected_format(result)
assert result["data"][0]["line"] == {
"color": [0.1, 2.0],
"showscale": True,
}


def test_color_by_categorical():
td = TabularData(["categorical", "scalar"])
td.extend([["foo", "0.1"], ["bar", "2"]])
renderer = ParallelCoordinatesRenderer(td, color_by="categorical")

result = json.loads(renderer.as_json())

assert expected_format(result)
assert result["data"][0]["line"] == {
"color": [0, 1],
"showscale": True,
"colorbar": {
"tickmode": "array",
"tickvals": [0, 1],
"ticktext": ["foo", "bar"],
},
}


def test_write_parallel_coordinates(tmp_dir):
td = TabularData(["categorical", "scalar"])
td.extend([["foo", "0.1"], ["bar", "2"]])

renderer = ParallelCoordinatesRenderer(td)
html_path = write(tmp_dir, renderers=[renderer])

html_text = html_path.read_text()

assert ParallelCoordinatesRenderer.SCRIPTS in html_text

div = ParallelCoordinatesRenderer.DIV.format(
id="plot_experiments", partial=renderer.as_json()
)
assert div in html_text

0 comments on commit b1b7d39

Please sign in to comment.