Skip to content

Commit

Permalink
Merge pull request #77 from mkolod/python36
Browse files Browse the repository at this point in the history
Make FlowNet2 work with PyTorch 0.4.1
  • Loading branch information
fitsumreda authored Aug 23, 2018
2 parents 3d6db9f + 12f794c commit cf5a3eb
Show file tree
Hide file tree
Showing 47 changed files with 1,204 additions and 1,166 deletions.
6 changes: 3 additions & 3 deletions install.sh
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#!/bin/bash
cd ./networks/correlation_package
./make.sh
python setup.py install
cd ../resample2d_package
./make.sh
python setup.py install
cd ../channelnorm_package
./make.sh
python setup.py install
cd ..
4 changes: 2 additions & 2 deletions models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
import math
import numpy as np

from networks.resample2d_package.modules.resample2d import Resample2d
from networks.channelnorm_package.modules.channelnorm import ChannelNorm
from networks.resample2d_package.resample2d import Resample2d
from networks.channelnorm_package.channelnorm import ChannelNorm

from networks import FlowNetC
from networks import FlowNetS
Expand Down
2 changes: 1 addition & 1 deletion networks/FlowNetC.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import math
import numpy as np

from .correlation_package.modules.correlation import Correlation
from .correlation_package.correlation import Correlation

from .submodules import *
'Parameter count , 39,175,298 '
Expand Down
31 changes: 0 additions & 31 deletions networks/channelnorm_package/build.py

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from torch.autograd import Function, Variable
from .._ext import channelnorm

from torch.nn.modules.module import Module
import channelnorm_cuda

class ChannelNormFunction(Function):

Expand All @@ -10,7 +10,7 @@ def forward(ctx, input1, norm_deg=2):
b, _, h, w = input1.size()
output = input1.new(b, 1, h, w).zero_()

channelnorm.ChannelNorm_cuda_forward(input1, output, norm_deg)
channelnorm_cuda.forward(input1, output, norm_deg)
ctx.save_for_backward(input1, output)
ctx.norm_deg = norm_deg

Expand All @@ -22,7 +22,18 @@ def backward(ctx, grad_output):

grad_input1 = Variable(input1.new(input1.size()).zero_())

channelnorm.ChannelNorm_cuda_backward(input1, output, grad_output.data,
channelnorm.backward(input1, output, grad_output.data,
grad_input1.data, ctx.norm_deg)

return grad_input1, None


class ChannelNorm(Module):

def __init__(self, norm_deg=2):
super(ChannelNorm, self).__init__()
self.norm_deg = norm_deg

def forward(self, input1):
return ChannelNormFunction.apply(input1, self.norm_deg)

31 changes: 31 additions & 0 deletions networks/channelnorm_package/channelnorm_cuda.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#include <torch/torch.h>
#include <ATen/ATen.h>

#include "channelnorm_kernel.cuh"

int channelnorm_cuda_forward(
at::Tensor& input1,
at::Tensor& output,
int norm_deg) {

channelnorm_kernel_forward(input1, output, norm_deg);
return 1;
}


int channelnorm_cuda_backward(
at::Tensor& input1,
at::Tensor& output,
at::Tensor& gradOutput,
at::Tensor& gradInput1,
int norm_deg) {

channelnorm_kernel_backward(input1, output, gradOutput, gradInput1, norm_deg);
return 1;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &channelnorm_cuda_forward, "Channel norm forward (CUDA)");
m.def("backward", &channelnorm_cuda_backward, "Channel norm backward (CUDA)");
}

174 changes: 174 additions & 0 deletions networks/channelnorm_package/channelnorm_kernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
#include <ATen/ATen.h>
#include <ATen/Context.h>

#include "channelnorm_kernel.cuh"

#define CUDA_NUM_THREADS 512

#define DIM0(TENSOR) ((TENSOR).x)
#define DIM1(TENSOR) ((TENSOR).y)
#define DIM2(TENSOR) ((TENSOR).z)
#define DIM3(TENSOR) ((TENSOR).w)

