-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtransforms.py
79 lines (59 loc) · 2.26 KB
/
transforms.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import random
from typing import List, Union
from torchvision.transforms import functional as F
from torchvision.transforms import transforms as T
class Compose(object):
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, image, target=None):
for t in self.transforms:
image, target = t(image, target)
return image, target
class ToTensor(object):
def __call__(self, image, target):
image = F.to_tensor(image)
target = F.to_tensor(target)
return image, target
class RandomHorizontalFlip(object):
def __init__(self, prob):
self.flip_prob = prob
def __call__(self, image, target):
if random.random() < self.flip_prob:
image = F.hflip(image)
target = F.hflip(target)
return image, target
class Normalize(object):
def __init__(self, mean, std):
self.mean = mean
self.std = std
def __call__(self, image, target):
image = F.normalize(image, mean=self.mean, std=self.std)
return image, target
class Resize(object):
def __init__(self, size: Union[int, List[int]], resize_mask: bool = True):
self.size = size # [h, w]
self.resize_mask = resize_mask
def __call__(self, image, target=None):
image = F.resize(image, self.size)
if self.resize_mask is True:
target = F.resize(target, self.size)
return image, target
class RandomCrop(object):
def __init__(self, size: int):
self.size = size
def pad_if_smaller(self, img, fill=0):
# 如果图像最小边长小于给定size,则用数值fill进行padding
min_size = min(img.shape[-2:])
if min_size < self.size:
ow, oh = img.size
padh = self.size - oh if oh < self.size else 0
padw = self.size - ow if ow < self.size else 0
img = F.pad(img, [0, 0, padw, padh], fill=fill)
return img
def __call__(self, image, target):
image = self.pad_if_smaller(image)
target = self.pad_if_smaller(target)
crop_params = T.RandomCrop.get_params(image, (self.size, self.size))
image = F.crop(image, *crop_params)
target = F.crop(target, *crop_params)
return image, target