Skip to content

Commit

Permalink
Evaluation: Finish up (#5)
Browse files Browse the repository at this point in the history
* train: Create dir

* blearndataset: Add dataset range

* eval: Fix index limits for eval

* eval: Fix latex table generation
  • Loading branch information
marcojob authored Oct 8, 2024
1 parent f35ffdd commit 8f478a9
Show file tree
Hide file tree
Showing 15 changed files with 71 additions and 33 deletions.
3 changes: 3 additions & 0 deletions .devcontainer/devcontainer_all_packages.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ main() {
ccache
cm-super
curl
dvipng
gawk
gnupg
htop
Expand All @@ -27,6 +28,8 @@ main() {
software-properties-common
ssh
sudo
texlive-fonts-recommended
texlive-latex-extra
udev
unzip
usbutils
Expand Down
6 changes: 3 additions & 3 deletions ci/pr_train_networks.bash
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ config_relative=$config_path/test_train_relative.json
pip install -e .

# RADAR TRAINING (depth prior + 2 output channels)
python3 scripts/train.py \
python3 scripts/train/train.py \
--checkpoints $checkpoints \
--config $config_radar \
--datasets $datasets \
Expand All @@ -27,7 +27,7 @@ else
fi

# RGB TRAINING (no depth prior + 1 output channel)
python3 scripts/train.py \
python3 scripts/train/train.py \
--checkpoints $checkpoints \
--config $config_metric \
--datasets $datasets \
Expand All @@ -41,7 +41,7 @@ else
fi

# Relative RGB TRAINING (no depth prior + 1 output channel)
python3 scripts/train.py \
python3 scripts/train/train.py \
--checkpoints $checkpoints \
--config $config_relative \
--datasets $datasets \
Expand Down
14 changes: 10 additions & 4 deletions radarmeetsvision/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ def __init__(self):
self.use_depth_prior = None

def reset_previous_best(self):
return {'d1': 0, 'd2': 0, 'd3': 0, 'abs_rel': 100, 'sq_rel': 100, 'rmse': 100, 'rmse_log': 100, 'log10': 100, 'silog': 100}
return {'d1': 0, 'd2': 0, 'd3': 0, 'abs_rel': 100, 'sq_rel': 100, 'rmse': 100, 'rmse_log': 100, 'log10': 100, 'silog': 100,
'average_depth': 0.0}

def set_use_depth_prior(self, use):
self.use_depth_prior = use
Expand Down Expand Up @@ -112,13 +113,18 @@ def set_optimizer(self, lr=0.000005):
return self.optimizer


def get_dataset_loader(self, task, datasets_dir, dataset_list):
def get_dataset_loader(self, task, datasets_dir, dataset_list, index_list=None):
datasets = []
datasets_dir = Path(datasets_dir)
for dataset_name in dataset_list:
for i, dataset_name in enumerate(dataset_list):
dataset_dir = datasets_dir / dataset_name

dataset = BlearnDataset(dataset_dir, task, self.size)
index_min, index_max = 0, -1
if index_list != None:
index_min, index_max = index_list[i][0], index_list[i][1]

dataset = BlearnDataset(dataset_dir, task, self.size, index_min, index_max)

if len(dataset) > 0:
datasets.append(dataset)

Expand Down
16 changes: 11 additions & 5 deletions radarmeetsvision/metric_depth_network/dataset/blearndataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
logger = logging.getLogger(__name__)

class BlearnDataset(Dataset):
def __init__(self, dataset_dir, mode, size):
def __init__(self, dataset_dir, mode, size, index_min=0, index_max=-1):
self.mode = mode
self.size = size

Expand Down Expand Up @@ -65,20 +65,25 @@ def __init__(self, dataset_dir, mode, size):
self.width = size[1]
self.train_split = 0.8

self.filelist = self.get_filelist()
self.filelist = self.get_filelist(index_min, index_max)
if self.filelist:
logger.info(f"Loaded {dataset_dir} with length {len(self.filelist)}")

def get_filelist(self):
# Sort and find all .jpg's

def get_filelist(self, index_min=0, index_max=-1):
all_rgb_files = list(self.rgb_dir.glob('*.jpg'))
all_rgb_files = sorted(all_rgb_files)
all_indexes = []

for rgb_file in all_rgb_files:
out = re.search(self.rgb_mask, str(rgb_file))
if out is not None:
all_indexes.append(int(out.group(1)))

if index_min != 0 or index_max != -1:
logger.info(f'Limiting dataset index: {index_min} - {index_max}')
all_indexes = [idx for idx in all_indexes if index_min <= idx <= index_max]

train_split_len = round(len(all_indexes) * self.train_split)
val_split_len = len(all_indexes) - train_split_len

Expand All @@ -94,12 +99,13 @@ def get_filelist(self):
else:
filelist = None

# Default case is to use full split
# Log the number of files selected
if filelist is not None:
logger.info(f"Using {len(filelist)}/{len(all_indexes)} for task {self.mode}")

return filelist


def __getitem__(self, item):
index = int(self.filelist[item])
img_path = self.rgb_dir / self.rgb_template.format(index)
Expand Down
6 changes: 4 additions & 2 deletions radarmeetsvision/metric_depth_network/util/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ def get_empty_results(device):
'rmse': 0.0,
'rmse_log': 0.0,
'log10': 0.0,
'silog': 0.0
'silog': 0.0,
'average_depth': 0.0
}
results_per_sample = {
'd1': [],
Expand All @@ -47,7 +48,8 @@ def get_empty_results(device):
'rmse': [],
'rmse_log': [],
'log10': [],
'silog': []
'silog': [],
'average_depth': []
}
nsamples = torch.tensor([0.0]).to(device)
return results, results_per_sample, nsamples
Expand Down
5 changes: 5 additions & 0 deletions scripts/evaluation/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@
"Agricultural Field": "outdoor0",
"Rhône Glacier": "rhone_flight"
},
"index": {
"maschinenhalle0": [0, -1],
"outdoor0": [0, -1],
"rhone_flight": [1205, 1505]
},
"networks": {
"Metric Depth \\cite{depthanythingv2}-S": "rgb_s_bs8_e9.pth",
"Metric Depth \\cite{depthanythingv2}-B": "rgb_b_bs4_e8.pth",
Expand Down
15 changes: 10 additions & 5 deletions scripts/evaluation/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@ def __init__(self, config, scenario_key, network_key, args):
self.results_per_sample = {}
self.results_dict = {}
self.interface = rmv.Interface()
self.networks_dir = Path(args.network)
self.networks_dir = None
if args.network is not None:
self.networks_dir = Path(args.network)

self.datasets_dir = args.dataset
self.results, self.results_per_sample = None, None
self.setup_interface(config, scenario_key, network_key)
Expand All @@ -43,17 +46,19 @@ def setup_interface(self, config, scenario_key, network_key):
self.interface.set_output_channels(network_config['output_channels'])
self.interface.set_use_depth_prior(network_config['use_depth_prior'])

network_file = config['networks'][network_key]
if network_file is not None:
network_file = self.networks_dir / network_file
network_file = None
if self.networks_dir is not None and config['networks'][network_key] is not None:
network_file = self.networks_dir / config['networks'][network_key]

self.interface.load_model(pretrained_from=network_file)

self.interface.set_size(config['height'], config['width'])
self.interface.set_batch_size(1)
self.interface.set_criterion()

dataset_list = [config['scenarios'][scenario_key]]
self.loader, _ = self.interface.get_dataset_loader('val_all', self.datasets_dir, dataset_list)
index_list = [config['index'][config["scenarios"][scenario_key]]]
self.loader, _ = self.interface.get_dataset_loader('val_all', self.datasets_dir, dataset_list, index_list)

def get_results_per_sample(self):
return self.results_per_sample
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@
"task": {
"train_all": {
"dir": "training",
"datasets": ["rhone2", "mountain_area", "rural_area", "road_corridor", "HyperSim"]
"datasets": ["rhone2", "mountain_area", "rural_area", "road_corridor", "HyperSim"],
},
"val_all": {
"dir": "validation",
"datasets": ["outdoor0", "maschinenhalle0", "rhone_flight"]
"datasets": ["outdoor0", "maschinenhalle0", "rhone_flight"],
"indeces": [[0, -1], [0, -1], [1205, 1505]]
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
},
"val_all": {
"dir": "validation",
"datasets": ["outdoor0", "maschinenhalle0", "rhone_flight"]
"datasets": ["outdoor0", "maschinenhalle0", "rhone_flight"],
"indeces": [[0, -1], [0, -1], [1205, 1505]]
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
},
"val_all": {
"dir": "validation",
"datasets": ["outdoor0", "maschinenhalle0", "rhone_flight"]
"datasets": ["outdoor0", "maschinenhalle0", "rhone_flight"],
"indeces": [[0, -1], [0, -1], [1205, 1505]]
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
},
"val_all": {
"dir": "validation",
"datasets": ["outdoor0", "maschinenhalle0", "rhone_flight"]
"datasets": ["outdoor0", "maschinenhalle0", "rhone_flight"],
"indeces": [[0, -1], [0, -1], [1205, 1505]]
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
},
"val_all": {
"dir": "validation",
"datasets": ["outdoor0", "maschinenhalle0", "rhone_flight"]
"datasets": ["outdoor0", "maschinenhalle0", "rhone_flight"],
"indeces": [[0, -1], [0, -1], [1205, 1505]]
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
},
"val_all": {
"dir": "validation",
"datasets": ["outdoor0", "maschinenhalle0", "rhone_flight"]
"datasets": ["outdoor0", "maschinenhalle0", "rhone_flight"],
"indeces": [[0, -1], [0, -1], [1205, 1505]]
}
}
}
}
3 changes: 2 additions & 1 deletion scripts/train.py → scripts/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@ def main(config, checkpoints, datasets, results):
loaders = {}
for task in config['task'].keys():
dataset_list = config['task'][task]['datasets']
index_list = config['task'][task].get('indeces', None)
datasets_dir = Path(datasets) / config['task'][task]['dir']
loader, _ = interface.get_dataset_loader(task, str(datasets_dir), dataset_list)
loader, _ = interface.get_dataset_loader(task, str(datasets_dir), dataset_list, index_list)
loaders[task] = loader

for epoch in range(config['epochs']):
Expand Down
4 changes: 4 additions & 0 deletions tests/resources/test_evaluation.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
"Training": "tiny_dataset",
"Industrial Hall": "tiny_dataset_validation"
},
"index": {
"tiny_dataset": [0, -1],
"tiny_dataset_validation": [0, 3]
},
"networks": {
"RGB": null,
"Naive": null,
Expand Down

0 comments on commit 8f478a9

Please sign in to comment.