diff --git a/.devcontainer/devcontainer_all_packages.sh b/.devcontainer/devcontainer_all_packages.sh index b1e2960..5b0f343 100755 --- a/.devcontainer/devcontainer_all_packages.sh +++ b/.devcontainer/devcontainer_all_packages.sh @@ -12,6 +12,7 @@ main() { ccache cm-super curl + dvipng gawk gnupg htop @@ -27,6 +28,8 @@ main() { software-properties-common ssh sudo + texlive-fonts-recommended + texlive-latex-extra udev unzip usbutils diff --git a/ci/pr_train_networks.bash b/ci/pr_train_networks.bash index d93b07e..f9ca0d1 100755 --- a/ci/pr_train_networks.bash +++ b/ci/pr_train_networks.bash @@ -13,7 +13,7 @@ config_relative=$config_path/test_train_relative.json pip install -e . # RADAR TRAINING (depth prior + 2 output channels) -python3 scripts/train.py \ +python3 scripts/train/train.py \ --checkpoints $checkpoints \ --config $config_radar \ --datasets $datasets \ @@ -27,7 +27,7 @@ else fi # RGB TRAINING (no depth prior + 1 output channel) -python3 scripts/train.py \ +python3 scripts/train/train.py \ --checkpoints $checkpoints \ --config $config_metric \ --datasets $datasets \ @@ -41,7 +41,7 @@ else fi # Relative RGB TRAINING (no depth prior + 1 output channel) -python3 scripts/train.py \ +python3 scripts/train/train.py \ --checkpoints $checkpoints \ --config $config_relative \ --datasets $datasets \ diff --git a/radarmeetsvision/interface.py b/radarmeetsvision/interface.py index 072321f..432e7ac 100644 --- a/radarmeetsvision/interface.py +++ b/radarmeetsvision/interface.py @@ -36,7 +36,8 @@ def __init__(self): self.use_depth_prior = None def reset_previous_best(self): - return {'d1': 0, 'd2': 0, 'd3': 0, 'abs_rel': 100, 'sq_rel': 100, 'rmse': 100, 'rmse_log': 100, 'log10': 100, 'silog': 100} + return {'d1': 0, 'd2': 0, 'd3': 0, 'abs_rel': 100, 'sq_rel': 100, 'rmse': 100, 'rmse_log': 100, 'log10': 100, 'silog': 100, + 'average_depth': 0.0} def set_use_depth_prior(self, use): self.use_depth_prior = use @@ -112,13 +113,18 @@ def set_optimizer(self, lr=0.000005): return self.optimizer - def get_dataset_loader(self, task, datasets_dir, dataset_list): + def get_dataset_loader(self, task, datasets_dir, dataset_list, index_list=None): datasets = [] datasets_dir = Path(datasets_dir) - for dataset_name in dataset_list: + for i, dataset_name in enumerate(dataset_list): dataset_dir = datasets_dir / dataset_name - dataset = BlearnDataset(dataset_dir, task, self.size) + index_min, index_max = 0, -1 + if index_list != None: + index_min, index_max = index_list[i][0], index_list[i][1] + + dataset = BlearnDataset(dataset_dir, task, self.size, index_min, index_max) + if len(dataset) > 0: datasets.append(dataset) diff --git a/radarmeetsvision/metric_depth_network/dataset/blearndataset.py b/radarmeetsvision/metric_depth_network/dataset/blearndataset.py index 49b09af..8ef1e19 100644 --- a/radarmeetsvision/metric_depth_network/dataset/blearndataset.py +++ b/radarmeetsvision/metric_depth_network/dataset/blearndataset.py @@ -23,7 +23,7 @@ logger = logging.getLogger(__name__) class BlearnDataset(Dataset): - def __init__(self, dataset_dir, mode, size): + def __init__(self, dataset_dir, mode, size, index_min=0, index_max=-1): self.mode = mode self.size = size @@ -65,20 +65,25 @@ def __init__(self, dataset_dir, mode, size): self.width = size[1] self.train_split = 0.8 - self.filelist = self.get_filelist() + self.filelist = self.get_filelist(index_min, index_max) if self.filelist: logger.info(f"Loaded {dataset_dir} with length {len(self.filelist)}") - def get_filelist(self): - # Sort and find all .jpg's + + def get_filelist(self, index_min=0, index_max=-1): all_rgb_files = list(self.rgb_dir.glob('*.jpg')) all_rgb_files = sorted(all_rgb_files) all_indexes = [] + for rgb_file in all_rgb_files: out = re.search(self.rgb_mask, str(rgb_file)) if out is not None: all_indexes.append(int(out.group(1))) + if index_min != 0 or index_max != -1: + logger.info(f'Limiting dataset index: {index_min} - {index_max}') + all_indexes = [idx for idx in all_indexes if index_min <= idx <= index_max] + train_split_len = round(len(all_indexes) * self.train_split) val_split_len = len(all_indexes) - train_split_len @@ -94,12 +99,13 @@ def get_filelist(self): else: filelist = None - # Default case is to use full split + # Log the number of files selected if filelist is not None: logger.info(f"Using {len(filelist)}/{len(all_indexes)} for task {self.mode}") return filelist + def __getitem__(self, item): index = int(self.filelist[item]) img_path = self.rgb_dir / self.rgb_template.format(index) diff --git a/radarmeetsvision/metric_depth_network/util/utils.py b/radarmeetsvision/metric_depth_network/util/utils.py index 59d60f1..73bab70 100644 --- a/radarmeetsvision/metric_depth_network/util/utils.py +++ b/radarmeetsvision/metric_depth_network/util/utils.py @@ -36,7 +36,8 @@ def get_empty_results(device): 'rmse': 0.0, 'rmse_log': 0.0, 'log10': 0.0, - 'silog': 0.0 + 'silog': 0.0, + 'average_depth': 0.0 } results_per_sample = { 'd1': [], @@ -47,7 +48,8 @@ def get_empty_results(device): 'rmse': [], 'rmse_log': [], 'log10': [], - 'silog': [] + 'silog': [], + 'average_depth': [] } nsamples = torch.tensor([0.0]).to(device) return results, results_per_sample, nsamples diff --git a/scripts/evaluation/config.json b/scripts/evaluation/config.json index cc25b37..8a989d6 100644 --- a/scripts/evaluation/config.json +++ b/scripts/evaluation/config.json @@ -4,6 +4,11 @@ "Agricultural Field": "outdoor0", "Rhône Glacier": "rhone_flight" }, + "index": { + "maschinenhalle0": [0, -1], + "outdoor0": [0, -1], + "rhone_flight": [1205, 1505] + }, "networks": { "Metric Depth \\cite{depthanythingv2}-S": "rgb_s_bs8_e9.pth", "Metric Depth \\cite{depthanythingv2}-B": "rgb_b_bs4_e8.pth", diff --git a/scripts/evaluation/evaluate.py b/scripts/evaluation/evaluate.py index 1049fca..0a08c57 100644 --- a/scripts/evaluation/evaluate.py +++ b/scripts/evaluation/evaluate.py @@ -21,7 +21,10 @@ def __init__(self, config, scenario_key, network_key, args): self.results_per_sample = {} self.results_dict = {} self.interface = rmv.Interface() - self.networks_dir = Path(args.network) + self.networks_dir = None + if args.network is not None: + self.networks_dir = Path(args.network) + self.datasets_dir = args.dataset self.results, self.results_per_sample = None, None self.setup_interface(config, scenario_key, network_key) @@ -43,9 +46,10 @@ def setup_interface(self, config, scenario_key, network_key): self.interface.set_output_channels(network_config['output_channels']) self.interface.set_use_depth_prior(network_config['use_depth_prior']) - network_file = config['networks'][network_key] - if network_file is not None: - network_file = self.networks_dir / network_file + network_file = None + if self.networks_dir is not None and config['networks'][network_key] is not None: + network_file = self.networks_dir / config['networks'][network_key] + self.interface.load_model(pretrained_from=network_file) self.interface.set_size(config['height'], config['width']) @@ -53,7 +57,8 @@ def setup_interface(self, config, scenario_key, network_key): self.interface.set_criterion() dataset_list = [config['scenarios'][scenario_key]] - self.loader, _ = self.interface.get_dataset_loader('val_all', self.datasets_dir, dataset_list) + index_list = [config['index'][config["scenarios"][scenario_key]]] + self.loader, _ = self.interface.get_dataset_loader('val_all', self.datasets_dir, dataset_list, index_list) def get_results_per_sample(self): return self.results_per_sample diff --git a/scripts/configs/train_metric_b.json b/scripts/train/configs/train_metric_b.json similarity index 84% rename from scripts/configs/train_metric_b.json rename to scripts/train/configs/train_metric_b.json index 1d1aaa5..e9a1d32 100644 --- a/scripts/configs/train_metric_b.json +++ b/scripts/train/configs/train_metric_b.json @@ -11,11 +11,12 @@ "task": { "train_all": { "dir": "training", - "datasets": ["rhone2", "mountain_area", "rural_area", "road_corridor", "HyperSim"] + "datasets": ["rhone2", "mountain_area", "rural_area", "road_corridor", "HyperSim"], }, "val_all": { "dir": "validation", - "datasets": ["outdoor0", "maschinenhalle0", "rhone_flight"] + "datasets": ["outdoor0", "maschinenhalle0", "rhone_flight"], + "indeces": [[0, -1], [0, -1], [1205, 1505]] } } -} \ No newline at end of file +} diff --git a/scripts/configs/train_metric_s.json b/scripts/train/configs/train_metric_s.json similarity index 89% rename from scripts/configs/train_metric_s.json rename to scripts/train/configs/train_metric_s.json index 1adc3bf..06ef212 100644 --- a/scripts/configs/train_metric_s.json +++ b/scripts/train/configs/train_metric_s.json @@ -15,7 +15,8 @@ }, "val_all": { "dir": "validation", - "datasets": ["outdoor0", "maschinenhalle0", "rhone_flight"] + "datasets": ["outdoor0", "maschinenhalle0", "rhone_flight"], + "indeces": [[0, -1], [0, -1], [1205, 1505]] } } -} \ No newline at end of file +} diff --git a/scripts/configs/train_radar_b.json b/scripts/train/configs/train_radar_b.json similarity index 89% rename from scripts/configs/train_radar_b.json rename to scripts/train/configs/train_radar_b.json index fbbfd48..3f5c2a6 100644 --- a/scripts/configs/train_radar_b.json +++ b/scripts/train/configs/train_radar_b.json @@ -15,7 +15,8 @@ }, "val_all": { "dir": "validation", - "datasets": ["outdoor0", "maschinenhalle0", "rhone_flight"] + "datasets": ["outdoor0", "maschinenhalle0", "rhone_flight"], + "indeces": [[0, -1], [0, -1], [1205, 1505]] } } -} \ No newline at end of file +} diff --git a/scripts/configs/train_radar_s.json b/scripts/train/configs/train_radar_s.json similarity index 89% rename from scripts/configs/train_radar_s.json rename to scripts/train/configs/train_radar_s.json index da0a064..fee8e0b 100644 --- a/scripts/configs/train_radar_s.json +++ b/scripts/train/configs/train_radar_s.json @@ -15,7 +15,8 @@ }, "val_all": { "dir": "validation", - "datasets": ["outdoor0", "maschinenhalle0", "rhone_flight"] + "datasets": ["outdoor0", "maschinenhalle0", "rhone_flight"], + "indeces": [[0, -1], [0, -1], [1205, 1505]] } } -} \ No newline at end of file +} diff --git a/scripts/configs/train_relative_b.json b/scripts/train/configs/train_relative_b.json similarity index 89% rename from scripts/configs/train_relative_b.json rename to scripts/train/configs/train_relative_b.json index 078a2ca..0b66c76 100644 --- a/scripts/configs/train_relative_b.json +++ b/scripts/train/configs/train_relative_b.json @@ -15,7 +15,8 @@ }, "val_all": { "dir": "validation", - "datasets": ["outdoor0", "maschinenhalle0", "rhone_flight"] + "datasets": ["outdoor0", "maschinenhalle0", "rhone_flight"], + "indeces": [[0, -1], [0, -1], [1205, 1505]] } } -} \ No newline at end of file +} diff --git a/scripts/configs/train_relative_s.json b/scripts/train/configs/train_relative_s.json similarity index 89% rename from scripts/configs/train_relative_s.json rename to scripts/train/configs/train_relative_s.json index f36585b..fa2b9e8 100644 --- a/scripts/configs/train_relative_s.json +++ b/scripts/train/configs/train_relative_s.json @@ -15,7 +15,8 @@ }, "val_all": { "dir": "validation", - "datasets": ["outdoor0", "maschinenhalle0", "rhone_flight"] + "datasets": ["outdoor0", "maschinenhalle0", "rhone_flight"], + "indeces": [[0, -1], [0, -1], [1205, 1505]] } } -} \ No newline at end of file +} diff --git a/scripts/train.py b/scripts/train/train.py similarity index 95% rename from scripts/train.py rename to scripts/train/train.py index 843038f..e5cec42 100644 --- a/scripts/train.py +++ b/scripts/train/train.py @@ -29,8 +29,9 @@ def main(config, checkpoints, datasets, results): loaders = {} for task in config['task'].keys(): dataset_list = config['task'][task]['datasets'] + index_list = config['task'][task].get('indeces', None) datasets_dir = Path(datasets) / config['task'][task]['dir'] - loader, _ = interface.get_dataset_loader(task, str(datasets_dir), dataset_list) + loader, _ = interface.get_dataset_loader(task, str(datasets_dir), dataset_list, index_list) loaders[task] = loader for epoch in range(config['epochs']): diff --git a/tests/resources/test_evaluation.json b/tests/resources/test_evaluation.json index 0251721..86f983f 100644 --- a/tests/resources/test_evaluation.json +++ b/tests/resources/test_evaluation.json @@ -3,6 +3,10 @@ "Training": "tiny_dataset", "Industrial Hall": "tiny_dataset_validation" }, + "index": { + "tiny_dataset": [0, -1], + "tiny_dataset_validation": [0, 3] + }, "networks": { "RGB": null, "Naive": null,