Skip to content

Commit

Permalink
Add multi-gpu support
Browse files Browse the repository at this point in the history
The code has been improved to deal with `torch.nn.DataParallel`. It
can now be used to train in parallel with decent speedups.

This was executed on a  2 x NVIDIA TITAN RTX system:

In [1]: import torch
   ...: from msd_pytorch import MSDRegressionModel
   ...: model = MSDRegressionModel(1, 1, 100, 1)
   ...: inp0 = torch.zeros(16, 1, 1000, 1000, dtype=torch.float32, device=torch.device("cuda:0"))
   ...: inp1 = torch.zeros(16, 1, 1000, 1000, dtype=torch.float32, device=torch.device("cuda:1"))
   ...: par_net = torch.nn.DataParallel(model.net, device_ids=[0,1])
In [2]: %timeit par_net(inp0)
1.03 s ± 1.8 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
In [2]: %timeit model.net(inp0)
2.83 s ± 2.35 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
  • Loading branch information
Allard Hendriksen committed Jul 30, 2019
1 parent 335cf17 commit 3f1d648
Show file tree
Hide file tree
Showing 9 changed files with 117 additions and 24 deletions.
14 changes: 10 additions & 4 deletions msd_pytorch/conv_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "THC/THCDeviceUtils.cuh"
#include "device_tensor.h"

using at::OptionalDeviceGuard;

