Skip to content

Commit

Permalink
BUG FIXES: fot other methods rather than Eigencam
Browse files Browse the repository at this point in the history
  • Loading branch information
Spritan committed Feb 26, 2024
1 parent 7ff1ebc commit f6544d0
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 75 deletions.
198 changes: 129 additions & 69 deletions YOLOv8_Explainer/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from PIL import Image

from ultralytics.nn.tasks import attempt_load_weights
from ultralytics.utils.ops import non_max_suppression
from ultralytics.utils.ops import non_max_suppression, xywh2xyxy

from pytorch_grad_cam import (
GradCAMPlusPlus,
Expand All @@ -24,52 +24,135 @@
from pytorch_grad_cam.activations_and_gradients import ActivationsAndGradients
from pytorch_grad_cam.utils.image import show_cam_on_image, scale_cam_image

from typing import List, Optional, Union, Tuple

from .utils import letterbox

class ActivationsAndGradients:
""" Class for extracting activations and
registering gradients from targetted intermediate layers """

class yolov8_target:
"""
This class is used to implement the YOLOv8 target layer.
def __init__(self, model: torch.nn.Module,
target_layers: List[torch.nn.Module],
reshape_transform: Optional[callable]) -> None: # type: ignore
"""
Initializes the ActivationsAndGradients object.
Args:
ouput_type (str): The type of output to return. Can be "class", "box", or "all".
conf (float): The confidence threshold.
ratio (float): The ratio of maximum scores to return.
Args:
model (torch.nn.Module): The neural network model.
target_layers (List[torch.nn.Module]): List of target layers from which to extract activations and gradients.
reshape_transform (Optional[callable]): A function to transform the shape of the activations and gradients if needed.
"""
self.model = model
self.gradients = []
self.activations = []
self.reshape_transform = reshape_transform
self.handles = []
for target_layer in target_layers:
self.handles.append(
target_layer.register_forward_hook(self.save_activation))
# Because of https://github.com/pytorch/pytorch/issues/61519,
# we don't use backward hook to record gradients.
self.handles.append(
target_layer.register_forward_hook(self.save_gradient))

def save_activation(self, module: torch.nn.Module,
input: Union[torch.Tensor, Tuple[torch.Tensor, ...]],
output: torch.Tensor) -> None:
"""
Saves the activation of the targeted layer.
Returns:
A tensor containing the output.
Args:
module (torch.nn.Module): The targeted layer module.
input (Union[torch.Tensor, Tuple[torch.Tensor, ...]]): The input to the targeted layer.
output (torch.Tensor): The output activation of the targeted layer.
"""
activation = output

"""
if self.reshape_transform is not None:
activation = self.reshape_transform(activation)
self.activations.append(activation.cpu().detach())

def __init__(self, ouput_type, conf, ratio) -> None:
super().__init__()
self.ouput_type = ouput_type
self.conf = conf
self.ratio = ratio
def save_gradient(self, module: torch.nn.Module,
input: Union[torch.Tensor, Tuple[torch.Tensor, ...]],
output: torch.Tensor) -> None:
"""
Saves the gradient of the targeted layer.
def forward(self, data):
Args:
module (torch.nn.Module): The targeted layer module.
input (Union[torch.Tensor, Tuple[torch.Tensor, ...]]): The input to the targeted layer.
output (torch.Tensor): The output activation of the targeted layer.
"""
if not hasattr(output, "requires_grad") or not output.requires_grad:
# You can only register hooks on tensor requires grad.
return

# Gradients are computed in reverse order
def _store_grad(grad: torch.Tensor) -> None:
if self.reshape_transform is not None:
grad = self.reshape_transform(grad)
self.gradients = [grad.cpu().detach()] + self.gradients

output.register_hook(_store_grad)

def post_process(self, result: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, np.ndarray]:
"""
This function is used to perform the forward pass of the YOLOv8 target layer.
Post-processes the result.
Args:
data (tensor): The input data.
result (torch.Tensor): The result tensor.
Returns:
A tensor containing the output.
Tuple[torch.Tensor, torch.Tensor, np.ndarray]: A tuple containing the post-processed result.
"""
logits_ = result[:, 4:]
boxes_ = result[:, :4]
sorted, indices = torch.sort(logits_.max(1)[0], descending=True)
return torch.transpose(logits_[0], dim0=0, dim1=1)[indices[0]], torch.transpose(boxes_[0], dim0=0, dim1=1)[indices[0]], xywh2xyxy(torch.transpose(boxes_[0], dim0=0, dim1=1)[indices[0]]).cpu().detach().numpy()

def __call__(self, x: torch.Tensor) -> List[List[Union[torch.Tensor, np.ndarray]]]:
"""
Calls the ActivationsAndGradients object.
Args:
x (torch.Tensor): The input tensor.
Returns:
List[List[Union[torch.Tensor, np.ndarray]]]: A list containing activations and gradients.
"""
self.gradients = []
self.activations = []
model_output = self.model(x)
post_result, pre_post_boxes, post_boxes = self.post_process(model_output[0])
return [[post_result, pre_post_boxes]]

def release(self) -> None:
"""Removes hooks."""
for handle in self.handles:
handle.remove()


class yolov8_target(torch.nn.Module):
def __init__(self, ouput_type, conf, ratio) -> None:
super().__init__()
self.ouput_type = ouput_type
self.conf = conf
self.ratio = ratio

def forward(self, data):
post_result, pre_post_boxes = data
result = []
for i in range(int(post_result.size(0) * self.ratio)):
if float(post_result[i].max()) < self.conf:
break
if self.ouput_type in ["class", "all"]:
if self.ouput_type == 'class' or self.ouput_type == 'all':
result.append(post_result[i].max())
elif self.ouput_type in ["box", "all"]:
result.extend(pre_post_boxes[i, :4])
elif self.ouput_type == 'box' or self.ouput_type == 'all':
for j in range(4):
result.append(pre_post_boxes[i, j])
return sum(result)


class yolov8_heatmap:
"""
This class is used to implement the YOLOv8 target layer.
Expand All @@ -94,8 +177,8 @@ def __init__(
self,
weight:str,
device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
method: str = "EigenCAM",
layer=[10, 12, 14, 16, 18, -3],
method= "EigenCAM",
layer=[10, 12, 14, 16, 18],
backward_type="all",
conf_threshold=0.2,
ratio=0.02,
Expand All @@ -105,33 +188,24 @@ def __init__(
"""
Initialize the YOLOv8 heatmap layer.
"""
self.conf_threshold = conf_threshold
self.device = device
self.renormalize = renormalize
self.show_box = show_box

device = device
ckpt = torch.load(weight)
self.model_names = model_names = ckpt["model"].names
self.model = model = attempt_load_weights(weight, device)
self.model.info()
for p in self.model.parameters():
model_names = ckpt['model'].names
model = attempt_load_weights(weight, device)
model.info()
for p in model.parameters():
p.requires_grad_(True)
self.model.eval()

self.target = target = yolov8_target(backward_type, conf_threshold, ratio)
model.eval()
target = yolov8_target(backward_type, conf_threshold, ratio)
target_layers = [model.model[l] for l in layer]

self.method = method = eval(method)(
self.model, target_layers, use_cuda=device.type == "cuda"
)
method.activations_and_grads = ActivationsAndGradients( # type:ignore
self.model, target_layers, None
)

self.colors = colors = np.random.uniform(
0, 255, size=(len(model_names), 3)
).astype(int)
method = eval(method)(model, target_layers, use_cuda=device.type == 'cuda')
method.activations_and_grads = ActivationsAndGradients(model, target_layers, None)

colors = np.random.uniform(0, 255, size=(len(model_names), 3)).astype(int)
self.__dict__.update(locals())


def post_process(self, result):
"""
Expand All @@ -143,9 +217,8 @@ def post_process(self, result):
Returns:
numpy.ndarray: The filtered detections.
"""
return non_max_suppression(
result, conf_thres=self.conf_threshold, iou_thres=0.65
)[0]
result = non_max_suppression(result, conf_thres=self.conf_threshold, iou_thres=0.65)[0]
return result

def draw_detections(self, box, color, name, img):
"""
Expand All @@ -162,16 +235,7 @@ def draw_detections(self, box, color, name, img):
"""
xmin, ymin, xmax, ymax = list(map(int, list(box)))
cv2.rectangle(img, (xmin, ymin), (xmax, ymax), tuple(int(x) for x in color), 2)
cv2.putText(
img,
str(name),
(xmin, ymin - 5),
cv2.FONT_HERSHEY_SIMPLEX,
0.8,
tuple(int(x) for x in color),
2,
lineType=cv2.LINE_AA,
)
cv2.putText(img, str(name), (xmin, ymin - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.8, tuple(int(x) for x in color), 2, lineType=cv2.LINE_AA)
return img

def renormalize_cam_in_bounding_boxes(
Expand All @@ -195,21 +259,17 @@ def renormalize_cam_in_bounding_boxes(
renormalized_cam = np.zeros(grayscale_cam.shape, dtype=np.float32)
for x1, y1, x2, y2 in boxes:
x1, y1 = max(x1, 0), max(y1, 0)
x2, y2 = min(grayscale_cam.shape[1] - 1, x2), min(
grayscale_cam.shape[0] - 1, y2
)
renormalized_cam[y1:y2, x1:x2] = scale_cam_image(
grayscale_cam[y1:y2, x1:x2].copy()
)
x2, y2 = min(grayscale_cam.shape[1] - 1, x2), min(grayscale_cam.shape[0] - 1, y2)
renormalized_cam[y1:y2, x1:x2] = scale_cam_image(grayscale_cam[y1:y2, x1:x2].copy())
renormalized_cam = scale_cam_image(renormalized_cam)
eigencam_image_renormalized = show_cam_on_image(image_float_np, renormalized_cam, use_rgb=True) # type: ignore
eigencam_image_renormalized = show_cam_on_image(image_float_np, renormalized_cam, use_rgb=True)
return eigencam_image_renormalized

def renormalize_cam(self, boxes, image_float_np, grayscale_cam):
"""Normalize the CAM to be in the range [0, 1]
across the entire image."""
renormalized_cam = scale_cam_image(grayscale_cam)
eigencam_image_renormalized = show_cam_on_image(image_float_np, renormalized_cam, use_rgb=True) # type: ignore
eigencam_image_renormalized = show_cam_on_image(image_float_np, renormalized_cam, use_rgb=True)
return eigencam_image_renormalized

def process(self, img_path):
Expand Down
7 changes: 2 additions & 5 deletions YOLOv8_Explainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ def letterbox(
tuple: Padding sizes.
"""
# Resize and pad image while meeting stride-multiple constraints
shape = im.shape[:2] # current shape [height, width]
if isinstance(new_shape, int):
new_shape = (new_shape, new_shape)
Expand Down Expand Up @@ -61,10 +60,8 @@ def letterbox(
im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR)
top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
im = cv2.copyMakeBorder(
im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color
) # add border

im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border

return im, ratio, (dw, dh)

def display_images(images: list[Image.Image]) -> None:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setup(
name="YOLOv8_Explainer",
version="0.0.01",
version="0.0.03",
description="Python packages that enable XAI methods for YOLOv8",
packages=find_packages(),
long_description=long_description,
Expand Down

0 comments on commit f6544d0

Please sign in to comment.