Skip to content

Commit

Permalink
style: run ruff
Browse files Browse the repository at this point in the history
  • Loading branch information
matthiasschaub committed Dec 18, 2023
1 parent a27fe76 commit 815456f
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 36 deletions.
2 changes: 1 addition & 1 deletion sketch_map_tool/upload_processing/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .clip import clip
from .detect_markings import detect_markings
from .create_marking_array import create_marking_array
from .detect_markings import detect_markings
from .georeference import georeference
from .merge import merge
from .polygonize import polygonize
Expand Down
2 changes: 1 addition & 1 deletion sketch_map_tool/upload_processing/create_marking_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,4 @@ def create_marking_array(
)
for color, mask in zip(colors, masks):
single_color_marking[mask] = color
return single_color_marking
return single_color_marking
17 changes: 7 additions & 10 deletions sketch_map_tool/upload_processing/detect_markings.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
from segment_anything import SamPredictor
from ultralytics import YOLO

from sketch_map_tool.upload_processing.create_marking_array import create_marking_array


def detect_markings(
image: NDArray,
Expand All @@ -17,9 +15,9 @@ def detect_markings(
# Sam can only deal with RGB and not RGBA etc.
img = Image.fromarray(image[:, :, ::-1]).convert("RGB")
# masks represent markings
masks, bboxes,colors = apply_ml_pipeline(img, yolo_model, sam_predictor)
masks, bboxes, colors = apply_ml_pipeline(img, yolo_model, sam_predictor)
colors = [int(c) + 1 for c in colors] # +1 because 0 is background
masks_processed = post_process(masks,bboxes)
masks_processed = post_process(masks, bboxes)
return masks_processed, colors


Expand Down Expand Up @@ -84,7 +82,7 @@ def apply_sam(
return masks, scores


def mask_from_bbox(bbox:list, sam_predictor: SamPredictor) -> tuple:
def mask_from_bbox(bbox: list, sam_predictor: SamPredictor) -> tuple:
"""Generate a mask using SAM (Segment Anything) predictor for a given bounding box.
Returns:
Expand Down Expand Up @@ -127,17 +125,16 @@ def post_process(masks: list[NDArray], bboxes: list[list[int]]) -> list[NDArray]
kernel = np.ones(kernel_size, np.uint8)

# Apply morphological closing operation
closed_mask = cv2.morphologyEx(mask.astype('uint8'), cv2.MORPH_CLOSE, kernel)
closed_mask = cv2.morphologyEx(mask.astype("uint8"), cv2.MORPH_CLOSE, kernel)

# Find contours
contours, _ = cv2.findContours(closed_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
contours, _ = cv2.findContours(
closed_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
)

# Create a blank canvas for filled contours
filled_contours = np.zeros_like(closed_mask, dtype=np.uint8)
cv2.drawContours(filled_contours, contours, -1, 1, thickness=cv2.FILLED)
cleaned_masks.append(filled_contours.astype(bool))

return cleaned_masks



52 changes: 31 additions & 21 deletions sketch_map_tool/upload_processing/post_process.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
import geojson
from geojson import FeatureCollection
from shapely.geometry import shape
from shapely import Polygon, MultiPolygon
from shapely.geometry import mapping
from shapely import MultiPolygon, Polygon
from shapely.geometry import mapping, shape
from shapely.ops import cascaded_union
from shapelysmooth import chaikin_smooth

from sketch_map_tool.definitions import COLORS

from typing import Union

def post_process(fc: FeatureCollection, name: str) -> FeatureCollection:
fc = clean(fc)
Expand Down Expand Up @@ -68,30 +66,44 @@ def simplify(fc: FeatureCollection) -> FeatureCollection:
# Buffer operation
buffer_distance_percentage = 0.1
max_diag = max(
((geometry.bounds[2] - geometry.bounds[0]) ** 2 + (geometry.bounds[3] - geometry.bounds[1]) ** 2) ** 0.5 for
geometry in geometries) # check for webmercator
(
(geometry.bounds[2] - geometry.bounds[0]) ** 2
+ (geometry.bounds[3] - geometry.bounds[1]) ** 2
)
** 0.5
for geometry in geometries
) # check for webmercator
buffer_distance = buffer_distance_percentage * max_diag
buffered_geometries = [geometry.buffer(buffer_distance) for geometry in geometries]
# Dissolve by color field (assuming there's a "color" field)
try:
dissolved_geometrie = [remove_inner_rings(geometry) for geometry in cascaded_union(buffered_geometries)]
dissolved_geometrie = [
remove_inner_rings(geometry)
for geometry in cascaded_union(buffered_geometries)
]
except:
dissolved_geometrie = [remove_inner_rings(geometry) for geometry in [cascaded_union(buffered_geometries)]]
dissolved_geometrie = [
remove_inner_rings(geometry)
for geometry in [cascaded_union(buffered_geometries)]
]

simplified_geometries = [geometry.buffer(-buffer_distance).simplify(0.0025 * max_diag) for geometry in dissolved_geometrie]
simplified_geometries = [
geometry.buffer(-buffer_distance).simplify(0.0025 * max_diag)
for geometry in dissolved_geometrie
]

# Create a single GeoJSON feature
features = [geojson.Feature(
geometry=mapping(geometry),
properties=properties
) for geometry in simplified_geometries]
features = [
geojson.Feature(geometry=mapping(geometry), properties=properties)
for geometry in simplified_geometries
]

# Create a GeoJSON feature collection with the single feature
fc = geojson.FeatureCollection(features)
return fc


def remove_inner_rings(geometry: Polygon | MultiPolygon) -> Polygon | MultiPolygon:
def remove_inner_rings(geometry: Polygon | MultiPolygon) -> Polygon | MultiPolygon:
"""
Removes inner rings (holes) from a given Shapely geometry object.
Expand All @@ -107,9 +119,9 @@ def remove_inner_rings(geometry: Polygon | MultiPolygon) -> Polygon | MultiPoly
"""
if geometry.is_empty:
return geometry
elif geometry.type == 'Polygon':
elif geometry.type == "Polygon":
return Polygon(geometry.exterior)
elif geometry.type == 'MultiPolygon':
elif geometry.type == "MultiPolygon":
return MultiPolygon([Polygon(poly.exterior) for poly in geometry.geoms])
else:
raise ValueError("Unsupported geometry type")
Expand Down Expand Up @@ -143,12 +155,10 @@ def smooth(fc: FeatureCollection) -> FeatureCollection:

corrected_geometry = chaikin_smooth(geometry)

updated_features.append(geojson.Feature(
geometry=mapping(corrected_geometry),
properties=properties
))
updated_features.append(
geojson.Feature(geometry=mapping(corrected_geometry), properties=properties)
)

# Create a GeoJSON feature collection with the updated features
fc = geojson.FeatureCollection(updated_features)
return fc

8 changes: 5 additions & 3 deletions tests/integration/upload_processing/test_detect_markings.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ def test_detect_markings(sam_predictor, yolo_model, map_frame_marked):


def test_apply_ml_pipeline(sam_predictor, yolo_model, map_frame_marked):
masks, bboxes ,colors = apply_ml_pipeline(map_frame_marked, yolo_model, sam_predictor)
masks, bboxes, colors = apply_ml_pipeline(
map_frame_marked, yolo_model, sam_predictor
)
# TODO: Should the len not be 2? Only two markings are on the input image.
assert len(masks) == len(colors)

Expand All @@ -50,11 +52,11 @@ def test_apply_ml_pipeline_show_masks(
yolo_model,
map_frame_marked,
):
masks, _ ,_ = apply_ml_pipeline(map_frame_marked, yolo_model, sam_predictor)
masks, _, _ = apply_ml_pipeline(map_frame_marked, yolo_model, sam_predictor)
for mask in masks:
plt.imshow(mask, cmap="viridis", alpha=0.7)
plt.show()


def test_post_process():
assert post_process([],[] ) is not None
assert post_process([], []) is not None

0 comments on commit 815456f

Please sign in to comment.