From 742613287bf02b5ba59dae6155578e585a5b8f4a Mon Sep 17 00:00:00 2001 From: Spritan Date: Wed, 28 Feb 2024 00:12:37 +0530 Subject: [PATCH] UPDATES: Default layers change and backward_type deprecated --- YOLOv8_Explainer/core.py | 115 +++++++++++++++++++++------------------ setup.py | 8 ++- 2 files changed, 67 insertions(+), 56 deletions(-) diff --git a/YOLOv8_Explainer/core.py b/YOLOv8_Explainer/core.py index aa46651..0c16181 100644 --- a/YOLOv8_Explainer/core.py +++ b/YOLOv8_Explainer/core.py @@ -1,40 +1,28 @@ import os import shutil - +from typing import List, Optional, Tuple, Union import cv2 -import torch import numpy as np +import torch from PIL import Image - +from pytorch_grad_cam import (EigenCAM, EigenGradCAM, GradCAM, GradCAMPlusPlus, + HiResCAM, LayerCAM, RandomCAM, XGradCAM) +from pytorch_grad_cam.activations_and_gradients import ActivationsAndGradients +from pytorch_grad_cam.utils.image import scale_cam_image, show_cam_on_image from ultralytics.nn.tasks import attempt_load_weights from ultralytics.utils.ops import non_max_suppression, xywh2xyxy -from pytorch_grad_cam import ( - GradCAMPlusPlus, - GradCAM, - XGradCAM, - EigenCAM, - HiResCAM, - LayerCAM, - RandomCAM, - EigenGradCAM, -) - -from pytorch_grad_cam.activations_and_gradients import ActivationsAndGradients -from pytorch_grad_cam.utils.image import show_cam_on_image, scale_cam_image - -from typing import List, Optional, Union, Tuple - from .utils import letterbox + class ActivationsAndGradients: """ Class for extracting activations and registering gradients from targetted intermediate layers """ - def __init__(self, model: torch.nn.Module, - target_layers: List[torch.nn.Module], - reshape_transform: Optional[callable]) -> None: # type: ignore + def __init__(self, model: torch.nn.Module, + target_layers: List[torch.nn.Module], + reshape_transform: Optional[callable]) -> None: # type: ignore """ Initializes the ActivationsAndGradients object. @@ -56,8 +44,8 @@ def __init__(self, model: torch.nn.Module, self.handles.append( target_layer.register_forward_hook(self.save_gradient)) - def save_activation(self, module: torch.nn.Module, - input: Union[torch.Tensor, Tuple[torch.Tensor, ...]], + def save_activation(self, module: torch.nn.Module, + input: Union[torch.Tensor, Tuple[torch.Tensor, ...]], output: torch.Tensor) -> None: """ Saves the activation of the targeted layer. @@ -73,8 +61,8 @@ def save_activation(self, module: torch.nn.Module, activation = self.reshape_transform(activation) self.activations.append(activation.cpu().detach()) - def save_gradient(self, module: torch.nn.Module, - input: Union[torch.Tensor, Tuple[torch.Tensor, ...]], + def save_gradient(self, module: torch.nn.Module, + input: Union[torch.Tensor, Tuple[torch.Tensor, ...]], output: torch.Tensor) -> None: """ Saves the gradient of the targeted layer. @@ -110,7 +98,7 @@ def post_process(self, result: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor boxes_ = result[:, :4] sorted, indices = torch.sort(logits_.max(1)[0], descending=True) return torch.transpose(logits_[0], dim0=0, dim1=1)[indices[0]], torch.transpose(boxes_[0], dim0=0, dim1=1)[indices[0]], xywh2xyxy(torch.transpose(boxes_[0], dim0=0, dim1=1)[indices[0]]).cpu().detach().numpy() - + def __call__(self, x: torch.Tensor) -> List[List[Union[torch.Tensor, np.ndarray]]]: """ Calls the ActivationsAndGradients object. @@ -124,7 +112,8 @@ def __call__(self, x: torch.Tensor) -> List[List[Union[torch.Tensor, np.ndarray] self.gradients = [] self.activations = [] model_output = self.model(x) - post_result, pre_post_boxes, post_boxes = self.post_process(model_output[0]) + post_result, pre_post_boxes, post_boxes = self.post_process( + model_output[0]) return [[post_result, pre_post_boxes]] def release(self) -> None: @@ -139,7 +128,7 @@ def __init__(self, ouput_type, conf, ratio) -> None: self.ouput_type = ouput_type self.conf = conf self.ratio = ratio - + def forward(self, data): post_result, pre_post_boxes = data result = [] @@ -153,6 +142,7 @@ def forward(self, data): result.append(pre_post_boxes[i, j]) return sum(result) + class yolov8_heatmap: """ This class is used to implement the YOLOv8 target layer. @@ -167,7 +157,7 @@ class yolov8_heatmap: ratio (float): The ratio of maximum scores to return. Defaults to 0.02. show_box (bool): Whether to show bounding boxes with the CAM. Defaults to True. renormalize (bool): Whether to renormalize the CAM to be in the range [0, 1] across the entire image. Defaults to False. - + Returns: A tensor containing the output. @@ -175,11 +165,10 @@ class yolov8_heatmap: def __init__( self, - weight:str, + weight: str, device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"), - method= "EigenCAM", - layer=[10, 12, 14, 16, 18], - backward_type="all", + method="EigenGradCAM", + layer=[12, 17, 21], conf_threshold=0.2, ratio=0.02, show_box=True, @@ -189,6 +178,7 @@ def __init__( Initialize the YOLOv8 heatmap layer. """ device = device + self.backward_type = "all" ckpt = torch.load(weight) model_names = ckpt['model'].names model = attempt_load_weights(weight, device) @@ -196,16 +186,18 @@ def __init__( for p in model.parameters(): p.requires_grad_(True) model.eval() - + target = yolov8_target(backward_type, conf_threshold, ratio) target_layers = [model.model[l] for l in layer] - method = eval(method)(model, target_layers, use_cuda=device.type == 'cuda') - method.activations_and_grads = ActivationsAndGradients(model, target_layers, None) - - colors = np.random.uniform(0, 255, size=(len(model_names), 3)).astype(int) + method = eval(method)(model, target_layers, + use_cuda=device.type == 'cuda') + method.activations_and_grads = ActivationsAndGradients( + model, target_layers, None) + + colors = np.random.uniform( + 0, 255, size=(len(model_names), 3)).astype(int) self.__dict__.update(locals()) - def post_process(self, result): """ @@ -217,7 +209,8 @@ def post_process(self, result): Returns: numpy.ndarray: The filtered detections. """ - result = non_max_suppression(result, conf_thres=self.conf_threshold, iou_thres=0.65)[0] + result = non_max_suppression( + result, conf_thres=self.conf_threshold, iou_thres=0.65)[0] return result def draw_detections(self, box, color, name, img): @@ -225,17 +218,19 @@ def draw_detections(self, box, color, name, img): Draw bounding boxes and labels on an image. Args: - box (list): The bounding box coordinates in the format [x1, y1, x2, y2]. - color (list): The color of the bounding box in the format [B, G, R]. + box (list): The bounding box coordinates in the format [x1, y1, x2, y2] + color (list): The color of the bounding box in the format [B, G, R] name (str): The label for the bounding box. - img (numpy.ndarray): The image on which to draw the bounding box. + img (numpy.ndarray): The image on which to draw the bounding box Returns: numpy.ndarray: The image with the bounding box drawn. """ xmin, ymin, xmax, ymax = list(map(int, list(box))) - cv2.rectangle(img, (xmin, ymin), (xmax, ymax), tuple(int(x) for x in color), 2) - cv2.putText(img, str(name), (xmin, ymin - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.8, tuple(int(x) for x in color), 2, lineType=cv2.LINE_AA) + cv2.rectangle(img, (xmin, ymin), (xmax, ymax), + tuple(int(x) for x in color), 2) + cv2.putText(img, str(name), (xmin, ymin - 5), cv2.FONT_HERSHEY_SIMPLEX, + 0.8, tuple(int(x) for x in color), 2, lineType=cv2.LINE_AA) return img def renormalize_cam_in_bounding_boxes( @@ -259,17 +254,21 @@ def renormalize_cam_in_bounding_boxes( renormalized_cam = np.zeros(grayscale_cam.shape, dtype=np.float32) for x1, y1, x2, y2 in boxes: x1, y1 = max(x1, 0), max(y1, 0) - x2, y2 = min(grayscale_cam.shape[1] - 1, x2), min(grayscale_cam.shape[0] - 1, y2) - renormalized_cam[y1:y2, x1:x2] = scale_cam_image(grayscale_cam[y1:y2, x1:x2].copy()) + x2, y2 = min(grayscale_cam.shape[1] - 1, + x2), min(grayscale_cam.shape[0] - 1, y2) + renormalized_cam[y1:y2, x1:x2] = scale_cam_image( + grayscale_cam[y1:y2, x1:x2].copy()) renormalized_cam = scale_cam_image(renormalized_cam) - eigencam_image_renormalized = show_cam_on_image(image_float_np, renormalized_cam, use_rgb=True) + eigencam_image_renormalized = show_cam_on_image( + image_float_np, renormalized_cam, use_rgb=True) return eigencam_image_renormalized def renormalize_cam(self, boxes, image_float_np, grayscale_cam): - """Normalize the CAM to be in the range [0, 1] + """Normalize the CAM to be in the range [0, 1] across the entire image.""" renormalized_cam = scale_cam_image(grayscale_cam) - eigencam_image_renormalized = show_cam_on_image(image_float_np, renormalized_cam, use_rgb=True) + eigencam_image_renormalized = show_cam_on_image( + image_float_np, renormalized_cam, use_rgb=True) return eigencam_image_renormalized def process(self, img_path): @@ -299,21 +298,29 @@ def process(self, img_path): print(e) return grayscale_cam = grayscale_cam[0, :] - cam_image = show_cam_on_image(img, grayscale_cam, use_rgb=True) # type: ignore + cam_image = show_cam_on_image( + img, grayscale_cam, use_rgb=True) # type: ignore pred = self.model(tensor)[0] pred = self.post_process(pred) if self.renormalize: cam_image = self.renormalize_cam( - pred[:, :4].cpu().detach().numpy().astype(np.int32), img, grayscale_cam + pred[:, :4].cpu().detach().numpy().astype( + np.int32), img, grayscale_cam ) if self.show_box: for data in pred: data = data.cpu().detach().numpy() + # Calculate the maximum value + max_value = float(data[4:].max()) + if max_value > 1: + conf = 1 + else: + conf = max_value cam_image = self.draw_detections( data[:4], self.colors[int(data[4:].argmax())], - f"{self.model_names[int(data[4:].argmax())]} {float(data[4:].max()):.2f}", + f"{self.model_names[int(data[4:].argmax())]} {conf}", cam_image, ) diff --git a/setup.py b/setup.py index 2373707..4029540 100644 --- a/setup.py +++ b/setup.py @@ -5,12 +5,13 @@ setup( name="YOLOv8_Explainer", - version="0.0.03", + version="0.0.04", description="Python packages that enable XAI methods for YOLOv8", packages=find_packages(), long_description=long_description, long_description_content_type="text/markdown", url="https://github.com/Spritan/YOLOv8_Explainer", + # Github="https://github.com/Spritan/YOLOv8_Explainer", author="Spritan", author_email="proypabsab@gmail.com", license="MIT", @@ -32,5 +33,8 @@ extras_require={ "dev": ["twine>=4.0.2"], }, - # python_requires=">=3.10", + project_urls={ + 'Homepage': 'https://spritan.github.io/YOLOv8_Explainer/', # Homepage URL + 'Source': 'https://github.com/Spritan/YOLOv8_Explainer', # GitHub URL + }, )