Skip to content

Commit

Permalink
Fix __eq__ for nested Geometrycollections
Browse files Browse the repository at this point in the history
  • Loading branch information
cleder committed Oct 12, 2023
1 parent 0b322cf commit c2770ac
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 19 deletions.
2 changes: 1 addition & 1 deletion pygeoif/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def shape(
shape(fi) for fi in geometry["geometries"] # type: ignore [typeddict-item]
]
return GeometryCollection(geometries) # type: ignore [arg-type]
raise NotImplementedError(f"[{geometry['type']} is nor implemented")
raise NotImplementedError(f"[{geometry['type']} is not implemented")


def num(number: str) -> float:
Expand Down
27 changes: 26 additions & 1 deletion pygeoif/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
from typing import cast

from pygeoif.types import CoordinatesType
from pygeoif.types import GeoCollectionInterface
from pygeoif.types import GeoInterface
from pygeoif.types import LineType
from pygeoif.types import MultiCoordinatesType
from pygeoif.types import Point2D
Expand Down Expand Up @@ -74,7 +76,8 @@ def centroid(coords: LineType) -> Tuple[Point2D, float]:


def _cross(o: Point2D, a: Point2D, b: Point2D) -> float:
"""2D cross product of OA and OB vectors, i.e. z-component of their 3D cross product.
"""
2D cross product of OA and OB vectors, i.e. z-component of their 3D cross product.
Returns a positive value, if OAB makes a counter-clockwise turn,
negative for clockwise turn, and zero if the points are collinear.
Expand Down Expand Up @@ -157,6 +160,28 @@ def compare_coordinates(
return False


def compare_geo_interface(
if1: Union[GeoInterface, GeoCollectionInterface],
if2: Union[GeoInterface, GeoCollectionInterface],
) -> bool:
"""Compare two geo interfaces."""
if if1["type"] != if2["type"]:
return False
if if1["type"] == "GeometryCollection":
return all(
compare_geo_interface(g1, g2) # type: ignore [arg-type]
for g1, g2 in zip_longest(
if1["geometries"], # type: ignore [typeddict-item]
if2["geometries"], # type: ignore [typeddict-item]
fillvalue={"type": None, "coordinates": ()},
)
)
return compare_coordinates(
if1["coordinates"], # type: ignore [typeddict-item]
if2["coordinates"], # type: ignore [typeddict-item]
)


__all__ = [
"centroid",
"compare_coordinates",
Expand Down
20 changes: 7 additions & 13 deletions pygeoif/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from pygeoif.exceptions import DimensionError
from pygeoif.functions import centroid
from pygeoif.functions import compare_coordinates
from pygeoif.functions import compare_geo_interface
from pygeoif.functions import convex_hull
from pygeoif.functions import dedupe
from pygeoif.functions import signed_area
Expand Down Expand Up @@ -1015,6 +1016,8 @@ def __eq__(self, other: object) -> bool:
Types and coordinates from all contained geometries must be equal.
"""
try:
if self.is_empty:
return False
if (
other.__geo_interface__.get("type") # type: ignore [attr-defined]
!= self.geom_type
Expand All @@ -1031,18 +1034,9 @@ def __eq__(self, other: object) -> bool:
return False
except AttributeError:
return False
return all(
(
s["type"] == o.get("type")
and compare_coordinates(s["coordinates"], o.get("coordinates"))
for s, o in zip(
(geom.__geo_interface__ for geom in self.geoms),
other.__geo_interface__.get( # type: ignore [attr-defined]
"geometries",
[],
),
)
)
return compare_geo_interface(
self.__geo_interface__,
other.__geo_interface__, # type: ignore [attr-defined]
)

def __len__(self) -> int:
Expand All @@ -1066,7 +1060,7 @@ def _wkt_coords(self) -> str:
def __geo_interface__(self) -> GeoCollectionInterface: # type: ignore [override]
"""Return the geo interface of the collection."""
return {
"type": self.geom_type,
"type": "GeometryCollection",
"geometries": tuple(geom.__geo_interface__ for geom in self.geoms),
}

Expand Down
5 changes: 3 additions & 2 deletions pygeoif/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from typing import Tuple
from typing import Union

from typing_extensions import Literal
from typing_extensions import Protocol
from typing_extensions import TypedDict

Expand Down Expand Up @@ -63,8 +64,8 @@ class GeoInterface(GeoInterfaceBase, total=False):
class GeoCollectionInterface(TypedDict):
"""Geometry Collection Interface."""

type: str
geometries: Sequence[GeoInterface]
type: Literal["GeometryCollection"]
geometries: Sequence[Union[GeoInterface, "GeoCollectionInterface"]]


class GeoFeatureInterfaceBase(TypedDict):
Expand Down
2 changes: 0 additions & 2 deletions tests/test_geometrycollection.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""Test Baseclass."""
import pytest

from pygeoif import geometry

Expand Down Expand Up @@ -360,7 +359,6 @@ def test_nested_geometry_collection_geo_interface() -> None:
}


@pytest.mark.xfail(reason="Not implemented yet")
def test_nested_geometry_collection_eq() -> None:
multipoint = geometry.MultiPoint([(0, 0), (1, 1), (1, 2), (2, 2)])
gc1 = geometry.GeometryCollection([geometry.Point(0, 0), multipoint])
Expand Down

0 comments on commit c2770ac

Please sign in to comment.