From 3f1d648dd60b6cca191de16e0eac651291372b8a Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Tue, 30 Jul 2019 18:16:24 +0200 Subject: [PATCH] Add multi-gpu support MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The code has been improved to deal with `torch.nn.DataParallel`. It can now be used to train in parallel with decent speedups. This was executed on a 2 x NVIDIA TITAN RTX system: In [1]: import torch ...: from msd_pytorch import MSDRegressionModel ...: model = MSDRegressionModel(1, 1, 100, 1) ...: inp0 = torch.zeros(16, 1, 1000, 1000, dtype=torch.float32, device=torch.device("cuda:0")) ...: inp1 = torch.zeros(16, 1, 1000, 1000, dtype=torch.float32, device=torch.device("cuda:1")) ...: par_net = torch.nn.DataParallel(model.net, device_ids=[0,1]) In [2]: %timeit par_net(inp0) 1.03 s ± 1.8 ms per loop (mean ± std. dev. of 7 runs, 10 loops each) In [2]: %timeit model.net(inp0) 2.83 s ± 2.35 ms per loop (mean ± std. dev. of 7 runs, 10 loops each) --- msd_pytorch/conv_cuda.cu | 14 +++++-- msd_pytorch/conv_relu.cpp | 8 ++++ msd_pytorch/conv_relu_cuda.cu | 41 ++++++++++++------- msd_pytorch/msd_block.py | 17 +++++++- msd_pytorch/msd_regression_model.py | 9 ++++ msd_pytorch/msd_segmentation_model.py | 10 +++++ msd_pytorch/relu_cuda.cu | 9 ++-- .../tests/test_msd_regression_model.py | 15 +++++++ .../tests/test_msd_segmentation_model.py | 18 ++++++++ 9 files changed, 117 insertions(+), 24 deletions(-) diff --git a/msd_pytorch/conv_cuda.cu b/msd_pytorch/conv_cuda.cu index f55671d..aa5a960 100644 --- a/msd_pytorch/conv_cuda.cu +++ b/msd_pytorch/conv_cuda.cu @@ -12,6 +12,7 @@ #include "THC/THCDeviceUtils.cuh" #include "device_tensor.h" +using at::OptionalDeviceGuard; __device__ __forceinline__ int reflect(int i, int dimi) { @@ -585,7 +586,8 @@ at::Tensor conv_cuda_forward(at::Tensor input_t, int dilation, int implementation, int block_size) { - AT_DISPATCH_FLOATING_TYPES(input_t.type(), "conv_cuda_forward", ([&] { + OptionalDeviceGuard device_guard(device_of(input_t)); + AT_DISPATCH_FLOATING_TYPES(input_t.scalar_type(), "conv_cuda_forward", ([&] { // Create device tensors: dTensor4R input_d = toDeviceTensorR(input_t); dTensor4R kernel_d = toDeviceTensorR(kernel_t); @@ -621,7 +623,8 @@ void conv_cuda_backward_x(at::Tensor grad_output_t, int dilation, int implementation, int block_size) { - AT_DISPATCH_FLOATING_TYPES(grad_output_t.type(), "conv_cuda_backward_x", ([&] { + OptionalDeviceGuard device_guard(at::device_of(grad_output_t)); + AT_DISPATCH_FLOATING_TYPES(grad_output_t.scalar_type(), "conv_cuda_backward_x", ([&] { // Create device tensors: dTensor4R grad_output_d = toDeviceTensorR(grad_output_t); dTensor4R grad_input_d = toDeviceTensorR(grad_input_t); @@ -645,7 +648,8 @@ void conv_cuda_backward_k(at::Tensor grad_output, at::Tensor input, at::Tensor grad_kernel, int dilation, int implementation, int block_size) { - AT_DISPATCH_FLOATING_TYPES(grad_output.type(), "conv_cuda_backward_k", ([&] { + OptionalDeviceGuard device_guard(at::device_of(grad_output)); + AT_DISPATCH_FLOATING_TYPES(grad_output.scalar_type(), "conv_cuda_backward_k", ([&] { // Create device tensors: dTensor4R grad_output_d = toDeviceTensorR(grad_output); dTensor4R input_d = toDeviceTensorR(input); @@ -670,7 +674,9 @@ void conv_cuda_backward_bias(at::Tensor grad_output, at::Tensor grad_bias, int implementation, int block_size) { - AT_DISPATCH_FLOATING_TYPES(grad_output.type(), "conv_cuda_backward_bias", ([&] { + OptionalDeviceGuard device_guard(at::device_of(grad_output)); + + AT_DISPATCH_FLOATING_TYPES(grad_output.scalar_type(), "conv_cuda_backward_bias", ([&] { // Create device tensors: dTensor4R grad_output_d = toDeviceTensorR(grad_output); dTensor1R grad_bias_d = toDeviceTensorR(grad_bias); diff --git a/msd_pytorch/conv_relu.cpp b/msd_pytorch/conv_relu.cpp index 8c0990f..4eea9d5 100644 --- a/msd_pytorch/conv_relu.cpp +++ b/msd_pytorch/conv_relu.cpp @@ -63,6 +63,14 @@ at::Tensor conv_relu_forward(at::Tensor input, CHECK_CUDA(bias); CHECK_INPUT(kernel); // kernel must be contiguous. + torch::TensorArg arg_input(input, "input", 0); + torch::TensorArg arg_kernel(kernel, "kernel", 1); + torch::TensorArg arg_bias(bias, "bias", 2); + torch::TensorArg arg_output(output, "output", 3); + + // Check same device + at::checkAllSameGPU("conv_relu_forward", {arg_input, arg_kernel, arg_bias, arg_output}); + // Check data type AT_ASSERTM(input.type() == kernel.type(), "input and kernel must have same type"); AT_ASSERTM(input.type() == bias.type(), "input and bias must have same type"); diff --git a/msd_pytorch/conv_relu_cuda.cu b/msd_pytorch/conv_relu_cuda.cu index 801c2d3..4bc1e64 100644 --- a/msd_pytorch/conv_relu_cuda.cu +++ b/msd_pytorch/conv_relu_cuda.cu @@ -12,6 +12,7 @@ #include "THC/THCDeviceUtils.cuh" #include "device_tensor.h" +using at::OptionalDeviceGuard; __device__ __forceinline__ int reflect(int i, int dimi) { @@ -642,7 +643,10 @@ at::Tensor conv_relu_cuda_forward(at::Tensor input_t, int dilation, int implementation, int block_size) { - AT_DISPATCH_FLOATING_TYPES(input_t.type(), "conv_relu_cuda_forward", ([&] { + OptionalDeviceGuard device_guard(device_of(input_t)); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + AT_DISPATCH_FLOATING_TYPES(input_t.scalar_type(), "conv_relu_cuda_forward", ([&] { // Create device tensors: dTensor4R input_d = toDeviceTensorR(input_t); dTensor4R kernel_d = toDeviceTensorR(kernel_t); @@ -654,16 +658,16 @@ at::Tensor conv_relu_cuda_forward(at::Tensor input_t, dim3 blockSize(block_size, block_size); auto buffer_sz = kernel_t.numel() * sizeof(scalar_t); if (implementation == 2) { - conv_relu2<<>> + conv_relu2<<>> (input_d, kernel_d, bias_d, out_d, dilation); } else if (implementation == 3) { - conv_relu3<<>> + conv_relu3<<>> (input_d, kernel_d, bias_d, out_d, dilation); } else if (implementation == 4) { - conv_relu4<<>> + conv_relu4<<>> (input_d, kernel_d, bias_d, out_d, dilation); } else if (implementation == 5) { - conv_relu5<<>> + conv_relu5<<>> (input_d, kernel_d, bias_d, out_d, dilation); } @@ -679,7 +683,10 @@ void conv_relu_cuda_backward_x(at::Tensor output_t, int dilation, int implementation, int block_size) { - AT_DISPATCH_FLOATING_TYPES(output_t.type(), "conv_relu_cuda_backward_x", ([&] { + OptionalDeviceGuard device_guard(device_of(grad_output_t)); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + AT_DISPATCH_FLOATING_TYPES(output_t.scalar_type(), "conv_relu_cuda_backward_x", ([&] { // Create device tensors: dTensor4R output_d = toDeviceTensorR(output_t); dTensor4R grad_output_d = toDeviceTensorR(grad_output_t); @@ -690,10 +697,10 @@ void conv_relu_cuda_backward_x(at::Tensor output_t, dim3 blockSize(block_size, block_size); auto buffer_sz = kernel_t.numel() * sizeof(scalar_t); if (implementation == 0) { - conv_relu_backward_x0<<>> + conv_relu_backward_x0<<>> (output_d, grad_output_d, kernel_d, grad_input_d, dilation); } else if (implementation == 1) { - conv_relu_backward_x1<<>> + conv_relu_backward_x1<<>> (output_d, grad_output_d, kernel_d, grad_input_d, dilation); } THCudaCheck(cudaGetLastError()); @@ -704,7 +711,10 @@ void conv_relu_cuda_backward_k(at::Tensor output, at::Tensor grad_output, at::Te at::Tensor grad_kernel, int dilation, int implementation, int block_size) { - AT_DISPATCH_FLOATING_TYPES(grad_output.type(), "conv_relu_cuda_backward_k", ([&] { + OptionalDeviceGuard device_guard(device_of(grad_output)); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + AT_DISPATCH_FLOATING_TYPES(grad_output.scalar_type(), "conv_relu_cuda_backward_k", ([&] { // Create device tensors: dTensor4R output_d = toDeviceTensorR(output); dTensor4R grad_output_d = toDeviceTensorR(grad_output); @@ -714,10 +724,10 @@ void conv_relu_cuda_backward_k(at::Tensor output, at::Tensor grad_output, at::Te THCCeilDiv((int) grad_output_d.getSize(2), block_size)); dim3 blockSize(block_size, block_size); if (implementation == 0) { - conv_relu_backward_k0<<>> + conv_relu_backward_k0<<>> (output_d, grad_output_d, input_d, grad_kernel_d, dilation); } else if (implementation == 1) { - conv_relu_backward_k1<<>> + conv_relu_backward_k1<<>> (output_d, grad_output_d, input_d, grad_kernel_d, dilation); } @@ -731,7 +741,10 @@ void conv_relu_cuda_backward_bias(at::Tensor output, at::Tensor grad_bias, int implementation, int block_size) { - AT_DISPATCH_FLOATING_TYPES(grad_output.type(), "conv_relu_cuda_backward_bias", ([&] { + OptionalDeviceGuard device_guard(device_of(grad_output)); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + AT_DISPATCH_FLOATING_TYPES(grad_output.scalar_type(), "conv_relu_cuda_backward_bias", ([&] { // Create device tensors: dTensor4R output_d = toDeviceTensorR(output); dTensor4R grad_output_d = toDeviceTensorR(grad_output); @@ -740,10 +753,10 @@ void conv_relu_cuda_backward_bias(at::Tensor output, THCCeilDiv((int) grad_output_d.getSize(2), block_size)); dim3 blockSize(block_size, block_size); if (implementation == 0) { - conv_relu_backward_bias0<<>> + conv_relu_backward_bias0<<>> (output_d, grad_output_d, grad_bias_d); } else if (implementation == 1) { - conv_relu_backward_bias1<<>> + conv_relu_backward_bias1<<>> (output_d, grad_output_d, grad_bias_d); } THCudaCheck(cudaGetLastError()); diff --git a/msd_pytorch/msd_block.py b/msd_pytorch/msd_block.py index 8dc6b7b..e624960 100644 --- a/msd_pytorch/msd_block.py +++ b/msd_pytorch/msd_block.py @@ -160,6 +160,8 @@ def __init__(self, in_channels, dilations, width=1): depth = len(self.dilations) + self.bias = torch.nn.Parameter(torch.Tensor(depth * width)) + self.weights = [] for i in range(depth): n_in = in_channels + width * i @@ -170,7 +172,6 @@ def __init__(self, in_channels, dilations, width=1): self.register_parameter('weight{}'.format(i), weight) self.weights.append(weight) - self.bias = torch.nn.Parameter(torch.Tensor(depth * width)) self.reset_parameters() @@ -185,7 +186,19 @@ def reset_parameters(self): torch.nn.init.uniform_(self.bias, -bound, bound) def forward(self, input): - return MSDBlockImpl2d.apply(input, self.dilations, self.bias, *self.weights) + # This is a bit of a hack, since we require but cannot assume + # that self.parameters() remains sorted in the order that we + # added the parameters. + # + # However, we need to obtain weights in this way, because + # self.weights may become obsolete when used in multi-gpu + # settings when the weights are automatically transferred (by, + # e.g., torch.nn.DataParallel). In that case, self.weights may + # continue to point to the weight parameters on the original + # device, even when the weight parameters have been + # transferred to a different gpu. + bias, *weights = self.parameters() + return MSDBlockImpl2d.apply(input, self.dilations, bias, *weights) class MSDModule2d(torch.nn.Module): diff --git a/msd_pytorch/msd_regression_model.py b/msd_pytorch/msd_regression_model.py index 3ee54b5..c89b42f 100644 --- a/msd_pytorch/msd_regression_model.py +++ b/msd_pytorch/msd_regression_model.py @@ -26,6 +26,7 @@ def __init__( *, dilations=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], loss="L2", + parallel=False, ): """Create a new MSD network for regression. @@ -47,6 +48,12 @@ def __init__( * "L1" - ``nn.L1Loss()`` * "L2" - ``nn.MSELoss()`` + :param parallel: `bool` + + Whether or not to execute the model on multiple GPUs. Note + that the batch size must be a multiple of the number of + available GPUs. + :returns: :rtype: @@ -62,6 +69,8 @@ def __init__( # Define the whole network: self.net = nn.Sequential(self.scale_in, self.msd, self.scale_out) self.net.cuda() + if parallel: + self.net = nn.DataParallel(self.net) # Train only MSD parameters: self.init_optimizer(self.msd) diff --git a/msd_pytorch/msd_segmentation_model.py b/msd_pytorch/msd_segmentation_model.py index 2352511..653ad8d 100644 --- a/msd_pytorch/msd_segmentation_model.py +++ b/msd_pytorch/msd_segmentation_model.py @@ -24,6 +24,7 @@ def __init__( width, *, dilations=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + parallel=False, ): """Create a new MSD network for segmentation. @@ -37,6 +38,12 @@ def __init__( good alternative is ``[1, 2, 4, 8]``. The dilations are repeated when there are more layers than supplied dilations. + :param parallel: `bool` + + Whether or not to execute the model on multiple GPUs. Note + that the batch size must be a multiple of the number of + available GPUs. + :returns: :rtype: @@ -52,6 +59,9 @@ def __init__( self.net = nn.Sequential(self.scale_in, net_trained) self.net.cuda() + if parallel: + self.net = nn.DataParallel(self.net) + # Train all parameters apart from self.scale_in. self.init_optimizer(net_trained) diff --git a/msd_pytorch/relu_cuda.cu b/msd_pytorch/relu_cuda.cu index da0bab9..f3ae578 100644 --- a/msd_pytorch/relu_cuda.cu +++ b/msd_pytorch/relu_cuda.cu @@ -3,6 +3,7 @@ #include #include +using at::OptionalDeviceGuard; template __global__ void relu_inplace_cuda_forward_kernel(scalar_t* __restrict__ input, size_t size) { @@ -24,26 +25,26 @@ __global__ void relu_inplace_cuda_backward_kernel(const scalar_t* __restrict__ i } at::Tensor relu_inplace_cuda_forward(at::Tensor input){ - + OptionalDeviceGuard device_guard(at::device_of(input)); const int threads = 1024; auto size = input.numel(); const dim3 blocks((size + threads - 1) / threads, 1); - AT_DISPATCH_FLOATING_TYPES(input.type(), "relu_inplace_cuda_forward", ([&] { + AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "relu_inplace_cuda_forward", ([&] { relu_inplace_cuda_forward_kernel<<>>(input.data(), size); })); return input; } at::Tensor relu_inplace_cuda_backward(at::Tensor input, at::Tensor grad_output){ - + OptionalDeviceGuard device_guard(at::device_of(grad_output)); auto grad_input = at::zeros_like(grad_output); const int threads = 1024; auto size = input.numel(); const dim3 blocks((size + threads - 1) / threads, 1); - AT_DISPATCH_FLOATING_TYPES(input.type(), "relu_inplace_cuda_forward", ([&] { + AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "relu_inplace_cuda_forward", ([&] { relu_inplace_cuda_backward_kernel<<>>(input.data(), grad_output.data(), grad_input.data(), diff --git a/msd_pytorch/tests/test_msd_regression_model.py b/msd_pytorch/tests/test_msd_regression_model.py index 4c6887c..cfff5ee 100644 --- a/msd_pytorch/tests/test_msd_regression_model.py +++ b/msd_pytorch/tests/test_msd_regression_model.py @@ -33,6 +33,21 @@ def test_params_change(): assert not torch_equal(p0, p1) +def test_data_parallel(): + """Check that msd_model is compatible with multi-GPU approaches + + Specifically, `torch.nn.DataParallel`. + """ + + shape = (100, 100) + inp = torch.zeros(4, 1, *shape, dtype=torch.float32, device=torch.device("cuda:0")) + tgt = torch.zeros(4, 1, *shape, dtype=torch.float32, device=torch.device("cuda:0")) + + model = MSDRegressionModel(1, 1, 11, 1, parallel=True) + model.forward(inp, tgt) + model.learn(inp, tgt) + + def test_api_surface(tmp_path): ########################################################################### # Create network # diff --git a/msd_pytorch/tests/test_msd_segmentation_model.py b/msd_pytorch/tests/test_msd_segmentation_model.py index b85b226..c7b3a4d 100644 --- a/msd_pytorch/tests/test_msd_segmentation_model.py +++ b/msd_pytorch/tests/test_msd_segmentation_model.py @@ -31,6 +31,24 @@ def test_params_change(): assert not torch_equal(p0, p1) +def test_data_parallel(): + """Check that msd_model is compatible with multi-GPU approaches + + Specifically, `torch.nn.DataParallel`. + """ + + shape = (100, 100) + num_labels = 3 + inp = torch.zeros(4, 1, *shape, dtype=torch.float32, device=torch.device("cuda:0")) + tgt = torch.randint( + low=0, high=num_labels, size=(4, 1, *shape), device=torch.device("cuda:0") + ) + + model = MSDSegmentationModel(1, num_labels, 11, 1, parallel=True) + model.forward(inp, tgt) + model.learn(inp, tgt) + + def test_api_surface(tmp_path): ########################################################################### # Create network #