From ba87a2ee1b68c338f48a4bca892e2fb541b5d5f8 Mon Sep 17 00:00:00 2001 From: David de la Iglesia Castro Date: Thu, 16 Mar 2023 14:47:53 +0100 Subject: [PATCH] vega_templates: Handle `content` as dict instead of string. (#124) Prevent unnecessary `dumps`/`loads` calls. Closes #23 --- src/dvc_render/vega.py | 26 +++---- src/dvc_render/vega_templates.py | 114 ++++++++++++++++++++----------- tests/test_templates.py | 15 ++-- tests/test_vega.py | 17 ++--- 4 files changed, 100 insertions(+), 72 deletions(-) diff --git a/src/dvc_render/vega.py b/src/dvc_render/vega.py index d88b36a..7d1f4cd 100644 --- a/src/dvc_render/vega.py +++ b/src/dvc_render/vega.py @@ -1,16 +1,11 @@ -from copy import deepcopy +import json from pathlib import Path -from typing import List, Optional +from typing import Any, Dict, List, Optional from warnings import warn from .base import Renderer -from .exceptions import DvcRenderException from .utils import list_dict_to_dict_list -from .vega_templates import LinearTemplate, get_template - - -class BadTemplateError(DvcRenderException): - pass +from .vega_templates import BadTemplateError, LinearTemplate, get_template class VegaRenderer(Renderer): @@ -44,16 +39,15 @@ def __init__(self, datapoints: List, name: str, **properties): def get_filled_template( self, skip_anchors: Optional[List[str]] = None, strict: bool = True - ) -> str: + ) -> Dict[str, Any]: """Returns a functional vega specification""" + self.template.reset() if not self.datapoints: - return "" + return {} if skip_anchors is None: skip_anchors = [] - content = deepcopy(self.template.content) - if strict: if self.properties.get("x"): self.template.check_field_exists( @@ -76,7 +70,7 @@ def get_filled_template( if value is None: continue if name == "data": - if self.template.anchor_str(name) not in self.template.content: + if not self.template.has_anchor(name): anchor = self.template.anchor(name) raise BadTemplateError( f"Template '{self.template.name}' " @@ -84,12 +78,12 @@ def get_filled_template( ) elif name in {"x", "y"}: value = self.template.escape_special_characters(value) - content = self.template.fill_anchor(content, name, value) + self.template.fill_anchor(name, value) - return content + return self.template.content def partial_html(self, **kwargs) -> str: - return self.get_filled_template() + return json.dumps(self.get_filled_template()) def generate_markdown(self, report_path=None) -> str: if not isinstance(self.template, LinearTemplate): diff --git a/src/dvc_render/vega_templates.py b/src/dvc_render/vega_templates.py index 765faab..c77f178 100644 --- a/src/dvc_render/vega_templates.py +++ b/src/dvc_render/vega_templates.py @@ -1,3 +1,4 @@ +# pylint: disable=missing-function-docstring import json from pathlib import Path from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union @@ -27,30 +28,69 @@ def __init__(self, template_name: str, path: str): ) +class BadTemplateError(DvcRenderException): + pass + + +def dict_replace_value(d: dict, name: str, value: Any) -> dict: + x = {} + for k, v in d.items(): + if isinstance(v, dict): + v = dict_replace_value(v, name, value) + elif isinstance(v, list): + v = list_replace_value(v, name, value) + elif isinstance(v, str): + if v == name: + x[k] = value + continue + x[k] = v + return x + + +def list_replace_value(l: list, name: str, value: str) -> list: # noqa: E741 + x = [] + for e in l: + if isinstance(e, list): + e = list_replace_value(e, name, value) + elif isinstance(e, dict): + e = dict_replace_value(e, name, value) + elif isinstance(e, str): + if e == name: + e = value + x.append(e) + return x + + +def dict_find_value(d: dict, value: str) -> bool: + for v in d.values(): + if isinstance(v, dict): + return dict_find_value(v, value) + if isinstance(v, str): + if v == value: + return True + return False + + class Template: - INDENT = 4 - SEPARATORS = (",", ": ") EXTENSION = ".json" ANCHOR = "" - DEFAULT_CONTENT: Optional[Dict[str, Any]] = None - DEFAULT_NAME: Optional[str] = None - - def __init__(self, content=None, name=None): - if content: - self.content = content - else: - self.content = ( - json.dumps( - self.DEFAULT_CONTENT, - indent=self.INDENT, - separators=self.SEPARATORS, - ) - + "\n" - ) - + DEFAULT_CONTENT: Dict[str, Any] = {} + DEFAULT_NAME: str = "" + + def __init__( + self, content: Optional[Dict[str, Any]] = None, name: Optional[str] = None + ): + if ( + content + and not isinstance(content, dict) + or self.DEFAULT_CONTENT + and not isinstance(self.DEFAULT_CONTENT, dict) + ): + raise BadTemplateError() + self._original_content = content or self.DEFAULT_CONTENT + self.content: Dict[str, Any] = self._original_content self.name = name or self.DEFAULT_NAME - assert self.content and self.name self.filename = Path(self.name).with_suffix(self.EXTENSION) @classmethod @@ -58,18 +98,6 @@ def anchor(cls, name): "Get ANCHOR formatted with name." return cls.ANCHOR.format(name.upper()) - def has_anchor(self, name) -> bool: - "Check if ANCHOR formatted with name is in content." - return self.anchor_str(name) in self.content - - @classmethod - def fill_anchor(cls, content, name, value) -> str: - "Replace anchor `name` with `value` in content." - value_str = json.dumps( - value, indent=cls.INDENT, separators=cls.SEPARATORS, sort_keys=True - ) - return content.replace(cls.anchor_str(name), value_str) - @classmethod def escape_special_characters(cls, value: str) -> str: "Escape special characters in `value`" @@ -77,17 +105,25 @@ def escape_special_characters(cls, value: str) -> str: value = value.replace(character, "\\" + character) return value - @classmethod - def anchor_str(cls, name) -> str: - "Get string wrapping ANCHOR formatted with name." - return f'"{cls.anchor(name)}"' - @staticmethod def check_field_exists(data, field): "Raise NoFieldInDataError if `field` not in `data`." if not any(field in row for row in data): raise NoFieldInDataError(field) + def reset(self): + """Reset self.content to its original state.""" + self.content = self._original_content + + def has_anchor(self, name) -> bool: + "Check if ANCHOR formatted with name is in content." + found = dict_find_value(self.content, self.anchor(name)) + return found + + def fill_anchor(self, name, value) -> None: + "Replace anchor `name` with `value` in content." + self.content = dict_replace_value(self.content, self.anchor(name), value) + class BarHorizontalSortedTemplate(Template): DEFAULT_NAME = "bar_horizontal_sorted" @@ -606,7 +642,7 @@ def get_template( _open = open if fs is None else fs.open if template_path: with _open(template_path, encoding="utf-8") as f: - content = f.read() + content = json.load(f) return Template(content, name=template) for template_cls in TEMPLATES: @@ -635,6 +671,6 @@ def dump_templates(output: "StrPath", targets: Optional[List] = None) -> None: if path.exists(): content = path.read_text(encoding="utf-8") if content != template.content: - raise TemplateContentDoesNotMatch(template.DEFAULT_NAME or "", path) + raise TemplateContentDoesNotMatch(template.DEFAULT_NAME, str(path)) else: - path.write_text(template.content, encoding="utf-8") + path.write_text(json.dumps(template.content), encoding="utf-8") diff --git a/tests/test_templates.py b/tests/test_templates.py index f99bcaf..af0b02a 100644 --- a/tests/test_templates.py +++ b/tests/test_templates.py @@ -1,3 +1,4 @@ +import json import os import pytest @@ -38,8 +39,9 @@ def test_raise_on_no_template(): ], ) def test_get_template_from_dir(tmp_dir, template_path, target_name): - tmp_dir.gen(template_path, "template_content") - assert get_template(target_name, ".dvc/plots").content == "template_content" + template_content = {"template_content": "foo"} + tmp_dir.gen(template_path, json.dumps(template_content)) + assert get_template(target_name, ".dvc/plots").content == template_content def test_get_template_exact_match(tmp_dir): @@ -51,13 +53,16 @@ def test_get_template_exact_match(tmp_dir): def test_get_template_from_file(tmp_dir): - tmp_dir.gen("foo/bar.json", "template_content") - assert get_template("foo/bar.json").content == "template_content" + template_content = {"template_content": "foo"} + tmp_dir.gen("foo/bar.json", json.dumps(template_content)) + assert get_template("foo/bar.json").content == template_content def test_get_template_fs(tmp_dir, mocker): - tmp_dir.gen("foo/bar.json", "template_content") + template_content = {"template_content": "foo"} + tmp_dir.gen("foo/bar.json", json.dumps(template_content)) fs = mocker.MagicMock() + mocker.patch("json.load", return_value={}) get_template("foo/bar.json", fs=fs) fs.open.assert_called() fs.exists.assert_called() diff --git a/tests/test_vega.py b/tests/test_vega.py index 39b6d4d..dbc823b 100644 --- a/tests/test_vega.py +++ b/tests/test_vega.py @@ -1,5 +1,3 @@ -import json - import pytest from dvc_render.vega import BadTemplateError, VegaRenderer @@ -33,7 +31,6 @@ def test_init_empty(): assert renderer.name == "" assert renderer.properties == {} - assert renderer.generate_html() == "" assert renderer.generate_markdown("foo") == "" @@ -43,7 +40,7 @@ def test_default_template_mark(): {"first_val": 200, "second_val": 300, "val": 3}, ] - plot_content = json.loads(VegaRenderer(datapoints, "foo").partial_html()) + plot_content = VegaRenderer(datapoints, "foo").get_filled_template() assert plot_content["layer"][0]["mark"] == "line" @@ -60,7 +57,7 @@ def test_choose_axes(): {"first_val": 200, "second_val": 300, "val": 3}, ] - plot_content = json.loads(VegaRenderer(datapoints, "foo", **props).partial_html()) + plot_content = VegaRenderer(datapoints, "foo", **props).get_filled_template() assert plot_content["data"]["values"] == [ { @@ -85,7 +82,7 @@ def test_confusion(): ] props = {"template": "confusion", "x": "predicted", "y": "actual"} - plot_content = json.loads(VegaRenderer(datapoints, "foo", **props).partial_html()) + plot_content = VegaRenderer(datapoints, "foo", **props).get_filled_template() assert plot_content["data"]["values"] == [ {"predicted": "B", "actual": "A"}, @@ -100,12 +97,8 @@ def test_confusion(): def test_bad_template(): - datapoints = [{"val": 2}, {"val": 3}] - props = {"template": Template("name", "content")} - renderer = VegaRenderer(datapoints, "foo", **props) with pytest.raises(BadTemplateError): - renderer.get_filled_template() - renderer.get_filled_template(skip_anchors=["data"]) + Template("name", "content") def test_raise_on_wrong_field(): @@ -177,7 +170,7 @@ def test_escape_special_characters(): ] props = {"template": "simple", "x": "foo.bar[0]", "y": "foo.bar[1]"} renderer = VegaRenderer(datapoints, "foo", **props) - filled = json.loads(renderer.get_filled_template()) + filled = renderer.get_filled_template() # data is not escaped assert filled["data"]["values"][0] == datapoints[0] # field and title yes