Skip to content

Commit

Permalink
Merge pull request #12 from sony/feature/20170805-fix_batchnormalization
Browse files Browse the repository at this point in the history
Improve backprop options of batch_normalization with CUDNN
  • Loading branch information
TakuyaNarihira authored Sep 1, 2017
2 parents 12c4350 + 63c3a89 commit 67510b6
Showing 1 changed file with 33 additions and 12 deletions.
45 changes: 33 additions & 12 deletions src/nbla/cuda/cudnn/function/batch_normalization.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include <nbla/array.hpp>
#include <nbla/variable.hpp>

#include <nbla/cuda/array/cuda_array.hpp>
#include <nbla/cuda/common.hpp>
#include <nbla/cuda/cudnn/cudnn.hpp>
#include <nbla/cuda/cudnn/function/batch_normalization.hpp>
Expand Down Expand Up @@ -149,10 +150,11 @@ void BatchNormalizationCudaCudnn<T>::backward_impl_batch(
const T *v = batch_var->get_data_pointer<T>(this->ctx_);
const T *x = inputs[0]->get_data_pointer<T>(this->ctx_);

if (propagate_down[0] && (propagate_down[1] || propagate_down[2])) {
T a = 1;
T b_data = accum[0] ? 1 : 0;
T b_param = 1;
if (propagate_down[0] || propagate_down[1] || propagate_down[2]) {
T a_data = propagate_down[0] ? 1 : 0;
T b_data = accum[0] && propagate_down[0] ? 1 : 0;
T a_param = propagate_down[1] || propagate_down[2] ? 1 : 0;
T b_param = a_param;
if (!(accum[1] || accum[2])) {
b_param = 0;
} else {
Expand All @@ -162,16 +164,35 @@ void BatchNormalizationCudaCudnn<T>::backward_impl_batch(
inputs[2]->grad()->zero();
}

T *dx = inputs[0]->cast_grad_and_get_pointer<T>(this->ctx_);
size_t workspace_size = 0;
if (!propagate_down[0]) {
workspace_size = inputs[0]->size() * sizeof(T);
} else if (!propagate_down[1] || !propagate_down[2]) {
workspace_size = inputs[1]->size() * sizeof(T);
}
T *tmp_buf = nullptr;
shared_ptr<CudaCachedArray> mem_workspace(
workspace_size
? new CudaCachedArray(workspace_size, dtypes::BYTE, this->ctx_)
: nullptr);
if (workspace_size) {
tmp_buf = (T *)mem_workspace->pointer();
}

T *dx = propagate_down[0]
? inputs[0]->cast_grad_and_get_pointer<T>(this->ctx_)
: tmp_buf;
const T *gamma = inputs[2]->get_data_pointer<T>(this->ctx_);
NBLA_CHECK(propagate_down[1] && propagate_down[2], error_code::value,
"'need_grad' of beta and gamma must be the same.");
T *db = inputs[1]->cast_grad_and_get_pointer<T>(this->ctx_);
T *dg = inputs[2]->cast_grad_and_get_pointer<T>(this->ctx_);
T *db = propagate_down[1]
? inputs[1]->cast_grad_and_get_pointer<T>(this->ctx_)
: tmp_buf;
T *dg = propagate_down[2]
? inputs[2]->cast_grad_and_get_pointer<T>(this->ctx_)
: tmp_buf;
NBLA_CUDNN_CHECK(cudnnBatchNormalizationBackward(
cudnn_handle_, mode_, &a, &b_data, &a, &b_param, input_desc_, x,
output_desc_, dy, input_desc_, dx, bn_scale_bias_mean_var_desc_, gamma,
dg, db, epsilon, m, v));
cudnn_handle_, mode_, &a_data, &b_data, &a_param, &b_param, input_desc_,
x, output_desc_, dy, input_desc_, dx, bn_scale_bias_mean_var_desc_,
gamma, dg, db, epsilon, m, v));
}
}

Expand Down

0 comments on commit 67510b6

Please sign in to comment.