-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathnatural_imgsource.py
77 lines (66 loc) · 2.43 KB
/
natural_imgsource.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
# This code provides the class that is used to generate backgrounds for the natural background setting
# the class is used inside an environment wrapper and will be called each time the env generates an observation
# the code is largely based on https://github.com/facebookresearch/deep_bisim4control
import random
import cv2
import numpy as np
import skvideo.io
class ImageSource(object):
"""
Source of natural images to be added to a simulated environment.
"""
def get_image(self):
"""
Returns:
an RGB image of [h, w, 3] with a fixed shape.
"""
pass
def reset(self):
""" Called when an episode ends. """
pass
class RandomVideoSource(ImageSource):
def __init__(self, shape, filelist, random_bg=False, max_videos=100, grayscale=False):
"""
Args:
shape: [h, w]
filelist: a list of video files
"""
self.grayscale = grayscale
self.shape = shape
self.filelist = filelist
random.shuffle(self.filelist)
self.filelist = self.filelist[:max_videos]
self.max_videos = max_videos
self.random_bg = random_bg
self.current_idx = 0
self._current_vid = None
self.reset()
def load_video(self, vid_id):
fname = self.filelist[vid_id]
if self.grayscale:
frames = skvideo.io.vread(fname, outputdict={"-pix_fmt": "gray"})
else:
frames = skvideo.io.vread(fname, num_frames=1000)
img_arr = np.zeros((frames.shape[0], self.shape[0], self.shape[1]) + ((3,) if not self.grayscale else (1,)))
for i in range(frames.shape[0]):
img_arr[i] = cv2.resize(
frames[i], (self.shape[1], self.shape[0])
) # THIS IS NOT A BUG! cv2 uses (width, height)
return img_arr
def reset(self):
del self._current_vid
while True:
try:
self._video_id = np.random.randint(0, len(self.filelist))
self._current_vid = self.load_video(self._video_id)
break
except Exception:
continue
self._loc = np.random.randint(0, len(self._current_vid))
def get_image(self):
if self.random_bg:
self._loc = np.random.randint(0, len(self._current_vid))
else:
self._loc += 1
img = self._current_vid[self._loc % len(self._current_vid)]
return img