Skip to content

Commit

Permalink
Merge pull request #133 from IDEA-Research/feature/save_annotations
Browse files Browse the repository at this point in the history
feature(save annotations): support save segmentation, keypoints and mask
  • Loading branch information
imhuwq authored Feb 28, 2024
2 parents 131cd5e + 144b780 commit 088f64f
Show file tree
Hide file tree
Showing 6 changed files with 76 additions and 202 deletions.
1 change: 1 addition & 0 deletions deepdataspace/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ class AnnotationType:
Segmentation = "Segmentation" #: The annotation segments the object.
Matting = "Matting" #: The annotation matting the object.
KeyPoints = "KeyPoints" #: The annotation marks the keypoints of the object.
Mask = "Mask" #: The annotation contains RLE format of mask


class TaskStatus:
Expand Down
38 changes: 27 additions & 11 deletions deepdataspace/model/label_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,7 @@ def _export_dataset(self, dataset: DataSet, label_set_name: str):
LTImage = LabelTaskImage(dataset.id)
images, offset = self._get_image_batch(dataset.id, 0)

has_bbox = False # whether the annotations have bbox
obj_types = set()

# iter every label image, save every annotation to target image
for ltimage in LTImage.find_many({}, sort=[("image_id", 1)]):
Expand Down Expand Up @@ -517,26 +517,42 @@ def _export_dataset(self, dataset: DataSet, label_set_name: str):
if not category:
continue

obj_types.add(AnnotationType.Classification)

bounding_box = anno["bounding_box"]
if not bounding_box:
continue
segmentation = anno["segmentation"]
mask = anno["mask"]
points = anno["points"]
lines = anno["lines"]
point_names = anno["point_names"]
point_colors = anno["point_colors"]

has_bbox = True
cat_obj = self._get_category(dataset.id, category, categories)
if bool(bounding_box):
obj_types.add(AnnotationType.Detection)

if bool(segmentation):
obj_types.add(AnnotationType.Segmentation)

if bool(mask):
obj_types.add(AnnotationType.Mask)

if bool(points) and bool(lines) and bool(point_names) and bool(point_colors):
obj_types.add(AnnotationType.KeyPoints)

cat_obj = self._get_category(dataset.id, category, categories)
anno_obj = Object(label_name=label_obj.name, label_type=label_obj.type, label_id=label_obj.id,
category_name=cat_obj.name, category_id=cat_obj.id,
bounding_box=anno["bounding_box"])
bounding_box=bounding_box, segmentation=segmentation, mask=mask,
points=points, lines=lines, point_names=point_names, point_colors=point_colors)
image.objects.append(anno_obj)
image.batch_save()

Image(dataset.id).finish_batch_save()

if has_bbox:
if AnnotationType.Classification not in dataset.object_types:
dataset.object_types.append(AnnotationType.Classification)
if AnnotationType.Detection not in dataset.object_types:
dataset.object_types.append(AnnotationType.Detection)
cur_obj_types = set(dataset.object_types)
if cur_obj_types != obj_types:
obj_types.union(cur_obj_types)
dataset.object_types = list(sorted(obj_types))
dataset.save()

def export_project(self, label_set_name: str):
Expand Down
1 change: 1 addition & 0 deletions deepdataspace/model/object.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,4 @@ def get_collection(cls, *args, **kwargs):
caption: Optional[str] = ""
compare_result: Optional[Dict[str, str]] = {} # {"90": "FP", ..., "10": "OK"}
matched_det_idx: Optional[int] = None # The matched ground truth index, for prediction objects only.
mask: dict = {}
2 changes: 0 additions & 2 deletions deepdataspace/server/resources/api_v1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from deepdataspace.server.resources.api_v1 import lints
from deepdataspace.server.resources.api_v1 import login
from deepdataspace.server.resources.api_v1 import ping
from deepdataspace.server.resources.api_v1.annotations import AnnotationsView
from deepdataspace.server.resources.api_v1.comparisons import ComparisonsView
from deepdataspace.server.resources.api_v1.datasets import DatasetView
from deepdataspace.server.resources.api_v1.datasets import DatasetsView
Expand All @@ -28,7 +27,6 @@
path("datasets/<dataset_id>", DatasetView.as_view()),
path("image_flags", ImageFlagsView.as_view()),
path("label_clone", LabelCloneView.as_view()),
path("annotations", AnnotationsView.as_view()),
path("comparisons", ComparisonsView.as_view()),
path("label_projects", label_tasks.ProjectsView.as_view()),
path("label_projects/<project_id>", label_tasks.ProjectView.as_view()),
Expand Down
163 changes: 0 additions & 163 deletions deepdataspace/server/resources/api_v1/annotations.py

