diff --git a/model.py b/model.py index 8bee563..9793ae3 100644 --- a/model.py +++ b/model.py @@ -72,9 +72,10 @@ def main(): gt = torch.cat((gt, target), 0) bs, n_crops, c, h, w = inp.size() input_var = torch.autograd.Variable(inp.view(-1, c, h, w).cuda(), volatile=True) - output = model(input_var) - output_mean = output.view(bs, n_crops, -1).mean(1) - pred = torch.cat((pred, output_mean.data), 0) + with torch.no_grad(): # Fix CUDA out of memory + output = model(input_var) + output_mean = output.view(bs, n_crops, -1).mean(1) + pred = torch.cat((pred, output_mean.data), 0) AUROCs = compute_AUCs(gt, pred) AUROC_avg = np.array(AUROCs).mean() @@ -126,4 +127,4 @@ def forward(self, x): if __name__ == '__main__': - main() \ No newline at end of file + main()