-
Notifications
You must be signed in to change notification settings - Fork 0
/
misc.py
36 lines (30 loc) · 1.17 KB
/
misc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import numpy as np
import torch
class Cutout(object):
def __init__(self, n_holes, length):
"""
randomly mask out one or more patches from an image
:param n_holes: the number of patches to cut out of each image
:param length: the length (in pixels) of each square patch
"""
self.n_holes = n_holes
self.length = length
def __call__(self, img):
"""
:param img: (Tensor) tensor image of size (C, H, W)
:return: (Tensor) image with n_holes of dimension length x length cut out of it
"""
h, w = img.size(1), img.size(2)
mask = np.ones((h, w), np.float32)
for n in range(self.n_holes):
y = np.random.randint(h)
x = np.random.randint(w)
y1 = np.clip(y - self.length // 2, 0, h)
y2 = np.clip(y + self.length // 2, 0, h)
x1 = np.clip(x - self.length // 2, 0, w)
x2 = np.clip(x + self.length // 2, 0, w)
mask[y1: y2, x1: x2] = 0.
mask = torch.from_numpy(mask)
mask = mask.expand_as(img)
img.mul_(mask)
return img