Skip to content

Commit

Permalink
blearndataset: Add dataset range
Browse files Browse the repository at this point in the history
  • Loading branch information
marcojob committed Oct 3, 2024
1 parent 0eedcf4 commit 336fd35
Show file tree
Hide file tree
Showing 11 changed files with 45 additions and 21 deletions.
11 changes: 8 additions & 3 deletions radarmeetsvision/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,13 +112,18 @@ def set_optimizer(self, lr=0.000005):
return self.optimizer


def get_dataset_loader(self, task, datasets_dir, dataset_list):
def get_dataset_loader(self, task, datasets_dir, dataset_list, index_list=None):
datasets = []
datasets_dir = Path(datasets_dir)
for dataset_name in dataset_list:
for i, dataset_name in enumerate(dataset_list):
dataset_dir = datasets_dir / dataset_name

dataset = BlearnDataset(dataset_dir, task, self.size)
index_min, index_max = 0, -1
if index_list != None:
index_min, index_max = index_list[i][0], index_list[i][1]

dataset = BlearnDataset(dataset_dir, task, self.size, index_min, index_max)

if len(dataset) > 0:
datasets.append(dataset)

Expand Down
16 changes: 11 additions & 5 deletions radarmeetsvision/metric_depth_network/dataset/blearndataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
logger = logging.getLogger(__name__)

class BlearnDataset(Dataset):
def __init__(self, dataset_dir, mode, size):
def __init__(self, dataset_dir, mode, size, index_min=0, index_max=-1):
self.mode = mode
self.size = size

Expand Down Expand Up @@ -65,20 +65,25 @@ def __init__(self, dataset_dir, mode, size):
self.width = size[1]
self.train_split = 0.8

self.filelist = self.get_filelist()
self.filelist = self.get_filelist(index_min, index_max)
if self.filelist:
logger.info(f"Loaded {dataset_dir} with length {len(self.filelist)}")

def get_filelist(self):
# Sort and find all .jpg's

def get_filelist(self, index_min=0, index_max=-1):
all_rgb_files = list(self.rgb_dir.glob('*.jpg'))
all_rgb_files = sorted(all_rgb_files)
all_indexes = []

for rgb_file in all_rgb_files:
out = re.search(self.rgb_mask, str(rgb_file))
if out is not None:
all_indexes.append(int(out.group(1)))

if index_min != 0 or index_max != -1:
logger.info(f'Limiting dataset index: {index_min} - {index_max}')
all_indexes = [idx for idx in all_indexes if index_min <= idx <= index_max]

train_split_len = round(len(all_indexes) * self.train_split)
val_split_len = len(all_indexes) - train_split_len

Expand All @@ -94,12 +99,13 @@ def get_filelist(self):
else:
filelist = None

# Default case is to use full split
# Log the number of files selected
if filelist is not None:
logger.info(f"Using {len(filelist)}/{len(all_indexes)} for task {self.mode}")

return filelist


