Skip to content

Commit

Permalink
Fixes missing DataParallel in rs train
Browse files Browse the repository at this point in the history
  • Loading branch information
daniel-j-h committed Oct 2, 2018
1 parent 344c0bb commit c1b3780
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion robosat/tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,9 @@ def main(args):
os.makedirs(model["common"]["checkpoint"], exist_ok=True)

num_classes = len(dataset["common"]["classes"])
net = FPNSegmentation(num_classes).to(device)
net = FPNSegmentation(num_classes)
net = DataParallel(net)
net = net.to(device)

if model["common"]["cuda"]:
torch.backends.cudnn.benchmark = True
Expand Down

0 comments on commit c1b3780

Please sign in to comment.