forked from Zhou-Yijie/MSLTNet
-
Notifications
You must be signed in to change notification settings - Fork 1
/
batch_transformers.py
81 lines (59 loc) · 3.09 KB
/
batch_transformers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import numpy as np
from torchvision import transforms
import torch
from PIL import Image
import collections
RANDOM_RESOLUTIONS = [512, 768, 1024, 1280, 1536]
class BatchRandomResolution(object):
def __init__(self, size=None, interpolation=Image.BILINEAR):
assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2) or (size is None)
self.size = size
self.interpolation = interpolation
def __call__(self, imgs):
if self.size is None:
h, w = imgs[0].size
max_idx = 0
for i in range(len(RANDOM_RESOLUTIONS)):
if h > RANDOM_RESOLUTIONS[i] and w > RANDOM_RESOLUTIONS[i]:
max_idx += 1
idx = np.random.randint(max_idx)
self.size = RANDOM_RESOLUTIONS[idx]
return [transforms.Resize(self.size, self.interpolation)(img) for img in imgs]
class BatchTestResolution(object):
def __init__(self, size=None, interpolation=Image.BILINEAR):
assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2) or (size is None)
self.size = size
self.interpolation = interpolation
def __call__(self, imgs):
h, w = imgs[0].size
if h > self.size and w > self.size:
return [transforms.Resize(self.size, self.interpolation)(img) for img in imgs]
else:
return imgs
class BatchToTensor(object):
def __call__(self, imgs):
return [transforms.ToTensor()(img) for img in imgs]
class unsqueese_1(object):
def __call__(self, imgs):
return [img.unsqueeze(1) for img in imgs]
class BatchRGBToGray(object):
def __call__(self, imgs):
return [img[0, :, :] * 0.299 + img[1, :, :] * 0.587 + img[2:, :, :] * 0.114 for img in imgs]
class BatchRGBToYCbCr(object):
def __call__(self, imgs):
return [torch.stack((0. / 256. + img[0, :, :] * 0.299000 + img[1, :, :] * 0.587000 + img[2, :, :] * 0.114000,
128. / 256. - img[0, :, :] * 0.168736 - img[1, :, :] * 0.331264 + img[2, :, :] * 0.500000,
128. / 256. + img[0, :, :] * 0.500000 - img[1, :, :] * 0.418688 - img[2, :, :] * 0.081312),
dim=0) for img in imgs]
class YCbCrToRGB(object):
def __call__(self, img):
return torch.stack((img[:, 0, :, :] + (img[:, 2, :, :] - 128 / 256.) * 1.402,
img[:, 0, :, :] - (img[:, 1, :, :] - 128 / 256.) * 0.344136 - (img[:, 2, :, :] - 128 / 256.) * 0.714136,
img[:, 0, :, :] + (img[:, 1, :, :] - 128 / 256.) * 1.772),
dim=1)
class RGBToYCbCr(object):
def __call__(self, img):
return torch.stack((0. / 256. + img[:, 0, :, :] * 0.299000 + img[:, 1, :, :] * 0.587000 + img[:, 2, :, :] * 0.114000,
128. / 256. - img[:, 0, :, :] * 0.168736 - img[:, 1, :, :] * 0.331264 + img[:, 2, :, :] * 0.500000,
128. / 256. + img[:, 0, :, :] * 0.500000 - img[:, 1, :, :] * 0.418688 - img[:, 2, :, :] * 0.081312),
dim=1)