-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdataset.py
88 lines (74 loc) · 3 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
# on colab :
#!pip install fastai --upgrade
# from fastai.data.external import untar_data, URLs
# coco_path = untar_data(URLs.COCO_SAMPLE)
# coco_path = str(coco_path) + "/train_sample"
# use_colab = True
import os
import glob
import time
import numpy as np
from PIL import Image
from pathlib import Path
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from skimage.color import rgb2lab, lab2rgb
import torch
from torch import nn, optim
from torchvision import transforms
from torchvision.utils import make_grid
from torch.utils.data import Dataset, DataLoader
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
use_colab = None
if use_colab == True:
path = coco_path
else:
path = "Path to the dataset"
paths = glob.glob(path + "/*.jpg") # Grabbing all the image file names
np.random.seed(123)
paths_subset = np.random.choice(paths, 10_000, replace=False) # choosing 1000 images randomly
rand_idxs = np.random.permutation(10_000)
train_idxs = rand_idxs[:8000] # choosing the first 8000 as training set
val_idxs = rand_idxs[8000:] # choosing last 2000 as validation set
train_paths = paths_subset[train_idxs]
val_paths = paths_subset[val_idxs]
print(len(train_paths), len(val_paths))
_, axes = plt.subplots(4, 4, figsize=(10, 10))
for ax, img_path in zip(axes.flatten(), train_paths):
ax.imshow(Image.open(img_path))
ax.axis("off")
SIZE = 256
class ColorizationDataset(Dataset):
def __init__(self, paths, split='train'):
if split == 'train':
self.transforms = transforms.Compose([
transforms.Resize((SIZE, SIZE), Image.BICUBIC),
transforms.RandomHorizontalFlip(), # A little data augmentation!
])
elif split == 'val':
self.transforms = transforms.Resize((SIZE, SIZE), Image.BICUBIC)
self.split = split
self.size = SIZE
self.paths = paths
def __getitem__(self, idx):
img = Image.open(self.paths[idx]).convert("RGB")
img = self.transforms(img)
img = np.array(img)
img_lab = rgb2lab(img).astype("float32") # Converting RGB to L*a*b
img_lab = transforms.ToTensor()(img_lab)
L = img_lab[[0], ...] / 50. - 1. # Between -1 and 1
ab = img_lab[[1, 2], ...] / 110. # Between -1 and 1
return {'L': L, 'ab': ab}
def __len__(self):
return len(self.paths)
def make_dataloaders(batch_size=16, n_workers=4, pin_memory=True, **kwargs): # A handy function to make our dataloaders
dataset = ColorizationDataset(**kwargs)
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=n_workers,
pin_memory=pin_memory)
return dataloader
train_dl = make_dataloaders(paths=train_paths, split='train')
val_dl = make_dataloaders(paths=val_paths, split='val')
data = next(iter(train_dl))
Ls, abs_ = data['L'], data['ab']
print(Ls.shape, abs_.shape)
print(len(train_dl), len(val_dl))