__device__ __forceinline__ int
reflect(int i, int dimi) {
Expand Down Expand Up @@ -585,7 +586,8 @@ at::Tensor conv_cuda_forward(at::Tensor input_t,
int dilation,
int implementation,
int block_size) {
AT_DISPATCH_FLOATING_TYPES(input_t.type(), "conv_cuda_forward", ([&] {
OptionalDeviceGuard device_guard(device_of(input_t));
AT_DISPATCH_FLOATING_TYPES(input_t.scalar_type(), "conv_cuda_forward", ([&] {
// Create device tensors:
dTensor4R input_d = toDeviceTensorR<scalar_t,4>(input_t);
dTensor4R kernel_d = toDeviceTensorR<scalar_t,4>(kernel_t);
Expand Down Expand Up @@ -621,7 +623,8 @@ void conv_cuda_backward_x(at::Tensor grad_output_t,
int dilation,
int implementation,
int block_size) {
AT_DISPATCH_FLOATING_TYPES(grad_output_t.type(), "conv_cuda_backward_x", ([&] {
OptionalDeviceGuard device_guard(at::device_of(grad_output_t));
AT_DISPATCH_FLOATING_TYPES(grad_output_t.scalar_type(), "conv_cuda_backward_x", ([&] {
// Create device tensors:
dTensor4R grad_output_d = toDeviceTensorR<scalar_t,4>(grad_output_t);
dTensor4R grad_input_d = toDeviceTensorR<scalar_t,4>(grad_input_t);
Expand All @@ -645,7 +648,8 @@ void conv_cuda_backward_k(at::Tensor grad_output, at::Tensor input,
at::Tensor grad_kernel,
int dilation, int implementation, int block_size)
{
AT_DISPATCH_FLOATING_TYPES(grad_output.type(), "conv_cuda_backward_k", ([&] {
OptionalDeviceGuard device_guard(at::device_of(grad_output));
AT_DISPATCH_FLOATING_TYPES(grad_output.scalar_type(), "conv_cuda_backward_k", ([&] {
// Create device tensors:
dTensor4R grad_output_d = toDeviceTensorR<scalar_t,4>(grad_output);
dTensor4R input_d = toDeviceTensorR<scalar_t,4>(input);
Expand All @@ -670,7 +674,9 @@ void conv_cuda_backward_bias(at::Tensor grad_output,
at::Tensor grad_bias,
int implementation, int block_size)
{
AT_DISPATCH_FLOATING_TYPES(grad_output.type(), "conv_cuda_backward_bias", ([&] {
OptionalDeviceGuard device_guard(at::device_of(grad_output));

AT_DISPATCH_FLOATING_TYPES(grad_output.scalar_type(), "conv_cuda_backward_bias", ([&] {
// Create device tensors:
dTensor4R grad_output_d = toDeviceTensorR<scalar_t,4>(grad_output);
dTensor1R grad_bias_d = toDeviceTensorR<scalar_t,1>(grad_bias);
Expand Down
8 changes: 8 additions & 0 deletions msd_pytorch/conv_relu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,14 @@ at::Tensor conv_relu_forward(at::Tensor input,
CHECK_CUDA(bias);
CHECK_INPUT(kernel); // kernel must be contiguous.

torch::TensorArg arg_input(input, "input", 0);
torch::TensorArg arg_kernel(kernel, "kernel", 1);
torch::TensorArg arg_bias(bias, "bias", 2);
torch::TensorArg arg_output(output, "output", 3);

// Check same device
at::checkAllSameGPU("conv_relu_forward", {arg_input, arg_kernel, arg_bias, arg_output});

// Check data type
AT_ASSERTM(input.type() == kernel.type(), "input and kernel must have same type");
AT_ASSERTM(input.type() == bias.type(), "input and bias must have same type");
Expand Down
41 changes: 27 additions & 14 deletions msd_pytorch/conv_relu_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "THC/THCDeviceUtils.cuh"
#include "device_tensor.h"

using at::OptionalDeviceGuard;

__device__ __forceinline__ int
reflect(int i, int dimi) {
Expand Down Expand Up @@ -642,7 +643,10 @@ at::Tensor conv_relu_cuda_forward(at::Tensor input_t,
int dilation,
int implementation,
int block_size) {
AT_DISPATCH_FLOATING_TYPES(input_t.type(), "conv_relu_cuda_forward", ([&] {
OptionalDeviceGuard device_guard(device_of(input_t));
cudaStream_t stream = at::cuda::getCurrentCUDAStream();

AT_DISPATCH_FLOATING_TYPES(input_t.scalar_type(), "conv_relu_cuda_forward", ([&] {
// Create device tensors:
dTensor4R input_d = toDeviceTensorR<scalar_t,4>(input_t);
dTensor4R kernel_d = toDeviceTensorR<scalar_t,4>(kernel_t);
Expand All @@ -654,16 +658,16 @@ at::Tensor conv_relu_cuda_forward(at::Tensor input_t,
dim3 blockSize(block_size, block_size);
auto buffer_sz = kernel_t.numel() * sizeof(scalar_t);
if (implementation == 2) {
conv_relu2<scalar_t><<<gridSize, blockSize, buffer_sz>>>
conv_relu2<scalar_t><<<gridSize, blockSize, buffer_sz, stream>>>
(input_d, kernel_d, bias_d, out_d, dilation);
} else if (implementation == 3) {
conv_relu3<scalar_t><<<gridSize, blockSize, buffer_sz>>>
conv_relu3<scalar_t><<<gridSize, blockSize, buffer_sz, stream>>>
(input_d, kernel_d, bias_d, out_d, dilation);
} else if (implementation == 4) {
conv_relu4<scalar_t><<<gridSize, blockSize, buffer_sz>>>
conv_relu4<scalar_t><<<gridSize, blockSize, buffer_sz, stream>>>
(input_d, kernel_d, bias_d, out_d, dilation);
} else if (implementation == 5) {
conv_relu5<scalar_t><<<gridSize, blockSize, buffer_sz>>>
conv_relu5<scalar_t><<<gridSize, blockSize, buffer_sz, stream>>>
(input_d, kernel_d, bias_d, out_d, dilation);
}

Expand All @@ -679,7 +683,10 @@ void conv_relu_cuda_backward_x(at::Tensor output_t,
int dilation,
int implementation,
int block_size) {
AT_DISPATCH_FLOATING_TYPES(output_t.type(), "conv_relu_cuda_backward_x", ([&] {
OptionalDeviceGuard device_guard(device_of(grad_output_t));
cudaStream_t stream = at::cuda::getCurrentCUDAStream();

AT_DISPATCH_FLOATING_TYPES(output_t.scalar_type(), "conv_relu_cuda_backward_x", ([&] {
// Create device tensors:
dTensor4R output_d = toDeviceTensorR<scalar_t,4>(output_t);
dTensor4R grad_output_d = toDeviceTensorR<scalar_t,4>(grad_output_t);
Expand All @@ -690,10 +697,10 @@ void conv_relu_cuda_backward_x(at::Tensor output_t,
dim3 blockSize(block_size, block_size);
auto buffer_sz = kernel_t.numel() * sizeof(scalar_t);
if (implementation == 0) {
conv_relu_backward_x0<scalar_t><<<gridSize, blockSize, buffer_sz>>>
conv_relu_backward_x0<scalar_t><<<gridSize, blockSize, buffer_sz, stream>>>
(output_d, grad_output_d, kernel_d, grad_input_d, dilation);
} else if (implementation == 1) {
conv_relu_backward_x1<scalar_t><<<gridSize, blockSize, buffer_sz>>>
conv_relu_backward_x1<scalar_t><<<gridSize, blockSize, buffer_sz, stream>>>
(output_d, grad_output_d, kernel_d, grad_input_d, dilation);
}
THCudaCheck(cudaGetLastError());
Expand All @@ -704,7 +711,10 @@ void conv_relu_cuda_backward_k(at::Tensor output, at::Tensor grad_output, at::Te
at::Tensor grad_kernel,
int dilation, int implementation, int block_size)
{
AT_DISPATCH_FLOATING_TYPES(grad_output.type(), "conv_relu_cuda_backward_k", ([&] {
OptionalDeviceGuard device_guard(device_of(grad_output));
cudaStream_t stream = at::cuda::getCurrentCUDAStream();

AT_DISPATCH_FLOATING_TYPES(grad_output.scalar_type(), "conv_relu_cuda_backward_k", ([&] {
// Create device tensors:
dTensor4R output_d = toDeviceTensorR<scalar_t,4>(output);
dTensor4R grad_output_d = toDeviceTensorR<scalar_t,4>(grad_output);
Expand All @@ -714,10 +724,10 @@ void conv_relu_cuda_backward_k(at::Tensor output, at::Tensor grad_output, at::Te
THCCeilDiv((int) grad_output_d.getSize(2), block_size));
dim3 blockSize(block_size, block_size);
if (implementation == 0) {
conv_relu_backward_k0<scalar_t><<<gridSize, blockSize>>>
conv_relu_backward_k0<scalar_t><<<gridSize, blockSize, 0, stream>>>
(output_d, grad_output_d, input_d, grad_kernel_d, dilation);
} else if (implementation == 1) {
conv_relu_backward_k1<scalar_t><<<gridSize, blockSize>>>
conv_relu_backward_k1<scalar_t><<<gridSize, blockSize, 0, stream>>>
(output_d, grad_output_d, input_d, grad_kernel_d, dilation);
}

Expand All @@ -731,7 +741,10 @@ void conv_relu_cuda_backward_bias(at::Tensor output,
at::Tensor grad_bias,
int implementation, int block_size)
{
AT_DISPATCH_FLOATING_TYPES(grad_output.type(), "conv_relu_cuda_backward_bias", ([&] {
OptionalDeviceGuard device_guard(device_of(grad_output));
cudaStream_t stream = at::cuda::getCurrentCUDAStream();

AT_DISPATCH_FLOATING_TYPES(grad_output.scalar_type(), "conv_relu_cuda_backward_bias", ([&] {
// Create device tensors:
dTensor4R output_d = toDeviceTensorR<scalar_t,4>(output);
dTensor4R grad_output_d = toDeviceTensorR<scalar_t,4>(grad_output);
Expand All @@ -740,10 +753,10 @@ void conv_relu_cuda_backward_bias(at::Tensor output,
THCCeilDiv((int) grad_output_d.getSize(2), block_size));
dim3 blockSize(block_size, block_size);
if (implementation == 0) {
conv_relu_backward_bias0<scalar_t><<<gridSize, blockSize>>>
conv_relu_backward_bias0<scalar_t><<<gridSize, blockSize, 0, stream>>>
(output_d, grad_output_d, grad_bias_d);
} else if (implementation == 1) {
conv_relu_backward_bias1<scalar_t><<<gridSize, blockSize>>>
conv_relu_backward_bias1<scalar_t><<<gridSize, blockSize, 0, stream>>>
(output_d, grad_output_d, grad_bias_d);
}
THCudaCheck(cudaGetLastError());
Expand Down
17 changes: 15 additions & 2 deletions msd_pytorch/msd_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,8 @@ def __init__(self, in_channels, dilations, width=1):

depth = len(self.dilations)

self.bias = torch.nn.Parameter(torch.Tensor(depth * width))

self.weights = []
for i in range(depth):
n_in = in_channels + width * i
Expand All @@ -170,7 +172,6 @@ def __init__(self, in_channels, dilations, width=1):
self.register_parameter('weight{}'.format(i), weight)
self.weights.append(weight)

self.bias = torch.nn.Parameter(torch.Tensor(depth * width))

self.reset_parameters()

Expand All @@ -185,7 +186,19 @@ def reset_parameters(self):
torch.nn.init.uniform_(self.bias, -bound, bound)

def forward(self, input):
return MSDBlockImpl2d.apply(input, self.dilations, self.bias, *self.weights)
# This is a bit of a hack, since we require but cannot assume
# that self.parameters() remains sorted in the order that we
# added the parameters.
#
# However, we need to obtain weights in this way, because
# self.weights may become obsolete when used in multi-gpu
# settings when the weights are automatically transferred (by,
# e.g., torch.nn.DataParallel). In that case, self.weights may
# continue to point to the weight parameters on the original
# device, even when the weight parameters have been
# transferred to a different gpu.
bias, *weights = self.parameters()
return MSDBlockImpl2d.apply(input, self.dilations, bias, *weights)


class MSDModule2d(torch.nn.Module):
Expand Down
9 changes: 9 additions & 0 deletions msd_pytorch/msd_regression_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def __init__(
*,
dilations=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
loss="L2",
parallel=False,
):
"""Create a new MSD network for regression.
Expand All @@ -47,6 +48,12 @@ def __init__(
* "L1" - ``nn.L1Loss()``
* "L2" - ``nn.MSELoss()``
:param parallel: `bool`
Whether or not to execute the model on multiple GPUs. Note
that the batch size must be a multiple of the number of
available GPUs.
:returns:
:rtype:
Expand All @@ -62,6 +69,8 @@ def __init__(
# Define the whole network:
self.net = nn.Sequential(self.scale_in, self.msd, self.scale_out)
self.net.cuda()
if parallel:
self.net = nn.DataParallel(self.net)

# Train only MSD parameters:
self.init_optimizer(self.msd)
10 changes: 10 additions & 0 deletions msd_pytorch/msd_segmentation_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def __init__(
width,
*,
dilations=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
parallel=False,
):
"""Create a new MSD network for segmentation.
Expand All @@ -37,6 +38,12 @@ def __init__(
good alternative is ``[1, 2, 4, 8]``. The dilations are
repeated when there are more layers than supplied dilations.
:param parallel: `bool`
Whether or not to execute the model on multiple GPUs. Note
that the batch size must be a multiple of the number of
available GPUs.
:returns:
:rtype:
Expand All @@ -52,6 +59,9 @@ def __init__(
self.net = nn.Sequential(self.scale_in, net_trained)
self.net.cuda()

if parallel:
self.net = nn.DataParallel(self.net)

# Train all parameters apart from self.scale_in.
self.init_optimizer(net_trained)

Expand Down
9 changes: 5 additions & 4 deletions msd_pytorch/relu_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <cuda.h>
#include <cuda_runtime.h>

using at::OptionalDeviceGuard;

template <typename scalar_t>
__global__ void relu_inplace_cuda_forward_kernel(scalar_t* __restrict__ input, size_t size) {
Expand All @@ -24,26 +25,26 @@ __global__ void relu_inplace_cuda_backward_kernel(const scalar_t* __restrict__ i
}

at::Tensor relu_inplace_cuda_forward(at::Tensor input){

OptionalDeviceGuard device_guard(at::device_of(input));
const int threads = 1024;
auto size = input.numel();
const dim3 blocks((size + threads - 1) / threads, 1);

AT_DISPATCH_FLOATING_TYPES(input.type(), "relu_inplace_cuda_forward", ([&] {
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "relu_inplace_cuda_forward", ([&] {
relu_inplace_cuda_forward_kernel<scalar_t><<<blocks, threads>>>(input.data<scalar_t>(), size);
}));
return input;
}

at::Tensor relu_inplace_cuda_backward(at::Tensor input, at::Tensor grad_output){

OptionalDeviceGuard device_guard(at::device_of(grad_output));
auto grad_input = at::zeros_like(grad_output);

const int threads = 1024;
auto size = input.numel();
const dim3 blocks((size + threads - 1) / threads, 1);

AT_DISPATCH_FLOATING_TYPES(input.type(), "relu_inplace_cuda_forward", ([&] {
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "relu_inplace_cuda_forward", ([&] {
relu_inplace_cuda_backward_kernel<scalar_t><<<blocks, threads>>>(input.data<scalar_t>(),
grad_output.data<scalar_t>(),
grad_input.data<scalar_t>(),
Expand Down
15 changes: 15 additions & 0 deletions msd_pytorch/tests/test_msd_regression_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,21 @@ def test_params_change():
assert not torch_equal(p0, p1)


def test_data_parallel():
"""Check that msd_model is compatible with multi-GPU approaches
Specifically, `torch.nn.DataParallel`.
"""

shape = (100, 100)
inp = torch.zeros(4, 1, *shape, dtype=torch.float32, device=torch.device("cuda:0"))
tgt = torch.zeros(4, 1, *shape, dtype=torch.float32, device=torch.device("cuda:0"))

model = MSDRegressionModel(1, 1, 11, 1, parallel=True)
model.forward(inp, tgt)
model.learn(inp, tgt)


def test_api_surface(tmp_path):
###########################################################################
# Create network #
Expand Down
18 changes: 18 additions & 0 deletions msd_pytorch/tests/test_msd_segmentation_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,24 @@ def test_params_change():
assert not torch_equal(p0, p1)


def test_data_parallel():
"""Check that msd_model is compatible with multi-GPU approaches
Specifically, `torch.nn.DataParallel`.
"""

shape = (100, 100)
num_labels = 3
inp = torch.zeros(4, 1, *shape, dtype=torch.float32, device=torch.device("cuda:0"))
tgt = torch.randint(
low=0, high=num_labels, size=(4, 1, *shape), device=torch.device("cuda:0")
)

model = MSDSegmentationModel(1, num_labels, 11, 1, parallel=True)
model.forward(inp, tgt)
model.learn(inp, tgt)


def test_api_surface(tmp_path):
###########################################################################
# Create network #
Expand Down

0 comments on commit 3f1d648

Please sign in to comment.