diff --git a/datasets.yml b/datasets.yml index 8a0175d..d9ae9d5 100644 --- a/datasets.yml +++ b/datasets.yml @@ -7,3 +7,6 @@ mpi_sintel: /path/to/MPI-Sintel kitti_2012: /path/to/KITTI/2012 kitti_2015: /path/to/KITTI/2015 hd1k: /path/to/HD1K +tartanair: /path/to/tartanair +spring: /path/to/spring +kubric: /path/to/kubric diff --git a/ptlflow/data/datasets.py b/ptlflow/data/datasets.py index 840e37e..6bbaef0 100644 --- a/ptlflow/data/datasets.py +++ b/ptlflow/data/datasets.py @@ -16,12 +16,14 @@ # limitations under the License. # ============================================================================= +import json import logging import math from pathlib import Path from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union import cv2 +from einops import rearrange import numpy as np import torch from torch.utils.data import Dataset @@ -123,6 +125,12 @@ def __init__( self.mb_b_paths = [] self.metadata = [] + self.flow_format = None + self.flow_read_mins = None + self.flow_read_maxs = None + self.flow_b_read_mins = None + self.flow_b_read_maxs = None + def __getitem__(self, index: int) -> Dict[str, torch.Tensor]: # noqa: C901 """Retrieve and return one input. @@ -145,7 +153,20 @@ def __getitem__(self, index: int) -> Dict[str, torch.Tensor]: # noqa: C901 inputs["images"] = [cv2.imread(str(path)) for path in self.img_paths[index]] if index < len(self.flow_paths): - inputs["flows"], valids = self._get_flows_and_valids(self.flow_paths[index]) + inputs["flows"], valids = self._get_flows_and_valids( + self.flow_paths[index], + flow_format=self.flow_format, + flow_min=self.flow_read_mins[index] + if ( + self.flow_read_mins is not None and len(self.flow_read_mins) > index + ) + else None, + flow_max=self.flow_read_maxs[index] + if ( + self.flow_read_maxs is not None and len(self.flow_read_maxs) > index + ) + else None, + ) if self.get_valid_mask: inputs["valids"] = valids @@ -160,7 +181,22 @@ def __getitem__(self, index: int) -> Dict[str, torch.Tensor]: # noqa: C901 str(p).replace("flow_occ", "flow_noc") for p in self.flow_paths[index] ] - _, valids_noc = self._get_flows_and_valids(noc_paths) + _, valids_noc = self._get_flows_and_valids( + noc_paths, + flow_format=self.flow_format, + flow_min=self.flow_read_mins[index] + if ( + self.flow_read_mins is not None + and len(self.flow_read_mins) > index + ) + else None, + flow_max=self.flow_read_maxs[index] + if ( + self.flow_read_maxs is not None + and len(self.flow_read_maxs) > index + ) + else None, + ) inputs["occs"] = [valids[i] - valids_noc[i] for i in range(len(valids))] if self.get_motion_boundary_mask and index < len(self.mb_paths): inputs["mbs"] = [ @@ -170,7 +206,20 @@ def __getitem__(self, index: int) -> Dict[str, torch.Tensor]: # noqa: C901 if self.get_backward: if index < len(self.flow_b_paths): inputs["flows_b"], valids_b = self._get_flows_and_valids( - self.flow_b_paths[index] + self.flow_b_paths[index], + flow_format=self.flow_format, + flow_min=self.flow_b_read_mins[index] + if ( + self.flow_b_read_mins is not None + and len(self.flow_b_read_mins) > index + ) + else None, + flow_max=self.flow_b_read_maxs[index] + if ( + self.flow_b_read_maxs is not None + and len(self.flow_b_read_maxs) > index + ) + else None, ) if self.get_valid_mask: inputs["valids_b"] = valids_b @@ -202,12 +251,18 @@ def __len__(self) -> int: return len(self.img_paths) def _get_flows_and_valids( - self, flow_paths: Sequence[str] + self, + flow_paths: Sequence[str], + flow_format: Optional[str] = None, + flow_min: Optional[float] = None, + flow_max: Optional[float] = None, ) -> Tuple[List[np.ndarray], List[Optional[np.ndarray]]]: flows = [] valids = [] for path in flow_paths: - flow = flow_utils.flow_read(path) + flow = flow_utils.flow_read( + path, format=flow_format, flow_min=flow_min, flow_max=flow_max + ) nan_mask = np.isnan(flow) flow[nan_mask] = self.max_flow + 1 @@ -1410,7 +1465,11 @@ def __init__( # noqa: C901 img2_paths ), f"{len(img1_paths)} vs {len(img2_paths)}" flow_paths = [] - if split != "test": + + if ( + split != "test" + or (Path(self.root_dir[ver]) / split_dir / "flow_occ").exists() + ): flow_paths = sorted( (Path(self.root_dir[ver]) / split_dir / "flow_occ").glob("*_10.png") ) @@ -1436,7 +1495,10 @@ def __init__( # noqa: C901 if img1_paths[i].stem not in remove_names ] ) - if split != "test": + if ( + split != "test" + or (Path(self.root_dir[ver]) / split_dir / "flow_occ").exists() + ): self.flow_paths.extend( [ [flow_paths[i]] @@ -1562,7 +1624,10 @@ def __init__( # noqa: C901 ) flow_paths = [] occ_paths = [] - if split != "test": + if ( + split != "test" + or (Path(self.root_dir) / split_dir / "flow").exists() + ): flow_paths = sorted( (Path(self.root_dir) / split_dir / "flow" / seq_name).glob( "*.flo" @@ -1661,7 +1726,7 @@ def __init__( # noqa: C901 as zero in the valid mask. get_valid_mask : bool, default True Whether to get or generate valid masks. - get_backward : bool, default True + get_backward : bool, default False Whether to get the backward version of the inputs. get_meta : bool, default True Whether to get metadata. @@ -1720,7 +1785,10 @@ def __init__( # noqa: C901 rev = direcs[0] == "BW" image_paths = sorted( ( - Path(self.root_dir) / split_dir / seq_name / f"frame_{side}" + Path(self.root_dir) + / split_dir + / seq_name + / f"frame_{side}" ).glob("*.png"), reverse=rev, ) @@ -1794,6 +1862,186 @@ def __init__( # noqa: C901 self._log_status() + def __getitem__(self, index: int) -> Dict[str, torch.Tensor]: # noqa: C901 + """Retrieve and return one input. + + Parameters + ---------- + index : int + The index of the entry on the input lists. + + Returns + ------- + Dict[str, torch.Tensor] + The retrieved input. This dict may contain the following keys, depending on the initialization choices: + ['images', 'flows', 'mbs', 'occs', 'valids', 'flows_b', 'mbs_b', 'occs_b', 'valids_b', 'meta']. + Except for 'meta', all the values are 4D tensors with shape NCHW. Notice that N does not correspond to the batch + size, but rather to the number of images of a given key. For example, typically 'images' will have N=2, and + 'flows' will have N=1, and so on. Therefore, a batch of these inputs will be a 5D tensor BNCHW. + """ + inputs = {} + + inputs["images"] = [cv2.imread(str(path)) for path in self.img_paths[index]] + + if index < len(self.flow_paths): + inputs["flows"], valids = self._get_flows_and_valids(self.flow_paths[index]) + if self.get_valid_mask: + inputs["valids"] = valids + + if self.get_backward: + if index < len(self.flow_b_paths): + inputs["flows_b"], valids_b = self._get_flows_and_valids( + self.flow_b_paths[index] + ) + if self.get_valid_mask: + inputs["valids_b"] = valids_b + + if self.transform is not None: + inputs = self.transform(inputs) + + inputs["flows"] = rearrange( + inputs["flows"], "b c (h nh) (w nw) -> b (nh nw) c h w", nh=2, nw=2 + ) + inputs["valids"] = inputs["valids"][:, :, ::2, ::2] + + if self.get_meta: + inputs["meta"] = { + "dataset_name": self.dataset_name, + "split_name": self.split_name, + } + if index < len(self.metadata): + inputs["meta"].update(self.metadata[index]) + + return inputs + + +class TartanAirDataset(BaseFlowDataset): + """Handle the TartanAir dataset.""" + + def __init__( # noqa: C901 + self, + root_dir: str, + difficulties: Union[str, List[str]] = "easy", + transform: Callable[[Dict[str, torch.Tensor]], Dict[str, torch.Tensor]] = None, + max_flow: float = 10000.0, + get_valid_mask: bool = True, + get_occlusion_mask: bool = True, + get_meta: bool = True, + sequence_length: int = 2, + sequence_position: str = "first", + ) -> None: + """Initialize SintelDataset. + + Parameters + ---------- + root_dir : str + path to the root directory of the MPI Sintel dataset. + split : str, default 'train' + Which split of the dataset should be loaded. It can be one of {'train', 'val', 'trainval', 'test'}. + difficulties : Union[str, List[str]], default 'easy' + Which difficulties should be loaded. It can be one of {'easy', 'hard', ['easy', 'hard']}. + transform : Callable[[Dict[str, torch.Tensor]], Dict[str, torch.Tensor]], optional + Transform to be applied on the inputs. + max_flow : float, default 10000.0 + Maximum optical flow absolute value. Flow absolute values that go over this limit are clipped, and also marked + as zero in the valid mask. + get_valid_mask : bool, default True + Whether to get or generate valid masks. + get_occlusion_mask : bool, default True + Whether to get occlusion masks. + get_meta : bool, default True + Whether to get metadata. + sequence_length : int, default 2 + How many consecutive images are loaded per sample. More than two images can be used for model which exploit more + temporal information. + sequence_position : str, default "first" + Only used when sequence_length > 2. + Determines the position where the main image frame will be in the sequence. It can one of three values: + - "first": the main frame will be the first one of the sequence, + - "middle": the main frame will be in the middle of the sequence (at position sequence_length // 2), + - "last": the main frame will be the penultimate in the sequence. + """ + if isinstance(difficulties, str): + difficulties = [difficulties] + difficulties = [d.capitalize() for d in difficulties] + super().__init__( + dataset_name=f'TartanAir_{"_".join(difficulties)}', + split_name="trainval", + transform=transform, + max_flow=max_flow, + get_valid_mask=get_valid_mask, + get_occlusion_mask=get_occlusion_mask, + get_motion_boundary_mask=False, + get_backward=False, + get_meta=get_meta, + ) + self.root_dir = root_dir + self.difficulties = difficulties + self.sequence_length = sequence_length + self.sequence_position = sequence_position + + sequence_paths = sorted([p for p in Path(root_dir).glob("*") if p.is_dir()]) + + # Read paths from disk + for seq_path in sequence_paths: + for diff in difficulties: + trajectory_paths = sorted( + [p for p in (seq_path / diff).glob("*") if p.is_dir()] + ) + for traj_path in trajectory_paths: + image_paths = sorted((traj_path / "image_left").glob("*.png")) + image_paths = self._extend_paths_list( + image_paths, sequence_length, sequence_position + ) + + flow_paths = sorted((traj_path / "flow").glob("*_flow.npy")) + flow_paths = self._extend_paths_list( + flow_paths, sequence_length, sequence_position + ) + assert len(image_paths) - 1 == len( + flow_paths + ), f"{seq_path.name}, {traj_path.name}: {len(image_paths)-1} vs {len(flow_paths)}" + + occ_paths = [] + if get_occlusion_mask: + occ_paths = sorted((traj_path / "flow").glob("*_mask.npy")) + occ_paths = self._extend_paths_list( + occ_paths, sequence_length, sequence_position + ) + assert len(occ_paths) == len(flow_paths) + for i in range(len(image_paths) - self.sequence_length + 1): + self.img_paths.append(image_paths[i : i + self.sequence_length]) + if len(flow_paths) > 0: + self.flow_paths.append( + flow_paths[i : i + self.sequence_length - 1] + ) + if len(occ_paths) > 0: + self.occ_paths.append( + occ_paths[i : i + self.sequence_length - 1] + ) + self.metadata.append( + { + "image_paths": [ + str(p) + for p in image_paths[i : i + self.sequence_length] + ], + "is_val": False, + "misc": seq_path.name, + "is_seq_start": i == 0, + } + ) + + # Sanity check + assert len(self.img_paths) == len( + self.flow_paths + ), f"{len(self.img_paths)} vs {len(self.flow_paths)}" + if len(self.occ_paths) > 0: + assert len(self.img_paths) == len( + self.occ_paths + ), f"{len(self.img_paths)} vs {len(self.occ_paths)}" + + self._log_status() + class MiddleburyDataset(BaseFlowDataset): """Handle the Middlebury dataset.""" @@ -2063,3 +2311,121 @@ def __init__( # noqa: C901 ), f"{len(self.img_paths)} vs {len(self.mb_b_paths)}" self._log_status() + + +class KubricDataset(BaseFlowDataset): + """Handle datasets generated by Kubric.""" + + def __init__( # noqa: C901 + self, + root_dir: str, + transform: Callable[[Dict[str, torch.Tensor]], Dict[str, torch.Tensor]] = None, + max_flow: float = 10000.0, + get_valid_mask: bool = True, + get_backward: bool = True, + get_meta: bool = True, + sequence_length: int = 2, + sequence_position: str = "first", + ) -> None: + """Initialize KubricDataset. + + Parameters + ---------- + root_dir : str + path to the root directory of the MPI Sintel dataset. + transform : Callable[[Dict[str, torch.Tensor]], Dict[str, torch.Tensor]], optional + Transform to be applied on the inputs. + max_flow : float, default 10000.0 + Maximum optical flow absolute value. Flow absolute values that go over this limit are clipped, and also marked + as zero in the valid mask. + get_valid_mask : bool, default True + Whether to get or generate valid masks. + get_backward : bool, default True + Whether to get the occluded version of the inputs. + get_meta : bool, default True + Whether to get metadata. + sequence_length : int, default 2 + How many consecutive images are loaded per sample. More than two images can be used for model which exploit more + temporal information. + sequence_position : str, default "first" + Only used when sequence_length > 2. + Determines the position where the main image frame will be in the sequence. It can one of three values: + - "first": the main frame will be the first one of the sequence, + - "middle": the main frame will be in the middle of the sequence (at position sequence_length // 2), + - "last": the main frame will be the penultimate in the sequence. + """ + super().__init__( + dataset_name=f"Kubric", + split_name="trainval", + transform=transform, + max_flow=max_flow, + get_valid_mask=get_valid_mask, + get_motion_boundary_mask=False, + get_backward=get_backward, + get_meta=get_meta, + ) + self.root_dir = root_dir + self.sequence_length = sequence_length + self.sequence_position = sequence_position + + self.flow_format = "kubric_png" + + sequence_dirs = sorted([p for p in (Path(root_dir)).glob("*") if p.is_dir()]) + + self.flow_read_mins = [] + self.flow_read_maxs = [] + + for seq_dir in sequence_dirs: + seq_name = seq_dir.name + image_paths = sorted(seq_dir.glob("rgba_*.png")) + image_paths = self._extend_paths_list( + image_paths, sequence_length, sequence_position + ) + flow_paths = sorted(seq_dir.glob("forward_flow_*.png"))[:-1] + flow_paths = self._extend_paths_list( + flow_paths, sequence_length, sequence_position + ) + assert len(image_paths) - 1 == len( + flow_paths + ), f"{seq_name}: {len(image_paths)-1} vs {len(flow_paths)}" + + with open(seq_dir / "data_ranges.json", "r") as f: + data_ranges = json.load(f) + + if get_backward: + back_flow_paths = sorted(seq_dir.glob("backward_flow_*.png"))[1:] + back_flow_paths = self._extend_paths_list( + back_flow_paths, sequence_length, sequence_position + ) + assert len(image_paths) - 1 == len( + back_flow_paths + ), f"{seq_name}: {len(image_paths)-1} vs {len(back_flow_paths)}" + self.flow_b_read_mins = [] + self.flow_b_read_maxs = [] + + for i in range(len(image_paths) - self.sequence_length + 1): + self.img_paths.append(image_paths[i : i + self.sequence_length]) + if len(flow_paths) > 0: + self.flow_paths.append(flow_paths[i : i + self.sequence_length - 1]) + self.flow_read_mins.append(data_ranges["forward_flow"]["min"]) + self.flow_read_maxs.append(data_ranges["forward_flow"]["max"]) + + if get_backward: + self.flow_b_paths.append( + back_flow_paths[i : i + self.sequence_length - 1] + ) + self.flow_b_read_mins.append(data_ranges["backward_flow"]["min"]) + self.flow_b_read_maxs.append(data_ranges["backward_flow"]["max"]) + + self.metadata.append( + { + "image_paths": [ + str(p) for p in image_paths[i : i + self.sequence_length] + ], + "is_val": False, + "misc": seq_name, + "is_seq_start": i == 0, + } + ) + + self._log_status() diff --git a/ptlflow/models/base_model/base_model.py b/ptlflow/models/base_model/base_model.py index 0ec4cda..9a20748 100644 --- a/ptlflow/models/base_model/base_model.py +++ b/ptlflow/models/base_model/base_model.py @@ -41,10 +41,12 @@ FlyingChairs2Dataset, Hd1kDataset, KittiDataset, + KubricDataset, SintelDataset, FlyingThings3DDataset, FlyingThings3DSubsetDataset, SpringDataset, + TartanAirDataset, ) from ptlflow.utils.utils import InputPadder, InputScaler from ptlflow.utils.utils import config_logging, make_divisible, bgr_val_as_tensor @@ -76,7 +78,6 @@ def __init__(self, args: Namespace, loss_fn: Callable, output_stride: int) -> No self.output_stride = output_stride self.train_size = None - self.train_avg_length = None self.extra_params = None @@ -997,6 +998,8 @@ def _get_chairs2_dataset(self, is_train: bool, *args: str) -> Dataset: get_motion_boundary_mask = True elif v == "back": get_backward = True + else: + raise ValueError(f"Invalid arg: {v}") dataset = FlyingChairs2Dataset( self.args.flying_chairs2_root_dir, @@ -1054,6 +1057,8 @@ def _get_hd1k_dataset(self, is_train: bool, *args: str) -> Dataset: sequence_length = int(v.split("_")[1]) elif v.startswith("seqpos"): sequence_position = v.split("_")[1] + else: + raise ValueError(f"Invalid arg: {v}") dataset = Hd1kDataset( self.args.hd1k_root_dir, @@ -1100,21 +1105,53 @@ def _get_kitti_dataset(self, is_train: bool, *args: str) -> Dataset: versions = ["2012", "2015"] split = "trainval" + get_occlusion_mask = False for v in args: if v in ["2012", "2015"]: versions = [v] elif v in ["train", "val", "trainval", "test"]: split = v + elif v == "occ": + get_occlusion_mask = True + else: + raise ValueError(f"Invalid arg: {v}") dataset = KittiDataset( self.args.kitti_2012_root_dir, self.args.kitti_2015_root_dir, + get_occlusion_mask=get_occlusion_mask, versions=versions, split=split, transform=transform, ) return dataset + def _get_kubric_dataset(self, is_train: bool, *args: str) -> Dataset: + if is_train: + raise NotImplementedError() + else: + transform = ft.ToTensor() + + get_backward = False + sequence_length = 2 + sequence_position = "first" + for v in args: + if v == "back": + get_backward = True + elif v.startswith("seqlen"): + sequence_length = int(v.split("_")[1]) + elif v.startswith("seqpos"): + sequence_position = v.split("_")[1] + + dataset = KubricDataset( + self.args.kubric_root_dir, + transform=transform, + get_backward=get_backward, + sequence_length=sequence_length, + sequence_position=sequence_position, + ) + return dataset + def _get_sintel_dataset(self, is_train: bool, *args: str) -> Dataset: device = "cuda" if self.args.train_transform_cuda else "cpu" md = make_divisible @@ -1164,6 +1201,8 @@ def _get_sintel_dataset(self, is_train: bool, *args: str) -> Dataset: sequence_length = int(v.split("_")[1]) elif v.startswith("seqpos"): sequence_position = v.split("_")[1] + else: + raise ValueError(f"Invalid arg: {v}") dataset = SintelDataset( self.args.mpi_sintel_root_dir, @@ -1216,6 +1255,7 @@ def _get_spring_dataset(self, is_train: bool, *args: str) -> Dataset: sequence_length = 2 sequence_position = "first" reverse_only = False + side_names = [] for v in args: if v in ["train", "val", "trainval", "test"]: split = v @@ -1229,11 +1269,20 @@ def _get_spring_dataset(self, is_train: bool, *args: str) -> Dataset: sequence_length = int(v.split("_")[1]) elif v.startswith("seqpos"): sequence_position = v.split("_")[1] + elif v == "left": + side_names.append("left") + elif v == "right": + side_names.append("right") + else: + raise ValueError(f"Invalid arg: {v}") + + if len(side_names) == 0: + side_names = ["left", "right"] dataset = SpringDataset( self.args.spring_root_dir, split=split, - side_names=["left", "right"], + side_names=side_names, add_reverse=add_reverse, transform=transform, get_backward=get_backward, @@ -1243,6 +1292,68 @@ def _get_spring_dataset(self, is_train: bool, *args: str) -> Dataset: ) return dataset + def _get_tartanair_dataset(self, is_train: bool, *args: str) -> Dataset: + device = "cuda" if self.args.train_transform_cuda else "cpu" + md = make_divisible + + get_occlusion_mask = False + sequence_length = 2 + sequence_position = "first" + difficulties = [] + for v in args: + if v in ["easy", "hard"]: + difficulties.append(v) + elif v == "occ": + get_occlusion_mask = True + elif v.startswith("seqlen"): + sequence_length = int(v.split("_")[1]) + elif v.startswith("seqpos"): + sequence_position = v.split("_")[1] + else: + raise ValueError(f"Invalid arg: {v}") + + if len(difficulties) == 0: + difficulties = ["easy"] + + if is_train: + if self.args.train_crop_size is None: + cy, cx = (md(360, self.output_stride), md(480, self.output_stride)) + self.args.train_crop_size = (cy, cx) + logging.warning( + "--train_crop_size is not set. It will be set as (%d, %d).", cy, cx + ) + else: + cy, cx = ( + md(self.args.train_crop_size[0], self.output_stride), + md(self.args.train_crop_size[1], self.output_stride), + ) + + # These transforms are based on RAFT: https://github.com/princeton-vl/RAFT + transform = ft.Compose( + [ + ft.ToTensor(device=device, fp16=self.args.train_transform_fp16), + ft.RandomScaleAndCrop((cy, cx), (-0.4, 0.8), (-0.2, 0.2)), + ft.ColorJitter(0.4, 0.4, 0.4, 0.5 / 3.14, 0.2), + ft.GaussianNoise(0.02), + ft.RandomPatchEraser( + 0.5, (int(1), int(3)), (int(50), int(100)), "mean" + ), + ft.RandomFlip(min(0.5, 0.5), min(0.1, 0.5)), + ] + ) + else: + transform = ft.ToTensor() + + dataset = TartanAirDataset( + self.args.tartanair_root_dir, + difficulties=difficulties, + transform=transform, + get_occlusion_mask=get_occlusion_mask, + sequence_length=sequence_length, + sequence_position=sequence_position, + ) + return dataset + def _get_things_dataset(self, is_train: bool, *args: str) -> Dataset: device = "cuda" if self.args.train_transform_cuda else "cpu" md = make_divisible @@ -1278,6 +1389,8 @@ def _get_things_dataset(self, is_train: bool, *args: str) -> Dataset: sequence_position = v.split("_")[1] elif v == "sinteltransform": sintel_transform = True + else: + raise ValueError(f"Invalid arg: {v}") if is_train: if self.args.train_crop_size is None: diff --git a/ptlflow/models/neuflow/utils.py b/ptlflow/models/neuflow/utils.py index 68cbf81..7ba0d88 100644 --- a/ptlflow/models/neuflow/utils.py +++ b/ptlflow/models/neuflow/utils.py @@ -41,6 +41,8 @@ def bilinear_sample(img, sample_coords): def flow_warp(feature, flow): b, c, h, w = feature.size() - grid = coords_grid(b, h, w, dtype=flow.dtype, device=flow.device) + flow # [B, 2, H, W] + grid = ( + coords_grid(b, h, w, dtype=flow.dtype, device=flow.device) + flow + ) # [B, 2, H, W] return bilinear_sample(feature, grid) diff --git a/ptlflow/utils/flow_metrics.py b/ptlflow/utils/flow_metrics.py index d24887e..b8ae410 100644 --- a/ptlflow/utils/flow_metrics.py +++ b/ptlflow/utils/flow_metrics.py @@ -157,11 +157,18 @@ def update( if occlusion_target is not None: occlusion_target = self._fix_shape(occlusion_target, batch_size) - epe = torch.norm(flow_pred - flow_target, p=2, dim=1) + if len(flow_target.shape) == 5: + epe = torch.norm(flow_pred[:, None] - flow_target, p=2, dim=2) + epe, min_idx = epe.min(dim=1) + target_norm = torch.norm(flow_target, p=2, dim=2) + target_norm = target_norm.gather(1, min_idx[:, None])[:, 0] + else: + epe = torch.norm(flow_pred - flow_target, p=2, dim=1) + target_norm = torch.norm(flow_target, p=2, dim=1) + px1_mask = (epe < 1).float() px3_mask = (epe < 3).float() px5_mask = (epe < 5).float() - target_norm = torch.norm(flow_target, p=2, dim=1) outlier_mask = ((epe > 3) & (epe > (0.05 * target_norm))).float() * 100 self.used_keys = [ ("epe", "epe", "valid_target"), @@ -339,10 +346,22 @@ def _fix_shape(self, tensor: torch.Tensor, batch_size: int) -> torch.Tensor: tensor.shape[3], tensor.shape[4], ) + elif len(tensor.shape) == 6: + tensor = tensor.view( + tensor.shape[0] * tensor.shape[1], + tensor.shape[2], + tensor.shape[3], + tensor.shape[4], + tensor.shape[5], + ) return tensor def _get_batch_size(self, flow_tensor: torch.Tensor) -> int: if len(flow_tensor.shape) < 4: return 1 - else: - return flow_tensor.view(-1, *flow_tensor.shape[-3:]).shape[0] + elif len(flow_tensor.shape) == 4: + return flow_tensor.shape[0] + elif len(flow_tensor.shape) == 5: + return flow_tensor.shape[0] * flow_tensor.shape[1] + elif len(flow_tensor.shape) == 6: + return flow_tensor.shape[0] diff --git a/ptlflow/utils/flow_utils.py b/ptlflow/utils/flow_utils.py index 69ab1ea..91d451a 100644 --- a/ptlflow/utils/flow_utils.py +++ b/ptlflow/utils/flow_utils.py @@ -19,6 +19,7 @@ import pathlib from typing import IO, Optional, Union +import cv2 as cv import numpy as np import torch @@ -74,7 +75,10 @@ def flow_to_rgb( def flow_read( - input_file: Union[str, pathlib.Path, IO], format: str = None + input_file: Union[str, pathlib.Path, IO], + format: Optional[str] = None, + flow_min: Optional[float] = None, + flow_max: Optional[float] = None, ) -> np.ndarray: """Read optical flow from file. @@ -85,7 +89,7 @@ def flow_read( input_file: str, pathlib.Path or IO Path of the file to read or file object. format: str, optional - Specify in what format the flow is read, accepted formats: "png", "flo", "pfm", or "flo5". + Specify in what format the flow is read, accepted formats: "png", "flo", "pfm", "flo5", "kubric_png". If None, it is guessed on the file extension. Returns @@ -106,6 +110,8 @@ def flow_read( return raft.read_pfm(input_file) elif (format is not None and format == "flo5") or str(input_file).endswith("flo5"): return flow_IO.readFlo5Flow(input_file) + elif format is not None and format == "kubric_png": + return read_kubric_flow(input_file, flow_min=flow_min, flow_max=flow_max) else: return flowpy.flow_read(input_file, format) @@ -139,3 +145,31 @@ def flow_write( flow_IO.writeFlo5File(flow, output_file) else: flowpy.flow_write(output_file, flow, format) + + +def read_kubric_flow( + input_file: Union[str, pathlib.Path, IO], + flow_min: Optional[float], + flow_max: Optional[float], +) -> np.ndarray: + """Read optical flow in Kubric PNG format from file. + + Parameters + ---------- + input_file: str, pathlib.Path or IO + Path of the file to read or file object. + flow_min: float + The minimum flow range value, generated by Kubric and stored in the file data_ranges.json + flow_max: float + The maximum flow range value, generated by Kubric and stored in the file data_ranges.json + + Returns + ------- + numpy.ndarray + 3D flow in the HWF (Height, Width, Flow) layout. + flow[..., 0] is the x-displacement. + flow[..., 1] is the y-displacement. + """ + flow = cv.imread(str(input_file), cv.IMREAD_UNCHANGED)[..., 1:].astype(np.float32) + flow = flow / 65535 * (flow_max - flow_min) + flow_min + return flow diff --git a/validate.py b/validate.py index 1bb1164..60756a8 100644 --- a/validate.py +++ b/validate.py @@ -403,6 +403,10 @@ def validate_one_dataloader( filename = "" if "sintel" in inputs["meta"]["dataset_name"][0].lower(): filename = f'{Path(inputs["meta"]["image_paths"][0][0]).parent.name}/' + elif "spring" in inputs["meta"]["dataset_name"][0].lower(): + filename = ( + f'{Path(inputs["meta"]["image_paths"][0][0]).parent.parent.name}/' + ) filename += Path(inputs["meta"]["image_paths"][0][0]).stem if metrics_individual is not None: @@ -484,6 +488,9 @@ def _write_to_file( if "sintel" in dataloader_name: seq_name = img_path.parts[-2] extra_dirs = seq_name + elif "spring" in dataloader_name: + seq_name = img_path.parts[-3] + extra_dirs = seq_name else: image_name = f"{batch_idx:08d}"