-
Notifications
You must be signed in to change notification settings - Fork 0
/
dataloader.py
97 lines (81 loc) · 3.3 KB
/
dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
from torch.utils.data import Dataset, DataLoader
import numpy as np
import sys
import torch
class UAVDatasetTuple(Dataset):
def __init__(self, task_path, init_path, label_path):
self.task_path = task_path
self.init_path = init_path
self.label_path = label_path
self.label_md = []
self.init_md = []
self.task_md = []
self._get_tuple()
def __len__(self):
return len(self.label_md)
def _get_tuple(self):
self.task_md = np.load(self.task_path).astype(float)
self.init_md = np.load(self.init_path).astype(float)
self.label_md = np.load(self.label_path).astype(float)
#assert len(self.task_md) == len(self.label_md), "not identical"
def __getitem__(self, idx):
try:
task = self._prepare_task(idx)
init = self._prepare_init(idx)
label = self._get_label(idx)
init = np.expand_dims(init, axis=0)
except Exception as e:
print('error encountered while loading {}'.format(idx))
print("Unexpected error:", sys.exc_info()[0])
print(e)
raise
return {'task': task, 'init':init, 'label': label}
def _prepare_init(self, idx):
init_md = self.init_md[idx]
return init_md
def _prepare_task(self, idx):
#task_coordinate = self.task_md[idx]
input = self.task_md[idx]
#print("input shape", input.shape)
task_md = torch.zeros([input.shape[0],15, 100, 100])
for i in range(input.shape[0]):
# if i < 30:
# continue
for j in range(15):
x1 = int(input[i][j][0])
y1 = int(input[i][j][1])
x2 = int(input[i][j][2])
y2 = int(input[i][j][3])
if x1 == 0 and y1 == 0 and x2 == 0 and y2 == 0:
continue
else:
task_md[i][j][x1][y1] = 1.00
task_md[i][j][x2][y2] = 1.00
# if i > 30:
# return task_md.reshape(input.shape[0],15,10000)
return task_md.reshape(input.shape[0],15,10000)
def _get_label(self, idx):
label_md = self.label_md[idx].reshape(100,100)
return label_md
def get_class_count(self):
total = len(self.label_md) * self.label_md[0].shape[0] * self.label_md[0].shape[1]
positive_class = 0
for label in self.label_md:
positive_class += np.sum(label)
print("The number of positive image pair is:", positive_class)
print("The number of negative image pair is:", total - positive_class)
positive_ratio = positive_class / total
negative_ratio = (total - positive_class) / total
return positive_ratio, negative_ratio
if __name__ == '__main__':
data_path ='/data/zzhao/uav_regression/main_test/data_tasks.npy'
init_path = '/data/zzhao/uav_regression/main_test/data_init_density.npy'
label_path = '/data/zzhao/uav_regression/main_test/training_label_density.npy'
all_dataset = UAVDatasetTuple(task_path=data_path, init_path=init_path, label_path=label_path)
sample = all_dataset[0]
print(sample['task'].shape)
count = 0
# for idx, val in enumerate(sample['task'][0]):
# if val == 1.00:
# print(idx)
# print(count)