-
Notifications
You must be signed in to change notification settings - Fork 6
/
Dataset.py
50 lines (37 loc) · 2.12 KB
/
Dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import DataLoader, DataPreprocesser, Debugger
import DatasetInstance_OurAerial
import numpy as np
class Dataset(object):
"""
Will handle the dataset
"""
def __init__(self, settings, init_source = 1):
self.settings = settings
self.dataLoader = DataLoader.DataLoader(settings)
self.debugger = Debugger.Debugger(settings)
if init_source == 1:
self.init_from_stable_datasets()
else:
print("Init manually from data and labels")
self.datasetInstance = None
self.dataPreprocesser = None
def init_from_stable_datasets(self):
dataset_variant = "256_cleanManual"
self.datasetInstance = DatasetInstance_OurAerial.DatasetInstance_OurAerial(self.settings, self.dataLoader, dataset_variant)
number_of_channels = self.datasetInstance.CHANNEL_NUMBER
self.dataPreprocesser = DataPreprocesser.DataPreprocesser(self.settings,number_of_channels)
self.data, self.paths = self.datasetInstance.load_dataset()
if self.settings.verbose >= 3:
self.debugger.inspect_dataset(self.data, self.paths, 3) # 3
print("Dataset loaded with", len(self.data[0]), "images.")
# Split into training, validation and test:
K = self.settings.TestDataset_K_Folds
test_fold = self.settings.TestDataset_Fold_Index
print("K-Fold crossval: [",test_fold,"from",K,"]")
self.train, self.val, self.test = self.datasetInstance.split_train_val_test_KFOLDCROSSVAL(self.data, test_fold=test_fold, K=K)
self.paths = np.asarray(self.paths)
self.train_paths, self.val_paths, self.test_paths = self.datasetInstance.split_train_val_test_KFOLDCROSSVAL(self.paths, test_fold=test_fold, K=K)
print("Has ", len(self.train[0]), "train, ", len(self.val[0]), "val, ", len(self.test[0]), "test, ")
print("Has ", len(self.train_paths[0]), "train_paths, ", len(self.val_paths[0]), "val_paths, ", len(self.test_paths[0]), "test_paths, ")
# preprocess the dataset
self.train, self.val, self.test = self.dataPreprocesser.process_dataset(self.train, self.val, self.test)