diff --git a/msd_pytorch/msd_block.py b/msd_pytorch/msd_block.py index e624960..9832075 100644 --- a/msd_pytorch/msd_block.py +++ b/msd_pytorch/msd_block.py @@ -1,36 +1,31 @@ import torch import conv_relu_cuda as cr_cuda -from msd_pytorch.msd_module import ( - MSDFinalLayer, - init_convolution_weights -) +from msd_pytorch.msd_module import MSDFinalLayer, init_convolution_weights import numpy as np IDX_WEIGHT_START = 3 class MSDBlockImpl2d(torch.autograd.Function): - @staticmethod def forward(ctx, input, dilations, bias, *weights): depth = len(dilations) assert depth == len(weights), "number of weights does not match depth" num_out_channels = sum(w.shape[0] for w in weights) - assert len(bias) == num_out_channels, \ - "number of biases does not match number of output channels from weights" + assert ( + len(bias) == num_out_channels + ), "number of biases does not match number of output channels from weights" ctx.dilations = dilations ctx.depth = depth result = input.new_empty( - input.shape[0], - input.shape[1] + num_out_channels, - *input.shape[2:], + input.shape[0], input.shape[1] + num_out_channels, *input.shape[2:] ) # Copy input into result buffer - result[:, :input.shape[1]] = input + result[:, : input.shape[1]] = input result_start = input.shape[1] bias_start = 0 @@ -40,19 +35,15 @@ def forward(ctx, input, dilations, bias, *weights): sub_input = result[:, :result_start] sub_weight = weights[i] blocksize = sub_weight.shape[0] - sub_bias = bias[bias_start: bias_start + blocksize] - sub_result = result[:, result_start: result_start + blocksize] + sub_bias = bias[bias_start : bias_start + blocksize] + sub_result = result[:, result_start : result_start + blocksize] dilation = ctx.dilations[i] # Compute convolution. conv_relu_forward computes the # convolution and relu in one pass and stores the # output in sub_result. cr_cuda.conv_relu_forward( - sub_input, - sub_weight, - sub_bias, - sub_result, - dilation + sub_input, sub_weight, sub_bias, sub_result, dilation ) # Update steps etc @@ -85,7 +76,7 @@ def backward(ctx, grad_output): bias_start = bias_end - blocksize sub_grad_output = gradients[:, result_start:result_end] - sub_grad_input = gradients[:, : result_start] + sub_grad_input = gradients[:, :result_start] sub_result = result[:, result_start:result_end] sub_input = result[:, :result_start] @@ -95,22 +86,14 @@ def backward(ctx, grad_output): # gradient wrt sub_input and adds the gradient to # sub_grad_input. cr_cuda.conv_relu_backward_x( - sub_result, - sub_grad_output, - sub_weight, - sub_grad_input, - dilation + sub_result, sub_grad_output, sub_weight, sub_grad_input, dilation ) # Gradient w.r.t weights if ctx.needs_input_grad[i + IDX_WEIGHT_START]: sub_grad_weight = torch.zeros_like(sub_weight) cr_cuda.conv_relu_backward_k( - sub_result, - sub_grad_output, - sub_input, - sub_grad_weight, - dilation + sub_result, sub_grad_output, sub_input, sub_grad_weight, dilation ) grad_weights.insert(0, sub_grad_weight) else: @@ -119,16 +102,14 @@ def backward(ctx, grad_output): if ctx.needs_input_grad[2]: sub_grad_bias = grad_bias[bias_start:bias_end] cr_cuda.conv_relu_backward_bias( - sub_result, - sub_grad_output, - sub_grad_bias, + sub_result, sub_grad_output, sub_grad_bias ) # Update positions etc result_end -= blocksize bias_end -= blocksize - grad_input = gradients[:, :weights[0].shape[1]] + grad_input = gradients[:, : weights[0].shape[1]] return (grad_input, None, grad_bias, *grad_weights) @@ -166,13 +147,11 @@ def __init__(self, in_channels, dilations, width=1): for i in range(depth): n_in = in_channels + width * i - weight = torch.nn.Parameter(torch.Tensor( - width, n_in, *self.kernel_size)) + weight = torch.nn.Parameter(torch.Tensor(width, n_in, *self.kernel_size)) - self.register_parameter('weight{}'.format(i), weight) + self.register_parameter("weight{}".format(i), weight) self.weights.append(weight) - self.reset_parameters() def reset_parameters(self): @@ -202,7 +181,9 @@ def forward(self, input): class MSDModule2d(torch.nn.Module): - def __init__(self, c_in, c_out, depth, width, dilations=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]): + def __init__( + self, c_in, c_out, depth, width, dilations=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + ): """Create a 2-dimensional MSD Module :param c_in: # of input channels @@ -236,7 +217,9 @@ def __init__(self, c_in, c_out, depth, width, dilations=[1, 2, 3, 4, 5, 6, 7, 8, def reset_parameters(self): # Initialize weights for hidden layers: for w in self.msd_block.weights: - init_convolution_weights(w.data, self.c_in, self.c_out, self.width, self.depth) + init_convolution_weights( + w.data, self.c_in, self.c_out, self.width, self.depth + ) self.msd_block.bias.data.zero_() self.final_layer.reset_parameters()