Skip to content

Commit

Permalink
Merge pull request #60 from penpot/invisible-shape-identification
Browse files Browse the repository at this point in the history
Invisible shape identification
  • Loading branch information
kklemon authored Jun 20, 2024
2 parents fe16989 + 56adb6f commit a2c73ea
Show file tree
Hide file tree
Showing 10 changed files with 824 additions and 556 deletions.
1,034 changes: 526 additions & 508 deletions poetry.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ tqdm = "^4.66.4"
transit-python2 = "^0.8.321"
types-markdown = "^3.6.0.20240316"
webdriver-manager = "^4.0.1"
cssutils = "^2.11.1"
requests-cache = "^1.2.1"

[tool.poetry.group.dev]
Expand Down
129 changes: 116 additions & 13 deletions src/penai/svg.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from selenium.webdriver.remote.webdriver import WebDriver
from tqdm import tqdm

from penai import utils
from penai.registries.web_drivers import RegisteredWebDriver, get_web_driver_for_html
from penai.types import PathLike, RecursiveStrDict
from penai.utils.dict import apply_func_to_nested_keys
Expand Down Expand Up @@ -305,6 +306,33 @@ def _el_is_group(el: Element) -> bool:
return el.tag == el.get_namespaced_key("g")


def _el_has_visible_content(el: Element) -> bool:
children = el.getchildren()

# Note: Not sure if this is really true
# A <g> might have a class set that will set some fill / bg color and thus make it visible
if not children:
return False

if len(children) == 1 and children[0].tag == el.get_namespaced_key(
"path",
):
css_parser = utils.get_css_parser()

path = children[0]
path_style = css_parser.parseStyle(path.get("style", ""))

if path_style.getPropertyValue("opacity") == "0":
return False

if not path.getchildren() and (
path.get("fill") == "none" or path_style.getPropertyValue("fill") in ["none"]
):
return False

return True


_PenpotShapeDictEntry = dict["PenpotShapeElement", "_PenpotShapeDictEntry"]


Expand Down Expand Up @@ -575,6 +603,22 @@ def is_container_type(self) -> bool:
def is_primitive_type(self) -> bool:
return self._shape_type.value.category == PenpotShapeTypeCategory.PRIMITIVE

def check_for_visible_content(self) -> bool:
if self.type == PenpotShapeType.GROUP:
return any(child.check_for_visible_content() for child in self.child_shapes)

inner_groups = self.get_inner_g_elements()

if not inner_groups:
return False

assert len(inner_groups), (
f"Found no inner <g>-elements (i.e. content elements) for shape with id {self.shape_id} while expecting at least one such element. "
f"Tree: {etree.tostring(self.get_containing_g_element(), pretty_print=True)}"
)

return any(_el_has_visible_content(group) for group in inner_groups)

