-
Notifications
You must be signed in to change notification settings - Fork 0
/
dataset.py
executable file
·66 lines (55 loc) · 1.92 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from io import BytesIO
import lmdb
from PIL import Image
from torch.utils.data import Dataset
import random
import numpy as np
import torch
import torch.nn.functional as F
from torchvision import transforms
import glob
import os
class AIDataset(Dataset):
def __init__(self, name, root_path, resolution):
self.name = name
self.root_path = root_path
self.resolution = resolution
self.transform_4ch = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5, 0.5), (0.5, 0.5, 0.5, 0.5), inplace=True),
])
self.transform_3ch = transforms.Compose(
[
transforms.Pad(40),
transforms.RandomHorizontalFlip(),
transforms.Resize(1024),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True),
])
self.transform_1ch = transforms.Compose(
[
transforms.ToTensor(),
])
self.image_dir_lists = glob.glob(os.path.join(root_path, '*'))
self.seg_dir_lists = glob.glob(os.path.join('distort', '*', '*'))
self.length = len(self.image_dir_lists)
print(f'{self.name} : {self.length}')
def __len__(self):
return self.length
def __getitem__(self, index):
img_path = self.image_dir_lists[index]
file_data = Image.open(img_path)
shape = np.array(file_data).shape
if shape[-1] == 4:
image = self.transform_4ch(file_data)
elif shape[-1] == 3:
image = self.transform_3ch(file_data)
random_seg = random.randint(0, len(self.seg_dir_lists)-1)
seg = Image.open(self.seg_dir_lists[random_seg])
seg = seg.resize((1024,1024))
seg = np.array(seg)
seg[seg>0] = 1
seg = np.stack([seg, seg, seg], axis=0)
seg = torch.tensor(seg)
return image, image * seg