-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(marking-detection): integrate new ml-models
Co-authored-by: Clemens Langer <[email protected]> Co-authored-by: Marcel Reinmuth <[email protected]>
- Loading branch information
1 parent
82054f0
commit 3d5e227
Showing
16 changed files
with
211 additions
and
220 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,132 +1,113 @@ | ||
# -*- coding: utf-8 -*- | ||
""" | ||
Functions to process images of sketch maps and detect markings on them | ||
""" | ||
|
||
import cv2 | ||
import numpy as np | ||
from numpy.typing import NDArray | ||
from PIL import Image, ImageEnhance | ||
from PIL import Image | ||
from segment_anything import SamPredictor | ||
from ultralytics import YOLO | ||
|
||
|
||
def detect_markings( | ||
sketch_map_frame: NDArray, | ||
color: str, | ||
threshold_bgr: float = 0.5, | ||
image: NDArray, | ||
yolo_model: YOLO, | ||
sam_predictor: SamPredictor, | ||
) -> NDArray: | ||
# Sam can only deal with RGB and not RGBA etc. | ||
img = Image.fromarray(image[:, :, ::-1]).convert("RGB") | ||
# masks represent markings | ||
masks, colors = apply_ml_pipeline(img, yolo_model, sam_predictor) | ||
colors = [int(c) + 1 for c in colors] # +1 because 0 is background | ||
return create_marking_array(masks, colors, image) | ||
|
||
|
||
def apply_ml_pipeline( | ||
image: Image.Image, | ||
yolo_model: YOLO, | ||
sam_predictor: SamPredictor, | ||
) -> tuple[list, list]: | ||
"""Apply the entire machine learning pipeline on an image. | ||
Steps: | ||
1. Apply YOLO to detect bounding boxes and label (colors) of objects (markings) | ||
2. Apply SAM to create binary masks of detected objects | ||
Returns: | ||
tuple: A list of masks and class labels. | ||
A mask is a binary numpy array with same dimensions as input image | ||
(map frame), masking the dominant segment inside of a bbox detected by YOLO. | ||
A class label is a color. | ||
""" | ||
Detect markings in the colours blue, green, red, pink, turquoise, white, and yellow | ||
Note that there must be a sufficient difference between the colour of the markings | ||
and the background. White and yellow markings might not be detected on sketch maps. | ||
bounding_boxes, class_labels = apply_yolo(image, yolo_model) | ||
masks, _ = apply_sam(image, bounding_boxes, sam_predictor) | ||
return masks, class_labels | ||
|
||
:param sketch_map_frame: TODO | ||
:param threshold_bgr: Threshold for the colour detection. 0.5 means 50%, i.e. all | ||
BGR values above 50% * 255 will be considered 255, | ||
all values below this threshold will be considered | ||
0 for determining the colour of the markings. | ||
""" | ||
threshold_bgr_abs = threshold_bgr * 255 | ||
|
||
colors = { | ||
"white": (255, 255, 255), | ||
"red": (0, 0, 255), | ||
"blue": (255, 0, 0), | ||
"green": (0, 255, 0), | ||
"yellow": (0, 255, 255), | ||
"turquoise": (255, 255, 0), | ||
"pink": (255, 0, 255), | ||
} | ||
bgr = colors[color] | ||
|
||
# for color, bgr in colors.items(): | ||
single_color_marking = np.zeros_like(sketch_map_frame, np.uint8) | ||
single_color_marking[ | ||
( | ||
(sketch_map_frame[:, :, 0] < threshold_bgr_abs) | ||
== (bgr[0] < threshold_bgr_abs) | ||
) | ||
& ( | ||
(sketch_map_frame[:, :, 1] < threshold_bgr_abs) | ||
== (bgr[1] < threshold_bgr_abs) | ||
) | ||
& ( | ||
(sketch_map_frame[:, :, 2] < threshold_bgr_abs) | ||
== (bgr[2] < threshold_bgr_abs) | ||
) | ||
] = 255 | ||
single_color_marking = _reduce_noise(single_color_marking) | ||
single_color_marking = _reduce_holes(single_color_marking) | ||
single_color_marking[single_color_marking > 0] = 255 | ||
return single_color_marking | ||
|
||
def apply_yolo( | ||
image: Image.Image, | ||
yolo_model: YOLO, | ||
) -> tuple[list, list]: | ||
"""Apply YOLO object detection on an image. | ||
def prepare_img_for_markings( | ||
img_base: NDArray, | ||
img_markings: NDArray, | ||
threshold_img_diff: int = 100, | ||
) -> NDArray: | ||
Returns: | ||
tuple: Detected bounding boxes around individual markings and corresponding | ||
class labels. | ||
""" | ||
TODO pydoc | ||
result = yolo_model(image)[0].boxes | ||
bounding_boxes = result.xyxy | ||
class_labels = result.cls | ||
return bounding_boxes, class_labels | ||
|
||
:param threshold_img_diff: Threshold for the marking detection concerning the | ||
absolute grayscale difference between corresponding pixels | ||
in 'img_base' and 'img_markings'. | ||
""" | ||
img_base_height, img_base_width, _ = img_base.shape | ||
img_markings = cv2.resize( | ||
img_markings, | ||
(img_base_width, img_base_height), | ||
fx=4, | ||
fy=4, | ||
interpolation=cv2.INTER_NEAREST, | ||
) | ||
img_markings_contrast = _enhance_contrast(img_markings) | ||
img_diff = cv2.absdiff(img_base, img_markings_contrast) | ||
img_diff_gray = cv2.cvtColor(img_diff, cv2.COLOR_BGR2GRAY) | ||
mask_markings = img_diff_gray > threshold_img_diff | ||
markings_multicolor = np.zeros_like(img_markings, np.uint8) | ||
markings_multicolor[mask_markings] = img_markings[mask_markings] | ||
return markings_multicolor | ||
|
||
def apply_sam( | ||
image: Image.Image, | ||
bounding_boxes: list, | ||
sam_predictor: SamPredictor, | ||
) -> tuple: | ||
"""Apply SAM (Segment Anything) on an image using bounding boxes. | ||
def _enhance_contrast(img: NDArray, factor: float = 2.0) -> NDArray: | ||
""" | ||
Enhance the contrast of a given image | ||
Creates masks based on image segmentation and bbox from object detection (YOLO) | ||
:param img: Image of which the contrast should be enhanced. | ||
:param factor: Factor for the contrast enhancement. | ||
:return: Image with enhanced contrast. | ||
Returns: | ||
tuple: List of masks and corresponding scores. | ||
""" | ||
input_img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) | ||
result = ImageEnhance.Contrast(input_img).enhance(factor) | ||
return cv2.cvtColor(np.array(result), cv2.COLOR_RGB2BGR) | ||
sam_predictor.set_image(np.array(image)) | ||
masks = [] | ||
scores = [] | ||
for bbox in bounding_boxes: | ||
mask, score = mask_from_bbox(np.array(bbox), sam_predictor) | ||
masks.append(mask) | ||
scores.append(score) | ||
return masks, scores | ||
|
||
|
||
def _reduce_noise(img: NDArray, factor: int = 2) -> NDArray: | ||
""" | ||
Reduce the noise, i.e. artifacts, in an image containing markings | ||
def mask_from_bbox(bbox, sam_predictor: SamPredictor) -> tuple: | ||
"""Generate a mask using SAM (Segment Anything) predictor for a given bounding box. | ||
:param img: Image in which the noise should be reduced. | ||
:param factor: Kernel size (x*x) for the noise reduction. | ||
:return: 'img' with less noise. | ||
Returns: | ||
tuple: Mask and corresponding score. | ||
""" | ||
# See https://docs.opencv.org/4.x/d9/d61/tutorial_py_morphological_ops.html | ||
reduced_noise = cv2.morphologyEx( | ||
img, cv2.MORPH_OPEN, np.ones((factor, factor), np.uint8) | ||
) | ||
# TODO: Long running job in next line -> Does the slightly improved noise | ||
# reduction justify keeping it? | ||
return cv2.fastNlMeansDenoisingColored(reduced_noise, None, 30, 30, 20, 21) | ||
masks, scores, _ = sam_predictor.predict(box=bbox, multimask_output=False) | ||
return masks[0], scores[0] | ||
|
||
|
||
def _reduce_holes(img: NDArray, factor: int = 4) -> NDArray: | ||
""" | ||
Reduce the holes in markings on a given image | ||
def create_marking_array( | ||
masks: list[NDArray], | ||
colors: list[int], | ||
image: NDArray, | ||
) -> NDArray: | ||
"""Create a single color marking array based on masks and colors. | ||
:param img: Image in which the holes should be reduced. | ||
:param factor: Kernel size (x*x) of the reduction. | ||
:return: 'img' with fewer and smaller holes. | ||
Parameters: | ||
- masks: List of masks representing markings. | ||
- colors: List of colors corresponding to each mask. | ||
- image: Original sketch map frame. | ||
Returns: | ||
NDArray: Single color marking array. | ||
""" | ||
# See https://docs.opencv.org/4.x/d9/d61/tutorial_py_morphological_ops.html | ||
return cv2.morphologyEx(img, cv2.MORPH_CLOSE, np.ones((factor, factor), np.uint8)) | ||
single_color_marking = np.zeros( | ||
(image.shape[0], image.shape[1]), | ||
dtype=np.uint8, | ||
) | ||
for color, mask in zip(colors, masks): | ||
single_color_marking[mask] = color | ||
return single_color_marking |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,12 @@ | ||
from geojson import FeatureCollection | ||
|
||
from sketch_map_tool.definitions import COLORS | ||
|
||
|
||
def enrich(fc: FeatureCollection, properties): | ||
"""Enrich GeoJSON properties.""" | ||
"""Enrich GeoJSON properties and map colors.""" | ||
for feature in fc.features: | ||
feature.properties = feature.properties | properties | ||
if "color" in feature.properties.keys(): | ||
feature.properties["color"] = COLORS[feature.properties["color"]] | ||
return fc |
Oops, something went wrong.