Skip to content

Commit

Permalink
dev(narugo): add support for rtdetr models
Browse files Browse the repository at this point in the history
  • Loading branch information
narugo1992 committed Oct 8, 2024
1 parent c894ec5 commit 2e77452
Showing 1 changed file with 65 additions and 17 deletions.
82 changes: 65 additions & 17 deletions imgutils/generic/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import numpy as np
from PIL import Image
from hbutils.color import rnd_colors
from hfutils.operate import get_hf_client
from hfutils.operate import get_hf_client, get_hf_fs
from hfutils.repository import hf_hub_repo_url
from hfutils.utils import hf_fs_path, hf_normpath
from huggingface_hub import HfFileSystem, hf_hub_download
Expand Down Expand Up @@ -194,7 +194,7 @@ def _image_preprocess(image: Image.Image, max_infer_size: int = 1216, align: int
return image, (old_width, old_height), (new_width, new_height)


def _xy_postprocess(x, y, old_size: Tuple[int, int], new_size: Tuple[int, int]):
def _xy_postprocess(x, y, old_size: Tuple[float, float], new_size: Tuple[float, float]):
"""
Convert coordinates from the preprocessed image size back to the original image size.
Expand All @@ -203,9 +203,9 @@ def _xy_postprocess(x, y, old_size: Tuple[int, int], new_size: Tuple[int, int]):
:param y: Y-coordinate in the preprocessed image.
:type y: float
:param old_size: Original image dimensions (width, height).
:type old_size: Tuple[int, int]
:type old_size: Tuple[float, float]
:param new_size: Preprocessed image dimensions (width, height).
:type new_size: Tuple[int, int]
:type new_size: Tuple[float, float]
:return: Adjusted (x, y) coordinates for the original image size.
:rtype: Tuple[int, int]
Expand All @@ -224,7 +224,7 @@ def _xy_postprocess(x, y, old_size: Tuple[int, int], new_size: Tuple[int, int]):


def _end2end_postprocess(output, conf_threshold: float, iou_threshold: float,
old_size: Tuple[int, int], new_size: Tuple[int, int], labels: List[str]) \
old_size: Tuple[float, float], new_size: Tuple[float, float], labels: List[str]) \
-> List[Tuple[Tuple[int, int, int, int], str, float]]:
assert output.shape[-1] == 6
_ = iou_threshold # actually the iou_threshold has not been supplied to end2end post-processing
Expand All @@ -240,8 +240,9 @@ def _end2end_postprocess(output, conf_threshold: float, iou_threshold: float,


def _nms_postprocess(output, conf_threshold: float, iou_threshold: float,
old_size: Tuple[int, int], new_size: Tuple[int, int], labels: List[str]) \
old_size: Tuple[float, float], new_size: Tuple[float, float], labels: List[str]) \
-> List[Tuple[Tuple[int, int, int, int], str, float]]:
assert output.shape[0] == 4 + len(labels)
# the output should be like [4+cls, box_cnt]
# cls means count of classes
# box_cnt means count of bboxes
Expand Down Expand Up @@ -269,7 +270,7 @@ def _nms_postprocess(output, conf_threshold: float, iou_threshold: float,


def _yolo_postprocess(output, conf_threshold: float, iou_threshold: float,
old_size: Tuple[int, int], new_size: Tuple[int, int], labels: List[str]) \
old_size: Tuple[float, float], new_size: Tuple[float, float], labels: List[str]) \
-> List[Tuple[Tuple[int, int, int, int], str, float]]:
"""
Post-process the raw output from the object detection model.
Expand All @@ -284,9 +285,9 @@ def _yolo_postprocess(output, conf_threshold: float, iou_threshold: float,
:param iou_threshold: IoU threshold for non-maximum suppression.
:type iou_threshold: float
:param old_size: Original image dimensions (width, height).
:type old_size: Tuple[int, int]
:type old_size: Tuple[float, float]
:param new_size: Preprocessed image dimensions (width, height).
:type new_size: Tuple[int, int]
:type new_size: Tuple[float, float]
:param labels: List of class labels.
:type labels: List[str]
Expand Down Expand Up @@ -319,6 +320,22 @@ def _yolo_postprocess(output, conf_threshold: float, iou_threshold: float,
)


def _rtdetr_postprocess(output, conf_threshold: float, iou_threshold: float,
old_size: Tuple[int, int], new_size: Tuple[int, int], labels: List[str]) \
-> List[Tuple[Tuple[int, int, int, int], str, float]]:
assert output.shape[-1] == 4 + len(labels)

Check warning on line 326 in imgutils/generic/yolo.py

View check run for this annotation

Codecov / codecov/patch

imgutils/generic/yolo.py#L326

Added line #L326 was not covered by tests
# the size rtdetr using is [0.0, 1.0]
_ = new_size
return _nms_postprocess(

Check warning on line 329 in imgutils/generic/yolo.py

View check run for this annotation

Codecov / codecov/patch

imgutils/generic/yolo.py#L328-L329

Added lines #L328 - L329 were not covered by tests
output=output.transpose(1, 0),
conf_threshold=conf_threshold,
iou_threshold=iou_threshold,
old_size=old_size,
new_size=(1.0, 1.0),
labels=labels,
)


def _safe_eval_names_str(names_str):
"""
Safely evaluate the names string from model metadata.
Expand Down Expand Up @@ -383,6 +400,7 @@ def __init__(self, repo_id: str, hf_token: Optional[str] = None):
self.repo_id = repo_id
self._model_names = None
self._models = {}
self._model_types = {}
self._hf_token = hf_token

def _get_hf_token(self) -> Optional[str]:
Expand Down Expand Up @@ -454,6 +472,23 @@ def _open_model(self, model_name: str):

return self._models[model_name]

def _get_model_type(self, model_name: str):
if model_name not in self._model_types:
hf_fs = get_hf_fs(hf_token=self._get_hf_token())
fs_path = hf_fs_path(
repo_id=self.repo_id,
repo_type='model',
filename=f'{model_name}/model_type.json',
revision='main',
)
if hf_fs.exists(fs_path):
model_type = json.loads(hf_fs.read_text(fs_path))['model_type']
else:
model_type = 'yolo'
self._model_types[model_name] = model_type

return self._model_types[model_name]

def predict(self, image: ImageTyping, model_name: str,
conf_threshold: float = 0.25, iou_threshold: float = 0.7) \
-> List[Tuple[Tuple[int, int, int, int], str, float]]:
Expand Down Expand Up @@ -485,14 +520,27 @@ def predict(self, image: ImageTyping, model_name: str,
new_image, old_size, new_size = _image_preprocess(image, max_infer_size)
data = rgb_encode(new_image)[None, ...]
output, = model.run(['output0'], {'images': data})
return _yolo_postprocess(
output=output[0],
conf_threshold=conf_threshold,
iou_threshold=iou_threshold,
old_size=old_size,
new_size=new_size,
labels=labels
)
model_type = self._get_model_type(model_name=model_name)
if model_type == 'yolo':
return _yolo_postprocess(
output=output[0],
conf_threshold=conf_threshold,
iou_threshold=iou_threshold,
old_size=old_size,
new_size=new_size,
labels=labels
)
elif model_type == 'rtdetr':
return _rtdetr_postprocess(

Check warning on line 534 in imgutils/generic/yolo.py

View check run for this annotation

Codecov / codecov/patch

imgutils/generic/yolo.py#L533-L534

Added lines #L533 - L534 were not covered by tests
output=output[0],
conf_threshold=conf_threshold,
iou_threshold=iou_threshold,
old_size=old_size,
new_size=new_size,
labels=labels
)
else:
raise ValueError(f'Unknown object detection model type - {model_type!r}.') # pragma: no cover

def clear(self):
"""
Expand Down

0 comments on commit 2e77452

Please sign in to comment.