Skip to content

Commit

Permalink
Merge pull request #80 from FocoosAI/hotfix/bbox-coordinates
Browse files Browse the repository at this point in the history
hotfix(detection): fix wrong bbox coordinates
  • Loading branch information
CuriousDolphin authored Feb 25, 2025
2 parents dec64e0 + 01d7d99 commit 53cf873
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 20 deletions.
10 changes: 3 additions & 7 deletions focoos/local_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
from focoos.utils.vision import (
get_postprocess_fn,
image_preprocess,
scale_detections,
sv_to_fai_detections,
)

Expand Down Expand Up @@ -190,21 +189,18 @@ def infer(
resize = None #!TODO check for segmentation
if self.metadata.task == FocoosTask.DETECTION:
resize = 640 if not self.metadata.im_size else self.metadata.im_size
logger.debug(f"Resize: {resize}")

t0 = perf_counter()
im1, im0 = image_preprocess(image, resize=resize)
logger.debug(f"Input image size: {im0.shape}, Resize to: {resize}")
t1 = perf_counter()
detections = self.runtime(im1.astype(np.float32))

t2 = perf_counter()

detections = self.postprocess_fn(
out=detections, im0_shape=(im0.shape[1], im0.shape[0]), conf_threshold=threshold
out=detections, im0_shape=(im0.shape[0], im0.shape[1]), conf_threshold=threshold
)

if resize:
detections = scale_detections(detections, (resize, resize), (im0.shape[1], im0.shape[0]))

out = sv_to_fai_detections(detections, classes=self.metadata.classes)
t3 = perf_counter()
latency = {
Expand Down
4 changes: 2 additions & 2 deletions focoos/utils/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,8 +296,8 @@ def det_postprocess(out: List[np.ndarray], im0_shape: Tuple[int, int], conf_thre
sv.Detections: A sv.Detections object containing the filtered bounding boxes, class ids, and confidences.
"""
cls_ids, boxes, confs = out
boxes[:, 0::2] *= im0_shape[1]
boxes[:, 1::2] *= im0_shape[0]
boxes[:, :, 0::2] *= im0_shape[1]
boxes[:, :, 1::2] *= im0_shape[0]
high_conf_indices = (confs > conf_threshold).nonzero()

return sv.Detections(
Expand Down
5 changes: 0 additions & 5 deletions tests/test_local_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,10 +189,6 @@ def mock_infer_setup(
mock_image_preprocess = mocker.patch("focoos.local_model.image_preprocess")
mock_image_preprocess.return_value = (image_ndarray, image_ndarray)

# Mock scale_detections
mock_scale_detections = mocker.patch("focoos.local_model.scale_detections")
mock_scale_detections.return_value = mock_sv_detections

# Mock sv_to_focoos_detections
mock_sv_to_focoos_detections = mocker.patch("focoos.local_model.sv_to_fai_detections")
mock_sv_to_focoos_detections.return_value = mock_focoos_detections.detections
Expand All @@ -219,7 +215,6 @@ def __call__(self, *args, **kwargs):
return (
mock_image_preprocess,
mock_runtime_call,
mock_scale_detections,
mock_sv_to_focoos_detections,
mock_annotate,
)
Expand Down
12 changes: 6 additions & 6 deletions tests/utils/test_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,9 +208,9 @@ def test_masks_to_xyxy():


def test_det_post_process():
cls_ids = np.array([1, 2, 3])
boxes = np.array([[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8], [0.9, 1.0, 1.1, 1.2]])
confs = np.array([0.8, 0.9, 0.7])
cls_ids = np.array([[1, 2, 3]])
boxes = np.array([[[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8], [0.9, 1.0, 1.1, 1.2]]])
confs = np.array([[0.8, 0.9, 0.7]])
out = [cls_ids, boxes, confs]

im0_shape = (640, 480)
Expand Down Expand Up @@ -316,9 +316,9 @@ def test_instance_postprocess():
def test_confidence_threshold_filtering():
"""Test that confidence threshold filtering works correctly"""
out = [
np.array([0, 1, 2]), # cls_ids
np.array([[0.1, 0.1, 0.3, 0.3], [0.4, 0.4, 0.6, 0.6], [0.7, 0.7, 0.9, 0.9]]), # boxes
np.array([0.95, 0.55, 0.85]), # confs
np.array([[0, 1, 2]]), # cls_ids
np.array([[[0.1, 0.1, 0.3, 0.3], [0.4, 0.4, 0.6, 0.6], [0.7, 0.7, 0.9, 0.9]]]), # boxes
np.array([[0.95, 0.55, 0.85]]), # confs
]

result = det_postprocess(out, (100, 100), conf_threshold=0.8)
Expand Down

0 comments on commit 53cf873

Please sign in to comment.