-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathroi_pooling_layer.cpp
168 lines (148 loc) · 6.06 KB
/
roi_pooling_layer.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
#include <algorithm>
#include <cfloat>
#include <vector>
#include "caffe/layers/roi_pooling_layer.hpp"
using std::max;
using std::min;
using std::floor;
using std::ceil;
namespace caffe {
template <typename Dtype>
void ROIPoolingLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
ROIPoolingParameter roi_pool_param = this->layer_param_.roi_pooling_param();
CHECK_GT(roi_pool_param.pooled_h(), 0)
<< "pooled_h must be > 0";
CHECK_GT(roi_pool_param.pooled_w(), 0)
<< "pooled_w must be > 0";
pooled_height_ = roi_pool_param.pooled_h();
pooled_width_ = roi_pool_param.pooled_w();
spatial_scale_ = roi_pool_param.spatial_scale();
LOG(INFO) << "Spatial scale: " << spatial_scale_;
}
template <typename Dtype>
void ROIPoolingLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
channels_ = bottom[0]->channels();
height_ = bottom[0]->height();
width_ = bottom[0]->width();
top[0]->Reshape(bottom[1]->num(), channels_, pooled_height_,
pooled_width_);
max_idx_.Reshape(bottom[1]->num(), channels_, pooled_height_,
pooled_width_);
}
template <typename Dtype>
void ROIPoolingLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
const Dtype* bottom_data = bottom[0]->cpu_data();
const Dtype* bottom_rois = bottom[1]->cpu_data();
// Number of ROIs
int num_rois = bottom[1]->num();
int batch_size = bottom[0]->num();
int top_count = top[0]->count();
Dtype* top_data = top[0]->mutable_cpu_data();
caffe_set(top_count, Dtype(-FLT_MAX), top_data);
int* argmax_data = max_idx_.mutable_cpu_data();
caffe_set(top_count, -1, argmax_data);
// For each ROI R = [batch_index x1 y1 x2 y2]: max pool over R
for (int n = 0; n < num_rois; ++n) {
int roi_batch_ind = bottom_rois[0];
int roi_start_w = round(bottom_rois[1] * spatial_scale_);
int roi_start_h = round(bottom_rois[2] * spatial_scale_);
int roi_end_w = round(bottom_rois[3] * spatial_scale_);
int roi_end_h = round(bottom_rois[4] * spatial_scale_);
CHECK_GE(roi_batch_ind, 0);
CHECK_LT(roi_batch_ind, batch_size);
int roi_height = max(roi_end_h - roi_start_h + 1, 1);
int roi_width = max(roi_end_w - roi_start_w + 1, 1);
const Dtype bin_size_h = static_cast<Dtype>(roi_height)
/ static_cast<Dtype>(pooled_height_);
const Dtype bin_size_w = static_cast<Dtype>(roi_width)
/ static_cast<Dtype>(pooled_width_);
const Dtype* batch_data = bottom_data + bottom[0]->offset(roi_batch_ind);
for (int c = 0; c < channels_; ++c) {
for (int ph = 0; ph < pooled_height_; ++ph) {
for (int pw = 0; pw < pooled_width_; ++pw) {
// Compute pooling region for this output unit:
// start (included) = floor(ph * roi_height / pooled_height_)
// end (excluded) = ceil((ph + 1) * roi_height / pooled_height_)
int hstart = static_cast<int>(floor(static_cast<Dtype>(ph)
* bin_size_h));
int wstart = static_cast<int>(floor(static_cast<Dtype>(pw)
* bin_size_w));
int hend = static_cast<int>(ceil(static_cast<Dtype>(ph + 1)
* bin_size_h));
int wend = static_cast<int>(ceil(static_cast<Dtype>(pw + 1)
* bin_size_w));
hstart = min(max(hstart + roi_start_h, 0), height_);
hend = min(max(hend + roi_start_h, 0), height_);
wstart = min(max(wstart + roi_start_w, 0), width_);
wend = min(max(wend + roi_start_w, 0), width_);
bool is_empty = (hend <= hstart) || (wend <= wstart);
const int pool_index = ph * pooled_width_ + pw;
if (is_empty) {
top_data[pool_index] = 0;
argmax_data[pool_index] = -1;
}
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
const int index = h * width_ + w;
if (batch_data[index] > top_data[pool_index]) {
top_data[pool_index] = batch_data[index];
argmax_data[pool_index] = index;
}
}
}
}
}
// Increment all data pointers by one channel
batch_data += bottom[0]->offset(0, 1);
top_data += top[0]->offset(0, 1);
argmax_data += max_idx_.offset(0, 1);
}
// Increment ROI data pointer
bottom_rois += bottom[1]->offset(1);
}
}
template <typename Dtype>
void ROIPoolingLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
if (propagate_down[1]) {
LOG(FATAL) << this->type()
<< " Layer cannot backpropagate to roi inputs.";
}
if (!propagate_down[0]) {
return;
}
const Dtype* bottom_rois = bottom[1]->cpu_data();
const Dtype* top_diff = top[0]->cpu_diff();
Dtype* bottom_diff = bottom[0]->mutable_cpu_diff();
caffe_set(bottom[0]->count(), Dtype(0.), bottom_diff);
const int* argmax_data = max_idx_.cpu_data();
const int num_rois = top[0]->num();
// Accumulate gradient over all ROIs
for (int roi_n = 0; roi_n < num_rois; ++roi_n) {
int roi_batch_ind = bottom_rois[roi_n * 5];
// Accumulate gradients over each bin in this ROI
for (int c = 0; c < channels_; ++c) {
for (int ph = 0; ph < pooled_height_; ++ph) {
for (int pw = 0; pw < pooled_width_; ++pw) {
int offset_top = ((roi_n * channels_ + c) * pooled_height_ + ph)
* pooled_width_ + pw;
int argmax_index = argmax_data[offset_top];
if (argmax_index >= 0) {
int offset_bottom = (roi_batch_ind * channels_ + c) * height_
* width_ + argmax_index;
bottom_diff[offset_bottom] += top_diff[offset_top];
}
}
}
}
}
}
#ifdef CPU_ONLY
STUB_GPU(ROIPoolingLayer);
#endif
INSTANTIATE_CLASS(ROIPoolingLayer);
REGISTER_LAYER_CLASS(ROIPooling);
} // namespace caffe