Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(upload-processing): post processing of markings #326

Merged
merged 9 commits into from
Dec 18, 2023
126 changes: 125 additions & 1 deletion poetry.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ torchvision = {version = "^0.16.0+cpu", source = "pytorch"}
ultralytics = "^8.0.209"
segment-anything = {git = "https://github.com/facebookresearch/segment-anything.git"}
pyzbar = "^0.1.9"
shapelysmooth = "^0.1.1"

[tool.poetry.group.dev.dependencies]
# Versions are fixed to match versions used by pre-commit
Expand Down
51 changes: 15 additions & 36 deletions sketch_map_tool/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from uuid import UUID
from zipfile import ZipFile

import geojson
from celery.result import AsyncResult
from celery.signals import worker_process_init, worker_process_shutdown
from geojson import FeatureCollection
Expand All @@ -18,12 +17,11 @@
from sketch_map_tool.oqt_analyses import generate_pdf as generate_report_pdf
from sketch_map_tool.oqt_analyses import get_report
from sketch_map_tool.upload_processing import (
clean,
clip,
enrich,
georeference,
merge,
polygonize,
post_process,
)
from sketch_map_tool.upload_processing.detect_markings import detect_markings
from sketch_map_tool.upload_processing.ml_models import init_model
Expand Down Expand Up @@ -92,13 +90,7 @@ def georeference_sketch_maps(
bboxes: list[Bbox],
) -> AsyncResult | BytesIO:
def process(sketch_map_id: int, uuid: str, bbox: Bbox) -> BytesIO:
"""Process a Sketch Map.

:param sketch_map_id: ID under which uploaded file is stored in the database.
:param uuid: UUID under which the sketch map was created.
:bbox: Bounding box of the AOI on the sketch map.
:return: Georeferenced image (GeoTIFF) of the sketch map .
"""
"""Process a Sketch Map."""
# r = interim result
r = db_client_celery.select_file(sketch_map_id)
r = to_array(r)
Expand Down Expand Up @@ -129,36 +121,23 @@ def digitize_sketches(
# Initialize ml-models. This has to happen inside of celery context.
# Custom trained model for object detection of markings and colors
yolo_path = init_model(get_config_value("neptune_model_id_yolo"))
yolo_model = YOLO(yolo_path)
yolo_model: YOLO = YOLO(yolo_path)
# Zero shot segment anything model
sam_path = init_model(get_config_value("neptune_model_id_sam"))
sam_model = sam_model_registry["vit_b"](sam_path)
sam_predictor = SamPredictor(sam_model) # mask predictor

def process(
sketch_map_id: int,
name: str,
uuid: str,
bbox: Bbox,
sam_predictor: SamPredictor,
yolo_model: YOLO,
) -> FeatureCollection:
"""Process a Sketch Map."""
sam_predictor: SamPredictor = SamPredictor(sam_model) # mask predictor

l = [] # noqa: E741
for file_id, file_name, uuid, bbox in zip(file_ids, file_names, uuids, bboxes):
# r = interim result
r: BytesIO = db_client_celery.select_file(sketch_map_id) # type: ignore
r: BytesIO = db_client_celery.select_file(file_id) # type: ignore
r: NDArray = to_array(r) # type: ignore
r: NDArray = clip(r, map_frames[uuid]) # type: ignore
r: NDArray = detect_markings(r, yolo_model, sam_predictor) # type: ignore
r: BytesIO = georeference(r, bbox, bgr=False) # type: ignore
r: BytesIO = polygonize(r, name) # type: ignore
r: FeatureCollection = geojson.load(r) # type: ignore
r: FeatureCollection = clean(r) # type: ignore
r: FeatureCollection = enrich(r, {"name": name}) # type: ignore
return r

return merge(
[
process(file_id, name, uuid, bbox, sam_predictor, yolo_model)
for file_id, name, uuid, bbox in zip(file_ids, file_names, uuids, bboxes)
]
)
# m = marking
for m in r:
m: BytesIO = georeference(m, bbox, bgr=False) # type: ignore
m: FeatureCollection = polygonize(m, layer_name=file_name) # type: ignore
m: FeatureCollection = post_process(m, file_name)
l.append(m)
return merge(l)
10 changes: 4 additions & 6 deletions sketch_map_tool/upload_processing/__init__.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,17 @@
from .clean import clean
from .clip import clip
from .detect_markings import detect_markings
from .enrich import enrich
from .georeference import georeference
from .merge import merge
from .polygonize import polygonize
from .post_process import post_process
from .qr_code_reader import read as read_qr_code

__all__ = (
"enrich",
"clean",
"clip",
"detect_markings",
"georeference",
"read_qr_code",
"polygonize",
"merge",
"polygonize",
"post_process",
"read_qr_code",
)
78 changes: 53 additions & 25 deletions sketch_map_tool/upload_processing/detect_markings.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-

import cv2
import numpy as np
from numpy.typing import NDArray
from PIL import Image
Expand All @@ -11,20 +11,21 @@ def detect_markings(
image: NDArray,
yolo_model: YOLO,
sam_predictor: SamPredictor,
) -> NDArray:
# Sam can only deal with RGB and not RGBA etc.
) -> list[NDArray]:
# SAM can only deal with RGB and not RGBA etc.
img = Image.fromarray(image[:, :, ::-1]).convert("RGB")
# masks represent markings
masks, colors = apply_ml_pipeline(img, yolo_model, sam_predictor)
masks, bboxes, colors = apply_ml_pipeline(img, yolo_model, sam_predictor)
colors = [int(c) + 1 for c in colors] # +1 because 0 is background
return create_marking_array(masks, colors, image)
processed_markings = post_process(masks, bboxes, colors)
return processed_markings