#define DIM3_INDEX(TENSOR, xx, yy, zz, ww) ((TENSOR)[((xx) * (TENSOR##_stride.x)) + ((yy) * (TENSOR##_stride.y)) + ((zz) * (TENSOR##_stride.z)) + ((ww) * (TENSOR##_stride.w))])

using at::Half;

template <typename scalar_t>
__global__ void kernel_channelnorm_update_output(
const int n,
const scalar_t* __restrict__ input1,
const long4 input1_size,
const long4 input1_stride,
scalar_t* __restrict__ output,
const long4 output_size,
const long4 output_stride,
int norm_deg) {

int index = blockIdx.x * blockDim.x + threadIdx.x;

if (index >= n) {
return;
}

int dim_b = DIM0(output_size);
int dim_c = DIM1(output_size);
int dim_h = DIM2(output_size);
int dim_w = DIM3(output_size);
int dim_chw = dim_c * dim_h * dim_w;

int b = ( index / dim_chw ) % dim_b;
int y = ( index / dim_w ) % dim_h;
int x = ( index ) % dim_w;

int i1dim_c = DIM1(input1_size);
int i1dim_h = DIM2(input1_size);
int i1dim_w = DIM3(input1_size);
int i1dim_chw = i1dim_c * i1dim_h * i1dim_w;
int i1dim_hw = i1dim_h * i1dim_w;

float result = 0.0;

for (int c = 0; c < i1dim_c; ++c) {
int i1Index = b * i1dim_chw + c * i1dim_hw + y * i1dim_w + x;
scalar_t val = input1[i1Index];
result += static_cast<float>(val * val);
}
result = sqrt(result);
output[index] = static_cast<scalar_t>(result);
}


template <typename scalar_t>
__global__ void kernel_channelnorm_backward_input1(
const int n,
const scalar_t* __restrict__ input1, const long4 input1_size, const long4 input1_stride,
const scalar_t* __restrict__ output, const long4 output_size, const long4 output_stride,
const scalar_t* __restrict__ gradOutput, const long4 gradOutput_size, const long4 gradOutput_stride,
scalar_t* __restrict__ gradInput, const long4 gradInput_size, const long4 gradInput_stride,
int norm_deg) {

int index = blockIdx.x * blockDim.x + threadIdx.x;

if (index >= n) {
return;
}

float val = 0.0;

int dim_b = DIM0(gradInput_size);
int dim_c = DIM1(gradInput_size);
int dim_h = DIM2(gradInput_size);
int dim_w = DIM3(gradInput_size);
int dim_chw = dim_c * dim_h * dim_w;
int dim_hw = dim_h * dim_w;

int b = ( index / dim_chw ) % dim_b;
int y = ( index / dim_w ) % dim_h;
int x = ( index ) % dim_w;


int outIndex = b * dim_hw + y * dim_w + x;
val = static_cast<float>(gradOutput[outIndex]) * static_cast<float>(input1[index]) / (static_cast<float>(output[outIndex])+1e-9);
gradInput[index] = static_cast<scalar_t>(val);

}

void channelnorm_kernel_forward(
at::Tensor& input1,
at::Tensor& output,
int norm_deg) {

const long4 input1_size = make_long4(input1.size(0), input1.size(1), input1.size(2), input1.size(3));
const long4 input1_stride = make_long4(input1.stride(0), input1.stride(1), input1.stride(2), input1.stride(3));

const long4 output_size = make_long4(output.size(0), output.size(1), output.size(2), output.size(3));
const long4 output_stride = make_long4(output.stride(0), output.stride(1), output.stride(2), output.stride(3));

int n = output.numel();

AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.type(), "channelnorm_forward", ([&] {

kernel_channelnorm_update_output<scalar_t><<< (n + CUDA_NUM_THREADS - 1)/CUDA_NUM_THREADS, CUDA_NUM_THREADS, 0, at::globalContext().getCurrentCUDAStream() >>>(
n,
input1.data<scalar_t>(),
input1_size,
input1_stride,
output.data<scalar_t>(),
output_size,
output_stride,
norm_deg);

}));

