Skip to content

Commit

Permalink
CuPy and Pytorch utils imp
Browse files Browse the repository at this point in the history
  • Loading branch information
jasukej committed Dec 7, 2024
1 parent be0f10e commit ca1850a
Show file tree
Hide file tree
Showing 5 changed files with 550 additions and 129 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import cv2
import os
import numpy as np
import cupy as cp
from typing import Tuple, List
from ultralytics import YOLO

class ModelInferenceCupy:
ROI_ROWS = 300
ROI_COLS = 300
CONTOUR_AREA_THRESHOLD = 500
POSTPROCESS_OUTPUT_SHAPE = (640, 640)

def __init__(self, weights_path=None, precision=None):
self.yolo = YOLO(weights_path) if weights_path else None
self.precision = precision

def preprocess(self, image: np.ndarray):
# Transfer image to GPU
gpu_image = cp.asarray(image)

# Resize w CuPy
resized = cv2.resize(cp.asnumpy(gpu_image), self.POSTPROCESS_OUTPUT_SHAPE)
return cp.asarray(resized)

def _convert_bboxes_to_pixel(self, bbox_array: np.ndarray, image_shape: Tuple[int, int]):
height, width = image_shape[:2]
bbox_gpu = cp.asarray(bbox_array)

# Vectorized conversion w CuPy
x1 = cp.floor(bbox_gpu[:, 0] * width).astype(cp.int32)
y1 = cp.floor(bbox_gpu[:, 1] * height).astype(cp.int32)
x2 = cp.ceil(bbox_gpu[:, 2] * width).astype(cp.int32)
y2 = cp.ceil(bbox_gpu[:, 3] * height).astype(cp.int32)

return cp.stack([x1, y1, x2, y2], axis=1).get()

def object_filter(self, image: np.ndarray, bboxes: List[Tuple[int, int, int, int]]):
gpu_image = cp.asarray(image)
detections = []

for bbox in bboxes:
x1, y1, x2, y2 = bbox
roi = gpu_image[y1:y2, x1:x2]

# Convert to HSV
roi_cpu = cp.asnumpy(roi) # Temporarily transfer for cv2
hsv = cp.asarray(cv2.cvtColor(roi_cpu, cv2.COLOR_BGR2HSV))

# Color segmentation
lower_mask = hsv[:, :, 0] > 35
upper_mask = hsv[:, :, 0] < 80
saturation_mask = hsv[:, :, 1] > 50
mask = lower_mask & upper_mask & saturation_mask

if cp.sum(mask) > self.CONTOUR_AREA_THRESHOLD:
# Processing the contours
gray_image = cv2.cvtColor(cp.asnumpy(roi), cv2.COLOR_BGR2GRAY)
gray_image = gray_image * cp.asnumpy(mask).astype(np.uint8)
_, thresh = cv2.threshold(gray_image, 0, 255, cv2.THRESH_BINARY)

contours, _ = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
for cnt in contours:
area = cv2.contourArea(cnt)
if area > self.CONTOUR_AREA_THRESHOLD:
x, y, w, h = cv2.boundingRect(cnt)
detections.append((x + x1, y + y1, x + w + x1, y + h + y1))

return detections

def postprocess(self, confidence, bbox_array, raw_image: np.ndarray, velocity=0):
detections = self._convert_bboxes_to_pixel(bbox_array, raw_image.shape)
detections = self.object_filter(raw_image, detections)
return self.verify_object(raw_image, detections, velocity)
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
import cv2
import os
import numpy as np
from typing import Tuple, List
import jax
import jax.numpy as jnp
from functools import partial

class ModelInferenceJax:
# Constants
ROI_ROWS = 300
ROI_COLS = 300
CONTOUR_AREA_THRESHOLD = 500
POSTPROCESS_OUTPUT_SHAPE = (640, 640)

def __init__(self, weights_path=None, precision=None):
if weights_path is None or precision is None:
self.model = None
return
else:
self.precision = precision
if not os.path.exists(weights_path):
print(f"Weights file not found at {weights_path}")
raise FileNotFoundError(f"Weights file not found at {weights_path}")

# Note: Model loading would be handled separately
# JAX doesn't have direct YOLO support, so you might need a custom solution
self.device = jax.devices()[0] # Get default device

@partial(jax.jit, static_argnums=(0,))
def preprocess(self, image: np.ndarray):
"""
Preprocesses the input image using JAX acceleration.
"""
# Convert image to float32 and normalize
image = jnp.array(image, dtype=jnp.float32) / 255.0

# Convert BGR to RGB (JAX implementation)
image = image[..., ::-1]

return image

@partial(jax.jit, static_argnums=(0,))
def color_threshold(self, hsv_image: jnp.ndarray) -> jnp.ndarray:
"""
JAX-accelerated color thresholding
"""
lower_hue = 35
upper_hue = 80

hue = hsv_image[..., 0]
saturation = hsv_image[..., 1]

lower_mask = hue > lower_hue
upper_mask = hue < upper_hue
saturation_mask = saturation > 50

return jnp.logical_and(jnp.logical_and(lower_mask, upper_mask), saturation_mask)

@partial(jax.jit, static_argnums=(0,))
def object_filter(self, image: jnp.ndarray, bboxes: jnp.ndarray) -> List[Tuple[int, int, int, int]]:
"""
JAX-accelerated object filtering based on color thresholds
"""
detections = []

