From 12805762cc4c98bd348b6169eb1074067c6f878a Mon Sep 17 00:00:00 2001 From: Luca Coviello Date: Thu, 26 Jul 2018 14:37:36 +0200 Subject: [PATCH] Perform data augmentation in training and normalize correctly at test time --- image.py | 4 ++-- val.ipynb | 25 +++++++++++-------------- 2 files changed, 13 insertions(+), 16 deletions(-) diff --git a/image.py b/image.py index c8506d49..c4a6a848 100644 --- a/image.py +++ b/image.py @@ -11,7 +11,7 @@ def load_data(img_path,train = True): img = Image.open(img_path).convert('RGB') gt_file = h5py.File(gt_path) target = np.asarray(gt_file['density']) - if False: + if train: crop_size = (img.size[0]/2,img.size[1]/2) if random.randint(0,9)<= -1: @@ -40,4 +40,4 @@ def load_data(img_path,train = True): target = cv2.resize(target,(target.shape[1]/8,target.shape[0]/8),interpolation = cv2.INTER_CUBIC)*64 - return img,target \ No newline at end of file + return img,target diff --git a/val.ipynb b/val.ipynb index ba73c72a..e6475ccd 100644 --- a/val.ipynb +++ b/val.ipynb @@ -18,11 +18,12 @@ "from scipy.ndimage.filters import gaussian_filter \n", "import scipy\n", "import json\n", - "import torchvision.transforms.functional as F\n", + "from torchvision import datasets, transforms\n", "from matplotlib import cm as CM\n", "from image import *\n", "from model import CSRNet\n", "import torch\n", + "\n", "%matplotlib inline" ] }, @@ -34,11 +35,11 @@ }, "outputs": [], "source": [ - "from torchvision import datasets, transforms\n", - "transform=transforms.Compose([\n", - " transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],\n", - " std=[0.229, 0.224, 0.225]),\n", - " ])" + "transform = transforms.Compose([\n", + " transforms.ToTensor(),\n", + " transforms.Normalize(mean=[0.485, 0.456, 0.406],\n", + " std=[0.229, 0.224, 0.225])\n", + "])" ] }, { @@ -326,17 +327,13 @@ "source": [ "mae = 0\n", "for i in xrange(len(img_paths)):\n", - " img = 255.0 * F.to_tensor(Image.open(img_paths[i]).convert('RGB'))\n", - "\n", - " img[0,:,:]=img[0,:,:]-92.8207477031\n", - " img[1,:,:]=img[1,:,:]-95.2757037428\n", - " img[2,:,:]=img[2,:,:]-104.877445883\n", - " img = img.cuda()\n", - " #img = transform(Image.open(img_paths[i]).convert('RGB')).cuda()\n", + " img = transform(Image.open(img_paths[i]).convert('RGB')).cuda()\n", " gt_file = h5py.File(img_paths[i].replace('.jpg','.h5').replace('images','ground_truth'),'r')\n", " groundtruth = np.asarray(gt_file['density'])\n", + " \n", " output = model(img.unsqueeze(0))\n", " mae += abs(output.detach().cpu().sum().numpy()-np.sum(groundtruth))\n", + " \n", " print i,mae\n", "print mae/len(img_paths)" ] @@ -358,7 +355,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", - "version": "2.7.13" + "version": "2.7.15" } }, "nbformat": 4,