diff --git a/sparseconvnet/SCN/CUDA/AveragePooling.cu b/sparseconvnet/SCN/CUDA/AveragePooling.cu index 6777dda..f3579f7 100644 --- a/sparseconvnet/SCN/CUDA/AveragePooling.cu +++ b/sparseconvnet/SCN/CUDA/AveragePooling.cu @@ -9,24 +9,24 @@ // NTX must be >=2 so r is filled properly template __global__ void AveragePooling_fp(T *input_features, T *output_features, - Int nPlanes, Int input_stride, - Int output_stride, Int *rules, Int nHot, - T alpha) { + Int nPlanes, Int input_stride, + Int output_stride, Int *rules, Int nHot, + T alpha) { __shared__ Int r[NTY * 2]; for (Int n = blockIdx.x * NTY; n < nHot; n += gridDim.x * NTY) { { Int i = threadIdx.x + NTX * threadIdx.y; if (i < NTY * 2 and i < 2 * (nHot - n)) - r[i] = rules[2 * n + i]; + r[i] = rules[2 * n + i]; } __syncthreads(); if (n + threadIdx.y < nHot) { Int i = r[2 * threadIdx.y] * input_stride; Int o = r[2 * threadIdx.y + 1] * output_stride; for (Int plane = threadIdx.x; plane < nPlanes; plane += NTX) - output_features[o + plane]+= alpha * input_features[i + plane]; - // atomicAdd(&output_features[o + plane], - // alpha * input_features[i + plane]); + output_features[o + plane] += alpha * input_features[i + plane]; + // atomicAdd(&output_features[o + plane], + // alpha * input_features[i + plane]); } __syncthreads(); } @@ -34,32 +34,35 @@ __global__ void AveragePooling_fp(T *input_features, T *output_features, template void cuda_AveragePooling_ForwardPass(T *input_features, T *output_features, - Int nPlanes, Int input_stride, - Int output_stride, RuleBook _rules, - Int filterVolume) { - RULEBOOKITERATOR((AveragePooling_fp<<<32, dim3(32, 32)>>>( - input_features, output_features, nPlanes, input_stride, output_stride, - rbB, nHotB, 1.0 / filterVolume)); - , ) + Int nPlanes, Int input_stride, + Int output_stride, RuleBook _rules, + Int filterVolume) { + auto application = [&](Int *rbB, Int nHotB, cudaStream_t &stream) -> void { + AveragePooling_fp<<<32, dim3(32, 32), 0, stream>>>( + input_features, output_features, nPlanes, input_stride, output_stride, + rbB, nHotB, 1.0 / filterVolume); + }; + + iterateRuleBook(_rules, application); } template __global__ void AveragePooling_bp(T *d_input_features, T *d_output_features, - Int nPlanes, Int input_stride, - Int output_stride, Int *rules, Int nHot, - T alpha) { + Int nPlanes, Int input_stride, + Int output_stride, Int *rules, Int nHot, + T alpha) { __shared__ Int r[NTY * 2]; for (Int n = blockIdx.x * NTY; n < nHot; n += gridDim.x * NTY) { { Int i = threadIdx.x + NTX * threadIdx.y; if (i < NTY * 2 and i < 2 * (nHot - n)) - r[i] = rules[2 * n + i]; + r[i] = rules[2 * n + i]; } __syncthreads(); if (n + threadIdx.y < nHot) { Int i = r[2 * threadIdx.y] * input_stride; Int o = r[2 * threadIdx.y + 1] * output_stride; for (Int plane = threadIdx.x; plane < nPlanes; plane += NTX) - d_input_features[i + plane] += alpha * d_output_features[o + plane]; + d_input_features[i + plane] += alpha * d_output_features[o + plane]; } __syncthreads(); } @@ -67,79 +70,72 @@ __global__ void AveragePooling_bp(T *d_input_features, T *d_output_features, template void cuda_AveragePooling_BackwardPass(T *d_input_features, T *d_output_features, - Int nPlanes, Int input_stride, - Int output_stride, RuleBook _rules, - Int filterVolume) { - RULEBOOKITERATOR((AveragePooling_bp<<<32, dim3(32, 32)>>>( - d_input_features, d_output_features, nPlanes, input_stride, output_stride, - rbB, nHotB, 1.0 / filterVolume)); - , ) -} - - - - - - - - - + Int nPlanes, Int input_stride, + Int output_stride, RuleBook _rules, + Int filterVolume) { + auto application = [&](Int *rbB, Int nHotB, cudaStream_t &stream) -> void { + AveragePooling_bp<<<32, dim3(32, 32), 0, stream>>>( + d_input_features, d_output_features, nPlanes, input_stride, output_stride, + rbB, nHotB, 1.0 / filterVolume); + }; + iterateRuleBook(_rules, application); +} // NTX must be >=2 so r is filled properly template -__global__ void CopyFeaturesHelper_fp(T *input_features, T *output_features, Int * rules, - Int nPlanes, Int nHot) { +__global__ void CopyFeaturesHelper_fp(T *input_features, T *output_features, + Int *rules, Int nPlanes, Int nHot) { __shared__ Int r[NTY * 2]; for (Int n = blockIdx.x * NTY; n < nHot; n += gridDim.x * NTY) { { Int i = threadIdx.x + NTX * threadIdx.y; if (i < NTY * 2 and i < 2 * (nHot - n)) - r[i] = rules[2 * n + i]; + r[i] = rules[2 * n + i]; } __syncthreads(); if (n + threadIdx.y < nHot) { - Int i = r[2 * threadIdx.y+1] * nPlanes; - Int o = r[2 * threadIdx.y ] * nPlanes; + Int i = r[2 * threadIdx.y + 1] * nPlanes; + Int o = r[2 * threadIdx.y] * nPlanes; for (Int plane = threadIdx.x; plane < nPlanes; plane += NTX) - output_features[o + plane]= input_features[i + plane]; + output_features[o + plane] = input_features[i + plane]; } __syncthreads(); } } template -void cuda_CopyFeaturesHelper_ForwardPass(T *input_features, T *output_features, Int* rules, - Int nPlanes, Int nHot) { -CopyFeaturesHelper_fp<<<32, dim3(32, 32)>>>( - input_features, output_features, rules, nPlanes, - nHot); +void cuda_CopyFeaturesHelper_ForwardPass(T *input_features, T *output_features, + Int *rules, Int nPlanes, Int nHot) { + CopyFeaturesHelper_fp<<<32, dim3(32, 32)>>>( + input_features, output_features, rules, nPlanes, nHot); } template -__global__ void CopyFeaturesHelper_bp(T *d_input_features, T *d_output_features, Int* rules, - Int nPlanes,Int nHot) { +__global__ void CopyFeaturesHelper_bp(T *d_input_features, T *d_output_features, + Int *rules, Int nPlanes, Int nHot) { __shared__ Int r[NTY * 2]; for (Int n = blockIdx.x * NTY; n < nHot; n += gridDim.x * NTY) { { Int i = threadIdx.x + NTX * threadIdx.y; if (i < NTY * 2 and i < 2 * (nHot - n)) - r[i] = rules[2 * n + i]; + r[i] = rules[2 * n + i]; } __syncthreads(); if (n + threadIdx.y < nHot) { - Int i = r[2 * threadIdx.y+1] * nPlanes; + Int i = r[2 * threadIdx.y + 1] * nPlanes; Int o = r[2 * threadIdx.y] * nPlanes; for (Int plane = threadIdx.x; plane < nPlanes; plane += NTX) - d_input_features[i + plane] = d_output_features[o + plane]; + d_input_features[i + plane] = d_output_features[o + plane]; } __syncthreads(); } } template -void cuda_CopyFeaturesHelper_BackwardPass(T *d_input_features, T *d_output_features, - Int* rules, Int nPlanes, Int nHot) { -CopyFeaturesHelper_bp<<<32, dim3(32, 32)>>>( +void cuda_CopyFeaturesHelper_BackwardPass(T *d_input_features, + T *d_output_features, Int *rules, + Int nPlanes, Int nHot) { + CopyFeaturesHelper_bp<<<32, dim3(32, 32)>>>( d_input_features, d_output_features, rules, nPlanes, nHot); } diff --git a/sparseconvnet/SCN/CUDA/Convolution.cu b/sparseconvnet/SCN/CUDA/Convolution.cu index f56f355..7cfd587 100644 --- a/sparseconvnet/SCN/CUDA/Convolution.cu +++ b/sparseconvnet/SCN/CUDA/Convolution.cu @@ -9,7 +9,7 @@ template __global__ void Convolution_fp_bias_(T *output_features, T *bias, Int nPlanes, - Int nActive) { + Int nActive) { Int n = blockIdx.x * 32 + threadIdx.x; T b = bias[n]; output_features += n; @@ -22,34 +22,36 @@ template void Convolution_fp_bias(T *oF, T *b, Int nPlanes, Int nActive) { if (nPlanes / 32 > 0) Convolution_fp_bias_<<>>(oF, b, nPlanes, - nActive); + nActive); if (nPlanes % 32 > 0) { Int o = nPlanes / 32 * 32; Convolution_fp_bias_<<>>(oF + o, b + o, nPlanes, - nActive); + nActive); } } template -__global__ void Convolution_bp_bias_(T *d_oF, T *d_b, Int nPlanes, Int nActive) { +__global__ void Convolution_bp_bias_(T *d_oF, T *d_b, Int nPlanes, + Int nActive) { Int n = blockIdx.x * 32 + threadIdx.x; - d_oF+=n; + d_oF += n; TACC t = 0; for (Int row = blockIdx.y; row < nActive; row += gridDim.y) - t += d_oF[row * nPlanes ]; + t += d_oF[row * nPlanes]; atomicAdd(&d_b[n], t); } template void Convolution_bp_bias(T *d_oF, T *d_b, Int nPlanes, Int nActive) { if (nPlanes / 32 > 0) - Convolution_bp_bias_<<>>(d_oF, d_b, nPlanes, nActive); + Convolution_bp_bias_<<>>(d_oF, d_b, nPlanes, + nActive); if (nPlanes % 32 > 0) { Int o = nPlanes / 32 * 32; - Convolution_bp_bias_<<>>(d_oF + o, d_b + o, nPlanes, nActive); + Convolution_bp_bias_<<>>(d_oF + o, d_b + o, + nPlanes, nActive); } } - // .._nPlanes == planes per nGroup // weight = nGroups x input_nPlanes x output_nPlanes // = nGroups x M*K x N*K @@ -57,8 +59,8 @@ void Convolution_bp_bias(T *d_oF, T *d_b, Int nPlanes, Int nActive) { template __global__ void dConvolution_KMxKN_forwardA(T *inFeatures, T *outFeatures, T *w, Int *rules, - Int nHot, Int input_nPlanes, Int input_stride, - Int output_nPlanes, Int output_stride) { + Int nHot, Int input_nPlanes, Int input_stride, + Int output_nPlanes, Int output_stride) { // nHot must be a multiple of K!! // Input x Weight -> Output @@ -95,31 +97,31 @@ dConvolution_KMxKN_forwardA(T *inFeatures, T *outFeatures, T *w, Int *rules, for (Int s = blockIdx.x * K; s < nHot; s += K * gridDim.x) { #pragma unroll for (int v = 0; v < V; v++) { - R0[v] = rules[2 * (s + ty[v])]; - R1[v] = rules[2 * (s + ty[v]) + 1]; + R0[v] = rules[2 * (s + ty[v])]; + R1[v] = rules[2 * (s + ty[v]) + 1]; } __syncthreads(); // Read input, reset O[] #pragma unroll for (int v = 0; v < V; v++) { - I[ty[v]][tx] = inFeatures[R0[v] * input_stride + tx]; - O[v] = 0; + I[ty[v]][tx] = inFeatures[R0[v] * input_stride + tx]; + O[v] = 0; } __syncthreads(); #pragma unroll for (int k = 0; k < K; k++) #pragma unroll - for (int v = 0; v < V; v++) - O[v] += I[ty[v]][k] * W[k][tx]; + for (int v = 0; v < V; v++) + O[v] += I[ty[v]][k] * W[k][tx]; #pragma unroll for (int v = 0; v < V; v++) - O[v] += outFeatures[R1[v] * output_stride + tx]; + O[v] += outFeatures[R1[v] * output_stride + tx]; #pragma unroll for (int v = 0; v < V; v++) - outFeatures[R1[v] * output_stride + tx] = O[v]; + outFeatures[R1[v] * output_stride + tx] = O[v]; __syncthreads(); } w += K * output_nPlanes; @@ -129,8 +131,8 @@ dConvolution_KMxKN_forwardA(T *inFeatures, T *outFeatures, T *w, Int *rules, template __global__ void dConvolution_KMxKN_forwardB(T *inFeatures, T *outFeatures, T *w, Int *rules, - Int nHot, Int input_nPlanes, Int input_stride, - Int output_nPlanes, Int output_stride) { + Int nHot, Int input_nPlanes, Int input_stride, + Int output_nPlanes, Int output_stride) { // Input x Weight -> Output // blockDim=(K,K/V,1), gridDim=(nBlocks,N,nGroups) Volkov-blocks // K is a multiple of V, @@ -165,36 +167,36 @@ dConvolution_KMxKN_forwardB(T *inFeatures, T *outFeatures, T *w, Int *rules, for (Int s = blockIdx.x * K; s < nHot; s += K * gridDim.x) { #pragma unroll for (int v = 0; v < V; v++) { - if (s + ty[v] < nHot) { - R0[v] = rules[2 * (s + ty[v])]; - R1[v] = rules[2 * (s + ty[v]) + 1]; - } + if (s + ty[v] < nHot) { + R0[v] = rules[2 * (s + ty[v])]; + R1[v] = rules[2 * (s + ty[v]) + 1]; + } } __syncthreads(); // Read input, reset O[] #pragma unroll for (int v = 0; v < V; v++) { - if (s + ty[v] < nHot) - I[ty[v]][tx] = inFeatures[R0[v] * input_stride + tx]; - O[v] = 0; + if (s + ty[v] < nHot) + I[ty[v]][tx] = inFeatures[R0[v] * input_stride + tx]; + O[v] = 0; } __syncthreads(); #pragma unroll for (int k = 0; k < K; k++) #pragma unroll - for (int v = 0; v < V; v++) - O[v] += I[ty[v]][k] * W[k][tx]; + for (int v = 0; v < V; v++) + O[v] += I[ty[v]][k] * W[k][tx]; #pragma unroll for (int v = 0; v < V; v++) - if (s + ty[v] < nHot) - O[v] += outFeatures[R1[v] * output_stride + tx]; + if (s + ty[v] < nHot) + O[v] += outFeatures[R1[v] * output_stride + tx]; #pragma unroll for (int v = 0; v < V; v++) - if (s + ty[v] < nHot) - outFeatures[R1[v] * output_stride + tx] = O[v]; + if (s + ty[v] < nHot) + outFeatures[R1[v] * output_stride + tx] = O[v]; __syncthreads(); } w += K * output_nPlanes; @@ -207,24 +209,26 @@ dConvolution_KMxKN_forwardB(T *inFeatures, T *outFeatures, T *w, Int *rules, if (input_nPlanes % K == 0 and output_nPlanes % K == 0) { \ Int o = (nHot / K) * K; \ if (o >= K) \ - dConvolution_KMxKN_forwardA< \ - T, K, V><<>>(inFeatures, outFeatures, w, rules, o, \ - input_nPlanes, input_stride, \ - output_nPlanes, output_stride); \ + dConvolution_KMxKN_forwardA< \ + T, K, \ + V><<>>( \ + inFeatures, outFeatures, w, rules, o, input_nPlanes, input_stride, \ + output_nPlanes, output_stride); \ if (nHot > o) \ - dConvolution_KMxKN_forwardB< \ - T, K, V><<>>( \ - inFeatures, outFeatures, w, rules + 2 * o, nHot - o, \ - input_nPlanes, input_stride, output_nPlanes, output_stride); \ + dConvolution_KMxKN_forwardB<<>>( \ + inFeatures, outFeatures, w, rules + 2 * o, nHot - o, \ + input_nPlanes, input_stride, output_nPlanes, output_stride); \ return; \ } \ } template void dConvolution_forward(T *inFeatures, T *outFeatures, T *w, Int *rules, - Int nHot, Int input_nPlanes, Int input_stride, - Int output_nPlanes, Int output_stride, Int nGroups) { + Int nHot, Int input_nPlanes, Int input_stride, + Int output_nPlanes, Int output_stride, Int nGroups, cudaStream_t &stream) { FOO(T, 64, 16) FOO(T, 32, 8) FOO(T, 16, 4) @@ -233,9 +237,10 @@ void dConvolution_forward(T *inFeatures, T *outFeatures, T *w, Int *rules, } template <> void dConvolution_forward(double *inFeatures, double *outFeatures, - double *w, Int *rules, Int nHot, - Int input_nPlanes, Int input_stride, - Int output_nPlanes, Int output_stride, Int nGroups) { + double *w, Int *rules, Int nHot, + Int input_nPlanes, Int input_stride, + Int output_nPlanes, Int output_stride, + Int nGroups, cudaStream_t &stream) { FOO(double, 32, 8) FOO(double, 16, 4) FOO(double, 8, 2) @@ -249,9 +254,9 @@ void dConvolution_forward(double *inFeatures, double *outFeatures, template __global__ void dConvolution_KMxKN_backward_dW_A(T *inFeatures, T *dInFeatures, T *dOutFeatures, - T *w, T *dw, Int *rules, Int nHot, - Int input_nPlanes, Int input_stride, - Int output_nPlanes, Int output_stride) { + T *w, T *dw, Int *rules, Int nHot, + Int input_nPlanes, Int input_stride, + Int output_nPlanes, Int output_stride) { // M = gridDim.y == input_nPlanes / K Int N = output_nPlanes / K; Int m = blockIdx.y; @@ -259,8 +264,8 @@ dConvolution_KMxKN_backward_dW_A(T *inFeatures, T *dInFeatures, T *dOutFeatures, inFeatures += m * K + g * input_nPlanes; dInFeatures += m * K + g * input_nPlanes; dOutFeatures += g * output_nPlanes; - w += m * K * output_nPlanes+ g * input_nPlanes * output_nPlanes; - dw += m * K * output_nPlanes+ g * input_nPlanes * output_nPlanes; + w += m * K * output_nPlanes + g * input_nPlanes * output_nPlanes; + dw += m * K * output_nPlanes + g * input_nPlanes * output_nPlanes; TACC dI[V]; TACC dW[V]; @@ -286,31 +291,31 @@ dConvolution_KMxKN_backward_dW_A(T *inFeatures, T *dInFeatures, T *dOutFeatures, for (Int s = blockIdx.x * K; s < nHot; s += K * gridDim.x) { #pragma unroll for (int v = 0; v < V; v++) { - R0[v] = rules[2 * (s + ty[v])]; - R1[v] = rules[2 * (s + ty[v]) + 1]; - dI[v] = 0; + R0[v] = rules[2 * (s + ty[v])]; + R1[v] = rules[2 * (s + ty[v]) + 1]; + dI[v] = 0; } __syncthreads(); // Read input and dOutput #pragma unroll for (int v = 0; v < V; v++) { - I[ty[v]][tx] = inFeatures[R0[v] * input_stride + tx]; - dO[ty[v]][tx] = dOutFeatures[R1[v] * output_stride + tx]; + I[ty[v]][tx] = inFeatures[R0[v] * input_stride + tx]; + dO[ty[v]][tx] = dOutFeatures[R1[v] * output_stride + tx]; } __syncthreads(); #pragma unroll for (int k = 0; k < K; k++) #pragma unroll - for (int v = 0; v < V; v++) { - dI[v] += dO[ty[v]][k] * W[tx][k]; - dW[v] += I[k][ty[v]] * dO[k][tx]; - } + for (int v = 0; v < V; v++) { + dI[v] += dO[ty[v]][k] * W[tx][k]; + dW[v] += I[k][ty[v]] * dO[k][tx]; + } #pragma unroll for (int v = 0; v < V; v++) - dI[v] += dInFeatures[R0[v] * input_stride + tx]; + dI[v] += dInFeatures[R0[v] * input_stride + tx]; #pragma unroll for (int v = 0; v < V; v++) - dInFeatures[R0[v] * input_stride + tx] = dI[v]; + dInFeatures[R0[v] * input_stride + tx] = dI[v]; __syncthreads(); } #pragma unroll @@ -328,9 +333,9 @@ dConvolution_KMxKN_backward_dW_A(T *inFeatures, T *dInFeatures, T *dOutFeatures, template __global__ void dConvolution_KMxKN_backward_dW_B(T *inFeatures, T *dInFeatures, T *dOutFeatures, - T *w, T *dw, Int *rules, Int nHot, - Int input_nPlanes, Int input_stride, - Int output_nPlanes, Int output_stride) { + T *w, T *dw, Int *rules, Int nHot, + Int input_nPlanes, Int input_stride, + Int output_nPlanes, Int output_stride) { // M = gridDim.y == input_nPlanes / K Int N = output_nPlanes / K; Int m = blockIdx.y; @@ -338,8 +343,8 @@ dConvolution_KMxKN_backward_dW_B(T *inFeatures, T *dInFeatures, T *dOutFeatures, inFeatures += m * K + g * input_nPlanes; dInFeatures += m * K + g * input_nPlanes; dOutFeatures += g * output_nPlanes; - w += m * K * output_nPlanes+ g * input_nPlanes * output_nPlanes; - dw += m * K * output_nPlanes+ g * input_nPlanes * output_nPlanes; + w += m * K * output_nPlanes + g * input_nPlanes * output_nPlanes; + dw += m * K * output_nPlanes + g * input_nPlanes * output_nPlanes; TACC dI[V]; TACC dW[V]; @@ -365,39 +370,39 @@ dConvolution_KMxKN_backward_dW_B(T *inFeatures, T *dInFeatures, T *dOutFeatures, for (Int s = blockIdx.x * K; s < nHot; s += K * gridDim.x) { #pragma unroll for (int v = 0; v < V; v++) { - if (s + ty[v] < nHot) { - R0[v] = rules[2 * (s + ty[v])]; - R1[v] = rules[2 * (s + ty[v]) + 1]; - } - dI[v] = 0; + if (s + ty[v] < nHot) { + R0[v] = rules[2 * (s + ty[v])]; + R1[v] = rules[2 * (s + ty[v]) + 1]; + } + dI[v] = 0; } __syncthreads(); // Read input and dOutput #pragma unroll for (int v = 0; v < V; v++) - if (s + ty[v] < nHot) { - I[ty[v]][tx] = inFeatures[R0[v] * input_stride + tx]; - dO[ty[v]][tx] = dOutFeatures[R1[v] * output_stride + tx]; - } else { - I[ty[v]][tx] = 0; - dO[ty[v]][tx] = 0; - } + if (s + ty[v] < nHot) { + I[ty[v]][tx] = inFeatures[R0[v] * input_stride + tx]; + dO[ty[v]][tx] = dOutFeatures[R1[v] * output_stride + tx]; + } else { + I[ty[v]][tx] = 0; + dO[ty[v]][tx] = 0; + } __syncthreads(); #pragma unroll for (int k = 0; k < K; k++) #pragma unroll - for (int v = 0; v < V; v++) { - dI[v] += dO[ty[v]][k] * W[tx][k]; - dW[v] += I[k][ty[v]] * dO[k][tx]; - } + for (int v = 0; v < V; v++) { + dI[v] += dO[ty[v]][k] * W[tx][k]; + dW[v] += I[k][ty[v]] * dO[k][tx]; + } #pragma unroll for (int v = 0; v < V; v++) - if (s + ty[v] < nHot) - dI[v] += dInFeatures[R0[v] * input_stride + tx]; + if (s + ty[v] < nHot) + dI[v] += dInFeatures[R0[v] * input_stride + tx]; #pragma unroll for (int v = 0; v < V; v++) - if (s + ty[v] < nHot) - dInFeatures[R0[v] * input_stride + tx] = dI[v]; + if (s + ty[v] < nHot) + dInFeatures[R0[v] * input_stride + tx] = dI[v]; __syncthreads(); } #pragma unroll @@ -414,26 +419,28 @@ dConvolution_KMxKN_backward_dW_B(T *inFeatures, T *dInFeatures, T *dOutFeatures, if (input_nPlanes % K == 0 and output_nPlanes % K == 0) { \ Int o = (nHot / K) * K; \ if (o >= K) \ - dConvolution_KMxKN_backward_dW_A< \ - T, K, V><<>>( \ - inFeatures, dInFeatures, dOutFeatures, w, dw, rules, o, \ - input_nPlanes, input_stride, output_nPlanes, output_stride); \ + dConvolution_KMxKN_backward_dW_A< \ + T, K, \ + V><<>>(inFeatures, dInFeatures, dOutFeatures, w, \ + dw, rules, o, input_nPlanes, input_stride, \ + output_nPlanes, output_stride); \ if (nHot > o) \ - dConvolution_KMxKN_backward_dW_B< \ - T, K, V><<>>( \ - inFeatures, dInFeatures, dOutFeatures, w, dw, rules + 2 * o, \ - nHot - o, input_nPlanes, input_stride, output_nPlanes, \ - output_stride); \ + dConvolution_KMxKN_backward_dW_B< \ + T, K, V><<>>( \ + inFeatures, dInFeatures, dOutFeatures, w, dw, rules + 2 * o, \ + nHot - o, input_nPlanes, input_stride, output_nPlanes, \ + output_stride); \ return; \ } \ } template void dConvolution_backward_dW(T *inFeatures, T *dInFeatures, T *dOutFeatures, - T *w, T *dw, Int *rules, Int nHot, - Int input_nPlanes, Int input_stride, - Int output_nPlanes, Int output_stride, Int nGroups) { + T *w, T *dw, Int *rules, Int nHot, + Int input_nPlanes, Int input_stride, + Int output_nPlanes, Int output_stride, + Int nGroups, cudaStream_t &stream) { FOO(T, 32, 8) FOO(T, 16, 4) FOO(T, 8, 2) @@ -444,8 +451,8 @@ void dConvolution_backward_dW(T *inFeatures, T *dInFeatures, T *dOutFeatures, template __global__ void dConvolution_KMxKN_forward2(T *inFeatures, T *outFeatures, T *w, Int *rules, - Int nHot, Int input_nPlanes, Int input_stride, - Int output_nPlanes, Int output_stride) { + Int nHot, Int input_nPlanes, Int input_stride, + Int output_nPlanes, Int output_stride) { // Input x Weight -> Output // blockDim=(K,K/V,1), gridDim=(nBlocks,N,nGroups) Volkov-blocks // K is a multiple of V, @@ -479,40 +486,40 @@ dConvolution_KMxKN_forward2(T *inFeatures, T *outFeatures, T *w, Int *rules, #pragma unroll for (int v = 0; v < V; v++) if (ty[v] < KI and tx < KO) - W[ty[v]][tx] = w[ty[v] * output_nPlanes + tx]; + W[ty[v]][tx] = w[ty[v] * output_nPlanes + tx]; for (Int s = blockIdx.x * K; s < nHot; s += K * gridDim.x) { // Read rules for K input/output pairs #pragma unroll for (int v = 0; v < V; v++) { - if (ty[v] < 2) { - int q = ty[v] * K + tx; - if (s + q / 2 < nHot) - R[q] = rules[2 * s + q]; - } + if (ty[v] < 2) { + int q = ty[v] * K + tx; + if (s + q / 2 < nHot) + R[q] = rules[2 * s + q]; + } } __syncthreads(); // Read input, reset O[] #pragma unroll for (int v = 0; v < V; v++) { - if (tx < KI and s + ty[v] < nHot) - I[ty[v]][tx] = inFeatures[R[2 * ty[v]] * input_stride + tx]; - O[v] = 0; + if (tx < KI and s + ty[v] < nHot) + I[ty[v]][tx] = inFeatures[R[2 * ty[v]] * input_stride + tx]; + O[v] = 0; } __syncthreads(); #pragma unroll for (int k = 0; k < KI; k++) #pragma unroll - for (int v = 0; v < V; v++) - O[v] += I[ty[v]][k] * W[k][tx]; + for (int v = 0; v < V; v++) + O[v] += I[ty[v]][k] * W[k][tx]; __syncthreads(); #pragma unroll for (int v = 0; v < V; v++) - if (tx < KO and s + ty[v] < nHot) - outFeatures[R[2 * ty[v] + 1] * output_stride + tx] += O[v]; + if (tx < KO and s + ty[v] < nHot) + outFeatures[R[2 * ty[v] + 1] * output_stride + tx] += O[v]; __syncthreads(); } w += K * output_nPlanes; @@ -526,9 +533,9 @@ dConvolution_KMxKN_forward2(T *inFeatures, T *outFeatures, T *w, Int *rules, template __global__ void dConvolution_KMxKN_backward_dW2(T *inFeatures, T *dInFeatures, T *dOutFeatures, - T *w, T *dw, Int *rules, Int nHot, - Int input_nPlanes, Int input_stride, - Int output_nPlanes, Int output_stride) { + T *w, T *dw, Int *rules, Int nHot, + Int input_nPlanes, Int input_stride, + Int output_nPlanes, Int output_stride) { // M = gridDim.y == input_nPlanes / K Int N = (output_nPlanes + K - 1) / K; Int m = blockIdx.y; @@ -536,8 +543,8 @@ dConvolution_KMxKN_backward_dW2(T *inFeatures, T *dInFeatures, T *dOutFeatures, inFeatures += m * K + g * input_nPlanes; dInFeatures += m * K + g * input_nPlanes; dOutFeatures += g * output_nPlanes; - w += m * K * output_nPlanes+ g * input_nPlanes * output_nPlanes; - dw += m * K * output_nPlanes+ g * input_nPlanes * output_nPlanes; + w += m * K * output_nPlanes + g * input_nPlanes * output_nPlanes; + dw += m * K * output_nPlanes + g * input_nPlanes * output_nPlanes; Int KI = min(K, input_nPlanes - K * m); TACC dI[V]; @@ -559,7 +566,7 @@ dConvolution_KMxKN_backward_dW2(T *inFeatures, T *dInFeatures, T *dOutFeatures, #pragma unroll for (int v = 0; v < V; v++) { if (ty[v] < KI and tx < KO) - W[ty[v]][tx] = w[ty[v] * output_nPlanes + tx]; + W[ty[v]][tx] = w[ty[v] * output_nPlanes + tx]; dW[v] = 0; } @@ -567,48 +574,48 @@ dConvolution_KMxKN_backward_dW2(T *inFeatures, T *dInFeatures, T *dOutFeatures, // Read rules for K input/output pairs, reset dI[] #pragma unroll for (int v = 0; v < V; v++) { - if (ty[v] < 2) { - int q = ty[v] * K + tx; - if (s + q / 2 < nHot) - R[q] = rules[2 * s + q]; - } - dI[v] = 0; + if (ty[v] < 2) { + int q = ty[v] * K + tx; + if (s + q / 2 < nHot) + R[q] = rules[2 * s + q]; + } + dI[v] = 0; } __syncthreads(); // Read input and dOutput #pragma unroll for (int v = 0; v < V; v++) { - if (tx < KI and s + ty[v] < nHot) - I[ty[v]][tx] = inFeatures[R[2 * ty[v]] * input_stride + tx]; - else - I[ty[v]][tx] = 0; - if (tx < KO and s + ty[v] < nHot) - dO[ty[v]][tx] = dOutFeatures[R[2 * ty[v] + 1] * output_stride + tx]; - else - dO[ty[v]][tx] = 0; + if (tx < KI and s + ty[v] < nHot) + I[ty[v]][tx] = inFeatures[R[2 * ty[v]] * input_stride + tx]; + else + I[ty[v]][tx] = 0; + if (tx < KO and s + ty[v] < nHot) + dO[ty[v]][tx] = dOutFeatures[R[2 * ty[v] + 1] * output_stride + tx]; + else + dO[ty[v]][tx] = 0; } __syncthreads(); #pragma unroll for (int k = 0; k < KO; k++) #pragma unroll - for (int v = 0; v < V; v++) - dI[v] += dO[ty[v]][k] * W[tx][k]; + for (int v = 0; v < V; v++) + dI[v] += dO[ty[v]][k] * W[tx][k]; #pragma unroll for (int k = 0; k < K; k++) #pragma unroll - for (int v = 0; v < V; v++) - dW[v] += I[k][ty[v]] * dO[k][tx]; + for (int v = 0; v < V; v++) + dW[v] += I[k][ty[v]] * dO[k][tx]; __syncthreads(); #pragma unroll for (int v = 0; v < V; v++) - if (tx < KI and s + ty[v] < nHot) - dInFeatures[R[2 * ty[v]] * input_stride + tx] += dI[v]; + if (tx < KI and s + ty[v] < nHot) + dInFeatures[R[2 * ty[v]] * input_stride + tx] += dI[v]; __syncthreads(); } #pragma unroll for (int v = 0; v < V; v++) if (ty[v] < KI and tx < KO) - atomicAdd(&dw[ty[v] * output_nPlanes + tx], dW[v]); + atomicAdd(&dw[ty[v] * output_nPlanes + tx], dW[v]); w += K; dw += K; dOutFeatures += K; @@ -617,52 +624,74 @@ dConvolution_KMxKN_backward_dW2(T *inFeatures, T *dInFeatures, T *dOutFeatures, template double dConvolution_forward2(T *inFeatures, T *outFeatures, T *w, - RuleBook _rules, Int input_nPlanes, - Int input_stride, Int output_nPlanes, - Int output_stride, Int nGroups) { + RuleBook _rules, Int input_nPlanes, + Int input_stride, Int output_nPlanes, + Int output_stride, Int nGroups) { Int c = input_nPlanes * output_nPlanes * nGroups; double flops = 0; + auto command = [&](Int nHotB) -> void { + w += c; + flops += nHotB * c; + }; + if (input_nPlanes % 8 != 0 or output_nPlanes % 8 != 0) { const int K = 16; const int V = 4; - RULEBOOKITERATOR( - (dConvolution_KMxKN_forward2< - T, K, - V><<>>( - inFeatures, outFeatures, w, rbB, nHotB, input_nPlanes, input_stride, - output_nPlanes, output_stride)); - , w += c; flops += nHotB * c;) + + auto application = [&](Int *rbB, Int nHotB, cudaStream_t &stream) -> void { + dConvolution_KMxKN_forward2< + T, K, V><<>>( + inFeatures, outFeatures, w, rbB, nHotB, input_nPlanes, input_stride, + output_nPlanes, output_stride); + }; + + iterateRuleBook(_rules, application, command); } else { - RULEBOOKITERATOR(dConvolution_forward(inFeatures, outFeatures, w, rbB, - nHotB, input_nPlanes, input_stride, - output_nPlanes, output_stride, nGroups); - , w += c; flops += nHotB * c;) + auto application = [&](Int *rbB, Int nHotB, cudaStream_t &stream) -> void { + dConvolution_forward(inFeatures, outFeatures, w, rbB, nHotB, + input_nPlanes, input_stride, output_nPlanes, + output_stride, nGroups, stream); + }; + + iterateRuleBook(_rules, application, command); } return flops; } template void dConvolution_backward_dW2(T *inFeatures, T *dInFeatures, T *dOutFeatures, - T *w, T *dw, RuleBook _rules, Int input_nPlanes, - Int input_stride, Int output_nPlanes, - Int output_stride, Int nGroups) { + T *w, T *dw, RuleBook _rules, Int input_nPlanes, + Int input_stride, Int output_nPlanes, + Int output_stride, Int nGroups) { Int c = input_nPlanes * output_nPlanes * nGroups; + auto command = [&](Int nHotB) -> void { + w += c; + dw += c; + }; + if (input_nPlanes % 8 != 0 or output_nPlanes % 8 != 0) { const int K = 16; const int V = 4; - RULEBOOKITERATOR( - (dConvolution_KMxKN_backward_dW2< - T, K, - V><<>>( - inFeatures, dInFeatures, dOutFeatures, w, dw, rbB, nHotB, - input_nPlanes, input_stride, output_nPlanes, output_stride)); - , w += c; dw += c;) + + auto application = [&](Int *rbB, Int nHotB, cudaStream_t &stream) -> void { + dConvolution_KMxKN_backward_dW2< + T, K, V><<>>( + inFeatures, dInFeatures, dOutFeatures, w, dw, rbB, nHotB, + input_nPlanes, input_stride, output_nPlanes, output_stride); + }; + + iterateRuleBook(_rules, application, command); } else { - RULEBOOKITERATOR(dConvolution_backward_dW(inFeatures, dInFeatures, - dOutFeatures, w, dw, rbB, nHotB, - input_nPlanes, input_stride, - output_nPlanes, output_stride, nGroups); - , w += c; dw += c;) + + auto application = [&](Int *rbB, Int nHotB, cudaStream_t &stream) -> void { + dConvolution_backward_dW(inFeatures, dInFeatures, dOutFeatures, w, dw, + rbB, nHotB, input_nPlanes, input_stride, + output_nPlanes, output_stride, nGroups, stream); + }; + + iterateRuleBook(_rules, application, command); } } #undef TACC \ No newline at end of file diff --git a/sparseconvnet/SCN/CUDA/Deconvolution.cu b/sparseconvnet/SCN/CUDA/Deconvolution.cu index f3e3e9e..a91139b 100644 --- a/sparseconvnet/SCN/CUDA/Deconvolution.cu +++ b/sparseconvnet/SCN/CUDA/Deconvolution.cu @@ -9,8 +9,8 @@ template __global__ void dDeconvolution_KMxKN_forwardA(T *inFeatures, T *outFeatures, T *w, Int *rules, - Int nHot, Int input_nPlanes, Int input_stride, - Int output_nPlanes, Int output_stride) { + Int nHot, Int input_nPlanes, Int input_stride, + Int output_nPlanes, Int output_stride) { // nHot must be a multiple of K!! // Input x Weight -> Output @@ -47,31 +47,31 @@ dDeconvolution_KMxKN_forwardA(T *inFeatures, T *outFeatures, T *w, Int *rules, for (Int s = blockIdx.x * K; s < nHot; s += K * gridDim.x) { #pragma unroll for (int v = 0; v < V; v++) { - R1[v] = rules[2 * (s + ty[v])]; - R0[v] = rules[2 * (s + ty[v]) + 1]; + R1[v] = rules[2 * (s + ty[v])]; + R0[v] = rules[2 * (s + ty[v]) + 1]; } __syncthreads(); // Read input, reset O[] #pragma unroll for (int v = 0; v < V; v++) { - I[ty[v]][tx] = inFeatures[R0[v] * input_stride + tx]; - O[v] = 0; + I[ty[v]][tx] = inFeatures[R0[v] * input_stride + tx]; + O[v] = 0; } __syncthreads(); #pragma unroll for (int k = 0; k < K; k++) #pragma unroll - for (int v = 0; v < V; v++) - O[v] += I[ty[v]][k] * W[k][tx]; + for (int v = 0; v < V; v++) + O[v] += I[ty[v]][k] * W[k][tx]; #pragma unroll for (int v = 0; v < V; v++) - O[v] += outFeatures[R1[v] * output_stride + tx]; + O[v] += outFeatures[R1[v] * output_stride + tx]; #pragma unroll for (int v = 0; v < V; v++) - outFeatures[R1[v] * output_stride + tx] = O[v]; + outFeatures[R1[v] * output_stride + tx] = O[v]; __syncthreads(); } w += K * output_nPlanes; @@ -81,8 +81,8 @@ dDeconvolution_KMxKN_forwardA(T *inFeatures, T *outFeatures, T *w, Int *rules, template __global__ void dDeconvolution_KMxKN_forwardB(T *inFeatures, T *outFeatures, T *w, Int *rules, - Int nHot, Int input_nPlanes, Int input_stride, - Int output_nPlanes, Int output_stride) { + Int nHot, Int input_nPlanes, Int input_stride, + Int output_nPlanes, Int output_stride) { // Input x Weight -> Output // blockDim=(K,K/V,1), gridDim=(nBlocks,N,nGroups) Volkov-blocks // K is a multiple of V, @@ -117,36 +117,36 @@ dDeconvolution_KMxKN_forwardB(T *inFeatures, T *outFeatures, T *w, Int *rules, for (Int s = blockIdx.x * K; s < nHot; s += K * gridDim.x) { #pragma unroll for (int v = 0; v < V; v++) { - if (s + ty[v] < nHot) { - R1[v] = rules[2 * (s + ty[v])]; - R0[v] = rules[2 * (s + ty[v]) + 1]; - } + if (s + ty[v] < nHot) { + R1[v] = rules[2 * (s + ty[v])]; + R0[v] = rules[2 * (s + ty[v]) + 1]; + } } __syncthreads(); // Read input, reset O[] #pragma unroll for (int v = 0; v < V; v++) { - if (s + ty[v] < nHot) - I[ty[v]][tx] = inFeatures[R0[v] * input_stride + tx]; - O[v] = 0; + if (s + ty[v] < nHot) + I[ty[v]][tx] = inFeatures[R0[v] * input_stride + tx]; + O[v] = 0; } __syncthreads(); #pragma unroll for (int k = 0; k < K; k++) #pragma unroll - for (int v = 0; v < V; v++) - O[v] += I[ty[v]][k] * W[k][tx]; + for (int v = 0; v < V; v++) + O[v] += I[ty[v]][k] * W[k][tx]; #pragma unroll for (int v = 0; v < V; v++) - if (s + ty[v] < nHot) - O[v] += outFeatures[R1[v] * output_stride + tx]; + if (s + ty[v] < nHot) + O[v] += outFeatures[R1[v] * output_stride + tx]; #pragma unroll for (int v = 0; v < V; v++) - if (s + ty[v] < nHot) - outFeatures[R1[v] * output_stride + tx] = O[v]; + if (s + ty[v] < nHot) + outFeatures[R1[v] * output_stride + tx] = O[v]; __syncthreads(); } w += K * output_nPlanes; @@ -159,24 +159,27 @@ dDeconvolution_KMxKN_forwardB(T *inFeatures, T *outFeatures, T *w, Int *rules, if (input_nPlanes % K == 0 and output_nPlanes % K == 0) { \ Int o = (nHot / K) * K; \ if (o >= K) \ - dDeconvolution_KMxKN_forwardA< \ - T, K, V><<>>(inFeatures, outFeatures, w, rules, o, \ - input_nPlanes, input_stride, \ - output_nPlanes, output_stride); \ + dDeconvolution_KMxKN_forwardA< \ + T, K, \ + V><<>>( \ + inFeatures, outFeatures, w, rules, o, input_nPlanes, input_stride, \ + output_nPlanes, output_stride); \ if (nHot > o) \ - dDeconvolution_KMxKN_forwardB< \ - T, K, V><<>>( \ - inFeatures, outFeatures, w, rules + 2 * o, nHot - o, \ - input_nPlanes, input_stride, output_nPlanes, output_stride); \ + dDeconvolution_KMxKN_forwardB< \ + T, K, V><<>>(inFeatures, outFeatures, w, rules + 2 * o, \ + nHot - o, input_nPlanes, input_stride, \ + output_nPlanes, output_stride); \ return; \ } \ } template void dDeconvolution_forward(T *inFeatures, T *outFeatures, T *w, Int *rules, - Int nHot, Int input_nPlanes, Int input_stride, - Int output_nPlanes, Int output_stride, Int nGroups) { + Int nHot, Int input_nPlanes, Int input_stride, + Int output_nPlanes, Int output_stride, Int nGroups, + cudaStream_t &stream) { FOO(T, 64, 16) FOO(T, 32, 8) FOO(T, 16, 4) @@ -185,9 +188,10 @@ void dDeconvolution_forward(T *inFeatures, T *outFeatures, T *w, Int *rules, } template <> void dDeconvolution_forward(double *inFeatures, double *outFeatures, - double *w, Int *rules, Int nHot, - Int input_nPlanes, Int input_stride, - Int output_nPlanes, Int output_stride, Int nGroups) { + double *w, Int *rules, Int nHot, + Int input_nPlanes, Int input_stride, + Int output_nPlanes, Int output_stride, + Int nGroups, cudaStream_t &stream) { FOO(double, 32, 8) FOO(double, 16, 4) FOO(double, 8, 2) @@ -210,8 +214,8 @@ __global__ void dDeconvolution_KMxKN_backward_dW_A( inFeatures += m * K + g * input_nPlanes; dInFeatures += m * K + g * input_nPlanes; dOutFeatures += g * output_nPlanes; - w += m * K * output_nPlanes+ g * input_nPlanes * output_nPlanes; - dw += m * K * output_nPlanes+ g * input_nPlanes * output_nPlanes; + w += m * K * output_nPlanes + g * input_nPlanes * output_nPlanes; + dw += m * K * output_nPlanes + g * input_nPlanes * output_nPlanes; TACC dI[V]; TACC dW[V]; @@ -237,31 +241,31 @@ __global__ void dDeconvolution_KMxKN_backward_dW_A( for (Int s = blockIdx.x * K; s < nHot; s += K * gridDim.x) { #pragma unroll for (int v = 0; v < V; v++) { - R1[v] = rules[2 * (s + ty[v])]; - R0[v] = rules[2 * (s + ty[v]) + 1]; - dI[v] = 0; + R1[v] = rules[2 * (s + ty[v])]; + R0[v] = rules[2 * (s + ty[v]) + 1]; + dI[v] = 0; } __syncthreads(); // Read input and dOutput #pragma unroll for (int v = 0; v < V; v++) { - I[ty[v]][tx] = inFeatures[R0[v] * input_stride + tx]; - dO[ty[v]][tx] = dOutFeatures[R1[v] * output_stride + tx]; + I[ty[v]][tx] = inFeatures[R0[v] * input_stride + tx]; + dO[ty[v]][tx] = dOutFeatures[R1[v] * output_stride + tx]; } __syncthreads(); #pragma unroll for (int k = 0; k < K; k++) #pragma unroll - for (int v = 0; v < V; v++) { - dI[v] += dO[ty[v]][k] * W[tx][k]; - dW[v] += I[k][ty[v]] * dO[k][tx]; - } + for (int v = 0; v < V; v++) { + dI[v] += dO[ty[v]][k] * W[tx][k]; + dW[v] += I[k][ty[v]] * dO[k][tx]; + } #pragma unroll for (int v = 0; v < V; v++) - dI[v] += dInFeatures[R0[v] * input_stride + tx]; + dI[v] += dInFeatures[R0[v] * input_stride + tx]; #pragma unroll for (int v = 0; v < V; v++) - dInFeatures[R0[v] * input_stride + tx] = dI[v]; + dInFeatures[R0[v] * input_stride + tx] = dI[v]; __syncthreads(); } #pragma unroll @@ -288,8 +292,8 @@ __global__ void dDeconvolution_KMxKN_backward_dW_B( inFeatures += m * K + g * input_nPlanes; dInFeatures += m * K + g * input_nPlanes; dOutFeatures += g * output_nPlanes; - w += m * K * output_nPlanes+ g * input_nPlanes * output_nPlanes; - dw += m * K * output_nPlanes+ g * input_nPlanes * output_nPlanes; + w += m * K * output_nPlanes + g * input_nPlanes * output_nPlanes; + dw += m * K * output_nPlanes + g * input_nPlanes * output_nPlanes; TACC dI[V]; TACC dW[V]; @@ -315,39 +319,39 @@ __global__ void dDeconvolution_KMxKN_backward_dW_B( for (Int s = blockIdx.x * K; s < nHot; s += K * gridDim.x) { #pragma unroll for (int v = 0; v < V; v++) { - if (s + ty[v] < nHot) { - R1[v] = rules[2 * (s + ty[v])]; - R0[v] = rules[2 * (s + ty[v]) + 1]; - } - dI[v] = 0; + if (s + ty[v] < nHot) { + R1[v] = rules[2 * (s + ty[v])]; + R0[v] = rules[2 * (s + ty[v]) + 1]; + } + dI[v] = 0; } __syncthreads(); // Read input and dOutput #pragma unroll for (int v = 0; v < V; v++) - if (s + ty[v] < nHot) { - I[ty[v]][tx] = inFeatures[R0[v] * input_stride + tx]; - dO[ty[v]][tx] = dOutFeatures[R1[v] * output_stride + tx]; - } else { - I[ty[v]][tx] = 0; - dO[ty[v]][tx] = 0; - } + if (s + ty[v] < nHot) { + I[ty[v]][tx] = inFeatures[R0[v] * input_stride + tx]; + dO[ty[v]][tx] = dOutFeatures[R1[v] * output_stride + tx]; + } else { + I[ty[v]][tx] = 0; + dO[ty[v]][tx] = 0; + } __syncthreads(); #pragma unroll for (int k = 0; k < K; k++) #pragma unroll - for (int v = 0; v < V; v++) { - dI[v] += dO[ty[v]][k] * W[tx][k]; - dW[v] += I[k][ty[v]] * dO[k][tx]; - } + for (int v = 0; v < V; v++) { + dI[v] += dO[ty[v]][k] * W[tx][k]; + dW[v] += I[k][ty[v]] * dO[k][tx]; + } #pragma unroll for (int v = 0; v < V; v++) - if (s + ty[v] < nHot) - dI[v] += dInFeatures[R0[v] * input_stride + tx]; + if (s + ty[v] < nHot) + dI[v] += dInFeatures[R0[v] * input_stride + tx]; #pragma unroll for (int v = 0; v < V; v++) - if (s + ty[v] < nHot) - dInFeatures[R0[v] * input_stride + tx] = dI[v]; + if (s + ty[v] < nHot) + dInFeatures[R0[v] * input_stride + tx] = dI[v]; __syncthreads(); } #pragma unroll @@ -364,26 +368,28 @@ __global__ void dDeconvolution_KMxKN_backward_dW_B( if (input_nPlanes % K == 0 and output_nPlanes % K == 0) { \ Int o = (nHot / K) * K; \ if (o >= K) \ - dDeconvolution_KMxKN_backward_dW_A< \ - T, K, V><<>>( \ - inFeatures, dInFeatures, dOutFeatures, w, dw, rules, o, \ - input_nPlanes, input_stride, output_nPlanes, output_stride); \ + dDeconvolution_KMxKN_backward_dW_A< \ + T, K, \ + V><<>>( \ + inFeatures, dInFeatures, dOutFeatures, w, dw, rules, o, \ + input_nPlanes, input_stride, output_nPlanes, output_stride); \ if (nHot > o) \ - dDeconvolution_KMxKN_backward_dW_B< \ - T, K, V><<>>( \ - inFeatures, dInFeatures, dOutFeatures, w, dw, rules + 2 * o, \ - nHot - o, input_nPlanes, input_stride, output_nPlanes, \ - output_stride); \ + dDeconvolution_KMxKN_backward_dW_B<<< \ + dim3(1, input_nPlanes / K, nGroups), dim3(K, K / V), 0, stream>>>( \ + inFeatures, dInFeatures, dOutFeatures, w, dw, rules + 2 * o, \ + nHot - o, input_nPlanes, input_stride, output_nPlanes, \ + output_stride); \ return; \ } \ } template void dDeconvolution_backward_dW(T *inFeatures, T *dInFeatures, T *dOutFeatures, - T *w, T *dw, Int *rules, Int nHot, - Int input_nPlanes, Int input_stride, - Int output_nPlanes, Int output_stride, Int nGroups) { + T *w, T *dw, Int *rules, Int nHot, + Int input_nPlanes, Int input_stride, + Int output_nPlanes, Int output_stride, + Int nGroups, cudaStream_t &stream) { FOO(T, 32, 8) FOO(T, 16, 4) FOO(T, 8, 2) @@ -394,8 +400,8 @@ void dDeconvolution_backward_dW(T *inFeatures, T *dInFeatures, T *dOutFeatures, template __global__ void dDeconvolution_KMxKN_forward2(T *inFeatures, T *outFeatures, T *w, Int *rules, - Int nHot, Int input_nPlanes, Int input_stride, - Int output_nPlanes, Int output_stride) { + Int nHot, Int input_nPlanes, Int input_stride, + Int output_nPlanes, Int output_stride) { // Input x Weight -> Output // blockDim=(K,K/V,1), gridDim=(nBlocks,N,nGroups) Volkov-blocks // K is a multiple of V, @@ -429,40 +435,40 @@ dDeconvolution_KMxKN_forward2(T *inFeatures, T *outFeatures, T *w, Int *rules, #pragma unroll for (int v = 0; v < V; v++) if (ty[v] < KI and tx < KO) - W[ty[v]][tx] = w[ty[v] * output_nPlanes + tx]; + W[ty[v]][tx] = w[ty[v] * output_nPlanes + tx]; for (Int s = blockIdx.x * K; s < nHot; s += K * gridDim.x) { // Read rules for K input/output pairs #pragma unroll for (int v = 0; v < V; v++) { - if (ty[v] < 2) { - int q = ty[v] * K + tx; - if (s + q / 2 < nHot) - R[q] = rules[2 * s + q]; - } + if (ty[v] < 2) { + int q = ty[v] * K + tx; + if (s + q / 2 < nHot) + R[q] = rules[2 * s + q]; + } } __syncthreads(); // Read input, reset O[] #pragma unroll for (int v = 0; v < V; v++) { - if (tx < KI and s + ty[v] < nHot) - I[ty[v]][tx] = inFeatures[R[2 * ty[v] + 1] * input_stride + tx]; - O[v] = 0; + if (tx < KI and s + ty[v] < nHot) + I[ty[v]][tx] = inFeatures[R[2 * ty[v] + 1] * input_stride + tx]; + O[v] = 0; } __syncthreads(); #pragma unroll for (int k = 0; k < KI; k++) #pragma unroll - for (int v = 0; v < V; v++) - O[v] += I[ty[v]][k] * W[k][tx]; + for (int v = 0; v < V; v++) + O[v] += I[ty[v]][k] * W[k][tx]; __syncthreads(); #pragma unroll for (int v = 0; v < V; v++) - if (tx < KO and s + ty[v] < nHot) - outFeatures[R[2 * ty[v]] * output_stride + tx] += O[v]; + if (tx < KO and s + ty[v] < nHot) + outFeatures[R[2 * ty[v]] * output_stride + tx] += O[v]; __syncthreads(); } w += K * output_nPlanes; @@ -476,9 +482,9 @@ dDeconvolution_KMxKN_forward2(T *inFeatures, T *outFeatures, T *w, Int *rules, template __global__ void dDeconvolution_KMxKN_backward_dW2(T *inFeatures, T *dInFeatures, - T *dOutFeatures, T *w, T *dw, Int *rules, - Int nHot, Int input_nPlanes, Int input_stride, - Int output_nPlanes, Int output_stride) { + T *dOutFeatures, T *w, T *dw, Int *rules, + Int nHot, Int input_nPlanes, Int input_stride, + Int output_nPlanes, Int output_stride) { // M = gridDim.y == input_nPlanes / K Int N = (output_nPlanes + K - 1) / K; Int m = blockIdx.y; @@ -486,8 +492,8 @@ dDeconvolution_KMxKN_backward_dW2(T *inFeatures, T *dInFeatures, inFeatures += m * K + g * input_nPlanes; dInFeatures += m * K + g * input_nPlanes; dOutFeatures += g * output_nPlanes; - w += m * K * output_nPlanes+ g * input_nPlanes * output_nPlanes; - dw += m * K * output_nPlanes+ g * input_nPlanes * output_nPlanes; + w += m * K * output_nPlanes + g * input_nPlanes * output_nPlanes; + dw += m * K * output_nPlanes + g * input_nPlanes * output_nPlanes; Int KI = min(K, input_nPlanes - K * m); TACC dI[V]; @@ -509,7 +515,7 @@ dDeconvolution_KMxKN_backward_dW2(T *inFeatures, T *dInFeatures, #pragma unroll for (int v = 0; v < V; v++) { if (ty[v] < KI and tx < KO) - W[ty[v]][tx] = w[ty[v] * output_nPlanes + tx]; + W[ty[v]][tx] = w[ty[v] * output_nPlanes + tx]; dW[v] = 0; } @@ -517,48 +523,48 @@ dDeconvolution_KMxKN_backward_dW2(T *inFeatures, T *dInFeatures, // Read rules for K input/output pairs, reset dI[] #pragma unroll for (int v = 0; v < V; v++) { - if (ty[v] < 2) { - int q = ty[v] * K + tx; - if (s + q / 2 < nHot) - R[q] = rules[2 * s + q]; - } - dI[v] = 0; + if (ty[v] < 2) { + int q = ty[v] * K + tx; + if (s + q / 2 < nHot) + R[q] = rules[2 * s + q]; + } + dI[v] = 0; } __syncthreads(); // Read input and dOutput #pragma unroll for (int v = 0; v < V; v++) { - if (tx < KI and s + ty[v] < nHot) - I[ty[v]][tx] = inFeatures[R[2 * ty[v] + 1] * input_stride + tx]; - else - I[ty[v]][tx] = 0; - if (tx < KO and s + ty[v] < nHot) - dO[ty[v]][tx] = dOutFeatures[R[2 * ty[v]] * output_stride + tx]; - else - dO[ty[v]][tx] = 0; + if (tx < KI and s + ty[v] < nHot) + I[ty[v]][tx] = inFeatures[R[2 * ty[v] + 1] * input_stride + tx]; + else + I[ty[v]][tx] = 0; + if (tx < KO and s + ty[v] < nHot) + dO[ty[v]][tx] = dOutFeatures[R[2 * ty[v]] * output_stride + tx]; + else + dO[ty[v]][tx] = 0; } __syncthreads(); #pragma unroll for (int k = 0; k < KO; k++) #pragma unroll - for (int v = 0; v < V; v++) - dI[v] += dO[ty[v]][k] * W[tx][k]; + for (int v = 0; v < V; v++) + dI[v] += dO[ty[v]][k] * W[tx][k]; #pragma unroll for (int k = 0; k < K; k++) #pragma unroll - for (int v = 0; v < V; v++) - dW[v] += I[k][ty[v]] * dO[k][tx]; + for (int v = 0; v < V; v++) + dW[v] += I[k][ty[v]] * dO[k][tx]; __syncthreads(); #pragma unroll for (int v = 0; v < V; v++) - if (tx < KI and s + ty[v] < nHot) - dInFeatures[R[2 * ty[v] + 1] * input_stride + tx] += dI[v]; + if (tx < KI and s + ty[v] < nHot) + dInFeatures[R[2 * ty[v] + 1] * input_stride + tx] += dI[v]; __syncthreads(); } #pragma unroll for (int v = 0; v < V; v++) if (ty[v] < KI and tx < KO) - atomicAdd(&dw[ty[v] * output_nPlanes + tx], dW[v]); + atomicAdd(&dw[ty[v] * output_nPlanes + tx], dW[v]); w += K; dw += K; dOutFeatures += K; @@ -567,52 +573,79 @@ dDeconvolution_KMxKN_backward_dW2(T *inFeatures, T *dInFeatures, template double dDeconvolution_forward2(T *inFeatures, T *outFeatures, T *w, - RuleBook _rules, Int input_nPlanes, - Int input_stride, Int output_nPlanes, - Int output_stride, Int nGroups) { + RuleBook _rules, Int input_nPlanes, + Int input_stride, Int output_nPlanes, + Int output_stride, Int nGroups) { Int c = input_nPlanes * output_nPlanes * nGroups; double flops = 0; + + auto command = [&](Int nHotB) -> void { + w += c; + flops += nHotB * c; + }; + if (input_nPlanes % 8 != 0 or output_nPlanes % 8 != 0) { const int K = 16; const int V = 4; - RULEBOOKITERATOR( - (dDeconvolution_KMxKN_forward2< - T, K, - V><<>>( - inFeatures, outFeatures, w, rbB, nHotB, input_nPlanes, input_stride, - output_nPlanes, output_stride)); - , w += c; flops += nHotB * c;) + + auto application = [&](Int *rbB, Int nHotB, cudaStream_t &stream) -> void { + dDeconvolution_KMxKN_forward2< + T, K, V><<>>( + inFeatures, outFeatures, w, rbB, nHotB, input_nPlanes, input_stride, + output_nPlanes, output_stride); + }; + + iterateRuleBook(_rules, application, command); + } else { - RULEBOOKITERATOR(dDeconvolution_forward(inFeatures, outFeatures, w, rbB, - nHotB, input_nPlanes, input_stride, - output_nPlanes, output_stride, nGroups); - , w += c; flops += nHotB * c;) + + auto application = [&](Int *rbB, Int nHotB, cudaStream_t &stream) -> void { + dDeconvolution_forward(inFeatures, outFeatures, w, rbB, nHotB, + input_nPlanes, input_stride, output_nPlanes, + output_stride, nGroups, stream); + }; + + iterateRuleBook(_rules, application, command); } return flops; } template void dDeconvolution_backward_dW2(T *inFeatures, T *dInFeatures, T *dOutFeatures, - T *w, T *dw, RuleBook _rules, - Int input_nPlanes, Int input_stride, - Int output_nPlanes, Int output_stride, Int nGroups) { + T *w, T *dw, RuleBook _rules, + Int input_nPlanes, Int input_stride, + Int output_nPlanes, Int output_stride, + Int nGroups) { Int c = input_nPlanes * output_nPlanes * nGroups; + auto command = [&](Int nHotB) -> void { + w += c; + dw += c; + }; + if (input_nPlanes % 8 != 0 or output_nPlanes % 8 != 0) { const int K = 16; const int V = 4; - RULEBOOKITERATOR( - (dDeconvolution_KMxKN_backward_dW2< - T, K, - V><<>>( - inFeatures, dInFeatures, dOutFeatures, w, dw, rbB, nHotB, - input_nPlanes, input_stride, output_nPlanes, output_stride)); - , w += c; dw += c;) + + auto application = [&](Int *rbB, Int nHotB, cudaStream_t &stream) -> void { + dDeconvolution_KMxKN_backward_dW2< + T, K, V><<>>( + inFeatures, dInFeatures, dOutFeatures, w, dw, rbB, nHotB, + input_nPlanes, input_stride, output_nPlanes, output_stride); + }; + + iterateRuleBook(_rules, application, command); + } else { - RULEBOOKITERATOR(dDeconvolution_backward_dW(inFeatures, dInFeatures, - dOutFeatures, w, dw, rbB, nHotB, - input_nPlanes, input_stride, - output_nPlanes, output_stride, nGroups); - , w += c; dw += c;) + auto application = [&](Int *rbB, Int nHotB, cudaStream_t &stream) -> void { + dDeconvolution_backward_dW(inFeatures, dInFeatures, dOutFeatures, w, dw, + rbB, nHotB, input_nPlanes, input_stride, + output_nPlanes, output_stride, nGroups, + stream); + }; + + iterateRuleBook(_rules, application, command); } } #undef TACC \ No newline at end of file diff --git a/sparseconvnet/SCN/CUDA/MaxPooling.cu b/sparseconvnet/SCN/CUDA/MaxPooling.cu index 4585efc..9051f7e 100644 --- a/sparseconvnet/SCN/CUDA/MaxPooling.cu +++ b/sparseconvnet/SCN/CUDA/MaxPooling.cu @@ -36,10 +36,14 @@ template void cuda_MaxPooling_ForwardPass(T *input_features, T *output_features, Int nPlanes, Int input_stride, Int output_stride, RuleBook _rules) { - RULEBOOKITERATOR((MaxPooling_fp<<<32, dim3(32, 32)>>>( - input_features, output_features, nPlanes, input_stride, output_stride, - rbB, nHotB)); - , ) + + auto application = [&](Int *rbB, Int nHotB, cudaStream_t &stream) -> void { + MaxPooling_fp<<<32, dim3(32, 32), 0, stream>>>( + input_features, output_features, nPlanes, input_stride, output_stride, + rbB, nHotB); + }; + + iterateRuleBook(_rules, application); } template __global__ void MaxPooling_bp(T *input_features, T *d_input_features, @@ -70,8 +74,13 @@ void cuda_MaxPooling_BackwardPass(T *input_features, T *d_input_features, T *output_features, T *d_output_features, Int nPlanes, Int input_stride, Int output_stride, RuleBook _rules) { - RULEBOOKITERATOR((MaxPooling_bp<<<32, dim3(32, 32)>>>( + + auto application = [&](Int *rbB, Int nHotB, cudaStream_t &stream) -> void { + MaxPooling_bp<<<32, dim3(32, 32), 0, stream>>>( input_features, d_input_features, output_features, d_output_features, - nPlanes, input_stride, output_stride, rbB, nHotB)); - , ) + nPlanes, input_stride, output_stride, rbB, nHotB); + }; + + iterateRuleBook(_rules, application); + } diff --git a/sparseconvnet/SCN/CUDA/RuleBookIterator.h b/sparseconvnet/SCN/CUDA/RuleBookIterator.h index 24992cb..f5d158a 100644 --- a/sparseconvnet/SCN/CUDA/RuleBookIterator.h +++ b/sparseconvnet/SCN/CUDA/RuleBookIterator.h @@ -7,28 +7,81 @@ #ifndef CUDA_RULEBOOKITERATOR_H #define CUDA_RULEBOOKITERATOR_H -// Macro to parallelize loading rulebook elements to CUDA memory and operating -// on the elements of the rulebook. -// X is the function to apply. -// Y is a command to run - -#define RULEBOOKITERATOR(X, Y) \ - { \ - Int rbMaxSize = 0; \ - for (auto &r : _rules) \ - rbMaxSize = std::max(rbMaxSize, (Int)r.size()); \ - at::Tensor rulesBuffer = at::empty({rbMaxSize}, at::CUDA(at_kINT)); \ - Int *rbB = rulesBuffer.data(); \ - for (int k = 0; k < _rules.size(); ++k) { \ - auto &r = _rules[k]; \ - Int nHotB = r.size() / 2; \ - if (nHotB) { \ - cudaMemcpy(rbB, &r[0], sizeof(Int) * 2 * nHotB, \ - cudaMemcpyHostToDevice); \ - X \ - } \ - Y \ - } \ +using RuleBook = std::vector>; + +void checkCuda(const cudaError_t &result) { + if (result != cudaSuccess) { + throw std::string("CUDA Runtime Error: ") + cudaGetErrorString(result); + } +} + +// Templated function to parallelize loading rulebook +// elements to CUDA memory and operating on the elements of the rulebook. +// Application is the function to apply. +// Command is a command to run. + +template +void iterateRuleBook(const RuleBook &_rules, Application app, Command comm) { + Int rbMaxSize = 0; + const Int streamCount = 4; + for (auto &r : _rules) + rbMaxSize = std::max(rbMaxSize, (Int)r.size()); + at::Tensor rulesBuffer = at::empty({rbMaxSize}, at::CUDA(at_kINT)); + Int *rbB = rulesBuffer.data(); + std::vector streams(streamCount); + std::vector pinnedBooks; + + for (int i = 0; i < streamCount; ++i) { + checkCuda(cudaStreamCreate(&streams[i])); + } + + int nextStream = 0; + cudaEvent_t prevEvent; + cudaEventCreate(&prevEvent); + + for (int k = 0; k < _rules.size(); ++k) { + auto &r = _rules[k]; + Int nHotB = r.size() / 2; + + if (nHotB) { + size_t ruleSize = sizeof(Int) * 2 * nHotB; + + Int *pinnedRules; + checkCuda(cudaMallocHost((Int **)&pinnedRules, ruleSize)); + memcpy(pinnedRules, &r[0], ruleSize); + + auto &stream = streams[nextStream]; + cudaMemcpyAsync(rbB, pinnedRules, ruleSize, cudaMemcpyHostToDevice, + stream); + + cudaStreamWaitEvent(stream, prevEvent, 0); + app(rbB, nHotB, stream); + + cudaEvent_t event; + cudaEventCreate(&event); + cudaEventRecord(event, stream); + + pinnedBooks.push_back(pinnedRules); + prevEvent = event; + nextStream = (nextStream + 1) % streamCount; + } + + comm(nHotB); + } + + for (auto &stream : streams) { + checkCuda(cudaStreamSynchronize(stream)); + checkCuda(cudaStreamDestroy(stream)); + } + + for (auto &rules : pinnedBooks) { + checkCuda(cudaFreeHost(rules)); } +} + +template +void iterateRuleBook(const RuleBook &_rules, Application app) { + iterateRuleBook(_rules, app, [](Int nHotB) -> void {}); +} #endif /* CUDA_RULEBOOKITERATOR_H */ diff --git a/sparseconvnet/SCN/CUDA/SparseToDense.cu b/sparseconvnet/SCN/CUDA/SparseToDense.cu index c526424..0c35902 100644 --- a/sparseconvnet/SCN/CUDA/SparseToDense.cu +++ b/sparseconvnet/SCN/CUDA/SparseToDense.cu @@ -31,9 +31,17 @@ template void cuda_SparseToDense_ForwardPass(T *input_features, T *output_features, Int nPlanes, Int spatialVolume, RuleBook _rules) { - RULEBOOKITERATOR((SparseToDense_fp<<<32, dim3(32, 32)>>>( - input_features, output_features, nPlanes, spatialVolume, rbB, nHotB)); - , output_features += nPlanes * spatialVolume;) + + auto application = [&](Int *rbB, Int nHotB, cudaStream_t &stream) -> void { + SparseToDense_fp<<<32, dim3(32, 32), 0, stream>>>( + input_features, output_features, nPlanes, spatialVolume, rbB, nHotB); + }; + + auto command = [&](Int nHotB) -> void { + output_features += nPlanes * spatialVolume; + }; + + iterateRuleBook(_rules, application, command); } // NTX must be >=2 so r is filled properly @@ -63,7 +71,15 @@ template void cuda_SparseToDense_BackwardPass(T *d_input_features, T *d_output_features, Int nPlanes, Int spatialVolume, RuleBook _rules) { - RULEBOOKITERATOR((SparseToDense_bp<<<32, dim3(32, 32)>>>( - d_input_features, d_output_features, nPlanes, spatialVolume, rbB, nHotB)); - , d_output_features += nPlanes * spatialVolume;) + + auto application = [&](Int *rbB, Int nHotB, cudaStream_t &stream) -> void { + SparseToDense_bp<<<32, dim3(32, 32), 0, stream>>>( + d_input_features, d_output_features, nPlanes, spatialVolume, rbB, nHotB); + }; + + auto command = [&](Int nHotB) -> void { + d_output_features += nPlanes * spatialVolume; + }; + + iterateRuleBook(_rules, application, command); } diff --git a/sparseconvnet/SCN/CUDA/UnPooling.cu b/sparseconvnet/SCN/CUDA/UnPooling.cu index 6bf3474..5aa04c2 100644 --- a/sparseconvnet/SCN/CUDA/UnPooling.cu +++ b/sparseconvnet/SCN/CUDA/UnPooling.cu @@ -33,11 +33,16 @@ template void cuda_UnPooling_ForwardPass(T *input_features, T *output_features, Int nPlanes, Int input_stride, Int output_stride, RuleBook _rules) { - RULEBOOKITERATOR((UnPooling_fp<<<32, dim3(32, 32)>>>( - input_features, output_features, nPlanes, input_stride, output_stride, - rbB, nHotB)); - , ) + + auto application = [&](Int *rbB, Int nHotB, cudaStream_t &stream) -> void { + UnPooling_fp<<<32, dim3(32, 32), 0, stream>>>( + input_features, output_features, nPlanes, input_stride, output_stride, + rbB, nHotB); + }; + + iterateRuleBook(_rules, application); } + template __global__ void UnPooling_bp(T *d_input_features, T *d_output_features, Int nPlanes, Int input_stride, Int output_stride, @@ -64,8 +69,12 @@ template void cuda_UnPooling_BackwardPass(T *d_input_features, T *d_output_features, Int nPlanes, Int input_stride, Int output_stride, RuleBook _rules) { - RULEBOOKITERATOR((UnPooling_bp<<<32, dim3(32, 32)>>>( - d_input_features, d_output_features, nPlanes, input_stride, output_stride, - rbB, nHotB)); - , ) + + auto application = [&](Int *rbB, Int nHotB, cudaStream_t &stream) -> void { + UnPooling_bp<<<32, dim3(32, 32), 0, stream>>>( + d_input_features, d_output_features, nPlanes, input_stride, + output_stride, rbB, nHotB); + }; + + iterateRuleBook(_rules, application); } diff --git a/sparseconvnet/SCN/sparseconvnet_cuda.cpp b/sparseconvnet/SCN/sparseconvnet_cuda.cpp index ac96449..f3b34d2 100644 --- a/sparseconvnet/SCN/sparseconvnet_cuda.cpp +++ b/sparseconvnet/SCN/sparseconvnet_cuda.cpp @@ -284,6 +284,8 @@ double SubmanifoldConvolution_updateOutput(at::Tensor inputSize, at::Tensor input_features, at::Tensor output_features, at::Tensor weight, at::Tensor bias) { + + if (input_features.type().is_cuda()) return cuda_SubmanifoldConvolution_updateOutput( inputSize, filterSize, m, input_features, output_features, weight,