Skip to content

Commit

Permalink
feat(marking-detection): integrate new ml-models
Browse files Browse the repository at this point in the history
Co-authored-by: Clemens Langer <[email protected]>
Co-authored-by: Marcel Reinmuth <[email protected]>
  • Loading branch information
3 people authored Dec 7, 2023
1 parent 82054f0 commit 3d5e227
Show file tree
Hide file tree
Showing 16 changed files with 211 additions and 220 deletions.
3 changes: 2 additions & 1 deletion docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ services:
environment:
SMT_BROKER_URL: "redis://redis:6379"
SMT_RESULT_BACKEND: "db+postgresql://smt:smt@postgres:5432"
SMT_NEPTUNE_API_TOKEN: ""
SMT_NEPTUNE_API_TOKEN: "eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vYXBwLm5lcHR1bmUuYWkiLCJhcGlfdXJsIjoiaHR0cHM6Ly9hcHAubmVwdHVuZS5haSIsImFwaV9rZXkiOiJmMDY5OGU4Yy1kOGE1LTRhOTQtODkwNC0yY2E2NjEzMTQ2OTUifQ=="
SMT_NEPTUNE_MODEL_ID_YOLO: "SMT-OSM-2"
entrypoint:
[
"mamba",
Expand Down
10 changes: 9 additions & 1 deletion sketch_map_tool/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,15 @@
"quality-report", "sketch-map", "raster-results", "vector-results"
]
# Colors to be detected
COLORS = ["red", "blue", "green", "yellow", "turquoise", "pink"]
COLORS = {
"1": "black",
"2": "blue",
"3": "green",
"4": "orange",
"5": "pink",
"6": "red",
"7": "yellow",
}
# Resources for PDF generation
PDF_RESOURCES_PATH = Path(__file__).parent.resolve() / "resources"

Expand Down
46 changes: 29 additions & 17 deletions sketch_map_tool/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,26 @@
from celery.signals import worker_process_init, worker_process_shutdown
from geojson import FeatureCollection
from numpy.typing import NDArray
from segment_anything import SamPredictor, sam_model_registry
from ultralytics import YOLO

from sketch_map_tool import celery_app as celery
from sketch_map_tool import map_generation
from sketch_map_tool import get_config_value, map_generation
from sketch_map_tool.database import client_celery as db_client_celery
from sketch_map_tool.definitions import COLORS
from sketch_map_tool.helpers import to_array
from sketch_map_tool.models import Bbox, PaperFormat, Size
from sketch_map_tool.oqt_analyses import generate_pdf as generate_report_pdf
from sketch_map_tool.oqt_analyses import get_report
from sketch_map_tool.upload_processing import (
clean,
clip,
detect_markings,
enrich,
georeference,
merge,
polygonize,
prepare_img_for_markings,
)
from sketch_map_tool.upload_processing.detect_markings import detect_markings
from sketch_map_tool.upload_processing.ml_models import init_model
from sketch_map_tool.wms import client as wms_client


Expand Down Expand Up @@ -125,29 +126,40 @@ def digitize_sketches(
map_frames: dict[str, NDArray],
bboxes: list[Bbox],
) -> AsyncResult | FeatureCollection:
# Initialize ml-models. This has to happen inside of celery context
#
# Zero shot segment anything model
sam_path = init_model(get_config_value("neptune_model_id_sam"))
sam_model = sam_model_registry["vit_b"](sam_path)
sam_predictor = SamPredictor(sam_model) # mask predictor
# Custom trained model for object detection of markings and colors
yolo_path = init_model(get_config_value("neptune_model_id_yolo"))
yolo_model = YOLO(yolo_path)

def process(
sketch_map_id: int, name: str, uuid: str, bbox: Bbox
sketch_map_id: int,
name: str,
uuid: str,
bbox: Bbox,
sam_predictor,
yolo_model,
) -> FeatureCollection:
"""Process a Sketch Map."""
# r = interim result
r = db_client_celery.select_file(sketch_map_id)
r = to_array(r)
r = clip(r, map_frames[uuid])
r = prepare_img_for_markings(map_frames[uuid], r)
geojsons = []
for color in COLORS:
r_ = detect_markings(r, color)
r_ = georeference(r_, bbox)
r_ = polygonize(r_, color)
r_ = geojson.load(r_)
r_ = clean(r_)
r_ = enrich(r_, {"color": color, "name": name})
geojsons.append(r_)
return merge(geojsons)
r = detect_markings(r, yolo_model, sam_predictor)
r = georeference(r, bbox, bgr=False)
r = polygonize(r, name)
r = geojson.load(r)
r = clean(r)
r = enrich(r, {"name": name})
return r

