Skip to content

Commit

Permalink
Added [local_avgpool] layer
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexeyAB committed Dec 20, 2019
1 parent e66d3b1 commit 005714a
Show file tree
Hide file tree
Showing 6 changed files with 317 additions and 24 deletions.
2 changes: 2 additions & 0 deletions include/darknet.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ typedef enum {
DECONVOLUTIONAL,
CONNECTED,
MAXPOOL,
LOCAL_AVGPOOL,
SOFTMAX,
DETECTION,
DROPOUT,
Expand Down Expand Up @@ -202,6 +203,7 @@ struct layer {
void(*update_gpu) (struct layer, int, float, float, float);
layer *share_layer;
int train;
int avgpool;
int batch_normalize;
int shortcut;
int batch;
Expand Down
163 changes: 142 additions & 21 deletions src/maxpool_layer.c
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,31 @@ void cudnn_maxpool_setup(layer *l)
}


maxpool_layer make_maxpool_layer(int batch, int h, int w, int c, int size, int stride_x, int stride_y, int padding, int maxpool_depth, int out_channels, int antialiasing, int train)
void cudnn_local_avgpool_setup(layer *l)
{
#ifdef CUDNN
CHECK_CUDNN(cudnnSetPooling2dDescriptor(
l->poolingDesc,
CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING,
CUDNN_NOT_PROPAGATE_NAN, // CUDNN_PROPAGATE_NAN, CUDNN_NOT_PROPAGATE_NAN
l->size,
l->size,
l->pad / 2, //0, //l.pad,
l->pad / 2, //0, //l.pad,
l->stride_x,
l->stride_y));

CHECK_CUDNN(cudnnSetTensor4dDescriptor(l->srcTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l->batch, l->c, l->h, l->w));
CHECK_CUDNN(cudnnSetTensor4dDescriptor(l->dstTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l->batch, l->out_c, l->out_h, l->out_w));
#endif // CUDNN
}

maxpool_layer make_maxpool_layer(int batch, int h, int w, int c, int size, int stride_x, int stride_y, int padding, int maxpool_depth, int out_channels, int antialiasing, int avgpool, int train)
{
maxpool_layer l = { (LAYER_TYPE)0 };
l.type = MAXPOOL;
l.avgpool = avgpool;
if (avgpool) l.type = LOCAL_AVGPOOL;
else l.type = MAXPOOL;
l.train = train;

const int blur_stride_x = stride_x;
Expand Down Expand Up @@ -88,31 +109,53 @@ maxpool_layer make_maxpool_layer(int batch, int h, int w, int c, int size, int s
int output_size = l.out_h * l.out_w * l.out_c * batch;

if (train) {
l.indexes = (int*)calloc(output_size, sizeof(int));
if (!avgpool) l.indexes = (int*)calloc(output_size, sizeof(int));
l.delta = (float*)calloc(output_size, sizeof(float));
}
l.output = (float*)calloc(output_size, sizeof(float));
l.forward = forward_maxpool_layer;
l.backward = backward_maxpool_layer;
if (avgpool) {
l.forward = forward_local_avgpool_layer;
l.backward = backward_local_avgpool_layer;
}
else {
l.forward = forward_maxpool_layer;
l.backward = backward_maxpool_layer;
}
#ifdef GPU
l.forward_gpu = forward_maxpool_layer_gpu;
l.backward_gpu = backward_maxpool_layer_gpu;
if (avgpool) {
l.forward_gpu = forward_local_avgpool_layer_gpu;
l.backward_gpu = backward_local_avgpool_layer_gpu;
}
else {
l.forward_gpu = forward_maxpool_layer_gpu;
l.backward_gpu = backward_maxpool_layer_gpu;
}

if (train) {
l.indexes_gpu = cuda_make_int_array(output_size);
if (!avgpool) l.indexes_gpu = cuda_make_int_array(output_size);
l.delta_gpu = cuda_make_array(l.delta, output_size);
}
l.output_gpu = cuda_make_array(l.output, output_size);
create_maxpool_cudnn_tensors(&l);
cudnn_maxpool_setup(&l);
if (avgpool) cudnn_local_avgpool_setup(&l);
else cudnn_maxpool_setup(&l);

#endif // GPU
l.bflops = (l.size*l.size*l.c * l.out_h*l.out_w) / 1000000000.;
if (maxpool_depth)
fprintf(stderr, "max-depth %2dx%2d/%2d %4d x%4d x%4d -> %4d x%4d x%4d %5.3f BF\n", size, size, stride_x, w, h, c, l.out_w, l.out_h, l.out_c, l.bflops);
else if(stride_x == stride_y)
fprintf(stderr, "max %2dx%2d/%2d %4d x%4d x%4d -> %4d x%4d x%4d %5.3f BF\n", size, size, stride_x, w, h, c, l.out_w, l.out_h, l.out_c, l.bflops);
else
fprintf(stderr, "max %2dx%2d/%2dx%2d %4d x%4d x%4d -> %4d x%4d x%4d %5.3f BF\n", size, size, stride_x, stride_y, w, h, c, l.out_w, l.out_h, l.out_c, l.bflops);
if (avgpool) {
if (stride_x == stride_y)
fprintf(stderr, "avg %2dx%2d/%2d %4d x%4d x%4d -> %4d x%4d x%4d %5.3f BF\n", size, size, stride_x, w, h, c, l.out_w, l.out_h, l.out_c, l.bflops);
else
fprintf(stderr, "avg %2dx%2d/%2dx%2d %4d x%4d x%4d -> %4d x%4d x%4d %5.3f BF\n", size, size, stride_x, stride_y, w, h, c, l.out_w, l.out_h, l.out_c, l.bflops);
}
else {
if (maxpool_depth)
fprintf(stderr, "max-depth %2dx%2d/%2d %4d x%4d x%4d -> %4d x%4d x%4d %5.3f BF\n", size, size, stride_x, w, h, c, l.out_w, l.out_h, l.out_c, l.bflops);
else if (stride_x == stride_y)
fprintf(stderr, "max %2dx%2d/%2d %4d x%4d x%4d -> %4d x%4d x%4d %5.3f BF\n", size, size, stride_x, w, h, c, l.out_w, l.out_h, l.out_c, l.bflops);
else
fprintf(stderr, "max %2dx%2d/%2dx%2d %4d x%4d x%4d -> %4d x%4d x%4d %5.3f BF\n", size, size, stride_x, stride_y, w, h, c, l.out_w, l.out_h, l.out_c, l.bflops);
}

if (l.antialiasing) {
printf("AA: ");
Expand Down Expand Up @@ -152,7 +195,7 @@ maxpool_layer make_maxpool_layer(int batch, int h, int w, int c, int size, int s
for (i = 0; i < l.out_c; ++i) l.input_layer->biases[i] = 0;
#ifdef GPU
if (gpu_index >= 0) {
l.input_antialiasing_gpu = cuda_make_array(NULL, l.batch*l.outputs);
if (l.antialiasing) l.input_antialiasing_gpu = cuda_make_array(NULL, l.batch*l.outputs);
push_convolutional_layer(*(l.input_layer));
}
#endif // GPU
Expand All @@ -173,7 +216,7 @@ void resize_maxpool_layer(maxpool_layer *l, int w, int h)
int output_size = l->outputs * l->batch;

if (l->train) {
l->indexes = (int*)realloc(l->indexes, output_size * sizeof(int));
if (!l->avgpool) l->indexes = (int*)realloc(l->indexes, output_size * sizeof(int));
l->delta = (float*)realloc(l->delta, output_size * sizeof(float));
}
l->output = (float*)realloc(l->output, output_size * sizeof(float));
Expand All @@ -183,13 +226,16 @@ void resize_maxpool_layer(maxpool_layer *l, int w, int h)
l->output_gpu = cuda_make_array(l->output, output_size);

if (l->train) {
CHECK_CUDA(cudaFree((float *)l->indexes_gpu));
if (!l->avgpool) {
CHECK_CUDA(cudaFree((float *)l->indexes_gpu));
l->indexes_gpu = cuda_make_int_array(output_size);
}
CHECK_CUDA(cudaFree(l->delta_gpu));
l->indexes_gpu = cuda_make_int_array(output_size);
l->delta_gpu = cuda_make_array(l->delta, output_size);
}

cudnn_maxpool_setup(l);
if(l->avgpool) cudnn_local_avgpool_setup(l);
else cudnn_maxpool_setup(l);
#endif
}

Expand Down Expand Up @@ -229,7 +275,8 @@ void forward_maxpool_layer(const maxpool_layer l, network_state state)
if (!state.train && l.stride_x == l.stride_y) {
forward_maxpool_layer_avx(state.input, l.output, l.indexes, l.size, l.w, l.h, l.out_w, l.out_h, l.c, l.pad, l.stride, l.batch);
}
else {
else
{

int b, i, j, k, m, n;
int w_offset = -l.pad / 2;
Expand Down Expand Up @@ -290,3 +337,77 @@ void backward_maxpool_layer(const maxpool_layer l, network_state state)
state.delta[index] += l.delta[i];
}
}


void forward_local_avgpool_layer(const maxpool_layer l, network_state state)
{
int b, i, j, k, m, n;
int w_offset = -l.pad / 2;
int h_offset = -l.pad / 2;

int h = l.out_h;
int w = l.out_w;
int c = l.c;

for (b = 0; b < l.batch; ++b) {
for (k = 0; k < c; ++k) {
for (i = 0; i < h; ++i) {
for (j = 0; j < w; ++j) {
int out_index = j + w*(i + h*(k + c*b));
float avg = 0;
int counter = 0;
for (n = 0; n < l.size; ++n) {
for (m = 0; m < l.size; ++m) {
int cur_h = h_offset + i*l.stride_y + n;
int cur_w = w_offset + j*l.stride_x + m;
int index = cur_w + l.w*(cur_h + l.h*(k + b*l.c));
int valid = (cur_h >= 0 && cur_h < l.h &&
cur_w >= 0 && cur_w < l.w);
if (valid) {
counter++;
avg += state.input[index];
}

}
}
l.output[out_index] = avg / counter;
}
}
}
}
}

void backward_local_avgpool_layer(const maxpool_layer l, network_state state)
{

int b, i, j, k, m, n;
int w_offset = -l.pad / 2;
int h_offset = -l.pad / 2;

int h = l.out_h;
int w = l.out_w;
int c = l.c;

for (b = 0; b < l.batch; ++b) {
for (k = 0; k < c; ++k) {
for (i = 0; i < h; ++i) {
for (j = 0; j < w; ++j) {
int out_index = j + w*(i + h*(k + c*b));
for (n = 0; n < l.size; ++n) {
for (m = 0; m < l.size; ++m) {
int cur_h = h_offset + i*l.stride_y + n;
int cur_w = w_offset + j*l.stride_x + m;
int index = cur_w + l.w*(cur_h + l.h*(k + b*l.c));
int valid = (cur_h >= 0 && cur_h < l.h &&
cur_w >= 0 && cur_w < l.w);

if (valid) state.delta[index] += l.delta[out_index] / (l.size*l.size);
}
}

}
}
}
}

}
8 changes: 7 additions & 1 deletion src/maxpool_layer.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,21 @@ typedef layer maxpool_layer;
extern "C" {
#endif
image get_maxpool_image(maxpool_layer l);
maxpool_layer make_maxpool_layer(int batch, int h, int w, int c, int size, int stride_x, int stride_y, int padding, int maxpool_depth, int out_channels, int antialiasing, int train);
maxpool_layer make_maxpool_layer(int batch, int h, int w, int c, int size, int stride_x, int stride_y, int padding, int maxpool_depth, int out_channels, int antialiasing, int avgpool, int train);
void resize_maxpool_layer(maxpool_layer *l, int w, int h);
void forward_maxpool_layer(const maxpool_layer l, network_state state);
void backward_maxpool_layer(const maxpool_layer l, network_state state);

void forward_local_avgpool_layer(const maxpool_layer l, network_state state);
void backward_local_avgpool_layer(const maxpool_layer l, network_state state);

#ifdef GPU
void forward_maxpool_layer_gpu(maxpool_layer l, network_state state);
void backward_maxpool_layer_gpu(maxpool_layer l, network_state state);
void cudnn_maxpool_setup(maxpool_layer *l);

void forward_local_avgpool_layer_gpu(maxpool_layer layer, network_state state);
void backward_local_avgpool_layer_gpu(maxpool_layer layer, network_state state);
#endif // GPU

#ifdef __cplusplus
Expand Down
Loading

0 comments on commit 005714a

Please sign in to comment.