-
Notifications
You must be signed in to change notification settings - Fork 43
/
utils.py
36 lines (26 loc) · 1.19 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import numpy as np
import torch
import torchnet as tnt
from torchvision.datasets.mnist import MNIST
import config
def augmentation(x, max_shift=2):
_, _, height, width = x.size()
h_shift, w_shift = np.random.randint(-max_shift, max_shift + 1, size=2)
source_height_slice = slice(max(0, h_shift), h_shift + height)
source_width_slice = slice(max(0, w_shift), w_shift + width)
target_height_slice = slice(max(0, -h_shift), -h_shift + height)
target_width_slice = slice(max(0, -w_shift), -w_shift + width)
shifted_image = torch.zeros(*x.size())
shifted_image[:, :, source_height_slice, source_width_slice] = x[:, :, target_height_slice, target_width_slice]
return shifted_image.float()
def get_iterator(mode):
dataset = MNIST(root='./data', train=mode, download=True)
data = getattr(dataset, 'train_data' if mode else 'test_data')
labels = getattr(dataset, 'train_labels' if mode else 'test_labels')
tensor_dataset = tnt.dataset.TensorDataset([data, labels])
return tensor_dataset.parallel(batch_size=config.BATCH_SIZE, num_workers=4, shuffle=mode)
if __name__ == "__main__":
t = torch.rand(1, 1, 28, 28)
print(t)
y = augmentation(t)
print(y)