return merge(
[
process(file_id, name, uuid, bbox)
process(file_id, name, uuid, bbox, sam_predictor, yolo_model)
for file_id, name, uuid, bbox in zip(file_ids, file_names, uuids, bboxes)
]
)
3 changes: 1 addition & 2 deletions sketch_map_tool/upload_processing/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .clean import clean
from .clip import clip
from .detect_markings import detect_markings, prepare_img_for_markings
from .detect_markings import detect_markings
from .enrich import enrich
from .georeference import georeference
from .merge import merge
Expand All @@ -12,7 +12,6 @@
"clean",
"clip",
"detect_markings",
"prepare_img_for_markings",
"georeference",
"read_qr_code",
"polygonize",
Expand Down
2 changes: 1 addition & 1 deletion sketch_map_tool/upload_processing/clean.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def clean(fc: FeatureCollection) -> FeatureCollection:
"""
# f -> feature
# fc -> feature collection
fc.features = [f for f in fc.features if f.properties["color"] == "255"]
fc.features = [f for f in fc.features if f.properties["color"] != "0"]
for f in fc.features:
if not isinstance(f.geometry, geojson.Polygon):
raise TypeError(
Expand Down
195 changes: 88 additions & 107 deletions sketch_map_tool/upload_processing/detect_markings.py
Original file line number Diff line number Diff line change
@@ -1,132 +1,113 @@
# -*- coding: utf-8 -*-
"""
Functions to process images of sketch maps and detect markings on them
"""

import cv2
import numpy as np
from numpy.typing import NDArray
from PIL import Image, ImageEnhance
from PIL import Image
from segment_anything import SamPredictor
from ultralytics import YOLO


def detect_markings(
sketch_map_frame: NDArray,
color: str,
threshold_bgr: float = 0.5,
image: NDArray,
yolo_model: YOLO,
sam_predictor: SamPredictor,
) -> NDArray:
# Sam can only deal with RGB and not RGBA etc.
img = Image.fromarray(image[:, :, ::-1]).convert("RGB")
# masks represent markings
masks, colors = apply_ml_pipeline(img, yolo_model, sam_predictor)
colors = [int(c) + 1 for c in colors] # +1 because 0 is background
return create_marking_array(masks, colors, image)


def apply_ml_pipeline(
image: Image.Image,
yolo_model: YOLO,
sam_predictor: SamPredictor,
) -> tuple[list, list]:
"""Apply the entire machine learning pipeline on an image.
Steps:
1. Apply YOLO to detect bounding boxes and label (colors) of objects (markings)
2. Apply SAM to create binary masks of detected objects
Returns:
tuple: A list of masks and class labels.
A mask is a binary numpy array with same dimensions as input image
(map frame), masking the dominant segment inside of a bbox detected by YOLO.
A class label is a color.
"""
Detect markings in the colours blue, green, red, pink, turquoise, white, and yellow
Note that there must be a sufficient difference between the colour of the markings
and the background. White and yellow markings might not be detected on sketch maps.
bounding_boxes, class_labels = apply_yolo(image, yolo_model)
masks, _ = apply_sam(image, bounding_boxes, sam_predictor)
return masks, class_labels

:param sketch_map_frame: TODO
:param threshold_bgr: Threshold for the colour detection. 0.5 means 50%, i.e. all
BGR values above 50% * 255 will be considered 255,
all values below this threshold will be considered
0 for determining the colour of the markings.
"""
threshold_bgr_abs = threshold_bgr * 255

colors = {
"white": (255, 255, 255),
"red": (0, 0, 255),
"blue": (255, 0, 0),
"green": (0, 255, 0),
"yellow": (0, 255, 255),
"turquoise": (255, 255, 0),
"pink": (255, 0, 255),
}
bgr = colors[color]

# for color, bgr in colors.items():
single_color_marking = np.zeros_like(sketch_map_frame, np.uint8)
single_color_marking[
(
(sketch_map_frame[:, :, 0] < threshold_bgr_abs)
== (bgr[0] < threshold_bgr_abs)
)
& (
(sketch_map_frame[:, :, 1] < threshold_bgr_abs)
== (bgr[1] < threshold_bgr_abs)
)
& (
(sketch_map_frame[:, :, 2] < threshold_bgr_abs)
== (bgr[2] < threshold_bgr_abs)
)
] = 255
single_color_marking = _reduce_noise(single_color_marking)
single_color_marking = _reduce_holes(single_color_marking)
single_color_marking[single_color_marking > 0] = 255
return single_color_marking

def apply_yolo(
image: Image.Image,
yolo_model: YOLO,
) -> tuple[list, list]:
"""Apply YOLO object detection on an image.
def prepare_img_for_markings(
img_base: NDArray,
img_markings: NDArray,
threshold_img_diff: int = 100,
) -> NDArray:
Returns:
tuple: Detected bounding boxes around individual markings and corresponding
class labels.
"""
TODO pydoc
result = yolo_model(image)[0].boxes
bounding_boxes = result.xyxy
class_labels = result.cls
return bounding_boxes, class_labels

