forked from vitoralbiero/img2pose
-
Notifications
You must be signed in to change notification settings - Fork 0
/
generalized_rcnn.py
121 lines (105 loc) · 4.61 KB
/
generalized_rcnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import warnings
from collections import OrderedDict
import torch
from torch import Tensor, nn
from torch.jit.annotations import Dict, List, Optional, Tuple
class GeneralizedRCNN(nn.Module):
"""
Main class for Generalized R-CNN.
Arguments:
backbone (nn.Module):
rpn (nn.Module):
roi_heads (nn.Module): takes the features + the proposals from the RPN
and computes detections / masks from it.
transform (nn.Module): performs the data transformation from the inputs
to feed into the model
"""
def __init__(self, backbone, rpn, roi_heads, transform):
super(GeneralizedRCNN, self).__init__()
self.transform = transform
self.backbone = backbone
self.rpn = rpn
self.roi_heads = roi_heads
# used only on torchscript mode
self._has_warned = False
@torch.jit.unused
def eager_outputs(self, losses, detections, evaluating):
# type: (Dict[str, Tensor], List[Dict[str, Tensor]])
# -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]
if evaluating:
return losses
return detections
def forward(self, images, targets=None):
# type: (List[Tensor], Optional[List[Dict[str, Tensor]]])
# -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]
"""
Arguments:
images (list[Tensor]): images to be processed
targets (list[Dict[Tensor]]): ground-truth (optional)
Returns:
result (list[BoxList] or dict[Tensor]): the output from the model.
During training, it returns a dict[Tensor] which contains the losses.
During testing, it returns list[BoxList] contains additional fields
like `scores`, `labels` and `mask` (for Mask R-CNN models).
"""
if self.training and targets is None:
raise ValueError("In training mode, targets should be passed")
if self.training or targets is not None:
assert targets is not None
for target in targets:
boxes = target["boxes"]
if isinstance(boxes, torch.Tensor):
if len(boxes.shape) != 2 or boxes.shape[-1] != 4:
raise ValueError(
"Expected target boxes to be a tensor"
"of shape [N, 4], got {:}.".format(boxes.shape)
)
else:
raise ValueError(
"Expected target boxes to be of type "
"Tensor, got {:}.".format(type(boxes))
)
original_image_sizes = torch.jit.annotate(List[Tuple[int, int]], [])
for img in images:
val = img.shape[-2:]
assert len(val) == 2
original_image_sizes.append((val[0], val[1]))
images, targets = self.transform(images, targets)
# Check for degenerate boxes
# TODO: Move this to a function
if targets is not None:
for target_idx, target in enumerate(targets):
boxes = target["boxes"]
degenerate_boxes = boxes[:, 2:] <= boxes[:, :2]
if degenerate_boxes.any():
# print the first degenrate box
bb_idx = degenerate_boxes.any(dim=1).nonzero().view(-1)[0]
degen_bb: List[float] = boxes[bb_idx].tolist()
raise ValueError(
"All bounding boxes should have positive height and width."
" Found invaid box {} for target at index {}.".format(
degen_bb, target_idx
)
)
features = self.backbone(images.tensors)
if isinstance(features, torch.Tensor):
features = OrderedDict([("0", features)])
proposals, proposal_losses = self.rpn(images, features, targets)
detections, detector_losses = self.roi_heads(
features, proposals, images.image_sizes, targets
)
detections = self.transform.postprocess(
detections, images.image_sizes, original_image_sizes
)
losses = {}
losses.update(detector_losses)
losses.update(proposal_losses)
if torch.jit.is_scripting():
if not self._has_warned:
warnings.warn(
"RCNN always returns a (Losses, Detections) tuple in scripting"
)
self._has_warned = True
return (losses, detections)
else:
return self.eager_outputs(losses, detections, targets is not None)