Skip to content

Commit

Permalink
vega_templates: Handle content as dict instead of string. (#124)
Browse files Browse the repository at this point in the history
Prevent unnecessary `dumps`/`loads` calls.

Closes #23
  • Loading branch information
daavoo authored Mar 16, 2023
1 parent 14f91b2 commit ba87a2e
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 72 deletions.
26 changes: 10 additions & 16 deletions src/dvc_render/vega.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,11 @@
from copy import deepcopy
import json
from pathlib import Path
from typing import List, Optional
from typing import Any, Dict, List, Optional
from warnings import warn

from .base import Renderer
from .exceptions import DvcRenderException
from .utils import list_dict_to_dict_list
from .vega_templates import LinearTemplate, get_template


class BadTemplateError(DvcRenderException):
pass
from .vega_templates import BadTemplateError, LinearTemplate, get_template


class VegaRenderer(Renderer):
Expand Down Expand Up @@ -44,16 +39,15 @@ def __init__(self, datapoints: List, name: str, **properties):

def get_filled_template(
self, skip_anchors: Optional[List[str]] = None, strict: bool = True
) -> str:
) -> Dict[str, Any]:
"""Returns a functional vega specification"""
self.template.reset()
if not self.datapoints:
return ""
return {}

if skip_anchors is None:
skip_anchors = []

content = deepcopy(self.template.content)

if strict:
if self.properties.get("x"):
self.template.check_field_exists(
Expand All @@ -76,20 +70,20 @@ def get_filled_template(
if value is None:
continue
if name == "data":
if self.template.anchor_str(name) not in self.template.content:
if not self.template.has_anchor(name):
anchor = self.template.anchor(name)
raise BadTemplateError(
f"Template '{self.template.name}' "
f"is not using '{anchor}' anchor"
)
elif name in {"x", "y"}:
value = self.template.escape_special_characters(value)
content = self.template.fill_anchor(content, name, value)
self.template.fill_anchor(name, value)

return content
return self.template.content

def partial_html(self, **kwargs) -> str:
return self.get_filled_template()
return json.dumps(self.get_filled_template())

def generate_markdown(self, report_path=None) -> str:
if not isinstance(self.template, LinearTemplate):
Expand Down
114 changes: 75 additions & 39 deletions src/dvc_render/vega_templates.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# pylint: disable=missing-function-docstring
import json
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
Expand Down Expand Up @@ -27,67 +28,102 @@ def __init__(self, template_name: str, path: str):
)


class BadTemplateError(DvcRenderException):
pass


def dict_replace_value(d: dict, name: str, value: Any) -> dict:
x = {}
for k, v in d.items():
if isinstance(v, dict):
v = dict_replace_value(v, name, value)
elif isinstance(v, list):
v = list_replace_value(v, name, value)
elif isinstance(v, str):
if v == name:
x[k] = value
continue
x[k] = v
return x


def list_replace_value(l: list, name: str, value: str) -> list: # noqa: E741
x = []
for e in l:
if isinstance(e, list):
e = list_replace_value(e, name, value)
elif isinstance(e, dict):
e = dict_replace_value(e, name, value)
elif isinstance(e, str):
if e == name:
e = value
x.append(e)
return x


def dict_find_value(d: dict, value: str) -> bool:
for v in d.values():
if isinstance(v, dict):
return dict_find_value(v, value)
if isinstance(v, str):
if v == value:
return True
return False


class Template:
INDENT = 4
SEPARATORS = (",", ": ")
EXTENSION = ".json"
ANCHOR = "<DVC_METRIC_{}>"

DEFAULT_CONTENT: Optional[Dict[str, Any]] = None
DEFAULT_NAME: Optional[str] = None

def __init__(self, content=None, name=None):
if content:
self.content = content
else:
self.content = (
json.dumps(
self.DEFAULT_CONTENT,
indent=self.INDENT,
separators=self.SEPARATORS,
)
+ "\n"
)

DEFAULT_CONTENT: Dict[str, Any] = {}
DEFAULT_NAME: str = ""

def __init__(
self, content: Optional[Dict[str, Any]] = None, name: Optional[str] = None
):
if (
content
and not isinstance(content, dict)
or self.DEFAULT_CONTENT
and not isinstance(self.DEFAULT_CONTENT, dict)
):
raise BadTemplateError()
self._original_content = content or self.DEFAULT_CONTENT
self.content: Dict[str, Any] = self._original_content
self.name = name or self.DEFAULT_NAME
assert self.content and self.name
self.filename = Path(self.name).with_suffix(self.EXTENSION)

@classmethod
def anchor(cls, name):
"Get ANCHOR formatted with name."
return cls.ANCHOR.format(name.upper())

def has_anchor(self, name) -> bool:
"Check if ANCHOR formatted with name is in content."
return self.anchor_str(name) in self.content

@classmethod
def fill_anchor(cls, content, name, value) -> str:
"Replace anchor `name` with `value` in content."
value_str = json.dumps(
value, indent=cls.INDENT, separators=cls.SEPARATORS, sort_keys=True
)
return content.replace(cls.anchor_str(name), value_str)

@classmethod
def escape_special_characters(cls, value: str) -> str:
"Escape special characters in `value`"
for character in (".", "[", "]"):
value = value.replace(character, "\\" + character)
return value

@classmethod
def anchor_str(cls, name) -> str:
"Get string wrapping ANCHOR formatted with name."
return f'"{cls.anchor(name)}"'

@staticmethod
def check_field_exists(data, field):
"Raise NoFieldInDataError if `field` not in `data`."
if not any(field in row for row in data):
raise NoFieldInDataError(field)

def reset(self):
"""Reset self.content to its original state."""
self.content = self._original_content

def has_anchor(self, name) -> bool:
"Check if ANCHOR formatted with name is in content."
found = dict_find_value(self.content, self.anchor(name))
return found

def fill_anchor(self, name, value) -> None:
"Replace anchor `name` with `value` in content."
self.content = dict_replace_value(self.content, self.anchor(name), value)


class BarHorizontalSortedTemplate(Template):
DEFAULT_NAME = "bar_horizontal_sorted"
Expand Down Expand Up @@ -606,7 +642,7 @@ def get_template(
_open = open if fs is None else fs.open
if template_path:
with _open(template_path, encoding="utf-8") as f:
content = f.read()
content = json.load(f)
return Template(content, name=template)

for template_cls in TEMPLATES:
Expand Down Expand Up @@ -635,6 +671,6 @@ def dump_templates(output: "StrPath", targets: Optional[List] = None) -> None:
if path.exists():
content = path.read_text(encoding="utf-8")
if content != template.content:
raise TemplateContentDoesNotMatch(template.DEFAULT_NAME or "", path)
raise TemplateContentDoesNotMatch(template.DEFAULT_NAME, str(path))
else:
path.write_text(template.content, encoding="utf-8")
path.write_text(json.dumps(template.content), encoding="utf-8")
15 changes: 10 additions & 5 deletions tests/test_templates.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import os

import pytest
Expand Down Expand Up @@ -38,8 +39,9 @@ def test_raise_on_no_template():
],
)
def test_get_template_from_dir(tmp_dir, template_path, target_name):
tmp_dir.gen(template_path, "template_content")
assert get_template(target_name, ".dvc/plots").content == "template_content"
template_content = {"template_content": "foo"}
tmp_dir.gen(template_path, json.dumps(template_content))
assert get_template(target_name, ".dvc/plots").content == template_content


