Skip to content

Commit

Permalink
Reformat msd_block.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Allard Hendriksen committed Jul 30, 2019
1 parent 3f1d648 commit ff7f5ef
Showing 1 changed file with 22 additions and 39 deletions.
61 changes: 22 additions & 39 deletions msd_pytorch/msd_block.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,31 @@
import torch
import conv_relu_cuda as cr_cuda
from msd_pytorch.msd_module import (
MSDFinalLayer,
init_convolution_weights
)
from msd_pytorch.msd_module import MSDFinalLayer, init_convolution_weights
import numpy as np

IDX_WEIGHT_START = 3


class MSDBlockImpl2d(torch.autograd.Function):

@staticmethod
def forward(ctx, input, dilations, bias, *weights):
depth = len(dilations)
assert depth == len(weights), "number of weights does not match depth"

num_out_channels = sum(w.shape[0] for w in weights)
assert len(bias) == num_out_channels, \
"number of biases does not match number of output channels from weights"
assert (
len(bias) == num_out_channels
), "number of biases does not match number of output channels from weights"

ctx.dilations = dilations
ctx.depth = depth

result = input.new_empty(
input.shape[0],
input.shape[1] + num_out_channels,
*input.shape[2:],
input.shape[0], input.shape[1] + num_out_channels, *input.shape[2:]
)

# Copy input into result buffer
result[:, :input.shape[1]] = input
result[:, : input.shape[1]] = input

result_start = input.shape[1]
bias_start = 0
Expand All @@ -40,19 +35,15 @@ def forward(ctx, input, dilations, bias, *weights):
sub_input = result[:, :result_start]
sub_weight = weights[i]
blocksize = sub_weight.shape[0]
sub_bias = bias[bias_start: bias_start + blocksize]
sub_result = result[:, result_start: result_start + blocksize]
sub_bias = bias[bias_start : bias_start + blocksize]
sub_result = result[:, result_start : result_start + blocksize]
dilation = ctx.dilations[i]

# Compute convolution. conv_relu_forward computes the
# convolution and relu in one pass and stores the
# output in sub_result.
cr_cuda.conv_relu_forward(
sub_input,
sub_weight,
sub_bias,
sub_result,
dilation
sub_input, sub_weight, sub_bias, sub_result, dilation
)

# Update steps etc
Expand Down Expand Up @@ -85,7 +76,7 @@ def backward(ctx, grad_output):
bias_start = bias_end - blocksize

sub_grad_output = gradients[:, result_start:result_end]
sub_grad_input = gradients[:, : result_start]
sub_grad_input = gradients[:, :result_start]
sub_result = result[:, result_start:result_end]
sub_input = result[:, :result_start]

Expand All @@ -95,22 +86,14 @@ def backward(ctx, grad_output):
# gradient wrt sub_input and adds the gradient to
# sub_grad_input.
cr_cuda.conv_relu_backward_x(
sub_result,
sub_grad_output,
sub_weight,
sub_grad_input,
dilation
sub_result, sub_grad_output, sub_weight, sub_grad_input, dilation
)

# Gradient w.r.t weights
if ctx.needs_input_grad[i + IDX_WEIGHT_START]:
sub_grad_weight = torch.zeros_like(sub_weight)
cr_cuda.conv_relu_backward_k(
sub_result,
sub_grad_output,
sub_input,
sub_grad_weight,
dilation
sub_result, sub_grad_output, sub_input, sub_grad_weight, dilation
)
grad_weights.insert(0, sub_grad_weight)
else:
Expand All @@ -119,16 +102,14 @@ def backward(ctx, grad_output):
if ctx.needs_input_grad[2]:
sub_grad_bias = grad_bias[bias_start:bias_end]
cr_cuda.conv_relu_backward_bias(
sub_result,
sub_grad_output,
sub_grad_bias,
sub_result, sub_grad_output, sub_grad_bias
)

# Update positions etc
result_end -= blocksize
bias_end -= blocksize

grad_input = gradients[:, :weights[0].shape[1]]
grad_input = gradients[:, : weights[0].shape[1]]

return (grad_input, None, grad_bias, *grad_weights)

Expand Down Expand Up @@ -166,13 +147,11 @@ def __init__(self, in_channels, dilations, width=1):
for i in range(depth):
n_in = in_channels + width * i

weight = torch.nn.Parameter(torch.Tensor(
width, n_in, *self.kernel_size))
weight = torch.nn.Parameter(torch.Tensor(width, n_in, *self.kernel_size))

self.register_parameter('weight{}'.format(i), weight)
self.register_parameter("weight{}".format(i), weight)
self.weights.append(weight)


self.reset_parameters()

def reset_parameters(self):
Expand Down Expand Up @@ -202,7 +181,9 @@ def forward(self, input):


class MSDModule2d(torch.nn.Module):
def __init__(self, c_in, c_out, depth, width, dilations=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]):
def __init__(
self, c_in, c_out, depth, width, dilations=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
):
"""Create a 2-dimensional MSD Module
:param c_in: # of input channels
Expand Down Expand Up @@ -236,7 +217,9 @@ def __init__(self, c_in, c_out, depth, width, dilations=[1, 2, 3, 4, 5, 6, 7, 8,
def reset_parameters(self):
# Initialize weights for hidden layers:
for w in self.msd_block.weights:
init_convolution_weights(w.data, self.c_in, self.c_out, self.width, self.depth)
init_convolution_weights(
w.data, self.c_in, self.c_out, self.width, self.depth
)

self.msd_block.bias.data.zero_()
self.final_layer.reset_parameters()
Expand Down

0 comments on commit ff7f5ef

Please sign in to comment.