forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
MultiLabelMarginCriterion.cu
159 lines (137 loc) · 5.1 KB
/
MultiLabelMarginCriterion.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
#include <THCUNN/THCUNN.h>
#include <THC/THCTensor.hpp>
#include <THCUNN/common.h>
#include <THC/THCReduceApplyUtils.cuh>
#include <TH/THHalf.h>
#include <THC/THCNumerics.cuh>
#include <c10/macros/Macros.h>
#include <thrust/functional.h>
#define MULTILABELMARGIN_THREADS 1024
template <typename Dtype, typename Acctype>
#if defined(__HIP_PLATFORM_HCC__)
C10_LAUNCH_BOUNDS_1(MULTILABELMARGIN_THREADS)
#endif
__global__ void cunn_MultiLabelMarginCriterion_updateOutput_kernel(Dtype *output,
Dtype *input,
THCIndex_t *target,
Dtype *istarget,
int nframe,
int dim,
int sizeaverage)
{
// Temporary sums (for mapreduce)
__shared__ Acctype sums[MULTILABELMARGIN_THREADS];
// vectors:
int k = blockIdx.x;
Dtype *input_k = input + k*dim;
THCIndex_t *target_k = target + k*dim;
Dtype *output_k = output + k;
Dtype *istarget_k = istarget + k*dim;
// zero istarget
for (int d = threadIdx.x; d < dim; d += blockDim.x) {
istarget_k[d] = ScalarConvert<int, Dtype>::to(0);
}
__syncthreads();
// mark targets in istarget
if (threadIdx.x == 0) {
for (int dt = 0; dt < dim; dt++) {
int target_idx = target_k[dt];
if (target_idx < 0) break;
istarget_k[target_idx] = ScalarConvert<int, Dtype>::to(1);
}
}
__syncthreads();
// iterate over targets
Acctype sum = 0;
for (int dt = 0; dt < dim; dt++) {
// next target:
int target_idx = target_k[dt];
if (target_idx < 0) break;
// current value for target
Dtype input_target_k = input_k[target_idx];
// compare to all inputs (multithreaded):
for (int d = threadIdx.x; d < dim; d += blockDim.x) {
// contribute to loss only if not a target
if (!ScalarConvert<Dtype, int>::to(istarget_k[d])) {
Dtype z = 1 - input_target_k + input_k[d];
if (z > 0)
sum += z;
}
}
}
// reduce
Acctype totalSum = reduceBlock(sums, blockDim.x, sum, thrust::plus<Acctype>(), (Acctype)0);
if (threadIdx.x == 0) {
if (sizeaverage) {
*output_k = ScalarConvert<Acctype, Dtype>::to((totalSum / dim) / nframe);
} else {
*output_k = ScalarConvert<Acctype, Dtype>::to(totalSum / dim);
}
}
}
template <typename Dtype, typename Acctype>
#if defined(__HIP_PLATFORM_HCC__)
C10_LAUNCH_BOUNDS_1(MULTILABELMARGIN_THREADS)
#endif
__global__ void cunn_MultiLabelMarginCriterion_updateGradInput_kernel(Dtype *gradInput,
Dtype *gradOutput,
Dtype *input,
THCIndex_t *target,
Dtype *istarget,
int nframe,
int dim,
int sizeaverage,
int reduce)
{
// Temporary sums (for mapreduce)
__shared__ Acctype sums[MULTILABELMARGIN_THREADS];
// vectors:
int k = blockIdx.x;
Dtype *input_k = input + k*dim;
Dtype *gradInput_k = gradInput + k*dim;
THCIndex_t *target_k = target + k*dim;
Dtype *istarget_k = istarget + k*dim;
Dtype *gradOutput_k = gradOutput;
if (!reduce) {
gradOutput_k += k;
}
// gain:
Dtype g = ScalarConvert<Acctype, Dtype>::to( sizeaverage && reduce ? 1./((Acctype)(nframe*dim)) : 1./((Acctype)dim) );
// zero gradients:
for (int d = threadIdx.x; d < dim; d += blockDim.x) {
gradInput_k[d] = ScalarConvert<int, Dtype>::to(0);
}
__syncthreads();
// iterate over targets
for (int dt = 0; dt < dim; dt++) {
// next target:
int target_idx = (int)target_k[dt];
if (target_idx < 0) break;
// current value for target
Dtype input_target_k = input_k[target_idx];
// compare to all inputs (multithreaded):
Acctype sum = 0;
for (int d = threadIdx.x; d < dim; d += blockDim.x) {
// contribute to loss only if not a target
if (!ScalarConvert<Dtype, int>::to(istarget_k[d])) {
Dtype z = 1 - input_target_k + input_k[d];
if (z > 0) {
sum -= g;
gradInput_k[d] += g;
}
}
}
__syncthreads();
// reduce sum
Acctype totalSum = reduceBlock(sums, blockDim.x, sum, thrust::plus<Acctype>(), (Acctype)0);
if (threadIdx.x == 0) {
gradInput_k[target_idx] += ScalarConvert<Acctype, Dtype>::to(totalSum);
}
}
for (int d = threadIdx.x; d < dim; d += blockDim.x) {
gradInput_k[d] *= *gradOutput_k;
}
}
#include <THCUNN/generic/MultiLabelMarginCriterion.cu>
#include <THC/THCGenerateFloatTypes.h>
#undef MULTILABELMARGIN_THREADS