Skip to content

Commit

Permalink
interface: Add methods to get results
Browse files Browse the repository at this point in the history
  • Loading branch information
marcojob committed Sep 27, 2024
1 parent f73eb58 commit e84d83b
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 21 deletions.
27 changes: 16 additions & 11 deletions radarmeetsvision/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def __init__(self):
self.optimizer = None
self.output_channels = None
self.previous_best = self.reset_previous_best()
self.results = None
self.results_path = None
self.use_depth_prior = None

Expand Down Expand Up @@ -71,6 +72,9 @@ def set_results_path(self, results_path):
else:
logger.error(f'{self.results_path} does not exist')

def get_results(self):
return self.results, self.results_per_sample

def load_model(self, pretrained_from=None):
if self.encoder is not None and self.max_depth is not None and self.output_channels is not None and self.use_depth_prior is not None:
logger.info(f'Using pretrained from: {pretrained_from}')
Expand Down Expand Up @@ -193,7 +197,7 @@ def train_epoch(self, epoch, train_loader):
def validate_epoch(self, epoch, val_loader):
self.model.eval()

results, nsamples = get_empty_results(self.device)
self.results, self.results_per_sample, nsamples = get_empty_results(self.device)
for i, sample in enumerate(val_loader):
image, _, depth_target, mask = self.prepare_sample(sample, random_flip=True)

Expand All @@ -206,21 +210,22 @@ def validate_epoch(self, epoch, val_loader):

current_results = eval_depth(depth_prediction[mask], depth_target[mask])
if current_results is not None:
for k in results.keys():
results[k] += current_results[k]
for k in self.results.keys():
self.results[k] += current_results[k]
self.results_per_sample[k].append(current_results[k])
nsamples += 1

self.update_best_result(results, nsamples)
self.update_best_result(self.results, nsamples)
self.save_checkpoint(epoch)


def save_checkpoint(self, epoch):
checkpoint = {
'model': self.model.state_dict(),
'optimizer': self.optimizer.state_dict(),
'epoch': epoch,
'previous_best': self.previous_best,
}
# TODO: How to check properly if current path is not .
if self.results_path is not None and len(str(self.results_path)) > 1:
checkpoint = {
'model': self.model.state_dict(),
'optimizer': self.optimizer.state_dict(),
'epoch': epoch,
'previous_best': self.previous_best,
}
# TODO: How to check properly if current path is not .
torch.save(checkpoint, self.results_path / f'latest_{epoch}.pth')
31 changes: 21 additions & 10 deletions radarmeetsvision/metric_depth_network/util/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,18 +28,29 @@ def print_epoch_summary(epoch, epochs, result_dict):

def get_empty_results(device):
results = {
'd1': torch.tensor([0.0]).to(device),
'd2': torch.tensor([0.0]).to(device),
'd3': torch.tensor([0.0]).to(device),
'abs_rel': torch.tensor([0.0]).to(device),
'sq_rel': torch.tensor([0.0]).to(device),
'rmse': torch.tensor([0.0]).to(device),
'rmse_log': torch.tensor([0.0]).to(device),
'log10': torch.tensor([0.0]).to(device),
'silog': torch.tensor([0.0]).to(device)
'd1': 0.0,
'd2': 0.0,
'd3': 0.0,
'abs_rel': 0.0,
'sq_rel': 0.0,
'rmse': 0.0,
'rmse_log': 0.0,
'log10': 0.0,
'silog': 0.0
}
results_per_sample = {
'd1': [],
'd2': [],
'd3': [],
'abs_rel': [],
'sq_rel': [],
'rmse': [],
'rmse_log': [],
'log10': [],
'silog': []
}
nsamples = torch.tensor([0.0]).to(device)
return results, nsamples
return results, results_per_sample, nsamples

def randomly_flip(img, target, valid_mask):
if random.random() < 0.5:
Expand Down

0 comments on commit e84d83b

Please sign in to comment.