-
Notifications
You must be signed in to change notification settings - Fork 27
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Workaround for nn.DataParallel bug #1
Comments
Thank you for the issue! AnalysisHere is the minimally reproducing example: import torch
from torch import nn
from cplxmodule import cplx
import cplxmodule.nn as cplxnn
from torch.nn.parallel.data_parallel import data_parallel
net, x = nn.Conv2d(3, 3, 3), torch.randn(2, 3, 6, 6)
data_parallel(net.cuda(0), x.cuda(0), [0, 1])
net, x = cplxnn.CplxConv2d(3, 3, 3), cplx.Cplx(torch.randn(1, 3, 6, 6))
data_parallel(net.cuda(0), x.cuda(0), [0, 1])
from torch.nn.parallel.scatter_gather import scatter_kwargs
from torch.nn.parallel.replicate import replicate
from torch.nn.parallel.parallel_apply import parallel_apply
replication and moving inputs manually seems to work ok: from torch.nn.parallel.replicate import replicate
net = cplxnn.CplxConv2d(3, 3, 3)
x = cplx.Cplx(torch.randn(1, 3, 6, 6), torch.randn(1, 3, 6, 6))
replicas = replicate(net.cuda(0), [0, 1])
print([m.weight.device for m in replicas])
print([m(x.to(m.weight.device)).device for m in replicas]) Since the error message from import torch.nn.functional as F
F.conv2d(torch.randn(1, 3, 6, 6).cuda(0), torch.randn(3, 3, 3, 3).cuda(1)) I think the issue here is that the input is on the wrong device, whilst the model might be on the right one. Thus i investigated scatter_kwargs.
SolutionUnfortunately I can only suggest a workaround wrapper. def r2r_wrap(model, dim_in=1, dim_out=1):
return torch.nn.Sequential(
# convert Tensor `B x F*2 x ...` (dim_in=1) to Cplx as part of the computations
cplxnn.RealToCplx(dim=dim_in),
model,
# convert Cplx back to B x O*2 x ... (dim_out=1) Tensor as part of the pipeline
cplxnn.CplxToReal(dim=dim_out)
)
model_complex = r2r_wrap(cplxtest_net()).to(device) This is a thin Real-to-Real wrapper around the whole model, which makes conversion from torch Tensors to Cplx and back a part of the model structure, and thus bypasses the scatter issue. The key nuance is that your input to and output of the model is just # put real-imag pairs along the 1st dim
complex_data = torch.randn(1, 3*2, 224, 224).to(device)
# real-imag pairs are assumed to alternate the 1st dim
cplx_data = cplxnn.RealToCplx(dim=1)(complex_data)
assert torch.allclose(cplx_data.real, complex_data[:, 0::2])
assert torch.allclose(cplx_data.imag, complex_data[:, 1::2]) The output: tensor_output = model_complex(complex_data)
cplx_data = cplxnn.RealToCplx(dim=1)(tensor_output)
assert torch.allclose(cplx_data.real, tensor_output[:, 0::2])
assert torch.allclose(cplx_data.imag, tensor_output[:, 1::2]) These Tensors are just a way to store the complex numbers and in no way affect their arithmetic or operations inside the Cplx network or cplxmodule itself. |
It seems cplxmoduke don't work with nn.DataParallel.
Attached minimal example gives the following error
cplxmodule-bug.py.txt
The text was updated successfully, but these errors were encountered: