-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathaugmentation.py
51 lines (47 loc) · 1.73 KB
/
augmentation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import torch
from torchvision.transforms import v2
from torchvision import tv_tensors
class XrayAugmentation:
def __init__(self, target_size=None, padding=None):
pass
def augmentation(self, img, bbox, degrees=(0,0), method='Normal'):
if method=='Demo':
trans_list = []
result_list = []
trans.append(
transforms1 = v2.Compose([
v2.RandomHorizontalFlip(p=1),
v2.ToDtype(torch.float32, scale=True),
]))
trans_list.append(
transforms2 = v2.Compose([
v2.RandomVerticalFlip(p=1),
v2.ToDtype(torch.float32, scale=True),
]))
trans_list.append(
transforms3 = v2.Compose([
v2.RandomHorizontalFlip(p=1),
v2.RandomVerticalFlip(p=1),
v2.ToDtype(torch.float32, scale=True),
]))
trans_list.append(
transforms4 = v2.Compose([
v2.RandomRotation(degrees=degrees),
v2.ToDtype(torch.float32, scale=True),
]))
for transform in trans_list:
result_list.append(transform(img, bbox))
return
elif method=='Normal':
if self.w_state['Hflip'].get()==1:
transforms = v2.Compose([
v2.RandomHorizontalFlip(p=1),
v2.ToDtype(torch.float32, scale=True),
])
print('H')
if self.w_state['Vflip'].get()==1:
transforms = v2.Compose([
v2.RandomVerticalFlip(p=1),
v2.ToDtype(torch.float32, scale=True),
])
print('V')