-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathplayground.py
75 lines (63 loc) · 2.62 KB
/
playground.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import os
import numpy as np
from PIL import Image
import torch
from torch.utils.data import Dataset
from torchvision import transforms
class FrameDataset(Dataset):
def __init__(self, root_dir, labeled=True, transform=None):
"""
Args:
root_dir (string): Dataset folder path
labeled (bool, optional): If the input data is labeled. Defaults to True.
transform (torchvision.transforms, optional): Data transforms. Defaults to None.
"""
self.root_dir = root_dir
self.labeled = labeled
self.transform = transform
self.video_dirs = sorted([d for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d))])
def __len__(self):
return len(self.video_dirs)
def __getitem__(self, idx):
video_dir = os.path.join(self.root_dir, self.video_dirs[idx])
frame_files = sorted([f for f in os.listdir(video_dir) if f.endswith('.png')])
frames = []
for frame_file in frame_files:
frame_path = os.path.join(video_dir, frame_file)
frame = Image.open(frame_path).convert('RGB')
if self.transform:
frame = self.transform(frame)
frames.append(frame)
frames = torch.stack(frames)
if self.labeled:
mask_path = os.path.join(video_dir, 'mask.npy')
mask = np.load(mask_path)
mask = torch.from_numpy(mask).long()
return frames, mask
else:
return frames, None
def dataset_test():
train = "/scratch/jp4906/VideoMask/train"
unlabeled = "/scratch/jp4906/VideoMask/unlabeled"
val = "/scratch/jp4906/VideoMask/val"
transform = transforms.Compose([
transforms.ToTensor(),
])
train_dataset = FrameDataset(root_dir=train, labeled=True, transform=transform)
print("Train dataset has {} length".format(len(train_dataset)))
print("")
video_frames, video_mask = train_dataset[0]
print('video_frames for train size: {}'.format(video_frames.size()))
print("")
print('video_mask for train size: {}'.format(video_mask.shape))
print("")
print("====================")
unlabeled_dataset = FrameDataset(root_dir=unlabeled, labeled=False, transform=transform)
print("Train dataset has {} length".format(len(train_dataset)))
print("")
video_frames, video_mask = unlabeled_dataset[0]
print('video_frames for unlabeled size: {}'.format(video_frames.size()))
print("")
print('video_mask for unlabeled should be NoneType: {}'.format(video_mask))
print("")
print("====================")