diff --git a/finetune.py b/finetune.py index c944c9a..7c1d743 100644 --- a/finetune.py +++ b/finetune.py @@ -103,8 +103,8 @@ def train(imgL,imgR,disp_L): output3 = torch.squeeze(output3,1) loss = 0.5*F.smooth_l1_loss(output1[mask], disp_true[mask], size_average=True) + 0.7*F.smooth_l1_loss(output2[mask], disp_true[mask], size_average=True) + F.smooth_l1_loss(output3[mask], disp_true[mask], size_average=True) elif args.model == 'basic': - output = model(imgL,imgR) - output = torch.squeeze(output3,1) + output3 = model(imgL,imgR) + output3 = torch.squeeze(output3,1) loss = F.smooth_l1_loss(output3[mask], disp_true[mask], size_average=True) loss.backward()