// TODO: ATen-equivalent check

// THCudaCheck(cudaGetLastError());
}

void channelnorm_kernel_backward(
at::Tensor& input1,
at::Tensor& output,
at::Tensor& gradOutput,
at::Tensor& gradInput1,
int norm_deg) {

const long4 input1_size = make_long4(input1.size(0), input1.size(1), input1.size(2), input1.size(3));
const long4 input1_stride = make_long4(input1.stride(0), input1.stride(1), input1.stride(2), input1.stride(3));

const long4 output_size = make_long4(output.size(0), output.size(1), output.size(2), output.size(3));
const long4 output_stride = make_long4(output.stride(0), output.stride(1), output.stride(2), output.stride(3));

const long4 gradOutput_size = make_long4(gradOutput.size(0), gradOutput.size(1), gradOutput.size(2), gradOutput.size(3));
const long4 gradOutput_stride = make_long4(gradOutput.stride(0), gradOutput.stride(1), gradOutput.stride(2), gradOutput.stride(3));

const long4 gradInput1_size = make_long4(gradInput1.size(0), gradInput1.size(1), gradInput1.size(2), gradInput1.size(3));
const long4 gradInput1_stride = make_long4(gradInput1.stride(0), gradInput1.stride(1), gradInput1.stride(2), gradInput1.stride(3));

int n = gradInput1.numel();

AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.type(), "channelnorm_backward_input1", ([&] {

kernel_channelnorm_backward_input1<scalar_t><<< (n + CUDA_NUM_THREADS - 1)/CUDA_NUM_THREADS, CUDA_NUM_THREADS, 0, at::globalContext().getCurrentCUDAStream() >>>(
n,
input1.data<scalar_t>(),
input1_size,
input1_stride,
output.data<scalar_t>(),
output_size,
output_stride,
gradOutput.data<scalar_t>(),
gradOutput_size,
gradOutput_stride,
gradInput1.data<scalar_t>(),
gradInput1_size,
gradInput1_stride,
norm_deg
);

}));

// TODO: Add ATen-equivalent check

// THCudaCheck(cudaGetLastError());
}
16 changes: 16 additions & 0 deletions networks/channelnorm_package/channelnorm_kernel.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#pragma once

#include <ATen/ATen.h>

void channelnorm_kernel_forward(
at::Tensor& input1,
at::Tensor& output,
int norm_deg);


void channelnorm_kernel_backward(
at::Tensor& input1,
at::Tensor& output,
at::Tensor& gradOutput,
at::Tensor& gradInput1,
int norm_deg);
Empty file.
12 changes: 0 additions & 12 deletions networks/channelnorm_package/make.sh

This file was deleted.

Empty file.
13 changes: 0 additions & 13 deletions networks/channelnorm_package/modules/channelnorm.py

This file was deleted.

28 changes: 28 additions & 0 deletions networks/channelnorm_package/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#!/usr/bin/env python3
import os
import torch

from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension

cxx_args = ['-std=c++11']

nvcc_args = [
'-gencode', 'arch=compute_52,code=sm_52',
'-gencode', 'arch=compute_60,code=sm_60',
'-gencode', 'arch=compute_61,code=sm_61',
'-gencode', 'arch=compute_70,code=sm_70',
'-gencode', 'arch=compute_70,code=compute_70'
]

setup(
name='channelnorm_cuda',
ext_modules=[
CUDAExtension('channelnorm_cuda', [
'channelnorm_cuda.cc',
'channelnorm_kernel.cu'
], extra_compile_args={'cxx': cxx_args, 'nvcc': nvcc_args})
],
cmdclass={
'build_ext': BuildExtension
})
Loading

0 comments on commit cf5a3eb

Please sign in to comment.