Skip to content

Commit

Permalink
Merge pull request #8 from Rahul1511E/multi-class
Browse files Browse the repository at this point in the history
Improve multi-class detection
  • Loading branch information
Spritan authored Nov 25, 2024
2 parents c2c0e54 + efd3ae5 commit 293dff7
Showing 1 changed file with 87 additions and 62 deletions.
149 changes: 87 additions & 62 deletions YOLOv8_Explainer/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ def post_process(self, result: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor
logits_ = result[:, 4:]
boxes_ = result[:, :4]
sorted, indices = torch.sort(logits_.max(1)[0], descending=True)
return torch.transpose(logits_[0], dim0=0, dim1=1)[indices[0]], torch.transpose(boxes_[0], dim0=0, dim1=1)[indices[0]], xywh2xyxy(torch.transpose(boxes_[0], dim0=0, dim1=1)[indices[0]]).cpu().detach().numpy()
return torch.transpose(logits_[0], dim0=0, dim1=1)[indices[0]], torch.transpose(boxes_[0], dim0=0, dim1=1)[
indices[0]], xywh2xyxy(torch.transpose(boxes_[0], dim0=0, dim1=1)[indices[0]]).cpu().detach().numpy()

def __call__(self, x: torch.Tensor) -> List[List[Union[torch.Tensor, np.ndarray]]]:
"""
Expand Down Expand Up @@ -132,14 +133,13 @@ def __init__(self, ouput_type, conf, ratio) -> None:
def forward(self, data):
post_result, pre_post_boxes = data
result = []
for i in range(int(post_result.size(0) * self.ratio)):
if float(post_result[i].max()) < self.conf:
break
if self.ouput_type == 'class' or self.ouput_type == 'all':
result.append(post_result[i].max())
elif self.ouput_type == 'box' or self.ouput_type == 'all':
for j in range(4):
result.append(pre_post_boxes[i, j])
for i in range(post_result.size(0)):
if float(post_result[i].max()) >= self.conf:
if self.ouput_type == 'class' or self.ouput_type == 'all':
result.append(post_result[i].max())
if self.ouput_type == 'box' or self.ouput_type == 'all':
for j in range(4):
result.append(pre_post_boxes[i, j])
return sum(result)


Expand All @@ -163,15 +163,15 @@ class yolov8_heatmap:
"""

def __init__(
self,
weight: str,
device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
method="EigenGradCAM",
layer=[12, 17, 21],
conf_threshold=0.2,
ratio=0.02,
show_box=True,
renormalize=False,
self,
weight: str,
device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
method="EigenGradCAM",
layer=[12, 17, 21],
conf_threshold=0.2,
ratio=0.02,
show_box=True,
renormalize=False,
) -> None:
"""
Initialize the YOLOv8 heatmap layer.
Expand Down Expand Up @@ -200,43 +200,67 @@ def __init__(

def post_process(self, result):
"""
Perform non-maximum suppression on the detections.
Perform non-maximum suppression on the detections and process results.
Args:
result (numpy.ndarray): The detections from the model.
result (torch.Tensor): The raw detections from the model.
Returns:
numpy.ndarray: The filtered detections.
torch.Tensor: Filtered and processed detections.
"""
result = non_max_suppression(
result, conf_thres=self.conf_threshold, iou_thres=0.80)[0]
return result
# Perform non-maximum suppression
processed_result = non_max_suppression(
result,
conf_thres=self.conf_threshold, # Use the class's confidence threshold
iou_thres=0.45 # Intersection over Union threshold
)

# If no detections, return an empty tensor
if len(processed_result) == 0 or processed_result[0].numel() == 0:
return torch.empty(0, 6) # Return an empty tensor with 6 columns

# Take the first batch of detections (assuming single image)
detections = processed_result[0]

# Filter detections based on confidence
mask = detections[:, 4] >= self.conf_threshold
filtered_detections = detections[mask]

return filtered_detections

def draw_detections(self, box, color, name, img):
"""
Draw bounding boxes and labels on an image.
Draw bounding boxes and labels on an image for multiple detections.
Args:
box (list): The bounding box coordinates in the format [x1, y1, x2, y2]
box (torch.Tensor or np.ndarray): The bounding box coordinates in the format [x1, y1, x2, y2]
color (list): The color of the bounding box in the format [B, G, R]
name (str): The label for the bounding box.
img (numpy.ndarray): The image on which to draw the bounding box
img (np.ndarray): The image on which to draw the bounding box
Returns:
numpy.ndarray: The image with the bounding box drawn.
np.ndarray: The image with the bounding box drawn.
"""
xmin, ymin, xmax, ymax = list(map(int, list(box)))
# Ensure box coordinates are integers
xmin, ymin, xmax, ymax = map(int, box[:4])

# Draw rectangle
cv2.rectangle(img, (xmin, ymin), (xmax, ymax),
tuple(int(x) for x in color), 2)
cv2.putText(img, str(name), (xmin, ymin - 5), cv2.FONT_HERSHEY_SIMPLEX,
0.8, tuple(int(x) for x in color), 2, lineType=cv2.LINE_AA)

# Draw label
cv2.putText(img, name, (xmin, ymin - 5),
cv2.FONT_HERSHEY_SIMPLEX,
0.8, tuple(int(x) for x in color), 2,
lineType=cv2.LINE_AA)

return img

def renormalize_cam_in_bounding_boxes(
self,
boxes: np.ndarray, # type: ignore
image_float_np: np.ndarray, # type: ignore
grayscale_cam: np.ndarray, # type: ignore
self,
boxes: np.ndarray, # type: ignore
image_float_np: np.ndarray, # type: ignore
grayscale_cam: np.ndarray, # type: ignore
) -> np.ndarray:
"""
Normalize the CAM to be in the range [0, 1]
Expand Down Expand Up @@ -271,19 +295,11 @@ def renormalize_cam(self, boxes, image_float_np, grayscale_cam):
return eigencam_image_renormalized

def process(self, img_path):
"""Process the input image and generate CAM visualization.
Args:
img_path (str): Path to the input image.
save_path (str): Path to save the generated CAM visualization.
Returns:
None
"""
"""Process the input image and generate CAM visualization."""
img = cv2.imread(img_path)
img = letterbox(img)[0]
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = np.float32(img) / 255.0 # type: ignore
img = np.float32(img) / 255.0

tensor = (
torch.from_numpy(np.transpose(img, axes=[2, 0, 1]))
Expand All @@ -296,31 +312,40 @@ def process(self, img_path):
except AttributeError as e:
print(e)
return

grayscale_cam = grayscale_cam[0, :]

pred1 = self.model(tensor)[0]
pred = self.post_process(pred1)
pred = non_max_suppression(
pred1,
conf_thres=self.conf_threshold,
iou_thres=0.45
)[0]

# Debugging print

if self.renormalize:
cam_image = self.renormalize_cam(
pred[:, :4].cpu().detach().numpy().astype(
np.int32), img, grayscale_cam
pred[:, :4].cpu().detach().numpy().astype(np.int32),
img,
grayscale_cam
)
else:
cam_image = show_cam_on_image(
img, grayscale_cam, use_rgb=True) # type: ignore
if self.show_box:
for data in pred:
data = data.cpu().detach().numpy()
# Calculate the maximum value
max_value = float(data[4:].max())
if max_value > 1:
conf = 1
else:
conf = max_value
cam_image = show_cam_on_image(img, grayscale_cam, use_rgb=True)

if self.show_box and len(pred) > 0:
for detection in pred:
detection = detection.cpu().detach().numpy()

# Get class index and confidence
class_index = int(detection[5])
conf = detection[4]

# Draw detection
cam_image = self.draw_detections(
data[:4],
self.colors[int(data[4:].argmax())],
f"{self.model_names[int(data[4:].argmax())]} {conf}",
detection[:4], # Box coordinates
self.colors[class_index], # Color for this class
f"{self.model_names[class_index]}", # Label with confidence
cam_image,
)

Expand Down

0 comments on commit 293dff7

Please sign in to comment.