def process_bbox(bbox):
x1, y1, x2, y2 = [int(coord) for coord in bbox]
roi = image[y1:y2, x1:x2]

# Convert to HSV using JAX operations
# Note: You'll need to implement HSV conversion in JAX
hsv = self._bgr_to_hsv(roi)

mask = self.color_threshold(hsv)

if jnp.sum(mask) > self.CONTOUR_AREA_THRESHOLD:
return (x1, y1, x2, y2)
return None

# Process each bbox
filtered_boxes = jax.vmap(process_bbox)(bboxes)
return [box for box in filtered_boxes if box is not None]

@partial(jax.jit, static_argnums=(0,))
def _bgr_to_hsv(self, bgr: jnp.ndarray) -> jnp.ndarray:
"""
JAX implementation of BGR to HSV conversion
Args:
bgr: Input image in BGR format with values in [0, 1]
Returns:
HSV image with H in [0, 180], S and V in [0, 255]
"""
# Separate BGR channels
b, g, r = bgr[..., 2], bgr[..., 1], bgr[..., 0]

# Calculate Value (V)
v = jnp.maximum(jnp.maximum(r, g), b)

# Calculate Saturation (S)
diff = v - jnp.minimum(jnp.minimum(r, g), b)
s = jnp.where(v == 0, 0, diff / v)

# Calculate Hue (H)
h = jnp.zeros_like(v)

# When r is max
r_max = (v == r)
h = jnp.where(r_max, 60 * (g - b) / (diff + 1e-7), h)

# When g is max
g_max = (v == g)
h = jnp.where(g_max, 120 + 60 * (b - r) / (diff + 1e-7), h)

# When b is max
b_max = (v == b)
h = jnp.where(b_max, 240 + 60 * (r - g) / (diff + 1e-7), h)

# Adjust negative values
h = jnp.where(h < 0, h + 360, h)

# Scale values to match OpenCV ranges
h = h / 2 # Convert to [0, 180]
s = s * 255 # Convert to [0, 255]
v = v * 255 # Convert to [0, 255]

return jnp.stack([h, s, v], axis=-1)

def postprocess(self, confidences: jnp.ndarray, bbox_array: jnp.ndarray,
raw_image: np.ndarray, velocity: float = 0) -> List[Tuple[int, int, int, int]]:
"""
Postprocesses the bounding boxes using JAX acceleration where possible
"""
detections = self.object_filter(raw_image, bbox_array)
detections = self.verify_object(raw_image, detections, velocity)
return detections

@partial(jax.jit, static_argnums=(0,))
def verify_object(self, image: jnp.ndarray, detections: List[Tuple[int, int, int, int]],
velocity: float) -> List[Tuple[int, int, int, int]]:
"""
JAX-accelerated object verification based on ROI and velocity
Args:
image: Input image
detections: List of bounding boxes in format [(x1, y1, x2, y2), ...]
velocity: Velocity value for ROI shift
Returns:
List of verified and adjusted bounding boxes
"""
# Convert list of tuples to JAX array
boxes = jnp.array(detections)

# Get image dimensions
height, width = image.shape[:2]

# Calculate ROI coordinates
roi_x1 = width // 2 - self.ROI_COLS // 2
roi_y1 = height // 2 - self.ROI_ROWS // 2
roi_x2 = roi_x1 + self.ROI_COLS
roi_y2 = roi_y1 + self.ROI_ROWS

# Calculate velocity-shifted ROI
shifted_roi_x1 = roi_x1 - int(velocity)
shifted_roi_x2 = roi_x2 - int(velocity)
shifted_roi_y1 = roi_y1
shifted_roi_y2 = roi_y2

def process_box(box):
x1, y1, x2, y2 = box

# Check if box is outside shifted ROI
outside_x = jnp.logical_or(
jnp.logical_and(x1 < shifted_roi_x1, x2 < shifted_roi_x1),
jnp.logical_and(x1 > shifted_roi_x2, x2 > shifted_roi_x2)
)
outside_y = jnp.logical_or(
jnp.logical_and(y1 < shifted_roi_y1, y2 < shifted_roi_y1),
jnp.logical_and(y1 > shifted_roi_y2, y2 > shifted_roi_y2)
)
is_outside = jnp.logical_or(outside_x, outside_y)

# Clip coordinates to ROI boundaries
x1_new = jnp.clip(x1, shifted_roi_x1, shifted_roi_x2)
x2_new = jnp.clip(x2, shifted_roi_x1, shifted_roi_x2)
y1_new = jnp.clip(y1, shifted_roi_y1, shifted_roi_y2)
y2_new = jnp.clip(y2, shifted_roi_y1, shifted_roi_y2)

# Create adjusted box
adjusted_box = jnp.array([x1_new, y1_new, x2_new, y2_new])

# Return adjusted box or None-indicator (zeros) if outside ROI
return jnp.where(is_outside, jnp.zeros(4), adjusted_box)

# Apply processing to all boxes using vmap
processed_boxes = jax.vmap(process_box)(boxes)

# Filter out invalid boxes (those that were outside ROI)
valid_mask = jnp.any(processed_boxes != 0, axis=1)
valid_boxes = processed_boxes[valid_mask]

return [(int(x1), int(y1), int(x2), int(y2))
for x1, y1, x2, y2 in valid_boxes.tolist()]
Loading

0 comments on commit ca1850a

Please sign in to comment.