From 12f794c22161a4b9adbcedef42ae065a1f5e8300 Mon Sep 17 00:00:00 2001 From: Marek Kolodziej Date: Tue, 21 Aug 2018 16:48:15 -0700 Subject: [PATCH] Updated channel norm, correlation and resample2d (THC -> ATen, cffi -> setuptools) --- install.sh | 6 +- models.py | 4 +- networks/FlowNetC.py | 2 +- networks/channelnorm_package/build.py | 31 -- .../{functions => }/channelnorm.py | 19 +- .../channelnorm_package/channelnorm_cuda.cc | 31 ++ .../channelnorm_package/channelnorm_kernel.cu | 174 ++++++++++ .../channelnorm_kernel.cuh | 16 + .../channelnorm_package/functions/__init__.py | 0 networks/channelnorm_package/make.sh | 12 - .../channelnorm_package/modules/__init__.py | 0 .../modules/channelnorm.py | 13 - networks/channelnorm_package/setup.py | 28 ++ .../src/ChannelNorm_cuda.c | 17 - .../src/ChannelNorm_cuda.h | 3 - .../src/ChannelNorm_kernel.cu | 127 ------- .../src/ChannelNorm_kernel.h | 12 - networks/correlation_package/build.py | 31 -- networks/correlation_package/correlation.py | 62 ++++ .../correlation_package/correlation_cuda.cc | 169 ++++++++++ .../{src => }/correlation_cuda_kernel.cu | 313 ++++++++++-------- .../correlation_cuda_kernel.cuh | 91 +++++ .../correlation_package/functions/__init__.py | 0 .../functions/correlation.py | 55 --- networks/correlation_package/make.sh | 14 - .../correlation_package/modules/__init__.py | 0 .../modules/correlation.py | 27 -- networks/correlation_package/setup.py | 29 ++ .../correlation_package/src/correlation.c | 33 -- .../correlation_package/src/correlation.h | 25 -- .../src/correlation_cuda.c | 180 ---------- .../src/correlation_cuda.h | 18 - .../src/correlation_cuda_kernel.h | 92 ----- networks/resample2d_package/build.py | 31 -- .../resample2d_package/functions/__init__.py | 0 networks/resample2d_package/make.sh | 12 - .../resample2d_package/modules/__init__.py | 0 .../resample2d_package/modules/resample2d.py | 14 - .../{functions => }/resample2d.py | 22 +- .../resample2d_package/resample2d_cuda.cc | 32 ++ .../resample2d_package/resample2d_kernel.cu | 306 +++++++++++++++++ .../resample2d_package/resample2d_kernel.cuh | 17 + networks/resample2d_package/setup.py | 29 ++ .../resample2d_package/src/Resample2d_cuda.c | 17 - .../resample2d_package/src/Resample2d_cuda.h | 3 - .../src/Resample2d_kernel.cu | 242 -------------- .../src/Resample2d_kernel.h | 11 - 47 files changed, 1204 insertions(+), 1166 deletions(-) delete mode 100755 networks/channelnorm_package/build.py rename networks/channelnorm_package/{functions => }/channelnorm.py (60%) create mode 100755 networks/channelnorm_package/channelnorm_cuda.cc create mode 100755 networks/channelnorm_package/channelnorm_kernel.cu create mode 100755 networks/channelnorm_package/channelnorm_kernel.cuh delete mode 100755 networks/channelnorm_package/functions/__init__.py delete mode 100755 networks/channelnorm_package/make.sh delete mode 100755 networks/channelnorm_package/modules/__init__.py delete mode 100755 networks/channelnorm_package/modules/channelnorm.py create mode 100755 networks/channelnorm_package/setup.py delete mode 100755 networks/channelnorm_package/src/ChannelNorm_cuda.c delete mode 100755 networks/channelnorm_package/src/ChannelNorm_cuda.h delete mode 100755 networks/channelnorm_package/src/ChannelNorm_kernel.cu delete mode 100755 networks/channelnorm_package/src/ChannelNorm_kernel.h delete mode 100755 networks/correlation_package/build.py create mode 100644 networks/correlation_package/correlation.py create mode 100755 networks/correlation_package/correlation_cuda.cc rename networks/correlation_package/{src => }/correlation_cuda_kernel.cu (54%) create mode 100755 networks/correlation_package/correlation_cuda_kernel.cuh delete mode 100755 networks/correlation_package/functions/__init__.py delete mode 100755 networks/correlation_package/functions/correlation.py delete mode 100755 networks/correlation_package/make.sh delete mode 100755 networks/correlation_package/modules/__init__.py delete mode 100755 networks/correlation_package/modules/correlation.py create mode 100755 networks/correlation_package/setup.py delete mode 100755 networks/correlation_package/src/correlation.c delete mode 100755 networks/correlation_package/src/correlation.h delete mode 100755 networks/correlation_package/src/correlation_cuda.c delete mode 100755 networks/correlation_package/src/correlation_cuda.h delete mode 100755 networks/correlation_package/src/correlation_cuda_kernel.h delete mode 100755 networks/resample2d_package/build.py delete mode 100755 networks/resample2d_package/functions/__init__.py delete mode 100755 networks/resample2d_package/make.sh delete mode 100755 networks/resample2d_package/modules/__init__.py delete mode 100755 networks/resample2d_package/modules/resample2d.py rename networks/resample2d_package/{functions => }/resample2d.py (55%) create mode 100755 networks/resample2d_package/resample2d_cuda.cc create mode 100755 networks/resample2d_package/resample2d_kernel.cu create mode 100755 networks/resample2d_package/resample2d_kernel.cuh create mode 100755 networks/resample2d_package/setup.py delete mode 100755 networks/resample2d_package/src/Resample2d_cuda.c delete mode 100755 networks/resample2d_package/src/Resample2d_cuda.h delete mode 100755 networks/resample2d_package/src/Resample2d_kernel.cu delete mode 100755 networks/resample2d_package/src/Resample2d_kernel.h diff --git a/install.sh b/install.sh index 38ab6d6..7128a58 100755 --- a/install.sh +++ b/install.sh @@ -1,8 +1,8 @@ #!/bin/bash cd ./networks/correlation_package -./make.sh +python setup.py install cd ../resample2d_package -./make.sh +python setup.py install cd ../channelnorm_package -./make.sh +python setup.py install cd .. diff --git a/models.py b/models.py index 3b11327..f6a1653 100755 --- a/models.py +++ b/models.py @@ -5,8 +5,8 @@ import math import numpy as np -from networks.resample2d_package.modules.resample2d import Resample2d -from networks.channelnorm_package.modules.channelnorm import ChannelNorm +from networks.resample2d_package.resample2d import Resample2d +from networks.channelnorm_package.channelnorm import ChannelNorm from networks import FlowNetC from networks import FlowNetS diff --git a/networks/FlowNetC.py b/networks/FlowNetC.py index f107552..64f89fe 100755 --- a/networks/FlowNetC.py +++ b/networks/FlowNetC.py @@ -5,7 +5,7 @@ import math import numpy as np -from .correlation_package.modules.correlation import Correlation +from .correlation_package.correlation import Correlation from .submodules import * 'Parameter count , 39,175,298 ' diff --git a/networks/channelnorm_package/build.py b/networks/channelnorm_package/build.py deleted file mode 100755 index 2066819..0000000 --- a/networks/channelnorm_package/build.py +++ /dev/null @@ -1,31 +0,0 @@ -import os -import torch -import torch.utils.ffi - -this_folder = os.path.dirname(os.path.abspath(__file__)) + '/' - -Headers = [] -Sources = [] -Defines = [] -Objects = [] - -if torch.cuda.is_available() == True: - Headers += ['src/ChannelNorm_cuda.h'] - Sources += ['src/ChannelNorm_cuda.c'] - Defines += [('WITH_CUDA', None)] - Objects += ['src/ChannelNorm_kernel.o'] - -ffi = torch.utils.ffi.create_extension( - name='_ext.channelnorm', - headers=Headers, - sources=Sources, - verbose=False, - with_cuda=True, - package=False, - relative_to=this_folder, - define_macros=Defines, - extra_objects=[os.path.join(this_folder, Object) for Object in Objects] -) - -if __name__ == '__main__': - ffi.build() \ No newline at end of file diff --git a/networks/channelnorm_package/functions/channelnorm.py b/networks/channelnorm_package/channelnorm.py similarity index 60% rename from networks/channelnorm_package/functions/channelnorm.py rename to networks/channelnorm_package/channelnorm.py index c925679..3bb6b15 100755 --- a/networks/channelnorm_package/functions/channelnorm.py +++ b/networks/channelnorm_package/channelnorm.py @@ -1,6 +1,6 @@ from torch.autograd import Function, Variable -from .._ext import channelnorm - +from torch.nn.modules.module import Module +import channelnorm_cuda class ChannelNormFunction(Function): @@ -10,7 +10,7 @@ def forward(ctx, input1, norm_deg=2): b, _, h, w = input1.size() output = input1.new(b, 1, h, w).zero_() - channelnorm.ChannelNorm_cuda_forward(input1, output, norm_deg) + channelnorm_cuda.forward(input1, output, norm_deg) ctx.save_for_backward(input1, output) ctx.norm_deg = norm_deg @@ -22,7 +22,18 @@ def backward(ctx, grad_output): grad_input1 = Variable(input1.new(input1.size()).zero_()) - channelnorm.ChannelNorm_cuda_backward(input1, output, grad_output.data, + channelnorm.backward(input1, output, grad_output.data, grad_input1.data, ctx.norm_deg) return grad_input1, None + + +class ChannelNorm(Module): + + def __init__(self, norm_deg=2): + super(ChannelNorm, self).__init__() + self.norm_deg = norm_deg + + def forward(self, input1): + return ChannelNormFunction.apply(input1, self.norm_deg) + diff --git a/networks/channelnorm_package/channelnorm_cuda.cc b/networks/channelnorm_package/channelnorm_cuda.cc new file mode 100755 index 0000000..69d82eb --- /dev/null +++ b/networks/channelnorm_package/channelnorm_cuda.cc @@ -0,0 +1,31 @@ +#include +#include + +#include "channelnorm_kernel.cuh" + +int channelnorm_cuda_forward( + at::Tensor& input1, + at::Tensor& output, + int norm_deg) { + + channelnorm_kernel_forward(input1, output, norm_deg); + return 1; +} + + +int channelnorm_cuda_backward( + at::Tensor& input1, + at::Tensor& output, + at::Tensor& gradOutput, + at::Tensor& gradInput1, + int norm_deg) { + + channelnorm_kernel_backward(input1, output, gradOutput, gradInput1, norm_deg); + return 1; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &channelnorm_cuda_forward, "Channel norm forward (CUDA)"); + m.def("backward", &channelnorm_cuda_backward, "Channel norm backward (CUDA)"); +} + diff --git a/networks/channelnorm_package/channelnorm_kernel.cu b/networks/channelnorm_package/channelnorm_kernel.cu new file mode 100755 index 0000000..7661fb6 --- /dev/null +++ b/networks/channelnorm_package/channelnorm_kernel.cu @@ -0,0 +1,174 @@ +#include +#include + +#include "channelnorm_kernel.cuh" + +#define CUDA_NUM_THREADS 512 + +#define DIM0(TENSOR) ((TENSOR).x) +#define DIM1(TENSOR) ((TENSOR).y) +#define DIM2(TENSOR) ((TENSOR).z) +#define DIM3(TENSOR) ((TENSOR).w) + +#define DIM3_INDEX(TENSOR, xx, yy, zz, ww) ((TENSOR)[((xx) * (TENSOR##_stride.x)) + ((yy) * (TENSOR##_stride.y)) + ((zz) * (TENSOR##_stride.z)) + ((ww) * (TENSOR##_stride.w))]) + +using at::Half; + +template +__global__ void kernel_channelnorm_update_output( + const int n, + const scalar_t* __restrict__ input1, + const long4 input1_size, + const long4 input1_stride, + scalar_t* __restrict__ output, + const long4 output_size, + const long4 output_stride, + int norm_deg) { + + int index = blockIdx.x * blockDim.x + threadIdx.x; + + if (index >= n) { + return; + } + + int dim_b = DIM0(output_size); + int dim_c = DIM1(output_size); + int dim_h = DIM2(output_size); + int dim_w = DIM3(output_size); + int dim_chw = dim_c * dim_h * dim_w; + + int b = ( index / dim_chw ) % dim_b; + int y = ( index / dim_w ) % dim_h; + int x = ( index ) % dim_w; + + int i1dim_c = DIM1(input1_size); + int i1dim_h = DIM2(input1_size); + int i1dim_w = DIM3(input1_size); + int i1dim_chw = i1dim_c * i1dim_h * i1dim_w; + int i1dim_hw = i1dim_h * i1dim_w; + + float result = 0.0; + + for (int c = 0; c < i1dim_c; ++c) { + int i1Index = b * i1dim_chw + c * i1dim_hw + y * i1dim_w + x; + scalar_t val = input1[i1Index]; + result += static_cast(val * val); + } + result = sqrt(result); + output[index] = static_cast(result); +} + + +template +__global__ void kernel_channelnorm_backward_input1( + const int n, + const scalar_t* __restrict__ input1, const long4 input1_size, const long4 input1_stride, + const scalar_t* __restrict__ output, const long4 output_size, const long4 output_stride, + const scalar_t* __restrict__ gradOutput, const long4 gradOutput_size, const long4 gradOutput_stride, + scalar_t* __restrict__ gradInput, const long4 gradInput_size, const long4 gradInput_stride, + int norm_deg) { + + int index = blockIdx.x * blockDim.x + threadIdx.x; + + if (index >= n) { + return; + } + + float val = 0.0; + + int dim_b = DIM0(gradInput_size); + int dim_c = DIM1(gradInput_size); + int dim_h = DIM2(gradInput_size); + int dim_w = DIM3(gradInput_size); + int dim_chw = dim_c * dim_h * dim_w; + int dim_hw = dim_h * dim_w; + + int b = ( index / dim_chw ) % dim_b; + int y = ( index / dim_w ) % dim_h; + int x = ( index ) % dim_w; + + + int outIndex = b * dim_hw + y * dim_w + x; + val = static_cast(gradOutput[outIndex]) * static_cast(input1[index]) / (static_cast(output[outIndex])+1e-9); + gradInput[index] = static_cast(val); + +} + +void channelnorm_kernel_forward( + at::Tensor& input1, + at::Tensor& output, + int norm_deg) { + + const long4 input1_size = make_long4(input1.size(0), input1.size(1), input1.size(2), input1.size(3)); + const long4 input1_stride = make_long4(input1.stride(0), input1.stride(1), input1.stride(2), input1.stride(3)); + + const long4 output_size = make_long4(output.size(0), output.size(1), output.size(2), output.size(3)); + const long4 output_stride = make_long4(output.stride(0), output.stride(1), output.stride(2), output.stride(3)); + + int n = output.numel(); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.type(), "channelnorm_forward", ([&] { + + kernel_channelnorm_update_output<<< (n + CUDA_NUM_THREADS - 1)/CUDA_NUM_THREADS, CUDA_NUM_THREADS, 0, at::globalContext().getCurrentCUDAStream() >>>( + n, + input1.data(), + input1_size, + input1_stride, + output.data(), + output_size, + output_stride, + norm_deg); + + })); + + // TODO: ATen-equivalent check + + // THCudaCheck(cudaGetLastError()); +} + +void channelnorm_kernel_backward( + at::Tensor& input1, + at::Tensor& output, + at::Tensor& gradOutput, + at::Tensor& gradInput1, + int norm_deg) { + + const long4 input1_size = make_long4(input1.size(0), input1.size(1), input1.size(2), input1.size(3)); + const long4 input1_stride = make_long4(input1.stride(0), input1.stride(1), input1.stride(2), input1.stride(3)); + + const long4 output_size = make_long4(output.size(0), output.size(1), output.size(2), output.size(3)); + const long4 output_stride = make_long4(output.stride(0), output.stride(1), output.stride(2), output.stride(3)); + + const long4 gradOutput_size = make_long4(gradOutput.size(0), gradOutput.size(1), gradOutput.size(2), gradOutput.size(3)); + const long4 gradOutput_stride = make_long4(gradOutput.stride(0), gradOutput.stride(1), gradOutput.stride(2), gradOutput.stride(3)); + + const long4 gradInput1_size = make_long4(gradInput1.size(0), gradInput1.size(1), gradInput1.size(2), gradInput1.size(3)); + const long4 gradInput1_stride = make_long4(gradInput1.stride(0), gradInput1.stride(1), gradInput1.stride(2), gradInput1.stride(3)); + + int n = gradInput1.numel(); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.type(), "channelnorm_backward_input1", ([&] { + + kernel_channelnorm_backward_input1<<< (n + CUDA_NUM_THREADS - 1)/CUDA_NUM_THREADS, CUDA_NUM_THREADS, 0, at::globalContext().getCurrentCUDAStream() >>>( + n, + input1.data(), + input1_size, + input1_stride, + output.data(), + output_size, + output_stride, + gradOutput.data(), + gradOutput_size, + gradOutput_stride, + gradInput1.data(), + gradInput1_size, + gradInput1_stride, + norm_deg + ); + + })); + + // TODO: Add ATen-equivalent check + +// THCudaCheck(cudaGetLastError()); +} diff --git a/networks/channelnorm_package/channelnorm_kernel.cuh b/networks/channelnorm_package/channelnorm_kernel.cuh new file mode 100755 index 0000000..3e6223f --- /dev/null +++ b/networks/channelnorm_package/channelnorm_kernel.cuh @@ -0,0 +1,16 @@ +#pragma once + +#include + +void channelnorm_kernel_forward( + at::Tensor& input1, + at::Tensor& output, + int norm_deg); + + +void channelnorm_kernel_backward( + at::Tensor& input1, + at::Tensor& output, + at::Tensor& gradOutput, + at::Tensor& gradInput1, + int norm_deg); diff --git a/networks/channelnorm_package/functions/__init__.py b/networks/channelnorm_package/functions/__init__.py deleted file mode 100755 index e69de29..0000000 diff --git a/networks/channelnorm_package/make.sh b/networks/channelnorm_package/make.sh deleted file mode 100755 index 2242f5e..0000000 --- a/networks/channelnorm_package/make.sh +++ /dev/null @@ -1,12 +0,0 @@ -#!/usr/bin/env bash -TORCH=$(python3 -c "import os; import torch; print(os.path.dirname(torch.__file__))") - -cd src -echo "Compiling channelnorm kernels by nvcc..." -rm ChannelNorm_kernel.o -rm -r ../_ext - -nvcc -c -o ChannelNorm_kernel.o ChannelNorm_kernel.cu -x cu -Xcompiler -fPIC -arch=sm_52 -I ${TORCH}/lib/include/TH -I ${TORCH}/lib/include/THC - -cd ../ -python3 build.py diff --git a/networks/channelnorm_package/modules/__init__.py b/networks/channelnorm_package/modules/__init__.py deleted file mode 100755 index e69de29..0000000 diff --git a/networks/channelnorm_package/modules/channelnorm.py b/networks/channelnorm_package/modules/channelnorm.py deleted file mode 100755 index 39c60ce..0000000 --- a/networks/channelnorm_package/modules/channelnorm.py +++ /dev/null @@ -1,13 +0,0 @@ -from torch.nn.modules.module import Module - -from ..functions.channelnorm import ChannelNormFunction - - -class ChannelNorm(Module): - - def __init__(self, norm_deg=2): - super(ChannelNorm, self).__init__() - self.norm_deg = norm_deg - - def forward(self, input1): - return ChannelNormFunction.apply(input1, self.norm_deg) diff --git a/networks/channelnorm_package/setup.py b/networks/channelnorm_package/setup.py new file mode 100755 index 0000000..5b9e86a --- /dev/null +++ b/networks/channelnorm_package/setup.py @@ -0,0 +1,28 @@ +#!/usr/bin/env python3 +import os +import torch + +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +cxx_args = ['-std=c++11'] + +nvcc_args = [ + '-gencode', 'arch=compute_52,code=sm_52', + '-gencode', 'arch=compute_60,code=sm_60', + '-gencode', 'arch=compute_61,code=sm_61', + '-gencode', 'arch=compute_70,code=sm_70', + '-gencode', 'arch=compute_70,code=compute_70' +] + +setup( + name='channelnorm_cuda', + ext_modules=[ + CUDAExtension('channelnorm_cuda', [ + 'channelnorm_cuda.cc', + 'channelnorm_kernel.cu' + ], extra_compile_args={'cxx': cxx_args, 'nvcc': nvcc_args}) + ], + cmdclass={ + 'build_ext': BuildExtension + }) diff --git a/networks/channelnorm_package/src/ChannelNorm_cuda.c b/networks/channelnorm_package/src/ChannelNorm_cuda.c deleted file mode 100755 index f1615ce..0000000 --- a/networks/channelnorm_package/src/ChannelNorm_cuda.c +++ /dev/null @@ -1,17 +0,0 @@ -#include -#include - -#include "ChannelNorm_kernel.h" - -extern THCState* state; - -int ChannelNorm_cuda_forward(THCudaTensor* input1, THCudaTensor* output, int norm_deg) { - ChannelNorm_kernel_forward(state, input1, output, norm_deg); - return 1; -} - - -int ChannelNorm_cuda_backward(THCudaTensor* input1, THCudaTensor* output, THCudaTensor* gradOutput, THCudaTensor* gradInput1, int norm_deg) { - ChannelNorm_kernel_backward(state, input1, output, gradOutput, gradInput1, norm_deg); - return 1; -} \ No newline at end of file diff --git a/networks/channelnorm_package/src/ChannelNorm_cuda.h b/networks/channelnorm_package/src/ChannelNorm_cuda.h deleted file mode 100755 index 46bdf0b..0000000 --- a/networks/channelnorm_package/src/ChannelNorm_cuda.h +++ /dev/null @@ -1,3 +0,0 @@ -int ChannelNorm_cuda_forward(THCudaTensor* input1, THCudaTensor* output, int norm_deg); - -int ChannelNorm_cuda_backward(THCudaTensor* input1, THCudaTensor* output, THCudaTensor* gradOutput, THCudaTensor* gradInput1, int norm_deg); \ No newline at end of file diff --git a/networks/channelnorm_package/src/ChannelNorm_kernel.cu b/networks/channelnorm_package/src/ChannelNorm_kernel.cu deleted file mode 100755 index 9976af9..0000000 --- a/networks/channelnorm_package/src/ChannelNorm_kernel.cu +++ /dev/null @@ -1,127 +0,0 @@ -#include -#include - -#define CUDA_NUM_THREADS 512 -#define THREADS_PER_BLOCK 64 - -#define DIM0(TENSOR) ((TENSOR).x) -#define DIM1(TENSOR) ((TENSOR).y) -#define DIM2(TENSOR) ((TENSOR).z) -#define DIM3(TENSOR) ((TENSOR).w) - -#define DIM3_INDEX(TENSOR, xx, yy, zz, ww) ((TENSOR)[((xx) * (TENSOR##_stride.x)) + ((yy) * (TENSOR##_stride.y)) + ((zz) * (TENSOR##_stride.z)) + ((ww) * (TENSOR##_stride.w))]) - - -#ifdef __cplusplus - extern "C" { -#endif - -__global__ void kernel_ChannelNorm_updateOutput(const int n, const float* input1, const long4 input1_size, const long4 input1_stride, float* output, const long4 output_size, const long4 output_stride, int norm_deg) { - int index = blockIdx.x * blockDim.x + threadIdx.x; - - if (index >= n) { - return; - } - - int dim_b = DIM0(output_size); - int dim_c = DIM1(output_size); - int dim_h = DIM2(output_size); - int dim_w = DIM3(output_size); - int dim_chw = dim_c * dim_h * dim_w; - - int b = ( index / dim_chw ) % dim_b; - int y = ( index / dim_w ) % dim_h; - int x = ( index ) % dim_w; - - int i1dim_c = DIM1(input1_size); - int i1dim_h = DIM2(input1_size); - int i1dim_w = DIM3(input1_size); - int i1dim_chw = i1dim_c * i1dim_h * i1dim_w; - int i1dim_hw = i1dim_h * i1dim_w; - - float result = 0.0; - - for (int c = 0; c < i1dim_c; ++c) { - int i1Index = b * i1dim_chw + c * i1dim_hw + y * i1dim_w + x; - float val = input1[i1Index]; - result += val * val; - } - result = sqrt(result); - output[index] = result; -} - - -__global__ void kernel_ChannelNorm_backward_input1(const int n, const float* input1, const long4 input1_size, const long4 input1_stride, - const float* output, const long4 output_size, const long4 output_stride, const float* gradOutput, const long4 gradOutput_size, const long4 gradOutput_stride, - float* gradInput, const long4 gradInput_size, const long4 gradInput_stride, int norm_deg) { - int index = blockIdx.x * blockDim.x + threadIdx.x; - - if (index >= n) { - return; - } - - float val = 0.0; - - int dim_b = DIM0(gradInput_size); - int dim_c = DIM1(gradInput_size); - int dim_h = DIM2(gradInput_size); - int dim_w = DIM3(gradInput_size); - int dim_chw = dim_c * dim_h * dim_w; - int dim_hw = dim_h * dim_w; - - int b = ( index / dim_chw ) % dim_b; - int y = ( index / dim_w ) % dim_h; - int x = ( index ) % dim_w; - - - int outIndex = b * dim_hw + y * dim_w + x; - val = gradOutput[outIndex] * input1[index] / (output[outIndex]+1e-9); - gradInput[index] = val; - -} - -void ChannelNorm_kernel_forward(THCState* state, THCudaTensor* input1, THCudaTensor* output, int norm_deg) { - int n = 0; - - const long4 input1_size = make_long4(input1->size[0], input1->size[1], input1->size[2], input1->size[3]); - const long4 input1_stride = make_long4(input1->stride[0], input1->stride[1], input1->stride[2], input1->stride[3]); - - const long4 output_size = make_long4(output->size[0], output->size[1], output->size[2], output->size[3]); - const long4 output_stride = make_long4(output->stride[0], output->stride[1], output->stride[2], output->stride[3]); - - n = THCudaTensor_nElement(state, output); - kernel_ChannelNorm_updateOutput<<< (n + CUDA_NUM_THREADS - 1)/CUDA_NUM_THREADS, CUDA_NUM_THREADS, 0, THCState_getCurrentStream(state) >>>( - n, THCudaTensor_data(state, input1), input1_size, input1_stride, THCudaTensor_data(state, output), output_size, output_stride, - norm_deg); - - THCudaCheck(cudaGetLastError()); -} - -void ChannelNorm_kernel_backward(THCState* state, THCudaTensor* input1, THCudaTensor* output, THCudaTensor* gradOutput, THCudaTensor* gradInput1, int norm_deg) { - int n = 0; - - const long4 input1_size = make_long4(input1->size[0], input1->size[1], input1->size[2], input1->size[3]); - const long4 input1_stride = make_long4(input1->stride[0], input1->stride[1], input1->stride[2], input1->stride[3]); - - const long4 output_size = make_long4(output->size[0], output->size[1], output->size[2], output->size[3]); - const long4 output_stride = make_long4(output->stride[0], output->stride[1], output->stride[2], output->stride[3]); - - const long4 gradOutput_size = make_long4(gradOutput->size[0], gradOutput->size[1], gradOutput->size[2], gradOutput->size[3]); - const long4 gradOutput_stride = make_long4(gradOutput->stride[0], gradOutput->stride[1], gradOutput->stride[2], gradOutput->stride[3]); - - const long4 gradInput1_size = make_long4(gradInput1->size[0], gradInput1->size[1], gradInput1->size[2], gradInput1->size[3]); - const long4 gradInput1_stride = make_long4(gradInput1->stride[0], gradInput1->stride[1], gradInput1->stride[2], gradInput1->stride[3]); - - n = THCudaTensor_nElement(state, gradInput1); - kernel_ChannelNorm_backward_input1<<< (n + CUDA_NUM_THREADS - 1)/CUDA_NUM_THREADS, CUDA_NUM_THREADS, 0, THCState_getCurrentStream(state) >>>( - n, THCudaTensor_data(state, input1), input1_size, input1_stride, THCudaTensor_data(state, output), output_size, output_stride, - THCudaTensor_data(state, gradOutput), gradOutput_size, gradOutput_stride, THCudaTensor_data(state, gradInput1), gradInput1_size, gradInput1_stride, - norm_deg - ); - - THCudaCheck(cudaGetLastError()); -} - -#ifdef __cplusplus - } -#endif \ No newline at end of file diff --git a/networks/channelnorm_package/src/ChannelNorm_kernel.h b/networks/channelnorm_package/src/ChannelNorm_kernel.h deleted file mode 100755 index 4ed2897..0000000 --- a/networks/channelnorm_package/src/ChannelNorm_kernel.h +++ /dev/null @@ -1,12 +0,0 @@ -#ifdef __cplusplus - extern "C" { -#endif - -void ChannelNorm_kernel_forward(THCState* state, THCudaTensor* input1, THCudaTensor* output, int norm_deg); - - -void ChannelNorm_kernel_backward(THCState* state, THCudaTensor* input1, THCudaTensor* output, THCudaTensor* gradOutput, THCudaTensor* gradInput1, int norm_deg); - -#ifdef __cplusplus - } -#endif \ No newline at end of file diff --git a/networks/correlation_package/build.py b/networks/correlation_package/build.py deleted file mode 100755 index 9974e2a..0000000 --- a/networks/correlation_package/build.py +++ /dev/null @@ -1,31 +0,0 @@ -import os -import torch -import torch.utils.ffi - -this_folder = os.path.dirname(os.path.abspath(__file__)) + '/' - -Headers = [] -Sources = [] -Defines = [] -Objects = [] - -if torch.cuda.is_available() == True: - Headers += ['src/correlation_cuda.h'] - Sources += ['src/correlation_cuda.c'] - Defines += [('WITH_CUDA', None)] - Objects += ['src/correlation_cuda_kernel.o'] - -ffi = torch.utils.ffi.create_extension( - name='_ext.correlation', - headers=Headers, - sources=Sources, - verbose=False, - with_cuda=True, - package=False, - relative_to=this_folder, - define_macros=Defines, - extra_objects=[os.path.join(this_folder, Object) for Object in Objects] -) - -if __name__ == '__main__': - ffi.build() \ No newline at end of file diff --git a/networks/correlation_package/correlation.py b/networks/correlation_package/correlation.py new file mode 100644 index 0000000..80a8b09 --- /dev/null +++ b/networks/correlation_package/correlation.py @@ -0,0 +1,62 @@ +import torch +from torch.nn.modules.module import Module +from torch.autograd import Function +import correlation_cuda + +class CorrelationFunction(Function): + + def __init__(self, pad_size=3, kernel_size=3, max_displacement=20, stride1=1, stride2=2, corr_multiply=1): + super(CorrelationFunction, self).__init__() + self.pad_size = pad_size + self.kernel_size = kernel_size + self.max_displacement = max_displacement + self.stride1 = stride1 + self.stride2 = stride2 + self.corr_multiply = corr_multiply + # self.out_channel = ((max_displacement/stride2)*2 + 1) * ((max_displacement/stride2)*2 + 1) + + def forward(self, input1, input2): + self.save_for_backward(input1, input2) + + with torch.cuda.device_of(input1): + rbot1 = input1.new() + rbot2 = input2.new() + output = input1.new() + + correlation_cuda.forward(input1, input2, rbot1, rbot2, output, + self.pad_size, self.kernel_size, self.max_displacement,self.stride1, self.stride2, self.corr_multiply) + + return output + + def backward(self, grad_output): + input1, input2 = self.saved_tensors + + with torch.cuda.device_of(input1): + rbot1 = input1.new() + rbot2 = input2.new() + + grad_input1 = input1.new() + grad_input2 = input2.new() + + correlation_cuda.backward(input1, input2, rbot1, rbot2, grad_output, grad_input1, grad_input2, + self.pad_size, self.kernel_size, self.max_displacement,self.stride1, self.stride2, self.corr_multiply) + + return grad_input1, grad_input2 + + +class Correlation(Module): + def __init__(self, pad_size=0, kernel_size=0, max_displacement=0, stride1=1, stride2=2, corr_multiply=1): + super(Correlation, self).__init__() + self.pad_size = pad_size + self.kernel_size = kernel_size + self.max_displacement = max_displacement + self.stride1 = stride1 + self.stride2 = stride2 + self.corr_multiply = corr_multiply + + def forward(self, input1, input2): + + result = CorrelationFunction(self.pad_size, self.kernel_size, self.max_displacement,self.stride1, self.stride2, self.corr_multiply)(input1, input2) + + return result + diff --git a/networks/correlation_package/correlation_cuda.cc b/networks/correlation_package/correlation_cuda.cc new file mode 100755 index 0000000..44a2386 --- /dev/null +++ b/networks/correlation_package/correlation_cuda.cc @@ -0,0 +1,169 @@ +#include +#include +#include +#include + +#include "correlation_cuda_kernel.cuh" + +int correlation_forward_cuda(at::Tensor& input1, at::Tensor& input2, at::Tensor& rInput1, at::Tensor& rInput2, at::Tensor& output, + int pad_size, + int kernel_size, + int max_displacement, + int stride1, + int stride2, + int corr_type_multiply) +{ + + int batchSize = input1.size(0); + + int nInputChannels = input1.size(1); + int inputHeight = input1.size(2); + int inputWidth = input1.size(3); + + int kernel_radius = (kernel_size - 1) / 2; + int border_radius = kernel_radius + max_displacement; + + int paddedInputHeight = inputHeight + 2 * pad_size; + int paddedInputWidth = inputWidth + 2 * pad_size; + + int nOutputChannels = ((max_displacement/stride2)*2 + 1) * ((max_displacement/stride2)*2 + 1); + + int outputHeight = ceil(static_cast(paddedInputHeight - 2 * border_radius) / static_cast(stride1)); + int outputwidth = ceil(static_cast(paddedInputWidth - 2 * border_radius) / static_cast(stride1)); + + rInput1.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels}); + rInput2.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels}); + output.resize_({batchSize, nOutputChannels, outputHeight, outputwidth}); + + rInput1.fill_(0); + rInput2.fill_(0); + output.fill_(0); + + int success = correlation_forward_cuda_kernel( + output, + output.size(0), + output.size(1), + output.size(2), + output.size(3), + output.stride(0), + output.stride(1), + output.stride(2), + output.stride(3), + input1, + input1.size(1), + input1.size(2), + input1.size(3), + input1.stride(0), + input1.stride(1), + input1.stride(2), + input1.stride(3), + input2, + input2.size(1), + input2.stride(0), + input2.stride(1), + input2.stride(2), + input2.stride(3), + rInput1, + rInput2, + pad_size, + kernel_size, + max_displacement, + stride1, + stride2, + corr_type_multiply, + at::globalContext().getCurrentCUDAStream() + ); + + //check for errors + if (!success) { + AT_ERROR("CUDA call failed"); + } + + return 1; + +} + +int correlation_backward_cuda(at::Tensor& input1, at::Tensor& input2, at::Tensor& rInput1, at::Tensor& rInput2, at::Tensor& gradOutput, + at::Tensor& gradInput1, at::Tensor& gradInput2, + int pad_size, + int kernel_size, + int max_displacement, + int stride1, + int stride2, + int corr_type_multiply) +{ + + int batchSize = input1.size(0); + int nInputChannels = input1.size(1); + int paddedInputHeight = input1.size(2)+ 2 * pad_size; + int paddedInputWidth = input1.size(3)+ 2 * pad_size; + + int height = input1.size(2); + int width = input1.size(3); + + rInput1.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels}); + rInput2.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels}); + gradInput1.resize_({batchSize, nInputChannels, height, width}); + gradInput2.resize_({batchSize, nInputChannels, height, width}); + + rInput1.fill_(0); + rInput2.fill_(0); + gradInput1.fill_(0); + gradInput2.fill_(0); + + int success = correlation_backward_cuda_kernel(gradOutput, + gradOutput.size(0), + gradOutput.size(1), + gradOutput.size(2), + gradOutput.size(3), + gradOutput.stride(0), + gradOutput.stride(1), + gradOutput.stride(2), + gradOutput.stride(3), + input1, + input1.size(1), + input1.size(2), + input1.size(3), + input1.stride(0), + input1.stride(1), + input1.stride(2), + input1.stride(3), + input2, + input2.stride(0), + input2.stride(1), + input2.stride(2), + input2.stride(3), + gradInput1, + gradInput1.stride(0), + gradInput1.stride(1), + gradInput1.stride(2), + gradInput1.stride(3), + gradInput2, + gradInput2.size(1), + gradInput2.stride(0), + gradInput2.stride(1), + gradInput2.stride(2), + gradInput2.stride(3), + rInput1, + rInput2, + pad_size, + kernel_size, + max_displacement, + stride1, + stride2, + corr_type_multiply, + at::globalContext().getCurrentCUDAStream() + ); + + if (!success) { + AT_ERROR("CUDA call failed"); + } + + return 1; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &correlation_forward_cuda, "Correlation forward (CUDA)"); + m.def("backward", &correlation_backward_cuda, "Correlation backward (CUDA)"); +} + diff --git a/networks/correlation_package/src/correlation_cuda_kernel.cu b/networks/correlation_package/correlation_cuda_kernel.cu similarity index 54% rename from networks/correlation_package/src/correlation_cuda_kernel.cu rename to networks/correlation_package/correlation_cuda_kernel.cu index 36c3cde..3b52734 100755 --- a/networks/correlation_package/src/correlation_cuda_kernel.cu +++ b/networks/correlation_package/correlation_cuda_kernel.cu @@ -1,21 +1,28 @@ #include -#include "correlation_cuda_kernel.h" - -#define real float +#include "correlation_cuda_kernel.cuh" #define CUDA_NUM_THREADS 1024 #define THREADS_PER_BLOCK 32 -__global__ void channels_first(float* input, float* rinput, int channels, int height, int width, int pad_size) +#include +#include +#include +#include + +using at::Half; + +template +__global__ void channels_first(const scalar_t* __restrict__ input, scalar_t* rinput, int channels, int height, int width, int pad_size) { + // n (batch size), c (num of channels), y (height), x (width) int n = blockIdx.x; int y = blockIdx.y; int x = blockIdx.z; int ch_off = threadIdx.x; - float value; + scalar_t value; int dimcyx = channels * height * width; int dimyx = height * width; @@ -31,9 +38,10 @@ __global__ void channels_first(float* input, float* rinput, int channels, int he } } -__global__ void Correlation_forward( float *output, int nOutputChannels, int outputHeight, int outputWidth, - float *rInput1, int nInputChannels, int inputHeight, int inputWidth, - float *rInput2, +template +__global__ void correlation_forward( scalar_t* output, int nOutputChannels, int outputHeight, int outputWidth, + const scalar_t* __restrict__ rInput1, int nInputChannels, int inputHeight, int inputWidth, + const scalar_t* __restrict__ rInput2, int pad_size, int kernel_size, int max_displacement, @@ -50,8 +58,8 @@ __global__ void Correlation_forward( float *output, int nOutputChannels, int out int displacement_size = 2 * displacement_rad + 1; int n = blockIdx.x; - int y1 = blockIdx.y * stride1 + max_displacement + kernel_rad; - int x1 = blockIdx.z * stride1 + max_displacement + kernel_rad; + int y1 = blockIdx.y * stride1 + max_displacement; + int x1 = blockIdx.z * stride1 + max_displacement; int c = threadIdx.x; int pdimyxc = pInputHeight * pInputWidth * nInputChannels; @@ -62,9 +70,9 @@ __global__ void Correlation_forward( float *output, int nOutputChannels, int out int tdimyx = outputHeight * outputWidth; int tdimx = outputWidth; - float nelems = kernel_size * kernel_size * pdimc; + scalar_t nelems = kernel_size * kernel_size * pdimc; - __shared__ float prod_sum[THREADS_PER_BLOCK]; + __shared__ scalar_t prod_sum[THREADS_PER_BLOCK]; // no significant speed-up in using chip memory for input1 sub-data, // not enough chip memory size to accomodate memory per block for input2 sub-data @@ -91,7 +99,7 @@ __global__ void Correlation_forward( float *output, int nOutputChannels, int out // accumulate __syncthreads(); if (c == 0) { - float reduce_sum = 0; + scalar_t reduce_sum = 0; for (int index = 0; index < THREADS_PER_BLOCK; ++index) { reduce_sum += prod_sum[index]; } @@ -105,9 +113,10 @@ __global__ void Correlation_forward( float *output, int nOutputChannels, int out } -__global__ void Correlation_backward_input1(int item, float *gradInput1, int nInputChannels, int inputHeight, int inputWidth, - float *gradOutput, int nOutputChannels, int outputHeight, int outputWidth, - float *rInput2, +template +__global__ void correlation_backward_input1(int item, scalar_t* gradInput1, int nInputChannels, int inputHeight, int inputWidth, + const scalar_t* __restrict__ gradOutput, int nOutputChannels, int outputHeight, int outputWidth, + const scalar_t* __restrict__ rInput2, int pad_size, int kernel_size, int max_displacement, @@ -163,9 +172,9 @@ __global__ void Correlation_backward_input1(int item, float *gradInput1, int nIn int odimyx = inputHeight * inputWidth; int odimx = inputWidth; - float nelems = kernel_size * kernel_size * nInputChannels; + scalar_t nelems = kernel_size * kernel_size * nInputChannels; - __shared__ float prod_sum[THREADS_PER_BLOCK]; + __shared__ scalar_t prod_sum[THREADS_PER_BLOCK]; prod_sum[tch_off] = 0; for (int tc = tch_off; tc < nOutputChannels; tc += THREADS_PER_BLOCK) { @@ -175,7 +184,7 @@ __global__ void Correlation_backward_input1(int item, float *gradInput1, int nIn int indx2 = n * pdimyxc + (y + j2)* pdimxc + (x + i2) * pdimc + c; - float val2 = rInput2[indx2]; + scalar_t val2 = rInput2[indx2]; for (int j = ymin; j <= ymax; ++j) { for (int i = xmin; i <= xmax; ++i) { @@ -187,7 +196,7 @@ __global__ void Correlation_backward_input1(int item, float *gradInput1, int nIn __syncthreads(); if(tch_off == 0) { - float reduce_sum = 0; + scalar_t reduce_sum = 0; for(int idx = 0; idx < THREADS_PER_BLOCK; idx++) { reduce_sum += prod_sum[idx]; } @@ -197,9 +206,10 @@ __global__ void Correlation_backward_input1(int item, float *gradInput1, int nIn } -__global__ void Correlation_backward_input2(int item, float *gradInput2, int nInputChannels, int inputHeight, int inputWidth, - float *gradOutput, int nOutputChannels, int outputHeight, int outputWidth, - float *rInput1, +template +__global__ void correlation_backward_input2(int item, scalar_t* gradInput2, int nInputChannels, int inputHeight, int inputWidth, + const scalar_t* __restrict__ gradOutput, int nOutputChannels, int outputHeight, int outputWidth, + const scalar_t* __restrict__ rInput1, int pad_size, int kernel_size, int max_displacement, @@ -234,9 +244,9 @@ __global__ void Correlation_backward_input2(int item, float *gradInput2, int nIn int odimyx = inputHeight * inputWidth; int odimx = inputWidth; - float nelems = kernel_size * kernel_size * nInputChannels; + scalar_t nelems = kernel_size * kernel_size * nInputChannels; - __shared__ float prod_sum[THREADS_PER_BLOCK]; + __shared__ scalar_t prod_sum[THREADS_PER_BLOCK]; prod_sum[tch_off] = 0; for (int tc = tch_off; tc < nOutputChannels; tc += THREADS_PER_BLOCK) { @@ -266,7 +276,7 @@ __global__ void Correlation_backward_input2(int item, float *gradInput2, int nIn ymax = min(outputHeight-1,ymax); int indx1 = n * pdimyxc + (y - j2)* pdimxc + (x - i2) * pdimc + c; - float val1 = rInput1[indx1]; + scalar_t val1 = rInput1[indx1]; for (int j = ymin; j <= ymax; ++j) { for (int i = xmin; i <= xmax; ++i) { @@ -279,7 +289,7 @@ __global__ void Correlation_backward_input2(int item, float *gradInput2, int nIn __syncthreads(); if(tch_off == 0) { - float reduce_sum = 0; + scalar_t reduce_sum = 0; for(int idx = 0; idx < THREADS_PER_BLOCK; idx++) { reduce_sum += prod_sum[idx]; } @@ -289,46 +299,43 @@ __global__ void Correlation_backward_input2(int item, float *gradInput2, int nIn } -#ifdef __cplusplus -extern "C" { -#endif - -int Correlation_forward_cuda_kernel(/*THCudaTensor_data(state, output)*/ float *output, - /*THCudaTensor_size(state, output, 0)*/ int ob, - /*THCudaTensor_size(state, output, 1)*/ int oc, - /*THCudaTensor_size(state, output, 2)*/ int oh, - /*THCudaTensor_size(state, output, 3)*/ int ow, - /*THCudaTensor_stride(state, output, 0)*/ int osb, - /*THCudaTensor_stride(state, output, 1)*/ int osc, - /*THCudaTensor_stride(state, output, 2)*/ int osh, - /*THCudaTensor_stride(state, output, 3)*/ int osw, - - /*THCudaTensor_data(state, input1)*/ float *input1, - /*THCudaTensor_size(state, input1, 1)*/ int ic, - /*THCudaTensor_size(state, input1, 2)*/ int ih, - /*THCudaTensor_size(state, input1, 3)*/ int iw, - /*THCudaTensor_stride(state, input1, 0)*/ int isb, - /*THCudaTensor_stride(state, input1, 1)*/ int isc, - /*THCudaTensor_stride(state, input1, 2)*/ int ish, - /*THCudaTensor_stride(state, input1, 3)*/ int isw, - - /*THCudaTensor_data(state, input2)*/ float *input2, - /*THCudaTensor_size(state, input2, 1)*/ int gc, - /*THCudaTensor_stride(state, input2, 0)*/ int gsb, - /*THCudaTensor_stride(state, input2, 1)*/ int gsc, - /*THCudaTensor_stride(state, input2, 2)*/ int gsh, - /*THCudaTensor_stride(state, input2, 3)*/ int gsw, - - /*THCudaTensor_data(state, rInput1)*/ float *rInput1, - /*THCudaTensor_data(state, rInput2)*/ float *rInput2, +int correlation_forward_cuda_kernel(at::Tensor& output, + int ob, + int oc, + int oh, + int ow, + int osb, + int osc, + int osh, + int osw, + + at::Tensor& input1, + int ic, + int ih, + int iw, + int isb, + int isc, + int ish, + int isw, + + at::Tensor& input2, + int gc, + int gsb, + int gsc, + int gsh, + int gsw, + + at::Tensor& rInput1, + at::Tensor& rInput2, int pad_size, int kernel_size, int max_displacement, int stride1, int stride2, int corr_type_multiply, - /*THCState_getCurrentStream(state)*/ cudaStream_t stream) + cudaStream_t stream) { + int batchSize = ob; int nInputChannels = ic; @@ -342,80 +349,98 @@ int Correlation_forward_cuda_kernel(/*THCudaTensor_data(state, output)*/ float * dim3 blocks_grid(batchSize, inputHeight, inputWidth); dim3 threads_block(THREADS_PER_BLOCK); - channels_first<<>> (input1,rInput1, nInputChannels, inputHeight, inputWidth,pad_size); - channels_first<<>> (input2,rInput2, nInputChannels, inputHeight, inputWidth, pad_size); + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.type(), "channels_first_fwd_1", ([&] { + + channels_first<<>>( + input1.data(), rInput1.data(), nInputChannels, inputHeight, inputWidth, pad_size); + + })); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input2.type(), "channels_first_fwd_2", ([&] { + + channels_first<<>> ( + input2.data(), rInput2.data(), nInputChannels, inputHeight, inputWidth, pad_size); + + })); dim3 threadsPerBlock(THREADS_PER_BLOCK); dim3 totalBlocksCorr(batchSize, outputHeight, outputWidth); - Correlation_forward <<< totalBlocksCorr, threadsPerBlock, 0, stream >>> - (output, nOutputChannels, outputHeight, outputWidth, - rInput1, nInputChannels, inputHeight, inputWidth, - rInput2, + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.type(), "correlation_forward", ([&] { + + correlation_forward<<>> + (output.data(), nOutputChannels, outputHeight, outputWidth, + rInput1.data(), nInputChannels, inputHeight, inputWidth, + rInput2.data(), pad_size, kernel_size, max_displacement, stride1, stride2); - // check for errors + })); + cudaError_t err = cudaGetLastError(); + + + // check for errors if (err != cudaSuccess) { - printf("error in Correlation_forward_cuda_kernel: %s\n", cudaGetErrorString(err)); + printf("error in correlation_forward_cuda_kernel: %s\n", cudaGetErrorString(err)); return 0; } return 1; } -int Correlation_backward_cuda_kernel( - /*THCudaTensor_data(state, gradOutput)*/ float *gradOutput, - /*THCudaTensor_size(state, gradOutput, 0)*/ int gob, - /*THCudaTensor_size(state, gradOutput, 1)*/ int goc, - /*THCudaTensor_size(state, gradOutput, 2)*/ int goh, - /*THCudaTensor_size(state, gradOutput, 3)*/ int gow, - /*THCudaTensor_stride(state, gradOutput, 0)*/ int gosb, - /*THCudaTensor_stride(state, gradOutput, 1)*/ int gosc, - /*THCudaTensor_stride(state, gradOutput, 2)*/ int gosh, - /*THCudaTensor_stride(state, gradOutput, 3)*/ int gosw, - - /*THCudaTensor_data(state, input1)*/ float* input1, - /*THCudaTensor_size(state, input1, 1)*/ int ic, - /*THCudaTensor_size(state, input1, 2)*/ int ih, - /*THCudaTensor_size(state, input1, 3)*/ int iw, - /*THCudaTensor_stride(state, input1, 0)*/ int isb, - /*THCudaTensor_stride(state, input1, 1)*/ int isc, - /*THCudaTensor_stride(state, input1, 2)*/ int ish, - /*THCudaTensor_stride(state, input1, 3)*/ int isw, - - /*THCudaTensor_data(state, input2)*/ float *input2, - /*THCudaTensor_stride(state, input2, 0)*/ int gsb, - /*THCudaTensor_stride(state, input2, 1)*/ int gsc, - /*THCudaTensor_stride(state, input2, 2)*/ int gsh, - /*THCudaTensor_stride(state, input2, 3)*/ int gsw, - - /*THCudaTensor_data(state, gradInput1)*/ float *gradInput1, - /*THCudaTensor_stride(state, gradInput1, 0)*/ int gisb, - /*THCudaTensor_stride(state, gradInput1, 1)*/ int gisc, - /*THCudaTensor_stride(state, gradInput1, 2)*/ int gish, - /*THCudaTensor_stride(state, gradInput1, 3)*/ int gisw, - - /*THCudaTensor_data(state, gradInput2)*/ float *gradInput2, - /*THCudaTensor_size(state, gradInput2, 1)*/ int ggc, - /*THCudaTensor_stride(state, gradInput2, 0)*/ int ggsb, - /*THCudaTensor_stride(state, gradInput2, 1)*/ int ggsc, - /*THCudaTensor_stride(state, gradInput2, 2)*/ int ggsh, - /*THCudaTensor_stride(state, gradInput2, 3)*/ int ggsw, - - /*THCudaTensor_data(state, rInput1)*/ float *rInput1, - /*THCudaTensor_data(state, rInput2)*/ float *rInput2, + +int correlation_backward_cuda_kernel( + at::Tensor& gradOutput, + int gob, + int goc, + int goh, + int gow, + int gosb, + int gosc, + int gosh, + int gosw, + + at::Tensor& input1, + int ic, + int ih, + int iw, + int isb, + int isc, + int ish, + int isw, + + at::Tensor& input2, + int gsb, + int gsc, + int gsh, + int gsw, + + at::Tensor& gradInput1, + int gisb, + int gisc, + int gish, + int gisw, + + at::Tensor& gradInput2, + int ggc, + int ggsb, + int ggsc, + int ggsh, + int ggsw, + + at::Tensor& rInput1, + at::Tensor& rInput2, int pad_size, int kernel_size, int max_displacement, int stride1, int stride2, int corr_type_multiply, - /*THCState_getCurrentStream(state)*/cudaStream_t stream) + cudaStream_t stream) { int batchSize = gob; @@ -432,46 +457,74 @@ int Correlation_backward_cuda_kernel( dim3 blocks_grid(batchSize, inputHeight, inputWidth); dim3 threads_block(THREADS_PER_BLOCK); - channels_first<<>> (input1, rInput1, nInputChannels,inputHeight, inputWidth, pad_size); - channels_first<<>> (input2, rInput2, nInputChannels, inputHeight, inputWidth, pad_size); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.type(), "lltm_forward_cuda", ([&] { + + channels_first<<>>( + input1.data(), + rInput1.data(), + nInputChannels, + inputHeight, + inputWidth, + pad_size + ); + })); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input2.type(), "lltm_forward_cuda", ([&] { + + channels_first<<>>( + input2.data(), + rInput2.data(), + nInputChannels, + inputHeight, + inputWidth, + pad_size + ); + })); dim3 threadsPerBlock(THREADS_PER_BLOCK); dim3 totalBlocksCorr(inputHeight, inputWidth, nInputChannels); for (int n = 0; n < num; ++n) { - Correlation_backward_input1 << > > ( - n, gradInput1, nInputChannels, inputHeight, inputWidth, - gradOutput, nOutputChannels, outputHeight, outputWidth, - rInput2, - pad_size, - kernel_size, - max_displacement, - stride1, - stride2); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input2.type(), "lltm_forward_cuda", ([&] { + + + correlation_backward_input1<<>> ( + n, gradInput1.data(), nInputChannels, inputHeight, inputWidth, + gradOutput.data(), nOutputChannels, outputHeight, outputWidth, + rInput2.data(), + pad_size, + kernel_size, + max_displacement, + stride1, + stride2); + })); } for(int n = 0; n < batchSize; n++) { - Correlation_backward_input2<<>>( - n, gradInput2, nInputChannels, inputHeight, inputWidth, - gradOutput, nOutputChannels, outputHeight, outputWidth, - rInput1, + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(rInput1.type(), "lltm_forward_cuda", ([&] { + + correlation_backward_input2<<>>( + n, gradInput2.data(), nInputChannels, inputHeight, inputWidth, + gradOutput.data(), nOutputChannels, outputHeight, outputWidth, + rInput1.data(), pad_size, kernel_size, max_displacement, stride1, stride2); + + })); } // check for errors cudaError_t err = cudaGetLastError(); if (err != cudaSuccess) { - printf("error in Correlation_backward_cuda_kernel: %s\n", cudaGetErrorString(err)); + printf("error in correlation_backward_cuda_kernel: %s\n", cudaGetErrorString(err)); return 0; } return 1; } - -#ifdef __cplusplus -} -#endif diff --git a/networks/correlation_package/correlation_cuda_kernel.cuh b/networks/correlation_package/correlation_cuda_kernel.cuh new file mode 100755 index 0000000..1586d3a --- /dev/null +++ b/networks/correlation_package/correlation_cuda_kernel.cuh @@ -0,0 +1,91 @@ +#pragma once + +#include +#include +#include + +int correlation_forward_cuda_kernel(at::Tensor& output, + int ob, + int oc, + int oh, + int ow, + int osb, + int osc, + int osh, + int osw, + + at::Tensor& input1, + int ic, + int ih, + int iw, + int isb, + int isc, + int ish, + int isw, + + at::Tensor& input2, + int gc, + int gsb, + int gsc, + int gsh, + int gsw, + + at::Tensor& rInput1, + at::Tensor& rInput2, + int pad_size, + int kernel_size, + int max_displacement, + int stride1, + int stride2, + int corr_type_multiply, + cudaStream_t stream); + + +int correlation_backward_cuda_kernel( + at::Tensor& gradOutput, + int gob, + int goc, + int goh, + int gow, + int gosb, + int gosc, + int gosh, + int gosw, + + at::Tensor& input1, + int ic, + int ih, + int iw, + int isb, + int isc, + int ish, + int isw, + + at::Tensor& input2, + int gsb, + int gsc, + int gsh, + int gsw, + + at::Tensor& gradInput1, + int gisb, + int gisc, + int gish, + int gisw, + + at::Tensor& gradInput2, + int ggc, + int ggsb, + int ggsc, + int ggsh, + int ggsw, + + at::Tensor& rInput1, + at::Tensor& rInput2, + int pad_size, + int kernel_size, + int max_displacement, + int stride1, + int stride2, + int corr_type_multiply, + cudaStream_t stream); diff --git a/networks/correlation_package/functions/__init__.py b/networks/correlation_package/functions/__init__.py deleted file mode 100755 index e69de29..0000000 diff --git a/networks/correlation_package/functions/correlation.py b/networks/correlation_package/functions/correlation.py deleted file mode 100755 index 76b9989..0000000 --- a/networks/correlation_package/functions/correlation.py +++ /dev/null @@ -1,55 +0,0 @@ -from torch.autograd import Function, Variable -from .._ext import correlation - - -class CorrelationFunction(Function): - - @staticmethod - def forward(ctx, - input1, - input2, - pad_size=3, - kernel_size=3, - max_displacement=20, - stride1=1, - stride2=2, - corr_multiply=1): - assert input1.is_contiguous() - assert input2.is_contiguous() - - ctx.save_for_backward(input1, input2) - ctx.pad_size = pad_size - ctx.kernel_size = kernel_size - ctx.max_displacement = max_displacement - ctx.stride1 = stride1 - ctx.stride2 = stride2 - ctx.corr_multiply = corr_multiply - - rbot1 = input1.new() - rbot2 = input2.new() - output = input1.new() - - correlation.Correlation_forward_cuda( - input1, input2, rbot1, rbot2, output, pad_size, kernel_size, - max_displacement, stride1, stride2, corr_multiply) - - return output - - @staticmethod - def backward(ctx, grad_output): - assert grad_output.is_contiguous() - - input1, input2 = ctx.saved_tensors - - rbot1 = input1.new() - rbot2 = input2.new() - - grad_input1 = Variable(input1.new()) - grad_input2 = Variable(input2.new()) - - correlation.Correlation_backward_cuda( - input1, input2, rbot1, rbot2, grad_output.data, grad_input1.data, - grad_input2.data, ctx.pad_size, ctx.kernel_size, - ctx.max_displacement, ctx.stride1, ctx.stride2, ctx.corr_multiply) - - return (grad_input1, grad_input2) + (None, ) * 6 diff --git a/networks/correlation_package/make.sh b/networks/correlation_package/make.sh deleted file mode 100755 index d87a446..0000000 --- a/networks/correlation_package/make.sh +++ /dev/null @@ -1,14 +0,0 @@ -#!/usr/bin/env bash -TORCH=$(python3 -c "import os; import torch; print(os.path.dirname(torch.__file__))") - -cd src - -echo "Compiling correlation kernels by nvcc..." - -rm correlation_cuda_kernel.o -rm -r ../_ext - -nvcc -c -o correlation_cuda_kernel.o correlation_cuda_kernel.cu -x cu -Xcompiler -fPIC -arch=sm_52 - -cd ../ -python3 build.py diff --git a/networks/correlation_package/modules/__init__.py b/networks/correlation_package/modules/__init__.py deleted file mode 100755 index e69de29..0000000 diff --git a/networks/correlation_package/modules/correlation.py b/networks/correlation_package/modules/correlation.py deleted file mode 100755 index f995543..0000000 --- a/networks/correlation_package/modules/correlation.py +++ /dev/null @@ -1,27 +0,0 @@ -from torch.nn.modules.module import Module - -from ..functions.correlation import CorrelationFunction - - -class Correlation(Module): - - def __init__(self, - pad_size=0, - kernel_size=0, - max_displacement=0, - stride1=1, - stride2=2, - corr_multiply=1): - super(Correlation, self).__init__() - self.pad_size = pad_size - self.kernel_size = kernel_size - self.max_displacement = max_displacement - self.stride1 = stride1 - self.stride2 = stride2 - self.corr_multiply = corr_multiply - - def forward(self, input1, input2): - return CorrelationFunction.apply(input1, input2, self.pad_size, - self.kernel_size, - self.max_displacement, self.stride1, - self.stride2, self.corr_multiply) diff --git a/networks/correlation_package/setup.py b/networks/correlation_package/setup.py new file mode 100755 index 0000000..48b7d73 --- /dev/null +++ b/networks/correlation_package/setup.py @@ -0,0 +1,29 @@ +#!/usr/bin/env python3 +import os +import torch + +from setuptools import setup, find_packages +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +cxx_args = ['-std=c++11'] + +nvcc_args = [ + '-gencode', 'arch=compute_50,code=sm_50', + '-gencode', 'arch=compute_52,code=sm_52', + '-gencode', 'arch=compute_60,code=sm_60', + '-gencode', 'arch=compute_61,code=sm_61', + '-gencode', 'arch=compute_70,code=sm_70', + '-gencode', 'arch=compute_70,code=compute_70' +] + +setup( + name='correlation_cuda', + ext_modules=[ + CUDAExtension('correlation_cuda', [ + 'correlation_cuda.cc', + 'correlation_cuda_kernel.cu' + ], extra_compile_args={'cxx': cxx_args, 'nvcc': nvcc_args}) + ], + cmdclass={ + 'build_ext': BuildExtension + }) diff --git a/networks/correlation_package/src/correlation.c b/networks/correlation_package/src/correlation.c deleted file mode 100755 index 4a58786..0000000 --- a/networks/correlation_package/src/correlation.c +++ /dev/null @@ -1,33 +0,0 @@ -#include - -int Correlation_forward_cpu(THFloatTensor *input1, - THFloatTensor *input2, - THFloatTensor *rInput1, - THFloatTensor *rInput2, - THFloatTensor *output, - int pad_size, - int kernel_size, - int max_displacement, - int stride1, - int stride2, - int corr_type_multiply) -{ - return 1; -} - -int Correlation_backward_cpu(THFloatTensor *input1, - THFloatTensor *input2, - THFloatTensor *rInput1, - THFloatTensor *rInput2, - THFloatTensor *gradOutput, - THFloatTensor *gradInput1, - THFloatTensor *gradInput2, - int pad_size, - int kernel_size, - int max_displacement, - int stride1, - int stride2, - int corr_type_multiply) -{ - return 1; -} diff --git a/networks/correlation_package/src/correlation.h b/networks/correlation_package/src/correlation.h deleted file mode 100755 index 935e391..0000000 --- a/networks/correlation_package/src/correlation.h +++ /dev/null @@ -1,25 +0,0 @@ -int Correlation_forward_cpu(THFloatTensor *input1, - THFloatTensor *input2, - THFloatTensor *rInput1, - THFloatTensor *rInput2, - THFloatTensor *output, - int pad_size, - int kernel_size, - int max_displacement, - int stride1, - int stride2, - int corr_type_multiply); - -int Correlation_backward_cpu(THFloatTensor *input1, - THFloatTensor *input2, - THFloatTensor *rInput1, - THFloatTensor *rInput2, - THFloatTensor *gradOutput, - THFloatTensor *gradInput1, - THFloatTensor *gradInput2, - int pad_size, - int kernel_size, - int max_displacement, - int stride1, - int stride2, - int corr_type_multiply); diff --git a/networks/correlation_package/src/correlation_cuda.c b/networks/correlation_package/src/correlation_cuda.c deleted file mode 100755 index 7a121f5..0000000 --- a/networks/correlation_package/src/correlation_cuda.c +++ /dev/null @@ -1,180 +0,0 @@ -#include -#include - -#include "correlation_cuda_kernel.h" - -#define real float - -// symbol to be automatically resolved by PyTorch libs -extern THCState *state; - -int Correlation_forward_cuda(THCudaTensor *input1, THCudaTensor *input2, THCudaTensor *rInput1, THCudaTensor *rInput2, THCudaTensor *output, - int pad_size, - int kernel_size, - int max_displacement, - int stride1, - int stride2, - int corr_type_multiply) -{ - - int batchSize = input1->size[0]; - int nInputChannels = input1->size[1]; - int inputHeight = input1->size[2]; - int inputWidth = input1->size[3]; - - int kernel_radius = (kernel_size - 1) / 2; - int border_radius = kernel_radius + max_displacement; - - int paddedInputHeight = inputHeight + 2 * pad_size; - int paddedInputWidth = inputWidth + 2 * pad_size; - - int nOutputChannels = ((max_displacement/stride2)*2 + 1) * ((max_displacement/stride2)*2 + 1); - - int outputHeight = ceil((float)(paddedInputHeight - 2 * border_radius) / (float)stride1); - int outputwidth = ceil((float)(paddedInputWidth - 2 * border_radius) / (float)stride1); - - THCudaTensor_resize4d(state, rInput1, batchSize, paddedInputHeight, paddedInputWidth, nInputChannels); - THCudaTensor_resize4d(state, rInput2, batchSize, paddedInputHeight, paddedInputWidth, nInputChannels); - THCudaTensor_resize4d(state, output, batchSize, nOutputChannels, outputHeight, outputwidth); - - THCudaTensor_fill(state, rInput1, 0); - THCudaTensor_fill(state, rInput2, 0); - THCudaTensor_fill(state, output, 0); - - int success = 0; - success = Correlation_forward_cuda_kernel( THCudaTensor_data(state, output), - THCudaTensor_size(state, output, 0), - THCudaTensor_size(state, output, 1), - THCudaTensor_size(state, output, 2), - THCudaTensor_size(state, output, 3), - THCudaTensor_stride(state, output, 0), - THCudaTensor_stride(state, output, 1), - THCudaTensor_stride(state, output, 2), - THCudaTensor_stride(state, output, 3), - - THCudaTensor_data(state, input1), - THCudaTensor_size(state, input1, 1), - THCudaTensor_size(state, input1, 2), - THCudaTensor_size(state, input1, 3), - THCudaTensor_stride(state, input1, 0), - THCudaTensor_stride(state, input1, 1), - THCudaTensor_stride(state, input1, 2), - THCudaTensor_stride(state, input1, 3), - - THCudaTensor_data(state, input2), - THCudaTensor_size(state, input2, 1), - THCudaTensor_stride(state, input2, 0), - THCudaTensor_stride(state, input2, 1), - THCudaTensor_stride(state, input2, 2), - THCudaTensor_stride(state, input2, 3), - - THCudaTensor_data(state, rInput1), - THCudaTensor_data(state, rInput2), - - pad_size, - kernel_size, - max_displacement, - stride1, - stride2, - corr_type_multiply, - - THCState_getCurrentStream(state)); - - THCudaTensor_free(state, rInput1); - THCudaTensor_free(state, rInput2); - - //check for errors - if (!success) { - THError("aborting"); - } - - return 1; - -} - -int Correlation_backward_cuda(THCudaTensor *input1, THCudaTensor *input2, THCudaTensor *rInput1, THCudaTensor *rInput2, THCudaTensor *gradOutput, - THCudaTensor *gradInput1, THCudaTensor *gradInput2, - int pad_size, - int kernel_size, - int max_displacement, - int stride1, - int stride2, - int corr_type_multiply) -{ - - int batchSize = input1->size[0]; - int nInputChannels = input1->size[1]; - int paddedInputHeight = input1->size[2]+ 2 * pad_size; - int paddedInputWidth = input1->size[3]+ 2 * pad_size; - - int height = input1->size[2]; - int width = input1->size[3]; - - THCudaTensor_resize4d(state, rInput1, batchSize, paddedInputHeight, paddedInputWidth, nInputChannels); - THCudaTensor_resize4d(state, rInput2, batchSize, paddedInputHeight, paddedInputWidth, nInputChannels); - THCudaTensor_resize4d(state, gradInput1, batchSize, nInputChannels, height, width); - THCudaTensor_resize4d(state, gradInput2, batchSize, nInputChannels, height, width); - - THCudaTensor_fill(state, rInput1, 0); - THCudaTensor_fill(state, rInput2, 0); - THCudaTensor_fill(state, gradInput1, 0); - THCudaTensor_fill(state, gradInput2, 0); - - int success = 0; - success = Correlation_backward_cuda_kernel( - THCudaTensor_data(state, gradOutput), - THCudaTensor_size(state, gradOutput, 0), - THCudaTensor_size(state, gradOutput, 1), - THCudaTensor_size(state, gradOutput, 2), - THCudaTensor_size(state, gradOutput, 3), - THCudaTensor_stride(state, gradOutput, 0), - THCudaTensor_stride(state, gradOutput, 1), - THCudaTensor_stride(state, gradOutput, 2), - THCudaTensor_stride(state, gradOutput, 3), - - THCudaTensor_data(state, input1), - THCudaTensor_size(state, input1, 1), - THCudaTensor_size(state, input1, 2), - THCudaTensor_size(state, input1, 3), - THCudaTensor_stride(state, input1, 0), - THCudaTensor_stride(state, input1, 1), - THCudaTensor_stride(state, input1, 2), - THCudaTensor_stride(state, input1, 3), - - THCudaTensor_data(state, input2), - THCudaTensor_stride(state, input2, 0), - THCudaTensor_stride(state, input2, 1), - THCudaTensor_stride(state, input2, 2), - THCudaTensor_stride(state, input2, 3), - - THCudaTensor_data(state, gradInput1), - THCudaTensor_stride(state, gradInput1, 0), - THCudaTensor_stride(state, gradInput1, 1), - THCudaTensor_stride(state, gradInput1, 2), - THCudaTensor_stride(state, gradInput1, 3), - - THCudaTensor_data(state, gradInput2), - THCudaTensor_size(state, gradInput2, 1), - THCudaTensor_stride(state, gradInput2, 0), - THCudaTensor_stride(state, gradInput2, 1), - THCudaTensor_stride(state, gradInput2, 2), - THCudaTensor_stride(state, gradInput2, 3), - - THCudaTensor_data(state, rInput1), - THCudaTensor_data(state, rInput2), - pad_size, - kernel_size, - max_displacement, - stride1, - stride2, - corr_type_multiply, - THCState_getCurrentStream(state)); - - THCudaTensor_free(state, rInput1); - THCudaTensor_free(state, rInput2); - - if (!success) { - THError("aborting"); - } - return 1; -} diff --git a/networks/correlation_package/src/correlation_cuda.h b/networks/correlation_package/src/correlation_cuda.h deleted file mode 100755 index 2f7fb2c..0000000 --- a/networks/correlation_package/src/correlation_cuda.h +++ /dev/null @@ -1,18 +0,0 @@ -int Correlation_forward_cuda(THCudaTensor *input1, THCudaTensor *input2, THCudaTensor *rInput1, THCudaTensor *rInput2, - THCudaTensor *output, - int pad_size, - int kernel_size, - int max_displacement, - int stride1, - int stride2, - int corr_type_multiply); - -int Correlation_backward_cuda(THCudaTensor *input1, THCudaTensor *input2, THCudaTensor *rInput1, THCudaTensor *rInput2, - THCudaTensor *gradOutput, THCudaTensor *gradInput1, THCudaTensor *gradInput2, - int pad_size, - int kernel_size, - int max_displacement, - int stride1, - int stride2, - int corr_type_multiply); - diff --git a/networks/correlation_package/src/correlation_cuda_kernel.h b/networks/correlation_package/src/correlation_cuda_kernel.h deleted file mode 100755 index af72be7..0000000 --- a/networks/correlation_package/src/correlation_cuda_kernel.h +++ /dev/null @@ -1,92 +0,0 @@ -#ifdef __cplusplus -extern "C" { -#endif - - int Correlation_forward_cuda_kernel(/*THCudaTensor_data(state, output)*/ float *output, - /*THCudaTensor_size(state, output, 0)*/ int ob, - /*THCudaTensor_size(state, output, 1)*/ int oc, - /*THCudaTensor_size(state, output, 2)*/ int oh, - /*THCudaTensor_size(state, output, 3)*/ int ow, - /*THCudaTensor_stride(state, output, 0)*/ int osb, - /*THCudaTensor_stride(state, output, 1)*/ int osc, - /*THCudaTensor_stride(state, output, 2)*/ int osh, - /*THCudaTensor_stride(state, output, 3)*/ int osw, - - /*THCudaTensor_data(state, input1)*/ float *input1, - /*THCudaTensor_size(state, input1, 1)*/ int ic, - /*THCudaTensor_size(state, input1, 2)*/ int ih, - /*THCudaTensor_size(state, input1, 3)*/ int iw, - /*THCudaTensor_stride(state, input1, 0)*/ int isb, - /*THCudaTensor_stride(state, input1, 1)*/ int isc, - /*THCudaTensor_stride(state, input1, 2)*/ int ish, - /*THCudaTensor_stride(state, input1, 3)*/ int isw, - - /*THCudaTensor_data(state, input2)*/ float *input2, - /*THCudaTensor_size(state, input2, 1)*/ int gc, - /*THCudaTensor_stride(state, input2, 0)*/ int gsb, - /*THCudaTensor_stride(state, input2, 1)*/ int gsc, - /*THCudaTensor_stride(state, input2, 2)*/ int gsh, - /*THCudaTensor_stride(state, input2, 3)*/ int gsw, - - /*THCudaTensor_data(state, rInput1)*/ float *rInput1, - /*THCudaTensor_data(state, rInput2)*/ float *rInput2, - int pad_size, - int kernel_size, - int max_displacement, - int stride1, - int stride2, - int corr_type_multiply, - /*THCState_getCurrentStream(state)*/ cudaStream_t stream); - - int Correlation_backward_cuda_kernel( - /*THCudaTensor_data(state, gradOutput)*/ float *gradOutput, - /*THCudaTensor_size(state, gradOutput, 0)*/ int gob, - /*THCudaTensor_size(state, gradOutput, 1)*/ int goc, - /*THCudaTensor_size(state, gradOutput, 2)*/ int goh, - /*THCudaTensor_size(state, gradOutput, 3)*/ int gow, - /*THCudaTensor_stride(state, gradOutput, 0)*/ int gosb, - /*THCudaTensor_stride(state, gradOutput, 1)*/ int gosc, - /*THCudaTensor_stride(state, gradOutput, 2)*/ int gosh, - /*THCudaTensor_stride(state, gradOutput, 3)*/ int gosw, - - /*THCudaTensor_data(state, input1)*/ float* input1, - /*THCudaTensor_size(state, input1, 1)*/ int ic, - /*THCudaTensor_size(state, input1, 2)*/ int ih, - /*THCudaTensor_size(state, input1, 3)*/ int iw, - /*THCudaTensor_stride(state, input1, 0)*/ int isb, - /*THCudaTensor_stride(state, input1, 1)*/ int isc, - /*THCudaTensor_stride(state, input1, 2)*/ int ish, - /*THCudaTensor_stride(state, input1, 3)*/ int isw, - - /*THCudaTensor_data(state, input2)*/ float *input2, - /*THCudaTensor_stride(state, input2, 0)*/ int gsb, - /*THCudaTensor_stride(state, input2, 1)*/ int gsc, - /*THCudaTensor_stride(state, input2, 2)*/ int gsh, - /*THCudaTensor_stride(state, input2, 3)*/ int gsw, - - /*THCudaTensor_data(state, gradInput1)*/ float *gradInput1, - /*THCudaTensor_stride(state, gradInput1, 0)*/ int gisb, - /*THCudaTensor_stride(state, gradInput1, 1)*/ int gisc, - /*THCudaTensor_stride(state, gradInput1, 2)*/ int gish, - /*THCudaTensor_stride(state, gradInput1, 3)*/ int gisw, - - /*THCudaTensor_data(state, gradInput2)*/ float *gradInput2, - /*THCudaTensor_size(state, gradInput2, 1)*/ int ggc, - /*THCudaTensor_stride(state, gradInput2, 0)*/ int ggsb, - /*THCudaTensor_stride(state, gradInput2, 1)*/ int ggsc, - /*THCudaTensor_stride(state, gradInput2, 2)*/ int ggsh, - /*THCudaTensor_stride(state, gradInput2, 3)*/ int ggsw, - - /*THCudaTensor_data(state, rInput1)*/ float *rInput1, - /*THCudaTensor_data(state, rInput2)*/ float *rInput2, - int pad_size, - int kernel_size, - int max_displacement, - int stride1, - int stride2, - int corr_type_multiply, - /*THCState_getCurrentStream(state)*/cudaStream_t stream); - -#ifdef __cplusplus -} -#endif diff --git a/networks/resample2d_package/build.py b/networks/resample2d_package/build.py deleted file mode 100755 index f1de130..0000000 --- a/networks/resample2d_package/build.py +++ /dev/null @@ -1,31 +0,0 @@ -import os -import torch -import torch.utils.ffi - -this_folder = os.path.dirname(os.path.abspath(__file__)) + '/' - -Headers = [] -Sources = [] -Defines = [] -Objects = [] - -if torch.cuda.is_available() == True: - Headers += ['src/Resample2d_cuda.h'] - Sources += ['src/Resample2d_cuda.c'] - Defines += [('WITH_CUDA', None)] - Objects += ['src/Resample2d_kernel.o'] - -ffi = torch.utils.ffi.create_extension( - name='_ext.resample2d', - headers=Headers, - sources=Sources, - verbose=False, - with_cuda=True, - package=False, - relative_to=this_folder, - define_macros=Defines, - extra_objects=[os.path.join(this_folder, Object) for Object in Objects] -) - -if __name__ == '__main__': - ffi.build() \ No newline at end of file diff --git a/networks/resample2d_package/functions/__init__.py b/networks/resample2d_package/functions/__init__.py deleted file mode 100755 index e69de29..0000000 diff --git a/networks/resample2d_package/make.sh b/networks/resample2d_package/make.sh deleted file mode 100755 index cb7d9df..0000000 --- a/networks/resample2d_package/make.sh +++ /dev/null @@ -1,12 +0,0 @@ -#!/usr/bin/env bash -TORCH=$(python3 -c "import os; import torch; print(os.path.dirname(torch.__file__))") - -cd src -echo "Compiling resample2d kernels by nvcc..." -rm Resample2d_kernel.o -rm -r ../_ext - -nvcc -c -o Resample2d_kernel.o Resample2d_kernel.cu -x cu -Xcompiler -fPIC -arch=sm_52 -I ${TORCH}/lib/include/TH -I ${TORCH}/lib/include/THC - -cd ../ -python3 build.py diff --git a/networks/resample2d_package/modules/__init__.py b/networks/resample2d_package/modules/__init__.py deleted file mode 100755 index e69de29..0000000 diff --git a/networks/resample2d_package/modules/resample2d.py b/networks/resample2d_package/modules/resample2d.py deleted file mode 100755 index 17de1b9..0000000 --- a/networks/resample2d_package/modules/resample2d.py +++ /dev/null @@ -1,14 +0,0 @@ -from torch.nn.modules.module import Module - -from ..functions.resample2d import Resample2dFunction - - -class Resample2d(Module): - - def __init__(self, kernel_size=1): - super(Resample2d, self).__init__() - self.kernel_size = kernel_size - - def forward(self, input1, input2): - input1_c = input1.contiguous() - return Resample2dFunction.apply(input1_c, input2, self.kernel_size) diff --git a/networks/resample2d_package/functions/resample2d.py b/networks/resample2d_package/resample2d.py similarity index 55% rename from networks/resample2d_package/functions/resample2d.py rename to networks/resample2d_package/resample2d.py index 28a18f6..aa019df 100755 --- a/networks/resample2d_package/functions/resample2d.py +++ b/networks/resample2d_package/resample2d.py @@ -1,6 +1,6 @@ +from torch.nn.modules.module import Module from torch.autograd import Function, Variable -from .._ext import resample2d - +import resample2d_cuda class Resample2dFunction(Function): @@ -16,7 +16,7 @@ def forward(ctx, input1, input2, kernel_size=1): b, _, h, w = input2.size() output = input1.new(b, d, h, w).zero_() - resample2d.Resample2d_cuda_forward(input1, input2, output, kernel_size) + resample2d_cuda.forward(input1, input2, output, kernel_size) return output @@ -29,8 +29,18 @@ def backward(ctx, grad_output): grad_input1 = Variable(input1.new(input1.size()).zero_()) grad_input2 = Variable(input1.new(input2.size()).zero_()) - resample2d.Resample2d_cuda_backward(input1, input2, grad_output.data, - grad_input1.data, grad_input2.data, - ctx.kernel_size) + resample2d_cuda.backward(input1, input2, grad_output.data, + grad_input1.data, grad_input2.data, + ctx.kernel_size) return grad_input1, grad_input2, None + +class Resample2d(Module): + + def __init__(self, kernel_size=1): + super(Resample2d, self).__init__() + self.kernel_size = kernel_size + + def forward(self, input1, input2): + input1_c = input1.contiguous() + return Resample2dFunction.apply(input1_c, input2, self.kernel_size) diff --git a/networks/resample2d_package/resample2d_cuda.cc b/networks/resample2d_package/resample2d_cuda.cc new file mode 100755 index 0000000..8e7269c --- /dev/null +++ b/networks/resample2d_package/resample2d_cuda.cc @@ -0,0 +1,32 @@ +#include +#include + +#include "resample2d_kernel.cuh" + +int resample2d_cuda_forward( + at::Tensor& input1, + at::Tensor& input2, + at::Tensor& output, + int kernel_size) { + resample2d_kernel_forward(input1, input2, output, kernel_size); + return 1; +} + +int resample2d_cuda_backward( + at::Tensor& input1, + at::Tensor& input2, + at::Tensor& gradOutput, + at::Tensor& gradInput1, + at::Tensor& gradInput2, + int kernel_size) { + resample2d_kernel_backward(input1, input2, gradOutput, gradInput1, gradInput2, kernel_size); + return 1; +} + + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &resample2d_cuda_forward, "Resample2D forward (CUDA)"); + m.def("backward", &resample2d_cuda_backward, "Resample2D backward (CUDA)"); +} + diff --git a/networks/resample2d_package/resample2d_kernel.cu b/networks/resample2d_package/resample2d_kernel.cu new file mode 100755 index 0000000..a93284d --- /dev/null +++ b/networks/resample2d_package/resample2d_kernel.cu @@ -0,0 +1,306 @@ +#include +#include + +#define CUDA_NUM_THREADS 512 +#define THREADS_PER_BLOCK 64 + +#define DIM0(TENSOR) ((TENSOR).x) +#define DIM1(TENSOR) ((TENSOR).y) +#define DIM2(TENSOR) ((TENSOR).z) +#define DIM3(TENSOR) ((TENSOR).w) + +#define DIM3_INDEX(TENSOR, xx, yy, zz, ww) ((TENSOR)[((xx) * (TENSOR##_stride.x)) + ((yy) * (TENSOR##_stride.y)) + ((zz) * (TENSOR##_stride.z)) + ((ww) * (TENSOR##_stride.w))]) + +template +__global__ void kernel_resample2d_update_output(const int n, + const scalar_t* __restrict__ input1, const long4 input1_size, const long4 input1_stride, + const scalar_t* __restrict__ input2, const long4 input2_size, const long4 input2_stride, + scalar_t* __restrict__ output, const long4 output_size, const long4 output_stride, int kernel_size) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + + if (index >= n) { + return; + } + + scalar_t val = 0.0f; + + int dim_b = DIM0(output_size); + int dim_c = DIM1(output_size); + int dim_h = DIM2(output_size); + int dim_w = DIM3(output_size); + int dim_chw = dim_c * dim_h * dim_w; + int dim_hw = dim_h * dim_w; + + int b = ( index / dim_chw ) % dim_b; + int c = ( index / dim_hw ) % dim_c; + int y = ( index / dim_w ) % dim_h; + int x = ( index ) % dim_w; + + scalar_t dx = DIM3_INDEX(input2, b, 0, y, x); + scalar_t dy = DIM3_INDEX(input2, b, 1, y, x); + + scalar_t xf = static_cast(x) + dx; + scalar_t yf = static_cast(y) + dy; + scalar_t alpha = xf - floor(xf); // alpha + scalar_t beta = yf - floor(yf); // beta + + int xL = max(min( int (floor(xf)), dim_w-1), 0); + int xR = max(min( int (floor(xf)+1), dim_w -1), 0); + int yT = max(min( int (floor(yf)), dim_h-1), 0); + int yB = max(min( int (floor(yf)+1), dim_h-1), 0); + + for (int fy = 0; fy < kernel_size; fy += 1) { + for (int fx = 0; fx < kernel_size; fx += 1) { + val += static_cast((1. - alpha)*(1. - beta) * DIM3_INDEX(input1, b, c, yT + fy, xL + fx)); + val += static_cast((alpha)*(1. - beta) * DIM3_INDEX(input1, b, c, yT + fy, xR + fx)); + val += static_cast((1. - alpha)*(beta) * DIM3_INDEX(input1, b, c, yB + fy, xL + fx)); + val += static_cast((alpha)*(beta) * DIM3_INDEX(input1, b, c, yB + fy, xR + fx)); + } + } + + output[index] = val; + +} + + +template +__global__ void kernel_resample2d_backward_input1( + const int n, const scalar_t* __restrict__ input1, const long4 input1_size, const long4 input1_stride, + const scalar_t* __restrict__ input2, const long4 input2_size, const long4 input2_stride, + const scalar_t* __restrict__ gradOutput, const long4 gradOutput_size, const long4 gradOutput_stride, + scalar_t* __restrict__ gradInput, const long4 gradInput_size, const long4 gradInput_stride, int kernel_size) { + + int index = blockIdx.x * blockDim.x + threadIdx.x; + + if (index >= n) { + return; + } + + int dim_b = DIM0(gradOutput_size); + int dim_c = DIM1(gradOutput_size); + int dim_h = DIM2(gradOutput_size); + int dim_w = DIM3(gradOutput_size); + int dim_chw = dim_c * dim_h * dim_w; + int dim_hw = dim_h * dim_w; + + int b = ( index / dim_chw ) % dim_b; + int c = ( index / dim_hw ) % dim_c; + int y = ( index / dim_w ) % dim_h; + int x = ( index ) % dim_w; + + scalar_t dx = DIM3_INDEX(input2, b, 0, y, x); + scalar_t dy = DIM3_INDEX(input2, b, 1, y, x); + + scalar_t xf = static_cast(x) + dx; + scalar_t yf = static_cast(y) + dy; + scalar_t alpha = xf - int(xf); // alpha + scalar_t beta = yf - int(yf); // beta + + int idim_h = DIM2(input1_size); + int idim_w = DIM3(input1_size); + + int xL = max(min( int (floor(xf)), idim_w-1), 0); + int xR = max(min( int (floor(xf)+1), idim_w -1), 0); + int yT = max(min( int (floor(yf)), idim_h-1), 0); + int yB = max(min( int (floor(yf)+1), idim_h-1), 0); + + for (int fy = 0; fy < kernel_size; fy += 1) { + for (int fx = 0; fx < kernel_size; fx += 1) { + atomicAdd(&DIM3_INDEX(gradInput, b, c, (yT + fy), (xL + fx)), (1-alpha)*(1-beta) * DIM3_INDEX(gradOutput, b, c, y, x)); + atomicAdd(&DIM3_INDEX(gradInput, b, c, (yT + fy), (xR + fx)), (alpha)*(1-beta) * DIM3_INDEX(gradOutput, b, c, y, x)); + atomicAdd(&DIM3_INDEX(gradInput, b, c, (yB + fy), (xL + fx)), (1-alpha)*(beta) * DIM3_INDEX(gradOutput, b, c, y, x)); + atomicAdd(&DIM3_INDEX(gradInput, b, c, (yB + fy), (xR + fx)), (alpha)*(beta) * DIM3_INDEX(gradOutput, b, c, y, x)); + } + } + +} + +template +__global__ void kernel_resample2d_backward_input2( + const int n, const scalar_t* __restrict__ input1, const long4 input1_size, const long4 input1_stride, + const scalar_t* __restrict__ input2, const long4 input2_size, const long4 input2_stride, + const scalar_t* __restrict__ gradOutput, const long4 gradOutput_size, const long4 gradOutput_stride, + scalar_t* __restrict__ gradInput, const long4 gradInput_size, const long4 gradInput_stride, int kernel_size) { + + int index = blockIdx.x * blockDim.x + threadIdx.x; + + if (index >= n) { + return; + } + + scalar_t output = 0.0; + int kernel_rad = (kernel_size - 1)/2; + + int dim_b = DIM0(gradInput_size); + int dim_c = DIM1(gradInput_size); + int dim_h = DIM2(gradInput_size); + int dim_w = DIM3(gradInput_size); + int dim_chw = dim_c * dim_h * dim_w; + int dim_hw = dim_h * dim_w; + + int b = ( index / dim_chw ) % dim_b; + int c = ( index / dim_hw ) % dim_c; + int y = ( index / dim_w ) % dim_h; + int x = ( index ) % dim_w; + + int odim_c = DIM1(gradOutput_size); + + scalar_t dx = DIM3_INDEX(input2, b, 0, y, x); + scalar_t dy = DIM3_INDEX(input2, b, 1, y, x); + + scalar_t xf = static_cast(x) + dx; + scalar_t yf = static_cast(y) + dy; + + int xL = max(min( int (floor(xf)), dim_w-1), 0); + int xR = max(min( int (floor(xf)+1), dim_w -1), 0); + int yT = max(min( int (floor(yf)), dim_h-1), 0); + int yB = max(min( int (floor(yf)+1), dim_h-1), 0); + + if (c % 2) { + float gamma = 1 - (xf - floor(xf)); // alpha + for (int i = 0; i <= 2*kernel_rad; ++i) { + for (int j = 0; j <= 2*kernel_rad; ++j) { + for (int ch = 0; ch < odim_c; ++ch) { + output += (gamma) * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, (yB + j), (xL + i)); + output -= (gamma) * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, (yT + j), (xL + i)); + output += (1-gamma) * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, (yB + j), (xR + i)); + output -= (1-gamma) * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, (yT + j), (xR + i)); + } + } + } + } + else { + float gamma = 1 - (yf - floor(yf)); // alpha + for (int i = 0; i <= 2*kernel_rad; ++i) { + for (int j = 0; j <= 2*kernel_rad; ++j) { + for (int ch = 0; ch < odim_c; ++ch) { + output += (gamma) * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, (yT + j), (xR + i)); + output -= (gamma) * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, (yT + j), (xL + i)); + output += (1-gamma) * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, (yB + j), (xR + i)); + output -= (1-gamma) * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, (yB + j), (xL + i)); + } + } + } + + } + + gradInput[index] = output; + +} + +void resample2d_kernel_forward( + at::Tensor& input1, + at::Tensor& input2, + at::Tensor& output, + int kernel_size) { + + int n = output.numel(); + + const long4 input1_size = make_long4(input1.size(0), input1.size(1), input1.size(2), input1.size(3)); + const long4 input1_stride = make_long4(input1.stride(0), input1.stride(1), input1.stride(2), input1.stride(3)); + + const long4 input2_size = make_long4(input2.size(0), input2.size(1), input2.size(2), input2.size(3)); + const long4 input2_stride = make_long4(input2.stride(0), input2.stride(1), input2.stride(2), input2.stride(3)); + + const long4 output_size = make_long4(output.size(0), output.size(1), output.size(2), output.size(3)); + const long4 output_stride = make_long4(output.stride(0), output.stride(1), output.stride(2), output.stride(3)); + + // TODO: when atomicAdd gets resolved, change to AT_DISPATCH_FLOATING_TYPES_AND_HALF +// AT_DISPATCH_FLOATING_TYPES(input1.type(), "resample_forward_kernel", ([&] { + + kernel_resample2d_update_output<<< (n + CUDA_NUM_THREADS - 1)/CUDA_NUM_THREADS, CUDA_NUM_THREADS, 0, at::globalContext().getCurrentCUDAStream() >>>( + n, + input1.data(), + input1_size, + input1_stride, + input2.data(), + input2_size, + input2_stride, + output.data(), + output_size, + output_stride, + kernel_size); + +// })); + + // TODO: ATen-equivalent check + + // THCudaCheck(cudaGetLastError()); + +} + +void resample2d_kernel_backward( + at::Tensor& input1, + at::Tensor& input2, + at::Tensor& gradOutput, + at::Tensor& gradInput1, + at::Tensor& gradInput2, + int kernel_size) { + + int n = gradOutput.numel(); + + const long4 input1_size = make_long4(input1.size(0), input1.size(1), input1.size(2), input1.size(3)); + const long4 input1_stride = make_long4(input1.stride(0), input1.stride(1), input1.stride(2), input1.stride(3)); + + const long4 input2_size = make_long4(input2.size(0), input2.size(1), input2.size(2), input2.size(3)); + const long4 input2_stride = make_long4(input2.stride(0), input2.stride(1), input2.stride(2), input2.stride(3)); + + const long4 gradOutput_size = make_long4(gradOutput.size(0), gradOutput.size(1), gradOutput.size(2), gradOutput.size(3)); + const long4 gradOutput_stride = make_long4(gradOutput.stride(0), gradOutput.stride(1), gradOutput.stride(2), gradOutput.stride(3)); + + const long4 gradInput1_size = make_long4(gradInput1.size(0), gradInput1.size(1), gradInput1.size(2), gradInput1.size(3)); + const long4 gradInput1_stride = make_long4(gradInput1.stride(0), gradInput1.stride(1), gradInput1.stride(2), gradInput1.stride(3)); + +// AT_DISPATCH_FLOATING_TYPES(input1.type(), "resample_backward_input1", ([&] { + + kernel_resample2d_backward_input1<<< (n + CUDA_NUM_THREADS - 1)/CUDA_NUM_THREADS, CUDA_NUM_THREADS, 0, at::globalContext().getCurrentCUDAStream() >>>( + n, + input1.data(), + input1_size, + input1_stride, + input2.data(), + input2_size, + input2_stride, + gradOutput.data(), + gradOutput_size, + gradOutput_stride, + gradInput1.data(), + gradInput1_size, + gradInput1_stride, + kernel_size + ); + +// })); + + const long4 gradInput2_size = make_long4(gradInput2.size(0), gradInput2.size(1), gradInput2.size(2), gradInput2.size(3)); + const long4 gradInput2_stride = make_long4(gradInput2.stride(0), gradInput2.stride(1), gradInput2.stride(2), gradInput2.stride(3)); + + n = gradInput2.numel(); + +// AT_DISPATCH_FLOATING_TYPES(gradInput2.type(), "resample_backward_input2", ([&] { + + + kernel_resample2d_backward_input2<<< (n + CUDA_NUM_THREADS - 1)/CUDA_NUM_THREADS, CUDA_NUM_THREADS, 0, at::globalContext().getCurrentCUDAStream() >>>( + n, + input1.data(), + input1_size, + input1_stride, + input2.data(), + input2_size, + input2_stride, + gradOutput.data(), + gradOutput_size, + gradOutput_stride, + gradInput2.data(), + gradInput2_size, + gradInput2_stride, + kernel_size + ); + +// })); + + // TODO: Use the ATen equivalent to get last error + + // THCudaCheck(cudaGetLastError()); + +} diff --git a/networks/resample2d_package/resample2d_kernel.cuh b/networks/resample2d_package/resample2d_kernel.cuh new file mode 100755 index 0000000..d20d10a --- /dev/null +++ b/networks/resample2d_package/resample2d_kernel.cuh @@ -0,0 +1,17 @@ +#pragma once + +#include + +void resample2d_kernel_forward( + at::Tensor& input1, + at::Tensor& input2, + at::Tensor& output, + int kernel_size); + +void resample2d_kernel_backward( + at::Tensor& input1, + at::Tensor& input2, + at::Tensor& gradOutput, + at::Tensor& gradInput1, + at::Tensor& gradInput2, + int kernel_size); diff --git a/networks/resample2d_package/setup.py b/networks/resample2d_package/setup.py new file mode 100755 index 0000000..bbedb25 --- /dev/null +++ b/networks/resample2d_package/setup.py @@ -0,0 +1,29 @@ +#!/usr/bin/env python3 +import os +import torch + +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +cxx_args = ['-std=c++11'] + +nvcc_args = [ + '-gencode', 'arch=compute_50,code=sm_50', + '-gencode', 'arch=compute_52,code=sm_52', + '-gencode', 'arch=compute_60,code=sm_60', + '-gencode', 'arch=compute_61,code=sm_61', + '-gencode', 'arch=compute_70,code=sm_70', + '-gencode', 'arch=compute_70,code=compute_70' +] + +setup( + name='resample2d_cuda', + ext_modules=[ + CUDAExtension('resample2d_cuda', [ + 'resample2d_cuda.cc', + 'resample2d_kernel.cu' + ], extra_compile_args={'cxx': cxx_args, 'nvcc': nvcc_args}) + ], + cmdclass={ + 'build_ext': BuildExtension + }) diff --git a/networks/resample2d_package/src/Resample2d_cuda.c b/networks/resample2d_package/src/Resample2d_cuda.c deleted file mode 100755 index 878c8d6..0000000 --- a/networks/resample2d_package/src/Resample2d_cuda.c +++ /dev/null @@ -1,17 +0,0 @@ -#include -#include - -#include "Resample2d_kernel.h" - -extern THCState* state; - -int Resample2d_cuda_forward(THCudaTensor* input1, THCudaTensor* input2, THCudaTensor* output, int kernel_size) { - Resample2d_kernel_forward(state, input1, input2, output, kernel_size); - return 1; -} - -int Resample2d_cuda_backward(THCudaTensor* input1, THCudaTensor* input2, THCudaTensor* gradOutput, THCudaTensor* gradInput1, THCudaTensor* gradInput2, int kernel_size) { - Resample2d_kernel_backward(state, input1, input2, gradOutput, gradInput1, gradInput2, kernel_size); - - return 1; -} \ No newline at end of file diff --git a/networks/resample2d_package/src/Resample2d_cuda.h b/networks/resample2d_package/src/Resample2d_cuda.h deleted file mode 100755 index 2edfb8d..0000000 --- a/networks/resample2d_package/src/Resample2d_cuda.h +++ /dev/null @@ -1,3 +0,0 @@ -int Resample2d_cuda_forward(THCudaTensor* input1, THCudaTensor* input2, THCudaTensor* output, int kernel_size); - -int Resample2d_cuda_backward(THCudaTensor* input1, THCudaTensor* input2, THCudaTensor* gradOutput, THCudaTensor* gradInput1, THCudaTensor* gradInput2, int kernel_size); \ No newline at end of file diff --git a/networks/resample2d_package/src/Resample2d_kernel.cu b/networks/resample2d_package/src/Resample2d_kernel.cu deleted file mode 100755 index 095893d..0000000 --- a/networks/resample2d_package/src/Resample2d_kernel.cu +++ /dev/null @@ -1,242 +0,0 @@ -#include -#include -#include -#include - -#define CUDA_NUM_THREADS 512 -#define THREADS_PER_BLOCK 64 - -#define DIM0(TENSOR) ((TENSOR).x) -#define DIM1(TENSOR) ((TENSOR).y) -#define DIM2(TENSOR) ((TENSOR).z) -#define DIM3(TENSOR) ((TENSOR).w) - -#define DIM3_INDEX(TENSOR, xx, yy, zz, ww) ((TENSOR)[((xx) * (TENSOR##_stride.x)) + ((yy) * (TENSOR##_stride.y)) + ((zz) * (TENSOR##_stride.z)) + ((ww) * (TENSOR##_stride.w))]) - -#ifdef __cplusplus - extern "C" { -#endif - -__global__ void kernel_Resample2d_updateOutput(const int n, const float* input1, const long4 input1_size, const long4 input1_stride, - const float* input2, const long4 input2_size, const long4 input2_stride, float* output, const long4 output_size, const long4 output_stride, int kernel_size) { - int index = blockIdx.x * blockDim.x + threadIdx.x; - - if (index >= n) { - return; - } - - float val = 0.0; - - int dim_b = DIM0(output_size); - int dim_c = DIM1(output_size); - int dim_h = DIM2(output_size); - int dim_w = DIM3(output_size); - int dim_chw = dim_c * dim_h * dim_w; - int dim_hw = dim_h * dim_w; - - int b = ( index / dim_chw ) % dim_b; - int c = ( index / dim_hw ) % dim_c; - int y = ( index / dim_w ) % dim_h; - int x = ( index ) % dim_w; - - float dx = DIM3_INDEX(input2, b, 0, y, x); - float dy = DIM3_INDEX(input2, b, 1, y, x); - - float xf = float(x) + dx; - float yf = float(y) + dy; - float alpha = xf - floor(xf); // alpha - float beta = yf - floor(yf); // beta - - int xL = max(min( int (floor(xf)), dim_w-1), 0); - int xR = max(min( int (floor(xf)+1), dim_w -1), 0); - int yT = max(min( int (floor(yf)), dim_h-1), 0); - int yB = max(min( int (floor(yf)+1), dim_h-1), 0); - - for (int fy = 0; fy < kernel_size; fy += 1) { - for (int fx = 0; fx < kernel_size; fx += 1) { - val += (1. - alpha)*(1. - beta) * DIM3_INDEX(input1, b, c, yT + fy, xL + fx); - val += (alpha)*(1. - beta) * DIM3_INDEX(input1, b, c, yT + fy, xR + fx); - val += (1. - alpha)*(beta) * DIM3_INDEX(input1, b, c, yB + fy, xL + fx); - val += (alpha)*(beta) * DIM3_INDEX(input1, b, c, yB + fy, xR + fx); - } - } - - output[index] = val; - -} - - -__global__ void kernel_Resample2d_backward_input1( - const int n, const float* input1, const long4 input1_size, const long4 input1_stride, const float* input2, const long4 input2_size, const long4 input2_stride, - const float* gradOutput, const long4 gradOutput_size, const long4 gradOutput_stride, float* gradInput, const long4 gradInput_size, const long4 gradInput_stride, int kernel_size) { - - int index = blockIdx.x * blockDim.x + threadIdx.x; - - if (index >= n) { - return; - } - - int dim_b = DIM0(gradOutput_size); - int dim_c = DIM1(gradOutput_size); - int dim_h = DIM2(gradOutput_size); - int dim_w = DIM3(gradOutput_size); - int dim_chw = dim_c * dim_h * dim_w; - int dim_hw = dim_h * dim_w; - - int b = ( index / dim_chw ) % dim_b; - int c = ( index / dim_hw ) % dim_c; - int y = ( index / dim_w ) % dim_h; - int x = ( index ) % dim_w; - - float dx = DIM3_INDEX(input2, b, 0, y, x); - float dy = DIM3_INDEX(input2, b, 1, y, x); - - float xf = float(x) + dx; - float yf = float(y) + dy; - float alpha = xf - int(xf); // alpha - float beta = yf - int(yf); // beta - - int idim_h = DIM2(input1_size); - int idim_w = DIM3(input1_size); - - int xL = max(min( int (floor(xf)), idim_w-1), 0); - int xR = max(min( int (floor(xf)+1), idim_w -1), 0); - int yT = max(min( int (floor(yf)), idim_h-1), 0); - int yB = max(min( int (floor(yf)+1), idim_h-1), 0); - - for (int fy = 0; fy < kernel_size; fy += 1) { - for (int fx = 0; fx < kernel_size; fx += 1) { - atomicAdd(&DIM3_INDEX(gradInput, b, c, (yT + fy), (xL + fx)), (1-alpha)*(1-beta) * DIM3_INDEX(gradOutput, b, c, y, x)); - atomicAdd(&DIM3_INDEX(gradInput, b, c, (yT + fy), (xR + fx)), (alpha)*(1-beta) * DIM3_INDEX(gradOutput, b, c, y, x)); - atomicAdd(&DIM3_INDEX(gradInput, b, c, (yB + fy), (xL + fx)), (1-alpha)*(beta) * DIM3_INDEX(gradOutput, b, c, y, x)); - atomicAdd(&DIM3_INDEX(gradInput, b, c, (yB + fy), (xR + fx)), (alpha)*(beta) * DIM3_INDEX(gradOutput, b, c, y, x)); - } - } - -} - -__global__ void kernel_Resample2d_backward_input2( - const int n, const float* input1, const long4 input1_size, const long4 input1_stride, const float* input2, const long4 input2_size, const long4 input2_stride, - const float* gradOutput, const long4 gradOutput_size, const long4 gradOutput_stride, float* gradInput, const long4 gradInput_size, const long4 gradInput_stride, int kernel_size) { - - int index = blockIdx.x * blockDim.x + threadIdx.x; - - if (index >= n) { - return; - } - - float output = 0.0; - int kernel_rad = (kernel_size - 1)/2; - - int dim_b = DIM0(gradInput_size); - int dim_c = DIM1(gradInput_size); - int dim_h = DIM2(gradInput_size); - int dim_w = DIM3(gradInput_size); - int dim_chw = dim_c * dim_h * dim_w; - int dim_hw = dim_h * dim_w; - - int b = ( index / dim_chw ) % dim_b; - int c = ( index / dim_hw ) % dim_c; - int y = ( index / dim_w ) % dim_h; - int x = ( index ) % dim_w; - - int odim_c = DIM1(gradOutput_size); - - float dx = DIM3_INDEX(input2, b, 0, y, x); - float dy = DIM3_INDEX(input2, b, 1, y, x); - - float xf = float(x) + dx; - float yf = float(y) + dy; - - int xL = max(min( int (floor(xf)), dim_w-1), 0); - int xR = max(min( int (floor(xf)+1), dim_w -1), 0); - int yT = max(min( int (floor(yf)), dim_h-1), 0); - int yB = max(min( int (floor(yf)+1), dim_h-1), 0); - - if (c % 2) { - float gamma = 1 - (xf - floor(xf)); // alpha - for (int i = 0; i <= 2*kernel_rad; ++i) { - for (int j = 0; j <= 2*kernel_rad; ++j) { - for (int ch = 0; ch < odim_c; ++ch) { - output += (gamma) * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, (yB + j), (xL + i)); - output -= (gamma) * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, (yT + j), (xL + i)); - output += (1-gamma) * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, (yB + j), (xR + i)); - output -= (1-gamma) * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, (yT + j), (xR + i)); - } - } - } - } - else { - float gamma = 1 - (yf - floor(yf)); // alpha - for (int i = 0; i <= 2*kernel_rad; ++i) { - for (int j = 0; j <= 2*kernel_rad; ++j) { - for (int ch = 0; ch < odim_c; ++ch) { - output += (gamma) * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, (yT + j), (xR + i)); - output -= (gamma) * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, (yT + j), (xL + i)); - output += (1-gamma) * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, (yB + j), (xR + i)); - output -= (1-gamma) * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, (yB + j), (xL + i)); - } - } - } - - } - - gradInput[index] = output; - -} - -void Resample2d_kernel_forward(THCState* state, THCudaTensor* input1, THCudaTensor* input2, THCudaTensor* output, int kernel_size) { - int n = 0; - - const long4 input1_size = make_long4(input1->size[0], input1->size[1], input1->size[2], input1->size[3]); - const long4 input1_stride = make_long4(input1->stride[0], input1->stride[1], input1->stride[2], input1->stride[3]); - - const long4 input2_size = make_long4(input2->size[0], input2->size[1], input2->size[2], input2->size[3]); - const long4 input2_stride = make_long4(input2->stride[0], input2->stride[1], input2->stride[2], input2->stride[3]); - - const long4 output_size = make_long4(output->size[0], output->size[1], output->size[2], output->size[3]); - const long4 output_stride = make_long4(output->stride[0], output->stride[1], output->stride[2], output->stride[3]); - - n = THCudaTensor_nElement(state, output); - kernel_Resample2d_updateOutput<<< (n + CUDA_NUM_THREADS - 1)/CUDA_NUM_THREADS, CUDA_NUM_THREADS, 0, THCState_getCurrentStream(state) >>>( - n, THCudaTensor_data(state, input1), input1_size, input1_stride, THCudaTensor_data(state, input2), input2_size, input2_stride, - THCudaTensor_data(state, output), output_size, output_stride, kernel_size); - - THCudaCheck(cudaGetLastError()); -} - -void Resample2d_kernel_backward(THCState* state, THCudaTensor* input1, THCudaTensor* input2, THCudaTensor* gradOutput, THCudaTensor* gradInput1, THCudaTensor* gradInput2, int kernel_size) { - int n = 0; - - const long4 input1_size = make_long4(input1->size[0], input1->size[1], input1->size[2], input1->size[3]); - const long4 input1_stride = make_long4(input1->stride[0], input1->stride[1], input1->stride[2], input1->stride[3]); - - const long4 input2_size = make_long4(input2->size[0], input2->size[1], input2->size[2], input2->size[3]); - const long4 input2_stride = make_long4(input2->stride[0], input2->stride[1], input2->stride[2], input2->stride[3]); - - const long4 gradOutput_size = make_long4(gradOutput->size[0], gradOutput->size[1], gradOutput->size[2], gradOutput->size[3]); - const long4 gradOutput_stride = make_long4(gradOutput->stride[0], gradOutput->stride[1], gradOutput->stride[2], gradOutput->stride[3]); - - const long4 gradInput1_size = make_long4(gradInput1->size[0], gradInput1->size[1], gradInput1->size[2], gradInput1->size[3]); - const long4 gradInput1_stride = make_long4(gradInput1->stride[0], gradInput1->stride[1], gradInput1->stride[2], gradInput1->stride[3]); - - n = THCudaTensor_nElement(state, gradOutput); - kernel_Resample2d_backward_input1<<< (n + CUDA_NUM_THREADS - 1)/CUDA_NUM_THREADS, CUDA_NUM_THREADS, 0, THCState_getCurrentStream(state) >>>( - n, THCudaTensor_data(state, input1), input1_size, input1_stride, THCudaTensor_data(state, input2), input2_size, input2_stride, - THCudaTensor_data(state, gradOutput), gradOutput_size, gradOutput_stride, THCudaTensor_data(state, gradInput1), gradInput1_size, gradInput1_stride, kernel_size - ); - - const long4 gradInput2_size = make_long4(gradInput2->size[0], gradInput2->size[1], gradInput2->size[2], gradInput2->size[3]); - const long4 gradInput2_stride = make_long4(gradInput2->stride[0], gradInput2->stride[1], gradInput2->stride[2], gradInput2->stride[3]); - - n = THCudaTensor_nElement(state, gradInput2); - kernel_Resample2d_backward_input2<<< (n + CUDA_NUM_THREADS - 1)/CUDA_NUM_THREADS, CUDA_NUM_THREADS, 0, THCState_getCurrentStream(state) >>>( - n, THCudaTensor_data(state, input1), input1_size, input1_stride, THCudaTensor_data(state, input2), input2_size, input2_stride, - THCudaTensor_data(state, gradOutput), gradOutput_size, gradOutput_stride, THCudaTensor_data(state, gradInput2), gradInput2_size, gradInput2_stride, kernel_size - ); - THCudaCheck(cudaGetLastError()); -} - -#ifdef __cplusplus - } -#endif \ No newline at end of file diff --git a/networks/resample2d_package/src/Resample2d_kernel.h b/networks/resample2d_package/src/Resample2d_kernel.h deleted file mode 100755 index 1b8949a..0000000 --- a/networks/resample2d_package/src/Resample2d_kernel.h +++ /dev/null @@ -1,11 +0,0 @@ -#ifdef __cplusplus - extern "C" { -#endif - -void Resample2d_kernel_forward(THCState* state, THCudaTensor* input1, THCudaTensor* input2, THCudaTensor* output, int kernel_size); - -void Resample2d_kernel_backward(THCState* state, THCudaTensor* input1, THCudaTensor* input2, THCudaTensor* gradOutput, THCudaTensor* gradInput1, THCudaTensor* gradInput2, int kernel_size); - -#ifdef __cplusplus - } -#endif \ No newline at end of file