def apply_ml_pipeline(
image: Image.Image,
yolo_model: YOLO,
sam_predictor: SamPredictor,
) -> tuple[list, list]:
) -> tuple[list, list, list]:
"""Apply the entire machine learning pipeline on an image.

Steps:
Expand All @@ -39,7 +40,7 @@ def apply_ml_pipeline(
"""
bounding_boxes, class_labels = apply_yolo(image, yolo_model)
masks, _ = apply_sam(image, bounding_boxes, sam_predictor)
return masks, class_labels
return masks, bounding_boxes, class_labels


def apply_yolo(
Expand Down Expand Up @@ -81,7 +82,7 @@ def apply_sam(
return masks, scores


def mask_from_bbox(bbox, sam_predictor: SamPredictor) -> tuple:
def mask_from_bbox(bbox: list, sam_predictor: SamPredictor) -> tuple:
"""Generate a mask using SAM (Segment Anything) predictor for a given bounding box.

Returns:
Expand All @@ -92,24 +93,51 @@ def mask_from_bbox(bbox, sam_predictor: SamPredictor) -> tuple:


def create_marking_array(
masks: list[NDArray],
colors: list[int],
image: NDArray,
mask: NDArray,
color: int,
) -> NDArray:
"""Create a single color marking array based on masks and colors.
"""Create a single color marking array based on masks and colors."""
single_color_marking = np.zeros(mask.shape, dtype=np.uint8)
single_color_marking[mask] = color
return single_color_marking

Parameters:
- masks: List of masks representing markings.
- colors: List of colors corresponding to each mask.
- image: Original sketch map frame.

Returns:
NDArray: Single color marking array.
def post_process(
masks: list[NDArray],
bboxes: list[list[int]],
colors,
) -> list[NDArray]:
"""Post-processes masks and bounding boxes to clean-up and fill contours.

Apply morphological operations to clean the masks, creates contours and fills them.
"""
single_color_marking = np.zeros(
(image.shape[0], image.shape[1]),
dtype=np.uint8,
)
for color, mask in zip(colors, masks):
single_color_marking[mask] = color
return single_color_marking
# Convert and preprocess masks
preprocessed_masks = np.array([np.vstack(mask) for mask in masks], dtype=np.float32)
preprocessed_masks[preprocessed_masks == 0] = np.nan

# Calculate height and width for each bounding box
bbox_sizes = [np.array([bbox[2] - bbox[0], bbox[3] - bbox[1]]) for bbox in bboxes]

processed_markings = []
for i, (mask, color) in enumerate(zip(preprocessed_masks, colors)):
# Calculate kernel size as 5% of the bounding box dimensions
kernel_size = tuple((bbox_sizes[i] * 0.05).astype(int))
kernel = np.ones(kernel_size, np.uint8)

# Apply morphological closing operation
mask_closed = cv2.morphologyEx(mask.astype("uint8"), cv2.MORPH_CLOSE, kernel)

# Find contours
mask_contour, _ = cv2.findContours(
mask_closed,
cv2.RETR_EXTERNAL,
cv2.CHAIN_APPROX_SIMPLE,
)

# Create a blank canvas for filled contours
mask_filled = np.zeros_like(mask_closed, dtype=np.uint8)
cv2.drawContours(mask_filled, mask_contour, -1, 1, thickness=cv2.FILLED)

# Mask to markings array
processed_markings.append(create_marking_array(mask_filled.astype(bool), color))
return processed_markings
15 changes: 8 additions & 7 deletions sketch_map_tool/upload_processing/polygonize.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,22 @@
from tempfile import NamedTemporaryFile, TemporaryDirectory

import geojson
from geojson import FeatureCollection
from osgeo import gdal, ogr
from pyproj import Transformer


def transform(feature: geojson.FeatureCollection):
def transform(feature: FeatureCollection) -> FeatureCollection:
"""Reproject GeoJSON from WebMercator to EPSG:4326"""
transformer = Transformer.from_crs("EPSG:3857", "EPSG:4326", always_xy=True)
return geojson.utils.map_tuples(
raw = geojson.utils.map_tuples( # type: ignore
lambda coordinates: transformer.transform(coordinates[0], coordinates[1]),
deepcopy(feature),
)
return geojson.loads(geojson.dumps(raw))


def polygonize(geotiff: BytesIO, layer_name: str) -> BytesIO:
def polygonize(geotiff: BytesIO, layer_name: str) -> FeatureCollection:
"""Produces a polygon feature layer (GeoJSON) from a raster (GeoTIFF)."""
gdal.UseExceptions()
ogr.UseExceptions()
Expand Down Expand Up @@ -47,7 +49,6 @@ def polygonize(geotiff: BytesIO, layer_name: str) -> BytesIO:
src_ds = None # close dataset
dst_ds = None # close dataset

with open(str(outfile_name), "rb") as f:
feature = geojson.FeatureCollection(geojson.load(f))
feature = transform(feature)
return BytesIO(geojson.dumps(feature).encode())
with open(outfile_name, "rb") as f:
fc = geojson.load(f)
return transform(fc)
Loading