def __getitem__(self, item):
index = int(self.filelist[item])
img_path = self.rgb_dir / self.rgb_template.format(index)
Expand Down
1 change: 1 addition & 0 deletions scripts/evaluation/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"Agricultural Field": "outdoor0",
"Rhône Glacier": "rhone_flight"
},
"index": [[0, -1], [0, -1], [1205, 1505]],
"networks": {
"Metric Depth \\cite{depthanythingv2}-S": "rgb_s_bs8_e9.pth",
"Metric Depth \\cite{depthanythingv2}-B": "rgb_b_bs4_e8.pth",
Expand Down
15 changes: 10 additions & 5 deletions scripts/evaluation/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@ def __init__(self, config, scenario_key, network_key, args):
self.results_per_sample = {}
self.results_dict = {}
self.interface = rmv.Interface()
self.networks_dir = Path(args.network)
self.networks_dir = None
if args.network is not None:
self.networks_dir = Path(args.network)

self.datasets_dir = args.dataset
self.results, self.results_per_sample = None, None
self.setup_interface(config, scenario_key, network_key)
Expand All @@ -43,17 +46,19 @@ def setup_interface(self, config, scenario_key, network_key):
self.interface.set_output_channels(network_config['output_channels'])
self.interface.set_use_depth_prior(network_config['use_depth_prior'])

network_file = config['networks'][network_key]
if network_file is not None:
network_file = self.networks_dir / network_file
network_file = None
if self.networks_dir is not None and config['networks'][network_key] is not None:
network_file = self.networks_dir / config['networks'][network_key]

self.interface.load_model(pretrained_from=network_file)

self.interface.set_size(config['height'], config['width'])
self.interface.set_batch_size(1)
self.interface.set_criterion()

dataset_list = [config['scenarios'][scenario_key]]
self.loader, _ = self.interface.get_dataset_loader('val_all', self.datasets_dir, dataset_list)
index_list = config.get('index', None)
self.loader, _ = self.interface.get_dataset_loader('val_all', self.datasets_dir, dataset_list, index_list)

def get_results_per_sample(self):
return self.results_per_sample
Expand Down
5 changes: 3 additions & 2 deletions scripts/train/configs/train_metric_b.json
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@
"task": {
"train_all": {
"dir": "training",
"datasets": ["rhone2", "mountain_area", "rural_area", "road_corridor", "HyperSim"]
"datasets": ["rhone2", "mountain_area", "rural_area", "road_corridor", "HyperSim"],
},
"val_all": {
"dir": "validation",
"datasets": ["outdoor0", "maschinenhalle0", "rhone_flight"]
"datasets": ["outdoor0", "maschinenhalle0", "rhone_flight"],
"indeces": [[0, -1], [0, -1], [1205, 1505]]
}
}
}
3 changes: 2 additions & 1 deletion scripts/train/configs/train_metric_s.json
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
},
"val_all": {
"dir": "validation",
"datasets": ["outdoor0", "maschinenhalle0", "rhone_flight"]
"datasets": ["outdoor0", "maschinenhalle0", "rhone_flight"],
"indeces": [[0, -1], [0, -1], [1205, 1505]]
}
}
}
3 changes: 2 additions & 1 deletion scripts/train/configs/train_radar_b.json
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
},
"val_all": {
"dir": "validation",
"datasets": ["outdoor0", "maschinenhalle0", "rhone_flight"]
"datasets": ["outdoor0", "maschinenhalle0", "rhone_flight"],
"indeces": [[0, -1], [0, -1], [1205, 1505]]
}
}
}
3 changes: 2 additions & 1 deletion scripts/train/configs/train_radar_s.json
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
},
"val_all": {
"dir": "validation",
"datasets": ["outdoor0", "maschinenhalle0", "rhone_flight"]
"datasets": ["outdoor0", "maschinenhalle0", "rhone_flight"],
"indeces": [[0, -1], [0, -1], [1205, 1505]]
}
}
}
3 changes: 2 additions & 1 deletion scripts/train/configs/train_relative_b.json
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
},
"val_all": {
"dir": "validation",
"datasets": ["outdoor0", "maschinenhalle0", "rhone_flight"]
"datasets": ["outdoor0", "maschinenhalle0", "rhone_flight"],
"indeces": [[0, -1], [0, -1], [1205, 1505]]
}
}
}
3 changes: 2 additions & 1 deletion scripts/train/configs/train_relative_s.json
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
},
"val_all": {
"dir": "validation",
"datasets": ["outdoor0", "maschinenhalle0", "rhone_flight"]
"datasets": ["outdoor0", "maschinenhalle0", "rhone_flight"],
"indeces": [[0, -1], [0, -1], [1205, 1505]]
}
}
}
3 changes: 2 additions & 1 deletion scripts/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@ def main(config, checkpoints, datasets, results):
loaders = {}
for task in config['task'].keys():
dataset_list = config['task'][task]['datasets']
index_list = config['task'][task].get('indeces', None)
datasets_dir = Path(datasets) / config['task'][task]['dir']
loader, _ = interface.get_dataset_loader(task, str(datasets_dir), dataset_list)
loader, _ = interface.get_dataset_loader(task, str(datasets_dir), dataset_list, index_list)
loaders[task] = loader

for epoch in range(config['epochs']):
Expand Down

0 comments on commit 336fd35

Please sign in to comment.