forked from rahul-goel/fused-ssim
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathssim_opt_v01.cu
444 lines (402 loc) · 13.1 KB
/
ssim_opt_v01.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
#include <torch/extension.h>
#include <cooperative_groups.h>
#include <algorithm>
#include <iostream>
#include <c10/cuda/CUDAGuard.h>
namespace cg = cooperative_groups;
#define G_00 0.001028380123898387f
#define G_01 0.0075987582094967365f
#define G_02 0.036000773310661316f
#define G_03 0.10936068743467331f
#define G_04 0.21300552785396576f
#define G_05 0.26601171493530273f
#define G_06 0.21300552785396576f
#define G_07 0.10936068743467331f
#define G_08 0.036000773310661316f
#define G_09 0.0075987582094967365f
#define G_10 0.001028380123898387f
// block size
#define BX 32
#define BY 32
// shared memory size
#define SX (BX + 10)
#define SSX (BX + 10)
#define SY (BY + 10)
// convolution scratchpad size
#define CX (BX)
#define CCX (BX + 0)
#define CY (BY + 10)
/**
* Get pixel value from image
*
* @param img Image tensor
* @param b Batch index
* @param c Channel index
* @param y Y coordinate
* @param x X coordinate
* @param CH Number of channels
* @param H Height of image
* @param W Width of image
*/
__device__ float get_pix_value(const float* img, const int b, const int c, const int y, const int x, const int CH, const int H, const int W) {
if (x >= W || y >= H || x < 0 || y < 0) {
return 0.0f;
} else {
return img[b * CH * H * W + c * H * W + y * W + x];
}
}
__device__ void load_into_shared(float pixels[SY][SSX], const float *inp, const int CH, const int H, const int W, const int i) {
auto block = cg::this_thread_block();
const int batch = block.group_index().z;
const int start_y = block.group_index().y * BY;
const int start_x = block.group_index().x * BX;
const int cnt = SY * SX;
const int num_blocks = (cnt + BX * BY - 1) / (BX * BY);
for (int b = 0; b < num_blocks; ++b) {
int tid = b * (BX * BY) + block.thread_rank();
if (tid < cnt) {
int local_y = tid / SX;
int local_x = tid % SX;
int y = start_y + local_y;
int x = start_x + local_x;
float one = get_pix_value(inp, batch, i, y - 5, x - 5, CH, H, W);
pixels[local_y][local_x] = one;
}
}
}
__device__ void multiply_shared_mem(float pix1[SY][SSX], float pix2[SY][SSX]) {
auto block = cg::this_thread_block();
const int cnt = SY * SX;
const int num_blocks = (cnt + BX * BY - 1) / (BX * BY);
for (int b = 0; b < num_blocks; ++b) {
int tid = b * (BX * BY) + block.thread_rank();
if (tid < cnt) {
int local_y = tid / SX;
int local_x = tid % SX;
float one = pix1[local_y][local_x];
float two = pix2[local_y][local_x];
pix1[local_y][local_x] = one * two;
}
}
}
__device__ inline float do_sq(float val) {
return val * val;
}
__device__ void
flush_conv_scratch(float buf[CY][CCX]) {
auto block = cg::this_thread_block();
const int cnt = CY * CX;
const int num_blocks = (cnt + BX * BY - 1) / (BX * BY);
for (int b = 0; b < num_blocks; ++b) {
const int tid = b * (BX * BY) + block.thread_rank();
if (tid < cnt) {
const int local_y = tid / CX;
const int local_x = tid % CX;
buf[local_y][local_x] = 0.0f;
}
}
}
__device__ void do_separable_conv_x(float pixels[SY][SSX], float opt[CY][CCX], int H, int W, bool sq = false) {
auto block = cg::this_thread_block();
int local_y = block.thread_index().y;
int local_x = block.thread_index().x + 5;
float val = 0.0f;
if (sq) {
val += G_00 * do_sq(pixels[local_y][local_x - 5]);
val += G_01 * do_sq(pixels[local_y][local_x - 4]);
val += G_02 * do_sq(pixels[local_y][local_x - 3]);
val += G_03 * do_sq(pixels[local_y][local_x - 2]);
val += G_04 * do_sq(pixels[local_y][local_x - 1]);
val += G_05 * do_sq(pixels[local_y][local_x ]);
val += G_06 * do_sq(pixels[local_y][local_x + 1]);
val += G_07 * do_sq(pixels[local_y][local_x + 2]);
val += G_08 * do_sq(pixels[local_y][local_x + 3]);
val += G_09 * do_sq(pixels[local_y][local_x + 4]);
val += G_10 * do_sq(pixels[local_y][local_x + 5]);
} else {
val += G_00 * pixels[local_y][local_x - 5];
val += G_01 * pixels[local_y][local_x - 4];
val += G_02 * pixels[local_y][local_x - 3];
val += G_03 * pixels[local_y][local_x - 2];
val += G_04 * pixels[local_y][local_x - 1];
val += G_05 * pixels[local_y][local_x ];
val += G_06 * pixels[local_y][local_x + 1];
val += G_07 * pixels[local_y][local_x + 2];
val += G_08 * pixels[local_y][local_x + 3];
val += G_09 * pixels[local_y][local_x + 4];
val += G_10 * pixels[local_y][local_x + 5];
}
opt[local_y][local_x] = val;
val = 0.0f;
local_y = block.thread_index().y + BY;
if (local_y < SY) {
if (sq) {
val += G_00 * do_sq(pixels[local_y][local_x - 5]);
val += G_01 * do_sq(pixels[local_y][local_x - 4]);
val += G_02 * do_sq(pixels[local_y][local_x - 3]);
val += G_03 * do_sq(pixels[local_y][local_x - 2]);
val += G_04 * do_sq(pixels[local_y][local_x - 1]);
val += G_05 * do_sq(pixels[local_y][local_x ]);
val += G_06 * do_sq(pixels[local_y][local_x + 1]);
val += G_07 * do_sq(pixels[local_y][local_x + 2]);
val += G_08 * do_sq(pixels[local_y][local_x + 3]);
val += G_09 * do_sq(pixels[local_y][local_x + 4]);
val += G_10 * do_sq(pixels[local_y][local_x + 5]);
} else {
val += G_00 * pixels[local_y][local_x - 5];
val += G_01 * pixels[local_y][local_x - 4];
val += G_02 * pixels[local_y][local_x - 3];
val += G_03 * pixels[local_y][local_x - 2];
val += G_04 * pixels[local_y][local_x - 1];
val += G_05 * pixels[local_y][local_x ];
val += G_06 * pixels[local_y][local_x + 1];
val += G_07 * pixels[local_y][local_x + 2];
val += G_08 * pixels[local_y][local_x + 3];
val += G_09 * pixels[local_y][local_x + 4];
val += G_10 * pixels[local_y][local_x + 5];
}
opt[local_y][local_x] = val;
}
}
__device__ float do_separable_conv_y(float pixels[CY][CCX], int H, int W, bool sq = false) {
auto block = cg::this_thread_block();
int local_y = block.thread_index().y + 5;
int local_x = block.thread_index().x + 5;
float val = 0.0f;
val += G_00 * pixels[local_y - 5][local_x];
val += G_01 * pixels[local_y - 4][local_x];
val += G_02 * pixels[local_y - 3][local_x];
val += G_03 * pixels[local_y - 2][local_x];
val += G_04 * pixels[local_y - 1][local_x];
val += G_05 * pixels[local_y ][local_x];
val += G_06 * pixels[local_y + 1][local_x];
val += G_07 * pixels[local_y + 2][local_x];
val += G_08 * pixels[local_y + 3][local_x];
val += G_09 * pixels[local_y + 4][local_x];
val += G_10 * pixels[local_y + 5][local_x];
return val;
}
__global__ void fusedssimCUDA(
int H,
int W,
int CH,
float C1,
float C2,
float* img1,
float* img2,
float* ssim_map,
float* dm_dmu1 = nullptr,
float* dm_dsigma1_sq = nullptr,
float* dm_dsigma12 = nullptr
)
{
auto block = cg::this_thread_block();
const int pix_y = block.group_index().y * BY + block.thread_index().y;
const int pix_x = block.group_index().x * BX + block.thread_index().x;
const int pix_id = pix_y * W + pix_x;
const int num_pix = H * W;
const int batch = block.group_index().z;
// shared memory that will be used to load pixels temporarily
__shared__ float buf1[SY][SSX];
__shared__ float buf2[SY][SSX];
__shared__ float buf3[CY][CCX];
for (int i = 0; i < CH; ++i) {
// load into shared
load_into_shared(buf1, img1, CH, H, W, i);
block.sync();
// calculate mu1
do_separable_conv_x(buf1, buf3, H, W);
block.sync();
float mu1 = do_separable_conv_y(buf3, H, W);
block.sync();
// calculate sigma1_sq
do_separable_conv_x(buf1, buf3, H, W, true);
block.sync();
float sigma1_sq = do_separable_conv_y(buf3, H, W) - mu1 * mu1;
// calculate mu2
load_into_shared(buf2, img2, CH, H, W, i);
block.sync();
do_separable_conv_x(buf2, buf3, H, W);
block.sync();
float mu2 = do_separable_conv_y(buf3, H, W);
block.sync();
// calculate sigma2_sq
do_separable_conv_x(buf2, buf3, H, W, true);
block.sync();
float sigma2_sq = do_separable_conv_y(buf3, H, W) - mu2 * mu2;
block.sync();
// calculate sigma12
multiply_shared_mem(buf1, buf2);
block.sync();
do_separable_conv_x(buf1, buf3, H, W);
block.sync();
float sigma12 = do_separable_conv_y(buf3, H, W) - mu1 * mu2;
block.sync();
float mu1_sq = mu1 * mu1;
float mu2_sq = mu2 * mu2;
float mu1_mu2 = mu1 * mu2;
float C = (2.0f * mu1_mu2 + C1);
float D = (2.0f * sigma12 + C2);
float A = (mu1_sq + mu2_sq + C1);
float B = (sigma1_sq + sigma2_sq + C2);
float m = (C * D) / (A * B);
if (pix_x < W && pix_y < H) {
const int global_idx = batch * CH * num_pix + i * num_pix + pix_id;
ssim_map[global_idx] = m;
if (dm_dmu1) {
dm_dmu1[global_idx] = (
(mu2 * 2.0f * D) / (A * B)
-(mu2 * 2.0f * C) / (A * B)
-(mu1 * 2.0f * C * D) / ( A * A * B)
+(mu1 * 2.0f * C * D) / (A * B * B)
);
dm_dsigma1_sq[global_idx] = ((-C * D) / (A * B * B));
dm_dsigma12[global_idx] = ((2 * C) / (A * B));
}
}
}
}
__global__ void fusedssim_backwardCUDA(
int H,
int W,
int CH,
float C1,
float C2,
float* img1,
float* img2,
float *dL_dmap,
float *dL_dimg1,
float* dm_dmu1 = nullptr,
float* dm_dsigma1_sq = nullptr,
float* dm_dsigma12 = nullptr
)
{
auto block = cg::this_thread_block();
const int pix_y = block.group_index().y * BY + block.thread_index().y;
const int pix_x = block.group_index().x * BX + block.thread_index().x;
const int pix_id = pix_y * W + pix_x;
const int num_pix = H * W;
const int batch = block.group_index().z;
// shared memory that will be used to load pixels temporarily
__shared__ float buf1[SY][SSX];
__shared__ float buf2[SY][SSX];
__shared__ float buf3[CY][CCX];
for (int i = 0; i < CH; ++i) {
float dL_dpix = 0.0f;
float tmp = 0.0f;
float pix1 = get_pix_value(img1, batch, i, pix_y, pix_x, CH, H, W);
float pix2 = get_pix_value(img2, batch, i, pix_y, pix_x, CH, H, W);
load_into_shared(buf1, dL_dmap, CH, H, W, i);
// gradient from mu1
load_into_shared(buf2, dm_dmu1, CH, H, W, i);
block.sync();
multiply_shared_mem(buf2, buf1);
block.sync();
flush_conv_scratch(buf3);
block.sync();
do_separable_conv_x(buf2, buf3, H, W);
block.sync();
tmp = do_separable_conv_y(buf3, H, W);
block.sync();
dL_dpix += tmp;
// gradient from sigma1_sq
load_into_shared(buf2, dm_dsigma1_sq, CH, H, W, i);
block.sync();
multiply_shared_mem(buf2, buf1);
block.sync();
flush_conv_scratch(buf3);
block.sync();
do_separable_conv_x(buf2, buf3, H, W);
block.sync();
tmp = pix1 * 2.0f * do_separable_conv_y(buf3, H, W);
block.sync();
dL_dpix += tmp;
// gradient from sigma12
load_into_shared(buf2, dm_dsigma12, CH, H, W, i);
block.sync();
multiply_shared_mem(buf2, buf1);
block.sync();
flush_conv_scratch(buf3);
block.sync();
do_separable_conv_x(buf2, buf3, H, W);
block.sync();
tmp = pix2 * do_separable_conv_y(buf3, H, W);
block.sync();
dL_dpix += tmp;
if (pix_x < W && pix_y < H) {
const int global_idx = batch * CH * num_pix + i * num_pix + pix_id;
dL_dimg1[global_idx] = dL_dpix;
}
}
}
std::tuple<torch::Tensor,torch::Tensor,torch::Tensor,torch::Tensor>
fusedssim_opt(
float C1,
float C2,
torch::Tensor &img1,
torch::Tensor &img2,
bool train
)
{
const at::cuda::OptionalCUDAGuard device_guard(device_of(img1));
int B = img1.size(0);
int CH = img1.size(1);
int H = img1.size(2);
int W = img1.size(3);
dim3 grid((W + BX - 1) / BX, (H + BY - 1) / BY, B);
dim3 block(BX, BY, 1);
torch::Tensor target = torch::zeros_like(img1).contiguous();
torch::Tensor dm_dmu1 = train ? torch::zeros_like(img1).contiguous() : torch::empty(0);
torch::Tensor dm_dsigma1_sq = train ? torch::zeros_like(img1).contiguous() : torch::empty(0);
torch::Tensor dm_dsigma12 = train ? torch::zeros_like(img1).contiguous() : torch::empty(0);
fusedssimCUDA<<<grid,block>>>(
H,
W,
CH,
C1,
C2,
img1.contiguous().data<float>(),
img2.contiguous().data<float>(),
target.contiguous().data<float>(),
dm_dmu1.contiguous().data<float>(),
dm_dsigma1_sq.contiguous().data<float>(),
dm_dsigma12.contiguous().data<float>()
);
return std::make_tuple(target, dm_dmu1, dm_dsigma1_sq, dm_dsigma12);
}
torch::Tensor
fusedssim_opt_backward(
float C1,
float C2,
torch::Tensor &img1,
torch::Tensor &img2,
torch::Tensor &dL_dmap,
torch::Tensor &dm_dmu1,
torch::Tensor &dm_dsigma1_sq,
torch::Tensor &dm_dsigma12
)
{
const at::cuda::OptionalCUDAGuard device_guard(device_of(img1));
int B = img1.size(0);
int CH = img1.size(1);
int H = img1.size(2);
int W = img1.size(3);
torch::Tensor dL_dimg1 = torch::zeros_like(img1).contiguous();
dim3 grid((W + BX - 1) / BX, (H + BY - 1) / BY, B);
dim3 block(BX, BY, 1);
fusedssim_backwardCUDA<<<grid,block>>>(
H,
W,
CH,
C1,
C2,
img1.contiguous().data<float>(),
img2.contiguous().data<float>(),
dL_dmap.contiguous().data<float>(),
dL_dimg1.contiguous().data<float>(),
dm_dmu1.contiguous().data<float>(),
dm_dsigma1_sq.contiguous().data<float>(),
dm_dsigma12.contiguous().data<float>()
);
return dL_dimg1;
}