diff --git a/robosat/tools/train.py b/robosat/tools/train.py index c8e8fee3..f6b74079 100644 --- a/robosat/tools/train.py +++ b/robosat/tools/train.py @@ -63,7 +63,9 @@ def main(args): os.makedirs(model["common"]["checkpoint"], exist_ok=True) num_classes = len(dataset["common"]["classes"]) - net = FPNSegmentation(num_classes).to(device) + net = FPNSegmentation(num_classes) + net = DataParallel(net) + net = net.to(device) if model["common"]["cuda"]: torch.backends.cudnn.benchmark = True