From f3c3f66c0641c0bf10d78ac23b2d6b534b8297c5 Mon Sep 17 00:00:00 2001 From: Matt Seddon <37993418+mattseddon@users.noreply.github.com> Date: Tue, 5 Dec 2023 15:05:11 +1100 Subject: [PATCH] `plots`: standardise across DVC/Studio/VS Code (#142) * update templates with new anchors * prototype filling optional anchors * move update datapoints to top level * add color into anchor definitions * add get_revs method to renderer * add anchors to fix confusion templates * split x and y labels out so they can be truncated by vs code * remove erroneous comment * fix issue with confusion matrix * fix string issues * add title to split anchors * add zoom and pan anchor * add tooltip anchor for Studio * rename anchor_revs to revs_with_datapoints * add empty sort property to order facet by rev in datapoints * add width and height anchors * add sort to y offset in horizontal bar templates * move horizontal bar plots from row to column * remove row_height and column_width anchors * use field separator for pivot field * fix linear plots with varied filename field * extend optional anchor tests * ensure all 4 variations have test cases * hoist anchors_y_definition in tests * add separate color test --- src/dvc_render/vega.py | 370 +++++++++- src/dvc_render/vega_templates.py | 176 ++--- tests/test_vega.py | 1087 +++++++++++++++++++++++++++++- 3 files changed, 1524 insertions(+), 109 deletions(-) diff --git a/src/dvc_render/vega.py b/src/dvc_render/vega.py index b10ab0a..9a62bc6 100644 --- a/src/dvc_render/vega.py +++ b/src/dvc_render/vega.py @@ -1,13 +1,63 @@ import base64 import io import json +from collections import defaultdict from pathlib import Path -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Tuple, Union from warnings import warn from .base import Renderer from .utils import list_dict_to_dict_list -from .vega_templates import BadTemplateError, LinearTemplate, get_template +from .vega_templates import BadTemplateError, LinearTemplate, Template, get_template + +FIELD_SEPARATOR = "::" +REV = "rev" +FILENAME = "filename" +FIELD = "field" +FILENAME_FIELD = [FILENAME, FIELD] +CONCAT_FIELDS = FIELD_SEPARATOR.join(FILENAME_FIELD) + +SPLIT_ANCHORS = [ + "color", + "data", + "plot_height", + "plot_width", + "shape", + "stroke_dash", + "title", + "tooltip", + "x_label", + "y_label", + "zoom_and_pan", +] +OPTIONAL_ANCHORS = [ + "color", + "column", + "group_by_x", + "group_by_y", + "group_by", + "pivot_field", + "plot_height", + "plot_width", + "row", + "shape", + "stroke_dash", + "tooltip", + "zoom_and_pan", +] +OPTIONAL_ANCHOR_RANGES: Dict[str, Union[List[str], List[List[int]]]] = { + "stroke_dash": [[1, 0], [8, 8], [8, 4], [4, 4], [4, 2], [2, 1], [1, 1]], + "color": [ + "#945dd6", + "#13adc7", + "#f46837", + "#48bb78", + "#4299e1", + "#ed8936", + "#f56565", + ], + "shape": ["circle", "square", "triangle", "diamond"], +} class VegaRenderer(Renderer): @@ -39,19 +89,20 @@ def __init__(self, datapoints: List, name: str, **properties): self.properties.get("template_dir", None), ) + self._split_content: Dict[str, str] = {} + def get_filled_template( self, - skip_anchors: Optional[List[str]] = None, + split_anchors: Optional[List[str]] = None, strict: bool = True, - as_string: bool = True, - ) -> Union[str, Dict[str, Any]]: + ) -> Dict[str, Any]: """Returns a functional vega specification""" self.template.reset() if not self.datapoints: return {} - if skip_anchors is None: - skip_anchors = [] + if split_anchors is None: + split_anchors = [] if strict: if self.properties.get("x"): @@ -67,13 +118,19 @@ def get_filled_template( self.properties.setdefault("y_label", self.properties.get("y")) self.properties.setdefault("data", self.datapoints) + varied_keys = self._process_optional_anchors(split_anchors) + self._update_datapoints(varied_keys) + names = ["title", "x", "y", "x_label", "y_label", "data"] for name in names: - if name in skip_anchors: - continue value = self.properties.get(name) if value is None: continue + + if name in split_anchors: + self._set_split_content(name, value) + continue + if name == "data": if not self.template.has_anchor(name): anchor = self.template.anchor(name) @@ -85,13 +142,27 @@ def get_filled_template( value = self.template.escape_special_characters(value) self.template.fill_anchor(name, value) - if as_string: - return json.dumps(self.template.content) + return self.template.content + def get_partial_filled_template(self): + """ + Returns a partially filled template along with the split out anchor content + """ + content = self.get_filled_template( + split_anchors=SPLIT_ANCHORS, + strict=True, + ) + return content, {"anchor_definitions": self._split_content} + + def get_template(self): + """ + Returns unfilled template (for Studio) + """ return self.template.content def partial_html(self, **kwargs) -> str: - return self.get_filled_template() # type: ignore + content = self.get_filled_template() + return json.dumps(content) def generate_markdown(self, report_path=None) -> str: if not isinstance(self.template, LinearTemplate): @@ -137,3 +208,278 @@ def generate_markdown(self, report_path=None) -> str: return f"\n![{self.name}]({src})" return "" + + def get_revs(self): + """ + Returns all revisions that were collected that have datapoints. + """ + return self.properties.get("revs_with_datapoints", []) + + def _process_optional_anchors(self, split_anchors: List[str]): + optional_anchors = [ + anchor for anchor in OPTIONAL_ANCHORS if self.template.has_anchor(anchor) + ] + if not optional_anchors: + return None + + self._fill_color(split_anchors, optional_anchors) + self._fill_set_encoding(split_anchors, optional_anchors) + + y_definitions = self.properties.get("anchors_y_definitions", []) + is_single_source = len(y_definitions) <= 1 + + if is_single_source: + self._process_single_source_plot(split_anchors, optional_anchors) + return [] + + return self._process_multi_source_plot( + split_anchors, optional_anchors, y_definitions + ) + + def _fill_color(self, split_anchors: List[str], optional_anchors: List[str]): + all_revs = self.get_revs() + self._fill_optional_anchor_mapping( + split_anchors, + optional_anchors, + REV, + "color", + all_revs, + ) + + def _fill_set_encoding(self, split_anchors: List[str], optional_anchors: List[str]): + for name, encoding in [ + ("zoom_and_pan", {"name": "grid", "select": "interval", "bind": "scales"}), + ("plot_height", 300), + ("plot_width", 300), + ]: + self._fill_optional_anchor(split_anchors, optional_anchors, name, encoding) + + def _process_single_source_plot( + self, split_anchors: List[str], optional_anchors: List[str] + ): + self._fill_group_by(split_anchors, optional_anchors, [REV]) + self._fill_optional_anchor( + split_anchors, optional_anchors, "pivot_field", "datum.rev" + ) + self._fill_tooltip(split_anchors, optional_anchors) + for anchor in optional_anchors: + self.template.fill_anchor(anchor, {}) + + def _process_multi_source_plot( + self, + split_anchors: List[str], + optional_anchors: List[str], + y_definitions: List[Dict[str, str]], + ): + varied_keys, domain = self._collect_variations(y_definitions) + + self._fill_optional_multi_source_anchors( + split_anchors, optional_anchors, varied_keys, domain + ) + return varied_keys + + def _collect_variations( + self, y_definitions: List[Dict[str, str]] + ) -> Tuple[List[str], List[str]]: + varied_values = defaultdict(set) + for defn in y_definitions: + for key in FILENAME_FIELD: + varied_values[key].add(defn.get(key, None)) + varied_values[CONCAT_FIELDS].add( + FIELD_SEPARATOR.join([defn.get(FILENAME, ""), defn.get(FIELD, "")]) + ) + + varied_keys = [] + + for filename_or_field in FILENAME_FIELD: + value_set = varied_values[filename_or_field] + num_values = len(value_set) + if num_values == 1: + continue + varied_keys.append(filename_or_field) + + domain = self._get_domain(varied_keys, varied_values) + + return varied_keys, domain + + def _fill_optional_multi_source_anchors( + self, + split_anchors: List[str], + optional_anchors: List[str], + varied_keys: List[str], + domain: List[str], + ): + if not optional_anchors: + return + + concat_field = FIELD_SEPARATOR.join(varied_keys) + self._fill_group_by(split_anchors, optional_anchors, [REV, concat_field]) + + self._fill_optional_anchor( + split_anchors, + optional_anchors, + "pivot_field", + f" + '{FIELD_SEPARATOR}' + ".join( + [f"datum.{key}" for key in [REV, *varied_keys]] + ), + ) + + self._fill_optional_anchor( + split_anchors, optional_anchors, "row", {"field": concat_field, "sort": []} + ) + self._fill_optional_anchor( + split_anchors, + optional_anchors, + "column", + {"field": concat_field, "sort": []}, + ) + + self._fill_tooltip(split_anchors, optional_anchors, [concat_field]) + + for anchor in ["stroke_dash", "shape"]: + self._fill_optional_anchor_mapping( + split_anchors, optional_anchors, concat_field, anchor, domain + ) + + def _fill_group_by( + self, + split_anchors: List[str], + optional_anchors: List[str], + group_by: List[str], + ): + self._fill_optional_anchor( + split_anchors, optional_anchors, "group_by", group_by + ) + self._fill_optional_anchor( + split_anchors, + optional_anchors, + "group_by_x", + [*group_by, self.properties.get("x")], + ) + self._fill_optional_anchor( + split_anchors, + optional_anchors, + "group_by_y", + [*group_by, self.properties.get("y")], + ) + + def _fill_tooltip( + self, + split_anchors: List[str], + optional_anchors: List[str], + additional_fields: Optional[List[str]] = None, + ): + if not additional_fields: + additional_fields = [] + self._fill_optional_anchor( + split_anchors, + optional_anchors, + "tooltip", + [ + {"field": REV}, + {"field": self.properties.get("x")}, + {"field": self.properties.get("y")}, + *[{"field": field} for field in additional_fields], + ], + ) + + def _fill_optional_anchor( + self, + split_anchors: List[str], + optional_anchors: List[str], + name: str, + value: Any, + ): + if name not in optional_anchors: + return + + optional_anchors.remove(name) + + if name in split_anchors: + self._set_split_content(name, value) + return + + self.template.fill_anchor(name, value) + + def _get_domain(self, varied_keys: List[str], varied_values: Dict[str, set]): + if len(varied_keys) == 2: + domain = list(varied_values[CONCAT_FIELDS]) + else: + filename_or_field = varied_keys[0] + domain = list(varied_values[filename_or_field]) + + domain.sort() + return domain + + def _fill_optional_anchor_mapping( + self, + split_anchors: List[str], + optional_anchors: List[str], + field: str, + name: str, + domain: List[str], + ): # pylint: disable=too-many-arguments + if name not in optional_anchors: + return + + optional_anchors.remove(name) + + encoding = self._get_optional_anchor_mapping(field, name, domain) + + if name in split_anchors: + self._set_split_content(name, encoding) + return + + self.template.fill_anchor(name, encoding) + + def _get_optional_anchor_mapping( + self, + field: str, + name: str, + domain: List[str], + ): + full_range_values: List[Any] = OPTIONAL_ANCHOR_RANGES.get(name, []) + anchor_range_values = full_range_values.copy() + + anchor_range = [] + for _ in range(len(domain)): + if not anchor_range_values: + anchor_range_values = full_range_values.copy() + range_value = anchor_range_values.pop(0) + anchor_range.append(range_value) + + legend = ( + # fix stroke dash and shape legend entry appearance (use empty shapes) + {"legend": {"symbolFillColor": "transparent", "symbolStrokeColor": "grey"}} + if name != "color" + else {} + ) + + return { + "field": field, + "scale": {"domain": domain, "range": anchor_range}, + **legend, + } + + def _update_datapoints(self, varied_keys: Optional[List[str]] = None): + if varied_keys is None: + return + + if len(varied_keys) == 2: + to_concatenate = varied_keys + to_remove = varied_keys + else: + to_concatenate = [] + to_remove = [key for key in FILENAME_FIELD if key not in varied_keys] + + for datapoint in self.datapoints: + if to_concatenate: + concat_key = FIELD_SEPARATOR.join(to_concatenate) + datapoint[concat_key] = FIELD_SEPARATOR.join( + [datapoint.get(k) for k in to_concatenate] + ) + for key in to_remove: + datapoint.pop(key, None) + + def _set_split_content(self, name: str, value: Any): + self._split_content[Template.anchor(name)] = value diff --git a/src/dvc_render/vega_templates.py b/src/dvc_render/vega_templates.py index 3940ac3..0f6dd05 100644 --- a/src/dvc_render/vega_templates.py +++ b/src/dvc_render/vega_templates.py @@ -142,9 +142,10 @@ class BarHorizontalSortedTemplate(Template): "$schema": "https://vega.github.io/schema/vega-lite/v5.json", "data": {"values": Template.anchor("data")}, "title": Template.anchor("title"), - "width": 300, - "height": 300, + "width": Template.anchor("plot_width"), + "height": Template.anchor("plot_height"), "mark": {"type": "bar"}, + "params": [Template.anchor("zoom_and_pan")], "encoding": { "x": { "field": Template.anchor("x"), @@ -158,11 +159,9 @@ class BarHorizontalSortedTemplate(Template): "title": Template.anchor("y_label"), "sort": "-x", }, - "yOffset": {"field": "rev"}, - "color": { - "field": "rev", - "type": "nominal", - }, + "yOffset": {"field": "rev", "sort": []}, + "color": Template.anchor("color"), + "column": Template.anchor("column"), }, } @@ -174,9 +173,10 @@ class BarHorizontalTemplate(Template): "$schema": "https://vega.github.io/schema/vega-lite/v5.json", "data": {"values": Template.anchor("data")}, "title": Template.anchor("title"), - "width": 300, - "height": 300, + "width": Template.anchor("plot_width"), + "height": Template.anchor("plot_height"), "mark": {"type": "bar"}, + "params": [Template.anchor("zoom_and_pan")], "encoding": { "x": { "field": Template.anchor("x"), @@ -189,11 +189,9 @@ class BarHorizontalTemplate(Template): "type": "nominal", "title": Template.anchor("y_label"), }, - "yOffset": {"field": "rev"}, - "color": { - "field": "rev", - "type": "nominal", - }, + "yOffset": {"field": "rev", "sort": []}, + "color": Template.anchor("color"), + "column": Template.anchor("column"), }, } @@ -204,7 +202,10 @@ class ConfusionTemplate(Template): "$schema": "https://vega.github.io/schema/vega-lite/v5.json", "data": {"values": Template.anchor("data")}, "title": Template.anchor("title"), - "facet": {"field": "rev", "type": "nominal"}, + "facet": { + "column": {"field": "rev", "sort": []}, + "row": Template.anchor("row"), + }, "params": [ { "name": "showValues", @@ -219,13 +220,13 @@ class ConfusionTemplate(Template): }, { "impute": "xy_count", - "groupby": ["rev", Template.anchor("y")], + "groupby": Template.anchor("group_by_y"), "key": Template.anchor("x"), "value": 0, }, { "impute": "xy_count", - "groupby": ["rev", Template.anchor("x")], + "groupby": Template.anchor("group_by_x"), "key": Template.anchor("y"), "value": 0, }, @@ -257,8 +258,8 @@ class ConfusionTemplate(Template): "layer": [ { "mark": "rect", - "width": 300, - "height": 300, + "width": Template.anchor("plot_width"), + "height": Template.anchor("plot_height"), "encoding": { "color": { "field": "xy_count", @@ -327,7 +328,10 @@ class NormalizedConfusionTemplate(Template): "$schema": "https://vega.github.io/schema/vega-lite/v5.json", "data": {"values": Template.anchor("data")}, "title": Template.anchor("title"), - "facet": {"field": "rev", "type": "nominal"}, + "facet": { + "column": {"field": "rev", "sort": []}, + "row": Template.anchor("row"), + }, "spec": { "transform": [ { @@ -342,7 +346,7 @@ class NormalizedConfusionTemplate(Template): }, { "impute": "xy_count", - "groupby": ["rev", Template.anchor("x")], + "groupby": Template.anchor("group_by_x"), "key": Template.anchor("y"), "value": 0, }, @@ -374,8 +378,8 @@ class NormalizedConfusionTemplate(Template): "layer": [ { "mark": "rect", - "width": 300, - "height": 300, + "width": Template.anchor("plot_width"), + "height": Template.anchor("plot_height"), "encoding": { "color": { "field": "percent_of_y", @@ -442,9 +446,10 @@ class ScatterTemplate(Template): "$schema": "https://vega.github.io/schema/vega-lite/v5.json", "data": {"values": Template.anchor("data")}, "title": Template.anchor("title"), - "width": 300, - "height": 300, + "width": Template.anchor("plot_width"), + "height": Template.anchor("plot_height"), "mark": {"type": "point", "tooltip": {"content": "data"}}, + "params": [Template.anchor("zoom_and_pan")], "encoding": { "x": { "field": Template.anchor("x"), @@ -456,10 +461,9 @@ class ScatterTemplate(Template): "type": "quantitative", "title": Template.anchor("y_label"), }, - "color": { - "field": "rev", - "type": "nominal", - }, + "color": Template.anchor("color"), + "shape": Template.anchor("shape"), + "tooltip": Template.anchor("tooltip"), }, } @@ -471,8 +475,8 @@ class ScatterJitterTemplate(Template): "$schema": "https://vega.github.io/schema/vega-lite/v5.json", "data": {"values": Template.anchor("data")}, "title": Template.anchor("title"), - "width": 300, - "height": 300, + "width": Template.anchor("plot_width"), + "height": Template.anchor("plot_height"), "transform": [ {"calculate": "random()", "as": "randomX"}, {"calculate": "random()", "as": "randomY"}, @@ -487,10 +491,9 @@ class ScatterJitterTemplate(Template): "field": Template.anchor("y"), "title": Template.anchor("y_label"), }, - "color": { - "field": "rev", - "type": "nominal", - }, + "color": Template.anchor("color"), + "shape": Template.anchor("shape"), + "tooltip": Template.anchor("tooltip"), "xOffset": {"field": "randomX", "type": "quantitative"}, "yOffset": {"field": "randomY", "type": "quantitative"}, }, @@ -503,8 +506,8 @@ class SmoothLinearTemplate(Template): "$schema": "https://vega.github.io/schema/vega-lite/v5.json", "data": {"values": Template.anchor("data")}, "title": Template.anchor("title"), - "width": 300, - "height": 300, + "width": Template.anchor("plot_width"), + "height": Template.anchor("plot_height"), "params": [ { "name": "smooth", @@ -517,15 +520,25 @@ class SmoothLinearTemplate(Template): }, }, ], + "encoding": { + "x": { + "field": Template.anchor("x"), + "type": "quantitative", + "title": Template.anchor("x_label"), + }, + "color": Template.anchor("color"), + "strokeDash": Template.anchor("stroke_dash"), + }, "layer": [ { - "mark": "line", - "encoding": { - "x": { - "field": Template.anchor("x"), - "type": "quantitative", - "title": Template.anchor("x_label"), + "layer": [ + {"params": [Template.anchor("zoom_and_pan")], "mark": "line"}, + { + "transform": [{"filter": {"param": "hover", "empty": False}}], + "mark": "point", }, + ], + "encoding": { "y": { "field": Template.anchor("y"), "type": "quantitative", @@ -536,24 +549,12 @@ class SmoothLinearTemplate(Template): "field": "rev", "type": "nominal", }, - "tooltip": [ - { - "field": Template.anchor("x"), - "title": Template.anchor("x_label"), - "type": "quantitative", - }, - { - "field": Template.anchor("y"), - "title": Template.anchor("y_label"), - "type": "quantitative", - }, - ], }, "transform": [ { "loess": Template.anchor("y"), "on": Template.anchor("x"), - "groupby": ["rev", "filename", "field", "filename::field"], + "groupby": Template.anchor("group_by"), "bandwidth": {"signal": "smooth"}, }, ], @@ -573,26 +574,10 @@ class SmoothLinearTemplate(Template): "scale": {"zero": False}, }, "color": {"field": "rev", "type": "nominal"}, - "tooltip": [ - { - "field": Template.anchor("x"), - "title": Template.anchor("x_label"), - "type": "quantitative", - }, - { - "field": Template.anchor("y"), - "title": Template.anchor("y_label"), - "type": "quantitative", - }, - ], }, }, { - "mark": { - "type": "circle", - "size": 10, - "tooltip": {"content": "encoding"}, - }, + "mark": {"type": "circle", "size": 10}, "encoding": { "x": { "aggregate": "max", @@ -610,6 +595,39 @@ class SmoothLinearTemplate(Template): "color": {"field": "rev", "type": "nominal"}, }, }, + { + "transform": [ + {"calculate": Template.anchor("pivot_field"), "as": "pivot_field"}, + { + "pivot": "pivot_field", + "value": Template.anchor("y"), + "groupby": [Template.anchor("x")], + }, + ], + "mark": { + "type": "rule", + "tooltip": {"content": "data"}, + "stroke": "grey", + }, + "encoding": { + "opacity": { + "condition": {"value": 0.3, "param": "hover", "empty": False}, + "value": 0, + } + }, + "params": [ + { + "name": "hover", + "select": { + "type": "point", + "fields": [Template.anchor("x")], + "nearest": True, + "on": "mouseover", + "clear": "mouseout", + }, + } + ], + }, ], } @@ -625,8 +643,9 @@ class SimpleLinearTemplate(Template): "$schema": "https://vega.github.io/schema/vega-lite/v5.json", "data": {"values": Template.anchor("data")}, "title": Template.anchor("title"), - "width": 300, - "height": 300, + "params": [Template.anchor("zoom_and_pan")], + "width": Template.anchor("plot_width"), + "height": Template.anchor("plot_height"), "mark": { "type": "line", "tooltip": {"content": "data"}, @@ -643,10 +662,9 @@ class SimpleLinearTemplate(Template): "title": Template.anchor("y_label"), "scale": {"zero": False}, }, - "color": { - "field": "rev", - "type": "nominal", - }, + "color": Template.anchor("color"), + "strokeDash": Template.anchor("stroke_dash"), + "tooltip": Template.anchor("tooltip"), }, } diff --git a/tests/test_vega.py b/tests/test_vega.py index 0cdd3a9..482c1bb 100644 --- a/tests/test_vega.py +++ b/tests/test_vega.py @@ -1,11 +1,12 @@ import json +from typing import Any, Dict, List import pytest -from dvc_render.vega import BadTemplateError, VegaRenderer +from dvc_render.vega import OPTIONAL_ANCHOR_RANGES, BadTemplateError, VegaRenderer from dvc_render.vega_templates import NoFieldInDataError, Template -# pylint: disable=missing-function-docstring, C1803 +# pylint: disable=missing-function-docstring, C1803, C0302 @pytest.mark.parametrize( @@ -42,17 +43,13 @@ def test_default_template_mark(): {"first_val": 200, "second_val": 300, "val": 3}, ] - plot_content = VegaRenderer(datapoints, "foo").get_filled_template(as_string=False) + plot_content = VegaRenderer(datapoints, "foo").get_filled_template() - assert plot_content["layer"][0]["mark"] == "line" + assert plot_content["layer"][0]["layer"][0]["mark"] == "line" assert plot_content["layer"][1]["mark"] == {"type": "line", "opacity": 0.2} - assert plot_content["layer"][2]["mark"] == { - "type": "circle", - "size": 10, - "tooltip": {"content": "encoding"}, - } + assert plot_content["layer"][2]["mark"] == {"type": "circle", "size": 10} def test_choose_axes(): @@ -62,9 +59,7 @@ def test_choose_axes(): {"first_val": 200, "second_val": 300, "val": 3}, ] - plot_content = VegaRenderer(datapoints, "foo", **props).get_filled_template( - as_string=False - ) + plot_content = VegaRenderer(datapoints, "foo", **props).get_filled_template() assert plot_content["data"]["values"] == [ { @@ -78,7 +73,7 @@ def test_choose_axes(): "second_val": 300, }, ] - assert plot_content["layer"][0]["encoding"]["x"]["field"] == "first_val" + assert plot_content["encoding"]["x"]["field"] == "first_val" assert plot_content["layer"][0]["encoding"]["y"]["field"] == "second_val" @@ -89,9 +84,7 @@ def test_confusion(): ] props = {"template": "confusion", "x": "predicted", "y": "actual"} - plot_content = VegaRenderer(datapoints, "foo", **props).get_filled_template( - as_string=False - ) + plot_content = VegaRenderer(datapoints, "foo", **props).get_filled_template() assert plot_content["data"]["values"] == [ {"predicted": "B", "actual": "A"}, @@ -221,7 +214,7 @@ def test_escape_special_characters(): ] props = {"template": "simple", "x": "foo.bar[0]", "y": "foo.bar[1]"} renderer = VegaRenderer(datapoints, "foo", **props) - filled = renderer.get_filled_template(as_string=False) + filled = renderer.get_filled_template() # data is not escaped assert filled["data"]["values"][0] == datapoints[0] # field and title yes @@ -263,7 +256,1065 @@ def test_fill_anchor_in_string(tmp_dir): props = {"template": "custom.json", "x": x, "y": y} renderer = VegaRenderer(datapoints, "foo", **props) - filled = renderer.get_filled_template(as_string=False) + filled = renderer.get_filled_template() assert filled["transform"][1]["calculate"] == "pow(datum.lab - datum.SR,2)" assert filled["encoding"]["x"]["field"] == x assert filled["encoding"]["y"]["field"] == y + + +@pytest.mark.parametrize( + ",".join( + [ + "anchors_y_definitions", + "datapoints", + "y", + "expected_dp_keys", + "stroke_dash_encoding", + "pivot_field", + "group_by", + ] + ), + ( + pytest.param( + [{"filename": "test", "field": "acc"}], + [ + { + "rev": "B", + "acc": "0.05", + "filename": "test", + "field": "acc", + "step": 1, + }, + { + "rev": "B", + "acc": "0.1", + "filename": "test", + "field": "acc", + "step": 2, + }, + ], + "acc", + ["rev", "acc", "step"], + {}, + "datum.rev", + ["rev"], + id="single_source", + ), + pytest.param( + [ + {"filename": "test", "field": "acc"}, + {"filename": "train", "field": "acc"}, + ], + [ + { + "rev": "B", + "acc": "0.05", + "filename": "test", + "field": "acc", + "step": 1, + }, + { + "rev": "B", + "acc": "0.1", + "filename": "test", + "field": "acc", + "step": 2, + }, + { + "rev": "B", + "acc": "0.04", + "filename": "train", + "field": "acc", + "step": 1, + }, + { + "rev": "B", + "acc": "0.09", + "filename": "train", + "field": "acc", + "step": 2, + }, + ], + "acc", + ["rev", "acc", "step", "filename"], + { + "field": "filename", + "scale": { + "domain": ["test", "train"], + "range": OPTIONAL_ANCHOR_RANGES["stroke_dash"][0:2], + }, + "legend": { + "symbolFillColor": "transparent", + "symbolStrokeColor": "grey", + }, + }, + "datum.rev + '::' + datum.filename", + ["rev", "filename"], + id="multi_filename", + ), + pytest.param( + [ + {"filename": "test", "field": "acc"}, + {"filename": "test", "field": "acc_norm"}, + ], + [ + { + "rev": "B", + "dvc_inferred_y_value": "0.05", + "acc": "0.05", + "acc_norm": "0.04", + "filename": "test", + "field": "acc", + "step": 1, + }, + { + "rev": "B", + "dvc_inferred_y_value": "0.1", + "acc": "0.1", + "acc_norm": "0.09", + "filename": "test", + "field": "acc", + "step": 2, + }, + { + "rev": "B", + "dvc_inferred_y_value": "0.04", + "acc": "0.05", + "acc_norm": "0.04", + "filename": "test", + "field": "acc_norm", + "step": 1, + }, + { + "rev": "B", + "dvc_inferred_y_value": "0.09", + "acc": "0.1", + "acc_norm": "0.09", + "filename": "test", + "field": "acc_norm", + "step": 2, + }, + ], + "dvc_inferred_y_value", + ["rev", "dvc_inferred_y_value", "acc", "acc_norm", "step", "field"], + { + "field": "field", + "scale": { + "domain": ["acc", "acc_norm"], + "range": OPTIONAL_ANCHOR_RANGES["stroke_dash"][0:2], + }, + "legend": { + "symbolFillColor": "transparent", + "symbolStrokeColor": "grey", + }, + }, + "datum.rev + '::' + datum.field", + ["rev", "field"], + id="multi_field", + ), + pytest.param( + [ + {"filename": "test", "field": "acc_norm"}, + {"filename": "test", "field": "acc"}, + {"filename": "train", "field": "acc"}, + ], + [ + { + "rev": "B", + "dvc_inferred_y_value": "0.05", + "acc": "0.05", + "acc_norm": "0.02", + "filename": "test", + "field": "acc", + "step": 1, + }, + { + "rev": "B", + "dvc_inferred_y_value": "0.1", + "acc": "0.01", + "acc_norm": "0.07", + "filename": "test", + "field": "acc", + "step": 2, + }, + { + "rev": "B", + "dvc_inferred_y_value": "0.04", + "acc": "0.04", + "filename": "train", + "field": "acc", + "step": 1, + }, + { + "rev": "B", + "dvc_inferred_y_value": "0.09", + "acc": "0.09", + "filename": "train", + "field": "acc", + "step": 2, + }, + { + "rev": "B", + "dvc_inferred_y_value": "0.02", + "acc": "0.05", + "acc_norm": "0.02", + "filename": "test", + "field": "acc_norm", + "step": 1, + }, + { + "rev": "B", + "dvc_inferred_y_value": "0.07", + "acc": "0.01", + "acc_norm": "0.07", + "filename": "test", + "field": "acc_norm", + "step": 2, + }, + ], + "dvc_inferred_y_value", + [ + "rev", + "dvc_inferred_y_value", + "acc", + "acc_norm", + "step", + "filename::field", + ], + { + "field": "filename::field", + "scale": { + "domain": ["test::acc", "test::acc_norm", "train::acc"], + "range": OPTIONAL_ANCHOR_RANGES["stroke_dash"][0:3], + }, + "legend": { + "symbolFillColor": "transparent", + "symbolStrokeColor": "grey", + }, + }, + "datum.rev + '::' + datum.filename + '::' + datum.field", + ["rev", "filename::field"], + id="multi_filename_field", + ), + ), +) +def test_optional_anchors_linear( + anchors_y_definitions, + datapoints, + y, + expected_dp_keys, + stroke_dash_encoding, + pivot_field, + group_by, +): # pylint: disable=too-many-arguments + props = { + "anchors_y_definitions": anchors_y_definitions, + "revs_with_datapoints": ["B"], + "template": "linear", + "x": "step", + "y": y, + } + + expected_datapoints = _get_expected_datapoints(datapoints, expected_dp_keys) + + renderer = VegaRenderer(datapoints, "foo", **props) + plot_content = renderer.get_filled_template() + + assert plot_content["data"]["values"] == expected_datapoints + assert plot_content["encoding"]["color"] == { + "field": "rev", + "scale": {"domain": ["B"], "range": OPTIONAL_ANCHOR_RANGES["color"][0:1]}, + } + assert plot_content["encoding"]["strokeDash"] == stroke_dash_encoding + assert plot_content["layer"][3]["transform"][0]["calculate"] == pivot_field + assert plot_content["layer"][0]["transform"][0]["groupby"] == group_by + + +@pytest.mark.parametrize( + ",".join( + [ + "anchors_y_definitions", + "datapoints", + "y", + "expected_dp_keys", + "row_encoding", + "group_by_y", + "group_by_x", + ] + ), + ( + pytest.param( + [{"filename": "test", "field": "predicted"}], + [ + { + "rev": "B", + "predicted": "0.05", + "actual": "0.5", + "filename": "test", + "field": "predicted", + }, + { + "rev": "B", + "predicted": "0.9", + "actual": "0.9", + "filename": "test", + "field": "predicted", + }, + ], + "predicted", + ["rev", "predicted", "actual"], + {}, + ["rev", "predicted"], + ["rev", "actual"], + id="single_source", + ), + pytest.param( + [ + {"filename": "test", "field": "predicted"}, + {"filename": "train", "field": "predicted"}, + ], + [ + { + "rev": "B", + "predicted": "0.05", + "actual": "0.5", + "filename": "test", + "field": "predicted", + }, + { + "rev": "B", + "predicted": "0.9", + "actual": "0.9", + "filename": "train", + "field": "predicted", + }, + ], + "predicted", + ["rev", "predicted", "actual"], + {"field": "filename", "sort": []}, + ["rev", "filename", "predicted"], + ["rev", "filename", "actual"], + id="multi_filename", + ), + pytest.param( + [ + {"filename": "data", "field": "predicted_test"}, + {"filename": "data", "field": "predicted_train"}, + ], + [ + { + "rev": "B", + "dvc_inferred_y_value": "0.05", + "predicted_train": "0.05", + "predicted_test": "0.9", + "actual": "0.5", + "filename": "data", + "field": "predicted_test", + }, + { + "rev": "B", + "dvc_inferred_y_value": "0.9", + "predicted_train": "0.05", + "predicted_test": "0.9", + "actual": "0.5", + "filename": "data", + "field": "predicted_train", + }, + ], + "dvc_inferred_y_value", + ["rev", "dvc_inferred_y_value", "actual"], + {"field": "field", "sort": []}, + ["rev", "field", "dvc_inferred_y_value"], + ["rev", "field", "actual"], + id="multi_field", + ), + pytest.param( + [ + {"filename": "test", "field": "predicted_test"}, + {"filename": "train", "field": "predicted_train"}, + ], + [ + { + "rev": "B", + "dvc_inferred_y_value": "0.05", + "predicted_test": "0.05", + "actual": "0.5", + "filename": "test", + "field": "predicted_test", + }, + { + "rev": "B", + "dvc_inferred_y_value": "0.9", + "predicted_test": "0.9", + "actual": "0.9", + "filename": "test", + "field": "predicted_test", + }, + { + "rev": "B", + "dvc_inferred_y_value": "0.9", + "predicted_train": "0.9", + "actual": "0.9", + "filename": "train", + "field": "predicted_train", + }, + { + "rev": "B", + "dvc_inferred_y_value": "0.9", + "predicted_train": "0.9", + "actual": "0.9", + "filename": "train", + "field": "predicted_train", + }, + ], + "dvc_inferred_y_value", + ["rev", "predicted", "actual"], + {"field": "filename::field", "sort": []}, + ["rev", "filename::field", "dvc_inferred_y_value"], + ["rev", "filename::field", "actual"], + id="multi_filename_field", + ), + ), +) +def test_optional_anchors_confusion( + anchors_y_definitions, + datapoints, + y, + expected_dp_keys, + row_encoding, + group_by_y, + group_by_x, +): # pylint: disable=too-many-arguments + props = { + "anchors_y_definitions": anchors_y_definitions, + "revs_with_datapoints": ["B"], + "template": "confusion", + "x": "actual", + "y": y, + } + + expected_datapoints = _get_expected_datapoints(datapoints, expected_dp_keys) + + renderer = VegaRenderer(datapoints, "foo", **props) + plot_content = renderer.get_filled_template() + + assert plot_content["data"]["values"] == expected_datapoints + assert plot_content["facet"]["row"] == row_encoding + assert plot_content["spec"]["transform"][0]["groupby"] == [y, "actual"] + assert plot_content["spec"]["transform"][1]["groupby"] == group_by_y + assert plot_content["spec"]["transform"][2]["groupby"] == group_by_x + assert plot_content["spec"]["layer"][0]["width"] == 300 + assert plot_content["spec"]["layer"][0]["height"] == 300 + + +@pytest.mark.parametrize( + ",".join( + [ + "anchors_y_definitions", + "datapoints", + "y", + "expected_dp_keys", + "shape_encoding", + "tooltip_encoding", + ] + ), + ( + pytest.param( + [{"filename": "test", "field": "acc"}], + [ + { + "rev": "B", + "acc": "0.05", + "filename": "test", + "field": "acc", + "loss": 0.1, + }, + ], + "acc", + ["rev", "acc", "loss"], + {}, + [{"field": "rev"}, {"field": "loss"}, {"field": "acc"}], + id="single_source", + ), + pytest.param( + [ + {"filename": "train", "field": "acc"}, + {"filename": "test", "field": "acc"}, + ], + [ + { + "rev": "B", + "acc": "0.05", + "filename": "train", + "field": "acc", + "loss": "0.0001", + }, + { + "rev": "B", + "acc": "0.06", + "filename": "test", + "field": "acc", + "loss": "200121", + }, + ], + "acc", + ["rev", "acc", "filename", "loss"], + { + "field": "filename", + "legend": { + "symbolFillColor": "transparent", + "symbolStrokeColor": "grey", + }, + "scale": { + "domain": ["test", "train"], + "range": OPTIONAL_ANCHOR_RANGES["shape"][0:2], + }, + }, + [ + {"field": "rev"}, + {"field": "loss"}, + {"field": "acc"}, + {"field": "filename"}, + ], + id="multi_filename", + ), + pytest.param( + [ + {"filename": "data", "field": "train_acc"}, + {"filename": "data", "field": "test_acc"}, + ], + [ + { + "rev": "B", + "dvc_inferred_y_value": "0.05", + "test_acc": "0.05", + "train_acc": "0.06", + "filename": "data", + "field": "test_acc", + "loss": 0.1, + }, + { + "rev": "B", + "dvc_inferred_y_value": "0.06", + "test_acc": "0.05", + "train_acc": "0.06", + "filename": "data", + "field": "train_acc", + "loss": 0.1, + }, + ], + "dvc_inferred_y_value", + ["rev", "dvc_inferred_y_value", "train_acc", "test_acc", "loss"], + { + "field": "field", + "legend": { + "symbolFillColor": "transparent", + "symbolStrokeColor": "grey", + }, + "scale": { + "domain": ["test_acc", "train_acc"], + "range": OPTIONAL_ANCHOR_RANGES["shape"][0:2], + }, + }, + [ + {"field": "rev"}, + {"field": "loss"}, + {"field": "dvc_inferred_y_value"}, + {"field": "field"}, + ], + id="multi_field", + ), + pytest.param( + [ + {"filename": "train", "field": "train_acc"}, + {"filename": "test", "field": "test_acc"}, + ], + [ + { + "rev": "B", + "dvc_inferred_y_value": "0.05", + "test_acc": "0.05", + "filename": "test", + "field": "test_acc", + "loss": 0.1, + }, + { + "rev": "B", + "dvc_inferred_y_value": "0.06", + "train_acc": "0.06", + "filename": "train", + "field": "train_acc", + "loss": 0.1, + }, + ], + "dvc_inferred_y_value", + ["rev", "dvc_inferred_y_value", "train_acc", "test_acc", "loss"], + { + "field": "filename::field", + "legend": { + "symbolFillColor": "transparent", + "symbolStrokeColor": "grey", + }, + "scale": { + "domain": ["test::test_acc", "train::train_acc"], + "range": OPTIONAL_ANCHOR_RANGES["shape"][0:2], + }, + }, + [ + {"field": "rev"}, + {"field": "loss"}, + {"field": "dvc_inferred_y_value"}, + {"field": "filename::field"}, + ], + id="multi_filename_field", + ), + ), +) +def test_optional_anchors_scatter( + anchors_y_definitions, + datapoints, + y, + expected_dp_keys, + shape_encoding, + tooltip_encoding, +): # pylint: disable=too-many-arguments + props = { + "anchors_y_definitions": anchors_y_definitions, + "revs_with_datapoints": ["B"], + "template": "scatter", + "x": "loss", + "y": y, + } + + expected_datapoints = _get_expected_datapoints(datapoints, expected_dp_keys) + + renderer = VegaRenderer(datapoints, "foo", **props) + plot_content = renderer.get_filled_template() + + assert plot_content["data"]["values"] == expected_datapoints + assert plot_content["encoding"]["color"] == { + "field": "rev", + "scale": {"domain": ["B"], "range": OPTIONAL_ANCHOR_RANGES["color"][0:1]}, + } + assert plot_content["encoding"]["shape"] == shape_encoding + assert plot_content["encoding"]["tooltip"] == tooltip_encoding + assert plot_content["params"] == [ + { + "name": "grid", + "select": "interval", + "bind": "scales", + } + ] + + +@pytest.mark.parametrize( + ",".join( + [ + "revs", + "datapoints", + ] + ), + ( + pytest.param( + ["B"], + [ + { + "rev": "B", + "acc": "0.05", + "step": "1", + "filename": "acc", + "field": "acc", + }, + ], + id="rev_count_1", + ), + pytest.param( + ["B", "C", "D", "E", "F"], + [ + { + "rev": "B", + "acc": "0.05", + "step": "1", + "filename": "acc", + "field": "acc", + }, + { + "rev": "C", + "acc": "0.1", + "step": "1", + "filename": "acc", + "field": "acc", + }, + { + "rev": "D", + "acc": "0.06", + "step": "1", + "filename": "acc", + "field": "acc", + }, + { + "rev": "E", + "acc": "0.6", + "step": "1", + "filename": "acc", + "field": "acc", + }, + { + "rev": "F", + "acc": "1.0", + "step": "1", + "filename": "acc", + "field": "acc", + }, + ], + id="rev_count_5", + ), + pytest.param( + ["B", "C", "D", "E", "F", "G", "H", "I", "J"], + [ + { + "rev": "B", + "acc": "0.05", + "step": "1", + "filename": "acc", + "field": "acc", + }, + { + "rev": "C", + "acc": "0.1", + "step": "1", + "filename": "acc", + "field": "acc", + }, + { + "rev": "D", + "acc": "0.06", + "step": "1", + "filename": "acc", + "field": "acc", + }, + { + "rev": "E", + "acc": "0.6", + "step": "1", + "filename": "acc", + "field": "acc", + }, + { + "rev": "F", + "acc": "1.0", + "step": "1", + "filename": "acc", + "field": "acc", + }, + { + "rev": "G", + "acc": "0.006", + "step": "1", + "filename": "acc", + "field": "acc", + }, + { + "rev": "H", + "acc": "0.00001", + "step": "1", + "filename": "acc", + "field": "acc", + }, + { + "rev": "I", + "acc": "0.8", + "step": "1", + "filename": "acc", + "field": "acc", + }, + { + "rev": "J", + "acc": "0.001", + "step": "1", + "filename": "acc", + "field": "acc", + }, + ], + id="rev_count_9", + ), + ), +) +def test_color_anchor(revs, datapoints): + props = { + "anchors_y_definitions": [{"filename": "acc", "field": "acc"}], + "revs_with_datapoints": revs, + "template": "linear", + "x": "step", + "y": "acc", + } + + renderer = VegaRenderer(datapoints, "foo", **props) + plot_content = renderer.get_filled_template() + + colors = OPTIONAL_ANCHOR_RANGES["color"] + color_range = colors[0 : len(revs)] + if len(revs) > len(colors): + color_range.extend(colors[0 : len(revs) - len(colors)]) + + assert plot_content["encoding"]["color"] == { + "field": "rev", + "scale": { + "domain": revs, + "range": color_range, + }, + } + + +@pytest.mark.parametrize( + ",".join( + [ + "anchors_y_definitions", + "datapoints", + "y", + "expected_dp_keys", + "stroke_dash_encoding", + ] + ), + ( + pytest.param( + [{"filename": "test", "field": "acc"}], + [ + { + "rev": "B", + "acc": "0.05", + "filename": "test", + "field": "acc", + "step": 1, + }, + { + "rev": "B", + "acc": "0.1", + "filename": "test", + "field": "acc", + "step": 2, + }, + ], + "acc", + ["rev", "acc", "step"], + {}, + id="single_source", + ), + pytest.param( + [ + {"filename": "test", "field": "acc"}, + {"filename": "train", "field": "acc"}, + ], + [ + { + "rev": "B", + "acc": "0.05", + "filename": "test", + "field": "acc", + "step": 1, + }, + { + "rev": "B", + "acc": "0.04", + "filename": "train", + "field": "acc", + "step": 1, + }, + ], + "acc", + ["rev", "acc", "step", "field"], + { + "field": "filename", + "scale": { + "domain": ["test", "train"], + "range": OPTIONAL_ANCHOR_RANGES["stroke_dash"][0:2], + }, + "legend": { + "symbolFillColor": "transparent", + "symbolStrokeColor": "grey", + }, + }, + id="multi_filename", + ), + pytest.param( + [ + {"filename": "test", "field": "acc"}, + {"filename": "test", "field": "acc_norm"}, + ], + [ + { + "rev": "B", + "dvc_inferred_y_value": "0.05", + "acc": "0.05", + "acc_norm": "0.04", + "filename": "test", + "field": "acc", + "step": 1, + }, + { + "rev": "B", + "dvc_inferred_y_value": "0.04", + "acc": "0.05", + "acc_norm": "0.04", + "filename": "test", + "field": "acc_norm", + "step": 1, + }, + ], + "dvc_inferred_y_value", + ["rev", "dvc_inferred_y_value", "acc", "acc_norm", "step", "field"], + { + "field": "field", + "scale": { + "domain": ["acc", "acc_norm"], + "range": OPTIONAL_ANCHOR_RANGES["stroke_dash"][0:2], + }, + "legend": { + "symbolFillColor": "transparent", + "symbolStrokeColor": "grey", + }, + }, + id="multi_field", + ), + pytest.param( + [ + {"filename": "test", "field": "acc_norm"}, + {"filename": "test", "field": "acc"}, + {"filename": "train", "field": "acc"}, + ], + [ + { + "rev": "B", + "dvc_inferred_y_value": "0.05", + "acc": "0.05", + "acc_norm": "0.02", + "filename": "test", + "field": "acc", + "step": 1, + }, + { + "rev": "B", + "dvc_inferred_y_value": "0.04", + "acc": "0.04", + "filename": "train", + "field": "acc", + "step": 1, + }, + { + "rev": "B", + "dvc_inferred_y_value": "0.02", + "filename": "test", + "acc": "0.05", + "acc_norm": "0.02", + "field": "acc_norm", + "step": 1, + }, + ], + "dvc_inferred_y_value", + [ + "rev", + "dvc_inferred_y_value", + "acc", + "acc_norm", + "step", + "filename::field", + ], + { + "field": "filename::field", + "scale": { + "domain": ["test::acc", "test::acc_norm", "train::acc"], + "range": OPTIONAL_ANCHOR_RANGES["stroke_dash"][0:3], + }, + "legend": { + "symbolFillColor": "transparent", + "symbolStrokeColor": "grey", + }, + }, + id="multi_filename_field", + ), + ), +) +def test_partial_filled_template( + anchors_y_definitions, + datapoints, + y, + expected_dp_keys, + stroke_dash_encoding, +): + title = f"{y} by step" + props = { + "anchors_y_definitions": anchors_y_definitions, + "revs_with_datapoints": ["B"], + "template": "linear", + "title": title, + "x": "step", + "y": y, + } + + expected_split = { + Template.anchor("color"): { + "field": "rev", + "scale": {"domain": ["B"], "range": OPTIONAL_ANCHOR_RANGES["color"][0:1]}, + }, + Template.anchor("data"): _get_expected_datapoints(datapoints, expected_dp_keys), + Template.anchor("plot_height"): 300, + Template.anchor("plot_width"): 300, + Template.anchor("title"): title, + Template.anchor("x_label"): "step", + Template.anchor("y_label"): y, + Template.anchor("zoom_and_pan"): { + "name": "grid", + "select": "interval", + "bind": "scales", + }, + } + + split_anchors = [ + Template.anchor("color"), + Template.anchor("data"), + ] + if len(anchors_y_definitions) > 1: + split_anchors.append(Template.anchor("stroke_dash")) + expected_split[Template.anchor("stroke_dash")] = stroke_dash_encoding + + content, split = VegaRenderer( + datapoints, "foo", **props + ).get_partial_filled_template() + + content_str = json.dumps(content) + + for anchor in split_anchors: + assert anchor in content_str + for key, value in split["anchor_definitions"].items(): + assert value == expected_split[key] + + +def _get_expected_datapoints( + datapoints: List[Dict[str, Any]], expected_dp_keys: List[str] +): + expected_datapoints: List[Dict[str, Any]] = [] + for datapoint in datapoints: + expected_datapoint = {} + for key in expected_dp_keys: + if key == "filename::field": + expected_datapoint[ + key + ] = f"{datapoint['filename']}::{datapoint['field']}" + else: + value = datapoint.get(key) + if value is None: + continue + expected_datapoint[key] = value + expected_datapoints.append(expected_datapoint) + + return datapoints + + +def test_partial_html(): + props = {"x": "x", "y": "y"} + datapoints = [ + {"x": 100, "y": 100, "val": 2}, + {"x": 200, "y": 300, "val": 3}, + ] + + assert isinstance(VegaRenderer(datapoints, "foo", **props).partial_html(), str)