def get_parent_shape(self) -> Self | None:
g_containing_par_shape_candidate = self.get_containing_g_element().getparent()
while g_containing_par_shape_candidate is not None:
Expand All @@ -598,6 +642,12 @@ def get_containing_g_element(self) -> BetterElement:
"""
return self.getparent()

def get_inner_g_elements(self) -> list[BetterElement]:
return self.get_containing_g_element().xpath(
"default:g[not(starts-with(@id, 'shape-'))]",
empty_namespace_name="svg",
)

def is_leave(self) -> bool:
return not self.get_direct_children_shapes()

Expand Down Expand Up @@ -672,20 +722,20 @@ def __init__(
style_supplier: BaseStyleSupplier | None = None,
):
super().__init__(dom)
(
self._shape_elements,
self._depth_to_shape_el,
self._shape_el_to_depth,
) = find_all_penpot_shapes(self.dom, style_supplier=style_supplier)

shape_els, depth_to_shape_el, shape_el_to_depth = find_all_penpot_shapes(
dom,
style_supplier,
)
self._depth_to_shape_el = depth_to_shape_el
self._shape_el_to_depth = shape_el_to_depth
if depth_to_shape_el:
self._max_shape_depth = max(depth_to_shape_el.keys())
else:
self._max_shape_depth = 0
self._style_supplier = style_supplier

self.style_supplier = style_supplier
self.penpot_shape_elements = shape_els
def _reset_state(self) -> None:
(
self._shape_elements,
self._depth_to_shape_el,
self._shape_el_to_depth,
) = find_all_penpot_shapes(self.dom)

@overload
def _get_shapes_by_attr(
Expand Down Expand Up @@ -740,9 +790,16 @@ def get_shape_by_name(
def get_shape_by_id(self, shape_id: str) -> PenpotShapeElement:
return self._get_shapes_by_attr("shape_id", shape_id, should_be_unique=True)

@property
def penpot_shape_elements(self) -> list[PenpotShapeElement]:
return self._shape_elements

@property
def max_shape_depth(self) -> int:
return self._max_shape_depth
if self._depth_to_shape_el:
return max(self._depth_to_shape_el.keys())
else:
return 0

def get_shape_elements_at_depth(self, depth: int) -> list[PenpotShapeElement]:
return self._depth_to_shape_el.get(depth, [])
Expand All @@ -751,6 +808,52 @@ def pprint_hierarchy(self, horizontal: bool = True) -> None:
for shape in self.get_shape_elements_at_depth(0):
shape.pprint_hierarchy(horizontal=horizontal)

def _remove_shape_from_tree(self, shape_id: str) -> None:
shape = self.get_shape_by_id(shape_id)

container_g = shape.get_containing_g_element()
container_g.getparent().remove(container_g)

def remove_shape(self, shape_id: str) -> None:
self._remove_shape_from_tree(shape_id)
self._reset_state()

try:
self.get_shape_by_id(shape_id)
except KeyError:
return

raise AssertionError(f"Shape with id {shape_id} was not removed correctly.")

def remove_elements_with_no_visible_content(self) -> None:
# Sort the shapes by descending depth in the shape hierarchy, so that we start with the deepest shapes.
# Otherwise we may delete a parent shape before its children, thus decouple the children from the tree
# which will lead to weird behavior (i.e. lxml will assign arbitrary namespace names) and errors.
# We could, of course, also detect these relationships and only remove invisible parents,
# but just sorting the shapes is easier and should be fine for now.
shapes = sorted(
self.penpot_shape_elements,
key=lambda shape: shape.depth_in_shapes,
reverse=True,
)

removed_ids = []

for shape in shapes:
if not shape.check_for_visible_content():
self._remove_shape_from_tree(shape.shape_id)
removed_ids.append(shape.shape_id)

self._reset_state()

for shape_id in removed_ids:
try:
self.get_shape_by_id(shape_id)
except KeyError:
continue

raise AssertionError(f"Shape with id {shape_id} was not removed correctly.")

def retrieve_and_set_view_boxes_for_shape_elements(
self,
web_driver: WebDriver | RegisteredWebDriver,
Expand Down
10 changes: 9 additions & 1 deletion src/penai/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import json
import logging
from functools import cache
from pathlib import Path

import requests
import requests_cache
from cssutils import CSSParser

from penai.types import PathLike

Expand All @@ -14,6 +16,12 @@ def read_json(path: PathLike) -> dict:


@cache
def get_cached_requests_session(cache_name: str) -> requests.Session:
def get_css_parser() -> CSSParser:
"""Get a CSS parser with the default settings."""
return CSSParser(loglevel=logging.CRITICAL)


@cache
def get_cached_requests_session(cache_name: str = "cache") -> requests.Session:
"""Get a requests session with a cache."""
return requests_cache.CachedSession(cache_name)
2 changes: 2 additions & 0 deletions src/penai/utils/svg.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ def temp_file_for_content(
# The code below is essentially equivalent to `with open()...write`
with NamedTemporaryFile(prefix="penai_", suffix=extension, mode=mode, delete=False) as file:
file.write(content)
file.flush()

path = Path(file.name)
yield path

Expand Down
18 changes: 10 additions & 8 deletions src/penai/xml.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,13 @@ def xpath(
class BetterElement(CustomElement):
"""Simplifies handling of namespaces in ElementTree."""

@cached_property
@property
def query_compatible_nsmap(self) -> dict[str, str]:
nsmap = dict(self.nsmap)
nsmap[""] = nsmap.pop(None)

if None in nsmap:
nsmap["default"] = nsmap.pop(None)

return nsmap

@override
Expand All @@ -82,12 +85,11 @@ def xpath(
namespaces: dict[str, str] | None = None,
**kwargs: dict[str, Any],
) -> list[Self]:
namespaces = namespaces or self.query_compatible_nsmap

# xpath() does not support empty namespaces (applies to both None and empty string)
namespaces.pop("", None)

return super().xpath(path, namespaces=namespaces, **kwargs)
return super().xpath(
path,
namespaces=namespaces or self.query_compatible_nsmap,
**kwargs,
)

@overload
def get_namespaced_key(self, key: str) -> str:
Expand Down
24 changes: 15 additions & 9 deletions test/penai/conftest.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from collections.abc import Generator
from collections.abc import Generator, Iterable
from datetime import datetime
from pathlib import Path
from typing import Any, Iterable
from typing import Any

import pytest
from pytest import FixtureRequest, MonkeyPatch
from selenium.webdriver.remote.webdriver import WebDriver

from penai.config import top_level_directory
from penai.registries.projects import SavedPenpotProject
from penai.render import BaseSVGRenderer, WebDriverSVGRenderer
from penai.render import BaseSVGRenderer, ResvgRenderer, WebDriverSVGRenderer
from penai.types import PathLike
from penai.utils.web_drivers import create_chrome_web_driver

Expand Down Expand Up @@ -66,11 +66,17 @@ def chrome_svg_renderer(chrom_web_driver: WebDriver) -> Iterable[BaseSVGRenderer
return WebDriverSVGRenderer(chrom_web_driver)


@pytest.fixture(params=[
SavedPenpotProject.AVATAAARS,
SavedPenpotProject.BLACK_AND_WHITE_MOBILE_TEMPLATES,
SavedPenpotProject.MATERIAL_DESIGN_3,
])
@pytest.fixture(scope="session")
def resvg_renderer() -> Iterable[BaseSVGRenderer]:
return ResvgRenderer()


@pytest.fixture(
params=[
SavedPenpotProject.AVATAAARS,
SavedPenpotProject.BLACK_AND_WHITE_MOBILE_TEMPLATES,
SavedPenpotProject.MATERIAL_DESIGN_3,
],
)
def example_project(request: FixtureRequest) -> SavedPenpotProject:
return request.param

1 change: 0 additions & 1 deletion test/penai/test_registries.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@


class TestPenpotProjectRegistry:

@staticmethod
def test_can_be_loaded(example_project: SavedPenpotProject) -> None:
loaded_project = example_project.load(pull=True)
Expand Down
18 changes: 12 additions & 6 deletions test/penai/test_renderers.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,9 @@ def test_rendering(
)

def test_size_inference(
self, renderer: BaseSVGRenderer, example_svg_path: Path,
self,
renderer: BaseSVGRenderer,
example_svg_path: Path,
) -> None:
img = renderer.render_svg_file(example_svg_path)

Expand All @@ -76,7 +78,9 @@ def test_size_inference(
assert img.size == (view_box.width, view_box.height)

def test_explicit_size_specification(
self, renderer: BaseSVGRenderer, example_svg_path: Path,
self,
renderer: BaseSVGRenderer,
example_svg_path: Path,
) -> None:
orig_aspect_ratio = SVG.from_file(example_svg_path).get_view_box().aspect_ratio

Expand All @@ -85,10 +89,12 @@ def test_explicit_size_specification(

img = renderer.render_svg_file(example_svg_path, width=100)
assert img.size[0] == 100
assert img.size[0] >= img.size[1] if orig_aspect_ratio > 1 else img.size[0] <= img.size[1], \
f"Original aspect ratio: {orig_aspect_ratio}, new size: {img.size}"
assert (
img.size[0] >= img.size[1] if orig_aspect_ratio > 1 else img.size[0] <= img.size[1]
), f"Original aspect ratio: {orig_aspect_ratio}, new size: {img.size}"

img = renderer.render_svg_file(example_svg_path, height=100)
assert img.size[1] == 100
assert img.size[0] >= img.size[1] if orig_aspect_ratio > 1 else img.size[0] <= img.size[1], \
f"Original aspect ratio: {orig_aspect_ratio}, new size: {img.size}"
assert (
img.size[0] >= img.size[1] if orig_aspect_ratio > 1 else img.size[0] <= img.size[1]
), f"Original aspect ratio: {orig_aspect_ratio}, new size: {img.size}"
Loading

0 comments on commit a2c73ea

Please sign in to comment.