Skip to content

Commit

Permalink
Added Fusion features for ASFF and for BiFPN(EfficientDet)
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexeyAB committed Nov 28, 2019
1 parent 31d483a commit 68db8ef
Show file tree
Hide file tree
Showing 11 changed files with 156 additions and 41 deletions.
3 changes: 2 additions & 1 deletion include/darknet.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ typedef struct tree {

// activations.h
typedef enum {
LOGISTIC, RELU, RELIE, LINEAR, RAMP, TANH, PLSE, LEAKY, ELU, LOGGY, STAIR, HARDTAN, LHTAN, SELU, SWISH, MISH
LOGISTIC, RELU, RELIE, LINEAR, RAMP, TANH, PLSE, LEAKY, ELU, LOGGY, STAIR, HARDTAN, LHTAN, SELU, SWISH, MISH, NORM_CHAN
}ACTIVATION;

// parser.h
Expand Down Expand Up @@ -235,6 +235,7 @@ struct layer {
int sqrt;
int flip;
int index;
int scale_wh;
int binary;
int xnor;
int peephole;
Expand Down
33 changes: 33 additions & 0 deletions src/activation_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -416,4 +416,37 @@ extern "C" void gradient_array_mish_ongpu(int n, float *activation_input_gpu, fl
const int num_blocks = get_number_of_blocks(n, BLOCK);
gradient_array_mish_kernel << <cuda_gridsize(n), BLOCK, 0, get_cuda_stream() >> > (n, activation_input_gpu, delta);
CHECK_CUDA(cudaPeekAtLastError());
}


__global__ void activate_array_normalize_channels_kernel(float *x, int size, int batch, int channels, int wh_step, float *output_gpu)
{
int i = blockIdx.x * blockDim.x + threadIdx.x;

int wh_i = i % wh_step;
int b = i / wh_step;

const float eps = 0.0001;
if (i < size) {
float sum = eps;
int k;
for (k = 0; k < channels; ++k) {
sum += x[wh_i + k * wh_step + b*wh_step*channels];
}
for (k = 0; k < channels; ++k) {
output_gpu[wh_i + k * wh_step + b*wh_step*channels] = x[wh_i + k * wh_step + b*wh_step*channels] / sum;
}
}
}

extern "C" void activate_array_normalize_channels_ongpu(float *x, int n, int batch, int channels, int wh_step, float *output_gpu)
{
// n = w*h*c*batch
// size = w*h*batch
int size = n / channels;

const int num_blocks = get_number_of_blocks(size, BLOCK);

activate_array_normalize_channels_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> > (x, size, batch, channels, wh_step, output_gpu);
CHECK_CUDA(cudaPeekAtLastError());
}
25 changes: 25 additions & 0 deletions src/activations.c
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ ACTIVATION get_activation(char *s)
if (strcmp(s, "logistic")==0) return LOGISTIC;
if (strcmp(s, "swish") == 0) return SWISH;
if (strcmp(s, "mish") == 0) return MISH;
if (strcmp(s, "normalize_channels") == 0) return NORM_CHAN;
if (strcmp(s, "loggy")==0) return LOGGY;
if (strcmp(s, "relu")==0) return RELU;
if (strcmp(s, "elu")==0) return ELU;
Expand Down Expand Up @@ -147,6 +148,30 @@ void activate_array_mish(float *x, const int n, float * activation_input, float
}
}

void activate_array_normalize_channels(float *x, const int n, int batch, int channels, int wh_step, float *output)
{
int size = n / channels;

int i;
#pragma omp parallel for
for (i = 0; i < size; ++i) {
int wh_i = i % wh_step;
int b = i / wh_step;

const float eps = 0.0001;
if (i < size) {
float sum = eps;
int k;
for (k = 0; k < channels; ++k) {
sum += x[wh_i + k * wh_step + b*wh_step*channels];
}
for (k = 0; k < channels; ++k) {
output[wh_i + k * wh_step + b*wh_step*channels] = x[wh_i + k * wh_step + b*wh_step*channels] / sum;
}
}
}
}

float gradient(float x, ACTIVATION a)
{
switch(a){
Expand Down
2 changes: 2 additions & 0 deletions src/activations.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,15 @@ void gradient_array_mish(const int n, const float * activation_input, float * de
void activate_array(float *x, const int n, const ACTIVATION a);
void activate_array_swish(float *x, const int n, float * output_sigmoid, float * output);
void activate_array_mish(float *x, const int n, float * activation_input, float * output);
void activate_array_normalize_channels(float *x, const int n, int batch, int channels, int wh_step, float *output);
#ifdef GPU
void activate_array_ongpu(float *x, int n, ACTIVATION a);
void activate_array_swish_ongpu(float *x, int n, float *output_sigmoid_gpu, float *output_gpu);
void activate_array_mish_ongpu(float *x, int n, float *activation_input_gpu, float *output_gpu);
void gradient_array_ongpu(float *x, int n, ACTIVATION a, float *delta);
void gradient_array_swish_ongpu(float *x, int n, float *sigmoid_gpu, float *delta);
void gradient_array_mish_ongpu(int n, float *activation_input_gpu, float *delta);
void activate_array_normalize_channels_ongpu(float *x, int n, int batch, int channels, int wh_step, float *output_gpu);
#endif

static inline float stair_activate(float x)
Expand Down
4 changes: 2 additions & 2 deletions src/blas.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,8 @@ void add_3_arrays_activate(float *a1, float *a2, float *a3, size_t size, ACTIVAT
void sum_of_mults(float *a1, float *a2, float *b1, float *b2, size_t size, float *dst);
void activate_and_mult(float *a1, float *a2, size_t size, ACTIVATION a, float *dst);

void scale_channels_gpu(float *in_w_h_c, int size, int channel_size, float *scales_c, float *out);
void backward_scale_channels_gpu(float *in_w_h_c_delta, int size, int channel_size,
void scale_channels_gpu(float *in_w_h_c, int size, int channel_size, int batch_size, int scale_wh, float *scales_c, float *out);
void backward_scale_channels_gpu(float *in_w_h_c_delta, int size, int channel_size, int batch_size, int scale_wh,
float *in_scales_c, float *out_from_delta,
float *in_from_output, float *out_state_delta);

Expand Down
66 changes: 43 additions & 23 deletions src/blas_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1102,19 +1102,26 @@ extern "C" void activate_and_mult(float *a1, float *a2, size_t size, ACTIVATION



__global__ void scale_channels_kernel(float *in_w_h_c, int size, int channel_size, float *scales_c, float *out)
__global__ void scale_channels_kernel(float *in_w_h_c, int size, int channel_size, int batch_size, int scale_wh, float *scales_c, float *out)
{
const int index = blockIdx.x*blockDim.x + threadIdx.x;
if (index < size) {
out[index] = in_w_h_c[index] * scales_c[index / channel_size];
if (scale_wh) {
int osd_index = index % channel_size + (index / batch_size)*channel_size;

out[index] = in_w_h_c[index] * scales_c[osd_index];
}
else {
out[index] = in_w_h_c[index] * scales_c[index / channel_size];
}
}
}

extern "C" void scale_channels_gpu(float *in_w_h_c, int size, int channel_size, float *scales_c, float *out)
extern "C" void scale_channels_gpu(float *in_w_h_c, int size, int channel_size, int batch_size, int scale_wh, float *scales_c, float *out)
{
const int block_size = BLOCK;
const int num_blocks = get_number_of_blocks(size, block_size);
scale_channels_kernel << <num_blocks, block_size, 0, get_cuda_stream() >> >(in_w_h_c, size, channel_size, scales_c, out);
scale_channels_kernel << <num_blocks, block_size, 0, get_cuda_stream() >> >(in_w_h_c, size, channel_size, batch_size, scale_wh, scales_c, out);
CHECK_CUDA(cudaPeekAtLastError());
}

Expand All @@ -1130,44 +1137,57 @@ float warpAllReduceSum(float val) {
return val;
}

__global__ void backward_scale_channels_kernel(float *in_w_h_c_delta, int size, int channel_size,
__global__ void backward_scale_channels_kernel(float *in_w_h_c_delta, int size, int channel_size, int batch_size, int scale_wh,
float *in_scales_c, float *out_from_delta,
float *in_from_output, float *out_state_delta)
{
const int index = blockIdx.x*blockDim.x + threadIdx.x;
int osd_index = index / channel_size;

if (index < size) {
//out_state_delta[osd_index] += in_w_h_c_delta[index] * in_from_output[index]; // l.delta * from (should be divided by channel_size?)

int warp_id = index / 32;
int index_warp_start = warp_id * 32;
int osd_index_warp_start = index_warp_start / channel_size;
int osd_index_warp_end = (index_warp_start + 31) / channel_size;

if (osd_index_warp_start == osd_index_warp_end) // all thread in warp process the same channel
if (scale_wh)
{
float sum = warpAllReduceSum(in_w_h_c_delta[index] * in_from_output[index]); // l.delta * from
if (threadIdx.x % 32 == 0) {
atomicAdd(&out_state_delta[osd_index], sum);
//out_state_delta[osd_index] += sum;
}
}
else {
int osd_index = index % channel_size + (index / batch_size)*channel_size;

//out_state_delta[osd_index] += in_w_h_c_delta[index] * in_from_output[index]; // l.delta * from (should be divided by channel_size?)
atomicAdd(&out_state_delta[osd_index], in_w_h_c_delta[index] * in_from_output[index]); // l.delta * from

out_from_delta[index] += in_scales_c[osd_index] * in_w_h_c_delta[index]; // input * l.delta // atomic isn't required here

}
else {
int osd_index = index / channel_size;
//out_state_delta[osd_index] += in_w_h_c_delta[index] * in_from_output[index]; // l.delta * from (should be divided by channel_size?)

int warp_id = index / 32;
int index_warp_start = warp_id * 32;
int osd_index_warp_start = index_warp_start / channel_size;
int osd_index_warp_end = (index_warp_start + 31) / channel_size;

if (osd_index_warp_start == osd_index_warp_end) // all thread in warp process the same channel
{
float sum = warpAllReduceSum(in_w_h_c_delta[index] * in_from_output[index]); // l.delta * from
if (threadIdx.x % 32 == 0) {
atomicAdd(&out_state_delta[osd_index], sum);
//out_state_delta[osd_index] += sum;
}
}
else {
atomicAdd(&out_state_delta[osd_index], in_w_h_c_delta[index] * in_from_output[index]); // l.delta * from
}

out_from_delta[index] += in_scales_c[osd_index] * in_w_h_c_delta[index]; // input * l.delta // atomic isn't required here
out_from_delta[index] += in_scales_c[osd_index] * in_w_h_c_delta[index]; // input * l.delta // atomic isn't required here
}
}
}

extern "C" void backward_scale_channels_gpu(float *in_w_h_c_delta, int size, int channel_size,
extern "C" void backward_scale_channels_gpu(float *in_w_h_c_delta, int size, int channel_size, int batch_size, int scale_wh,
float *in_scales_c, float *out_from_delta,
float *in_from_output, float *out_state_delta)
{
const int block_size = BLOCK;
const int num_blocks = get_number_of_blocks(size, block_size);
backward_scale_channels_kernel << <num_blocks, block_size, 0, get_cuda_stream() >> > (in_w_h_c_delta, size, channel_size,
backward_scale_channels_kernel << <num_blocks, block_size, 0, get_cuda_stream() >> > (in_w_h_c_delta, size, channel_size, batch_size, scale_wh,
in_scales_c, out_from_delta,
in_from_output, out_state_delta);

Expand Down
1 change: 1 addition & 0 deletions src/convolutional_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -599,6 +599,7 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state)

if (l.activation == SWISH) activate_array_swish_ongpu(l.output_gpu, l.outputs*l.batch, l.activation_input_gpu, l.output_gpu);
else if (l.activation == MISH) activate_array_mish_ongpu(l.output_gpu, l.outputs*l.batch, l.activation_input_gpu, l.output_gpu);
else if (l.activation == NORM_CHAN) activate_array_normalize_channels_ongpu(l.output_gpu, l.outputs*l.batch, l.batch, l.out_c, l.out_w*l.out_h, l.output_gpu);
else if (l.activation != LINEAR) activate_array_ongpu(l.output_gpu, l.outputs*l.batch, l.activation);
//if(l.dot > 0) dot_error_gpu(l);
if(l.binary || l.xnor) swap_binary(&l);
Expand Down
2 changes: 2 additions & 0 deletions src/convolutional_layer.c
Original file line number Diff line number Diff line change
Expand Up @@ -1199,6 +1199,7 @@ void forward_convolutional_layer(convolutional_layer l, network_state state)
//activate_array(l.output, m*n*l.batch, l.activation);
if (l.activation == SWISH) activate_array_swish(l.output, l.outputs*l.batch, l.activation_input, l.output);
else if (l.activation == MISH) activate_array_mish(l.output, l.outputs*l.batch, l.activation_input, l.output);
else if (l.activation == NORM_CHAN) activate_array_normalize_channels(l.output, l.outputs*l.batch, l.batch, l.out_c, l.out_w*l.out_h, l.output);
else activate_array_cpu_custom(l.output, m*n*l.batch, l.activation);
return;

Expand Down Expand Up @@ -1239,6 +1240,7 @@ void forward_convolutional_layer(convolutional_layer l, network_state state)
//activate_array(l.output, m*n*l.batch, l.activation);
if (l.activation == SWISH) activate_array_swish(l.output, l.outputs*l.batch, l.activation_input, l.output);
else if (l.activation == MISH) activate_array_mish(l.output, l.outputs*l.batch, l.activation_input, l.output);
else if (l.activation == NORM_CHAN) activate_array_normalize_channels(l.output, l.outputs*l.batch, l.batch, l.out_c, l.out_w*l.out_h, l.output);
else activate_array_cpu_custom(l.output, l.outputs*l.batch, l.activation);

if(l.binary || l.xnor) swap_binary(&l);
Expand Down
3 changes: 2 additions & 1 deletion src/parser.c
Original file line number Diff line number Diff line change
Expand Up @@ -748,11 +748,12 @@ layer parse_scale_channels(list *options, size_params params, network net)
char *l = option_find(options, "from");
int index = atoi(l);
if (index < 0) index = params.index + index;
int scale_wh = option_find_int_quiet(options, "scale_wh", 0);

int batch = params.batch;
layer from = net.layers[index];

layer s = make_scale_channels_layer(batch, index, params.w, params.h, params.c, from.out_w, from.out_h, from.out_c);
layer s = make_scale_channels_layer(batch, index, params.w, params.h, params.c, from.out_w, from.out_h, from.out_c, scale_wh);

char *activation_s = option_find_str_quiet(options, "activation", "linear");
ACTIVATION activation = get_activation(activation_s);
Expand Down
56 changes: 43 additions & 13 deletions src/scale_channels_layer.c
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,18 @@
#include <stdio.h>
#include <assert.h>

layer make_scale_channels_layer(int batch, int index, int w, int h, int c, int w2, int h2, int c2)
layer make_scale_channels_layer(int batch, int index, int w, int h, int c, int w2, int h2, int c2, int scale_wh)
{
fprintf(stderr,"scale Layer: %d\n", index);
layer l = { (LAYER_TYPE)0 };
l.type = SCALE_CHANNELS;
l.batch = batch;
l.scale_wh = scale_wh;
l.w = w;
l.h = h;
l.c = c;
assert(w == 1 && h == 1);
if (!l.scale_wh) assert(w == 1 && h == 1);
else assert(c == 1);

l.out_w = w2;
l.out_h = h2;
Expand Down Expand Up @@ -62,12 +64,24 @@ void forward_scale_channels_layer(const layer l, network_state state)
{
int size = l.batch * l.out_c * l.out_w * l.out_h;
int channel_size = l.out_w * l.out_h;
int batch_size = l.out_c * l.out_w * l.out_h;
float *from_output = state.net.layers[l.index].output;

int i;
#pragma omp parallel for
for (i = 0; i < size; ++i) {
l.output[i] = state.input[i / channel_size] * from_output[i];
if (l.scale_wh) {
int i;
#pragma omp parallel for
for (i = 0; i < size; ++i) {
int input_index = i % channel_size + (i / batch_size)*channel_size;

l.output[i] = state.input[input_index] * from_output[i];
}
}
else {
int i;
#pragma omp parallel for
for (i = 0; i < size; ++i) {
l.output[i] = state.input[i / channel_size] * from_output[i];
}
}

activate_array(l.output, l.outputs*l.batch, l.activation);
Expand All @@ -81,15 +95,29 @@ void backward_scale_channels_layer(const layer l, network_state state)

int size = l.batch * l.out_c * l.out_w * l.out_h;
int channel_size = l.out_w * l.out_h;
int batch_size = l.out_c * l.out_w * l.out_h;
float *from_output = state.net.layers[l.index].output;
float *from_delta = state.net.layers[l.index].delta;

int i;
#pragma omp parallel for
for (i = 0; i < size; ++i) {
state.delta[i / channel_size] += l.delta[i] * from_output[i] / channel_size; // l.delta * from (should be divided by channel_size?)
if (l.scale_wh) {
int i;
#pragma omp parallel for
for (i = 0; i < size; ++i) {
int input_index = i % channel_size + (i / batch_size)*channel_size;

from_delta[i] += state.input[i / channel_size] * l.delta[i]; // input * l.delta
state.delta[input_index] += l.delta[i] * from_output[i];// / l.out_c; // l.delta * from (should be divided by l.out_c?)

from_delta[i] += state.input[input_index] * l.delta[i]; // input * l.delta
}
}
else {
int i;
#pragma omp parallel for
for (i = 0; i < size; ++i) {
state.delta[i / channel_size] += l.delta[i] * from_output[i];// / channel_size; // l.delta * from (should be divided by channel_size?)

from_delta[i] += state.input[i / channel_size] * l.delta[i]; // input * l.delta
}
}
}

Expand All @@ -98,8 +126,9 @@ void forward_scale_channels_layer_gpu(const layer l, network_state state)
{
int size = l.batch * l.out_c * l.out_w * l.out_h;
int channel_size = l.out_w * l.out_h;
int batch_size = l.out_c * l.out_w * l.out_h;

scale_channels_gpu(state.net.layers[l.index].output_gpu, size, channel_size, state.input, l.output_gpu);
scale_channels_gpu(state.net.layers[l.index].output_gpu, size, channel_size, batch_size, l.scale_wh, state.input, l.output_gpu);

activate_array_ongpu(l.output_gpu, l.outputs*l.batch, l.activation);
}
Expand All @@ -110,9 +139,10 @@ void backward_scale_channels_layer_gpu(const layer l, network_state state)

int size = l.batch * l.out_c * l.out_w * l.out_h;
int channel_size = l.out_w * l.out_h;
int batch_size = l.out_c * l.out_w * l.out_h;
float *from_output = state.net.layers[l.index].output_gpu;
float *from_delta = state.net.layers[l.index].delta_gpu;

backward_scale_channels_gpu(l.delta_gpu, size, channel_size, state.input, from_delta, from_output, state.delta);
backward_scale_channels_gpu(l.delta_gpu, size, channel_size, batch_size, l.scale_wh, state.input, from_delta, from_output, state.delta);
}
#endif
2 changes: 1 addition & 1 deletion src/scale_channels_layer.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#ifdef __cplusplus
extern "C" {
#endif
layer make_scale_channels_layer(int batch, int index, int w, int h, int c, int w2, int h2, int c2);
layer make_scale_channels_layer(int batch, int index, int w, int h, int c, int w2, int h2, int c2, int scale_wh);
void forward_scale_channels_layer(const layer l, network_state state);
void backward_scale_channels_layer(const layer l, network_state state);
void resize_scale_channels_layer(layer *l, network *net);
Expand Down

0 comments on commit 68db8ef

Please sign in to comment.