-
Notifications
You must be signed in to change notification settings - Fork 0
/
transforms.py
38 lines (30 loc) · 1.13 KB
/
transforms.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import random
from torchvision.transforms import functional as F
class Compose(object):
"""组合多个transform函数"""
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, image, target):
for t in self.transforms:
image, target = t(image, target)
return image, target
class ToTensor(object):
"""将PIL图像转为Tensor"""
def __call__(self, image, target):
image = F.to_tensor(image)
return image, target
class RandomHorizontalFlip(object):
"""随机水平翻转图像以及bboxes"""
def __init__(self, prob=0.5):
self.prob = prob
def __call__(self, image, target):
if random.random() < self.prob:
height, width = image.shape[-2:]
image = image.flip(-1) # 水平翻转图片
bbox = target["boxes"]
# bbox: xmin, ymin, xmax, ymax
bbox[:, [0, 2]] = width - bbox[:, [2, 0]] # 翻转对应bbox坐标信息
target["boxes"] = bbox
if "masks" in target:
target["masks"] = target["masks"].flip(-1)
return image, target