forked from RolandGao/RegSeg
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdata_utils.py
39 lines (34 loc) · 1.46 KB
/
data_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import torch
import torch.utils.data
import numpy as np
import random
def cat_list(images, fill_value=0):
max_size = tuple(max(s) for s in zip(*[img.shape for img in images]))
batch_shape = (len(images),) + max_size
batched_imgs = images[0].new(*batch_shape).fill_(fill_value)
for img, pad_img in zip(images, batched_imgs):
pad_img[..., :img.shape[-2], :img.shape[-1]].copy_(img)
return batched_imgs
def collate_fn(batch):
images, targets = list(zip(*batch))
batched_imgs = cat_list(images, fill_value=0)
batched_targets = cat_list(targets, fill_value=255)
return batched_imgs, batched_targets
def worker_init_fn(worker_id):
worker_seed = torch.initial_seed() % 2**32
random.seed(worker_seed)
np.random.seed(worker_seed)
def get_dataloader_train(dataset,batch_size,num_workers=4):
train_sampler = torch.utils.data.RandomSampler(dataset)
data_loader = torch.utils.data.DataLoader(
dataset, batch_size=batch_size,
sampler=train_sampler, num_workers=num_workers,
collate_fn=collate_fn, drop_last=True,worker_init_fn=worker_init_fn)
return data_loader
def get_dataloader_val(dataset_test,num_workers=4):
test_sampler = torch.utils.data.SequentialSampler(dataset_test)
data_loader_test = torch.utils.data.DataLoader(
dataset_test, batch_size=1,
sampler=test_sampler, num_workers=num_workers,
collate_fn=collate_fn,worker_init_fn=worker_init_fn)
return data_loader_test