+
+ +

Source code for msd_pytorch.msd_block

+import torch
+import conv_relu_cuda as cr_cuda
+from msd_pytorch.msd_module import MSDFinalLayer, init_convolution_weights
+import numpy as np
+
+IDX_WEIGHT_START = 3
+
+
+
[docs]class MSDBlockImpl2d(torch.autograd.Function): +
[docs] @staticmethod + def forward(ctx, input, dilations, bias, *weights): + depth = len(dilations) + assert depth == len(weights), "number of weights does not match depth" + + num_out_channels = sum(w.shape[0] for w in weights) + assert ( + len(bias) == num_out_channels + ), "number of biases does not match number of output channels from weights" + + ctx.dilations = dilations + ctx.depth = depth + + result = input.new_empty( + input.shape[0], input.shape[1] + num_out_channels, *input.shape[2:] + ) + + # Copy input into result buffer + result[:, : input.shape[1]] = input + + result_start = input.shape[1] + bias_start = 0 + + for i in range(depth): + # Extract variables + sub_input = result[:, :result_start] + sub_weight = weights[i] + blocksize = sub_weight.shape[0] + sub_bias = bias[bias_start : bias_start + blocksize] + sub_result = result[:, result_start : result_start + blocksize] + dilation = ctx.dilations[i] + + # Compute convolution. conv_relu_forward computes the + # convolution and relu in one pass and stores the + # output in sub_result. + cr_cuda.conv_relu_forward( + sub_input, sub_weight, sub_bias, sub_result, dilation + ) + + # Update steps etc + result_start += blocksize + bias_start += blocksize + + ctx.save_for_backward(bias, result, *weights) + + return result
+ +
[docs] @staticmethod + def backward(ctx, grad_output): + bias, result, *weights = ctx.saved_tensors + depth = ctx.depth + + grad_bias = torch.zeros_like(bias) + # XXX: Could we just overwrite grad_output instead of clone? + gradients = grad_output.clone() + grad_weights = [] + + result_end = result.shape[1] + bias_end = len(bias) + + for i in range(depth): + idx = depth - 1 - i + # Get subsets + sub_weight = weights[idx] + blocksize = sub_weight.shape[0] + result_start = result_end - blocksize + bias_start = bias_end - blocksize + + sub_grad_output = gradients[:, result_start:result_end] + sub_grad_input = gradients[:, :result_start] + sub_result = result[:, result_start:result_end] + sub_input = result[:, :result_start] + + dilation = ctx.dilations[idx] + + # Gradient w.r.t. input: conv_relu_backward_x computes the + # gradient wrt sub_input and adds the gradient to + # sub_grad_input. + cr_cuda.conv_relu_backward_x( + sub_result, sub_grad_output, sub_weight, sub_grad_input, dilation + ) + + # Gradient w.r.t weights + if ctx.needs_input_grad[i + IDX_WEIGHT_START]: + sub_grad_weight = torch.zeros_like(sub_weight) + cr_cuda.conv_relu_backward_k( + sub_result, sub_grad_output, sub_input, sub_grad_weight, dilation + ) + grad_weights.insert(0, sub_grad_weight) + else: + grad_weights.insert(0, None) + # Gradient of Bias + if ctx.needs_input_grad[2]: + sub_grad_bias = grad_bias[bias_start:bias_end] + cr_cuda.conv_relu_backward_bias( + sub_result, sub_grad_output, sub_grad_bias + ) + + # Update positions etc + result_end -= blocksize + bias_end -= blocksize + + grad_input = gradients[:, : weights[0].shape[1]] + + return (grad_input, None, grad_bias, *grad_weights)
+ + +msdblock2d = MSDBlockImpl2d.apply + + +
[docs]class MSDBlock2d(torch.nn.Module): +
[docs] def __init__(self, in_channels, dilations, width=1): + """Multi-scale dense block + + Parameters + ---------- + in_channels : int + Number of input channels + dilations : tuple of int + Dilation for each convolution-block + width : int + Number of channels per convolution. + + Notes + ----- + The number of output channels is in_channels + depth * width + """ + super().__init__() + self.kernel_size = (3, 3) + self.width = width + self.dilations = dilations + + depth = len(self.dilations) + + self.bias = torch.nn.Parameter(torch.Tensor(depth * width)) + + self.weights = [] + for i in range(depth): + n_in = in_channels + width * i + + weight = torch.nn.Parameter(torch.Tensor(width, n_in, *self.kernel_size)) + + self.register_parameter("weight{}".format(i), weight) + self.weights.append(weight) + + self.reset_parameters()
+ +
[docs] def reset_parameters(self): + for weight in self.weights: + torch.nn.init.kaiming_uniform_(weight, a=np.sqrt(5)) + + if self.bias is not None: + # TODO: improve + fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weights[0]) + bound = 1 / np.sqrt(fan_in) + torch.nn.init.uniform_(self.bias, -bound, bound)
+ +
[docs] def forward(self, input): + # This is a bit of a hack, since we require but cannot assume + # that self.parameters() remains sorted in the order that we + # added the parameters. + # + # However, we need to obtain weights in this way, because + # self.weights may become obsolete when used in multi-gpu + # settings when the weights are automatically transferred (by, + # e.g., torch.nn.DataParallel). In that case, self.weights may + # continue to point to the weight parameters on the original + # device, even when the weight parameters have been + # transferred to a different gpu. + bias, *weights = self.parameters() + return MSDBlockImpl2d.apply(input, self.dilations, bias, *weights)
+ + +
[docs]class MSDModule2d(torch.nn.Module): +
[docs] def __init__( + self, c_in, c_out, depth, width, dilations=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + ): + """Create a 2-dimensional MSD Module + + :param c_in: # of input channels + :param c_out: # of output channels + :param depth: # of layers + :param width: # the width of the module + :param dilations: `list(int)` + + A list of dilations to use. Default is ``[1, 2, ..., 10]``. A + good alternative is ``[1, 2, 4, 8]``. The dilations are + repeated. + + :returns: an MSD module + :rtype: MSDModule2d + + """ + + super(MSDModule2d, self).__init__() + + self.c_in = c_in + self.c_out = c_out + self.depth = depth + self.width = width + self.dilations = [dilations[i % len(dilations)] for i in range(depth)] + + self.msd_block = MSDBlock2d(self.c_in, self.dilations, self.width) + self.final_layer = MSDFinalLayer(c_in=c_in + width * depth, c_out=c_out) + + self.reset_parameters()
+ +
[docs] def reset_parameters(self): + # Initialize weights for hidden layers: + for w in self.msd_block.weights: + init_convolution_weights( + w.data, self.c_in, self.c_out, self.width, self.depth + ) + + self.msd_block.bias.data.zero_() + self.final_layer.reset_parameters()
+ +
[docs] def forward(self, input): + output = self.msd_block(input) + output = self.final_layer(output) + return output
+
+ +
+ +