forked from zsef123/EfficientNets-PyTorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
loader.py
38 lines (33 loc) · 1.15 KB
/
loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision import transforms
def get_loaders(root, batch_size, resolution, num_workers=32):
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
train_dataset = ImageFolder(
root + "/train",
transforms.Compose([
transforms.Resize([resolution, resolution]),
transforms.RandomResizedCrop(resolution),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
])
)
val_dataset = ImageFolder(
root + "/val",
transforms.Compose([
transforms.Resize([resolution, resolution]),
transforms.ToTensor(),
normalize,
])
)
train_loader = DataLoader(train_dataset,
batch_size=batch_size, shuffle=True,
num_workers=num_workers, pin_memory=True
)
val_loader = DataLoader(val_dataset,
batch_size=batch_size, shuffle=False,
num_workers=num_workers, pin_memory=True
)
return train_loader, val_loader