-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
550 additions
and
129 deletions.
There are no files selected for viewing
74 changes: 74 additions & 0 deletions
74
workspace_python/ros2_ws/src/python_workspace/python_workspace/scripts/utils_cupy.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
import cv2 | ||
import os | ||
import numpy as np | ||
import cupy as cp | ||
from typing import Tuple, List | ||
from ultralytics import YOLO | ||
|
||
class ModelInferenceCupy: | ||
ROI_ROWS = 300 | ||
ROI_COLS = 300 | ||
CONTOUR_AREA_THRESHOLD = 500 | ||
POSTPROCESS_OUTPUT_SHAPE = (640, 640) | ||
|
||
def __init__(self, weights_path=None, precision=None): | ||
self.yolo = YOLO(weights_path) if weights_path else None | ||
self.precision = precision | ||
|
||
def preprocess(self, image: np.ndarray): | ||
# Transfer image to GPU | ||
gpu_image = cp.asarray(image) | ||
|
||
# Resize w CuPy | ||
resized = cv2.resize(cp.asnumpy(gpu_image), self.POSTPROCESS_OUTPUT_SHAPE) | ||
return cp.asarray(resized) | ||
|
||
def _convert_bboxes_to_pixel(self, bbox_array: np.ndarray, image_shape: Tuple[int, int]): | ||
height, width = image_shape[:2] | ||
bbox_gpu = cp.asarray(bbox_array) | ||
|
||
# Vectorized conversion w CuPy | ||
x1 = cp.floor(bbox_gpu[:, 0] * width).astype(cp.int32) | ||
y1 = cp.floor(bbox_gpu[:, 1] * height).astype(cp.int32) | ||
x2 = cp.ceil(bbox_gpu[:, 2] * width).astype(cp.int32) | ||
y2 = cp.ceil(bbox_gpu[:, 3] * height).astype(cp.int32) | ||
|
||
return cp.stack([x1, y1, x2, y2], axis=1).get() | ||
|
||
def object_filter(self, image: np.ndarray, bboxes: List[Tuple[int, int, int, int]]): | ||
gpu_image = cp.asarray(image) | ||
detections = [] | ||
|
||
for bbox in bboxes: | ||
x1, y1, x2, y2 = bbox | ||
roi = gpu_image[y1:y2, x1:x2] | ||
|
||
# Convert to HSV | ||
roi_cpu = cp.asnumpy(roi) # Temporarily transfer for cv2 | ||
hsv = cp.asarray(cv2.cvtColor(roi_cpu, cv2.COLOR_BGR2HSV)) | ||
|
||
# Color segmentation | ||
lower_mask = hsv[:, :, 0] > 35 | ||
upper_mask = hsv[:, :, 0] < 80 | ||
saturation_mask = hsv[:, :, 1] > 50 | ||
mask = lower_mask & upper_mask & saturation_mask | ||
|
||
if cp.sum(mask) > self.CONTOUR_AREA_THRESHOLD: | ||
# Processing the contours | ||
gray_image = cv2.cvtColor(cp.asnumpy(roi), cv2.COLOR_BGR2GRAY) | ||
gray_image = gray_image * cp.asnumpy(mask).astype(np.uint8) | ||
_, thresh = cv2.threshold(gray_image, 0, 255, cv2.THRESH_BINARY) | ||
|
||
contours, _ = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) | ||
for cnt in contours: | ||
area = cv2.contourArea(cnt) | ||
if area > self.CONTOUR_AREA_THRESHOLD: | ||
x, y, w, h = cv2.boundingRect(cnt) | ||
detections.append((x + x1, y + y1, x + w + x1, y + h + y1)) | ||
|
||
return detections | ||
|
||
def postprocess(self, confidence, bbox_array, raw_image: np.ndarray, velocity=0): | ||
detections = self._convert_bboxes_to_pixel(bbox_array, raw_image.shape) | ||
detections = self.object_filter(raw_image, detections) | ||
return self.verify_object(raw_image, detections, velocity) |
206 changes: 206 additions & 0 deletions
206
workspace_python/ros2_ws/src/python_workspace/python_workspace/scripts/utils_jax.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,206 @@ | ||
import cv2 | ||
import os | ||
import numpy as np | ||
from typing import Tuple, List | ||
import jax | ||
import jax.numpy as jnp | ||
from functools import partial | ||
|
||
class ModelInferenceJax: | ||
# Constants | ||
ROI_ROWS = 300 | ||
ROI_COLS = 300 | ||
CONTOUR_AREA_THRESHOLD = 500 | ||
POSTPROCESS_OUTPUT_SHAPE = (640, 640) | ||
|
||
def __init__(self, weights_path=None, precision=None): | ||
if weights_path is None or precision is None: | ||
self.model = None | ||
return | ||
else: | ||
self.precision = precision | ||
if not os.path.exists(weights_path): | ||
print(f"Weights file not found at {weights_path}") | ||
raise FileNotFoundError(f"Weights file not found at {weights_path}") | ||
|
||
# Note: Model loading would be handled separately | ||
# JAX doesn't have direct YOLO support, so you might need a custom solution | ||
self.device = jax.devices()[0] # Get default device | ||
|
||
@partial(jax.jit, static_argnums=(0,)) | ||
def preprocess(self, image: np.ndarray): | ||
""" | ||
Preprocesses the input image using JAX acceleration. | ||
""" | ||
# Convert image to float32 and normalize | ||
image = jnp.array(image, dtype=jnp.float32) / 255.0 | ||
|
||
# Convert BGR to RGB (JAX implementation) | ||
image = image[..., ::-1] | ||
|
||
return image | ||
|
||
@partial(jax.jit, static_argnums=(0,)) | ||
def color_threshold(self, hsv_image: jnp.ndarray) -> jnp.ndarray: | ||
""" | ||
JAX-accelerated color thresholding | ||
""" | ||
lower_hue = 35 | ||
upper_hue = 80 | ||
|
||
hue = hsv_image[..., 0] | ||
saturation = hsv_image[..., 1] | ||
|
||
lower_mask = hue > lower_hue | ||
upper_mask = hue < upper_hue | ||
saturation_mask = saturation > 50 | ||
|
||
return jnp.logical_and(jnp.logical_and(lower_mask, upper_mask), saturation_mask) | ||
|
||
@partial(jax.jit, static_argnums=(0,)) | ||
def object_filter(self, image: jnp.ndarray, bboxes: jnp.ndarray) -> List[Tuple[int, int, int, int]]: | ||
""" | ||
JAX-accelerated object filtering based on color thresholds | ||
""" | ||
detections = [] | ||
|
||
def process_bbox(bbox): | ||
x1, y1, x2, y2 = [int(coord) for coord in bbox] | ||
roi = image[y1:y2, x1:x2] | ||
|
||
# Convert to HSV using JAX operations | ||
# Note: You'll need to implement HSV conversion in JAX | ||
hsv = self._bgr_to_hsv(roi) | ||
|
||
mask = self.color_threshold(hsv) | ||
|
||
if jnp.sum(mask) > self.CONTOUR_AREA_THRESHOLD: | ||
return (x1, y1, x2, y2) | ||
return None | ||
|
||
# Process each bbox | ||
filtered_boxes = jax.vmap(process_bbox)(bboxes) | ||
return [box for box in filtered_boxes if box is not None] | ||
|
||
@partial(jax.jit, static_argnums=(0,)) | ||
def _bgr_to_hsv(self, bgr: jnp.ndarray) -> jnp.ndarray: | ||
""" | ||
JAX implementation of BGR to HSV conversion | ||
Args: | ||
bgr: Input image in BGR format with values in [0, 1] | ||
Returns: | ||
HSV image with H in [0, 180], S and V in [0, 255] | ||
""" | ||
# Separate BGR channels | ||
b, g, r = bgr[..., 2], bgr[..., 1], bgr[..., 0] | ||
|
||
# Calculate Value (V) | ||
v = jnp.maximum(jnp.maximum(r, g), b) | ||
|
||
# Calculate Saturation (S) | ||
diff = v - jnp.minimum(jnp.minimum(r, g), b) | ||
s = jnp.where(v == 0, 0, diff / v) | ||
|
||
# Calculate Hue (H) | ||
h = jnp.zeros_like(v) | ||
|
||
# When r is max | ||
r_max = (v == r) | ||
h = jnp.where(r_max, 60 * (g - b) / (diff + 1e-7), h) | ||
|
||
# When g is max | ||
g_max = (v == g) | ||
h = jnp.where(g_max, 120 + 60 * (b - r) / (diff + 1e-7), h) | ||
|
||
# When b is max | ||
b_max = (v == b) | ||
h = jnp.where(b_max, 240 + 60 * (r - g) / (diff + 1e-7), h) | ||
|
||
# Adjust negative values | ||
h = jnp.where(h < 0, h + 360, h) | ||
|
||
# Scale values to match OpenCV ranges | ||
h = h / 2 # Convert to [0, 180] | ||
s = s * 255 # Convert to [0, 255] | ||
v = v * 255 # Convert to [0, 255] | ||
|
||
return jnp.stack([h, s, v], axis=-1) | ||
|
||
def postprocess(self, confidences: jnp.ndarray, bbox_array: jnp.ndarray, | ||
raw_image: np.ndarray, velocity: float = 0) -> List[Tuple[int, int, int, int]]: | ||
""" | ||
Postprocesses the bounding boxes using JAX acceleration where possible | ||
""" | ||
detections = self.object_filter(raw_image, bbox_array) | ||
detections = self.verify_object(raw_image, detections, velocity) | ||
return detections | ||
|
||
@partial(jax.jit, static_argnums=(0,)) | ||
def verify_object(self, image: jnp.ndarray, detections: List[Tuple[int, int, int, int]], | ||
velocity: float) -> List[Tuple[int, int, int, int]]: | ||
""" | ||
JAX-accelerated object verification based on ROI and velocity | ||
Args: | ||
image: Input image | ||
detections: List of bounding boxes in format [(x1, y1, x2, y2), ...] | ||
velocity: Velocity value for ROI shift | ||
Returns: | ||
List of verified and adjusted bounding boxes | ||
""" | ||
# Convert list of tuples to JAX array | ||
boxes = jnp.array(detections) | ||
|
||
# Get image dimensions | ||
height, width = image.shape[:2] | ||
|
||
# Calculate ROI coordinates | ||
roi_x1 = width // 2 - self.ROI_COLS // 2 | ||
roi_y1 = height // 2 - self.ROI_ROWS // 2 | ||
roi_x2 = roi_x1 + self.ROI_COLS | ||
roi_y2 = roi_y1 + self.ROI_ROWS | ||
|
||
# Calculate velocity-shifted ROI | ||
shifted_roi_x1 = roi_x1 - int(velocity) | ||
shifted_roi_x2 = roi_x2 - int(velocity) | ||
shifted_roi_y1 = roi_y1 | ||
shifted_roi_y2 = roi_y2 | ||
|
||
def process_box(box): | ||
x1, y1, x2, y2 = box | ||
|
||
# Check if box is outside shifted ROI | ||
outside_x = jnp.logical_or( | ||
jnp.logical_and(x1 < shifted_roi_x1, x2 < shifted_roi_x1), | ||
jnp.logical_and(x1 > shifted_roi_x2, x2 > shifted_roi_x2) | ||
) | ||
outside_y = jnp.logical_or( | ||
jnp.logical_and(y1 < shifted_roi_y1, y2 < shifted_roi_y1), | ||
jnp.logical_and(y1 > shifted_roi_y2, y2 > shifted_roi_y2) | ||
) | ||
is_outside = jnp.logical_or(outside_x, outside_y) | ||
|
||
# Clip coordinates to ROI boundaries | ||
x1_new = jnp.clip(x1, shifted_roi_x1, shifted_roi_x2) | ||
x2_new = jnp.clip(x2, shifted_roi_x1, shifted_roi_x2) | ||
y1_new = jnp.clip(y1, shifted_roi_y1, shifted_roi_y2) | ||
y2_new = jnp.clip(y2, shifted_roi_y1, shifted_roi_y2) | ||
|
||
# Create adjusted box | ||
adjusted_box = jnp.array([x1_new, y1_new, x2_new, y2_new]) | ||
|
||
# Return adjusted box or None-indicator (zeros) if outside ROI | ||
return jnp.where(is_outside, jnp.zeros(4), adjusted_box) | ||
|
||
# Apply processing to all boxes using vmap | ||
processed_boxes = jax.vmap(process_box)(boxes) | ||
|
||
# Filter out invalid boxes (those that were outside ROI) | ||
valid_mask = jnp.any(processed_boxes != 0, axis=1) | ||
valid_boxes = processed_boxes[valid_mask] | ||
|
||
return [(int(x1), int(y1), int(x2), int(y2)) | ||
for x1, y1, x2, y2 in valid_boxes.tolist()] |
Oops, something went wrong.