forked from hpi-xnor/BMXNet-v2
-
Notifications
You must be signed in to change notification settings - Fork 0
/
bounding_box-inl.h
838 lines (786 loc) · 31.9 KB
/
bounding_box-inl.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file bounding_box-inl.h
* \brief bounding box util functions and operators
* \author Joshua Zhang
*/
#ifndef MXNET_OPERATOR_CONTRIB_BOUNDING_BOX_INL_H_
#define MXNET_OPERATOR_CONTRIB_BOUNDING_BOX_INL_H_
#include <mxnet/operator_util.h>
#include <dmlc/optional.h>
#include <nnvm/tuple.h>
#include <vector>
#include <utility>
#include <string>
#include <algorithm>
#include "../mshadow_op.h"
#include "../mxnet_op.h"
#include "../operator_common.h"
#include "../tensor/sort_op.h"
namespace mxnet {
namespace op {
namespace box_common_enum {
enum BoxType {kCorner, kCenter};
}
namespace box_nms_enum {
enum BoxNMSOpInputs {kData};
enum BoxNMSOpOutputs {kOut, kTemp};
enum BoxNMSOpResource {kTempSpace};
} // box_nms_enum
struct BoxNMSParam : public dmlc::Parameter<BoxNMSParam> {
float overlap_thresh;
float valid_thresh;
int topk;
int coord_start;
int score_index;
int id_index;
bool force_suppress;
int in_format;
int out_format;
DMLC_DECLARE_PARAMETER(BoxNMSParam) {
DMLC_DECLARE_FIELD(overlap_thresh).set_default(0.5)
.describe("Overlapping(IoU) threshold to suppress object with smaller score.");
DMLC_DECLARE_FIELD(valid_thresh).set_default(0)
.describe("Filter input boxes to those whose scores greater than valid_thresh.");
DMLC_DECLARE_FIELD(topk).set_default(-1)
.describe("Apply nms to topk boxes with descending scores, -1 to no restriction.");
DMLC_DECLARE_FIELD(coord_start).set_default(2)
.describe("Start index of the consecutive 4 coordinates.");
DMLC_DECLARE_FIELD(score_index).set_default(1)
.describe("Index of the scores/confidence of boxes.");
DMLC_DECLARE_FIELD(id_index).set_default(-1)
.describe("Optional, index of the class categories, -1 to disable.");
DMLC_DECLARE_FIELD(force_suppress).set_default(false)
.describe("Optional, if set false and id_index is provided, nms will only apply"
" to boxes belongs to the same category");
DMLC_DECLARE_FIELD(in_format).set_default(box_common_enum::kCorner)
.add_enum("corner", box_common_enum::kCorner)
.add_enum("center", box_common_enum::kCenter)
.describe("The input box encoding type. \n"
" \"corner\" means boxes are encoded as [xmin, ymin, xmax, ymax],"
" \"center\" means boxes are encodes as [x, y, width, height].");
DMLC_DECLARE_FIELD(out_format).set_default(box_common_enum::kCorner)
.add_enum("corner", box_common_enum::kCorner)
.add_enum("center", box_common_enum::kCenter)
.describe("The output box encoding type. \n"
" \"corner\" means boxes are encoded as [xmin, ymin, xmax, ymax],"
" \"center\" means boxes are encodes as [x, y, width, height].");
}
}; // BoxNMSParam
inline bool BoxNMSShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape> *in_attrs,
std::vector<TShape> *out_attrs) {
const BoxNMSParam& param = nnvm::get<BoxNMSParam>(attrs.parsed);
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 2U);
if (in_attrs->at(0).ndim() == 0U && out_attrs->at(0).ndim() == 0U) {
return false;
}
TShape& ishape = (*in_attrs)[0];
int indim = ishape.ndim();
CHECK(indim >= 2)
<< "input must have dim >= 2"
<< " the last two dimensions are num_box and box_width "
<< ishape << " provided";
int width_elem = ishape[indim - 1];
int expected = 5;
if (param.id_index > 0) {
expected += 1;
}
CHECK_GE(width_elem, expected)
<< "the last dimension must have at least 5 elements"
<< " namely (score, coordinates x 4) "
<< width_elem << " provided, " << expected << " expected.";
// check indices
int coord_start = param.coord_start;
int coord_end = param.coord_start + 3;
int score_index = param.score_index;
CHECK(score_index >= 0 && score_index < width_elem)
<< "score_index: " << score_index << " out of range: (0, "
<< width_elem << ")";
CHECK(score_index < coord_start || score_index > coord_end)
<< "score_index: " << score_index << " conflict with coordinates: ("
<< coord_start << ", " << coord_end << ").";
CHECK(coord_start >= 0 && coord_end < width_elem)
<< "coordinates: (" << coord_start << ", " << coord_end
<< ") out of range:: (0, " << width_elem << ")";
if (param.id_index >= 0) {
int id_index = param.id_index;
CHECK(id_index >= 0 && id_index < width_elem)
<< "id_index: " << id_index << " out of range: (0, "
<< width_elem << ")";
CHECK(id_index < coord_start || id_index > coord_end)
<< "id_index: " << id_index << " conflict with coordinates: ("
<< coord_start << ", " << coord_end << ").";
CHECK_NE(id_index, score_index)
<< "id_index: " << id_index << " conflict with score_index: " << score_index;
}
TShape oshape = ishape;
oshape[indim - 1] = 1;
SHAPE_ASSIGN_CHECK(*out_attrs, 0, ishape); // out_shape[0] == in_shape
SHAPE_ASSIGN_CHECK(*out_attrs, 1, oshape); // out_shape[1]
return true;
}
inline uint32_t BoxNMSNumVisibleOutputs(const NodeAttrs& attrs) {
return static_cast<uint32_t>(1);
}
template<typename DType>
int FilterScores(mshadow::Tensor<cpu, 1, DType> out_scores,
mshadow::Tensor<cpu, 1, int32_t> out_sorted_index,
mshadow::Tensor<cpu, 1, DType> scores,
mshadow::Tensor<cpu, 1, int32_t> sorted_index,
float valid_thresh) {
index_t j = 0;
for (index_t i = 0; i < scores.size(0); i++) {
if (scores[i] > valid_thresh) {
out_scores[j] = scores[i];
out_sorted_index[j] = sorted_index[i];
j++;
}
}
return j;
}
namespace mshadow_op {
struct less_than : public mxnet_op::tunable {
// a is x, b is sigma
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a, DType b) {
return static_cast<DType>(a < b);
}
}; // struct equal_to
} // namespace mshadow_op
struct corner_to_center {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType *data, int stride) {
int index = i * stride;
DType left = data[index];
if (left < 0) return;
DType top = data[index+1];
DType right = data[index+2];
DType bot = data[index+3];
data[index] = (left + right) / 2;
data[index+1] = (top + bot) / 2;
data[index+2] = right - left;
data[index+3] = bot - top;
}
};
struct center_to_corner {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType *data, int stride) {
int index = i * stride;
DType x = data[index];
if (x < 0) return;
DType y = data[index+1];
DType width = data[index+2] / 2;
DType height = data[index+3] / 2;
data[index] = x - width;
data[index+1] = y - height;
data[index+2] = x + width;
data[index+3] = y + height;
}
};
template<typename DType>
MSHADOW_XINLINE DType BoxArea(const DType *box, int encode) {
DType a1 = box[0];
DType a2 = box[1];
DType a3 = box[2];
DType a4 = box[3];
DType width, height;
if (box_common_enum::kCorner == encode) {
width = a3 - a1;
height = a4 - a2;
} else {
width = a3;
height = a4;
}
if (width < 0 || height < 0) {
return DType(0);
} else {
return width * height;
}
}
/*!
* \brief compute areas specialized for nms to reduce computation
*
* \param i the launched thread index (total thread num_batch * topk)
* \param out 1d array for areas (size num_batch * num_elem)
* \param in 1st coordinate of 1st box (buffer + coord_start)
* \param indices index to areas and in buffer (sorted_index)
* \param batch_start map (b, k) to compact index by indices[batch_start[b] + k]
* \param topk effective batch size of boxes, to be mapped to real index
* \param stride should be width_elem (e.g. 6 including cls and scores)
* \param encode passed to BoxArea to compute area
*/
struct compute_area {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType *out, const DType *in,
const int32_t *indices, const int32_t *batch_start,
int topk, int num_elem, int stride, int encode) {
int b = i / topk;
int k = i % topk;
int pos = static_cast<int>(batch_start[b]) + k;
if (pos >= static_cast<int>(batch_start[b + 1])) return;
int index = static_cast<int>(indices[pos]);
int in_index = index * stride;
out[index] = BoxArea(in + in_index, encode);
}
};
// compute line intersect along either height or width
template<typename DType>
MSHADOW_XINLINE DType Intersect(const DType *a, const DType *b, int encode) {
DType a1 = a[0];
DType a2 = a[2];
DType b1 = b[0];
DType b2 = b[2];
DType w;
if (box_common_enum::kCorner == encode) {
DType left = a1 > b1 ? a1 : b1;
DType right = a2 < b2 ? a2 : b2;
w = right - left;
} else {
DType aw = a2 / 2;
DType bw = b2 / 2;
DType al = a1 - aw;
DType ar = a1 + aw;
DType bl = b1 - bw;
DType br = b1 + bw;
DType left = bl > al ? bl : al;
DType right = br < ar ? br : ar;
w = right - left;
}
return w > 0 ? w : DType(0);
}
/*!
* \brief Implementation of the non-maximum suppression operation
*
* \param i the launched thread index
* \param index sorted index in descending order
* \param batch_start map (b, k) to compact index by indices[batch_start[b] + k]
* \param input the input of nms op
* \param areas pre-computed box areas
* \param k nms topk number
* \param ref compare reference position
* \param num number of input boxes in each batch
* \param stride input stride, usually 6 (id-score-x1-y1-x2-y2)
* \param offset_box box offset, usually 2
* \param thresh nms threshold
* \param force force suppress regardless of class id
* \param offset_id class id offset, used when force == false, usually 0
* \param encode box encoding type, corner(0) or center(1)
* \param DType the data type
*/
struct nms_impl {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, int32_t *index, const int32_t *batch_start,
const DType *input, const DType *areas,
int k, int ref, int num,
int stride, int offset_box, int offset_id,
float thresh, bool force, int encode) {
int b = i / k; // batch
int pos = i % k + ref + 1; // position
ref = static_cast<int>(batch_start[b]) + ref;
pos = static_cast<int>(batch_start[b]) + pos;
if (ref >= static_cast<int>(batch_start[b + 1])) return;
if (pos >= static_cast<int>(batch_start[b + 1])) return;
if (index[ref] < 0) return; // reference has been suppressed
if (index[pos] < 0) return; // self been suppressed
int ref_offset = static_cast<int>(index[ref]) * stride + offset_box;
int pos_offset = static_cast<int>(index[pos]) * stride + offset_box;
if (!force && offset_id >=0) {
int ref_id = static_cast<int>(input[ref_offset - offset_box + offset_id]);
int pos_id = static_cast<int>(input[pos_offset - offset_box + offset_id]);
if (ref_id != pos_id) return; // different class
}
DType intersect = Intersect(input + ref_offset, input + pos_offset, encode);
intersect *= Intersect(input + ref_offset + 1, input + pos_offset + 1, encode);
int ref_area_offset = static_cast<int>(index[ref]);
int pos_area_offset = static_cast<int>(index[pos]);
DType iou = intersect / (areas[ref_area_offset] + areas[pos_area_offset] - intersect);
if (iou > thresh) {
index[pos] = -1;
}
}
};
/*!
* \brief Assign output of nms by indexing input
*
* \param i the launched thread index (total num_batch)
* \param out output array [cls, conf, b0, b1, b2, b3]
* \param record book keeping the selected index for backward
* \param index compact sorted_index, use batch_start to access
* \param batch_start map(b, k) to compact index by index[batch_start[b] + k]
* \param k nms topk number
* \param num number of input boxes in each batch
* \param stride input stride, usually 6 (id-score-x1-y2-x2-y2)
*/
struct nms_assign {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType *out, DType *record, const DType *input,
const int32_t *index, const int32_t *batch_start,
int k, int num, int stride) {
int count = 0;
for (int j = 0; j < k; ++j) {
int pos = static_cast<int>(batch_start[i]) + j;
if (pos >= static_cast<int>(batch_start[i + 1])) return;
int location = static_cast<int>(index[pos]);
if (location >= 0) {
// copy to output
int out_location = (i * num + count) * stride;
int in_location = location * stride;
for (int s = 0; s < stride; ++s) {
out[out_location + s] = input[in_location + s];
}
// keep the index in the record for backward
record[i * num + count] = location;
++count;
}
}
}
};
struct nms_backward {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType *in_grad, const DType *out_grad,
const DType *record, int num, int stride) {
int index = static_cast<int>(record[i]);
if (index < 0) return;
int loc = index * stride;
int from_loc = i * stride;
for (int j = 0; j < stride; ++j) {
in_grad[loc + j] = out_grad[from_loc + j];
}
}
};
template<typename xpu>
void BoxNMSForward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
using namespace mshadow::expr;
using namespace mxnet_op;
CHECK_EQ(inputs.size(), 1U);
CHECK_EQ(outputs.size(), 2U) << "BoxNMS output: [output, temp]";
const BoxNMSParam& param = nnvm::get<BoxNMSParam>(attrs.parsed);
Stream<xpu> *s = ctx.get_stream<xpu>();
TShape in_shape = inputs[box_nms_enum::kData].shape_;
int indim = in_shape.ndim();
int num_batch = indim <= 2? 1 : in_shape.ProdShape(0, indim - 2);
int num_elem = in_shape[indim - 2];
int width_elem = in_shape[indim - 1];
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
Tensor<xpu, 3, DType> data = inputs[box_nms_enum::kData]
.get_with_shape<xpu, 3, DType>(Shape3(num_batch, num_elem, width_elem), s);
Tensor<xpu, 3, DType> out = outputs[box_nms_enum::kOut]
.get_with_shape<xpu, 3, DType>(Shape3(num_batch, num_elem, width_elem), s);
Tensor<xpu, 3, DType> record = outputs[box_nms_enum::kTemp]
.get_with_shape<xpu, 3, DType>(Shape3(num_batch, num_elem, 1), s);
// prepare workspace
Shape<1> sort_index_shape = Shape1(num_batch * num_elem);
Shape<3> buffer_shape = Shape3(num_batch, num_elem, width_elem);
Shape<1> batch_start_shape = Shape1(num_batch + 1);
// index
index_t int32_size = sort_index_shape.Size() * 3 + batch_start_shape.Size();
index_t dtype_size = sort_index_shape.Size() * 2;
if (req[0] == kWriteInplace) {
dtype_size += buffer_shape.Size();
}
// ceil up when sizeof(DType) is larger than sizeof(DType)
index_t int32_offset = (int32_size * sizeof(int32_t) - 1) / sizeof(DType) + 1;
index_t workspace_size = int32_offset + dtype_size;
Tensor<xpu, 1, DType> workspace = ctx.requested[box_nms_enum::kTempSpace]
.get_space_typed<xpu, 1, DType>(Shape1(workspace_size), s);
Tensor<xpu, 1, int32_t> sorted_index(
reinterpret_cast<int32_t*>(workspace.dptr_), sort_index_shape, s);
Tensor<xpu, 1, int32_t> all_sorted_index(sorted_index.dptr_ + sorted_index.MSize(),
sort_index_shape, s);
Tensor<xpu, 1, int32_t> batch_id(
all_sorted_index.dptr_ + all_sorted_index.MSize(), sort_index_shape, s);
Tensor<xpu, 1, int32_t> batch_start(batch_id.dptr_ + batch_id.MSize(), batch_start_shape, s);
Tensor<xpu, 1, DType> scores(workspace.dptr_ + int32_offset,
sort_index_shape, s);
Tensor<xpu, 1, DType> areas(scores.dptr_ + scores.MSize(), sort_index_shape, s);
Tensor<xpu, 3, DType> buffer = data;
if (req[0] == kWriteInplace) {
// make copy
buffer = Tensor<xpu, 3, DType>(areas.dptr_ + areas.MSize(), buffer_shape, s);
buffer = F<mshadow_op::identity>(data);
}
// indecies
int score_index = param.score_index;
int coord_start = param.coord_start;
int id_index = param.id_index;
// sort topk
int topk = param.topk < 0? num_elem : std::min(num_elem, param.topk);
if (topk < 1) {
out = F<mshadow_op::identity>(buffer);
record = reshape(range<DType>(0, num_batch * num_elem), record.shape_);
return;
}
// use batch_id and areas as temporary storage
Tensor<xpu, 1, DType> all_scores = areas;
// Tensor<xpu, 1, DType> all_sorted_index = areas;
all_scores = reshape(slice<2>(buffer, score_index, score_index + 1), all_scores.shape_);
all_sorted_index = range<int32_t>(0, num_batch * num_elem);
// filter scores but keep original sorted_index value
// move valid score and index to the front, return valid size
int num_valid = mxnet::op::FilterScores(scores, sorted_index, all_scores, all_sorted_index,
param.valid_thresh);
// if everything is filtered, output -1
if (num_valid == 0) {
record = -1;
out = -1;
return;
}
// mark the invalid boxes before nms
if (num_valid < num_batch * num_elem) {
slice<0>(sorted_index, num_valid, num_batch * num_elem) = -1;
}
// only sort the valid scores and batch_id
Shape<1> valid_score_shape = Shape1(num_valid);
Tensor<xpu, 1, DType> valid_scores(scores.dptr_, valid_score_shape, s);
Tensor<xpu, 1, int32_t> valid_sorted_index(sorted_index.dptr_, valid_score_shape, s);
Tensor<xpu, 1, int32_t> valid_batch_id(batch_id.dptr_, valid_score_shape, s);
// sort index by batch_id then score (stable sort)
mxnet::op::SortByKey(valid_scores, valid_sorted_index, false);
valid_batch_id = (valid_sorted_index / ScalarExp<int32_t>(num_elem));
mxnet::op::SortByKey(valid_batch_id, valid_sorted_index, true);
// calculate batch_start: accumulated sum to denote 1st sorted_index for a given batch_index
valid_batch_id = (valid_sorted_index / ScalarExp<int32_t>(num_elem));
for (int b = 0; b < num_batch + 1; b++) {
slice<0>(batch_start, b, b + 1) = reduce_keepdim<red::sum, false>(
F<mshadow_op::less_than>(valid_batch_id, ScalarExp<int32_t>(b)), 0);
}
// pre-compute areas of candidates
areas = 0;
Kernel<compute_area, xpu>::Launch(s, num_batch * topk,
areas.dptr_, buffer.dptr_ + coord_start, sorted_index.dptr_, batch_start.dptr_,
topk, num_elem, width_elem, param.in_format);
// apply nms
// go through each box as reference, suppress if overlap > threshold
// sorted_index with -1 is marked as suppressed
for (int ref = 0; ref < topk; ++ref) {
int num_worker = topk - ref - 1;
if (num_worker < 1) continue;
Kernel<nms_impl, xpu>::Launch(s, num_batch * num_worker,
sorted_index.dptr_, batch_start.dptr_, buffer.dptr_, areas.dptr_,
num_worker, ref, num_elem,
width_elem, coord_start, id_index,
param.overlap_thresh, param.force_suppress, param.in_format);
}
// store the results to output, keep a record for backward
record = -1;
out = -1;
Kernel<nms_assign, xpu>::Launch(s, num_batch,
out.dptr_, record.dptr_, buffer.dptr_, sorted_index.dptr_, batch_start.dptr_,
topk, num_elem, width_elem);
// convert encoding
if (param.in_format != param.out_format) {
if (box_common_enum::kCenter == param.out_format) {
Kernel<corner_to_center, xpu>::Launch(s, num_batch * num_elem,
out.dptr_ + coord_start, width_elem);
} else {
Kernel<center_to_corner, xpu>::Launch(s, num_batch * num_elem,
out.dptr_ + coord_start, width_elem);
}
}
});
}
template<typename xpu>
void BoxNMSBackward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
using namespace mshadow::expr;
using namespace mxnet_op;
CHECK_EQ(inputs.size(), 4U);
CHECK_EQ(outputs.size(), 1U);
Stream<xpu> *s = ctx.get_stream<xpu>();
TShape in_shape = outputs[box_nms_enum::kData].shape_;
int indim = in_shape.ndim();
int num_batch = indim <= 2? 1 : in_shape.ProdShape(0, indim - 2);
int num_elem = in_shape[indim - 2];
int width_elem = in_shape[indim - 1];
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
Tensor<xpu, 3, DType> out_grad = inputs[box_nms_enum::kOut]
.get_with_shape<xpu, 3, DType>(Shape3(num_batch, num_elem, width_elem), s);
Tensor<xpu, 3, DType> in_grad = outputs[box_nms_enum::kData]
.get_with_shape<xpu, 3, DType>(Shape3(num_batch, num_elem, width_elem), s);
Tensor<xpu, 3, DType> record = inputs[box_nms_enum::kTemp + 2]
.get_with_shape<xpu, 3, DType>(Shape3(num_batch, num_elem, 1), s);
in_grad = 0;
Kernel<nms_backward, xpu>::Launch(s, num_batch * num_elem, in_grad.dptr_,
out_grad.dptr_, record.dptr_, num_elem, width_elem);
});
}
struct BoxOverlapParam : public dmlc::Parameter<BoxOverlapParam> {
int format;
DMLC_DECLARE_PARAMETER(BoxOverlapParam) {
DMLC_DECLARE_FIELD(format).set_default(box_common_enum::kCorner)
.add_enum("corner", box_common_enum::kCorner)
.add_enum("center", box_common_enum::kCenter)
.describe("The box encoding type. \n"
" \"corner\" means boxes are encoded as [xmin, ymin, xmax, ymax],"
" \"center\" means boxes are encodes as [x, y, width, height].");
}
}; // BoxOverlapParam
inline bool BoxOverlapShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape> *in_attrs,
std::vector<TShape> *out_attrs) {
CHECK_EQ(in_attrs->size(), 2U);
CHECK_EQ(out_attrs->size(), 1U);
TShape& lshape = (*in_attrs)[0];
TShape& rshape = (*in_attrs)[1];
CHECK_GE(lshape.ndim(), 2)
<< "lhs must have dim >= 2 "
<< lshape.ndim() << " provided";
int ldim = lshape[lshape.ndim() - 1];
CHECK_EQ(ldim, 4)
<< "last dimension of lhs must be 4 "
<< ldim << " provided";
CHECK_GE(rshape.ndim(), 2)
<< "rhs must have dim >= 2 "
<< rshape.ndim() << " provided";
int rdim = rshape[rshape.ndim() - 1];
CHECK_EQ(rdim, 4)
<< "last dimension of rhs must be 4 "
<< rdim << " provided";
// assign output shape
TShape oshape(lshape.ndim() + rshape.ndim() - 2);
int idx = 0;
for (index_t i = 0; i < lshape.ndim() - 1; ++i) {
oshape[idx++] = lshape[i];
}
for (index_t i = 0; i < rshape.ndim() - 1; ++i) {
oshape[idx++] = rshape[i];
}
SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape);
return true;
}
struct compute_overlap {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType *out, const DType *lhs,
const DType *rhs, int num,
int begin, int stride, int encode) {
int l = i / num;
int r = i % num;
int l_index = l * stride + begin;
int r_index = r * stride + begin;
DType intersect = Intersect(lhs + l_index, rhs + r_index, encode);
intersect *= Intersect(lhs + l_index + 1, rhs + r_index + 1, encode);
if (intersect <= 0) {
out[i] = DType(0);
return;
}
DType l_area = BoxArea(lhs + l_index, encode);
DType r_area = BoxArea(rhs + r_index, encode);
out[i] = intersect / (l_area + r_area - intersect);
}
};
template<typename xpu>
void BoxOverlapForward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
using namespace mshadow::expr;
using namespace mxnet_op;
CHECK_EQ(inputs.size(), 2U);
CHECK_EQ(outputs.size(), 1U);
const BoxOverlapParam& param = nnvm::get<BoxOverlapParam>(attrs.parsed);
Stream<xpu> *s = ctx.get_stream<xpu>();
TShape lshape = inputs[0].shape_;
TShape rshape = inputs[1].shape_;
int lsize = lshape.ProdShape(0, lshape.ndim() - 1);
int rsize = rshape.ProdShape(0, rshape.ndim() - 1);
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
Tensor<xpu, 1, DType> lhs = inputs[0]
.get_with_shape<xpu, 1, DType>(Shape1(lsize * 4), s);
Tensor<xpu, 1, DType> rhs = inputs[1]
.get_with_shape<xpu, 1, DType>(Shape1(rsize * 4), s);
Tensor<xpu, 1, DType> out = outputs[0]
.get_with_shape<xpu, 1, DType>(Shape1(lsize * rsize), s);
Kernel<compute_overlap, xpu>::Launch(s, lsize * rsize, out.dptr_,
lhs.dptr_, rhs.dptr_, rsize, 0, 4, param.format);
});
}
template<typename xpu>
void BoxOverlapBackward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
using namespace mshadow::expr;
using namespace mxnet_op;
CHECK_EQ(inputs.size(), 1U);
CHECK_EQ(outputs.size(), 2U);
Stream<xpu> *s = ctx.get_stream<xpu>();
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
Tensor<xpu, 2, DType> in_grad_lhs = outputs[0].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> in_grad_rhs = outputs[1].FlatTo2D<xpu, DType>(s);
// TODO(Joshua Zhang): allow backprop?
in_grad_lhs = 0;
in_grad_rhs = 0;
});
}
struct BipartiteMatchingParam : public dmlc::Parameter<BipartiteMatchingParam> {
bool is_ascend;
float threshold;
int topk;
DMLC_DECLARE_PARAMETER(BipartiteMatchingParam) {
DMLC_DECLARE_FIELD(is_ascend).set_default(false)
.describe("Use ascend order for scores instead of descending. "
"Please set threshold accordingly.");
DMLC_DECLARE_FIELD(threshold)
.describe("Ignore matching when score < thresh, if is_ascend=false, "
"or ignore score > thresh, if is_ascend=true.");
DMLC_DECLARE_FIELD(topk).set_default(-1)
.describe("Limit the number of matches to topk, set -1 for no limit");
}
}; // BipartiteMatchingParam
inline bool MatchingShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape> *in_attrs,
std::vector<TShape> *out_attrs) {
// const BipartiteMatchingParam& param = nnvm::get<BipartiteMatchingParam>(attrs.parsed);
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 2U);
TShape& dshape = (*in_attrs)[0];
CHECK_GE(dshape.ndim(), 2)
<< "score matrix must have dim >= 2 "
<< dshape.ndim() << " provided";
// assign output shape
TShape oshape(dshape.ndim() - 1);
for (index_t i = 0; i < dshape.ndim() - 1; ++i) {
oshape[i] = dshape[i];
}
SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape);
oshape[oshape.ndim() - 1] = dshape[dshape.ndim() - 1];
SHAPE_ASSIGN_CHECK(*out_attrs, 1, oshape);
return true;
}
struct bipartite_matching {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType *row_marker, DType *col_marker,
const DType *scores, const int32_t *sorted_index,
int num_batch, int num_row, int num_col,
float threshold, bool is_ascend, int topk) {
int stride = num_row * num_col;
const int32_t *index = sorted_index + i * stride;
const DType *score = scores + i * stride;
DType *rmarker = row_marker + i * num_row;
DType *cmarker = col_marker + i * num_col;
int count = 0;
for (int j = 0; j < stride; ++j) {
int idx = static_cast<int>(index[j]) % stride;
int r = idx / num_col;
int c = idx % num_col;
if (rmarker[r] == -1 && cmarker[c] == -1) {
if ((!is_ascend && score[j] > threshold) ||
(is_ascend && score[j] < threshold)) {
rmarker[r] = c;
cmarker[c] = r;
++count;
if (topk > 0 && count > topk) {
break;
}
} else {
// already encounter bad scores
break;
}
}
}
}
};
template<typename xpu>
void BipartiteMatchingForward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
using namespace mshadow::expr;
using namespace mxnet_op;
CHECK_EQ(inputs.size(), 1U);
CHECK_EQ(outputs.size(), 2U);
const BipartiteMatchingParam& param = nnvm::get<BipartiteMatchingParam>(attrs.parsed);
Stream<xpu> *s = ctx.get_stream<xpu>();
TShape dshape = inputs[0].shape_;
int row = dshape[dshape.ndim() - 2];
int col = dshape[dshape.ndim() - 1];
int batch_size = dshape.Size() / row / col;
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
Tensor<xpu, 1, DType> scores = inputs[0]
.get_with_shape<xpu, 1, DType>(Shape1(dshape.Size()), s);
Tensor<xpu, 2, DType> row_marker = outputs[0]
.get_with_shape<xpu, 2, DType>(Shape2(batch_size, row), s);
Tensor<xpu, 2, DType> col_marker = outputs[1]
.get_with_shape<xpu, 2, DType>(Shape2(batch_size, col), s);
Shape<1> sort_index_shape = Shape1(dshape.Size());
index_t workspace_size = sort_index_shape.Size();
workspace_size += (sort_index_shape.Size() * 2 * sizeof(int32_t) - 1) / sizeof(DType) + 1;
Tensor<xpu, 1, DType> workspace = ctx.requested[0]
.get_space_typed<xpu, 1, DType>(Shape1(workspace_size), s);
Tensor<xpu, 1, DType> scores_copy(workspace.dptr_,
sort_index_shape, s);
Tensor<xpu, 1, int32_t> sorted_index(reinterpret_cast<int32_t*>(
scores_copy.dptr_ + scores_copy.MSize()), sort_index_shape, s);
Tensor<xpu, 1, int32_t> batch_id(sorted_index.dptr_ + sorted_index.MSize(),
sort_index_shape, s);
// sort according to score
scores_copy = F<mshadow_op::identity>(scores);
sorted_index = range<int32_t>(0, dshape.Size());
mxnet::op::SortByKey(scores_copy, sorted_index, param.is_ascend);
batch_id = (sorted_index / ScalarExp<int32_t>(row * col));
mxnet::op::SortByKey(batch_id, scores_copy, true);
batch_id = (sorted_index / ScalarExp<int32_t>(row * col));
mxnet::op::SortByKey(batch_id, sorted_index, true);
// bipartite matching, parallelization is limited to batch_size
row_marker = -1;
col_marker = -1;
Kernel<bipartite_matching, xpu>::Launch(s, batch_size, row_marker.dptr_,
col_marker.dptr_, scores_copy.dptr_, sorted_index.dptr_, batch_size, row, col,
param.threshold, param.is_ascend, param.topk);
});
}
template<typename xpu>
void BipartiteMatchingBackward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
using namespace mshadow::expr;
using namespace mxnet_op;
CHECK_EQ(inputs.size(), 2U);
CHECK_EQ(outputs.size(), 1U);
Stream<xpu> *s = ctx.get_stream<xpu>();
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
Tensor<xpu, 2, DType> in_grad = outputs[0].FlatTo2D<xpu, DType>(s);
// TODO(Joshua Zhang): allow backprop?
in_grad = 0;
});
}
} // namespace op
} // namespace mxnet
#endif // MXNET_OPERATOR_CONTRIB_BOUNDING_BOX_INL_H_