This file was deleted.

73 changes: 47 additions & 26 deletions deepdataspace/server/resources/api_v1/label_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
"""

import copy
from typing import List

from pydantic import BaseModel

from deepdataspace.constants import ErrCode
from deepdataspace.constants import LabelImageQAActions
Expand Down Expand Up @@ -918,6 +921,43 @@ def get(self, request, task_id):
return format_response(data)


class BoundingBox(BaseModel):
xmin: float
ymin: float
xmax: float
ymax: float


class Mask(BaseModel):
counts: str
size: List[int]


AnnoDataMissingFields = type("AnnoDataMissingFields", (ValueError,), {})


class AnnoData(BaseModel):
category_name: str
bounding_box: BoundingBox = {}
segmentation: str = ""
mask: Mask = ""
points: List[float] = []
lines: List[int] = []
point_colors: List[int] = []
point_names: List[str] = []

def model_post_init(self, __context) -> None:
empty_bbox = not self.bounding_box
empty_polygon = not self.segmentation
empty_mask = not self.mask
empty_keypoints = (not self.lines or
not self.points or
not self.point_colors or
not self.point_names)
if empty_bbox or empty_polygon or empty_mask or empty_keypoints:
raise AnnoDataMissingFields(f"annotations missing field(s)")


class TaskImageLabelView(AuthenticatedAPIView):
"""
- POST /api/v1/label_task_image_labels/<task_image_id>
Expand All @@ -933,36 +973,17 @@ def _parse_annotations(self, request):
valid_anno_list = []
for idx, anno in enumerate(annotations):
try:
assert "category_name" in anno
assert "bounding_box" in anno
assert "xmin" in anno["bounding_box"]
assert "ymin" in anno["bounding_box"]
assert "xmax" in anno["bounding_box"]
assert "ymax" in anno["bounding_box"]
except AssertionError:
raise_exception(ErrCode.LabelAnnotationMissingFields, f"annotations[{idx}] missing field(s)")

try:
bbox = anno["bounding_box"]
cat_name = str(anno["category_name"])
xmin, ymin = float(bbox["xmin"]), float(bbox["ymin"])
xmax, ymax = float(bbox["xmax"]), float(bbox["ymax"])
valid_anno = AnnoData(**anno)
except AnnoDataMissingFields as err:
raise_exception(ErrCode.LabelAnnotationMissingFields,
f"annotations.[{idx}] {ErrCode.LabelAnnotationMissingFieldsMsg}")
except Exception:
raise_exception(ErrCode.LabelAnnotationFieldValueInvalid,
f"annotations[{idx}] field data type is wrong")
f"annotations.[{idx}] {ErrCode.LabelAnnotationFieldValueInvalid}")
else:
valid_anno = {
"category_name": cat_name,
"category_id" : cat_name,
"bounding_box" : {
"xmin": xmin,
"ymin": ymin,
"xmax": xmax,
"ymax": ymax,
}
}
valid_anno = valid_anno.dict()
valid_anno["category_id"] = valid_anno["category_name"]
valid_anno_list.append(valid_anno)

return valid_anno_list

def post(self, request, task_image_id):
Expand Down

0 comments on commit 088f64f

Please sign in to comment.