diff --git a/ptlflow/data/datasets.py b/ptlflow/data/datasets.py index e33be54..30682d9 100644 --- a/ptlflow/data/datasets.py +++ b/ptlflow/data/datasets.py @@ -16,21 +16,17 @@ # limitations under the License. # ============================================================================= -import json -import logging import math from pathlib import Path from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union -import cv2 +import cv2 as cv from einops import rearrange +from loguru import logger import numpy as np import torch from torch.utils.data import Dataset from ptlflow.utils import flow_utils -from ptlflow.utils.utils import config_logging - -config_logging() THIS_DIR = Path(__file__).resolve().parent @@ -126,10 +122,8 @@ def __init__( self.metadata = [] self.flow_format = None - self.flow_read_mins = None - self.flow_read_maxs = None - self.flow_b_read_mins = None - self.flow_b_read_maxs = None + + self.is_two_file_flow = False def __getitem__(self, index: int) -> Dict[str, torch.Tensor]: # noqa: C901 """Retrieve and return one input. @@ -150,38 +144,25 @@ def __getitem__(self, index: int) -> Dict[str, torch.Tensor]: # noqa: C901 """ inputs = {} - inputs["images"] = [cv2.imread(str(path)) for path in self.img_paths[index]] + inputs["images"] = [cv.imread(str(path)) for path in self.img_paths[index]] if index < len(self.flow_paths): inputs["flows"], valids = self._get_flows_and_valids( self.flow_paths[index], flow_format=self.flow_format, - flow_min=( - self.flow_read_mins[index] - if ( - self.flow_read_mins is not None - and len(self.flow_read_mins) > index - ) - else None - ), - flow_max=( - self.flow_read_maxs[index] - if ( - self.flow_read_maxs is not None - and len(self.flow_read_maxs) > index - ) - else None - ), ) if self.get_valid_mask: inputs["valids"] = valids if self.get_occlusion_mask: if index < len(self.occ_paths): - inputs["occs"] = [ - cv2.imread(str(path), 0)[:, :, None] - for path in self.occ_paths[index] - ] + inputs["occs"] = [] + for path in self.occ_paths[index]: + if str(path).endswith("npy"): + occ = np.load(path) + else: + occ = cv.imread(str(path), 0) + inputs["occs"].append(occ[:, :, None]) elif self.dataset_name.startswith("KITTI"): noc_paths = [ str(p).replace("flow_occ", "flow_noc") @@ -190,27 +171,11 @@ def __getitem__(self, index: int) -> Dict[str, torch.Tensor]: # noqa: C901 _, valids_noc = self._get_flows_and_valids( noc_paths, flow_format=self.flow_format, - flow_min=( - self.flow_read_mins[index] - if ( - self.flow_read_mins is not None - and len(self.flow_read_mins) > index - ) - else None - ), - flow_max=( - self.flow_read_maxs[index] - if ( - self.flow_read_maxs is not None - and len(self.flow_read_maxs) > index - ) - else None - ), ) inputs["occs"] = [valids[i] - valids_noc[i] for i in range(len(valids))] if self.get_motion_boundary_mask and index < len(self.mb_paths): inputs["mbs"] = [ - cv2.imread(str(path), 0)[:, :, None] for path in self.mb_paths[index] + cv.imread(str(path), 0)[:, :, None] for path in self.mb_paths[index] ] if self.get_backward: @@ -218,33 +183,20 @@ def __getitem__(self, index: int) -> Dict[str, torch.Tensor]: # noqa: C901 inputs["flows_b"], valids_b = self._get_flows_and_valids( self.flow_b_paths[index], flow_format=self.flow_format, - flow_min=( - self.flow_b_read_mins[index] - if ( - self.flow_b_read_mins is not None - and len(self.flow_b_read_mins) > index - ) - else None - ), - flow_max=( - self.flow_b_read_maxs[index] - if ( - self.flow_b_read_maxs is not None - and len(self.flow_b_read_maxs) > index - ) - else None - ), ) if self.get_valid_mask: inputs["valids_b"] = valids_b if self.get_occlusion_mask and index < len(self.occ_b_paths): - inputs["occs_b"] = [ - cv2.imread(str(path), 0)[:, :, None] - for path in self.occ_b_paths[index] - ] + inputs["occs_b"] = [] + for path in self.occ_b_paths[index]: + if str(path).endswith("npy"): + occ = np.load(path) + else: + occ = cv.imread(str(path), 0) + inputs["occs_b"].append(occ[:, :, None]) if self.get_motion_boundary_mask and index < len(self.mb_b_paths): inputs["mbs_b"] = [ - cv2.imread(str(path), 0)[:, :, None] + cv.imread(str(path), 0)[:, :, None] for path in self.mb_b_paths[index] ] @@ -268,15 +220,16 @@ def _get_flows_and_valids( self, flow_paths: Sequence[str], flow_format: Optional[str] = None, - flow_min: Optional[float] = None, - flow_max: Optional[float] = None, ) -> Tuple[List[np.ndarray], List[Optional[np.ndarray]]]: flows = [] valids = [] for path in flow_paths: - flow = flow_utils.flow_read( - path, format=flow_format, flow_min=flow_min, flow_max=flow_max - ) + if self.is_two_file_flow: + flow_x = -flow_utils.flow_read(path[0], format=flow_format) + flow_y = -flow_utils.flow_read(path[1], format=flow_format) + flow = np.stack([flow_x, flow_y], 2) + else: + flow = flow_utils.flow_read(path, format=flow_format) nan_mask = np.isnan(flow) flow[nan_mask] = self.max_flow + 1 @@ -294,14 +247,14 @@ def _get_flows_and_valids( def _log_status(self) -> None: if self.__len__() == 0: - logging.warning( - "No samples were found for %s dataset. Be sure to update the dataset path in datasets.yml, " + logger.warning( + "No samples were found for {} dataset. Be sure to update the dataset path in datasets.yml, " "or provide the path by the argument --[dataset_name]_root_dir.", self.dataset_name, ) else: - logging.info( - "Loading %d samples from %s dataset.", self.__len__(), self.dataset_name + logger.info( + "Loading {} samples from {} dataset.", self.__len__(), self.dataset_name ) def _extend_paths_list( @@ -411,7 +364,6 @@ def __init__( "is_val": paths[0].stem in val_names, "misc": "", "is_seq_start": True, - "is_seq_end": True, } for paths in self.img_paths ] @@ -511,7 +463,6 @@ def __init__( "is_val": paths[0].stem in val_names, "misc": "", "is_seq_start": True, - "is_seq_end": True, } for paths in self.img_paths ] @@ -526,13 +477,13 @@ def __init__( self, root_dir: str, split: str = "train", - add_reverse: bool = True, + add_reverse: bool = False, transform: Callable[[Dict[str, torch.Tensor]], Dict[str, torch.Tensor]] = None, max_flow: float = 1000.0, get_valid_mask: bool = True, - get_occlusion_mask: bool = True, - get_motion_boundary_mask: bool = True, - get_backward: bool = True, + get_occlusion_mask: bool = False, + get_motion_boundary_mask: bool = False, + get_backward: bool = False, get_meta: bool = True, ) -> None: """Initialize FlyingChairs2Dataset. @@ -686,7 +637,6 @@ def __init__( "is_val": False, "misc": "", "is_seq_start": True, - "is_seq_end": True, } for paths in self.img_paths ] @@ -967,10 +917,6 @@ def __init__( # noqa: C901 "is_val": False, "misc": "", "is_seq_start": i == 0, - "is_seq_end": i - == ( - len(image_paths) - self.sequence_length - ), } ) if self.get_backward: @@ -1189,8 +1135,6 @@ def __init__( # noqa: C901 "is_val": False, "misc": "", "is_seq_start": i == 0, - "is_seq_end": i - == (len(flow_group) - self.sequence_length + 1), } ) @@ -1387,7 +1331,6 @@ def __init__( # noqa: C901 "is_val": (seq_img_names[i] in val_names), "misc": "", "is_seq_start": True, - "is_seq_end": True, } ) @@ -1527,7 +1470,6 @@ def __init__( # noqa: C901 "is_val": img1_paths[i].stem in val_names, "misc": ver, "is_seq_start": True, - "is_seq_end": True, } for i in range(len(img1_paths)) if img1_paths[i].stem not in remove_names @@ -1685,8 +1627,6 @@ def __init__( # noqa: C901 "is_val": seq_name in val_seqs, "misc": seq_name, "is_seq_start": i == 0, - "is_seq_end": i - == (len(image_paths) - self.sequence_length), } ) @@ -1720,6 +1660,9 @@ def __init__( # noqa: C901 sequence_length: int = 2, sequence_position: str = "first", reverse_only: bool = False, + subsample: bool = False, + is_image_4k: bool = False, + image_4k_split_dir_suffix: str = "_4k", ) -> None: """Initialize SintelDataset. @@ -1755,6 +1698,18 @@ def __init__( # noqa: C901 - "last": the main frame will be the penultimate in the sequence. reverse_only : bool, default False If True, only uses the backward samples, discarding the forward ones. + subsample : bool, default False + If True, the groundtruth is subsampled from 4K to 2K by neareast subsampling. + If False, and is_image_4k is also False, then the groundtruth is reshaped as: einops.rearrange("b c (h nh) (w nw) -> b (nh nw) c h w", nh=2, nw=2), + which corresponds to stacking the predictions of every 2x2 blocks. + If False, and is_image_4k is True, then the groundtruth is returned in its original 4D-shaped 4K resolution, but the flow values are doubled. + is_image_4k : bool, default False + If True, assumes the input images will be provided in 4K resolution, instead of the original 2K. + image_4k_split_dir_suffix : str, default "_4k" + Only used when is_image_4k == True. It indicates the suffix to add to the split folder name where the 4k images are located. + For example, by default, the 4K images need to be located inside folders called "train_4k" and/or "test/4k". + The structure of these folders should be the same as the original "train" and "test". + The "*_4k" folders only need to contain the image directories, the groundtruth will still be loaded from the original locations. """ if isinstance(side_names, str): side_names = [side_names] @@ -1774,6 +1729,12 @@ def __init__( # noqa: C901 self.side_names = side_names self.sequence_length = sequence_length self.sequence_position = sequence_position + self.subsample = subsample + self.is_image_4k = is_image_4k + self.image_4k_split_dir_suffix = image_4k_split_dir_suffix + + if self.is_image_4k: + assert not self.subsample # Get sequence names for the given split if split == "test": @@ -1797,9 +1758,17 @@ def __init__( # noqa: C901 for side in side_names: for direcs in directions: rev = direcs[0] == "BW" + img_split_dir_name = ( + f"{split_dir}{self.image_4k_split_dir_suffix}" + if self.is_image_4k + else split_dir + ) image_paths = sorted( ( - Path(self.root_dir) / split_dir / seq_name / f"frame_{side}" + Path(self.root_dir) + / img_split_dir_name + / seq_name + / f"frame_{side}" ).glob("*.png"), reverse=rev, ) @@ -1860,8 +1829,6 @@ def __init__( # noqa: C901 "is_val": False, "misc": seq_name, "is_seq_start": i == 0, - "is_seq_end": i - == (len(image_paths) - self.sequence_length), } ) @@ -1892,7 +1859,7 @@ def __getitem__(self, index: int) -> Dict[str, torch.Tensor]: # noqa: C901 """ inputs = {} - inputs["images"] = [cv2.imread(str(path)) for path in self.img_paths[index]] + inputs["images"] = [cv.imread(str(path)) for path in self.img_paths[index]] if index < len(self.flow_paths): inputs["flows"], valids = self._get_flows_and_valids(self.flow_paths[index]) @@ -1907,13 +1874,38 @@ def __getitem__(self, index: int) -> Dict[str, torch.Tensor]: # noqa: C901 if self.get_valid_mask: inputs["valids_b"] = valids_b - if self.transform is not None: - inputs = self.transform(inputs) + if self.subsample: + inputs["flows"] = [f[::2, ::2] for f in inputs["flows"]] + inputs["valids"] = [v[::2, ::2] for v in inputs["valids"]] + if self.get_backward: + inputs["flows_b"] = [f[::2, ::2] for f in inputs["flows_b"]] + inputs["valids_b"] = [v[::2, ::2] for v in inputs["valids_b"]] + if self.transform is not None: + inputs = self.transform(inputs) + elif self.is_image_4k: + if self.transform is not None: + inputs = self.transform(inputs) + if "flows" in inputs: + inputs["flows"] = 2 * inputs["flows"] + if self.get_backward: + inputs["flows_b"] = 2 * inputs["flows_b"] + else: + if self.transform is not None: + inputs = self.transform(inputs) - inputs["flows"] = rearrange( - inputs["flows"], "b c (h nh) (w nw) -> b (nh nw) c h w", nh=2, nw=2 - ) - inputs["valids"] = inputs["valids"][:, :, ::2, ::2] + if "flows" in inputs: + inputs["flows"] = rearrange( + inputs["flows"], "b c (h nh) (w nw) -> b (nh nw) c h w", nh=2, nw=2 + ) + inputs["valids"] = inputs["valids"][:, :, ::2, ::2] + if self.get_backward: + inputs["flows_b"] = rearrange( + inputs["flows_b"], + "b c (h nh) (w nw) -> b (nh nw) c h w", + nh=2, + nw=2, + ) + inputs["valids_b"] = inputs["valids_b"][:, :, ::2, ::2] if self.get_meta: inputs["meta"] = { @@ -1941,7 +1933,7 @@ def __init__( # noqa: C901 sequence_length: int = 2, sequence_position: str = "first", ) -> None: - """Initialize SintelDataset. + """Initialize TartanAirDataset. Parameters ---------- @@ -2140,7 +2132,6 @@ def __init__( # noqa: C901 "is_val": False, "misc": seq_name, "is_seq_start": True, - "is_seq_end": True, } ) @@ -2153,6 +2144,76 @@ def __init__( # noqa: C901 self._log_status() +class MiddleburySTDataset(BaseFlowDataset): + """Handle the Middlebury-ST dataset.""" + + def __init__( # noqa: C901 + self, + root_dir: str, + transform: Callable[[Dict[str, torch.Tensor]], Dict[str, torch.Tensor]] = None, + max_flow: float = 10000.0, + get_valid_mask: bool = True, + get_meta: bool = True, + ) -> None: + """Initialize MiddleburySTDataset. + + Parameters + ---------- + root_dir : str + path to the root directory of the Middlebury dataset. + transform : Callable[[Dict[str, torch.Tensor]], Dict[str, torch.Tensor]], optional + Transform to be applied on the inputs. + max_flow : float, default 10000.0 + Maximum optical flow absolute value. Flow absolute values that go over this limit are clipped, and also marked + as zero in the valid mask. + get_valid_mask : bool, default True + Whether to get or generate valid masks. + get_meta : bool, default True + Whether to get metadata. + """ + super().__init__( + dataset_name="MiddleburyST", + split_name="trainval", + transform=transform, + max_flow=max_flow, + get_valid_mask=get_valid_mask, + get_occlusion_mask=False, + get_motion_boundary_mask=False, + get_backward=False, + get_meta=get_meta, + ) + self.root_dir = root_dir + self.sequence_length = 2 + self.is_two_file_flow = True + + sequence_names = sorted( + [p.stem for p in Path(self.root_dir).glob("*") if p.is_dir()] + ) + + # Read paths from disk + for seq_name in sequence_names: + image_paths = [ + Path(self.root_dir) / seq_name / "im0.png", + Path(self.root_dir) / seq_name / "im1.png", + ] + self.img_paths.append(image_paths) + disp_paths = [ + Path(self.root_dir) / seq_name / "disp0.pfm", + Path(self.root_dir) / seq_name / "disp0y.pfm", + ] + self.flow_paths.append([disp_paths]) + self.metadata.append( + { + "image_paths": [str(p) for p in image_paths], + "is_val": False, + "misc": seq_name, + "is_seq_start": True, + } + ) + + self._log_status() + + class MonkaaDataset(BaseFlowDataset): """Handle the Monkaa dataset.""" @@ -2291,8 +2352,6 @@ def __init__( # noqa: C901 "is_val": False, "misc": "", "is_seq_start": i == 0, - "is_seq_end": i - == (len(image_paths) - self.sequence_length), } ) if self.get_backward: @@ -2337,13 +2396,14 @@ def __init__( # noqa: C901 get_meta: bool = True, sequence_length: int = 2, sequence_position: str = "first", + max_seq: Optional[int] = None, ) -> None: """Initialize KubricDataset. Parameters ---------- root_dir : str - path to the root directory of the MPI Sintel dataset. + path to the root directory of the Kubric dataset. transform : Callable[[Dict[str, torch.Tensor]], Dict[str, torch.Tensor]], optional Transform to be applied on the inputs. max_flow : float, default 10000.0 @@ -2382,9 +2442,7 @@ def __init__( # noqa: C901 self.flow_format = "kubric_png" sequence_dirs = sorted([p for p in (Path(root_dir)).glob("*") if p.is_dir()]) - - self.flow_read_mins = [] - self.flow_read_maxs = [] + sequence_dirs = sequence_dirs[:max_seq] for seq_dir in sequence_dirs: seq_name = seq_dir.name @@ -2396,37 +2454,30 @@ def __init__( # noqa: C901 flow_paths = self._extend_paths_list( flow_paths, sequence_length, sequence_position ) + flow_paths = [(p, "forward_flow") for p in flow_paths] assert len(image_paths) - 1 == len( flow_paths ), f"{seq_name}: {len(image_paths)-1} vs {len(flow_paths)}" - with open(seq_dir / "data_ranges.json", "r") as f: - data_ranges = json.load(f) - if get_backward: back_flow_paths = sorted(seq_dir.glob("backward_flow_*.png"))[1:] back_flow_paths = self._extend_paths_list( back_flow_paths, sequence_length, sequence_position ) + back_flow_paths = [(p, "backward_flow") for p in back_flow_paths] assert len(image_paths) - 1 == len( back_flow_paths ), f"{seq_name}: {len(image_paths)-1} vs {len(back_flow_paths)}" - self.flow_b_read_mins = [] - self.flow_b_read_maxs = [] for i in range(len(image_paths) - self.sequence_length + 1): self.img_paths.append(image_paths[i : i + self.sequence_length]) if len(flow_paths) > 0: self.flow_paths.append(flow_paths[i : i + self.sequence_length - 1]) - self.flow_read_mins.append(data_ranges["forward_flow"]["min"]) - self.flow_read_maxs.append(data_ranges["forward_flow"]["max"]) if get_backward: self.flow_b_paths.append( back_flow_paths[i : i + self.sequence_length - 1] ) - self.flow_b_read_mins.append(data_ranges["backward_flow"]["min"]) - self.flow_b_read_maxs.append(data_ranges["backward_flow"]["max"]) self.metadata.append( { @@ -2440,3 +2491,97 @@ def __init__( # noqa: C901 ) self._log_status() + + +class ViperDataset(BaseFlowDataset): + """Handle the Viper dataset.""" + + def __init__( # noqa: C901 + self, + root_dir: str, + split: str = "train", + transform: Callable[[Dict[str, torch.Tensor]], Dict[str, torch.Tensor]] = None, + max_flow: float = 10000.0, + get_valid_mask: bool = True, + get_meta: bool = True, + ) -> None: + """Initialize ViperDataset. + + Parameters + ---------- + root_dir : str + path to the root directory of the Middlebury dataset. + split : str, default 'train' + Which split of the dataset should be loaded. It can be one of {'train', 'val', 'trainval'}. + transform : Callable[[Dict[str, torch.Tensor]], Dict[str, torch.Tensor]], optional + Transform to be applied on the inputs. + max_flow : float, default 10000.0 + Maximum optical flow absolute value. Flow absolute values that go over this limit are clipped, and also marked + as zero in the valid mask. + img_extension : str + Extension of the image file. It can be one of {'jpg', 'png'}. + get_valid_mask : bool, default True + Whether to get or generate valid masks. + get_meta : bool, default True + Whether to get metadata. + """ + super().__init__( + dataset_name="VIPER", + split_name=split, + transform=transform, + max_flow=max_flow, + get_valid_mask=get_valid_mask, + get_occlusion_mask=False, + get_motion_boundary_mask=False, + get_backward=False, + get_meta=get_meta, + ) + self.root_dir = root_dir + self.sequence_length = 2 + + self.flow_format = "viper_npz" + + if split == "trainval": + split_dirs = ["train", "val"] + else: + split_dirs = [split] + + for spdir in split_dirs: + img_dir_path = Path(self.root_dir) / spdir / "img" + flow_dir_path = Path(self.root_dir) / spdir / "flow" + + sequence_names = sorted( + [p.stem for p in img_dir_path.glob("*") if p.is_dir()] + ) + + # Read paths from disk + for seq_name in sequence_names: + if flow_dir_path.exists(): + flow_paths = sorted(list((flow_dir_path / seq_name).glob(f"*.npz"))) + for fpath in flow_paths: + file_idx = int(fpath.stem.split("_")[1]) + img1_path = ( + img_dir_path / seq_name / f"{seq_name}_{(file_idx):05d}.png" + ) + img2_path = ( + img_dir_path + / seq_name + / f"{seq_name}_{(file_idx + 1):05d}.png" + ) + if img1_path.exists() and img2_path.exists(): + self.img_paths.append([img1_path, img2_path]) + self.flow_paths.append([fpath]) + self.metadata.append( + { + "image_paths": [ + str(p) for p in [img1_path, img2_path] + ], + "is_val": spdir == "val", + "misc": seq_name, + "is_seq_start": True, + } + ) + else: + raise NotImplementedError() + + self._log_status() diff --git a/ptlflow/data/flow_datamodule.py b/ptlflow/data/flow_datamodule.py new file mode 100644 index 0000000..bab2436 --- /dev/null +++ b/ptlflow/data/flow_datamodule.py @@ -0,0 +1,1275 @@ +# ============================================================================= +# Copyright 2024 Henrique Morimitsu +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +from typing import Optional + +import lightning.pytorch as pl +from loguru import logger +from torch.utils.data import DataLoader, Dataset +import yaml + +from ptlflow.data import flow_transforms as ft +from ptlflow.data.datasets import ( + AutoFlowDataset, + FlyingChairsDataset, + FlyingChairs2Dataset, + Hd1kDataset, + KittiDataset, + KubricDataset, + MiddleburySTDataset, + SintelDataset, + FlyingThings3DDataset, + FlyingThings3DSubsetDataset, + SpringDataset, + TartanAirDataset, + ViperDataset, +) +from ptlflow.utils.utils import make_divisible + + +class FlowDataModule(pl.LightningDataModule): + def __init__( + self, + predict_dataset: Optional[str] = None, + test_dataset: Optional[str] = None, + train_dataset: Optional[str] = None, + val_dataset: Optional[str] = None, + train_batch_size: Optional[int] = None, + train_num_workers: int = 4, + train_crop_size: tuple[int, int] = None, + train_transform_cuda: bool = False, + train_transform_fp16: bool = False, + autoflow_root_dir: Optional[str] = None, + flying_chairs_root_dir: Optional[str] = None, + flying_chairs2_root_dir: Optional[str] = None, + flying_things3d_root_dir: Optional[str] = None, + flying_things3d_subset_root_dir: Optional[str] = None, + mpi_sintel_root_dir: Optional[str] = None, + kitti_2012_root_dir: Optional[str] = None, + kitti_2015_root_dir: Optional[str] = None, + hd1k_root_dir: Optional[str] = None, + tartanair_root_dir: Optional[str] = None, + spring_root_dir: Optional[str] = None, + kubric_root_dir: Optional[str] = None, + dataset_config_path: str = "./datasets.yaml", + ): + super().__init__() + self.predict_dataset = predict_dataset + self.test_dataset = test_dataset + self.train_dataset = train_dataset + self.val_dataset = val_dataset + self.train_batch_size = train_batch_size + self.train_num_workers = train_num_workers + self.train_crop_size = train_crop_size + self.train_transform_cuda = train_transform_cuda + self.train_transform_fp16 = train_transform_fp16 + + self.autoflow_root_dir = autoflow_root_dir + self.flying_chairs_root_dir = flying_chairs_root_dir + self.flying_chairs2_root_dir = flying_chairs2_root_dir + self.flying_things3d_root_dir = flying_things3d_root_dir + self.flying_things3d_subset_root_dir = flying_things3d_subset_root_dir + self.mpi_sintel_root_dir = mpi_sintel_root_dir + self.kitti_2012_root_dir = kitti_2012_root_dir + self.kitti_2015_root_dir = kitti_2015_root_dir + self.hd1k_root_dir = hd1k_root_dir + self.tartanair_root_dir = tartanair_root_dir + self.spring_root_dir = spring_root_dir + self.kubric_root_dir = kubric_root_dir + self.dataset_config_path = dataset_config_path + + self.predict_dataset_parsed = None + self.test_dataset_parsed = None + self.train_dataset_parsed = None + self.val_dataset_parsed = None + + self.train_dataloader_length = 0 + self.train_epoch_step = 0 + + self.val_dataloader_names = [] + self.val_dataloader_lengths = [] + + self.test_dataloader_names = [] + + def setup(self, stage): + self._load_dataset_paths() + + if stage == "fit": + assert ( + self.train_dataset is not None + ), "You need to provide a value for --data.train_dataset" + assert ( + self.val_dataset is not None + ), "You need to provide a value for --data.val_dataset" + + if self.train_dataset is None: + self.train_dataset = "chairs-train" + logger.warning( + "--data.train_dataset is not set. It will be set as {}", + self.train_dataset, + ) + if self.train_batch_size is None: + self.train_batch_size = 8 + logger.warning( + "--data.train_batch_size is not set. It will be set to {}", + self.train_batch_size, + ) + + self.train_dataset_parsed = self._parse_dataset_selection( + self.train_dataset + ) + self.val_dataset_parsed = self._parse_dataset_selection(self.val_dataset) + elif stage == "predict": + assert ( + self.predict_dataset is not None + ), "You need to provide a value for --data.predict_dataset" + self.parsed_predict_dataset_parsed = self._parse_dataset_selection( + self.predict_dataset + ) + elif stage == "test": + assert ( + self.test_dataset is not None + ), "You need to provide a value for --data.test_dataset" + self.test_dataset_parsed = self._parse_dataset_selection(self.test_dataset) + elif stage == "validate": + assert ( + self.val_dataset is not None + ), "You need to provide a value for --data.val_dataset" + self.val_dataset_parsed = self._parse_dataset_selection(self.val_dataset) + + def predict_dataloader(self): + return super().predict_dataloader() + + def test_dataloader(self): + dataset_ids = [self.test_dataset] + if "sintel" in dataset_ids: + dataset_ids.remove("sintel") + dataset_ids.extend(["sintel-clean", "sintel-final"]) + elif "spring" in dataset_ids: + dataset_ids.append("spring-revonly") + + dataloaders = [] + for dataset_id in dataset_ids: + dataset_id += "-test" + dataset_tokens = dataset_id.split("-") + dataset = getattr(self, f"_get_{dataset_tokens[0]}_dataset")( + False, *dataset_tokens[1:] + ) + dataloaders.append( + DataLoader( + dataset, + 1, + shuffle=False, + num_workers=1, + pin_memory=False, + drop_last=False, + ) + ) + + self.test_dataloader_names.append(dataset_id) + + return dataloaders + + def train_dataloader(self): + if self.train_dataset_parsed is not None: + train_dataset = None + for parsed_vals in self.train_dataset_parsed: + multiplier = parsed_vals[0] + dataset_name = parsed_vals[1] + dataset = getattr(self, f"_get_{dataset_name}_dataset")( + True, *parsed_vals[2:] + ) + dataset_mult = dataset + for _ in range(multiplier - 1): + dataset_mult += dataset + + if train_dataset is None: + train_dataset = dataset_mult + else: + train_dataset += dataset_mult + + train_pin_memory = False if self.train_transform_cuda else True + train_dataloader = DataLoader( + train_dataset, + self.train_batch_size, + shuffle=True, + num_workers=self.train_num_workers, + pin_memory=train_pin_memory, + drop_last=False, + persistent_workers=self.train_transform_cuda, + ) + self.train_dataloader_length = len(train_dataloader) + return train_dataloader + + def val_dataloader(self): + dataloaders = [] + self.val_dataloader_names = [] + self.val_dataloader_lengths = [] + for parsed_vals in self.val_dataset_parsed: + dataset_name = parsed_vals[1] + dataset = getattr(self, f"_get_{dataset_name}_dataset")( + False, *parsed_vals[2:] + ) + dataloaders.append( + DataLoader( + dataset, + 1, + shuffle=False, + num_workers=1, + pin_memory=False, + drop_last=False, + persistent_workers=self.train_transform_cuda, + ) + ) + + self.val_dataloader_names.append("-".join(parsed_vals[1:])) + self.val_dataloader_lengths.append(len(dataset)) + + return dataloaders + + def _load_dataset_paths(self): + with open(self.dataset_config_path, "r") as f: + dataset_paths = yaml.safe_load(f) + for name, path in dataset_paths.items(): + if getattr(self, f"{name}_root_dir") is None: + setattr(self, f"{name}_root_dir", path) + + def _parse_dataset_selection( + self, + dataset_selection: str, + ) -> list[tuple[str, int]]: + """Parse the input string into the selected dataset and their multipliers and parameters. + + For example, 'chairs-train+3*sintel-clean-trainval+kitti-2012-train*5' will be parsed into + [(1, 'chairs', 'train'), (3, 'sintel', 'clean', 'trainval'), (5, 'kitti', '2012', 'train')]. + + Parameters + ---------- + dataset_selection : str + The string defining the dataset selection. Each dataset is separated by a '+' sign. The multiplier must be either + in the beginning or the end of one dataset string, connected to a '*' sign. The remaining content must be separated + by '-' symbols. + + Returns + ------- + List[Tuple[str, int]] + The parsed choice of datasets and their number of repetitions. + + Raises + ------ + ValueError + If the given string is invalid. + """ + if dataset_selection is None: + return [] + + dataset_selection = dataset_selection.replace(" ", "") + datasets = dataset_selection.split("+") + for i in range(len(datasets)): + tokens = datasets[i].split("*") + if len(tokens) == 1: + datasets[i] = (1,) + tuple(tokens[0].split("-")) + elif len(tokens) == 2: + try: + mult, params = int(tokens[0]), tokens[1] + except ValueError: + params, mult = tokens[0], int( + tokens[1] + ) # if the multiplier comes last. + datasets[i] = (mult,) + tuple(params.split("-")) + else: + raise ValueError( + "The specified dataset string {:} is invalid. Check the BaseModel.parse_dataset_selection() documentation " + "to see how to write a valid selection string." + ) + return datasets + + def _get_model_output_stride(self): + if hasattr(self, "trainer") and self.trainer is not None: + if hasattr(self.trainer.model, "module"): + return self.trainer.model.module.output_stride + else: + return self.trainer.model.output_stride + else: + return 1 + + ########################################################################### + # _get_datasets + ########################################################################### + + def _get_autoflow_dataset(self, is_train: bool, *args: str) -> Dataset: + device = "cuda" if self.train_transform_cuda else "cpu" + md = make_divisible + + fbocc_transform = False + for v in args: + if v == "fbocc": + fbocc_transform = True + else: + raise ValueError(f"Invalid arg: {v}") + + if is_train: + if self.train_crop_size is None: + cy, cx = ( + md(368, self._get_model_output_stride()), + md(496, self._get_model_output_stride()), + ) + self.train_crop_size = (cy, cx) + logger.warning( + "--train_crop_size is not set. It will be set as ({}, {}}).", cy, cx + ) + else: + cy, cx = ( + md(self.train_crop_size[0], self._get_model_output_stride()), + md(self.train_crop_size[1], self._get_model_output_stride()), + ) + + # These transforms are based on RAFT: https://github.com/princeton-vl/RAFT + transform = ft.Compose( + [ + ft.ToTensor(device=device, fp16=self.train_transform_fp16), + ft.RandomScaleAndCrop( + (cy, cx), + (-0.1, 1.0), + (-0.2, 0.2), + ), + ft.ColorJitter(0.4, 0.4, 0.4, 0.5 / 3.14, 0.2), + ft.GaussianNoise(0.02), + ft.RandomPatchEraser( + 0.5, (int(1), int(3)), (int(50), int(100)), "mean" + ), + ft.RandomFlip(min(0.5, 0.5), min(0.1, 0.5)), + ( + ft.GenerateFBCheckFlowOcclusion(threshold=1) + if fbocc_transform + else None + ), + ] + ) + else: + transform = ft.ToTensor() + + split = "trainval" + if len(args) > 0 and args[0] in ["train", "val", "trainval"]: + split = args[0] + dataset = AutoFlowDataset( + self.autoflow_root_dir, split=split, transform=transform + ) + return dataset + + def _get_chairs_dataset(self, is_train: bool, *args: str) -> Dataset: + device = "cuda" if self.train_transform_cuda else "cpu" + md = make_divisible + + fbocc_transform = False + split = "trainval" + for v in args: + if v in ["train", "val", "trainval"]: + split = args[0] + elif v == "fbocc": + fbocc_transform = True + else: + raise ValueError(f"Invalid arg: {v}") + + if is_train: + if self.train_crop_size is None: + cy, cx = ( + md(368, self._get_model_output_stride()), + md(496, self._get_model_output_stride()), + ) + self.train_crop_size = (cy, cx) + logger.warning( + "--train_crop_size is not set. It will be set as ({}, {}).", cy, cx + ) + else: + cy, cx = ( + md(self.train_crop_size[0], self._get_model_output_stride()), + md(self.train_crop_size[1], self._get_model_output_stride()), + ) + + # These transforms are based on RAFT: https://github.com/princeton-vl/RAFT + transform = ft.Compose( + [ + ft.ToTensor(device=device, fp16=self.train_transform_fp16), + ft.RandomScaleAndCrop( + (cy, cx), + (-0.1, 1.0), + (-0.2, 0.2), + ), + ft.ColorJitter(0.4, 0.4, 0.4, 0.5 / 3.14, 0.2), + ft.GaussianNoise(0.02), + ft.RandomPatchEraser( + 0.5, (int(1), int(3)), (int(50), int(100)), "mean" + ), + ft.RandomFlip(min(0.5, 0.5), min(0.1, 0.5)), + ( + ft.GenerateFBCheckFlowOcclusion(threshold=1) + if fbocc_transform + else None + ), + ] + ) + else: + transform = ft.ToTensor() + + dataset = FlyingChairsDataset( + self.flying_chairs_root_dir, split=split, transform=transform + ) + return dataset + + def _get_chairs2_dataset(self, is_train: bool, *args: str) -> Dataset: + device = "cuda" if self.train_transform_cuda else "cpu" + md = make_divisible + + split = "trainval" + add_reverse = False + get_occlusion_mask = False + get_motion_boundary_mask = False + get_backward = False + fbocc_transform = False + for v in args: + if v in ["train", "val", "trainval"]: + split = v + elif v == "rev": + add_reverse = True + elif v == "occ": + get_occlusion_mask = True + elif v == "mb": + get_motion_boundary_mask = True + elif v == "back": + get_backward = True + elif v == "fbocc": + fbocc_transform = True + else: + raise ValueError(f"Invalid arg: {v}") + + if is_train: + if self.train_crop_size is None: + cy, cx = ( + md(368, self._get_model_output_stride()), + md(496, self._get_model_output_stride()), + ) + self.train_crop_size = (cy, cx) + logger.warning( + "--train_crop_size is not set. It will be set as ({}, {}).", cy, cx + ) + else: + cy, cx = ( + md(self.train_crop_size[0], self._get_model_output_stride()), + md(self.train_crop_size[1], self._get_model_output_stride()), + ) + + # These transforms are based on RAFT: https://github.com/princeton-vl/RAFT + transform = ft.Compose( + [ + ft.ToTensor(device=device, fp16=self.train_transform_fp16), + ft.RandomScaleAndCrop( + (cy, cx), + (-0.1, 1.0), + (-0.2, 0.2), + ), + ft.ColorJitter(0.4, 0.4, 0.4, 0.5 / 3.14, 0.2), + ft.GaussianNoise(0.02), + ft.RandomPatchEraser( + 0.5, (int(1), int(3)), (int(50), int(100)), "mean" + ), + ft.RandomFlip(min(0.5, 0.5), min(0.1, 0.5)), + ( + ft.GenerateFBCheckFlowOcclusion(threshold=1) + if fbocc_transform + else None + ), + ] + ) + else: + transform = ft.ToTensor() + + dataset = FlyingChairs2Dataset( + self.flying_chairs2_root_dir, + split=split, + transform=transform, + add_reverse=add_reverse, + get_occlusion_mask=get_occlusion_mask, + get_motion_boundary_mask=get_motion_boundary_mask, + get_backward=get_backward, + ) + return dataset + + def _get_hd1k_dataset(self, is_train: bool, *args: str) -> Dataset: + device = "cuda" if self.train_transform_cuda else "cpu" + md = make_divisible + + split = "trainval" + sequence_length = 2 + sequence_position = "first" + fbocc_transform = False + for v in args: + if v in ["train", "val", "trainval", "test"]: + split = args[0] + elif v.startswith("seqlen"): + sequence_length = int(v.split("_")[1]) + elif v.startswith("seqpos"): + sequence_position = v.split("_")[1] + elif v == "fbocc": + fbocc_transform = True + else: + raise ValueError(f"Invalid arg: {v}") + + if is_train: + if self.train_crop_size is None: + cy, cx = ( + md(368, self._get_model_output_stride()), + md(768, self._get_model_output_stride()), + ) + self.train_crop_size = (cy, cx) + logger.warning( + "--train_crop_size is not set. It will be set as ({}, {}).", cy, cx + ) + else: + cy, cx = ( + md(self.train_crop_size[0], self._get_model_output_stride()), + md(self.train_crop_size[1], self._get_model_output_stride()), + ) + + # These transforms are based on RAFT: https://github.com/princeton-vl/RAFT + transform = ft.Compose( + [ + ft.ToTensor(device=device, fp16=self.train_transform_fp16), + ft.RandomScaleAndCrop( + (cy, cx), (-0.5, 0.2), (-0.2, 0.2), sparse=True + ), + ft.ColorJitter(0.4, 0.4, 0.4, 0.5 / 3.14, 0.2), + ft.GaussianNoise(0.02), + ft.RandomPatchEraser( + 0.5, (int(1), int(3)), (int(50), int(100)), "mean" + ), + ft.RandomFlip(min(0.5, 0.5), min(0.1, 0.5)), + ( + ft.GenerateFBCheckFlowOcclusion(threshold=1) + if fbocc_transform + else None + ), + ] + ) + else: + transform = ft.ToTensor() + + dataset = Hd1kDataset( + self.hd1k_root_dir, + split=split, + transform=transform, + sequence_length=sequence_length, + sequence_position=sequence_position, + ) + return dataset + + def _get_kitti_dataset(self, is_train: bool, *args: str) -> Dataset: + device = "cuda" if self.train_transform_cuda else "cpu" + md = make_divisible + + versions = ["2012", "2015"] + split = "trainval" + get_occlusion_mask = False + fbocc_transform = False + for v in args: + if v in ["2012", "2015"]: + versions = [v] + elif v in ["train", "val", "trainval", "test"]: + split = v + elif v == "occ": + get_occlusion_mask = True + elif v == "fbocc": + fbocc_transform = True + else: + raise ValueError(f"Invalid arg: {v}") + + if is_train: + if self.train_crop_size is None: + cy, cx = ( + md(288, self._get_model_output_stride()), + md(960, self._get_model_output_stride()), + ) + # cy, cx = (md(416, self._get_model_output_stride()), md(960, self._get_model_output_stride())) + self.train_crop_size = (cy, cx) + logger.warning( + "--train_crop_size is not set. It will be set as ({}, {}).", cy, cx + ) + else: + cy, cx = ( + md(self.train_crop_size[0], self._get_model_output_stride()), + md(self.train_crop_size[1], self._get_model_output_stride()), + ) + + # These transforms are based on RAFT: https://github.com/princeton-vl/RAFT + transform = ft.Compose( + [ + ft.ToTensor(device=device, fp16=self.train_transform_fp16), + ft.RandomScaleAndCrop( + (cy, cx), (-0.2, 0.4), (-0.2, 0.2), sparse=True + ), + ft.ColorJitter(0.4, 0.4, 0.4, 0.5 / 3.14, 0.2), + ft.GaussianNoise(0.02), + ft.RandomPatchEraser( + 0.5, (int(1), int(3)), (int(50), int(100)), "mean" + ), + ( + ft.GenerateFBCheckFlowOcclusion(threshold=1) + if fbocc_transform + else None + ), + ] + ) + else: + transform = ft.ToTensor() + + dataset = KittiDataset( + self.kitti_2012_root_dir, + self.kitti_2015_root_dir, + get_occlusion_mask=get_occlusion_mask, + versions=versions, + split=split, + transform=transform, + ) + return dataset + + def _get_kubric_dataset(self, is_train: bool, *args: str) -> Dataset: + if is_train: + raise NotImplementedError() + else: + transform = ft.ToTensor() + + get_backward = False + sequence_length = 2 + sequence_position = "first" + max_seq = None + for v in args: + if v == "back": + get_backward = True + elif v.startswith("seqlen"): + sequence_length = int(v.split("_")[1]) + elif v.startswith("seqpos"): + sequence_position = v.split("_")[1] + elif v.startswith("maxseq"): + max_seq = int(v.split("_")[1]) + + dataset = KubricDataset( + self.kubric_root_dir, + transform=transform, + get_backward=get_backward, + sequence_length=sequence_length, + sequence_position=sequence_position, + max_seq=max_seq, + ) + return dataset + + def _get_sintel_dataset(self, is_train: bool, *args: str) -> Dataset: + device = "cuda" if self.train_transform_cuda else "cpu" + md = make_divisible + + pass_names = ["clean", "final"] + split = "trainval" + get_occlusion_mask = False + sequence_length = 2 + sequence_position = "first" + fbocc_transform = False + for v in args: + if v in ["clean", "final"]: + pass_names = [v] + elif v in ["train", "val", "trainval", "test"]: + split = v + elif v == "occ": + get_occlusion_mask = True + elif v.startswith("seqlen"): + sequence_length = int(v.split("_")[1]) + elif v.startswith("seqpos"): + sequence_position = v.split("_")[1] + elif v == "fbocc": + fbocc_transform = True + else: + raise ValueError(f"Invalid arg: {v}") + + if is_train: + if self.train_crop_size is None: + cy, cx = ( + md(368, self._get_model_output_stride()), + md(768, self._get_model_output_stride()), + ) + self.train_crop_size = (cy, cx) + logger.warning( + "--train_crop_size is not set. It will be set as ({}, {}).", cy, cx + ) + else: + cy, cx = ( + md(self.train_crop_size[0], self._get_model_output_stride()), + md(self.train_crop_size[1], self._get_model_output_stride()), + ) + + # These transforms are based on RAFT: https://github.com/princeton-vl/RAFT + transform = ft.Compose( + [ + ft.ToTensor(device=device, fp16=self.train_transform_fp16), + ft.RandomScaleAndCrop((cy, cx), (-0.2, 0.6), (-0.2, 0.2)), + ft.ColorJitter(0.4, 0.4, 0.4, 0.5 / 3.14, 0.2), + ft.GaussianNoise(0.02), + ft.RandomPatchEraser( + 0.5, (int(1), int(3)), (int(50), int(100)), "mean" + ), + ft.RandomFlip(min(0.5, 0.5), min(0.1, 0.5)), + ( + ft.GenerateFBCheckFlowOcclusion(threshold=1) + if fbocc_transform + else None + ), + ] + ) + else: + transform = ft.ToTensor() + + dataset = SintelDataset( + self.mpi_sintel_root_dir, + split=split, + pass_names=pass_names, + transform=transform, + get_occlusion_mask=get_occlusion_mask, + sequence_length=sequence_length, + sequence_position=sequence_position, + ) + return dataset + + def _get_sintel_finetune_dataset(self, is_train: bool, *args: str) -> Dataset: + device = "cuda" if self.train_transform_cuda else "cpu" + md = make_divisible + + fbocc_transform = False + searaft_split = False + for v in args: + if v == "fbocc": + fbocc_transform = True + elif v == "searaft_split": + searaft_split = True + else: + raise ValueError(f"Invalid arg: {v}") + + if is_train: + if self.train_crop_size is None: + cy, cx = ( + md(368, self._get_model_output_stride()), + md(768, self._get_model_output_stride()), + ) + # cy, cx = (md(416, self._get_model_output_stride()), md(960, self._get_model_output_stride())) + self.train_crop_size = (cy, cx) + logger.warning( + "--train_crop_size is not set. It will be set as ({}, {}).", cy, cx + ) + else: + cy, cx = ( + md(self.train_crop_size[0], self._get_model_output_stride()), + md(self.train_crop_size[1], self._get_model_output_stride()), + ) + + # These transforms are based on RAFT: https://github.com/princeton-vl/RAFT + transform1 = ft.Compose( + [ + ft.ToTensor(device=device, fp16=self.train_transform_fp16), + ft.RandomScaleAndCrop((cy, cx), (-0.2, 0.6), (-0.2, 0.2)), + ft.ColorJitter(0.4, 0.4, 0.4, 0.5 / 3.14, 0.2), + ft.RandomPatchEraser( + 0.5, (int(1), int(3)), (int(50), int(100)), "mean" + ), + ft.RandomFlip(min(0.5, 0.5), min(0.1, 0.5)), + ( + ft.GenerateFBCheckFlowOcclusion(threshold=1) + if fbocc_transform + else None + ), + ] + ) + things_dataset = FlyingThings3DDataset( + self.flying_things3d_root_dir, + split="train", + pass_names=["clean"], + side_names=["left"], + transform=transform1, + get_backward=False, + get_motion_boundary_mask=False, + get_occlusion_mask=False, + ) + + sintel_clean_dataset = SintelDataset( + self.mpi_sintel_root_dir, + split="trainval", + pass_names=["clean"], + transform=transform1, + get_occlusion_mask=False, + ) + sintel_clean_mult_dataset = sintel_clean_dataset + for _ in range(19 if searaft_split else 99): + sintel_clean_mult_dataset += sintel_clean_dataset + + sintel_final_dataset = SintelDataset( + self.mpi_sintel_root_dir, + split="trainval", + pass_names=["final"], + transform=transform1, + get_occlusion_mask=False, + ) + sintel_final_mult_dataset = sintel_final_dataset + for _ in range(19 if searaft_split else 99): + sintel_final_mult_dataset += sintel_final_dataset + + transform2 = ft.Compose( + [ + ft.ToTensor(device=device, fp16=self.train_transform_fp16), + ft.RandomScaleAndCrop( + (cy, cx), (-0.3, 0.5), (-0.2, 0.2), sparse=True + ), + ft.ColorJitter(0.4, 0.4, 0.4, 0.5 / 3.14, 0.2), + ft.RandomPatchEraser( + 0.5, (int(1), int(3)), (int(50), int(100)), "mean" + ), + ft.RandomFlip(min(0.5, 0.5), min(0.1, 0.5)), + ( + ft.GenerateFBCheckFlowOcclusion(threshold=1) + if fbocc_transform + else None + ), + ] + ) + kitti_dataset = KittiDataset( + root_dir_2015=self.kitti_2015_root_dir, + split="trainval", + versions=["2015"], + transform=transform2, + get_occlusion_mask=False, + ) + kitti_mult_dataset = kitti_dataset + for _ in range(79 if searaft_split else 199): + kitti_mult_dataset += kitti_dataset + + transform3 = ft.Compose( + [ + ft.ToTensor(device=device, fp16=self.train_transform_fp16), + ft.RandomScaleAndCrop( + (cy, cx), (-0.5, 0.2), (-0.2, 0.2), sparse=True + ), + ft.ColorJitter(0.4, 0.4, 0.4, 0.5 / 3.14, 0.2), + ft.RandomPatchEraser( + 0.5, (int(1), int(3)), (int(50), int(100)), "mean" + ), + ft.RandomFlip(min(0.5, 0.5), min(0.1, 0.5)), + ( + ft.GenerateFBCheckFlowOcclusion(threshold=1) + if fbocc_transform + else None + ), + ] + ) + hd1k_dataset = Hd1kDataset( + self.hd1k_root_dir, split="trainval", transform=transform3 + ) + hd1k_mult_dataset = hd1k_dataset + for _ in range(29 if searaft_split else 4): + hd1k_mult_dataset += hd1k_dataset + + mixed_dataset = ( + things_dataset + + sintel_clean_mult_dataset + + sintel_final_mult_dataset + + kitti_mult_dataset + + hd1k_mult_dataset + ) + + logger.info("Loaded datasets:") + logger.info( + "FlyingThings3D: unique samples {} - multiplied samples {}", + len(things_dataset), + len(things_dataset), + ) + logger.info( + "Sintel clean: unique samples {} - multiplied samples {}", + len(sintel_clean_dataset), + len(sintel_clean_mult_dataset), + ) + logger.info( + "Sintel final: unique samples {} - multiplied samples {}", + len(sintel_final_dataset), + len(sintel_final_mult_dataset), + ) + logger.info( + "KITTI 2015: unique samples {} - multiplied samples {}", + len(kitti_dataset), + len(kitti_mult_dataset), + ) + logger.info( + "HD1K: unique samples {} - multiplied samples {}", + len(hd1k_dataset), + len(hd1k_mult_dataset), + ) + logger.info("Total dataset size: {}", len(mixed_dataset)) + else: + raise NotImplementedError() + + return mixed_dataset + + def _get_spring_dataset(self, is_train: bool, *args: str) -> Dataset: + device = "cuda" if self.train_transform_cuda else "cpu" + md = make_divisible + + split = "train" + add_reverse = False + get_backward = False + sequence_length = 2 + sequence_position = "first" + reverse_only = False + subsample = False + side_names = [] + fbocc_transform = False + for v in args: + if v in ["train", "val", "trainval", "test"]: + split = v + elif v == "rev": + add_reverse = True + elif v == "revonly": + reverse_only = True + elif v == "back": + get_backward = True + elif v.startswith("seqlen"): + sequence_length = int(v.split("_")[1]) + elif v.startswith("seqpos"): + sequence_position = v.split("_")[1] + elif v == "sub": + subsample = True + elif v == "left": + side_names.append("left") + elif v == "right": + side_names.append("right") + elif v == "fbocc": + fbocc_transform = True + else: + raise ValueError(f"Invalid arg: {v}") + + if is_train: + if self.train_crop_size is None: + cy, cx = ( + md(540, self._get_model_output_stride()), + md(960, self._get_model_output_stride()), + ) + self.train_crop_size = (cy, cx) + logger.warning( + "--train_crop_size is not set. It will be set as ({}, {}).", cy, cx + ) + else: + cy, cx = ( + md(self.train_crop_size[0], self._get_model_output_stride()), + md(self.train_crop_size[1], self._get_model_output_stride()), + ) + + # Transforms copied from SEA-RAFT + transform = ft.Compose( + [ + ft.ToTensor(device=device, fp16=self.train_transform_fp16), + ft.RandomScaleAndCrop((cy, cx), (0.0, 0.2), (-0.2, 0.2)), + ft.ColorJitter(0.4, 0.4, 0.4, 0.5 / 3.14, 0.2), + ft.GaussianNoise(0.02), + ft.RandomPatchEraser( + 0.5, (int(1), int(3)), (int(50), int(100)), "mean" + ), + ft.RandomFlip(min(0.5, 0.5), min(0.1, 0.5)), + ( + ft.GenerateFBCheckFlowOcclusion(threshold=1) + if fbocc_transform + else None + ), + ] + ) + else: + transform = ft.ToTensor() + + if len(side_names) == 0: + side_names = ["left", "right"] + + dataset = SpringDataset( + self.spring_root_dir, + split=split, + side_names=side_names, + add_reverse=add_reverse, + transform=transform, + get_backward=get_backward, + sequence_length=sequence_length, + sequence_position=sequence_position, + reverse_only=reverse_only, + subsample=subsample, + ) + return dataset + + def _get_tartanair_dataset(self, is_train: bool, *args: str) -> Dataset: + device = "cuda" if self.train_transform_cuda else "cpu" + md = make_divisible + + get_occlusion_mask = False + sequence_length = 2 + sequence_position = "first" + difficulties = [] + fbocc_transform = False + for v in args: + if v in ["easy", "hard"]: + difficulties.append(v) + elif v == "occ": + get_occlusion_mask = True + elif v.startswith("seqlen"): + sequence_length = int(v.split("_")[1]) + elif v.startswith("seqpos"): + sequence_position = v.split("_")[1] + elif v == "fbocc": + fbocc_transform = True + else: + raise ValueError(f"Invalid arg: {v}") + + if len(difficulties) == 0: + difficulties = ["easy"] + + if is_train: + if self.train_crop_size is None: + cy, cx = ( + md(360, self._get_model_output_stride()), + md(480, self._get_model_output_stride()), + ) + self.train_crop_size = (cy, cx) + logger.warning( + "--train_crop_size is not set. It will be set as ({}, {}).", cy, cx + ) + else: + cy, cx = ( + md(self.train_crop_size[0], self._get_model_output_stride()), + md(self.train_crop_size[1], self._get_model_output_stride()), + ) + + # These transforms are based on RAFT: https://github.com/princeton-vl/RAFT + transform = ft.Compose( + [ + ft.ToTensor(device=device, fp16=self.train_transform_fp16), + ft.RandomScaleAndCrop((cy, cx), (-0.4, 0.8), (-0.2, 0.2)), + ft.ColorJitter(0.4, 0.4, 0.4, 0.5 / 3.14, 0.2), + ft.GaussianNoise(0.02), + ft.RandomPatchEraser( + 0.5, (int(1), int(3)), (int(50), int(100)), "mean" + ), + ft.RandomFlip(min(0.5, 0.5), min(0.1, 0.5)), + ( + ft.GenerateFBCheckFlowOcclusion(threshold=1) + if fbocc_transform + else None + ), + ] + ) + else: + transform = ft.ToTensor() + + dataset = TartanAirDataset( + self.tartanair_root_dir, + difficulties=difficulties, + transform=transform, + get_occlusion_mask=get_occlusion_mask, + sequence_length=sequence_length, + sequence_position=sequence_position, + ) + return dataset + + def _get_things_dataset(self, is_train: bool, *args: str) -> Dataset: + device = "cuda" if self.train_transform_cuda else "cpu" + md = make_divisible + + pass_names = ["clean", "final"] + split = "trainval" + is_subset = False + add_reverse = False + get_occlusion_mask = False + get_motion_boundary_mask = False + get_backward = False + sequence_length = 2 + sequence_position = "first" + sintel_transform = False + fbocc_transform = False + for v in args: + if v in ["clean", "final"]: + pass_names = [v] + elif v in ["train", "val", "trainval"]: + split = v + elif v == "subset": + is_subset = True + elif v == "rev": + add_reverse = True + elif v == "occ": + get_occlusion_mask = True + elif v == "mb": + get_motion_boundary_mask = True + elif v == "back": + get_backward = True + elif v.startswith("seqlen"): + sequence_length = int(v.split("_")[1]) + elif v.startswith("seqpos"): + sequence_position = v.split("_")[1] + elif v == "sinteltransform": + sintel_transform = True + elif v == "fbocc": + fbocc_transform = True + else: + raise ValueError(f"Invalid arg: {v}") + + if is_train: + if self.train_crop_size is None: + cy, cx = ( + md(400, self._get_model_output_stride()), + md(720, self._get_model_output_stride()), + ) + # cy, cx = (md(416, self._get_model_output_stride()), md(960, self._get_model_output_stride())) + self.train_crop_size = (cy, cx) + logger.warning( + "--train_crop_size is not set. It will be set as ({}, {}).", cy, cx + ) + else: + cy, cx = ( + md(self.train_crop_size[0], self._get_model_output_stride()), + md(self.train_crop_size[1], self._get_model_output_stride()), + ) + + # These transforms are based on RAFT: https://github.com/princeton-vl/RAFT + if sintel_transform: + major_scale = (-0.2, 0.6) + else: + major_scale = (-0.4, 0.8) + transform = ft.Compose( + [ + ft.ToTensor(device=device, fp16=self.train_transform_fp16), + ft.RandomScaleAndCrop((cy, cx), major_scale, (-0.2, 0.2)), + ft.ColorJitter(0.4, 0.4, 0.4, 0.5 / 3.14, 0.2), + ft.GaussianNoise(0.02), + ft.RandomPatchEraser( + 0.5, (int(1), int(3)), (int(50), int(100)), "mean" + ), + ft.RandomFlip(min(0.5, 0.5), min(0.1, 0.5)), + ( + ft.GenerateFBCheckFlowOcclusion(threshold=1) + if fbocc_transform + else None + ), + ] + ) + else: + transform = ft.ToTensor() + + if is_subset: + dataset = FlyingThings3DSubsetDataset( + self.flying_things3d_subset_root_dir, + split=split, + pass_names=pass_names, + side_names=["left", "right"], + add_reverse=add_reverse, + transform=transform, + get_occlusion_mask=get_occlusion_mask, + get_motion_boundary_mask=get_motion_boundary_mask, + get_backward=get_backward, + sequence_length=sequence_length, + sequence_position=sequence_position, + ) + else: + dataset = FlyingThings3DDataset( + self.flying_things3d_root_dir, + split=split, + pass_names=pass_names, + side_names=["left", "right"], + add_reverse=add_reverse, + transform=transform, + get_occlusion_mask=get_occlusion_mask, + get_motion_boundary_mask=get_motion_boundary_mask, + get_backward=get_backward, + sequence_length=sequence_length, + sequence_position=sequence_position, + ) + return dataset + + def _get_middlebury_st_dataset(self, is_train: bool, *args: str) -> Dataset: + assert not is_train + transform = ft.ToTensor() + + dataset = MiddleburySTDataset( + self.middlebury_st_root_dir, + transform=transform, + ) + return dataset + + def _get_viper_dataset(self, is_train: bool, *args: str) -> Dataset: + assert not is_train + transform = ft.ToTensor() + + dataset = ViperDataset( + self.viper_root_dir, + split="val", + transform=transform, + ) + return dataset + + def _get_overfit_dataset(self, is_train: bool, *args: str) -> Dataset: + md = make_divisible + if self.train_crop_size is None: + cy, cx = ( + md(436, self._get_model_output_stride()), + md(1024, self._get_model_output_stride()), + ) + self.train_crop_size = (cy, cx) + logger.warning( + "--train_crop_size is not set. It will be set as ({}, {}).", cy, cx + ) + else: + cy, cx = ( + md(self.train_crop_size[0], self._get_model_output_stride()), + md(self.train_crop_size[1], self._get_model_output_stride()), + ) + transform = ft.Compose([ft.ToTensor(), ft.Resize((cy, cx))]) + + dataset_name = "sintel" + if len(args) > 0 and args[0] in ["chairs2"]: + dataset_name = args[0] + + if dataset_name == "sintel": + dataset = SintelDataset( + self.mpi_sintel_root_dir, + split="trainval", + pass_names="clean", + transform=transform, + get_occlusion_mask=False, + ) + elif dataset_name == "chairs2": + dataset = FlyingChairs2Dataset( + self.flying_chairs2_root_dir, + split="trainval", + transform=transform, + add_reverse=False, + get_occlusion_mask=True, + get_motion_boundary_mask=True, + get_backward=True, + ) + + dataset.img_paths = dataset.img_paths[:1] + dataset.flow_paths = dataset.flow_paths[:1] + dataset.occ_paths = dataset.occ_paths[:1] + dataset.mb_paths = dataset.mb_paths[:1] + dataset.flow_b_paths = dataset.flow_b_paths[:1] + dataset.occ_b_paths = dataset.occ_b_paths[:1] + dataset.mb_b_paths = dataset.mb_b_paths[:1] + dataset.metadata = dataset.metadata[:1] + + return dataset diff --git a/ptlflow/data/flow_transforms.py b/ptlflow/data/flow_transforms.py index 6fb49e1..cb67afd 100644 --- a/ptlflow/data/flow_transforms.py +++ b/ptlflow/data/flow_transforms.py @@ -47,7 +47,7 @@ def __init__(self, transforms_list: Sequence[object]) -> None: transforms_list : Sequence[object] A sequence of transforms to be applied. """ - self.transforms_list = transforms_list + self.transforms_list = [t for t in transforms_list if t is not None] def __call__( self, inputs: Dict[str, Union[np.ndarray, Sequence[np.ndarray]]] @@ -136,6 +136,177 @@ def __call__( return inputs +class GenerateFBCheckFlowOcclusion(object): + """Generate occlusion masks based on forward/backward flow consistency. + + In other words, a pixel p is considered occluded when \|Ff(p) + Fb(p + F(f))\|_2 > threshold, + where Ff and Fb get the forward and backward flow vectors of a pixel. + """ + + def __init__( + self, + threshold: float = 0.0, + forward_flow_key: str = "flows", + backward_flow_key: str = "flows_b", + forward_occlusion_key: str = "occs", + backward_occlusion_key: str = "occs_b", + compute_backward_occlusion: bool = True, + ) -> None: + """Initialize ColorJitter. + + Parameters + ---------- + threshold : float, default 0.0 + A pixel is considered occluded if \|Ff(p) + Fb(p + F(f))\|_2 > threshold. + forward_flow_key : str, default "flows" + The name of the input dict entry that stores the forward flow. + backward_flow_key : str, default "flows_b" + The name of the input dict entry that stores the backward flow. + forward_occlusion_key : str, default "occs" + The name that will be added to the input dict to store the calculated occlusion masks for the forward flow. + backward_occlusion_key : str, default "occs" + The name that will be added to the input dict to store the calculated occlusion masks for the backward flow. + compute_backward_occlusion : bool, default True + If False, the occlusion mask is calculated only for the forward flow. + """ + self.threshold = threshold + self.forward_flow_key = forward_flow_key + self.backward_flow_key = backward_flow_key + self.forward_occlusion_key = forward_occlusion_key + self.backward_occlusion_key = backward_occlusion_key + self.compute_backward_occlusion = compute_backward_occlusion + + def __call__(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """Perform the transformation on the inputs. + + Parameters + ---------- + inputs : Dict[str, torch.Tensor] + Elements to be transformed. Each element is a 4D tensor NCHW. + + Returns + ------- + Dict[str, torch.Tensor] + The inputs transformed by this operation. + """ + assert self.forward_flow_key in inputs + assert self.backward_flow_key in inputs + flow_f = inputs[self.forward_flow_key] + flow_b = inputs[self.backward_flow_key] + b, c, h, w = flow_f.shape + + coords = self._coords_grid(b, h, w, dtype=flow_f.dtype, device=flow_f.device) + flow_b_warped, oob_mask_f = self._bilinear_sampler(flow_b, coords + flow_f) + diff_f = torch.norm(flow_f + flow_b_warped, p=2, dim=1, keepdim=True) + inputs[self.forward_occlusion_key] = ( + (diff_f > self.threshold) | oob_mask_f + ).float() + + if self.compute_backward_occlusion: + flow_f_warped, oob_mask_b = self._bilinear_sampler(flow_f, coords + flow_b) + diff_b = torch.norm(flow_b + flow_f_warped, p=2, dim=1, keepdim=True) + inputs[self.backward_occlusion_key] = ( + (diff_b > self.threshold) | oob_mask_b + ).float() + + return inputs + + def _bilinear_sampler(self, img, coords): + # Code adapted from RAFT: https://github.com/princeton-vl/RAFT + H, W = img.shape[-2:] + xgrid, ygrid = coords.split([1, 1], dim=1) + xgrid = 2 * xgrid / (W - 1) - 1 + ygrid = 2 * ygrid / (H - 1) - 1 + + grid = torch.cat([xgrid, ygrid], dim=1) + img = F.grid_sample(img, grid.permute(0, 2, 3, 1), align_corners=True) + + mask = (xgrid < -1) | (ygrid < -1) | (xgrid > 1) | (ygrid > 1) + return img, mask + + def _coords_grid(self, batch, ht, wd, dtype, device): + # Code adapted from RAFT: https://github.com/princeton-vl/RAFT + coords = torch.meshgrid( + torch.arange(ht, dtype=dtype, device=device), + torch.arange(wd, dtype=dtype, device=device), + indexing="ij", + ) + coords = torch.stack(coords[::-1], dim=0) + return coords[None].repeat(batch, 1, 1, 1) + + +class CenterCrop(object): + """Applies center crop to the inputs.""" + + def __init__( + self, + crop_size: Optional[Tuple[int, int]] = None, + occlusion_keys: Union[KeysView, Sequence[str]] = ("occs", "occs_b"), + valid_key: str = "valids", + ignore_keys: Optional[Sequence[str]] = None, + ) -> None: + """Initialize RandomScaleAndCrop. + + Parameters + ---------- + crop_size : Optional[Tuple[int, int]], optional + If provided, crop the inputs to this size (h, w). + occlusion_keys : Union[KeysView, Sequence[str]], default ['occs', 'occs_b'] + Indicate which of the input keys correspond to occlusion mask tensors. + valid_keys : str, default 'valids' + The name of the key in inputs that contains the binary mask indicating which pixels are valid. + Only used when sparse=True. + """ + self.crop_size = crop_size + self.occlusion_keys = list(occlusion_keys) + self.valid_key = valid_key + self.ignore_keys = ignore_keys + + def __call__( # noqa: C901 + self, inputs: Dict[str, torch.Tensor] + ) -> Dict[str, torch.Tensor]: + """Perform the transformation on the inputs. + + Parameters + ---------- + inputs : Dict[str, torch.Tensor] + Elements to be transformed. Each element is a 4D tensor NCHW. + + Returns + ------- + Dict[str, torch.Tensor] + The inputs transformed by this operation. + + Raises + ------ + NotImplementedError + If trying to use time scale. + """ + h, w = inputs[self.valid_key].shape[2:4] + y_crop = (h - self.crop_size[0]) // 2 + x_crop = (w - self.crop_size[1]) // 2 + + for k, v in inputs.items(): + if k not in self.ignore_keys: + v = v[ + :, + :, + y_crop : y_crop + self.crop_size[0], + x_crop : x_crop + self.crop_size[1], + ] + inputs[k] = v + + # Update occlusion masks for out-of-bounds flows + for k, v in inputs.items(): + if self.ignore_keys is None or k not in self.ignore_keys: + try: + i = self.occlusion_keys.index(k) + inputs[k] = _update_oob_flows(v, inputs[self.flow_keys[i]]) + except ValueError: + pass + return inputs + + class ColorJitter(tt.ColorJitter): """Randomly apply color transformations only to the images. @@ -989,6 +1160,7 @@ def __init__( flow_keys: Union[KeysView, Sequence[str]] = ("flows", "flows_b"), sparse: bool = False, valid_key: str = "valids", + ignore_keys: Optional[Union[KeysView, Sequence[str]]] = None, ) -> None: """Initialize Resize. @@ -1009,6 +1181,8 @@ def __init__( valid_keys : str, default 'valids' The name of the key in inputs that contains the binary mask indicating which pixels are valid. Only used when sparse=True. + ignore_keys : Optional[Union[KeysView, Sequence[str]]] + If not None, remove these keys from the inputs_keys. """ self.size = size self.scale = scale @@ -1016,6 +1190,7 @@ def __init__( self.flow_keys = list(flow_keys) self.sparse = sparse self.valid_key = valid_key + self.ignore_keys = ignore_keys def __call__(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """Perform the transformation on the inputs. @@ -1041,6 +1216,7 @@ def __call__(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: self.flow_keys, self.sparse, self.valid_key, + ignore_keys=self.ignore_keys, ) return inputs @@ -1082,6 +1258,7 @@ def _resize( flow_keys: Union[KeysView, Sequence[str]], sparse: bool, valid_key: str, + ignore_keys: Optional[Sequence[str]] = None, ): """Resize inputs to a target size. Set sparse=True when the valid mask has holes. This ensures that the resized valid mask does not interpolate the valid positions. @@ -1116,7 +1293,7 @@ def _resize( valids = inputs[valid_key] n, k, h, w = valids.shape hs, ws = target_size - scale_factor = torch.Tensor( + scale_factor = torch.tensor( [float(ws) / w, float(hs) / h], device=valids.device ) valids_flat = rearrange(valids, "n k h w -> n (k h w)") @@ -1151,7 +1328,7 @@ def _resize( inputs[valid_key] = valids_out for k, v in inputs.items(): - if k != valid_key: + if k != valid_key and (ignore_keys is None or k not in ignore_keys): if k in binary_keys or k in flow_keys: v_out = torch.zeros( v.shape[0], v.shape[1], hs, ws, dtype=v.dtype, device=v.device @@ -1163,9 +1340,9 @@ def _resize( v_valid = v_valid * scale_factor v_valid = v_valid[inbounds_list[i]] v_valid = rearrange(v_valid, "n k -> k n") - v_out[ - i, :, xy_scaled_list[i][1], xy_scaled_list[i][0] - ] = v_valid + v_out[i, :, xy_scaled_list[i][1], xy_scaled_list[i][0]] = ( + v_valid + ) v = v_out else: v = F.interpolate( @@ -1174,20 +1351,21 @@ def _resize( inputs[k] = v else: for k, v in inputs.items(): - h, w = v.shape[-2:] - if k in binary_keys: - v = F.interpolate(v, size=target_size, mode="nearest") - else: - v = F.interpolate( - v, size=target_size, mode="bilinear", align_corners=True - ) - - if k in flow_keys: - scale_mult = torch.Tensor( - [float(target_size[1]) / w, float(target_size[0]) / h], - device=v.device, - )[None, :, None, None] - v = v * scale_mult + if ignore_keys is None or k not in ignore_keys: + h, w = v.shape[-2:] + if k in binary_keys: + v = F.interpolate(v, size=target_size, mode="nearest") + else: + v = F.interpolate( + v, size=target_size, mode="bilinear", align_corners=True + ) + + if k in flow_keys: + scale_mult = torch.tensor( + [float(target_size[1]) / w, float(target_size[0]) / h], + device=v.device, + )[None, :, None, None] + v = v * scale_mult inputs[k] = v