:param threshold_img_diff: Threshold for the marking detection concerning the
absolute grayscale difference between corresponding pixels
in 'img_base' and 'img_markings'.
"""
img_base_height, img_base_width, _ = img_base.shape
img_markings = cv2.resize(
img_markings,
(img_base_width, img_base_height),
fx=4,
fy=4,
interpolation=cv2.INTER_NEAREST,
)
img_markings_contrast = _enhance_contrast(img_markings)
img_diff = cv2.absdiff(img_base, img_markings_contrast)
img_diff_gray = cv2.cvtColor(img_diff, cv2.COLOR_BGR2GRAY)
mask_markings = img_diff_gray > threshold_img_diff
markings_multicolor = np.zeros_like(img_markings, np.uint8)
markings_multicolor[mask_markings] = img_markings[mask_markings]
return markings_multicolor

def apply_sam(
image: Image.Image,
bounding_boxes: list,
sam_predictor: SamPredictor,
) -> tuple:
"""Apply SAM (Segment Anything) on an image using bounding boxes.
def _enhance_contrast(img: NDArray, factor: float = 2.0) -> NDArray:
"""
Enhance the contrast of a given image
Creates masks based on image segmentation and bbox from object detection (YOLO)
:param img: Image of which the contrast should be enhanced.
:param factor: Factor for the contrast enhancement.
:return: Image with enhanced contrast.
Returns:
tuple: List of masks and corresponding scores.
"""
input_img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
result = ImageEnhance.Contrast(input_img).enhance(factor)
return cv2.cvtColor(np.array(result), cv2.COLOR_RGB2BGR)
sam_predictor.set_image(np.array(image))
masks = []
scores = []
for bbox in bounding_boxes:
mask, score = mask_from_bbox(np.array(bbox), sam_predictor)
masks.append(mask)
scores.append(score)
return masks, scores


def _reduce_noise(img: NDArray, factor: int = 2) -> NDArray:
"""
Reduce the noise, i.e. artifacts, in an image containing markings
def mask_from_bbox(bbox, sam_predictor: SamPredictor) -> tuple:
"""Generate a mask using SAM (Segment Anything) predictor for a given bounding box.
:param img: Image in which the noise should be reduced.
:param factor: Kernel size (x*x) for the noise reduction.
:return: 'img' with less noise.
Returns:
tuple: Mask and corresponding score.
"""
# See https://docs.opencv.org/4.x/d9/d61/tutorial_py_morphological_ops.html
reduced_noise = cv2.morphologyEx(
img, cv2.MORPH_OPEN, np.ones((factor, factor), np.uint8)
)
# TODO: Long running job in next line -> Does the slightly improved noise
# reduction justify keeping it?
return cv2.fastNlMeansDenoisingColored(reduced_noise, None, 30, 30, 20, 21)
masks, scores, _ = sam_predictor.predict(box=bbox, multimask_output=False)
return masks[0], scores[0]


def _reduce_holes(img: NDArray, factor: int = 4) -> NDArray:
"""
Reduce the holes in markings on a given image
def create_marking_array(
masks: list[NDArray],
colors: list[int],
image: NDArray,
) -> NDArray:
"""Create a single color marking array based on masks and colors.
:param img: Image in which the holes should be reduced.
:param factor: Kernel size (x*x) of the reduction.
:return: 'img' with fewer and smaller holes.
Parameters:
- masks: List of masks representing markings.
- colors: List of colors corresponding to each mask.
- image: Original sketch map frame.
Returns:
NDArray: Single color marking array.
"""
# See https://docs.opencv.org/4.x/d9/d61/tutorial_py_morphological_ops.html
return cv2.morphologyEx(img, cv2.MORPH_CLOSE, np.ones((factor, factor), np.uint8))
single_color_marking = np.zeros(
(image.shape[0], image.shape[1]),
dtype=np.uint8,
)
for color, mask in zip(colors, masks):
single_color_marking[mask] = color
return single_color_marking
6 changes: 5 additions & 1 deletion sketch_map_tool/upload_processing/enrich.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
from geojson import FeatureCollection

from sketch_map_tool.definitions import COLORS


def enrich(fc: FeatureCollection, properties):
"""Enrich GeoJSON properties."""
"""Enrich GeoJSON properties and map colors."""
for feature in fc.features:
feature.properties = feature.properties | properties
if "color" in feature.properties.keys():
feature.properties["color"] = COLORS[feature.properties["color"]]
return fc
Loading

0 comments on commit 3d5e227

Please sign in to comment.