def test_get_template_exact_match(tmp_dir):
Expand All @@ -51,13 +53,16 @@ def test_get_template_exact_match(tmp_dir):


def test_get_template_from_file(tmp_dir):
tmp_dir.gen("foo/bar.json", "template_content")
assert get_template("foo/bar.json").content == "template_content"
template_content = {"template_content": "foo"}
tmp_dir.gen("foo/bar.json", json.dumps(template_content))
assert get_template("foo/bar.json").content == template_content


def test_get_template_fs(tmp_dir, mocker):
tmp_dir.gen("foo/bar.json", "template_content")
template_content = {"template_content": "foo"}
tmp_dir.gen("foo/bar.json", json.dumps(template_content))
fs = mocker.MagicMock()
mocker.patch("json.load", return_value={})
get_template("foo/bar.json", fs=fs)
fs.open.assert_called()
fs.exists.assert_called()
Expand Down
17 changes: 5 additions & 12 deletions tests/test_vega.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import json

import pytest

from dvc_render.vega import BadTemplateError, VegaRenderer
Expand Down Expand Up @@ -33,7 +31,6 @@ def test_init_empty():
assert renderer.name == ""
assert renderer.properties == {}

assert renderer.generate_html() == ""
assert renderer.generate_markdown("foo") == ""


Expand All @@ -43,7 +40,7 @@ def test_default_template_mark():
{"first_val": 200, "second_val": 300, "val": 3},
]

plot_content = json.loads(VegaRenderer(datapoints, "foo").partial_html())
plot_content = VegaRenderer(datapoints, "foo").get_filled_template()

assert plot_content["layer"][0]["mark"] == "line"

Expand All @@ -60,7 +57,7 @@ def test_choose_axes():
{"first_val": 200, "second_val": 300, "val": 3},
]

plot_content = json.loads(VegaRenderer(datapoints, "foo", **props).partial_html())
plot_content = VegaRenderer(datapoints, "foo", **props).get_filled_template()

assert plot_content["data"]["values"] == [
{
Expand All @@ -85,7 +82,7 @@ def test_confusion():
]
props = {"template": "confusion", "x": "predicted", "y": "actual"}

plot_content = json.loads(VegaRenderer(datapoints, "foo", **props).partial_html())
plot_content = VegaRenderer(datapoints, "foo", **props).get_filled_template()

assert plot_content["data"]["values"] == [
{"predicted": "B", "actual": "A"},
Expand All @@ -100,12 +97,8 @@ def test_confusion():


def test_bad_template():
datapoints = [{"val": 2}, {"val": 3}]
props = {"template": Template("name", "content")}
renderer = VegaRenderer(datapoints, "foo", **props)
with pytest.raises(BadTemplateError):
renderer.get_filled_template()
renderer.get_filled_template(skip_anchors=["data"])
Template("name", "content")


def test_raise_on_wrong_field():
Expand Down Expand Up @@ -177,7 +170,7 @@ def test_escape_special_characters():
]
props = {"template": "simple", "x": "foo.bar[0]", "y": "foo.bar[1]"}
renderer = VegaRenderer(datapoints, "foo", **props)
filled = json.loads(renderer.get_filled_template())
filled = renderer.get_filled_template()
# data is not escaped
assert filled["data"]["values"][0] == datapoints[0]
# field and title yes
Expand Down

0 comments on commit ba87a2e

Please sign in to comment.