-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add GradCAM for object detection (#75)
* Add example with GradCAM for YOLOv5 in object detection * Remove constraints for min and max bbox size * Extract base ObjectDetector class * Remove previous example with object detection task * Add original code source URL in docstring * Move fetching layer from YOLO model function, add target_layer parameter to GradCAM * Add dataclasses for complex object detection output data types * Refactor code, add docstrings, fix names, add workaround with np.abs for bbox coordinates * Add support for torchvision SSD model for object detection, refactoring * Rename directories * Refactor, change BaseObjectDetector class definition, make SSD model inherit from it * Refactor modules structure, add GradCAM to object detection module * Remove unused imports * Refactor, move OD models to separate module, fix YOLO prediction generation * Apply pre-commit hooks to all files * Replace excessive unnecessary dependency with simple parsing * Remove model warmup run * Remove unused path in forward pass algorithm * Replace custom implementations with torchvision imports * Add unit tests for object detection utils * Remove obsolete directory * Fix YOLOv5 bbox conversion * Add unit test for object detection visualization utils, refactor * Refactor GradCAM for OD * Remove redundant device argument to YOLOv5ObjectDetector class initializer * Fix unit tests for image preprocessing - use rectangle instead of square shapes * Simplify forward function in WrapperYOLOv5ObjectDetectionModule class * Add custom GradCAM base algorithm implementation for classification and object detection * Move object detection model examples to examples directory * Restructure library directory structure for explainer algorithms, move object detection example script to notebook * Add interpolation method parameter to resize_image function * Fix unit tests and imports * Remove adding epsilon in preprocess object detection image function * Fix example notebook - add assertions of YOLOv5 image shape * Fix basic usage notebook after refactoring class names and directory structure * Enable changing image ratio to match YOLOv5 requirements of image shape * Update README * Refactor object detection custom modules * Refactor GradCAM for object detection
- Loading branch information
1 parent
1f57591
commit 26b31ef
Showing
60 changed files
with
3,539 additions
and
2,177 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,7 +11,7 @@ jobs: | |
|
||
- name: Checkout | ||
uses: actions/checkout@v3 | ||
|
||
- name: Setup Python | ||
uses: actions/setup-python@v4 | ||
with: | ||
|
@@ -26,18 +26,18 @@ jobs: | |
restore-keys: | | ||
venv-${{ runner.os }}-${{ env.PYTHON_VERSION }}-${{ hashFiles('poetry.lock') }} | ||
venv-${{ runner.os }}-${{ env.PYTHON_VERSION }}- | ||
- name: Setup poetry | ||
uses: abatilo/[email protected] | ||
with: | ||
poetry-version: ${{ env.POETRY_VERSION }} | ||
|
||
- name: Install dependencies | ||
shell: bash | ||
run: | | ||
poetry install --with docs | ||
working-directory: "" | ||
|
||
- name: Get Package Version | ||
id: get_version | ||
run: echo ::set-output name=VERSION::$(poetry version | cut -d " " -f 2) | ||
|
@@ -77,7 +77,7 @@ jobs: | |
echo ${{ github.ref }} | ||
echo "BRANCH_NAME=$(echo ${GITHUB_REF##*/} | tr / -)" >> $GITHUB_ENV | ||
cat $GITHUB_ENV | ||
- uses: actions/checkout@v3 | ||
name: Check out gh-pages branch (full history) | ||
with: | ||
|
@@ -109,13 +109,13 @@ jobs: | |
|
||
- name: Run docs-versions-menu | ||
run: docs-versions-menu | ||
|
||
- name: Set git configuration | ||
shell: bash | ||
run: | | ||
git config user.name github-actions | ||
git config user.email [email protected] | ||
- name: Commit changes | ||
shell: bash | ||
run: | | ||
|
@@ -124,7 +124,7 @@ jobs: | |
git add -A --verbose | ||
echo "# GIT STATUS" | ||
git status | ||
echo "# GIT COMMIT" | ||
echo "# GIT COMMIT" | ||
git commit --verbose -m "Auto-update from Github Actions Workflow" -m "Deployed from commit ${GITHUB_SHA} (${GITHUB_REF})" | ||
git log -n 1 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
Empty file.
205 changes: 205 additions & 0 deletions
205
example/gradcam_object_detection/custom_models/ssd/ssd_object_detector.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,205 @@ | ||
"""File contains SSD ObjectDetector class.""" | ||
from collections import OrderedDict | ||
from typing import Dict, List, Optional, Tuple, Union | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
from torchvision.models.detection import _utils as det_utils | ||
from torchvision.models.detection.image_list import ImageList | ||
from torchvision.models.detection.ssd import SSD | ||
from torchvision.ops import boxes as box_ops | ||
|
||
from foxai.explainer.computer_vision.object_detection.base_object_detector import ( | ||
BaseObjectDetector, | ||
) | ||
from foxai.explainer.computer_vision.object_detection.types import PredictionOutput | ||
|
||
|
||
class SSDObjectDetector(BaseObjectDetector): | ||
"""Custom SSD ObjectDetector class which returns predictions with logits to explain. | ||
Code based on https://github.com/pytorch/vision/blob/main/torchvision/models/detection/ssd.py. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
model: SSD, | ||
class_names: Optional[List[str]] = None, | ||
): | ||
super().__init__() | ||
self.model = model | ||
self.class_names = class_names | ||
|
||
def forward( | ||
self, | ||
image: torch.Tensor, | ||
) -> Tuple[List[PredictionOutput], List[torch.Tensor]]: | ||
"""Forward pass of the network. | ||
Args: | ||
image: Image to process. | ||
Returns: | ||
Tuple of 2 values, first is tuple of predictions containing bounding-boxes, | ||
class number, class name and confidence; second value is tensor with logits | ||
per each detection. | ||
""" | ||
# get the original image sizes | ||
images = list(image) | ||
original_image_sizes: List[Tuple[int, int]] = [] | ||
for img in images: | ||
img_shape_hw = img.shape[-2:] | ||
assert ( | ||
len(img_shape_hw) == 2 | ||
), f"expecting the last two dimensions of the Tensor to be H and W instead got {img.shape[-2:]}" | ||
original_image_sizes.append((img_shape_hw[0], img_shape_hw[1])) | ||
|
||
# transform the input | ||
image_list: ImageList | ||
targets: Optional[List[Dict[str, torch.Tensor]]] | ||
image_list, targets = self.model.transform(images, None) | ||
|
||
# Check for degenerate boxes | ||
if targets is not None: | ||
for target_idx, target in enumerate(targets): | ||
boxes = target["boxes"] | ||
degenerate_boxes = boxes[:, 2:] <= boxes[:, :2] | ||
if degenerate_boxes.any(): | ||
bb_idx = torch.where(degenerate_boxes.any(dim=1))[0][0] | ||
degen_bb: List[float] = boxes[bb_idx].tolist() | ||
assert False, ( | ||
"All bounding boxes should have positive height and width. " | ||
+ f"Found invalid box {degen_bb} for target at index {target_idx}." | ||
) | ||
|
||
# get the features from the backbone | ||
features: Union[Dict[str, torch.Tensor], torch.Tensor] = self.model.backbone( | ||
image_list.tensors | ||
) | ||
if isinstance(features, torch.Tensor): | ||
features = OrderedDict([("0", features)]) | ||
|
||
features_list = list(features.values()) | ||
|
||
# compute the ssd heads outputs using the features | ||
head_outputs = self.model.head(features_list) | ||
|
||
# create the set of anchors | ||
anchors = self.model.anchor_generator(image_list, features_list) | ||
|
||
detections: List[Dict[str, torch.Tensor]] = [] | ||
detections, logits = self.postprocess_detections( | ||
head_outputs=head_outputs, | ||
image_anchors=anchors, | ||
image_shapes=image_list.image_sizes, | ||
) | ||
detections = self.model.transform.postprocess( | ||
detections, image_list.image_sizes, original_image_sizes | ||
) | ||
|
||
detection_class_names = [str(val.item()) for val in detections[0]["labels"]] | ||
if self.class_names: | ||
detection_class_names = [ | ||
str(self.class_names[val.item()]) for val in detections[0]["labels"] | ||
] | ||
|
||
# change order of bounding boxes | ||
# at the moment they are [x2, y2, x1, y1] and we need them in | ||
# [x1, y1, x2, y2] | ||
detections[0]["boxes"] = detections[0]["boxes"].detach().cpu() | ||
for detection in detections[0]["boxes"]: | ||
tmp1 = detection[0].item() | ||
tmp2 = detection[2].item() | ||
detection[0] = detection[1] | ||
detection[2] = detection[3] | ||
detection[1] = tmp1 | ||
detection[3] = tmp2 | ||
|
||
predictions = [ | ||
PredictionOutput( | ||
bbox=bbox.tolist(), | ||
class_number=class_no.item(), | ||
class_name=class_name, | ||
confidence=confidence.item(), | ||
) | ||
for bbox, class_no, class_name, confidence in zip( | ||
detections[0]["boxes"], | ||
detections[0]["labels"], | ||
detection_class_names, | ||
detections[0]["scores"], | ||
) | ||
] | ||
|
||
return predictions, logits | ||
|
||
def postprocess_detections( | ||
self, | ||
head_outputs: Dict[str, torch.Tensor], | ||
image_anchors: List[torch.Tensor], | ||
image_shapes: List[Tuple[int, int]], | ||
) -> Tuple[List[Dict[str, torch.Tensor]], List[torch.Tensor]]: | ||
bbox_regression = head_outputs["bbox_regression"] | ||
logits = head_outputs["cls_logits"] | ||
confidence_scores = F.softmax(head_outputs["cls_logits"], dim=-1) | ||
pred_class = torch.argmax(confidence_scores[0], dim=1) | ||
pred_class = pred_class[None, :, None] | ||
|
||
num_classes = confidence_scores.size(-1) | ||
device = confidence_scores.device | ||
|
||
detections: List[Dict[str, torch.Tensor]] = [] | ||
|
||
for boxes, scores, anchors, image_shape in zip( | ||
bbox_regression, confidence_scores, image_anchors, image_shapes | ||
): | ||
boxes = self.model.box_coder.decode_single(boxes, anchors) | ||
boxes = box_ops.clip_boxes_to_image(boxes, image_shape) | ||
|
||
image_boxes: List[torch.Tensor] = [] | ||
image_scores: List[torch.Tensor] = [] | ||
image_labels: List[torch.Tensor] = [] | ||
for label in range(1, num_classes): | ||
score = scores[:, label] | ||
|
||
keep_idxs = score > self.model.score_thresh | ||
score = score[keep_idxs] | ||
box = boxes[keep_idxs] | ||
|
||
# keep only topk scoring predictions | ||
num_topk = det_utils._topk_min( # pylint: disable = (protected-access) | ||
score, self.model.topk_candidates, 0 | ||
) | ||
score, idxs = score.topk(num_topk) | ||
box = box[idxs] | ||
|
||
image_boxes.append(box) | ||
image_scores.append(score) | ||
image_labels.append( | ||
torch.full_like( | ||
score, fill_value=label, dtype=torch.int64, device=device | ||
) | ||
) | ||
|
||
image_box: torch.Tensor = torch.cat(image_boxes, dim=0) | ||
image_score: torch.Tensor = torch.cat(image_scores, dim=0) | ||
image_label: torch.Tensor = torch.cat(image_labels, dim=0) | ||
|
||
# non-maximum suppression | ||
keep = box_ops.batched_nms( | ||
boxes=image_box, | ||
scores=image_score, | ||
idxs=image_label, | ||
iou_threshold=self.model.nms_thresh, | ||
) | ||
keep = keep[: self.model.detections_per_img] | ||
|
||
detections.append( | ||
{ | ||
"boxes": image_box[keep], | ||
"scores": image_score[keep], | ||
"labels": image_label[keep], | ||
} | ||
) | ||
# add batch dimension for further processing | ||
keep_logits = logits[0][keep][None, :] | ||
return detections, list(keep_logits) |
Empty file.
Oops, something went wrong.