diff --git a/configs/default_config.py b/configs/default_config.py index fa2c09d3..236a95ab 100644 --- a/configs/default_config.py +++ b/configs/default_config.py @@ -31,8 +31,11 @@ ######################################################################################################################## cfg.save = CN() cfg.save.folder = '' # Folder where data will be saved -cfg.save.viz = True # Flag for saving inverse depth map visualization -cfg.save.npz = True # Flag for saving numpy depth maps +cfg.save.depth = CN() +cfg.save.depth.rgb = True # Flag for saving rgb images +cfg.save.depth.viz = True # Flag for saving inverse depth map visualization +cfg.save.depth.npz = True # Flag for saving numpy depth maps +cfg.save.depth.png = True # Flag for saving png depth maps ######################################################################################################################## ### WANDB ######################################################################################################################## diff --git a/configs/eval_ddad.yaml b/configs/eval_ddad.yaml index bc1e153b..ab822205 100644 --- a/configs/eval_ddad.yaml +++ b/configs/eval_ddad.yaml @@ -21,5 +21,8 @@ datasets: cameras: [['camera_01']] save: folder: '/data/save' - viz: True - npz: True \ No newline at end of file + depth: + rgb: True + viz: True + npz: True + png: True diff --git a/configs/eval_image.yaml b/configs/eval_image.yaml index 0184e3f4..926c05b9 100644 --- a/configs/eval_image.yaml +++ b/configs/eval_image.yaml @@ -15,5 +15,8 @@ datasets: split: ['{:010d}'] save: folder: '/data/save' - viz: True - npy: True + depth: + rgb: True + viz: True + npz: True + png: True diff --git a/configs/eval_kitti.yaml b/configs/eval_kitti.yaml index 0cde0d7a..bd3bb274 100644 --- a/configs/eval_kitti.yaml +++ b/configs/eval_kitti.yaml @@ -20,5 +20,8 @@ datasets: depth_type: ['velodyne'] save: folder: '/data/save' - viz: True - npz: True \ No newline at end of file + depth: + rgb: True + viz: True + npz: True + png: True diff --git a/packnet_sfm/__init__.py b/packnet_sfm/__init__.py index 6a89a8ad..81cde5cf 100644 --- a/packnet_sfm/__init__.py +++ b/packnet_sfm/__init__.py @@ -19,8 +19,3 @@ Furthermore, it gets better with input resolution and number of parameters, generalizes better, and can run in real-time (with TensorRT). See [References](#references) for more info on our models. """ - -from packnet_sfm.models import ModelWrapper, ModelCheckpoint -from packnet_sfm.trainers import HorovodTrainer - -__all__ = ["ModelWrapper", "HorovodTrainer", "ModelCheckpoint"] \ No newline at end of file diff --git a/packnet_sfm/datasets/__init__.py b/packnet_sfm/datasets/__init__.py index e528d685..a22702ce 100644 --- a/packnet_sfm/datasets/__init__.py +++ b/packnet_sfm/datasets/__init__.py @@ -9,13 +9,3 @@ - ImageDataset: reads from a folder containing image sequences (no support for depth maps) """ - -from packnet_sfm.datasets.kitti_dataset import KITTIDataset -from packnet_sfm.datasets.dgp_dataset import DGPDataset -from packnet_sfm.datasets.image_dataset import ImageDataset - -__all__ = [ - "KITTIDataset", - "DGPDataset", - "ImageDataset", -] diff --git a/packnet_sfm/datasets/kitti_dataset.py b/packnet_sfm/datasets/kitti_dataset.py index ad854fc8..8403b5f3 100644 --- a/packnet_sfm/datasets/kitti_dataset.py +++ b/packnet_sfm/datasets/kitti_dataset.py @@ -13,10 +13,12 @@ ######################################################################################################################## +# Cameras from the stero pair (left is the origin) IMAGE_FOLDER = { 'left': 'image_02', 'right': 'image_03', } +# Name of different calibration files CALIB_FILE = { 'cam2cam': 'calib_cam_to_cam.txt', 'velo2cam': 'calib_velo_to_cam.txt', @@ -144,9 +146,12 @@ def _get_parent_folder(image_file): return os.path.abspath(os.path.join(image_file, "../../../..")) @staticmethod - def _get_intrinsics(calib_data): + def _get_intrinsics(image_file, calib_data): """Get intrinsics from the calib_data dictionary.""" - return np.reshape(calib_data['P_rect_02'], (3, 4))[:, :3] + for cam in ['left', 'right']: + # Check for both cameras, if found replace and return intrinsics + if IMAGE_FOLDER[cam] in image_file: + return np.reshape(calib_data[IMAGE_FOLDER[cam].replace('image', 'P_rect')], (3, 4))[:, :3] @staticmethod def _read_raw_calib_file(folder): @@ -358,7 +363,7 @@ def __getitem__(self, idx): c_data = self._read_raw_calib_file(parent_folder) self.calibration_cache[parent_folder] = c_data sample.update({ - 'intrinsics': self._get_intrinsics(c_data), + 'intrinsics': self._get_intrinsics(self.paths[idx], c_data), }) # Add pose information if requested diff --git a/packnet_sfm/models/__init__.py b/packnet_sfm/models/__init__.py index 9f6c6bba..84cfd989 100644 --- a/packnet_sfm/models/__init__.py +++ b/packnet_sfm/models/__init__.py @@ -9,17 +9,3 @@ - ModelCheckpoint enables saving/restoring state of torch.nn.Module objects """ - -from packnet_sfm.models.model_checkpoint import ModelCheckpoint -from packnet_sfm.models.model_wrapper import ModelWrapper -from packnet_sfm.models.SfmModel import SfmModel -from packnet_sfm.models.SelfSupModel import SelfSupModel -from packnet_sfm.models.SemiSupModel import SemiSupModel - -__all__ = [ - "ModelCheckpoint", - "ModelWrapper", - "SfmModel", - "SelfSupModel", - "SemiSupModel", - ] \ No newline at end of file diff --git a/packnet_sfm/models/model_wrapper.py b/packnet_sfm/models/model_wrapper.py index 0eacd162..520906dd 100644 --- a/packnet_sfm/models/model_wrapper.py +++ b/packnet_sfm/models/model_wrapper.py @@ -8,7 +8,6 @@ import torch from torch.utils.data import ConcatDataset, DataLoader -from packnet_sfm.datasets import KITTIDataset, DGPDataset, ImageDataset from packnet_sfm.datasets.transforms import get_transforms from packnet_sfm.utils.depth import inv2depth, post_process_inv_depth, compute_depth_metrics from packnet_sfm.utils.horovod import print0, world_size, rank, on_rank_0 @@ -516,12 +515,14 @@ def setup_dataset(config, mode, requirements, **kwargs): # KITTI dataset if config.dataset[i] == 'KITTI': + from packnet_sfm.datasets.kitti_dataset import KITTIDataset dataset = KITTIDataset( config.path[i], path_split, **dataset_args, **dataset_args_i, ) # DGP dataset elif config.dataset[i] == 'DGP': + from packnet_sfm.datasets.dgp_dataset import DGPDataset dataset = DGPDataset( config.path[i], config.split[i], **dataset_args, **dataset_args_i, @@ -529,6 +530,7 @@ def setup_dataset(config, mode, requirements, **kwargs): ) # Image dataset elif config.dataset[i] == 'Image': + from packnet_sfm.datasets.image_dataset import ImageDataset dataset = ImageDataset( config.path[i], config.split[i], **dataset_args, **dataset_args_i, diff --git a/packnet_sfm/utils/depth.py b/packnet_sfm/utils/depth.py index 2842aa7d..1eb61fbe 100644 --- a/packnet_sfm/utils/depth.py +++ b/packnet_sfm/utils/depth.py @@ -1,13 +1,68 @@ # Copyright 2020 Toyota Research Institute. All rights reserved. -from matplotlib.cm import get_cmap -import torch import numpy as np -from packnet_sfm.utils.image import \ - gradient_x, gradient_y, flip_lr, interpolate_image +import torch +import torchvision.transforms as transforms +from matplotlib.cm import get_cmap + +from packnet_sfm.utils.image import load_image, gradient_x, gradient_y, flip_lr, interpolate_image from packnet_sfm.utils.types import is_seq, is_tensor +def load_depth(file): + """ + Load a depth map from file + Parameters + ---------- + file : str + Depth map filename (.npz or .png) + + Returns + ------- + depth : np.array [H,W] + Depth map (invalid pixels are 0) + """ + if file.endswith('npz'): + return np.load(file)['depth'] + elif file.endswith('png'): + depth_png = np.array(load_image(file), dtype=int) + assert (np.max(depth_png) > 255), 'Wrong .png depth file' + return depth_png.astype(np.float) / 256. + else: + raise NotImplementedError('Depth extension not supported.') + + +def write_depth(filename, depth, intrinsics=None): + """ + Write a depth map to file, and optionally its corresponding intrinsics. + + Parameters + ---------- + filename : str + File where depth map will be saved (.npz or .png) + depth : np.array [H,W] + Depth map + intrinsics : np.array [3,3] + Optional camera intrinsics matrix + """ + # If depth is a tensor + if is_tensor(depth): + depth = depth.detach().squeeze().cpu() + # If intrinsics is a tensor + if is_tensor(intrinsics): + intrinsics = intrinsics.detach().cpu() + # If we are saving as a .npz + if filename.endswith('.npz'): + np.savez_compressed(filename, depth=depth, intrinsics=intrinsics) + # If we are saving as a .png + elif filename.endswith('.png'): + depth = transforms.ToPILImage()((depth * 256).int()) + depth.save(filename) + # Something is wrong + else: + raise NotImplementedError('Depth filename not valid.') + + def viz_inv_depth(inv_depth, normalizer=None, percentile=95, colormap='plasma', filter_zeros=False): """ diff --git a/packnet_sfm/utils/horovod.py b/packnet_sfm/utils/horovod.py index 05aa2c8c..5e825a9e 100644 --- a/packnet_sfm/utils/horovod.py +++ b/packnet_sfm/utils/horovod.py @@ -1,10 +1,15 @@ -import horovod.torch as hvd +try: + import horovod.torch as hvd + HAS_HOROVOD = True +except ImportError: + HAS_HOROVOD = False -######################################################################################################################## def hvd_init(): - hvd.init() + if HAS_HOROVOD: + hvd.init() + return HAS_HOROVOD def on_rank_0(func): def wrapper(*args, **kwargs): @@ -13,13 +18,31 @@ def wrapper(*args, **kwargs): return wrapper def rank(): - return hvd.rank() + return hvd.rank() if HAS_HOROVOD else 0 def world_size(): - return hvd.size() + return hvd.size() if HAS_HOROVOD else 1 @on_rank_0 def print0(string='\n'): print(string) -######################################################################################################################## +def reduce_value(value, average, name): + """ + Reduce the mean value of a tensor from all GPUs + + Parameters + ---------- + value : torch.Tensor + Value to be reduced + average : bool + Whether values will be averaged or not + name : str + Value name + + Returns + ------- + value : torch.Tensor + reduced value + """ + return hvd.allreduce(value, average=average, name=name) diff --git a/packnet_sfm/utils/image.py b/packnet_sfm/utils/image.py index 77d0f671..78c69118 100644 --- a/packnet_sfm/utils/image.py +++ b/packnet_sfm/utils/image.py @@ -1,5 +1,6 @@ # Copyright 2020 Toyota Research Institute. All rights reserved. +import cv2 import torch import torch.nn.functional as funct from functools import lru_cache @@ -7,7 +8,6 @@ from packnet_sfm.utils.misc import same_shape -######################################################################################################################## def load_image(path): """ @@ -25,7 +25,20 @@ def load_image(path): """ return Image.open(path) -######################################################################################################################## + +def write_image(filename, image): + """ + Write an image to file. + + Parameters + ---------- + filename : str + File where image will be saved + image : np.array [H,W,3] + RGB image + """ + cv2.imwrite(filename, image[:, :, ::-1]) + def flip_lr(image): """ diff --git a/packnet_sfm/utils/logging.py b/packnet_sfm/utils/logging.py index ed12876c..b5f446fd 100644 --- a/packnet_sfm/utils/logging.py +++ b/packnet_sfm/utils/logging.py @@ -30,7 +30,7 @@ def pcolor(string, color, on_color=None, attrs=None): return colored(string, color, on_color, attrs) -def prepare_dataset_prefix(config, n): +def prepare_dataset_prefix(config, dataset_idx): """ Concatenates dataset path and split for metrics logging @@ -38,7 +38,7 @@ def prepare_dataset_prefix(config, n): ---------- config : CfgNode Dataset configuration - n : int + dataset_idx : int Dataset index for multiple datasets Returns @@ -46,13 +46,18 @@ def prepare_dataset_prefix(config, n): prefix : str Dataset prefix for metrics logging """ - prefix = '{}-{}'.format( - os.path.splitext(config.path[n].split('/')[-1])[0], - os.path.splitext(os.path.basename(config.split[n]))[0]) - if config.depth_type[n] is not '': - prefix += '-{}'.format(config.depth_type[n]) - if len(config.cameras[n]) == 1: # only allows single cameras - prefix += '-{}'.format(config.cameras[n][0]) + # Path is always available + prefix = '{}'.format(os.path.splitext(config.path[dataset_idx].split('/')[-1])[0]) + # If split is available and does not contain { character + if config.split[dataset_idx] != '' and '{' not in config.split[dataset_idx]: + prefix += '-{}'.format(os.path.splitext(os.path.basename(config.split[dataset_idx]))[0]) + # If depth type is available + if config.depth_type[dataset_idx] != '': + prefix += '-{}'.format(config.depth_type[dataset_idx]) + # If we are using specific cameras + if len(config.cameras[dataset_idx]) == 1: # only allows single cameras + prefix += '-{}'.format(config.cameras[dataset_idx][0]) + # Return full prefix return prefix diff --git a/packnet_sfm/utils/reduce.py b/packnet_sfm/utils/reduce.py index c91ca554..028531ff 100644 --- a/packnet_sfm/utils/reduce.py +++ b/packnet_sfm/utils/reduce.py @@ -1,25 +1,51 @@ import torch -import horovod.torch as hvd import numpy as np from collections import OrderedDict +from packnet_sfm.utils.horovod import reduce_value from packnet_sfm.utils.logging import prepare_dataset_prefix -######################################################################################################################## -def reduce_value(value): - """Reduce the mean value of a tensor from all GPUs""" - return hvd.allreduce(value, average=True, name='value') +def reduce_dict(data, to_item=False): + """ + Reduce the mean values of a dictionary from all GPUs + + Parameters + ---------- + data : dict + Dictionary to be reduced + to_item : bool + True if the reduced values will be return as .item() -def reduce_dict(dict, to_item=False): - """Reduce the mean values of a dictionary from all GPUs""" - for key, val in dict.items(): - dict[key] = reduce_value(dict[key]) + Returns + ------- + dict : dict + Reduced dictionary + """ + for key, val in data.items(): + data[key] = reduce_value(data[key], average=True, name=key) if to_item: - dict[key] = dict[key].item() - return dict + data[key] = data[key].item() + return data def all_reduce_metrics(output_data_batch, datasets, name='depth'): + """ + Reduce metrics for all batches and all datasets using Horovod + + Parameters + ---------- + output_data_batch : list + List of outputs for each batch + datasets : list + List of all considered datasets + name : str + Name of the task for the metric + + Returns + ------- + all_metrics_dict : list + List of reduced metrics + """ # If there is only one dataset, wrap in a list if isinstance(output_data_batch[0], dict): output_data_batch = [output_data_batch] @@ -37,7 +63,7 @@ def all_reduce_metrics(output_data_batch, datasets, name='depth'): for output in output_batch: for i, idx in enumerate(output['idx']): seen[idx] += 1 - seen = hvd.allreduce(seen, average=False, name='idx') + seen = reduce_value(seen, average=False, name='idx') assert not np.any(seen.numpy() == 0), \ 'Not all samples were seen during evaluation' # Reduce all relevant metrics @@ -46,7 +72,7 @@ def all_reduce_metrics(output_data_batch, datasets, name='depth'): for output in output_batch: for i, idx in enumerate(output['idx']): metrics[idx] = output[name] - metrics = hvd.allreduce(metrics, average=False, name='depth_pp_gt') + metrics = reduce_value(metrics, average=False, name=name) metrics_dict[name] = (metrics / seen.view(-1, 1)).mean(0) # Append metrics dictionary to the list all_metrics_dict.append(metrics_dict) @@ -56,8 +82,21 @@ def all_reduce_metrics(output_data_batch, datasets, name='depth'): ######################################################################################################################## def collate_metrics(output_data_batch, name='depth'): - """Collate epoch output to produce average metrics.""" + """ + Collate epoch output to produce average metrics + + Parameters + ---------- + output_data_batch : list + List of outputs for each batch + name : str + Name of the task for the metric + Returns + ------- + metrics_data : list + List of collated metrics + """ # If there is only one dataset, wrap in a list if isinstance(output_data_batch[0], dict): output_data_batch = [output_data_batch] @@ -77,8 +116,27 @@ def collate_metrics(output_data_batch, name='depth'): def create_dict(metrics_data, metrics_keys, metrics_modes, dataset, name='depth'): - """Creates a dictionary from collated metrics.""" + """ + Creates a dictionary from collated metrics + + Parameters + ---------- + metrics_data : list + List containing collated metrics + metrics_keys : list + List of keys for the metrics + metrics_modes + List of modes for the metrics + dataset : CfgNode + Dataset configuration file + name : str + Name of the task for the metric + Returns + ------- + metrics_dict : dict + Metrics dictionary + """ # Create metrics dictionary metrics_dict = {} # For all datasets diff --git a/packnet_sfm/utils/save.py b/packnet_sfm/utils/save.py index 9161ac31..9c923712 100644 --- a/packnet_sfm/utils/save.py +++ b/packnet_sfm/utils/save.py @@ -1,11 +1,11 @@ # Copyright 2020 Toyota Research Institute. All rights reserved. -import cv2 import numpy as np import os +from packnet_sfm.utils.image import write_image +from packnet_sfm.utils.depth import write_depth, inv2depth, viz_inv_depth from packnet_sfm.utils.logging import prepare_dataset_prefix -from packnet_sfm.utils.depth import inv2depth, viz_inv_depth def save_depth(batch, output, args, dataset, save): @@ -29,8 +29,8 @@ def save_depth(batch, output, args, dataset, save): if save.folder is '': return - # If we want to save depth maps - if save.viz or save.npz: + # If we want to save + if save.depth.rgb or save.depth.viz or save.depth.npz or save.depth.png: # Retrieve useful tensors rgb = batch['rgb'] pred_inv_depth = output['inv_depth'] @@ -39,8 +39,8 @@ def save_depth(batch, output, args, dataset, save): filename = batch['filename'] dataset_idx = 0 if len(args) == 1 else args[1] save_path = os.path.join(save.folder, 'depth', - prepare_dataset_prefix(dataset, dataset_idx), - os.path.basename(save.pretrained).split('.')[0]) + prepare_dataset_prefix(dataset, dataset_idx), + os.path.basename(save.pretrained).split('.')[0]) # Create folder os.makedirs(save_path, exist_ok=True) @@ -48,18 +48,19 @@ def save_depth(batch, output, args, dataset, save): length = rgb.shape[0] for i in range(length): # Save numpy depth maps - if save.npz: - # Get depth from predicted depth map and save to .npz - np.savez_compressed('{}/{}.npz'.format(save_path, filename[i]), - depth=inv2depth(pred_inv_depth[i]).squeeze().detach().cpu().numpy()) - # Save inverse depth visualizations - if save.viz: - # Prepare RGB image + if save.depth.npz: + write_depth('{}/{}_depth.npz'.format(save_path, filename[i]), + depth=inv2depth(pred_inv_depth[i]), + intrinsics=batch['intrinsics'][i] if 'intrinsics' in batch else None) + # Save png depth maps + if save.depth.png: + write_depth('{}/{}_depth.png'.format(save_path, filename[i]), + depth=inv2depth(pred_inv_depth[i])) + # Save rgb images + if save.depth.rgb: rgb_i = rgb[i].permute(1, 2, 0).detach().cpu().numpy() * 255 - # Prepare inverse depth - pred_inv_depth_i = viz_inv_depth(pred_inv_depth[i]) * 255 - # Concatenate both vertically - image = np.concatenate([rgb_i, pred_inv_depth_i], 0) - # Write to disk - cv2.imwrite('{}/{}.png'.format( - save_path, filename[i]), image[:, :, ::-1]) + write_image('{}/{}_rgb.png'.format(save_path, filename[i]), rgb_i) + # Save inverse depth visualizations + if save.depth.viz: + viz_i = viz_inv_depth(pred_inv_depth[i]) * 255 + write_image('{}/{}_viz.png'.format(save_path, filename[i]), viz_i) diff --git a/scripts/eval.py b/scripts/eval.py index 1f9df504..dd6456ca 100644 --- a/scripts/eval.py +++ b/scripts/eval.py @@ -3,7 +3,8 @@ import argparse import torch -from packnet_sfm import ModelWrapper, HorovodTrainer +from packnet_sfm.models.model_wrapper import ModelWrapper +from packnet_sfm.trainers.horovod_trainer import HorovodTrainer from packnet_sfm.utils.config import parse_test_file from packnet_sfm.utils.load import set_debug from packnet_sfm.utils.horovod import hvd_init @@ -33,6 +34,8 @@ def test(ckpt_file, cfg_file, half): Checkpoint path for a pretrained model cfg_file : str Configuration file + half: bool + use half precision (fp16) """ # Initialize horovod hvd_init() diff --git a/scripts/evaluate_depth_maps.py b/scripts/evaluate_depth_maps.py index 6772206c..1b7b2350 100644 --- a/scripts/evaluate_depth_maps.py +++ b/scripts/evaluate_depth_maps.py @@ -1,12 +1,15 @@ -import os -from glob import glob +import argparse import numpy as np +import os import torch + +from glob import glob from argparse import Namespace +from packnet_sfm.utils.depth import load_depth +from tqdm import tqdm -from packnet_sfm.utils.depth import compute_depth_metrics -import argparse +from packnet_sfm.utils.depth import load_depth, compute_depth_metrics def parse_args(): @@ -28,53 +31,33 @@ def parse_args(): return args -def evaluate_depth_maps(pred_folder, gt_folder, use_gt_scale, **kwargs): - """ - Calculates depth metrics from a folder of predicted and ground-truth depth files - - Parameters - ---------- - pred_folder : str - Folder containing predicted depth maps (.npz with key 'depth') - gt_folder : str - Folder containing ground-truth depth maps (.npz with key 'depth') - use_gt_scale : bool - Using ground-truth median scaling or not - kwargs : dict - Extra parameters for depth evaluation - """ - # Get and sort ground-truth files - gt_files = glob(os.path.join(gt_folder, '*.npz')) +def main(args): + # Get and sort ground-truth and predicted files + exts = ('npz', 'png') + gt_files, pred_files = [], [] + for ext in exts: + gt_files.extend(glob(os.path.join(args.gt_folder, '*.{}'.format(ext)))) + pred_files.extend(glob(os.path.join(args.pred_folder, '*.{}'.format(ext)))) + # Sort ground-truth and prediction gt_files.sort() - # Get and sort predicted files - pred_files = glob(os.path.join(pred_folder, '*.npz')) pred_files.sort() - # Prepare configuration - config = Namespace(**kwargs) # Loop over all files metrics = [] - for gt, pred in zip(gt_files, pred_files): - # Get and prepare ground-truth - gt = np.load(gt)['depth'] - gt = torch.tensor(gt).unsqueeze(0).unsqueeze(0) - # Get and prepare predictions - pred = np.load(pred)['depth'] - pred = torch.tensor(pred).unsqueeze(0).unsqueeze(0) + progress_bar = tqdm(zip(gt_files, pred_files), total=len(gt_files)) + for gt, pred in progress_bar: + # Get and prepare ground-truth and predictions + gt = torch.tensor(load_depth(gt)).unsqueeze(0).unsqueeze(0) + pred = torch.tensor(load_depth(pred)).unsqueeze(0).unsqueeze(0) # Calculate metrics - metrics.append(compute_depth_metrics(config, gt, pred, - use_gt_scale=use_gt_scale)) + metrics.append(compute_depth_metrics( + args, gt, pred, use_gt_scale=args.use_gt_scale)) # Get and print average value metrics = (sum(metrics) / len(metrics)).detach().cpu().numpy() names = ['abs_rel', 'sqr_rel', 'rmse', 'rmse_log', 'a1', 'a2', 'a3'] for name, metric in zip(names, metrics): print('{} = {}'.format(name, metric)) + if __name__ == '__main__': args = parse_args() - evaluate_depth_maps(args.pred_folder, args.gt_folder, - use_gt_scale=args.use_gt_scale, - min_depth=args.min_depth, - max_depth=args.max_depth, - crop=args.crop) - - + main(args) diff --git a/scripts/infer.py b/scripts/infer.py index 5934cc8b..340fe1f6 100644 --- a/scripts/infer.py +++ b/scripts/infer.py @@ -1,19 +1,20 @@ # Copyright 2020 Toyota Research Institute. All rights reserved. +import argparse +import numpy as np import os import torch -import argparse + from glob import glob from cv2 import imwrite -import numpy as np -from packnet_sfm import ModelWrapper +from packnet_sfm.models.model_wrapper import ModelWrapper from packnet_sfm.datasets.augmentations import resize_image, to_tensor from packnet_sfm.utils.horovod import hvd_init, rank, world_size, print0 from packnet_sfm.utils.image import load_image from packnet_sfm.utils.config import parse_test_file from packnet_sfm.utils.load import set_debug -from packnet_sfm.utils.depth import inv2depth, viz_inv_depth +from packnet_sfm.utils.depth import write_depth, inv2depth, viz_inv_depth from packnet_sfm.utils.logging import pcolor @@ -27,11 +28,12 @@ def parse_args(): parser.add_argument('--checkpoint', type=str, help='Checkpoint (.ckpt)') parser.add_argument('--input', type=str, help='Input file or folder') parser.add_argument('--output', type=str, help='Output file or folder') - parser.add_argument('--image_shape', type=tuple, default=None, + parser.add_argument('--image_shape', type=int, nargs='+', default=None, help='Input and output image shape ' '(default: checkpoint\'s config.datasets.augmentation.image_shape)') parser.add_argument('--half', action="store_true", help='Use half precision (fp16)') - parser.add_argument('--save_npz', action='store_true', help='save in .npz format') + parser.add_argument('--save', type=str, choices=['npz', 'png'], default=None, + help='Save format (npz or png). Default is None (no depth map is saved).') args = parser.parse_args() assert args.checkpoint.endswith('.ckpt'), \ 'You need to provide a .ckpt file as checkpoint' @@ -44,7 +46,7 @@ def parse_args(): @torch.no_grad() -def infer_and_save_depth(input_file, output_file, model_wrapper, image_shape, half, save_npz): +def infer_and_save_depth(input_file, output_file, model_wrapper, image_shape, half, save): """ Process a single input file to produce and save visualization @@ -60,9 +62,8 @@ def infer_and_save_depth(input_file, output_file, model_wrapper, image_shape, ha Input image shape half: bool use half precision (fp16) - save_npz: bool - save .npz output depth maps if True, else save as png - + save: str + Save format (npz or png) """ if not is_image(output_file): # If not an image, assume it's a folder and append the input name @@ -85,14 +86,13 @@ def infer_and_save_depth(input_file, output_file, model_wrapper, image_shape, ha # Depth inference (returns predicted inverse depth) pred_inv_depth = model_wrapper.depth(image)[0] - if save_npz: - # Get depth from predicted depth map and save to .npz - depth = inv2depth(pred_inv_depth).squeeze().detach().cpu().numpy() - output_file = os.path.splitext(output_file)[0] + ".npz" + if save == 'npz' or save == 'png': + # Get depth from predicted depth map and save to different formats + filename = '{}.{}'.format(os.path.splitext(output_file)[0], save) print('Saving {} to {}'.format( pcolor(input_file, 'cyan', attrs=['bold']), - pcolor(output_file, 'magenta', attrs=['bold']))) - np.savez_compressed(output_file, depth=depth) + pcolor(filename, 'magenta', attrs=['bold']))) + write_depth(filename, depth=inv2depth(pred_inv_depth)) else: # Prepare RGB image rgb = image[0].permute(1, 2, 0).detach().cpu().numpy() * 255 @@ -152,7 +152,7 @@ def main(args): # Process each file for fn in files[rank()::world_size()]: infer_and_save_depth( - fn, args.output, model_wrapper, image_shape, args.half, args.save_npz) + fn, args.output, model_wrapper, image_shape, args.half, args.save) if __name__ == '__main__': diff --git a/scripts/train.py b/scripts/train.py index 8bef107a..d9759d76 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -2,7 +2,9 @@ import argparse -from packnet_sfm import ModelWrapper, ModelCheckpoint, HorovodTrainer +from packnet_sfm.models.model_wrapper import ModelWrapper +from packnet_sfm.models.model_checkpoint import ModelCheckpoint +from packnet_sfm.trainers.horovod_trainer import HorovodTrainer from packnet_sfm.utils.config import parse_train_file from packnet_sfm.utils.load import set_debug, filter_args_create from packnet_sfm.utils.horovod import hvd_init, rank