-
Notifications
You must be signed in to change notification settings - Fork 7
/
utils.py
77 lines (61 loc) · 2.28 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import torch
from torch.utils.data.sampler import BatchSampler, SubsetRandomSampler
from torch.utils.data import Dataset, DataLoader
from skimage.util.shape import view_as_windows
import os
import numpy as np
def reparameterize(mu, logsigma):
std = torch.exp(0.5*logsigma)
eps = torch.randn_like(std)
return mu + eps*std
def obs_extract(obs):
obs = np.transpose(obs['rgb'], (0,3,1,2))
return torch.from_numpy(obs)
def count_step(i_update, i_env, i_step, num_envs, num_steps):
step = i_update * (num_steps * num_envs) + i_env * num_steps + i_step
return step
# for representation learning
class ExpDataset(Dataset):
def __init__(self, file_dir, game, num_splitted, transform):
super(ExpDataset, self).__init__()
self.file_dir = file_dir
self.files = [f for f in os.listdir(file_dir) if game in f]
self.num_splitted = num_splitted
self.data = []
self.progress = 0
self.transform = transform
self.loadnext()
def __len__(self):
assert(len(self.data) > 0)
return self.data[0].shape[0]
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
return torch.stack([self.transform(d[idx]) for d in self.data])
def loadnext(self):
self.data = []
for file in self.files:
frames = np.load(os.path.join(self.file_dir, file, '%d.npz' % (self.progress)))['obs']
self.data.append(frames)
self.progress = (self.progress + 1) % self.num_splitted
# referred from https://github.com/MishaLaskin/curl
def random_crop(imgs, output_size):
"""
Vectorized way to do random crop using sliding windows
and picking out random ones
args:
imgs, batch images with shape (B,C,H,W)
"""
# batch size
n = imgs.shape[0]
img_size = imgs.shape[-1]
crop_max = img_size - output_size
imgs = np.transpose(imgs, (0, 2, 3, 1))
w1 = np.random.randint(0, crop_max, n)
h1 = np.random.randint(0, crop_max, n)
# creates all sliding windows combinations of size (output_size)
windows = view_as_windows(
imgs, (1, output_size, output_size, 1))[..., 0,:,:, 0]
# selects a random window for each batch element
cropped_imgs = windows[np.arange(n), w1, h1]
return cropped_imgs