From fd0402bf0fdd09ed4a098840f8300938aa687619 Mon Sep 17 00:00:00 2001 From: hzphzp Date: Tue, 29 Aug 2023 10:35:56 +0800 Subject: [PATCH] fix .gitignore and upload missing dir --- .gitignore | 1 - Adaptive Frequency Filters/data/__init__.py | 6 + .../data/collate_fns/__init__.py | 94 + .../data/collate_fns/collate_functions.py | 42 + .../data/data_loaders.py | 137 ++ .../data/datasets/__init__.py | 288 +++ .../data/datasets/classification/__init__.py | 8 + .../data/datasets/classification/imagenet.py | 220 ++ .../datasets/classification/imagenet_fast.py | 197 ++ .../classification/imagenet_opencv.py | 161 ++ .../imagenet_opencv_bitplane_fast.py | 158 ++ .../classification/imagenet_opencv_fast.py | 160 ++ .../datasets/classification/imagenet_v2.py | 173 ++ .../data/datasets/dataset_base.py | 230 ++ .../data/datasets/detection/__init__.py | 0 .../data/datasets/detection/coco_base.py | 342 +++ .../data/datasets/detection/coco_mask_rcnn.py | 150 ++ .../data/datasets/detection/coco_ssd.py | 224 ++ .../data/datasets/segmentation/__init__.py | 0 .../data/datasets/segmentation/ade20k.py | 521 ++++ .../segmentation/coco_segmentation.py | 230 ++ .../data/datasets/segmentation/pascal_voc.py | 275 +++ .../data/loader/__init__.py | 0 .../data/loader/dataloader.py | 54 + .../data/sampler/__init__.py | 116 + .../data/sampler/base_sampler.py | 295 +++ .../data/sampler/batch_sampler.py | 155 ++ .../data/sampler/multi_scale_sampler.py | 339 +++ .../data/sampler/utils.py | 124 + .../data/sampler/variable_batch_sampler.py | 421 ++++ .../data/transforms/__init__.py | 56 + .../data/transforms/base_transforms.py | 26 + .../data/transforms/image_opencv.py | 1760 ++++++++++++++ .../data/transforms/image_pil.py | 2158 +++++++++++++++++ .../data/transforms/image_torch.py | 247 ++ .../data/transforms/utils.py | 47 + .../data/transforms/video.py | 608 +++++ 37 files changed, 10022 insertions(+), 1 deletion(-) create mode 100644 Adaptive Frequency Filters/data/__init__.py create mode 100644 Adaptive Frequency Filters/data/collate_fns/__init__.py create mode 100644 Adaptive Frequency Filters/data/collate_fns/collate_functions.py create mode 100644 Adaptive Frequency Filters/data/data_loaders.py create mode 100644 Adaptive Frequency Filters/data/datasets/__init__.py create mode 100644 Adaptive Frequency Filters/data/datasets/classification/__init__.py create mode 100644 Adaptive Frequency Filters/data/datasets/classification/imagenet.py create mode 100644 Adaptive Frequency Filters/data/datasets/classification/imagenet_fast.py create mode 100644 Adaptive Frequency Filters/data/datasets/classification/imagenet_opencv.py create mode 100644 Adaptive Frequency Filters/data/datasets/classification/imagenet_opencv_bitplane_fast.py create mode 100644 Adaptive Frequency Filters/data/datasets/classification/imagenet_opencv_fast.py create mode 100644 Adaptive Frequency Filters/data/datasets/classification/imagenet_v2.py create mode 100644 Adaptive Frequency Filters/data/datasets/dataset_base.py create mode 100644 Adaptive Frequency Filters/data/datasets/detection/__init__.py create mode 100644 Adaptive Frequency Filters/data/datasets/detection/coco_base.py create mode 100644 Adaptive Frequency Filters/data/datasets/detection/coco_mask_rcnn.py create mode 100644 Adaptive Frequency Filters/data/datasets/detection/coco_ssd.py create mode 100644 Adaptive Frequency Filters/data/datasets/segmentation/__init__.py create mode 100644 Adaptive Frequency Filters/data/datasets/segmentation/ade20k.py create mode 100644 Adaptive Frequency Filters/data/datasets/segmentation/coco_segmentation.py create mode 100644 Adaptive Frequency Filters/data/datasets/segmentation/pascal_voc.py create mode 100644 Adaptive Frequency Filters/data/loader/__init__.py create mode 100644 Adaptive Frequency Filters/data/loader/dataloader.py create mode 100644 Adaptive Frequency Filters/data/sampler/__init__.py create mode 100644 Adaptive Frequency Filters/data/sampler/base_sampler.py create mode 100644 Adaptive Frequency Filters/data/sampler/batch_sampler.py create mode 100644 Adaptive Frequency Filters/data/sampler/multi_scale_sampler.py create mode 100644 Adaptive Frequency Filters/data/sampler/utils.py create mode 100644 Adaptive Frequency Filters/data/sampler/variable_batch_sampler.py create mode 100644 Adaptive Frequency Filters/data/transforms/__init__.py create mode 100644 Adaptive Frequency Filters/data/transforms/base_transforms.py create mode 100644 Adaptive Frequency Filters/data/transforms/image_opencv.py create mode 100644 Adaptive Frequency Filters/data/transforms/image_pil.py create mode 100644 Adaptive Frequency Filters/data/transforms/image_torch.py create mode 100644 Adaptive Frequency Filters/data/transforms/utils.py create mode 100644 Adaptive Frequency Filters/data/transforms/video.py diff --git a/.gitignore b/.gitignore index ac50f0f..119b625 100644 --- a/.gitignore +++ b/.gitignore @@ -143,7 +143,6 @@ cython_debug/ .vscode/ abc/ xyz/ -data/ **/_backup_/** **/exp/** **/ckpts/** diff --git a/Adaptive Frequency Filters/data/__init__.py b/Adaptive Frequency Filters/data/__init__.py new file mode 100644 index 0000000..04ef584 --- /dev/null +++ b/Adaptive Frequency Filters/data/__init__.py @@ -0,0 +1,6 @@ +# +# For licensing see accompanying LICENSE file. +# Copyright (C) 2023 Apple Inc. All Rights Reserved. +# + +from .data_loaders import create_train_val_loader, create_eval_loader diff --git a/Adaptive Frequency Filters/data/collate_fns/__init__.py b/Adaptive Frequency Filters/data/collate_fns/__init__.py new file mode 100644 index 0000000..608a845 --- /dev/null +++ b/Adaptive Frequency Filters/data/collate_fns/__init__.py @@ -0,0 +1,94 @@ +# +# For licensing see accompanying LICENSE file. +# Copyright (C) 2023 Apple Inc. All Rights Reserved. +# + +import os +import importlib +import argparse + +COLLATE_FN_REGISTRY = {} + + +def register_collate_fn(name): + def register_collate_fn_method(f): + if name in COLLATE_FN_REGISTRY: + raise ValueError( + "Cannot register duplicate collate function ({})".format(name) + ) + COLLATE_FN_REGISTRY[name] = f + return f + + return register_collate_fn_method + + +def arguments_collate_fn(parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="Collate function arguments", description="Collate function arguments" + ) + group.add_argument( + "--dataset.collate-fn-name-train", + type=str, + default="default_collate_fn", + help="Name of collate function", + ) + group.add_argument( + "--dataset.collate-fn-name-val", + type=str, + default="default_collate_fn", + help="Name of collate function", + ) + group.add_argument( + "--dataset.collate-fn-name-eval", + type=str, + default=None, + help="Name of collate function used for evaluation. " + "Default is None, i.e., use PyTorch's inbuilt collate function", + ) + return parser + + +def build_collate_fn(opts, *args, **kwargs): + collate_fn_name_train = getattr( + opts, "dataset.collate_fn_name_train", "default_collate_fn" + ) + collate_fn_name_val = getattr( + opts, "dataset.collate_fn_name_val", "default_collate_fn" + ) + collate_fn_train = None + if ( + collate_fn_name_train is not None + and collate_fn_name_train in COLLATE_FN_REGISTRY + ): + collate_fn_train = COLLATE_FN_REGISTRY[collate_fn_name_train] + + collate_fn_val = None + if collate_fn_name_val is None: + collate_fn_val = collate_fn_name_train + elif collate_fn_name_val is not None and collate_fn_name_val in COLLATE_FN_REGISTRY: + collate_fn_val = COLLATE_FN_REGISTRY[collate_fn_name_val] + + return collate_fn_train, collate_fn_val + + +def build_eval_collate_fn(opts, *args, **kwargs): + collate_fn_name_eval = getattr(opts, "dataset.collate_fn_name_eval", None) + collate_fn_eval = None + if collate_fn_name_eval is not None and collate_fn_name_eval in COLLATE_FN_REGISTRY: + collate_fn_eval = COLLATE_FN_REGISTRY[collate_fn_name_eval] + + return collate_fn_eval + + +# automatically import the augmentations +collate_fn_dir = os.path.dirname(__file__) + +for file in os.listdir(collate_fn_dir): + path = os.path.join(collate_fn_dir, file) + if ( + not file.startswith("_") + and not file.startswith(".") + and (file.endswith(".py") or os.path.isdir(path)) + ): + collate_fn_fname = file[: file.find(".py")] if file.endswith(".py") else file + module = importlib.import_module("data.collate_fns." + collate_fn_fname) diff --git a/Adaptive Frequency Filters/data/collate_fns/collate_functions.py b/Adaptive Frequency Filters/data/collate_fns/collate_functions.py new file mode 100644 index 0000000..185d149 --- /dev/null +++ b/Adaptive Frequency Filters/data/collate_fns/collate_functions.py @@ -0,0 +1,42 @@ +# +# For licensing see accompanying LICENSE file. +# Copyright (C) 2023 Apple Inc. All Rights Reserved. +# +import numpy as np +import torch +from typing import List, Dict + +from utils import logger + +from . import register_collate_fn + + +@register_collate_fn(name="default_collate_fn") +def default_collate_fn(batch: List[Dict], opts): + """Default collate function""" + batch_size = len(batch) + + keys = list(batch[0].keys()) + + new_batch = {k: [] for k in keys} + for b in range(batch_size): + for k in keys: + new_batch[k].append(batch[b][k]) + + # stack the keys + for k in keys: + batch_elements = new_batch.pop(k) + + if isinstance(batch_elements[0], (int, float, np.integer, np.floating)): + # list of ints or floats + batch_elements = torch.as_tensor(batch_elements) + else: + # stack tensors (including 0-dimensional) + try: + batch_elements = torch.stack(batch_elements, dim=0).contiguous() + except Exception as e: + logger.error("Unable to stack the tensors. Error: {}".format(e)) + + new_batch[k] = batch_elements + + return new_batch diff --git a/Adaptive Frequency Filters/data/data_loaders.py b/Adaptive Frequency Filters/data/data_loaders.py new file mode 100644 index 0000000..e466e98 --- /dev/null +++ b/Adaptive Frequency Filters/data/data_loaders.py @@ -0,0 +1,137 @@ +# +# For licensing see accompanying LICENSE file. +# Copyright (C) 2023 Apple Inc. All Rights Reserved. +# + +import torch +from functools import partial + +from utils import logger +from utils.ddp_utils import is_master +from utils.tensor_utils import image_size_from_opts + +from .datasets import train_val_datasets, evaluation_datasets +from .sampler import build_sampler +from .collate_fns import build_collate_fn, build_eval_collate_fn +from .loader.dataloader import affnetDataLoader + + +def create_eval_loader(opts): + eval_dataset = evaluation_datasets(opts) + n_eval_samples = len(eval_dataset) + is_master_node = is_master(opts) + + # overwrite the validation argument + setattr( + opts, "dataset.val_batch_size0", getattr(opts, "dataset.eval_batch_size0", 1) + ) + + # we don't need variable batch sampler for evaluation + sampler_name = getattr(opts, "sampler.name", "batch_sampler") + crop_size_h, crop_size_w = image_size_from_opts(opts) + if sampler_name.find("video") > -1 and sampler_name != "video_batch_sampler": + clips_per_video = getattr(opts, "sampler.vbs.clips_per_video", 1) + frames_per_clip = getattr(opts, "sampler.vbs.num_frames_per_clip", 8) + setattr(opts, "sampler.name", "video_batch_sampler") + setattr(opts, "sampler.bs.crop_size_width", crop_size_w) + setattr(opts, "sampler.bs.crop_size_height", crop_size_h) + setattr(opts, "sampler.bs.clips_per_video", clips_per_video) + setattr(opts, "sampler.bs.num_frames_per_clip", frames_per_clip) + elif sampler_name.find("var") > -1: + setattr(opts, "sampler.name", "batch_sampler") + setattr(opts, "sampler.bs.crop_size_width", crop_size_w) + setattr(opts, "sampler.bs.crop_size_height", crop_size_h) + + eval_sampler = build_sampler( + opts=opts, n_data_samples=n_eval_samples, is_training=False + ) + + collate_fn_eval = build_eval_collate_fn(opts=opts) + + data_workers = getattr(opts, "dataset.workers", 1) + persistent_workers = False + pin_memory = False + + eval_loader = affnetDataLoader( + dataset=eval_dataset, + batch_size=1, + batch_sampler=eval_sampler, + num_workers=data_workers, + pin_memory=pin_memory, + persistent_workers=persistent_workers, + collate_fn=partial(collate_fn_eval, opts=opts) + if collate_fn_eval is not None + else None, + ) + + if is_master_node: + logger.log("Evaluation sampler details: ") + print("{}".format(eval_sampler)) + + return eval_loader + + +def create_train_val_loader(opts): + train_dataset, valid_dataset = train_val_datasets(opts) + + n_train_samples = len(train_dataset) + is_master_node = is_master(opts) + + train_sampler = build_sampler( + opts=opts, n_data_samples=n_train_samples, is_training=True + ) + if valid_dataset is not None: + n_valid_samples = len(valid_dataset) + valid_sampler = build_sampler( + opts=opts, n_data_samples=n_valid_samples, is_training=False + ) + else: + valid_sampler = None + + data_workers = getattr(opts, "dataset.workers", 1) + persistent_workers = getattr(opts, "dataset.persistent_workers", False) and ( + data_workers > 0 + ) + pin_memory = getattr(opts, "dataset.pin_memory", False) + prefetch_factor = getattr(opts, "dataset.prefetch_factor", 2) + + collate_fn_train, collate_fn_val = build_collate_fn(opts=opts) + + train_loader = affnetDataLoader( + dataset=train_dataset, + batch_size=1, # Handled inside data sampler + num_workers=data_workers, + pin_memory=pin_memory, + batch_sampler=train_sampler, + persistent_workers=persistent_workers, + collate_fn=partial(collate_fn_train, opts=opts) + if collate_fn_train is not None + else None, + prefetch_factor=prefetch_factor, + ) + + if valid_dataset is not None: + val_loader = affnetDataLoader( + dataset=valid_dataset, + batch_size=1, + batch_sampler=valid_sampler, + num_workers=data_workers, + pin_memory=pin_memory, + persistent_workers=persistent_workers, + collate_fn=partial(collate_fn_val, opts=opts) + if collate_fn_val is not None + else None, + ) + else: + val_loader = None + + if is_master_node: + logger.log("Training sampler details: ") + print("{}".format(train_sampler)) + + if valid_dataset is not None: + logger.log("Validation sampler details: ") + print("{}".format(valid_sampler)) + logger.log("Number of data workers: {}".format(data_workers)) + + return train_loader, val_loader, train_sampler diff --git a/Adaptive Frequency Filters/data/datasets/__init__.py b/Adaptive Frequency Filters/data/datasets/__init__.py new file mode 100644 index 0000000..4a907ce --- /dev/null +++ b/Adaptive Frequency Filters/data/datasets/__init__.py @@ -0,0 +1,288 @@ +# +# For licensing see accompanying LICENSE file. +# Copyright (C) 2023 Apple Inc. All Rights Reserved. +# + +import os +import importlib +import argparse +import glob + +from utils.ddp_utils import is_master +from utils import logger + +from .dataset_base import BaseImageDataset + + +SUPPORTED_TASKS = [] +DATASET_REGISTRY = {} + +SEPARATOR = ":" + + +def register_dataset(name, task): + def register_dataset_class(cls): + if name in DATASET_REGISTRY: + raise ValueError( + "Cannot register duplicate dataset class ({})".format(name) + ) + + if not issubclass(cls, BaseImageDataset): + raise ValueError( + "Dataset ({}: {}) must extend BaseImageDataset".format( + name, cls.__name__ + ) + ) + + DATASET_REGISTRY[name + SEPARATOR + task] = cls + return cls + + return register_dataset_class + + +def supported_dataset_str(dataset_name, dataset_category): + supp_list = list(DATASET_REGISTRY.keys()) + supp_str = "Dataset ({}) under task ({}) is not yet supported. \n Supported datasets are:".format( + dataset_name, dataset_category + ) + for t_name in SUPPORTED_TASKS: + supp_str += "\n\t {}: ".format(logger.color_text(t_name)) + for i, m_name in enumerate(supp_list): + d_name, t_name1 = m_name.split(SEPARATOR) + if t_name == t_name1: + supp_str += "\n\t\t{}".format(d_name) + logger.error(supp_str + "\n") + + +def evaluation_datasets(opts): + dataset_name = getattr(opts, "dataset.name", "imagenet") + dataset_category = getattr(opts, "dataset.category", "classification") + + is_master_node = is_master(opts) + + name_dataset_task = dataset_name + SEPARATOR + dataset_category + eval_dataset = None + if name_dataset_task in DATASET_REGISTRY: + eval_dataset = DATASET_REGISTRY[name_dataset_task]( + opts=opts, is_training=False, is_evaluation=True + ) + else: + supported_dataset_str( + dataset_name=dataset_name, dataset_category=dataset_category + ) + + if is_master_node: + logger.log("Evaluation dataset details: ") + print("{}".format(eval_dataset)) + + return eval_dataset + + +def train_val_datasets(opts): + dataset_name = getattr(opts, "dataset.name", "imagenet") + dataset_category = getattr(opts, "dataset.category", "classification") + disable_val = getattr(opts, "dataset.disable_val", False) + + is_master_node = is_master(opts) + + name_dataset_task = dataset_name + SEPARATOR + dataset_category + train_dataset = valid_dataset = None + if name_dataset_task in DATASET_REGISTRY and not disable_val: + train_dataset = DATASET_REGISTRY[name_dataset_task](opts=opts, is_training=True) + valid_dataset = DATASET_REGISTRY[name_dataset_task]( + opts=opts, is_training=False + ) + elif name_dataset_task in DATASET_REGISTRY and disable_val: + train_dataset = DATASET_REGISTRY[name_dataset_task](opts=opts, is_training=True) + valid_dataset = None + else: + supported_dataset_str( + dataset_name=dataset_name, dataset_category=dataset_category + ) + + if is_master_node: + logger.log("Training and validation dataset details: ") + print("{}".format(train_dataset)) + print("{}".format(valid_dataset)) + return train_dataset, valid_dataset + + + +def general_dataset_args(parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="Dataset", description="Arguments related to dataset" + ) + group.add_argument( + "--dataset.root-train", + type=str, + default="", + help="Root location of train dataset", + ) + group.add_argument( + "--dataset.root-val", + type=str, + default="", + help="Root location of valid dataset", + ) + group.add_argument( + "--dataset.root-test", + type=str, + default="", + help="Root location of test dataset", + ) + group.add_argument( + "--dataset.disable-val", action="store_true", help="Disable validation" + ) + + group.add_argument( + "--dataset.name", type=str, default="imagenet", help="Dataset name" + ) + group.add_argument( + "--dataset.category", + type=str, + default="classification", + help="Dataset category (e.g., segmentation, classification)", + ) + group.add_argument( + "--dataset.train-batch-size0", default=128, type=int, help="Training batch size" + ) + group.add_argument( + "--dataset.val-batch-size0", default=1, type=int, help="Validation batch size" + ) + group.add_argument( + "--dataset.eval-batch-size0", default=1, type=int, help="Validation batch size" + ) + group.add_argument( + "--dataset.workers", default=-1, type=int, help="Number of data workers" + ) + group.add_argument( + "--dataset.dali-workers", + default=-1, + type=int, + help="Number of data workers for dali", + ) + group.add_argument( + "--dataset.persistent-workers", + action="store_true", + help="Use same workers across all epochs in data loader", + ) + group.add_argument( + "--dataset.pin-memory", + action="store_true", + help="Use pin memory option in data loader", + ) + group.add_argument( + "--dataset.prefetch-factor", + type=int, + default=2, + help="Number of samples loaded in advance by each data worker", + ) + group.add_argument( + "--dataset.img-dtype", + type=str, + choices=["float", "half", "float16"], + default="float", + help="Image datatype", + ) + + group.add_argument( + "--dataset.cache-images-on-ram", action="store_true", help="Cache data on RAM" + ) + group.add_argument( + "--dataset.cache-limit", + type=float, + default=80.0, + help="Max. memory to use in RAM.", + ) + + # sample efficient training + group.add_argument( + "--dataset.sample-efficient-training.enable", + action="store_true", + help="sample efficient training", + ) + group.add_argument( + "--dataset.sample-efficient-training.sample-confidence", + type=float, + default=0.5, + help="Confidence for sample", + ) + group.add_argument( + "--dataset.sample-efficient-training.find-easy-samples-every-k-epochs", + type=int, + default=5, + help="Find easy samples after every K epochs", + ) + group.add_argument( + "--dataset.sample-efficient-training.min-sample-frequency", + type=int, + default=5, + help="Frequency that sample has been classified as easy for N number of times.", + ) + + group.add_argument( + "--dataset.decode-data-on-gpu", action="store_true", help="Decode data on GPU" + ) + group.add_argument( + "--dataset.sampler-type", + type=str, + default="batch", + help="Batch sampler or not.", + ) + + group.add_argument( + "--dataset.padding-index", + type=int, + default=None, + help="Padding index for text vocabulary", + ) + + group.add_argument( + "--dataset.text-vocab-size", type=int, default=-1, help="Text vocabulary size" + ) + + return parser + + +def arguments_dataset(parser: argparse.ArgumentParser): + parser = general_dataset_args(parser=parser) + + try: + from internal.utils.server_utils import dataset_server_args + parser = dataset_server_args(parser) + except ImportError as e: + pass + + + # add dataset specific arguments + for k, v in DATASET_REGISTRY.items(): + parser = v.add_arguments(parser=parser) + + return parser + + +# automatically import the datasets +dataset_dir = os.path.dirname(__file__) + +# supported tasks (each folder in datasets is for a particular task) +for abs_dir_path in glob.glob("{}/*".format(dataset_dir)): + if os.path.isdir(abs_dir_path): + file_or_folder_name = os.path.basename(abs_dir_path).strip() + if not file_or_folder_name.startswith( + "_" + ) and not file_or_folder_name.startswith("."): + SUPPORTED_TASKS.append(file_or_folder_name) + +for task in SUPPORTED_TASKS: + task_path = os.path.join(dataset_dir, task) + for file in os.listdir(task_path): + path = os.path.join(task_path, file) + if ( + not file.startswith("_") + and not file.startswith(".") + and (file.endswith(".py") or os.path.isdir(path)) + ): + dataset_name = file[: file.find(".py")] if file.endswith(".py") else file + module = importlib.import_module( + "data.datasets." + task + "." + dataset_name + ) diff --git a/Adaptive Frequency Filters/data/datasets/classification/__init__.py b/Adaptive Frequency Filters/data/datasets/classification/__init__.py new file mode 100644 index 0000000..f63f68f --- /dev/null +++ b/Adaptive Frequency Filters/data/datasets/classification/__init__.py @@ -0,0 +1,8 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# -------------------------------------------------------- + +""" +Image Classification Datasets +""" diff --git a/Adaptive Frequency Filters/data/datasets/classification/imagenet.py b/Adaptive Frequency Filters/data/datasets/classification/imagenet.py new file mode 100644 index 0000000..adfdfbe --- /dev/null +++ b/Adaptive Frequency Filters/data/datasets/classification/imagenet.py @@ -0,0 +1,220 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# -------------------------------------------------------- + +from torchvision.datasets import ImageFolder +from typing import Optional, Tuple, Dict, List, Union +import torch +import argparse + +from utils import logger + +from .. import register_dataset +from ..dataset_base import BaseImageDataset +from ...transforms import image_pil as T +from ...collate_fns import register_collate_fn + + +@register_dataset(name="imagenet", task="classification") +class ImagenetDataset(BaseImageDataset, ImageFolder): + """ + ImageNet Classification Dataset that uses PIL for reading and augmenting images. The dataset structure should + follow the ImageFolder class in :class:`torchvision.datasets.imagenet` + + Args: + opts: command-line arguments + is_training (Optional[bool]): A flag used to indicate training or validation mode. Default: True + is_evaluation (Optional[bool]): A flag used to indicate evaluation (or inference) mode. Default: False + + .. note:: + We recommend to use this dataset class over the imagenet_opencv.py file. + + """ + + def __init__( + self, + opts, + is_training: Optional[bool] = True, + is_evaluation: Optional[bool] = False, + *args, + **kwargs + ) -> None: + BaseImageDataset.__init__( + self, opts=opts, is_training=is_training, is_evaluation=is_evaluation + ) + root = self.root + ImageFolder.__init__( + self, root=root, transform=None, target_transform=None, is_valid_file=None + ) + + self.n_classes = len(list(self.class_to_idx.keys())) + setattr(opts, "model.classification.n_classes", self.n_classes) + setattr(opts, "dataset.collate_fn_name_train", "imagenet_collate_fn") + setattr(opts, "dataset.collate_fn_name_val", "imagenet_collate_fn") + setattr(opts, "dataset.collate_fn_name_eval", "imagenet_collate_fn") + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + """Add dataset-specific arguments to the parser.""" + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + group.add_argument( + "--dataset.imagenet.crop-ratio", + type=float, + default=0.875, + help="Crop ratio", + ) + return parser + + def _training_transforms(self, size: Union[Tuple, int], *args, **kwargs): + """ + Training data augmentation methods. + Image --> RandomResizedCrop --> RandomHorizontalFlip --> Optional(AutoAugment or RandAugment) + --> Tensor --> Optional(RandomErasing) --> Optional(MixUp) --> Optional(CutMix) + + .. note:: + 1. AutoAugment, RandAugment and TrivialAugmentWide are mutually exclusive. + 2. Mixup and CutMix are applied on batches are implemented in trainer. + """ + aug_list = [ + T.RandomResizedCrop(opts=self.opts, size=size), + T.RandomHorizontalFlip(opts=self.opts), + ] + auto_augment = getattr( + self.opts, "image_augmentation.auto_augment.enable", False + ) + rand_augment = getattr( + self.opts, "image_augmentation.rand_augment.enable", False + ) + trivial_augment_wide = getattr( + self.opts, "image_augmentation.trivial_augment_wide.enable", False + ) + if bool(auto_augment) + bool(rand_augment) + bool(trivial_augment_wide) > 1: + logger.error( + "AutoAugment, RandAugment and TrivialAugmentWide are mutually exclusive. Use either of them, but not more than one" + ) + elif auto_augment: + aug_list.append(T.AutoAugment(opts=self.opts)) + elif rand_augment: + if getattr( + self.opts, "image_augmentation.rand_augment.use_timm_library", False + ): + aug_list.append(T.RandAugmentTimm(opts=self.opts)) + else: + aug_list.append(T.RandAugment(opts=self.opts)) + elif trivial_augment_wide: + aug_list.append(T.TrivialAugmentWide(opts=self.opts)) + + aug_list.append(T.ToTensor(opts=self.opts)) + + if getattr(self.opts, "image_augmentation.random_erase.enable", False): + aug_list.append(T.RandomErasing(opts=self.opts)) + + return T.Compose(opts=self.opts, img_transforms=aug_list) + + def _validation_transforms(self, size: Union[Tuple, int], *args, **kwargs): + """ + Validation augmentation + Image --> Resize --> CenterCrop --> ToTensor + """ + aug_list = [ + T.Resize(opts=self.opts), + T.CenterCrop(opts=self.opts), + T.ToTensor(opts=self.opts), + ] + + return T.Compose(opts=self.opts, img_transforms=aug_list) + + def _evaluation_transforms(self, size: Union[Tuple, int], *args, **kwargs): + """Same as the validation_transforms""" + return self._validation_transforms(size=size) + + def __getitem__(self, batch_indexes_tup: Tuple) -> Dict: + """ + :param batch_indexes_tup: Tuple of the form (Crop_size_W, Crop_size_H, Image_ID) + :return: dictionary containing input image, label, and sample_id. + """ + crop_size_h, crop_size_w, img_index = batch_indexes_tup + if self.is_training: + transform_fn = self._training_transforms(size=(crop_size_h, crop_size_w)) + else: + # same for validation and evaluation + transform_fn = self._validation_transforms(size=(crop_size_h, crop_size_w)) + + img_path, target = self.samples[img_index] + + input_img = self.read_image_pil(img_path) + + if input_img is None: + # Sometimes images are corrupt + # Skip such images + logger.log("Img index {} is possibly corrupt.".format(img_index)) + input_tensor = torch.zeros( + size=(3, crop_size_h, crop_size_w), dtype=self.img_dtype + ) + target = -1 + data = {"image": input_tensor} + else: + data = {"image": input_img} + data = transform_fn(data) + + data["samples"] = data.pop("image") + data["targets"] = target + data["sample_id"] = img_index + + return data + + def __len__(self) -> int: + return len(self.samples) + + def __repr__(self) -> str: + from utils.tensor_utils import image_size_from_opts + + im_h, im_w = image_size_from_opts(opts=self.opts) + + if self.is_training: + transforms_str = self._training_transforms(size=(im_h, im_w)) + else: + transforms_str = self._validation_transforms(size=(im_h, im_w)) + + return "{}(\n\troot={}\n\tis_training={}\n\tsamples={}\n\tn_classes={}\n\ttransforms={}\n)".format( + self.__class__.__name__, + self.root, + self.is_training, + len(self.samples), + self.n_classes, + transforms_str, + ) + + +@register_collate_fn(name="imagenet_collate_fn") +def imagenet_collate_fn(batch: List, opts) -> Dict: + batch_size = len(batch) + img_size = [batch_size, *batch[0]["samples"].shape] + img_dtype = batch[0]["samples"].dtype + + images = torch.zeros(size=img_size, dtype=img_dtype) + # fill with -1, so that we can ignore corrupted images + labels = torch.full(size=[batch_size], fill_value=-1, dtype=torch.long) + sample_ids = torch.zeros(size=[batch_size], dtype=torch.long) + valid_indexes = [] + for i, batch_i in enumerate(batch): + label_i = batch_i.pop("targets") + images[i] = batch_i.pop("samples") + labels[i] = label_i # label is an int + sample_ids[i] = batch_i.pop("sample_id") # sample id is an int + if label_i != -1: + valid_indexes.append(i) + + valid_indexes = torch.tensor(valid_indexes, dtype=torch.long) + images = torch.index_select(images, dim=0, index=valid_indexes) + labels = torch.index_select(labels, dim=0, index=valid_indexes) + sample_ids = torch.index_select(sample_ids, dim=0, index=valid_indexes) + + channels_last = getattr(opts, "common.channels_last", False) + if channels_last: + images = images.to(memory_format=torch.channels_last) + + return {"samples": images, "targets": labels, "sample_id": sample_ids} diff --git a/Adaptive Frequency Filters/data/datasets/classification/imagenet_fast.py b/Adaptive Frequency Filters/data/datasets/classification/imagenet_fast.py new file mode 100644 index 0000000..ffa3d18 --- /dev/null +++ b/Adaptive Frequency Filters/data/datasets/classification/imagenet_fast.py @@ -0,0 +1,197 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# -------------------------------------------------------- + +# from torchvision.datasets import ImageFolder +from utils.my_dataset_folder import ImageFolder +import os +from typing import Optional, Tuple, Dict, List, Union +import torch +import argparse + +from utils import logger + +from .. import register_dataset +from ..dataset_base import BaseImageDataset +from ...transforms import image_pil as T +from ...collate_fns import register_collate_fn + + +@register_dataset(name="imagenet_fast", task="classification") +class ImagenetDataset(BaseImageDataset, ImageFolder): + """ + ImageNet Classification Dataset that uses PIL for reading and augmenting images. The dataset structure should + follow the ImageFolder class in :class:`torchvision.datasets.imagenet` + + Args: + opts: command-line arguments + is_training (Optional[bool]): A flag used to indicate training or validation mode. Default: True + is_evaluation (Optional[bool]): A flag used to indicate evaluation (or inference) mode. Default: False + + .. note:: + We recommend to use this dataset class over the imagenet_opencv.py file. + + """ + + def __init__( + self, + opts, + is_training: Optional[bool] = True, + is_evaluation: Optional[bool] = False, + *args, + **kwargs + ) -> None: + BaseImageDataset.__init__( + self, opts=opts, is_training=is_training, is_evaluation=is_evaluation + ) + root = self.root + # ImageFolder.__init__( + # self, root=root, transform=None, target_transform=None, is_valid_file=None + # ) + # assert is_training ^ is_evaluation + prefix = 'train' if is_training else 'val' + map_txt = os.path.join(root, '..', f"{prefix}_map.txt") + ImageFolder.__init__( + self, root=root, transform=None, target_transform=None, is_valid_file=None, map_txt=map_txt + ) + # self.n_classes = len(list(self.class_to_idx.keys())) + self.n_classes = len(self.classes) + setattr(opts, "model.classification.n_classes", self.n_classes) + setattr(opts, "dataset.collate_fn_name_train", "imagenet_collate_fn") + setattr(opts, "dataset.collate_fn_name_val", "imagenet_collate_fn") + setattr(opts, "dataset.collate_fn_name_eval", "imagenet_collate_fn") + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + """Add dataset-specific arguments to the parser.""" + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + group.add_argument( + "--dataset.imagenet_fast.crop-ratio", # --dataset.imagenet.crop-ratio + type=float, + default=0.875, + help="Crop ratio", + ) + return parser + + def _training_transforms(self, size: Union[Tuple, int], *args, **kwargs): + """ + Training data augmentation methods. + Image --> RandomResizedCrop --> RandomHorizontalFlip --> Optional(AutoAugment or RandAugment) + --> Tensor --> Optional(RandomErasing) --> Optional(MixUp) --> Optional(CutMix) + + .. note:: + 1. AutoAugment, RandAugment and TrivialAugmentWide are mutually exclusive. + 2. Mixup and CutMix are applied on batches are implemented in trainer. + """ + aug_list = [ + T.RandomResizedCrop(opts=self.opts, size=size), + T.RandomHorizontalFlip(opts=self.opts), + ] + auto_augment = getattr( + self.opts, "image_augmentation.auto_augment.enable", False + ) + rand_augment = getattr( + self.opts, "image_augmentation.rand_augment.enable", False + ) + trivial_augment_wide = getattr( + self.opts, "image_augmentation.trivial_augment_wide.enable", False + ) + if bool(auto_augment) + bool(rand_augment) + bool(trivial_augment_wide) > 1: + logger.error( + "AutoAugment, RandAugment and TrivialAugmentWide are mutually exclusive. Use either of them, but not more than one" + ) + elif auto_augment: + aug_list.append(T.AutoAugment(opts=self.opts)) + elif rand_augment: + if getattr( + self.opts, "image_augmentation.rand_augment.use_timm_library", False + ): + aug_list.append(T.RandAugmentTimm(opts=self.opts)) + else: + aug_list.append(T.RandAugment(opts=self.opts)) + elif trivial_augment_wide: + aug_list.append(T.TrivialAugmentWide(opts=self.opts)) + + aug_list.append(T.ToTensor(opts=self.opts)) + + if getattr(self.opts, "image_augmentation.random_erase.enable", False): + aug_list.append(T.RandomErasing(opts=self.opts)) + + return T.Compose(opts=self.opts, img_transforms=aug_list) + + def _validation_transforms(self, size: Union[Tuple, int], *args, **kwargs): + """ + Validation augmentation + Image --> Resize --> CenterCrop --> ToTensor + """ + aug_list = [ + T.Resize(opts=self.opts), + T.CenterCrop(opts=self.opts), + T.ToTensor(opts=self.opts), + ] + + return T.Compose(opts=self.opts, img_transforms=aug_list) + + def _evaluation_transforms(self, size: Union[Tuple, int], *args, **kwargs): + """Same as the validation_transforms""" + return self._validation_transforms(size=size) + + def __getitem__(self, batch_indexes_tup: Tuple) -> Dict: + """ + :param batch_indexes_tup: Tuple of the form (Crop_size_W, Crop_size_H, Image_ID) + :return: dictionary containing input image, label, and sample_id. + """ + crop_size_h, crop_size_w, img_index = batch_indexes_tup + if self.is_training: + transform_fn = self._training_transforms(size=(crop_size_h, crop_size_w)) + else: + # same for validation and evaluation + transform_fn = self._validation_transforms(size=(crop_size_h, crop_size_w)) + + img_path, target = self.samples[img_index] + + input_img = self.read_image_pil(img_path) + + if input_img is None: + # Sometimes images are corrupt + # Skip such images + logger.log("Img index {} is possibly corrupt.".format(img_index)) + input_tensor = torch.zeros( + size=(3, crop_size_h, crop_size_w), dtype=self.img_dtype + ) + target = -1 + data = {"image": input_tensor} + else: + data = {"image": input_img} + data = transform_fn(data) + + data["samples"] = data.pop("image") + data["targets"] = target + data["sample_id"] = img_index + + return data + + def __len__(self) -> int: + return len(self.samples) + + def __repr__(self) -> str: + from utils.tensor_utils import image_size_from_opts + + im_h, im_w = image_size_from_opts(opts=self.opts) + + if self.is_training: + transforms_str = self._training_transforms(size=(im_h, im_w)) + else: + transforms_str = self._validation_transforms(size=(im_h, im_w)) + + return "{}(\n\troot={}\n\tis_training={}\n\tsamples={}\n\tn_classes={}\n\ttransforms={}\n)".format( + self.__class__.__name__, + self.root, + self.is_training, + len(self.samples), + self.n_classes, + transforms_str, + ) diff --git a/Adaptive Frequency Filters/data/datasets/classification/imagenet_opencv.py b/Adaptive Frequency Filters/data/datasets/classification/imagenet_opencv.py new file mode 100644 index 0000000..676d6cf --- /dev/null +++ b/Adaptive Frequency Filters/data/datasets/classification/imagenet_opencv.py @@ -0,0 +1,161 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# -------------------------------------------------------- + +from torchvision.datasets import ImageFolder +from typing import Optional, Tuple, Dict +import numpy as np +import math +import warnings + +from utils import logger + +from .. import register_dataset +from ..dataset_base import BaseImageDataset +from ...transforms import image_opencv as tf + + +@register_dataset(name="imagenet_opencv", task="classification") +class ImagenetOpenCVDataset(BaseImageDataset, ImageFolder): + """ + ImageNet Classification Dataset that uses OpenCV for data augmentation. + + The dataset structure should follow the ImageFolder class in :class:`torchvision.datasets.imagenet` + + Args: + opts: command-line arguments + is_training (Optional[bool]): A flag used to indicate training or validation mode. Default: True + is_evaluation (Optional[bool]): A flag used to indicate evaluation (or inference) mode. Default: False + + .. note:: + This class is depreciated and will be removed in future versions (Use it for evaluation). + """ + + def __init__( + self, + opts, + is_training: Optional[bool] = True, + is_evaluation: Optional[bool] = False, + *args, + **kwargs + ) -> None: + warnings.warn( + "The use of dataset.name=imagenet_opencv is depreciated. Please use dataset.name=imagenet", + DeprecationWarning, + ) + BaseImageDataset.__init__( + self, opts=opts, is_training=is_training, is_evaluation=is_evaluation + ) + root = self.root + ImageFolder.__init__( + self, root=root, transform=None, target_transform=None, is_valid_file=None + ) + + self.n_classes = len(list(self.class_to_idx.keys())) + setattr(opts, "model.classification.n_classes", self.n_classes) + + setattr(opts, "dataset.collate_fn_name_train", "imagenet_collate_fn") + setattr(opts, "dataset.collate_fn_name_val", "imagenet_collate_fn") + setattr(opts, "dataset.collate_fn_name_eval", "imagenet_collate_fn") + + def _training_transforms(self, size: tuple or int): + """ + Training data augmentation methods (RandomResizedCrop --> RandomHorizontalFlip --> ToTensor). + """ + aug_list = [ + tf.RandomResizedCrop(opts=self.opts, size=size), + tf.RandomHorizontalFlip(opts=self.opts), + tf.NumpyToTensor(opts=self.opts), + ] + return tf.Compose(opts=self.opts, img_transforms=aug_list) + + def _validation_transforms(self, size: tuple): + """Implements validation transformation method (Resize --> CenterCrop --> ToTensor).""" + if isinstance(size, (tuple, list)): + size = min(size) + + assert isinstance(size, int) + # (256 - 224) = 32 + # where 224/0.875 = 256 + + crop_ratio = getattr(self.opts, "dataset.imagenet.crop_ratio", 0.875) + if 0 < crop_ratio < 1.0: + scale_size = int(math.ceil(size / crop_ratio)) + scale_size = (scale_size // 32) * 32 + else: + logger.warning( + "Crop ratio should be between 0 and 1. Got: {}".format(crop_ratio) + ) + logger.warning("Setting scale_size as size + 32") + scale_size = size + 32 # int(make_divisible(crop_size / 0.875, divisor=32)) + + return tf.Compose( + opts=self.opts, + img_transforms=[ + tf.Resize(opts=self.opts, size=scale_size), + tf.CenterCrop(opts=self.opts, size=size), + tf.NumpyToTensor(opts=self.opts), + ], + ) + + def _evaluation_transforms(self, size: tuple): + """Same as the validation_transforms""" + return self._validation_transforms(size=size) + + def __getitem__(self, batch_indexes_tup: Tuple) -> Dict: + """ + + :param batch_indexes_tup: Tuple of the form (Crop_size_W, Crop_size_H, Image_ID) + :return: dictionary containing input image and label ID. + """ + crop_size_h, crop_size_w, img_index = batch_indexes_tup + if self.is_training: + transform_fn = self._training_transforms(size=(crop_size_h, crop_size_w)) + else: # same for validation and evaluation + transform_fn = self._validation_transforms(size=(crop_size_h, crop_size_w)) + + img_path, target = self.samples[img_index] + input_img = self.read_image_opencv(img_path) + + if input_img is None: + # Sometimes images are corrupt and cv2 is not able to load them + # Skip such images + logger.log( + "Img index {} is possibly corrupt. Removing it from the sample list".format( + img_index + ) + ) + del self.samples[img_index] + input_img = np.zeros(shape=(crop_size_h, crop_size_w, 3), dtype=np.uint8) + + data = {"image": input_img} + data = transform_fn(data) + + data["samples"] = data.pop("image") + data["targets"] = target + data["sample_id"] = img_index + + return data + + def __len__(self): + return len(self.samples) + + def __repr__(self): + from utils.tensor_utils import image_size_from_opts + + im_h, im_w = image_size_from_opts(opts=self.opts) + + if self.is_training: + transforms_str = self._training_transforms(size=(im_h, im_w)) + else: + transforms_str = self._validation_transforms(size=(im_h, im_w)) + + return "{}(\n\troot={}\n\tis_training={}\n\tsamples={}\n\tn_classes={}\n\ttransforms={}\n)".format( + self.__class__.__name__, + self.root, + self.is_training, + len(self.samples), + self.n_classes, + transforms_str, + ) diff --git a/Adaptive Frequency Filters/data/datasets/classification/imagenet_opencv_bitplane_fast.py b/Adaptive Frequency Filters/data/datasets/classification/imagenet_opencv_bitplane_fast.py new file mode 100644 index 0000000..9da868e --- /dev/null +++ b/Adaptive Frequency Filters/data/datasets/classification/imagenet_opencv_bitplane_fast.py @@ -0,0 +1,158 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# -------------------------------------------------------- +import os + +from utils.my_dataset_folder import ImageFolder +from typing import Optional, Tuple, Dict +import numpy as np +import math + +from utils import logger + +from .. import register_dataset +from ..dataset_base import BaseImageDataset +from ...transforms import image_opencv as tf +from ...transforms.image_opencv import BitPlane + +# change name +@register_dataset(name="imagenet_opencv_bitplane_fast", task="classification") +class ImagenetOpenCVDataset(BaseImageDataset, ImageFolder): + """ + ImageNet Classification Dataset that uses OpenCV for data augmentation. + + The dataset structure should follow the ImageFolder class in :class:`torchvision.datasets.imagenet` + + Args: + opts: command-line arguments + is_training (Optional[bool]): A flag used to indicate training or validation mode. Default: True + is_evaluation (Optional[bool]): A flag used to indicate evaluation (or inference) mode. Default: False + + .. note:: + This class is depreciated and will be removed in future versions (Use it for evaluation). + """ + + def __init__( + self, + opts, + is_training: Optional[bool] = True, + is_evaluation: Optional[bool] = False, + *args, + **kwargs + ) -> None: + BaseImageDataset.__init__( + self, opts=opts, is_training=is_training, is_evaluation=is_evaluation + ) + root = self.root + # assert is_training ^ is_evaluation + prefix = 'train' if is_training else 'val' + map_txt = os.path.join(root, '..', f"{prefix}_map.txt") + ImageFolder.__init__( + self, root=root, transform=None, target_transform=None, is_valid_file=None, map_txt=map_txt + ) + + self.n_classes = len(self.classes) + setattr(opts, "model.classification.n_classes", self.n_classes) + + def _training_transforms(self, size: tuple or int): + """ + Training data augmentation methods (RandomResizedCrop --> RandomHorizontalFlip --> ToTensor). + """ + aug_list = [ + tf.RandomResizedCrop(opts=self.opts, size=size), + tf.RandomHorizontalFlip(opts=self.opts), + BitPlane(opts=self.opts, h=size[0], w=size[1]), + tf.NumpyToTensor(opts=self.opts), + ] + return tf.Compose(opts=self.opts, img_transforms=aug_list) + + def _validation_transforms(self, size: tuple): + """Implements validation transformation method (Resize --> CenterCrop --> ToTensor).""" + if isinstance(size, (tuple, list)): + size = min(size) + + assert isinstance(size, int) + # (256 - 224) = 32 + # where 224/0.875 = 256 + + crop_ratio = getattr(self.opts, "dataset.imagenet.crop_ratio", 0.875) + if 0 < crop_ratio < 1.0: + scale_size = int(math.ceil(size / crop_ratio)) + scale_size = (scale_size // 32) * 32 + else: + logger.warning( + "Crop ratio should be between 0 and 1. Got: {}".format(crop_ratio) + ) + logger.warning("Setting scale_size as size + 32") + scale_size = size + 32 # int(make_divisible(crop_size / 0.875, divisor=32)) + + return tf.Compose( + opts=self.opts, + img_transforms=[ + tf.Resize(opts=self.opts, size=scale_size), + tf.CenterCrop(opts=self.opts, size=size), + BitPlane(opts=self.opts, h=size, w=size), + tf.NumpyToTensor(opts=self.opts), + ], + ) + + def _evaluation_transforms(self, size: tuple): + """Same as the validation_transforms""" + return self._validation_transforms(size=size) + + def __getitem__(self, batch_indexes_tup: Tuple) -> Dict: + """ + + :param batch_indexes_tup: Tuple of the form (Crop_size_W, Crop_size_H, Image_ID) + :return: dictionary containing input image and label ID. + """ + crop_size_h, crop_size_w, img_index = batch_indexes_tup + if self.is_training: + transform_fn = self._training_transforms(size=(crop_size_h, crop_size_w)) + else: # same for validation and evaluation + transform_fn = self._validation_transforms(size=(crop_size_h, crop_size_w)) + + img_path, target = self.samples[img_index] + input_img = self.read_image_opencv(img_path) + + if input_img is None: + # Sometimes images are corrupt and cv2 is not able to load them + # Skip such images + logger.log( + "Img index {} is possibly corrupt. Removing it from the sample list".format( + img_index + ) + ) + del self.samples[img_index] + input_img = np.zeros(shape=(crop_size_h, crop_size_w, 3), dtype=np.uint8) + + data = {"image": input_img} + data = transform_fn(data) + + data["label"] = target + data["sample_id"] = img_index + + return data + + def __len__(self): + return len(self.samples) + + def __repr__(self): + from utils.tensor_utils import image_size_from_opts + + im_h, im_w = image_size_from_opts(opts=self.opts) + + if self.is_training: + transforms_str = self._training_transforms(size=(im_h, im_w)) + else: + transforms_str = self._validation_transforms(size=(im_h, im_w)) + + return "{}(\n\troot={}\n\tis_training={}\n\tsamples={}\n\tn_classes={}\n\ttransforms={}\n)".format( + self.__class__.__name__, + self.root, + self.is_training, + len(self.samples), + self.n_classes, + transforms_str, + ) diff --git a/Adaptive Frequency Filters/data/datasets/classification/imagenet_opencv_fast.py b/Adaptive Frequency Filters/data/datasets/classification/imagenet_opencv_fast.py new file mode 100644 index 0000000..d6db0f3 --- /dev/null +++ b/Adaptive Frequency Filters/data/datasets/classification/imagenet_opencv_fast.py @@ -0,0 +1,160 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# -------------------------------------------------------- +import os + +from utils.my_dataset_folder import ImageFolder +from typing import Optional, Tuple, Dict +import numpy as np +import math + +from utils import logger + +from .. import register_dataset +from ..dataset_base import BaseImageDataset +from ...transforms import image_opencv as tf + + +@register_dataset(name="imagenet_opencv_fast", task="classification") +class ImagenetOpenCVDataset(BaseImageDataset, ImageFolder): + """ + ImageNet Classification Dataset that uses OpenCV for data augmentation. + + The dataset structure should follow the ImageFolder class in :class:`torchvision.datasets.imagenet` + + Args: + opts: command-line arguments + is_training (Optional[bool]): A flag used to indicate training or validation mode. Default: True + is_evaluation (Optional[bool]): A flag used to indicate evaluation (or inference) mode. Default: False + + .. note:: + This class is depreciated and will be removed in future versions (Use it for MobileViT evaluation). + """ + + def __init__( + self, + opts, + is_training: Optional[bool] = True, + is_evaluation: Optional[bool] = False, + *args, + **kwargs + ) -> None: + BaseImageDataset.__init__( + self, opts=opts, is_training=is_training, is_evaluation=is_evaluation + ) + root = self.root + # assert is_training ^ is_evaluation + prefix = 'train' if is_training else 'val' + map_txt = os.path.join(root, '..', f"{prefix}_map.txt") + ImageFolder.__init__( + self, root=root, transform=None, target_transform=None, is_valid_file=None, map_txt=map_txt + ) + + self.n_classes = len(self.classes) + setattr(opts, "model.classification.n_classes", self.n_classes) + + setattr(opts, "dataset.collate_fn_name_train", "imagenet_collate_fn") + setattr(opts, "dataset.collate_fn_name_val", "imagenet_collate_fn") + setattr(opts, "dataset.collate_fn_name_eval", "imagenet_collate_fn") + + def _training_transforms(self, size: tuple or int): + """ + Training data augmentation methods (RandomResizedCrop --> RandomHorizontalFlip --> ToTensor). + """ + aug_list = [ + tf.RandomResizedCrop(opts=self.opts, size=size), + tf.RandomHorizontalFlip(opts=self.opts), + tf.NumpyToTensor(opts=self.opts), + ] + return tf.Compose(opts=self.opts, img_transforms=aug_list) + + def _validation_transforms(self, size: tuple): + """Implements validation transformation method (Resize --> CenterCrop --> ToTensor).""" + if isinstance(size, (tuple, list)): + size = min(size) + + assert isinstance(size, int) + # (256 - 224) = 32 + # where 224/0.875 = 256 + + crop_ratio = getattr(self.opts, "dataset.imagenet.crop_ratio", 0.875) + if 0 < crop_ratio < 1.0: + scale_size = int(math.ceil(size / crop_ratio)) + scale_size = (scale_size // 32) * 32 + else: + logger.warning( + "Crop ratio should be between 0 and 1. Got: {}".format(crop_ratio) + ) + logger.warning("Setting scale_size as size + 32") + scale_size = size + 32 # int(make_divisible(crop_size / 0.875, divisor=32)) + + return tf.Compose( + opts=self.opts, + img_transforms=[ + tf.Resize(opts=self.opts, size=scale_size), + tf.CenterCrop(opts=self.opts, size=size), + tf.NumpyToTensor(opts=self.opts), + ], + ) + + def _evaluation_transforms(self, size: tuple): + """Same as the validation_transforms""" + return self._validation_transforms(size=size) + + def __getitem__(self, batch_indexes_tup: Tuple) -> Dict: + """ + + :param batch_indexes_tup: Tuple of the form (Crop_size_W, Crop_size_H, Image_ID) + :return: dictionary containing input image and label ID. + """ + crop_size_h, crop_size_w, img_index = batch_indexes_tup + if self.is_training: + transform_fn = self._training_transforms(size=(crop_size_h, crop_size_w)) + else: # same for validation and evaluation + transform_fn = self._validation_transforms(size=(crop_size_h, crop_size_w)) + + img_path, target = self.samples[img_index] + input_img = self.read_image_opencv(img_path) + + if input_img is None: + # Sometimes images are corrupt and cv2 is not able to load them + # Skip such images + logger.log( + "Img index {} is possibly corrupt. Removing it from the sample list".format( + img_index + ) + ) + del self.samples[img_index] + input_img = np.zeros(shape=(crop_size_h, crop_size_w, 3), dtype=np.uint8) + + data = {"image": input_img} + data = transform_fn(data) + + data["samples"] = data.pop("image") + data["targets"] = target + data["sample_id"] = img_index + + return data + + def __len__(self): + return len(self.samples) + + def __repr__(self): + from utils.tensor_utils import image_size_from_opts + + im_h, im_w = image_size_from_opts(opts=self.opts) + + if self.is_training: + transforms_str = self._training_transforms(size=(im_h, im_w)) + else: + transforms_str = self._validation_transforms(size=(im_h, im_w)) + + return "{}(\n\troot={}\n\tis_training={}\n\tsamples={}\n\tn_classes={}\n\ttransforms={}\n)".format( + self.__class__.__name__, + self.root, + self.is_training, + len(self.samples), + self.n_classes, + transforms_str, + ) diff --git a/Adaptive Frequency Filters/data/datasets/classification/imagenet_v2.py b/Adaptive Frequency Filters/data/datasets/classification/imagenet_v2.py new file mode 100644 index 0000000..47a1881 --- /dev/null +++ b/Adaptive Frequency Filters/data/datasets/classification/imagenet_v2.py @@ -0,0 +1,173 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# -------------------------------------------------------- + +import argparse +import tarfile +from pathlib import Path +from typing import Optional, Tuple, Dict, Union + +import torch + +from utils import logger +from utils.download_utils import get_local_path + +from .. import register_dataset +from ..dataset_base import BaseImageDataset +from ...transforms import image_pil as T + +IMAGENETv2_SPLIT_LINK_MAP = { + "matched_frequency": { + "url": "https://s3-us-west-2.amazonaws.com/imagenetv2public/imagenetv2-matched-frequency.tar.gz", + "extracted_folder_name": "imagenetv2-matched-frequency-format-val", + }, + "threshold_0.7": { + "url": "https://s3-us-west-2.amazonaws.com/imagenetv2public/imagenetv2-threshold0.7.tar.gz", + "extracted_folder_name": "imagenetv2-threshold0.7-format-val", + }, + "top_images": { + "url": "https://s3-us-west-2.amazonaws.com/imagenetv2public/imagenetv2-top-images.tar.gz", + "extracted_folder_name": "imagenetv2-top-images-format-val", + }, +} + + +@register_dataset(name="imagenet_v2", task="classification") +class Imagenetv2Dataset(BaseImageDataset): + """ + `ImageNetv2 Dataset `_ for studying the robustness of models trained on ImageNet dataset + + Args: + opts: command-line arguments + is_training (Optional[bool]): ImageNetv2 should be used for evaluation only Default: False + is_evaluation (Optional[bool]): A flag used to indicate evaluation (or inference) mode. Default: True + + """ + + def __init__( + self, + opts, + is_training: Optional[bool] = False, + is_evaluation: Optional[bool] = True, + *args, + **kwargs, + ) -> None: + if is_training: + logger.error( + "{} can only be used for evaluation".format(self.__class__.__name__) + ) + + super().__init__( + opts=opts, is_training=is_training, is_evaluation=is_evaluation + ) + + split = getattr(opts, "dataset.imagenet_v2.split", None) + if split is None or split not in IMAGENETv2_SPLIT_LINK_MAP.keys(): + logger.error( + "Please specify split for ImageNetv2. Supported ImageNetv2 splits are: {}".format( + IMAGENETv2_SPLIT_LINK_MAP.keys() + ) + ) + + split_path = get_local_path(opts, path=IMAGENETv2_SPLIT_LINK_MAP[split]["url"]) + with tarfile.open(split_path) as tf: + tf.extractall(self.root) + + root = Path( + "{}/{}".format( + self.root, IMAGENETv2_SPLIT_LINK_MAP[split]["extracted_folder_name"] + ) + ) + file_names = list(root.glob("**/*.jpeg")) + self.file_names = file_names + + setattr(opts, "dataset.collate_fn_name_train", "imagenet_collate_fn") + setattr(opts, "dataset.collate_fn_name_val", "imagenet_collate_fn") + setattr(opts, "dataset.collate_fn_name_eval", "imagenet_collate_fn") + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + """Add dataset-specific arguments to the parser.""" + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + group.add_argument( + "--dataset.imagenet-v2.split", + type=str, + default="matched-frequency", + help="ImageNetv2 dataset. Possible choices are: {}".format( + [ + f"{i + 1}: {split_name}" + for i, split_name in enumerate(IMAGENETv2_SPLIT_LINK_MAP.keys()) + ] + ), + choices=IMAGENETv2_SPLIT_LINK_MAP.keys(), + ) + return parser + + def _validation_transforms(self, size: Union[Tuple, int], *args, **kwargs): + """ + Validation augmentation + Image --> Resize --> CenterCrop --> ToTensor + """ + aug_list = [ + T.Resize(opts=self.opts), + T.CenterCrop(opts=self.opts), + T.ToTensor(opts=self.opts), + ] + + return T.Compose(opts=self.opts, img_transforms=aug_list) + + def __getitem__(self, batch_indexes_tup: Tuple) -> Dict: + """ + :param batch_indexes_tup: Tuple of the form (Crop_size_W, Crop_size_H, Image_ID) + :return: dictionary containing input image, label, and sample_id. + """ + crop_size_h, crop_size_w, img_index = batch_indexes_tup + + # same for validation and evaluation + transform_fn = self._validation_transforms(size=(crop_size_h, crop_size_w)) + + # infer target label from the file name + # file names are organized as SPLIT_NAME-format-val/class_idx/*.jpg + # Example: All images in this folder (imagenetv2-matched-frequency-format-val/0/*.jpg) belong to class 0 + img_path = str(self.file_names[img_index]) + target = int(self.file_names[img_index].parent.name) + + input_img = self.read_image_pil(img_path) + if input_img is None: + # Sometimes images are corrupt + # Skip such images + logger.log("Img index {} is possibly corrupt.".format(img_index)) + input_tensor = torch.zeros( + size=(3, crop_size_h, crop_size_w), dtype=self.img_dtype + ) + target = -1 + data = {"image": input_tensor} + else: + data = {"image": input_img} + data = transform_fn(data) + + data["samples"] = data["image"] + data["targets"] = target + data["sample_id"] = img_index + + return data + + def __len__(self) -> int: + return len(self.file_names) + + def __repr__(self) -> str: + from utils.tensor_utils import image_size_from_opts + + im_h, im_w = image_size_from_opts(opts=self.opts) + + transforms_str = self._validation_transforms(size=(im_h, im_w)) + + return "{}(\n\troot={}\n\tsamples={}\n\ttransforms={}\n)".format( + self.__class__.__name__, + self.root, + len(self.file_names), + transforms_str, + ) diff --git a/Adaptive Frequency Filters/data/datasets/dataset_base.py b/Adaptive Frequency Filters/data/datasets/dataset_base.py new file mode 100644 index 0000000..34bd627 --- /dev/null +++ b/Adaptive Frequency Filters/data/datasets/dataset_base.py @@ -0,0 +1,230 @@ +# +# For licensing see accompanying LICENSE file. +# Copyright (C) 2023 Apple Inc. All Rights Reserved. +# + +import copy +import warnings +import torch +from torch import Tensor +from torch.utils import data +import cv2 +from PIL import Image +from typing import Optional, Union, Dict +import argparse +import psutil +import time +import numpy as np +from torchvision.io import ( + read_image, + read_file, + decode_jpeg, + ImageReadMode, + decode_image, +) +import io + +from utils import logger +from utils.ddp_utils import is_start_rank_node, is_master + + +class BaseImageDataset(data.Dataset): + """ + Base Dataset class for Image datasets + """ + + def __init__( + self, + opts, + is_training: Optional[bool] = True, + is_evaluation: Optional[bool] = False, + *args, + **kwargs + ): + if getattr(opts, "dataset.trove.enable", False): + opts = self.load_from_server(opts=opts, is_training=is_training) + + root = ( + getattr(opts, "dataset.root_train", None) + if is_training + else getattr(opts, "dataset.root_val", None) + ) + self.root = root + self.is_training = is_training + self.is_evaluation = is_evaluation + self.sampler_name = getattr(opts, "sampler.name", None) + self.opts = opts + + image_device_cuda = getattr(self.opts, "dataset.decode_data_on_gpu", False) + device = getattr(self.opts, "dev.device", torch.device("cpu")) + use_cuda = False + if image_device_cuda and ( + (isinstance(device, str) and device.find("cuda") > -1) + or (isinstance(device, torch.device) and device.type.find("cuda") > -1) + ): # cuda could be cuda:0 + use_cuda = True + + if use_cuda and getattr(opts, "dataset.pin_memory", False): + if is_master(opts): + logger.error( + "For loading images on GPU, --dataset.pin-memory should be disabled." + ) + + self.device = device if use_cuda else torch.device("cpu") + + self.cached_data = ( + dict() + if getattr(opts, "dataset.cache_images_on_ram", False) and is_training + else None + ) + if self.cached_data is not None: + if not getattr(opts, "dataset.persistent_workers", False): + if is_master(opts): + logger.error( + "For caching, --dataset.persistent-workers should be enabled." + ) + + self.cache_limit = getattr(opts, "dataset.cache_limit", 80.0) + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + return parser + + @staticmethod + def load_from_server(opts, is_training): + try: + from internal.utils.server_utils import load_from_data_server + + opts = load_from_data_server(opts=opts, is_training=is_training) + except ImportError as e: + import traceback + traceback.print_exc() + logger.error( + "Unable to load data. Please load data manually. Error: {}".format(e) + ) + + return opts + + def _training_transforms(self, *args, **kwargs): + raise NotImplementedError + + def _validation_transforms(self, *args, **kwargs): + raise NotImplementedError + + def _evaluation_transforms(self, *args, **kwargs): + raise NotImplementedError + + def read_image_pil(self, path: str, *args, **kwargs): + def convert_to_rgb(inp_data: Union[str, io.BytesIO]): + try: + rgb_img = Image.open(inp_data).convert("RGB") + except: + rgb_img = None + return rgb_img + + if self.cached_data is not None: + # code for caching data on RAM + used_memory = float(psutil.virtual_memory().percent) + + if path in self.cached_data: + img_byte = self.cached_data[path] + + elif (path not in self.cached_data) and (used_memory <= self.cache_limit): + # image is not present in cache and RAM usage is less than the threshold, add to cache + with open(path, "rb") as bin_file: + bin_file_data = bin_file.read() + img_byte = io.BytesIO(bin_file_data) + self.cached_data[path] = img_byte + else: + with open(path, "rb") as bin_file: + bin_file_data = bin_file.read() + img_byte = io.BytesIO(bin_file_data) # in-memory data + img = convert_to_rgb(img_byte) + else: + img = convert_to_rgb(path) + return img + + def read_pil_image_torchvision(self, path: str): + if self.cached_data is not None: + # code for caching data on RAM + used_memory = float(psutil.virtual_memory().percent) + + if path in self.cached_data: + byte_img = self.cached_data[path] + elif (path not in self.cached_data) and (used_memory <= self.cache_limit): + # image is not present in cache and RAM usage is less than the threshold, add to cache + byte_img = read_file(path) + self.cached_data[path] = byte_img + else: + byte_img = read_file(path) + else: + byte_img = read_file(path) + img = decode_image(byte_img, mode=ImageReadMode.RGB) + return img + + def read_image_tensor(self, path: str): + if self.cached_data is not None: + # code for caching data on RAM + used_memory = float(psutil.virtual_memory().percent) + + if path in self.cached_data: + byte_img = self.cached_data[path] + elif (path not in self.cached_data) and (used_memory <= self.cache_limit): + # image is not present in cache and RAM usage is less than the threshold, add to cache + byte_img = read_file(path) + self.cached_data[path] = byte_img + else: + byte_img = read_file(path) + else: + byte_img = read_file(path) + img = decode_jpeg(byte_img, device=self.device, mode=ImageReadMode.RGB) + return img + + @staticmethod + def read_mask_pil(path: str): + try: + mask = Image.open(path) + if mask.mode != "L": + logger.error("Mask mode should be L. Got: {}".format(mask.mode)) + return mask + except: + return None + + @staticmethod + def read_image_opencv(path: str): + warnings.warn( + "The use of read_image_opencv function is depreciated. Please use read_image_pil", + DeprecationWarning, + ) + return cv2.imread( + path, cv2.IMREAD_COLOR + ) # Image is read in BGR Format and not RGB format + + @staticmethod + def read_mask_opencv(path: str): + warnings.warn( + "The use of read_mask_opencv function is depreciated. Please use read_mask_pil", + DeprecationWarning, + ) + return cv2.imread(path, cv2.IMREAD_GRAYSCALE) + + @staticmethod + def convert_mask_to_tensor(mask): + # convert to tensor + mask = np.array(mask) + if len(mask.shape) > 2 and mask.shape[-1] > 1: + mask = np.ascontiguousarray(mask.transpose(2, 0, 1)) + return torch.as_tensor(mask, dtype=torch.long) + + @staticmethod + def adjust_mask_value(): + return 0 + + @staticmethod + def class_names(): + pass + + def __repr__(self): + return "{}(\n\troot={}\n\t is_training={})".format( + self.__class__.__name__, self.root, self.is_training + ) diff --git a/Adaptive Frequency Filters/data/datasets/detection/__init__.py b/Adaptive Frequency Filters/data/datasets/detection/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/Adaptive Frequency Filters/data/datasets/detection/coco_base.py b/Adaptive Frequency Filters/data/datasets/detection/coco_base.py new file mode 100644 index 0000000..a9883e8 --- /dev/null +++ b/Adaptive Frequency Filters/data/datasets/detection/coco_base.py @@ -0,0 +1,342 @@ +# +# For licensing see accompanying LICENSE file. +# Copyright (C) 2023 Apple Inc. All Rights Reserved. +# + +import torch +from pycocotools.coco import COCO +from pycocotools import mask as coco_mask +import os +from typing import Optional, Tuple, Dict, List +import numpy as np +import argparse + +from utils import logger + +from ...transforms import image_pil as T +from ...datasets import BaseImageDataset, register_dataset + + +@register_dataset(name="coco", task="detection") +class COCODetection(BaseImageDataset): + """ + Base class for the MS COCO Object Detection Dataset. + + Args: + opts: command-line arguments + is_training (Optional[bool]): A flag used to indicate training or validation mode. Default: True + is_evaluation (Optional[bool]): A flag used to indicate evaluation (or inference) mode. Default: False + + .. note:: + This class implements basic functions (e.g., reading image and annotations), and does not implement + training/validation transforms. Detector specific sub-classes should extend this class and implement those + methods. See `coco_ssd.py` as an example for SSD. + + """ + + def __init__( + self, + opts, + is_training: Optional[bool] = True, + is_evaluation: Optional[bool] = False, + *args, + **kwargs + ) -> None: + super().__init__( + opts=opts, is_training=is_training, is_evaluation=is_evaluation + ) + + split = "train" if is_training else "val" + year = 2017 + ann_file = os.path.join( + self.root, "annotations/instances_{}{}.json".format(split, year) + ) + + # disable printing, so that pycocotools print statements are not printed + logger.disable_printing() + + self.coco = COCO(ann_file) + self.img_dir = os.path.join(self.root, "{}{}".format(split, year)) + self.ids = ( + list(self.coco.imgToAnns.keys()) + if is_training + else list(self.coco.imgs.keys()) + ) + + coco_categories = sorted(self.coco.getCatIds()) + bkrnd_id = ( + 0 if getattr(opts, "dataset.detection.no_background_id", False) else 1 + ) + self.coco_id_to_contiguous_id = { + coco_id: i + bkrnd_id for i, coco_id in enumerate(coco_categories) + } + self.contiguous_id_to_coco_id = { + v: k for k, v in self.coco_id_to_contiguous_id.items() + } + self.num_classes = len(self.contiguous_id_to_coco_id.keys()) + bkrnd_id + + # enable printing + logger.enable_printing() + + setattr(opts, "model.detection.n_classes", self.num_classes) + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + group.add_argument( + "--dataset.detection.no-background-id", + action="store_true", + help="Do not include background id", + ) + return parser + + def _training_transforms(self, size: tuple, ignore_idx: Optional[int] = 255): + """Training transforms should be implemented in sub-class""" + raise NotImplementedError + + def _validation_transforms(self, size: tuple, *args, **kwargs): + """Validation transforms should be implemented in sub-class""" + raise NotImplementedError + + def _evaluation_transforms(self, size: tuple, *args, **kwargs): + """Evaluation or Inference transforms (Resize (Optional) --> Tensor). + + .. note:: + Resizing the input to the same resolution as the detector's input is not enabled by default. + It can be enabled by passing **--evaluation.detection.resize-input-images** flag. + + """ + aug_list = [] + if getattr(self.opts, "evaluation.detection.resize_input_images", False): + aug_list.append(T.Resize(opts=self.opts, img_size=size)) + + aug_list.append(T.ToTensor(opts=self.opts)) + return T.Compose(opts=self.opts, img_transforms=aug_list) + + def __getitem__(self, batch_indexes_tup: Tuple, *args, **kwargs) -> Dict: + crop_size_h, crop_size_w, img_index = batch_indexes_tup + + if self.is_training: + transform_fn = self._training_transforms(size=(crop_size_h, crop_size_w)) + elif self.is_evaluation: + transform_fn = self._evaluation_transforms(size=(crop_size_h, crop_size_w)) + else: # same for validation and evaluation + transform_fn = self._validation_transforms(size=(crop_size_h, crop_size_w)) + + image_id = self.ids[img_index] + + image, img_name = self.get_image(image_id=image_id) + im_width, im_height = image.size + + boxes, labels, mask = self.get_boxes_and_labels( + image_id=image_id, + image_width=im_width, + image_height=im_height, + include_masks=True, + ) + + data = { + "image": image, + "box_labels": labels, + "box_coordinates": boxes, + "mask": mask, + } + + if transform_fn is not None: + data = transform_fn(data) + + output_data = { + "samples": { + "image": data["image"], + }, + "targets": { + "box_labels": data["box_labels"], + "box_coordinates": data["box_coordinates"], + "mask": data["mask"], + "image_id": torch.tensor(image_id), + "image_width": torch.tensor(im_width), + "image_height": torch.tensor(im_height), + }, + } + + return output_data + + def __len__(self): + return len(self.ids) + + def get_boxes_and_labels( + self, image_id, image_width, image_height, *args, include_masks=False, **kwargs + ) -> Tuple[np.ndarray, np.ndarray, Optional[np.ndarray]]: + ann_ids = self.coco.getAnnIds(imgIds=image_id) + ann = self.coco.loadAnns(ann_ids) + + # filter crowd annotations + ann = [obj for obj in ann if obj["iscrowd"] == 0] + boxes = np.array( + [self._xywh2xyxy(obj["bbox"], image_width, image_height) for obj in ann], + np.float32, + ).reshape((-1, 4)) + labels = np.array( + [self.coco_id_to_contiguous_id[obj["category_id"]] for obj in ann], np.int64 + ).reshape((-1,)) + # remove invalid boxes + keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0]) + boxes = boxes[keep] + labels = labels[keep] + + masks = None + if include_masks: + masks = [] + for obj in ann: + rle = coco_mask.frPyObjects( + obj["segmentation"], image_height, image_width + ) + m = coco_mask.decode(rle) + if len(m.shape) < 3: + mask = m.astype(np.uint8) + else: + mask = (np.sum(m, axis=2) > 0).astype(np.uint8) + masks.append(mask) + + if len(masks) > 0: + masks = np.stack(masks, axis=0) + else: + masks = np.zeros(shape=(0, image_height, image_width), dtype=np.uint8) + masks = masks.astype(np.uint8) + masks = torch.from_numpy(masks) + masks = masks[keep] + assert len(boxes) == len(labels) == len(masks) + return boxes, labels, masks + else: + return boxes, labels, None + + def _xywh2xyxy(self, box, image_width, image_height) -> List: + x1, y1, w, h = box + return [ + max(0, x1), + max(0, y1), + min(x1 + w, image_width), + min(y1 + h, image_height), + ] + + def get_image(self, image_id: int) -> Tuple: + file_name = self.coco.loadImgs(image_id)[0]["file_name"] + image_file = os.path.join(self.img_dir, file_name) + image = self.read_image_pil(image_file) + return image, file_name + + def extra_repr(self) -> str: + return "" + + def __repr__(self) -> str: + from utils.tensor_utils import image_size_from_opts + + im_h, im_w = image_size_from_opts(opts=self.opts) + + if self.is_training: + transforms_str = self._training_transforms(size=(im_h, im_w)) + elif self.is_evaluation: + transforms_str = self._evaluation_transforms(size=(im_h, im_w)) + else: + transforms_str = self._validation_transforms(size=(im_h, im_w)) + + repr_str = ( + "{}(\n\troot={}\n\tis_training={}\n\tsamples={}\n\ttransforms={}".format( + self.__class__.__name__, + self.root, + self.is_training, + len(self.ids), + transforms_str, + ) + ) + repr_str += self.extra_repr() + repr_str += "\n)" + return repr_str + + @staticmethod + def class_names() -> List: + return [ + "background", + "person", + "bicycle", + "car", + "motorcycle", + "airplane", + "bus", + "train", + "truck", + "boat", + "traffic light", + "fire hydrant", + "stop sign", + "parking meter", + "bench", + "bird", + "cat", + "dog", + "horse", + "sheep", + "cow", + "elephant", + "bear", + "zebra", + "giraffe", + "backpack", + "umbrella", + "handbag", + "tie", + "suitcase", + "frisbee", + "skis", + "snowboard", + "sports ball", + "kite", + "baseball bat", + "baseball glove", + "skateboard", + "surfboard", + "tennis racket", + "bottle", + "wine glass", + "cup", + "fork", + "knife", + "spoon", + "bowl", + "banana", + "apple", + "sandwich", + "orange", + "broccoli", + "carrot", + "hot dog", + "pizza", + "donut", + "cake", + "chair", + "couch", + "potted plant", + "bed", + "dining table", + "toilet", + "tv", + "laptop", + "mouse", + "remote", + "keyboard", + "cell phone", + "microwave", + "oven", + "toaster", + "sink", + "refrigerator", + "book", + "clock", + "vase", + "scissors", + "teddy bear", + "hair drier", + "toothbrush", + ] diff --git a/Adaptive Frequency Filters/data/datasets/detection/coco_mask_rcnn.py b/Adaptive Frequency Filters/data/datasets/detection/coco_mask_rcnn.py new file mode 100644 index 0000000..986c5c4 --- /dev/null +++ b/Adaptive Frequency Filters/data/datasets/detection/coco_mask_rcnn.py @@ -0,0 +1,150 @@ +# +# For licensing see accompanying LICENSE file. +# Copyright (C) 2023 Apple Inc. All Rights Reserved. +# + +import torch +from typing import Optional, Tuple, Dict, List +import math +import argparse + +from .coco_base import COCODetection +from ...transforms import image_pil as T +from ...datasets import register_dataset +from ...collate_fns import register_collate_fn + + +@register_dataset(name="coco_mask_rcnn", task="detection") +class COCODetectionMaskRCNN(COCODetection): + """Dataset class for the MS COCO Object Detection using Mask RCNN . + + Args: + opts : + Command line arguments + is_training : bool + A flag used to indicate training or validation mode + is_evaluation : bool + A flag used to indicate evaluation (or inference) mode + """ + + def __init__( + self, + opts, + is_training: Optional[bool] = True, + is_evaluation: Optional[bool] = False, + *args, + **kwargs + ) -> None: + super().__init__( + opts=opts, is_training=is_training, is_evaluation=is_evaluation + ) + + # set the collate functions for the dataset + setattr(opts, "dataset.collate_fn_name_train", "coco_mask_rcnn_collate_fn") + setattr(opts, "dataset.collate_fn_name_val", "coco_mask_rcnn_collate_fn") + setattr(opts, "dataset.collate_fn_name_eval", "coco_mask_rcnn_collate_fn") + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + group.add_argument( + "--dataset.detection.coco-mask-rcnn.use-lsj-aug", + action="store_true", + help="Use large scale jitter augmentation for training Mask RCNN model", + ) + + return parser + + def _training_transforms(self, size: tuple, ignore_idx: Optional[int] = 255): + """Training data augmentation methods + (Resize --> RandomHorizontalFlip --> ToTensor). + """ + + if getattr(self.opts, "dataset.detection.coco_mask_rcnn.use_lsj_aug", False): + aug_list = [ + T.ScaleJitter(opts=self.opts), + T.FixedSizeCrop(opts=self.opts), + T.RandomHorizontalFlip(opts=self.opts), + T.ToTensor(opts=self.opts), + ] + else: + aug_list = [ + T.Resize(opts=self.opts, img_size=size), + T.RandomHorizontalFlip(opts=self.opts), + T.ToTensor(opts=self.opts), + ] + + return T.Compose(opts=self.opts, img_transforms=aug_list) + + def _validation_transforms(self, size: tuple, *args, **kwargs): + """Implements validation transformation method (Resize --> ToTensor).""" + aug_list = [ + T.Resize(opts=self.opts), + T.ToTensor(opts=self.opts), + ] + return T.Compose(opts=self.opts, img_transforms=aug_list) + + def __getitem__(self, batch_indexes_tup: Tuple, *args, **kwargs) -> Dict: + crop_size_h, crop_size_w, img_index = batch_indexes_tup + + if self.is_training: + transform_fn = self._training_transforms(size=(crop_size_h, crop_size_w)) + else: # same for validation and evaluation + transform_fn = self._validation_transforms(size=(crop_size_h, crop_size_w)) + + image_id = self.ids[img_index] + + image, img_name = self.get_image(image_id=image_id) + im_width, im_height = image.size + + boxes, labels, mask = self.get_boxes_and_labels( + image_id=image_id, + image_width=im_width, + image_height=im_height, + include_masks=True, + ) + + data = { + "image": image, + "box_labels": labels, + "box_coordinates": boxes, + "mask": mask, + } + + if transform_fn is not None: + data = transform_fn(data) + + output_data = { + "samples": { + "image": data["image"], + # PyTorch Mask RCNN implementation expect labels as an input. Because we do not want to change the + # the training infrastructure of affnet library, we pass labels as part of image key and + # handle it in the model. + "label": { + "labels": data["box_labels"], + "boxes": data["box_coordinates"], + "masks": data["mask"], + }, + }, + "targets": { + "image_id": torch.tensor(image_id), + "image_width": torch.tensor(im_width), + "image_height": torch.tensor(im_height), + }, + } + + return output_data + + +@register_collate_fn(name="coco_mask_rcnn_collate_fn") +def coco_mask_rcnn_collate_fn(batch: List, opts, *args, **kwargs) -> Dict: + new_batch = {"samples": {"image": [], "label": []}, "targets": []} + + for b_id, batch_ in enumerate(batch): + new_batch["samples"]["image"].append(batch_["samples"]["image"]) + new_batch["samples"]["label"].append(batch_["samples"]["label"]) + new_batch["targets"].append(batch_["targets"]) + + return new_batch diff --git a/Adaptive Frequency Filters/data/datasets/detection/coco_ssd.py b/Adaptive Frequency Filters/data/datasets/detection/coco_ssd.py new file mode 100644 index 0000000..321dd46 --- /dev/null +++ b/Adaptive Frequency Filters/data/datasets/detection/coco_ssd.py @@ -0,0 +1,224 @@ +# +# For licensing see accompanying LICENSE file. +# Copyright (C) 2023 Apple Inc. All Rights Reserved. +# + +import torch +from typing import Optional, Tuple, Dict +import math +import argparse + +from utils import logger +from affnet.matcher_det import build_matcher +from affnet.anchor_generator import build_anchor_generator + +from .coco_base import COCODetection +from ...transforms import image_pil as T +from ...datasets import register_dataset +from ...collate_fns import register_collate_fn + + +@register_dataset(name="coco_ssd", task="detection") +class COCODetectionSSD(COCODetection): + """Dataset class for the MS COCO Object Detection using Single Shot Object Detector (SSD). + + Args: + opts : + Command line arguments + is_training : bool + A flag used to indicate training or validation mode + is_evaluation : bool + A flag used to indicate evaluation (or inference) mode + """ + + def __init__( + self, + opts, + is_training: Optional[bool] = True, + is_evaluation: Optional[bool] = False, + *args, + **kwargs + ) -> None: + super().__init__( + opts=opts, is_training=is_training, is_evaluation=is_evaluation + ) + + anchor_gen_name = getattr(opts, "anchor_generator.name", None) + if anchor_gen_name is None or anchor_gen_name != "ssd": + logger.error("For SSD, we need --anchor-generator.name to be ssd") + + self.anchor_box_generator = build_anchor_generator(opts=opts, is_numpy=True) + + self.output_strides = self.anchor_box_generator.output_strides + + if getattr(opts, "matcher.name") != "ssd": + logger.error("For SSD, we need --matcher.name as ssd") + + self.match_prior = build_matcher(opts=opts) + + # set the collate functions for the dataset + setattr(opts, "dataset.collate_fn_name_train", "coco_ssd_collate_fn") + setattr(opts, "dataset.collate_fn_name_val", "coco_ssd_collate_fn") + setattr(opts, "dataset.collate_fn_name_eval", "coco_ssd_collate_fn") + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + return parser + + def _training_transforms(self, size: tuple, ignore_idx: Optional[int] = 255): + """Training data augmentation methods + (SSDCroping --> PhotometricDistort --> RandomHorizontalFlip -> Resize --> ToTensor). + """ + aug_list = [ + T.SSDCroping(opts=self.opts), + T.PhotometricDistort(opts=self.opts), + T.RandomHorizontalFlip(opts=self.opts), + T.Resize(opts=self.opts, img_size=size), + T.BoxPercentCoords(opts=self.opts), + T.ToTensor(opts=self.opts), + ] + + return T.Compose(opts=self.opts, img_transforms=aug_list) + + def _validation_transforms(self, size: tuple, *args, **kwargs): + """Implements validation transformation method (Resize --> ToTensor).""" + aug_list = [ + T.Resize(opts=self.opts), + T.BoxPercentCoords(opts=self.opts), + T.ToTensor(opts=self.opts), + ] + return T.Compose(opts=self.opts, img_transforms=aug_list) + + def generate_anchors(self, height, width): + """Generate anchors **on-the-fly** based on the input resolution.""" + anchors = [] + for output_stride in self.output_strides: + if output_stride == -1: + fm_width = fm_height = 1 + else: + fm_width = int(math.ceil(width / output_stride)) + fm_height = int(math.ceil(height / output_stride)) + fm_anchor = self.anchor_box_generator( + fm_height=fm_height, fm_width=fm_width, fm_output_stride=output_stride + ) + anchors.append(fm_anchor) + anchors = torch.cat(anchors, dim=0) + return anchors + + def __getitem__(self, batch_indexes_tup: Tuple) -> Dict: + crop_size_h, crop_size_w, img_index = batch_indexes_tup + + if self.is_training: + transform_fn = self._training_transforms(size=(crop_size_h, crop_size_w)) + else: + # During evaluation, we use base class + transform_fn = self._validation_transforms(size=(crop_size_h, crop_size_w)) + + image_id = self.ids[img_index] + + image, img_fname = self.get_image(image_id=image_id) + im_width, im_height = image.size + boxes, labels, _ = self.get_boxes_and_labels( + image_id=image_id, image_width=im_width, image_height=im_height + ) + + data = {"image": image, "box_labels": labels, "box_coordinates": boxes} + + data = transform_fn(data) + + # convert to priors + anchors = self.generate_anchors(height=crop_size_h, width=crop_size_w) + + gt_coordinates, gt_labels = self.match_prior( + gt_boxes=data["box_coordinates"], + gt_labels=data["box_labels"], + anchors=anchors, + ) + + output_data = { + "samples": {"image": data.pop("image")}, + "targets": { + "box_labels": gt_labels, + "box_coordinates": gt_coordinates, + "image_id": torch.tensor(image_id), + "image_width": torch.tensor(im_width), + "image_height": torch.tensor(im_height), + }, + } + + return output_data + + def __repr__(self): + from utils.tensor_utils import image_size_from_opts + + im_h, im_w = image_size_from_opts(opts=self.opts) + + if self.is_training: + transforms_str = self._training_transforms(size=(im_h, im_w)) + elif self.is_evaluation: + transforms_str = self._evaluation_transforms(size=(im_h, im_w)) + else: + transforms_str = self._validation_transforms(size=(im_h, im_w)) + + return "{}(\n\troot={}\n\tis_training={}\n\tsamples={}\n\ttransforms={}\n\tmatcher={}\n\tanchor_gen={}\n)".format( + self.__class__.__name__, + self.root, + self.is_training, + len(self.ids), + transforms_str, + self.match_prior, + self.anchor_box_generator, + ) + + +@register_collate_fn(name="coco_ssd_collate_fn") +def coco_ssd_collate_fn(batch, opts): + new_batch = { + "samples": {"image": []}, + "targets": { + "box_labels": [], + "box_coordinates": [], + "image_id": [], + "image_width": [], + "image_height": [], + }, + } + + for b_id, batch_ in enumerate(batch): + # prepare inputs + new_batch["samples"]["image"].append(batch_["samples"]["image"]) + + # prepare outputs + new_batch["targets"]["box_labels"].append(batch_["targets"]["box_labels"]) + new_batch["targets"]["box_coordinates"].append( + batch_["targets"]["box_coordinates"] + ) + new_batch["targets"]["image_id"].append(batch_["targets"]["image_id"]) + new_batch["targets"]["image_width"].append(batch_["targets"]["image_width"]) + new_batch["targets"]["image_height"].append(batch_["targets"]["image_height"]) + + # stack inputs + new_batch["samples"]["image"] = torch.stack(new_batch["samples"]["image"], dim=0) + + # stack outputs + new_batch["targets"]["box_labels"] = torch.stack( + new_batch["targets"]["box_labels"], dim=0 + ) + + new_batch["targets"]["box_coordinates"] = torch.stack( + new_batch["targets"]["box_coordinates"], dim=0 + ) + + new_batch["targets"]["image_id"] = torch.stack( + new_batch["targets"]["image_id"], dim=0 + ) + + new_batch["targets"]["image_width"] = torch.stack( + new_batch["targets"]["image_width"], dim=0 + ) + + new_batch["targets"]["image_height"] = torch.stack( + new_batch["targets"]["image_height"], dim=0 + ) + + return new_batch diff --git a/Adaptive Frequency Filters/data/datasets/segmentation/__init__.py b/Adaptive Frequency Filters/data/datasets/segmentation/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/Adaptive Frequency Filters/data/datasets/segmentation/ade20k.py b/Adaptive Frequency Filters/data/datasets/segmentation/ade20k.py new file mode 100644 index 0000000..65f7148 --- /dev/null +++ b/Adaptive Frequency Filters/data/datasets/segmentation/ade20k.py @@ -0,0 +1,521 @@ +# +# For licensing see accompanying LICENSE file. +# Copyright (C) 2023 Apple Inc. All Rights Reserved. +# + +import os +from typing import Optional, List, Dict, Tuple +import numpy as np + +from utils import logger + +from .. import register_dataset +from ..dataset_base import BaseImageDataset +from ...transforms import image_pil as T + + +@register_dataset(name="ade20k", task="segmentation") +class ADE20KDataset(BaseImageDataset): + """ + Dataset class for the ADE20K dataset + + The structure of the dataset should be something like this: :: + + ADEChallengeData2016/annotations/training/*.png + ADEChallengeData2016/annotations/validation/*.png + + ADEChallengeData2016/images/training/*.jpg + ADEChallengeData2016/images/validation/*.jpg + + Args: + opts: command-line arguments + is_training (Optional[bool]): A flag used to indicate training or validation mode. Default: True + is_evaluation (Optional[bool]): A flag used to indicate evaluation (or inference) mode. Default: False + + """ + + def __init__( + self, + opts, + is_training: Optional[bool] = True, + is_evaluation: Optional[bool] = False, + *args, + **kwargs + ) -> None: + """ + + :param opts: arguments + :param is_training: Training or validation mode + :param is_evaluation: Evaluation mode + """ + super().__init__( + opts=opts, is_training=is_training, is_evaluation=is_evaluation + ) + root = self.root + + image_dir = os.path.join( + root, "images", "training" if is_training else "validation" + ) + annotation_dir = os.path.join( + root, "annotations", "training" if is_training else "validation" + ) + + images = [] + masks = [] + for file_name in os.listdir(image_dir): + if file_name.endswith(".jpg"): + img_f_name = "{}/{}".format(image_dir, file_name) + mask_f_name = "{}/{}".format( + annotation_dir, file_name.replace("jpg", "png") + ) + + if os.path.isfile(img_f_name) and os.path.isfile(mask_f_name): + images.append(img_f_name) + masks.append(mask_f_name) + + self.images = images + self.masks = masks + self.ignore_label = 255 + self.bgrnd_idx = 0 + setattr( + opts, "model.segmentation.n_classes", len(self.class_names()) - 1 + ) # ignore background + + # set the collate functions for the dataset + # For evaluation, we use PyTorch's default collate function. So, we set to collate_fn_name_eval to None + setattr(opts, "dataset.collate_fn_name_train", "default_collate_fn") + setattr(opts, "dataset.collate_fn_name_val", "default_collate_fn") + setattr(opts, "dataset.collate_fn_name_eval", None) + + def _training_transforms(self, size: tuple): + first_aug = T.RandomShortSizeResize(opts=self.opts) + aug_list = [ + T.RandomHorizontalFlip(opts=self.opts), + T.RandomCrop(opts=self.opts, size=size, ignore_idx=self.ignore_label), + ] + + if getattr(self.opts, "image_augmentation.random_gaussian_noise.enable", False): + aug_list.append(T.RandomGaussianBlur(opts=self.opts)) + + if getattr(self.opts, "image_augmentation.photo_metric_distort.enable", False): + aug_list.append(T.PhotometricDistort(opts=self.opts)) + + if getattr(self.opts, "image_augmentation.random_rotate.enable", False): + aug_list.append(T.RandomRotate(opts=self.opts)) + + if getattr(self.opts, "image_augmentation.random_order.enable", False): + new_aug_list = [ + first_aug, + T.RandomOrder(opts=self.opts, img_transforms=aug_list), + T.ToTensor(opts=self.opts), + ] + return T.Compose(opts=self.opts, img_transforms=new_aug_list) + else: + aug_list.insert(0, first_aug) + aug_list.append(T.ToTensor(opts=self.opts)) + return T.Compose(opts=self.opts, img_transforms=aug_list) + + def _validation_transforms(self, size: tuple, *args, **kwargs): + aug_list = [T.Resize(opts=self.opts), T.ToTensor(opts=self.opts)] + return T.Compose(opts=self.opts, img_transforms=aug_list) + + def _evaluation_transforms(self, size: tuple, *args, **kwargs): + aug_list = [] + if getattr(self.opts, "evaluation.segmentation.resize_input_images", False): + # we want to resize while maintaining aspect ratio. So, we pass img_size argument to resize function + aug_list.append(T.Resize(opts=self.opts, img_size=min(size))) + + aug_list.append(T.ToTensor(opts=self.opts)) + return T.Compose(opts=self.opts, img_transforms=aug_list) + + def __getitem__(self, batch_indexes_tup: Tuple[int, int, int]) -> Dict: + crop_size_h, crop_size_w, img_index = batch_indexes_tup + crop_size = (crop_size_h, crop_size_w) + + if self.is_training: + _transform = self._training_transforms(size=crop_size) + elif self.is_evaluation: + _transform = self._evaluation_transforms(size=crop_size) + else: + _transform = self._validation_transforms(size=crop_size) + + mask = self.read_mask_pil(self.masks[img_index]) + img = self.read_image_pil(self.images[img_index]) + + if (img.size[0] != mask.size[0]) or (img.size[1] != mask.size[1]): + logger.error( + "Input image and mask sizes are different. Input size: {} and Mask size: {}".format( + img.size, mask.size + ) + ) + + data = {"image": img} + if not self.is_evaluation: + data["mask"] = mask + + data = _transform(data) + + if self.is_evaluation: + # for evaluation purposes, resize only the input and not mask + data["mask"] = self.convert_mask_to_tensor(mask) + + output_data = { + "samples": data["image"], + "targets": data["mask"] - 1, # ignore background during training + } + + if self.is_evaluation: + im_width, im_height = img.size + img_name = self.images[img_index].split(os.sep)[-1].replace("jpg", "png") + mask = output_data.pop("targets") + output_data["targets"] = { + "mask": mask, + "file_name": img_name, + "im_width": im_width, + "im_height": im_height, + } + + return output_data + + @staticmethod + def adjust_mask_value(): + return 1 + + def __len__(self) -> int: + return len(self.images) + + @staticmethod + def color_palette() -> List: + color_codes = [ + [0, 0, 0], # background + [120, 120, 120], + [180, 120, 120], + [6, 230, 230], + [80, 50, 50], + [4, 200, 3], + [120, 120, 80], + [140, 140, 140], + [204, 5, 255], + [230, 230, 230], + [4, 250, 7], + [224, 5, 255], + [235, 255, 7], + [150, 5, 61], + [120, 120, 70], + [8, 255, 51], + [255, 6, 82], + [143, 255, 140], + [204, 255, 4], + [255, 51, 7], + [204, 70, 3], + [0, 102, 200], + [61, 230, 250], + [255, 6, 51], + [11, 102, 255], + [255, 7, 71], + [255, 9, 224], + [9, 7, 230], + [220, 220, 220], + [255, 9, 92], + [112, 9, 255], + [8, 255, 214], + [7, 255, 224], + [255, 184, 6], + [10, 255, 71], + [255, 41, 10], + [7, 255, 255], + [224, 255, 8], + [102, 8, 255], + [255, 61, 6], + [255, 194, 7], + [255, 122, 8], + [0, 255, 20], + [255, 8, 41], + [255, 5, 153], + [6, 51, 255], + [235, 12, 255], + [160, 150, 20], + [0, 163, 255], + [140, 140, 140], + [250, 10, 15], + [20, 255, 0], + [31, 255, 0], + [255, 31, 0], + [255, 224, 0], + [153, 255, 0], + [0, 0, 255], + [255, 71, 0], + [0, 235, 255], + [0, 173, 255], + [31, 0, 255], + [11, 200, 200], + [255, 82, 0], + [0, 255, 245], + [0, 61, 255], + [0, 255, 112], + [0, 255, 133], + [255, 0, 0], + [255, 163, 0], + [255, 102, 0], + [194, 255, 0], + [0, 143, 255], + [51, 255, 0], + [0, 82, 255], + [0, 255, 41], + [0, 255, 173], + [10, 0, 255], + [173, 255, 0], + [0, 255, 153], + [255, 92, 0], + [255, 0, 255], + [255, 0, 245], + [255, 0, 102], + [255, 173, 0], + [255, 0, 20], + [255, 184, 184], + [0, 31, 255], + [0, 255, 61], + [0, 71, 255], + [255, 0, 204], + [0, 255, 194], + [0, 255, 82], + [0, 10, 255], + [0, 112, 255], + [51, 0, 255], + [0, 194, 255], + [0, 122, 255], + [0, 255, 163], + [255, 153, 0], + [0, 255, 10], + [255, 112, 0], + [143, 255, 0], + [82, 0, 255], + [163, 255, 0], + [255, 235, 0], + [8, 184, 170], + [133, 0, 255], + [0, 255, 92], + [184, 0, 255], + [255, 0, 31], + [0, 184, 255], + [0, 214, 255], + [255, 0, 112], + [92, 255, 0], + [0, 224, 255], + [112, 224, 255], + [70, 184, 160], + [163, 0, 255], + [153, 0, 255], + [71, 255, 0], + [255, 0, 163], + [255, 204, 0], + [255, 0, 143], + [0, 255, 235], + [133, 255, 0], + [255, 0, 235], + [245, 0, 255], + [255, 0, 122], + [255, 245, 0], + [10, 190, 212], + [214, 255, 0], + [0, 204, 255], + [20, 0, 255], + [255, 255, 0], + [0, 153, 255], + [0, 41, 255], + [0, 255, 204], + [41, 0, 255], + [41, 255, 0], + [173, 0, 255], + [0, 245, 255], + [71, 0, 255], + [122, 0, 255], + [0, 255, 184], + [0, 92, 255], + [184, 255, 0], + [0, 133, 255], + [255, 214, 0], + [25, 194, 194], + [102, 255, 0], + [92, 0, 255], + ] + color_codes = np.asarray(color_codes).flatten() + return list(color_codes) + + @staticmethod + def class_names() -> List: + return [ + "background", + "wall", + "building", + "sky", + "floor", + "tree", + "ceiling", + "road", + "bed ", + "windowpane", + "grass", + "cabinet", + "sidewalk", + "person", + "earth", + "door", + "table", + "mountain", + "plant", + "curtain", + "chair", + "car", + "water", + "painting", + "sofa", + "shelf", + "house", + "sea", + "mirror", + "rug", + "field", + "armchair", + "seat", + "fence", + "desk", + "rock", + "wardrobe", + "lamp", + "bathtub", + "railing", + "cushion", + "base", + "box", + "column", + "signboard", + "chest of drawers", + "counter", + "sand", + "sink", + "skyscraper", + "fireplace", + "refrigerator", + "grandstand", + "path", + "stairs", + "runway", + "case", + "pool table", + "pillow", + "screen door", + "stairway", + "river", + "bridge", + "bookcase", + "blind", + "coffee table", + "toilet", + "flower", + "book", + "hill", + "bench", + "countertop", + "stove", + "palm", + "kitchen island", + "computer", + "swivel chair", + "boat", + "bar", + "arcade machine", + "hovel", + "bus", + "towel", + "light", + "truck", + "tower", + "chandelier", + "awning", + "streetlight", + "booth", + "television receiver", + "airplane", + "dirt track", + "apparel", + "pole", + "land", + "bannister", + "escalator", + "ottoman", + "bottle", + "buffet", + "poster", + "stage", + "van", + "ship", + "fountain", + "conveyer belt", + "canopy", + "washer", + "plaything", + "swimming pool", + "stool", + "barrel", + "basket", + "waterfall", + "tent", + "bag", + "minibike", + "cradle", + "oven", + "ball", + "food", + "step", + "tank", + "trade name", + "microwave", + "pot", + "animal", + "bicycle", + "lake", + "dishwasher", + "screen", + "blanket", + "sculpture", + "hood", + "sconce", + "vase", + "traffic light", + "tray", + "ashcan", + "fan", + "pier", + "crt screen", + "plate", + "monitor", + "bulletin board", + "shower", + "radiator", + "glass", + "clock", + "flag", + ] + + def __repr__(self) -> str: + from utils.tensor_utils import image_size_from_opts + + im_h, im_w = image_size_from_opts(opts=self.opts) + + if self.is_training: + transforms_str = self._training_transforms(size=(im_h, im_w)) + elif self.is_evaluation: + transforms_str = self._evaluation_transforms(size=(im_h, im_w)) + else: + transforms_str = self._validation_transforms(size=(im_h, im_w)) + + return ( + "{}(\n\troot={}\n\tis_training={}\n\tsamples={}\n\ttransforms={}\n)".format( + self.__class__.__name__, + self.root, + self.is_training, + len(self.images), + transforms_str, + ) + ) diff --git a/Adaptive Frequency Filters/data/datasets/segmentation/coco_segmentation.py b/Adaptive Frequency Filters/data/datasets/segmentation/coco_segmentation.py new file mode 100644 index 0000000..9fe8359 --- /dev/null +++ b/Adaptive Frequency Filters/data/datasets/segmentation/coco_segmentation.py @@ -0,0 +1,230 @@ +# +# For licensing see accompanying LICENSE file. +# Copyright (C) 2023 Apple Inc. All Rights Reserved. +# + +import os +from typing import Optional, List, Dict, Union +import argparse + +from pycocotools.coco import COCO +from pycocotools import mask +import numpy as np +import os +from typing import Optional + + +from .. import register_dataset +from ..dataset_base import BaseImageDataset +from ...transforms import image_pil as T + + +@register_dataset("coco", "segmentation") +class COCODataset(BaseImageDataset): + """ + Dataset class for the COCO dataset that maps classes to PASCAL VOC classes + + Args: + opts: command-line arguments + is_training (Optional[bool]): A flag used to indicate training or validation mode. Default: True + is_evaluation (Optional[bool]): A flag used to indicate evaluation (or inference) mode. Default: False + """ + + def __init__( + self, + opts, + is_training: Optional[bool] = True, + is_evaluation: Optional[bool] = False, + *args, + **kwargs + ) -> None: + """ + + :param opts: arguments + :param is_training: Training or validation mode + :param is_evaluation: Evaluation mode + """ + super().__init__( + opts=opts, is_training=is_training, is_evaluation=is_evaluation + ) + year = 2017 + split = "train" if is_training else "val" + ann_file = os.path.join( + self.root, "annotations/instances_{}{}.json".format(split, year) + ) + self.img_dir = os.path.join(self.root, "images/{}{}".format(split, year)) + self.split = split + self.coco = COCO(ann_file) + self.coco_mask = mask + self.ids = list(self.coco.imgs.keys()) + + self.ignore_label = 255 + self.bgrnd_idx = 0 + + setattr(opts, "model.segmentation.n_classes", len(self.class_names())) + + def __getitem__(self, batch_indexes_tup): + crop_size_h, crop_size_w, img_index = batch_indexes_tup + crop_size = (crop_size_h, crop_size_w) + + if self.is_training: + _transform = self._training_transforms( + size=crop_size, ignore_idx=self.ignore_label + ) + elif self.is_evaluation: + _transform = self._evaluation_transforms(size=crop_size) + else: + _transform = self._validation_transforms(size=crop_size) + + coco = self.coco + img_id = self.ids[img_index] + img_metadata = coco.loadImgs(img_id)[0] + path = img_metadata["file_name"] + + rgb_img = self.read_image_opencv(os.path.join(self.img_dir, path)) + cocotarget = coco.loadAnns(coco.getAnnIds(imgIds=img_id)) + + im_height, im_width = rgb_img.shape[:2] + + mask = self._gen_seg_mask( + cocotarget, img_metadata["height"], img_metadata["width"] + ) + + data = {"image": rgb_img, "mask": None if self.is_evaluation else mask} + + data = _transform(data) + + if self.is_evaluation: + # for evaluation purposes, resize only the input and not mask + data["mask"] = mask + + output_data = {"samples": data["image"], "targets": data["mask"]} + + if self.is_evaluation: + img_name = path.replace("jpg", "png") + mask = output_data.pop("targets") + output_data["targets"] = { + "mask": mask, + "file_name": img_name, + "im_width": im_width, + "im_height": im_height, + } + + return output_data + + def _gen_seg_mask(self, target, h, w): + mask = np.zeros((h, w), dtype=np.uint8) + coco_mask = self.coco_mask + coco_to_pascal = self.coco_to_pascal_mapping() + for instance in target: + rle = coco_mask.frPyObjects(instance["segmentation"], h, w) + m = coco_mask.decode(rle) + cat = instance["category_id"] + if cat in coco_to_pascal: + c = coco_to_pascal.index(cat) + else: + continue + if len(m.shape) < 3: + mask[:, :] += (mask == 0) * (m * c) + else: + mask[:, :] += (mask == 0) * (((np.sum(m, axis=2)) > 0) * c).astype( + np.uint8 + ) + return mask + + def _training_transforms(self, size: tuple, ignore_idx: Optional[int] = 255): + aug_list = [ + T.RandomResize(opts=self.opts), + T.RandomCrop(opts=self.opts, size=size), + T.RandomHorizontalFlip(opts=self.opts), + T.ToTensor(opts=self.opts), + ] + + return T.Compose(opts=self.opts, img_transforms=aug_list) + + def _validation_transforms(self, size: tuple, *args, **kwargs): + aug_list = [T.Resize(opts=self.opts), T.ToTensor(opts=self.opts)] + return T.Compose(opts=self.opts, img_transforms=aug_list) + + def _evaluation_transforms(self, size: tuple, *args, **kwargs): + aug_list = [] + if getattr(self.opts, "evaluation.segmentation.resize_input_images", False): + aug_list.append(T.Resize(opts=self.opts)) + + aug_list.append(T.ToTensor(opts=self.opts)) + return T.Compose(opts=self.opts, img_transforms=aug_list) + + def __len__(self): + return len(self.ids) + + @staticmethod + def class_names() -> List: + return [ + "background", + "aeroplane", + "bicycle", + "bird", + "boat", + "bottle", + "bus", + "car", + "cat", + "chair", + "cow", + "diningtable", + "dog", + "horse", + "motorbike", + "person", + "potted_plant", + "sheep", + "sofa", + "train", + "tv_monitor", + ] + + @staticmethod + def coco_to_pascal_mapping(): + return [ + 0, + 5, + 2, + 16, + 9, + 44, + 6, + 3, + 17, + 62, + 21, + 67, + 18, + 19, + 4, + 1, + 64, + 20, + 63, + 7, + 72, + ] + + def __repr__(self): + from utils.tensor_utils import image_size_from_opts + + im_h, im_w = image_size_from_opts(opts=self.opts) + + if self.is_training: + transforms_str = self._training_transforms(size=(im_h, im_w)) + elif self.is_evaluation: + transforms_str = self._evaluation_transforms(size=(im_h, im_w)) + else: + transforms_str = self._validation_transforms(size=(im_h, im_w)) + + return "{}(\n\troot={}\n\tis_training={}\n\tsamples={}\n\t\n\ttransforms={}\n)".format( + self.__class__.__name__, + self.root, + self.is_training, + len(self.ids), + transforms_str, + ) diff --git a/Adaptive Frequency Filters/data/datasets/segmentation/pascal_voc.py b/Adaptive Frequency Filters/data/datasets/segmentation/pascal_voc.py new file mode 100644 index 0000000..5ee858b --- /dev/null +++ b/Adaptive Frequency Filters/data/datasets/segmentation/pascal_voc.py @@ -0,0 +1,275 @@ +# +# For licensing see accompanying LICENSE file. +# Copyright (C) 2023 Apple Inc. All Rights Reserved. +# + +import os +from typing import Optional, List, Tuple, Dict +import argparse +import numpy as np + +from .. import register_dataset +from ..dataset_base import BaseImageDataset +from ...transforms import image_pil as T + + +@register_dataset("pascal", "segmentation") +class PascalVOCDataset(BaseImageDataset): + """ + Dataset class for the PASCAL VOC 2012 dataset + + The structure of PASCAL VOC dataset should be something like this: :: + + pascal_voc/VOCdevkit/VOC2012/Annotations + pascal_voc/VOCdevkit/VOC2012/JPEGImages + pascal_voc/VOCdevkit/VOC2012/SegmentationClass + pascal_voc/VOCdevkit/VOC2012/SegmentationClassAug_Visualization + pascal_voc/VOCdevkit/VOC2012/ImageSets + pascal_voc/VOCdevkit/VOC2012/list + pascal_voc/VOCdevkit/VOC2012/SegmentationClassAug + pascal_voc/VOCdevkit/VOC2012/SegmentationObject + + Args: + opts: command-line arguments + is_training (Optional[bool]): A flag used to indicate training or validation mode. Default: True + is_evaluation (Optional[bool]): A flag used to indicate evaluation (or inference) mode. Default: False + """ + + def __init__( + self, + opts, + is_training: Optional[bool] = True, + is_evaluation: Optional[bool] = False, + *args, + **kwargs + ) -> None: + super().__init__( + opts=opts, is_training=is_training, is_evaluation=is_evaluation + ) + use_coco_data = getattr(opts, "dataset.pascal.use_coco_data", False) + coco_root_dir = getattr(opts, "dataset.pascal.coco_root_dir", None) + root = self.root + + voc_root_dir = os.path.join(root, "VOC2012") + voc_list_dir = os.path.join(voc_root_dir, "list") + + coco_data_file = None + if self.is_training: + # use the PASCAL VOC 2012 train data with augmented data + data_file = os.path.join(voc_list_dir, "train_aug.txt") + if use_coco_data and coco_root_dir is not None: + coco_data_file = os.path.join(coco_root_dir, "train_2017.txt") + assert os.path.isfile( + coco_data_file + ), "COCO data file does not exist at: {}".format(coco_root_dir) + else: + data_file = os.path.join(voc_list_dir, "val.txt") + + self.images = [] + self.masks = [] + with open(data_file, "r") as lines: + for line in lines: + line_split = line.split(" ") + rgb_img_loc = voc_root_dir + os.sep + line_split[0].strip() + mask_img_loc = voc_root_dir + os.sep + line_split[1].strip() + assert os.path.isfile( + rgb_img_loc + ), "RGB file does not exist at: {}".format(rgb_img_loc) + assert os.path.isfile( + mask_img_loc + ), "Mask image does not exist at: {}".format(rgb_img_loc) + self.images.append(rgb_img_loc) + self.masks.append(mask_img_loc) + + # if you want to use Coarse data for training + if self.is_training and coco_data_file is not None: + with open(coco_data_file, "r") as lines: + for line in lines: + line_split = line.split(" ") + rgb_img_loc = coco_root_dir + os.sep + line_split[0].rstrip() + mask_img_loc = coco_root_dir + os.sep + line_split[1].rstrip() + # assert os.path.isfile(rgb_img_loc) + # assert os.path.isfile(mask_img_loc) + self.images.append(rgb_img_loc) + self.masks.append(mask_img_loc) + self.use_coco_data = use_coco_data + self.ignore_label = 255 + self.bgrnd_idx = 0 + setattr(opts, "model.segmentation.n_classes", len(self.class_names())) + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + group.add_argument( + "--dataset.pascal.use-coco-data", + action="store_true", + help="Use MS-COCO data for training", + ) + group.add_argument( + "--dataset.pascal.coco-root-dir", + type=str, + default=None, + help="Location of MS-COCO data", + ) + return parser + + @staticmethod + def color_palette(): + color_codes = [ + [0, 0, 0], + [128, 0, 0], + [0, 128, 0], + [128, 128, 0], + [0, 0, 128], + [128, 0, 128], + [0, 128, 128], + [128, 128, 128], + [64, 0, 0], + [192, 0, 0], + [64, 128, 0], + [192, 128, 0], + [64, 0, 128], + [192, 0, 128], + [64, 128, 128], + [192, 128, 128], + [0, 64, 0], + [128, 64, 0], + [0, 192, 0], + [128, 192, 0], + [0, 64, 128], + ] + + color_codes = np.asarray(color_codes).flatten() + return list(color_codes) + + @staticmethod + def class_names() -> List: + return [ + "background", + "aeroplane", + "bicycle", + "bird", + "boat", + "bottle", + "bus", + "car", + "cat", + "chair", + "cow", + "diningtable", + "dog", + "horse", + "motorbike", + "person", + "potted_plant", + "sheep", + "sofa", + "train", + "tv_monitor", + ] + + def _training_transforms(self, size: tuple): + first_aug = T.RandomShortSizeResize(opts=self.opts) + aug_list = [ + T.RandomHorizontalFlip(opts=self.opts), + T.RandomCrop(opts=self.opts, size=size, ignore_idx=self.ignore_label), + ] + + if getattr(self.opts, "image_augmentation.random_gaussian_noise.enable", False): + aug_list.append(T.RandomGaussianBlur(opts=self.opts)) + + if getattr(self.opts, "image_augmentation.photo_metric_distort.enable", False): + aug_list.append(T.PhotometricDistort(opts=self.opts)) + + if getattr(self.opts, "image_augmentation.random_rotate.enable", False): + aug_list.append(T.RandomRotate(opts=self.opts)) + + if getattr(self.opts, "image_augmentation.random_order.enable", False): + new_aug_list = [ + first_aug, + T.RandomOrder(opts=self.opts, img_transforms=aug_list), + T.ToTensor(opts=self.opts), + ] + return T.Compose(opts=self.opts, img_transforms=new_aug_list) + else: + aug_list.insert(0, first_aug) + aug_list.append(T.ToTensor(opts=self.opts)) + return T.Compose(opts=self.opts, img_transforms=aug_list) + + def _validation_transforms(self, size: tuple, *args, **kwargs): + aug_list = [T.Resize(opts=self.opts), T.ToTensor(opts=self.opts)] + return T.Compose(opts=self.opts, img_transforms=aug_list) + + def _evaluation_transforms(self, size: tuple, *args, **kwargs): + aug_list = [] + if getattr(self.opts, "evaluation.segmentation.resize_input_images", False): + # we want to resize while maintaining aspect ratio. So, we pass img_size argument to resize function + aug_list.append(T.Resize(opts=self.opts, img_size=min(size))) + + aug_list.append(T.ToTensor(opts=self.opts)) + return T.Compose(opts=self.opts, img_transforms=aug_list) + + def __getitem__(self, batch_indexes_tup: Tuple) -> Dict: + crop_size_h, crop_size_w, img_index = batch_indexes_tup + crop_size = (crop_size_h, crop_size_w) + + if self.is_training: + _transform = self._training_transforms(size=crop_size) + elif self.is_evaluation: + _transform = self._evaluation_transforms(size=crop_size) + else: + _transform = self._validation_transforms(size=crop_size) + + img = self.read_image_pil(self.images[img_index]) + mask = self.read_mask_pil(self.masks[img_index]) + + data = {"image": img} + if not self.is_evaluation: + data["mask"] = mask + + data = _transform(data) + + if self.is_evaluation: + # for evaluation purposes, resize only the input and not mask + data["mask"] = self.convert_mask_to_tensor(mask) + + output_data = {"samples": data["image"], "targets": data["mask"]} + + if self.is_evaluation: + im_width, im_height = img.size + img_name = self.images[img_index].split(os.sep)[-1].replace("jpg", "png") + mask = output_data.pop("targets") + output_data["targets"] = { + "mask": mask, + "file_name": img_name, + "im_width": im_width, + "im_height": im_height, + } + + return output_data + + def __len__(self): + return len(self.images) + + def __repr__(self): + from utils.tensor_utils import image_size_from_opts + + im_h, im_w = image_size_from_opts(opts=self.opts) + + if self.is_training: + transforms_str = self._training_transforms(size=(im_h, im_w)) + elif self.is_evaluation: + transforms_str = self._evaluation_transforms(size=(im_h, im_w)) + else: + transforms_str = self._validation_transforms(size=(im_h, im_w)) + + return "{}(\n\troot={}\n\tis_training={}\n\tsamples={}\n\tuse_coco={}\n\ttransforms={}\n)".format( + self.__class__.__name__, + self.root, + self.is_training, + len(self.images), + self.use_coco_data, + transforms_str, + ) diff --git a/Adaptive Frequency Filters/data/loader/__init__.py b/Adaptive Frequency Filters/data/loader/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/Adaptive Frequency Filters/data/loader/dataloader.py b/Adaptive Frequency Filters/data/loader/dataloader.py new file mode 100644 index 0000000..d6c9f77 --- /dev/null +++ b/Adaptive Frequency Filters/data/loader/dataloader.py @@ -0,0 +1,54 @@ +# +# For licensing see accompanying LICENSE file. +# Copyright (C) 2023 Apple Inc. All Rights Reserved. +# + +import torch +from typing import Optional, Union, List +from torch.utils.data import DataLoader + +from ..sampler.base_sampler import BaseSamplerDP, BaseSamplerDDP +from ..datasets.dataset_base import BaseImageDataset + + +class affnetDataLoader(DataLoader): + """This class extends PyTorch's Dataloader""" + + def __init__( + self, + dataset: BaseImageDataset, + batch_size: int, + batch_sampler: Union[BaseSamplerDP, BaseSamplerDDP], + num_workers: Optional[int] = 1, + pin_memory: Optional[bool] = False, + persistent_workers: Optional[bool] = False, + collate_fn: Optional = None, + prefetch_factor: Optional[int] = 2, + *args, + **kwargs + ): + super(affnetDataLoader, self).__init__( + dataset=dataset, + batch_size=batch_size, + batch_sampler=batch_sampler, + num_workers=num_workers, + pin_memory=pin_memory, + persistent_workers=persistent_workers, + collate_fn=collate_fn, + prefetch_factor=prefetch_factor, + ) + + def update_indices(self, new_indices: List, *args, **kwargs): + """Update indices in the dataset class""" + if hasattr(self.batch_sampler, "img_indices") and hasattr( + self.batch_sampler, "update_indices" + ): + self.batch_sampler.update_indices(new_indices) + + def samples_in_dataset(self): + """Number of samples in the dataset""" + return len(self.batch_sampler.img_indices) + + def get_sample_indices(self) -> List: + """Sample IDs""" + return self.batch_sampler.img_indices diff --git a/Adaptive Frequency Filters/data/sampler/__init__.py b/Adaptive Frequency Filters/data/sampler/__init__.py new file mode 100644 index 0000000..3c1ece8 --- /dev/null +++ b/Adaptive Frequency Filters/data/sampler/__init__.py @@ -0,0 +1,116 @@ +# +# For licensing see accompanying LICENSE file. +# Copyright (C) 2023 Apple Inc. All Rights Reserved. +# + +import os +import importlib +from typing import Optional +from utils import logger +import argparse + +from utils.ddp_utils import is_master + +from .base_sampler import BaseSamplerDDP, BaseSamplerDP + +SAMPLER_REGISTRY = {} + + +def register_sampler(name): + def register_sampler_class(cls): + if name in SAMPLER_REGISTRY: + raise ValueError( + "Cannot register duplicate sampler class ({})".format(name) + ) + + if not (issubclass(cls, BaseSamplerDDP) or issubclass(cls, BaseSamplerDP)): + raise ValueError( + "Sampler ({}: {}) must extend BaseSamplerDDP or BaseSamplerDP".format( + name, cls.__name__ + ) + ) + + SAMPLER_REGISTRY[name] = cls + return cls + + return register_sampler_class + + +def build_sampler(opts, n_data_samples: int, is_training: Optional[bool] = False): + sampler_name = getattr(opts, "sampler.name", "variable_batch_sampler") + is_distributed = getattr(opts, "ddp.use_distributed", False) + + if is_distributed and sampler_name.split("_")[-1] != "ddp": + sampler_name = sampler_name + "_ddp" + + sampler = None + if sampler_name in SAMPLER_REGISTRY: + sampler = SAMPLER_REGISTRY[sampler_name]( + opts, n_data_samples=n_data_samples, is_training=is_training + ) + else: + supp_list = list(SAMPLER_REGISTRY.keys()) + supp_str = ( + "Sampler ({}) not yet supported. \n Supported optimizers are:".format( + sampler_name + ) + ) + for i, m_name in enumerate(supp_list): + supp_str += "\n\t {}: {}".format(i, logger.color_text(m_name)) + logger.error(supp_str) + + return sampler + + +def sampler_common_args(parser: argparse.ArgumentParser): + parser.add_argument( + "--sampler.name", type=str, default="batch_sampler", help="Name of the sampler" + ) + parser.add_argument( + "--sampler.use-shards", + action="store_true", + help="Use data sharding. Only applicable to DDP", + ) + parser.add_argument( + "--sampler.num-repeats", + type=int, + default=1, + help="Repeat samples, as in repeated augmentation", + ) + + parser.add_argument( + "--sampler.truncated-repeat-aug-sampler", + action="store_true", + help="Use truncated repeated augmentation sampler", + ) + + parser.add_argument( + "--sampler.disable-shuffle-sharding", + action="store_true", + help="Disable shuffling while sharding for extremely large datasets", + ) + + return parser + + +def arguments_sampler(parser: argparse.ArgumentParser): + parser = sampler_common_args(parser=parser) + + # add classification specific arguments + for k, v in SAMPLER_REGISTRY.items(): + parser = v.add_arguments(parser=parser) + + return parser + + +# automatically import the samplers +sampler_dir = os.path.dirname(__file__) +for file in os.listdir(sampler_dir): + path = os.path.join(sampler_dir, file) + if ( + not file.startswith("_") + and not file.startswith(".") + and (file.endswith(".py") or os.path.isdir(path)) + ): + sampler_name = file[: file.find(".py")] if file.endswith(".py") else file + module = importlib.import_module("data.sampler." + sampler_name) diff --git a/Adaptive Frequency Filters/data/sampler/base_sampler.py b/Adaptive Frequency Filters/data/sampler/base_sampler.py new file mode 100644 index 0000000..827c2f0 --- /dev/null +++ b/Adaptive Frequency Filters/data/sampler/base_sampler.py @@ -0,0 +1,295 @@ +# +# For licensing see accompanying LICENSE file. +# Copyright (C) 2023 Apple Inc. All Rights Reserved. +# + +import torch +from torch.utils.data.sampler import Sampler +from typing import Optional +import torch.distributed as dist +import math +import argparse +import copy +import numpy as np +import random + + +class BaseSamplerDP(Sampler): + """ + Base class for DataParallel Sampler + + Args: + opts: command line argument + n_data_samples (int): Number of samples in the dataset + is_training (Optional[bool]): Training or validation mode. Default: False + """ + + def __init__( + self, + opts, + n_data_samples: int, + is_training: Optional[bool] = False, + *args, + **kwargs + ) -> None: + # max between 1 and number of available GPUs. 1 because for supporting CPUs + n_gpus: int = max(1, torch.cuda.device_count()) + batch_size_gpu0: int = ( + getattr(opts, "dataset.train_batch_size0", 32) + if is_training + else getattr(opts, "dataset.val_batch_size0", 32) + ) + + n_samples_per_gpu = int(math.ceil(n_data_samples * 1.0 / n_gpus)) + total_size = n_samples_per_gpu * n_gpus + + indexes = [idx for idx in range(n_data_samples)] + # This ensures that we can divide the batches evenly across GPUs + indexes += indexes[: (total_size - n_data_samples)] + assert total_size == len(indexes) + + self.img_indices = indexes + self.n_samples = total_size + self.batch_size_gpu0 = batch_size_gpu0 + self.n_gpus = n_gpus + self.shuffle = True if is_training else False + self.epoch = 0 + + self.num_repeats = getattr(opts, "sampler.num_repeats", 1) if is_training else 1 + self.trunc_rep_aug = getattr( + opts, "sampler.truncated_repeat_aug_sampler", False + ) + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + return parser + + def extra_repr(self): + extra_repr_str = "\n\t num_repeat={}" "\n\t trunc_rep_aug={}".format( + self.num_repeats, self.trunc_rep_aug + ) + return extra_repr_str + + def get_indices(self): + img_indices = copy.deepcopy(self.img_indices) + if self.shuffle: + random.seed(self.epoch) + random.shuffle(img_indices) + + if self.num_repeats > 1: + # Apply repeated augmentation + """Assume that we have [0, 1, 2, 3] samples. With repeated augmentation, + we first repeat the samples [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3] and then select 4 + samples [0, 0, 0, 1]. Note that we do shuffle at the beginning, so samples are not the + same at every iteration. + """ + n_samples_before_repeat = len(img_indices) + img_indices = np.repeat(img_indices, repeats=self.num_repeats) + img_indices = list(img_indices) + if self.trunc_rep_aug: + img_indices = img_indices[:n_samples_before_repeat] + return img_indices + + def __iter__(self): + raise NotImplementedError + + def __len__(self): + return len(self.img_indices) * (1 if self.trunc_rep_aug else self.num_repeats) + + def set_epoch(self, epoch): + self.epoch = epoch + + def update_scales(self, epoch, is_master_node=False, *args, **kwargs): + pass + + def update_indices(self, new_indices): + self.img_indices = new_indices + + def __repr__(self): + return "{}()".format(self.__class__.__name__) + + +class BaseSamplerDDP(Sampler): + """ + Base class for DistributedDataParallel Sampler + + Args: + opts: command line argument + n_data_samples (int): Number of samples in the dataset + is_training (Optional[bool]): Training or validation mode. Default: False + """ + + def __init__( + self, + opts, + n_data_samples: int, + is_training: Optional[bool] = False, + *args, + **kwargs + ) -> None: + # max between 1 and number of available GPUs. 1 because for supporting CPUs + batch_size_gpu0: int = ( + getattr(opts, "dataset.train_batch_size0", 32) + if is_training + else getattr(opts, "dataset.val_batch_size0", 32) + ) + + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + + num_replicas = dist.get_world_size() + rank = dist.get_rank() + gpus_node_i = max(1, torch.cuda.device_count()) + + num_samples_per_replica = int(math.ceil(n_data_samples * 1.0 / num_replicas)) + total_size = num_samples_per_replica * num_replicas + + img_indices = [idx for idx in range(n_data_samples)] + img_indices += img_indices[: (total_size - n_data_samples)] + assert len(img_indices) == total_size + + self.img_indices = img_indices + self.n_samples_per_replica = num_samples_per_replica + self.shuffle = True if is_training else False + self.epoch = 0 + self.rank = rank + self.batch_size_gpu0 = batch_size_gpu0 + self.num_replicas = num_replicas + self.skip_sample_indices = [] + self.node_id = rank // gpus_node_i + + self.num_nodes = max(1, num_replicas // gpus_node_i) + self.local_rank = rank % gpus_node_i + self.num_gpus_node_i = gpus_node_i + + self.sharding = ( + getattr(opts, "sampler.use_shards", False) if is_training else False + ) + self.num_repeats = getattr(opts, "sampler.num_repeats", 1) if is_training else 1 + self.trunc_rep_aug = ( + getattr(opts, "sampler.truncated_repeat_aug_sampler", False) + if self.num_repeats + else False + ) + self.n_samples_per_replica = num_samples_per_replica * ( + 1 if self.trunc_rep_aug else self.num_repeats + ) + self.disable_shuffle_sharding = getattr( + opts, "sampler.disable_shuffle_sharding", False + ) + + def extra_repr(self): + extra_repr_str = ( + "\n\t num_repeat={}" + "\n\t trunc_rep_aug={}" + "\n\t sharding={}" + "\n\t disable_shuffle_sharding={}".format( + self.num_repeats, + self.trunc_rep_aug, + self.sharding, + self.disable_shuffle_sharding, + ) + ) + return extra_repr_str + + def get_indices_rank_i(self): + img_indices = copy.deepcopy(self.img_indices) + if self.shuffle: + random.seed(self.epoch) + + if self.sharding: + """If we have 8 samples, say [0, 1, 2, 3, 4, 5, 6, 7], and we have two nodes, + then node 0 will receive first 4 samples and node 1 will receive last 4 samples. + + note: + This strategy is useful when dataset is large and we want to process subset of dataset on each node. + """ + + # compute number pf samples per node. + # Each node may have multiple GPUs + # Node id = rank // num_gpus_per_rank + samples_per_node = int(math.ceil(len(img_indices) / self.num_nodes)) + indices_node_i = img_indices[ + self.node_id + * samples_per_node : (self.node_id + 1) + * samples_per_node + ] + + # Ensure that each node has equal number of samples + if len(indices_node_i) < samples_per_node: + indices_node_i += indices_node_i[ + : (samples_per_node - len(indices_node_i)) + ] + + # Note: For extremely large datasets, we may want to disable shuffling for efficient data loading + if not self.disable_shuffle_sharding: + # shuffle the indices within a node. + random.shuffle(indices_node_i) + + if self.num_repeats > 1: + """Assume that we have [0, 1, 2, 3] samples in rank_i. With repeated augmentation, + we first repeat the samples [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3] and then select 4 + samples [0, 0, 0, 1]. Note shuffling at the beginning + """ + # Apply repeated augmentation + n_samples_before_repeat = len(indices_node_i) + indices_node_i = np.repeat(indices_node_i, repeats=self.num_repeats) + indices_node_i = list(indices_node_i) + if self.trunc_rep_aug: + indices_node_i = indices_node_i[:n_samples_before_repeat] + + # divide the samples among each GPU in a node + indices_rank_i = indices_node_i[ + self.local_rank : len(indices_node_i) : self.num_gpus_node_i + ] + else: + """If we have 8 samples, say [0, 1, 2, 3, 4, 5, 6, 7], and we have two nodes, + then node 0 will receive [0, 2, 4, 6] and node 1 will receive [1, 3, 4, 7]. + + note: + This strategy is useful when each data sample is stored independently, and is + default in many frameworks + """ + random.shuffle(img_indices) + + if self.num_repeats > 1: + # Apply repeated augmentation + n_samples_before_repeat = len(img_indices) + img_indices = np.repeat(img_indices, repeats=self.num_repeats) + img_indices = list(img_indices) + if self.trunc_rep_aug: + img_indices = img_indices[:n_samples_before_repeat] + + # divide the samples among each GPU in a node + indices_rank_i = img_indices[ + self.rank : len(img_indices) : self.num_replicas + ] + else: + indices_rank_i = img_indices[ + self.rank : len(self.img_indices) : self.num_replicas + ] + return indices_rank_i + + def __iter__(self): + raise NotImplementedError + + def __len__(self): + return (len(self.img_indices) // self.num_replicas) * ( + 1 if self.trunc_rep_aug else self.num_repeats + ) + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + return parser + + def set_epoch(self, epoch): + self.epoch = epoch + + def update_scales(self, epoch, is_master_node=False, *args, **kwargs): + pass + + def update_indices(self, new_indices): + self.img_indices = new_indices + + def __repr__(self): + return "{}()".format(self.__class__.__name__) diff --git a/Adaptive Frequency Filters/data/sampler/batch_sampler.py b/Adaptive Frequency Filters/data/sampler/batch_sampler.py new file mode 100644 index 0000000..a12bf79 --- /dev/null +++ b/Adaptive Frequency Filters/data/sampler/batch_sampler.py @@ -0,0 +1,155 @@ +# +# For licensing see accompanying LICENSE file. +# Copyright (C) 2023 Apple Inc. All Rights Reserved. +# +import copy +import random +import argparse +from typing import Optional +import math +import numpy as np + +from common import DEFAULT_IMAGE_WIDTH, DEFAULT_IMAGE_HEIGHT + +from . import register_sampler, BaseSamplerDDP, BaseSamplerDP + + +@register_sampler(name="batch_sampler") +class BatchSampler(BaseSamplerDP): + """ + Standard Batch Sampler for data parallel + + Args: + opts: command line argument + n_data_samples (int): Number of samples in the dataset + is_training (Optional[bool]): Training or validation mode. Default: False + """ + + def __init__( + self, + opts, + n_data_samples: int, + is_training: Optional[bool] = False, + *args, + **kwargs + ) -> None: + super().__init__( + opts=opts, n_data_samples=n_data_samples, is_training=is_training + ) + crop_size_w: int = getattr( + opts, "sampler.bs.crop_size_width", DEFAULT_IMAGE_WIDTH + ) + crop_size_h: int = getattr( + opts, "sampler.bs.crop_size_height", DEFAULT_IMAGE_HEIGHT + ) + + self.crop_size_w = crop_size_w + self.crop_size_h = crop_size_h + + def __iter__(self): + img_indices = self.get_indices() + + start_index = 0 + batch_size = self.batch_size_gpu0 + n_samples = len(img_indices) + while start_index < n_samples: + + end_index = min(start_index + batch_size, n_samples) + batch_ids = img_indices[start_index:end_index] + start_index += batch_size + + if len(batch_ids) > 0: + batch = [ + (self.crop_size_h, self.crop_size_w, b_id) for b_id in batch_ids + ] + yield batch + + def __repr__(self): + repr_str = "{}(".format(self.__class__.__name__) + repr_str += "\n\tbase_im_size=(h={}, w={})" "\n\tbase_batch_size={}".format( + self.crop_size_h, self.crop_size_w, self.batch_size_gpu0 + ) + repr_str += self.extra_repr() + repr_str += "\n)" + return repr_str + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="Batch sampler", description="Arguments related to Batch sampler" + ) + group.add_argument( + "--sampler.bs.crop-size-width", + default=DEFAULT_IMAGE_WIDTH, + type=int, + help="Base crop size (along width) during training", + ) + group.add_argument( + "--sampler.bs.crop-size-height", + default=DEFAULT_IMAGE_HEIGHT, + type=int, + help="Base crop size (along height) during training", + ) + return parser + + +@register_sampler(name="batch_sampler_ddp") +class BatchSamplerDDP(BaseSamplerDDP): + """ + Standard Batch Sampler for distributed data parallel + + Args: + opts: command line argument + n_data_samples (int): Number of samples in the dataset + is_training (Optional[bool]): Training or validation mode. Default: False + """ + + def __init__( + self, + opts, + n_data_samples: int, + is_training: Optional[bool] = False, + *args, + **kwargs + ) -> None: + super().__init__( + opts=opts, n_data_samples=n_data_samples, is_training=is_training + ) + crop_size_w: int = getattr( + opts, "sampler.bs.crop_size_width", DEFAULT_IMAGE_WIDTH + ) + crop_size_h: int = getattr( + opts, "sampler.bs.crop_size_height", DEFAULT_IMAGE_HEIGHT + ) + + self.crop_size_w = crop_size_w + self.crop_size_h = crop_size_h + + def __iter__(self): + indices_rank_i = self.get_indices_rank_i() + start_index = 0 + batch_size = self.batch_size_gpu0 + + n_samples_rank_i = len(indices_rank_i) + while start_index < n_samples_rank_i: + end_index = min(start_index + batch_size, n_samples_rank_i) + batch_ids = indices_rank_i[start_index:end_index] + n_batch_samples = len(batch_ids) + if n_batch_samples != batch_size: + batch_ids += indices_rank_i[: (batch_size - n_batch_samples)] + start_index += batch_size + + if len(batch_ids) > 0: + batch = [ + (self.crop_size_h, self.crop_size_w, b_id) for b_id in batch_ids + ] + yield batch + + def __repr__(self): + repr_str = "{}(".format(self.__class__.__name__) + repr_str += "\n\tbase_im_size=(h={}, w={})" "\n\tbase_batch_size={}".format( + self.crop_size_h, self.crop_size_w, self.batch_size_gpu0 + ) + repr_str += self.extra_repr() + repr_str += "\n)" + return repr_str diff --git a/Adaptive Frequency Filters/data/sampler/multi_scale_sampler.py b/Adaptive Frequency Filters/data/sampler/multi_scale_sampler.py new file mode 100644 index 0000000..411423b --- /dev/null +++ b/Adaptive Frequency Filters/data/sampler/multi_scale_sampler.py @@ -0,0 +1,339 @@ +# +# For licensing see accompanying LICENSE file. +# Copyright (C) 2023 Apple Inc. All Rights Reserved. +# +import copy +import random +import argparse +from utils import logger +from typing import Optional +from common import DEFAULT_IMAGE_WIDTH, DEFAULT_IMAGE_HEIGHT +import numpy as np + +from . import register_sampler, BaseSamplerDP, BaseSamplerDDP +from .utils import _image_batch_pairs + + +@register_sampler(name="multi_scale_sampler") +class MultiScaleSampler(BaseSamplerDP): + """ + Multi-scale Batch Sampler for data parallel + + Args: + opts: command line argument + n_data_samples (int): Number of samples in the dataset + is_training (Optional[bool]): Training or validation mode. Default: False + """ + + def __init__( + self, + opts, + n_data_samples: int, + is_training: Optional[bool] = False, + *args, + **kwargs + ) -> None: + super().__init__( + opts=opts, n_data_samples=n_data_samples, is_training=is_training + ) + + crop_size_w: int = getattr( + opts, "sampler.msc.crop_size_width", DEFAULT_IMAGE_WIDTH + ) + crop_size_h: int = getattr( + opts, "sampler.msc.crop_size_height", DEFAULT_IMAGE_HEIGHT + ) + + min_crop_size_w: int = getattr(opts, "sampler.msc.min_crop_size_width", 160) + max_crop_size_w: int = getattr(opts, "sampler.msc.max_crop_size_width", 320) + + min_crop_size_h: int = getattr(opts, "sampler.msc.min_crop_size_height", 160) + max_crop_size_h: int = getattr(opts, "sampler.msc.max_crop_size_height", 320) + + scale_inc: bool = getattr(opts, "sampler.msc.scale_inc", False) + scale_ep_intervals: list or int = getattr( + opts, "sampler.msc.ep_intervals", [40] + ) + scale_inc_factor: float = getattr(opts, "sampler.msc.scale_inc_factor", 0.25) + + check_scale_div_factor: int = getattr(opts, "sampler.msc.check_scale", 32) + max_img_scales: int = getattr(opts, "sampler.msc.max_n_scales", 10) + + if isinstance(scale_ep_intervals, int): + scale_ep_intervals = [scale_ep_intervals] + + self.min_crop_size_w = min_crop_size_w + self.max_crop_size_w = max_crop_size_w + self.min_crop_size_h = min_crop_size_h + self.max_crop_size_h = max_crop_size_h + + self.crop_size_w = crop_size_w + self.crop_size_h = crop_size_h + + self.scale_inc_factor = scale_inc_factor + self.scale_ep_intervals = scale_ep_intervals + + self.max_img_scales = max_img_scales + self.check_scale_div_factor = check_scale_div_factor + self.scale_inc = scale_inc + + if is_training: + self.img_batch_tuples = _image_batch_pairs( + crop_size_h=self.crop_size_h, + crop_size_w=self.crop_size_w, + batch_size_gpu0=self.batch_size_gpu0, + n_gpus=self.n_gpus, + max_scales=self.max_img_scales, + check_scale_div_factor=self.check_scale_div_factor, + min_crop_size_w=self.min_crop_size_w, + max_crop_size_w=self.max_crop_size_w, + min_crop_size_h=self.min_crop_size_h, + max_crop_size_h=self.max_crop_size_h, + ) + # over-ride the batch-size + self.img_batch_tuples = [ + (h, w, self.batch_size_gpu0) for h, w, b in self.img_batch_tuples + ] + else: + self.img_batch_tuples = [(crop_size_h, crop_size_w, self.batch_size_gpu0)] + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="Multi-scale sampler", description="Multi-scale sampler" + ) + group.add_argument( + "--sampler.msc.crop-size-width", + default=DEFAULT_IMAGE_WIDTH, + type=int, + help="Base crop size (along width) during training", + ) + group.add_argument( + "--sampler.msc.crop-size-height", + default=DEFAULT_IMAGE_HEIGHT, + type=int, + help="Base crop size (along height) during training", + ) + + group.add_argument( + "--sampler.msc.min-crop-size-width", + default=160, + type=int, + help="Min. crop size along width during training", + ) + group.add_argument( + "--sampler.msc.max-crop-size-width", + default=320, + type=int, + help="Max. crop size along width during training", + ) + + group.add_argument( + "--sampler.msc.min-crop-size-height", + default=160, + type=int, + help="Min. crop size along height during training", + ) + group.add_argument( + "--sampler.msc.max-crop-size-height", + default=320, + type=int, + help="Max. crop size along height during training", + ) + group.add_argument( + "--sampler.msc.max-n-scales", + default=5, + type=int, + help="Max. scales in variable batch sampler. For example, [0.25, 0.5, 0.75, 1, 1.25] ", + ) + group.add_argument( + "--sampler.msc.check-scale", + default=32, + type=int, + help="Image scales should be divisible by this factor", + ) + group.add_argument( + "--sampler.msc.ep-intervals", + default=[40], + type=int, + help="Epoch intervals at which scales are adjusted", + ) + group.add_argument( + "--sampler.msc.scale-inc-factor", + default=0.25, + type=float, + help="Factor by which we should increase the scale", + ) + group.add_argument( + "--sampler.msc.scale-inc", + action="store_true", + help="Increase image scales during training", + ) + + return parser + + def __iter__(self): + img_indices = self.get_indices() + start_index = 0 + n_samples = len(img_indices) + while start_index < n_samples: + crop_h, crop_w, batch_size = random.choice(self.img_batch_tuples) + + end_index = min(start_index + batch_size, n_samples) + batch_ids = img_indices[start_index:end_index] + n_batch_samples = len(batch_ids) + if len(batch_ids) != batch_size: + batch_ids += img_indices[: (batch_size - n_batch_samples)] + start_index += batch_size + + if len(batch_ids) > 0: + batch = [(crop_h, crop_w, b_id) for b_id in batch_ids] + yield batch + + def update_scales(self, epoch, is_master_node=False, *args, **kwargs): + pass + + def __repr__(self): + repr_str = "{}(".format(self.__class__.__name__) + repr_str += ( + "\n\t base_im_size=(h={}, w={}), " + "\n\t base_batch_size={} " + "\n\t scales={} " + "\n\t scale_inc={} " + "\n\t scale_inc_factor={} " + "\n\t ep_intervals={}".format( + self.crop_size_h, + self.crop_size_w, + self.batch_size_gpu0, + self.img_batch_tuples, + self.scale_inc, + self.scale_inc_factor, + self.scale_ep_intervals, + ) + ) + repr_str += self.extra_repr() + repr_str += "\n)" + return repr_str + + +@register_sampler(name="multi_scale_sampler_ddp") +class MultiScaleSamplerDDP(BaseSamplerDDP): + """ + Multi-scale Batch Sampler for distributed data parallel + + Args: + opts: command line argument + n_data_samples (int): Number of samples in the dataset + is_training (Optional[bool]): Training or validation mode. Default: False + """ + + def __init__( + self, + opts, + n_data_samples: int, + is_training: Optional[bool] = False, + *args, + **kwargs + ) -> None: + super().__init__( + opts=opts, n_data_samples=n_data_samples, is_training=is_training + ) + crop_size_w: int = getattr( + opts, "sampler.msc.crop_size_width", DEFAULT_IMAGE_WIDTH + ) + crop_size_h: int = getattr( + opts, "sampler.msc.crop_size_height", DEFAULT_IMAGE_HEIGHT + ) + + min_crop_size_w: int = getattr(opts, "sampler.msc.min_crop_size_width", 160) + max_crop_size_w: int = getattr(opts, "sampler.msc.max_crop_size_width", 320) + + min_crop_size_h: int = getattr(opts, "sampler.msc.min_crop_size_height", 160) + max_crop_size_h: int = getattr(opts, "sampler.msc.max_crop_size_height", 320) + + scale_inc: bool = getattr(opts, "sampler.msc.scale_inc", False) + scale_ep_intervals: list or int = getattr( + opts, "sampler.msc.ep_intervals", [40] + ) + scale_inc_factor: float = getattr(opts, "sampler.msc.scale_inc_factor", 0.25) + check_scale_div_factor: int = getattr(opts, "sampler.msc.check_scale", 32) + + max_img_scales: int = getattr(opts, "sampler.msc.max_n_scales", 10) + + self.crop_size_h = crop_size_h + self.crop_size_w = crop_size_w + self.min_crop_size_h = min_crop_size_h + self.max_crop_size_h = max_crop_size_h + self.min_crop_size_w = min_crop_size_w + self.max_crop_size_w = max_crop_size_w + + self.scale_inc_factor = scale_inc_factor + self.scale_ep_intervals = scale_ep_intervals + self.max_img_scales = max_img_scales + self.check_scale_div_factor = check_scale_div_factor + self.scale_inc = scale_inc + + if is_training: + self.img_batch_tuples = _image_batch_pairs( + crop_size_h=self.crop_size_h, + crop_size_w=self.crop_size_w, + batch_size_gpu0=self.batch_size_gpu0, + n_gpus=self.num_replicas, + max_scales=self.max_img_scales, + check_scale_div_factor=self.check_scale_div_factor, + min_crop_size_w=self.min_crop_size_w, + max_crop_size_w=self.max_crop_size_w, + min_crop_size_h=self.min_crop_size_h, + max_crop_size_h=self.max_crop_size_h, + ) + self.img_batch_tuples = [ + (h, w, self.batch_size_gpu0) for h, w, b in self.img_batch_tuples + ] + else: + self.img_batch_tuples = [ + (self.crop_size_h, self.crop_size_w, self.batch_size_gpu0) + ] + + def __iter__(self): + indices_rank_i = self.get_indices_rank_i() + + start_index = 0 + n_samples_rank_i = len(indices_rank_i) + while start_index < n_samples_rank_i: + crop_h, crop_w, batch_size = random.choice(self.img_batch_tuples) + + end_index = min(start_index + batch_size, n_samples_rank_i) + batch_ids = indices_rank_i[start_index:end_index] + n_batch_samples = len(batch_ids) + if n_batch_samples != batch_size: + batch_ids += indices_rank_i[: (batch_size - n_batch_samples)] + start_index += batch_size + + if len(batch_ids) > 0: + batch = [(crop_h, crop_w, b_id) for b_id in batch_ids] + yield batch + + def update_scales(self, epoch, is_master_node=False, *args, **kwargs): + pass + + def __repr__(self): + repr_str = "{}(".format(self.__class__.__name__) + repr_str += ( + "\n\t base_im_size=(h={}, w={}), " + "\n\t base_batch_size={} " + "\n\t scales={} " + "\n\t scale_inc={} " + "\n\t scale_inc_factor={} " + "\n\t ep_intervals={}".format( + self.crop_size_h, + self.crop_size_w, + self.batch_size_gpu0, + self.img_batch_tuples, + self.scale_inc, + self.scale_inc_factor, + self.scale_ep_intervals, + ) + ) + repr_str += self.extra_repr() + repr_str += "\n )" + return repr_str diff --git a/Adaptive Frequency Filters/data/sampler/utils.py b/Adaptive Frequency Filters/data/sampler/utils.py new file mode 100644 index 0000000..936e780 --- /dev/null +++ b/Adaptive Frequency Filters/data/sampler/utils.py @@ -0,0 +1,124 @@ +# +# For licensing see accompanying LICENSE file. +# Copyright (C) 2023 Apple Inc. All Rights Reserved. +# + +from typing import Optional, List +import numpy as np + +from utils.math_utils import make_divisible + + +def _image_batch_pairs( + crop_size_w: int, + crop_size_h: int, + batch_size_gpu0: int, + n_gpus: int, + max_scales: Optional[float] = 5, + check_scale_div_factor: Optional[int] = 32, + min_crop_size_w: Optional[int] = 160, + max_crop_size_w: Optional[int] = 320, + min_crop_size_h: Optional[int] = 160, + max_crop_size_h: Optional[int] = 320, + *args, + **kwargs +) -> List: + """ + This function creates batch and image size pairs. For a given batch size and image size, different image sizes + are generated and batch size is adjusted so that GPU memory can be utilized efficiently. + + Args: + crop_size_w (int): Base Image width (e.g., 224) + crop_size_h (int): Base Image height (e.g., 224) + batch_size_gpu0 (int): Batch size on GPU 0 for base image + n_gpus (int): Number of available GPUs + max_scales (Optional[int]): Number of scales. How many image sizes that we want to generate between min and max scale factors. Default: 5 + check_scale_div_factor (Optional[int]): Check if image scales are divisible by this factor. Default: 32 + min_crop_size_w (Optional[int]): Min. crop size along width. Default: 160 + max_crop_size_w (Optional[int]): Max. crop size along width. Default: 320 + min_crop_size_h (Optional[int]): Min. crop size along height. Default: 160 + max_crop_size_h (Optional[int]): Max. crop size along height. Default: 320 + + Returns: + a sorted list of tuples. Each index is of the form (h, w, batch_size) + + """ + width_dims = list(np.linspace(min_crop_size_w, max_crop_size_w, max_scales)) + if crop_size_w not in width_dims: + width_dims.append(crop_size_w) + + height_dims = list(np.linspace(min_crop_size_h, max_crop_size_h, max_scales)) + if crop_size_h not in height_dims: + height_dims.append(crop_size_h) + + image_scales = set() + + for h, w in zip(height_dims, width_dims): + # ensure that sampled sizes are divisible by check_scale_div_factor + # This is important in some cases where input undergoes a fixed number of down-sampling stages + # for instance, in ImageNet training, CNNs usually have 5 downsampling stages, which downsamples the + # input image of resolution 224x224 to 7x7 size + h = make_divisible(h, check_scale_div_factor) + w = make_divisible(w, check_scale_div_factor) + image_scales.add((h, w)) + + image_scales = list(image_scales) + + img_batch_tuples = set() + n_elements = crop_size_w * crop_size_h * batch_size_gpu0 + for (crop_h, crop_y) in image_scales: + # compute the batch size for sampled image resolutions with respect to the base resolution + _bsz = max(1, int(round(n_elements / (crop_h * crop_y), 2))) + + img_batch_tuples.add((crop_h, crop_y, _bsz)) + + img_batch_tuples = list(img_batch_tuples) + return sorted(img_batch_tuples) + + +def make_video_pairs( + crop_size_h: int, + crop_size_w: int, + min_crop_size_h: int, + max_crop_size_h: int, + min_crop_size_w: int, + max_crop_size_w: int, + default_frames: int, + max_scales: Optional[int] = 5, + check_scale_div_factor: Optional[int] = 32, + *args, + **kwargs +) -> List: + """ + This function creates number of frames and spatial size pairs for videos. + + Args: + crop_size_h (int): Base Image height (e.g., 224) + crop_size_w (int): Base Image width (e.g., 224) + min_crop_size_w (int): Min. crop size along width. + max_crop_size_w (int): Max. crop size along width. + min_crop_size_h (int): Min. crop size along height. + max_crop_size_h (int): Max. crop size along height. + default_frames (int): Default number of frames per clip in a video. + max_scales (Optional[int]): Number of scales. Default: 5 + check_scale_div_factor (Optional[int]): Check if spatial scales are divisible by this factor. Default: 32 + Returns: + a sorted list of tuples. Each index is of the form (h, w, n_frames) + """ + + width_dims = list(np.linspace(min_crop_size_w, max_crop_size_w, max_scales)) + if crop_size_w not in width_dims: + width_dims.append(crop_size_w) + height_dims = list(np.linspace(min_crop_size_h, max_crop_size_h, max_scales)) + if crop_size_h not in height_dims: + height_dims.append(crop_size_h) + + # ensure that spatial dimensions are divisible by check_scale_div_factor + width_dims = [make_divisible(w, check_scale_div_factor) for w in width_dims] + height_dims = [make_divisible(h, check_scale_div_factor) for h in height_dims] + batch_pairs = set() + n_elements = crop_size_w * crop_size_h * default_frames + for (h, w) in zip(height_dims, width_dims): + n_frames = max(1, int(round(n_elements / (h * w), 2))) + batch_pairs.add((h, w, n_frames)) + return sorted(list(batch_pairs)) diff --git a/Adaptive Frequency Filters/data/sampler/variable_batch_sampler.py b/Adaptive Frequency Filters/data/sampler/variable_batch_sampler.py new file mode 100644 index 0000000..a5a7449 --- /dev/null +++ b/Adaptive Frequency Filters/data/sampler/variable_batch_sampler.py @@ -0,0 +1,421 @@ +# +# For licensing see accompanying LICENSE file. +# Copyright (C) 2023 Apple Inc. All Rights Reserved. +# +import copy +import random +import argparse +from typing import Optional +import numpy as np +import math + +from utils import logger +from common import DEFAULT_IMAGE_WIDTH, DEFAULT_IMAGE_HEIGHT + +from . import register_sampler, BaseSamplerDP, BaseSamplerDDP +from .utils import _image_batch_pairs + + +@register_sampler(name="variable_batch_sampler") +class VariableBatchSampler(BaseSamplerDP): + """ + `Variably-size multi-scale batch sampler ` for data parallel + + Args: + opts: command line argument + n_data_samples (int): Number of samples in the dataset + is_training (Optional[bool]): Training or validation mode. Default: False + """ + + def __init__( + self, + opts, + n_data_samples: int, + is_training: Optional[bool] = False, + *args, + **kwargs + ) -> None: + super().__init__( + opts=opts, n_data_samples=n_data_samples, is_training=is_training + ) + + crop_size_w: int = getattr( + opts, "sampler.vbs.crop_size_width", DEFAULT_IMAGE_WIDTH + ) + crop_size_h: int = getattr( + opts, "sampler.vbs.crop_size_height", DEFAULT_IMAGE_HEIGHT + ) + + min_crop_size_w: int = getattr(opts, "sampler.vbs.min_crop_size_width", 160) + max_crop_size_w: int = getattr(opts, "sampler.vbs.max_crop_size_width", 320) + + min_crop_size_h: int = getattr(opts, "sampler.vbs.min_crop_size_height", 160) + max_crop_size_h: int = getattr(opts, "sampler.vbs.max_crop_size_height", 320) + + scale_inc: bool = getattr(opts, "sampler.vbs.scale_inc", False) + scale_ep_intervals: list or int = getattr( + opts, "sampler.vbs.ep_intervals", [40] + ) + min_scale_inc_factor: float = getattr( + opts, "sampler.vbs.min_scale_inc_factor", 1.0 + ) + max_scale_inc_factor: float = getattr( + opts, "sampler.vbs.max_scale_inc_factor", 1.0 + ) + + check_scale_div_factor: int = getattr(opts, "sampler.vbs.check_scale", 32) + max_img_scales: int = getattr(opts, "sampler.vbs.max_n_scales", 10) + + if isinstance(scale_ep_intervals, int): + scale_ep_intervals = [scale_ep_intervals] + + self.min_crop_size_w = min_crop_size_w + self.max_crop_size_w = max_crop_size_w + self.min_crop_size_h = min_crop_size_h + self.max_crop_size_h = max_crop_size_h + + self.crop_size_w = crop_size_w + self.crop_size_h = crop_size_h + + self.min_scale_inc_factor = min_scale_inc_factor + self.max_scale_inc_factor = max_scale_inc_factor + self.scale_ep_intervals = scale_ep_intervals + + self.max_img_scales = max_img_scales + self.check_scale_div_factor = check_scale_div_factor + self.scale_inc = scale_inc + + if is_training: + self.img_batch_tuples = _image_batch_pairs( + crop_size_h=self.crop_size_h, + crop_size_w=self.crop_size_w, + batch_size_gpu0=self.batch_size_gpu0, + n_gpus=self.n_gpus, + max_scales=self.max_img_scales, + check_scale_div_factor=self.check_scale_div_factor, + min_crop_size_w=self.min_crop_size_w, + max_crop_size_w=self.max_crop_size_w, + min_crop_size_h=self.min_crop_size_h, + max_crop_size_h=self.max_crop_size_h, + ) + else: + self.img_batch_tuples = [(crop_size_h, crop_size_w, self.batch_size_gpu0)] + + def __iter__(self): + img_indices = self.get_indices() + start_index = 0 + n_samples = len(img_indices) + while start_index < n_samples: + crop_h, crop_w, batch_size = random.choice(self.img_batch_tuples) + + end_index = min(start_index + batch_size, n_samples) + batch_ids = img_indices[start_index:end_index] + n_batch_samples = len(batch_ids) + if len(batch_ids) != batch_size: + batch_ids += img_indices[: (batch_size - n_batch_samples)] + start_index += batch_size + + if len(batch_ids) > 0: + batch = [(crop_h, crop_w, b_id) for b_id in batch_ids] + yield batch + + def update_scales(self, epoch, is_master_node=False, *args, **kwargs): + if epoch in self.scale_ep_intervals and self.scale_inc: + self.min_crop_size_w += int( + self.min_crop_size_w * self.min_scale_inc_factor + ) + self.max_crop_size_w += int( + self.max_crop_size_w * self.max_scale_inc_factor + ) + + self.min_crop_size_h += int( + self.min_crop_size_h * self.min_scale_inc_factor + ) + self.max_crop_size_h += int( + self.max_crop_size_h * self.max_scale_inc_factor + ) + + self.img_batch_tuples = _image_batch_pairs( + crop_size_h=self.crop_size_h, + crop_size_w=self.crop_size_w, + batch_size_gpu0=self.batch_size_gpu0, + n_gpus=self.n_gpus, + max_scales=self.max_img_scales, + check_scale_div_factor=self.check_scale_div_factor, + min_crop_size_w=self.min_crop_size_w, + max_crop_size_w=self.max_crop_size_w, + min_crop_size_h=self.min_crop_size_h, + max_crop_size_h=self.max_crop_size_h, + ) + if is_master_node: + logger.log("Scales updated in {}".format(self.__class__.__name__)) + logger.log("New scales: {}".format(self.img_batch_tuples)) + + def __repr__(self): + repr_str = "{}(".format(self.__class__.__name__) + repr_str += ( + "\n\t base_im_size=(h={}, w={}), " + "\n\t base_batch_size={} " + "\n\t scales={} " + "\n\t scale_inc={} " + "\n\t min_scale_inc_factor={} " + "\n\t max_scale_inc_factor={} " + "\n\t ep_intervals={}".format( + self.crop_size_h, + self.crop_size_w, + self.batch_size_gpu0, + self.img_batch_tuples, + self.scale_inc, + self.min_scale_inc_factor, + self.max_scale_inc_factor, + self.scale_ep_intervals, + ) + ) + repr_str += self.extra_repr() + repr_str += "\n)" + return repr_str + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="Variable batch sampler", + description="Arguments related to variable batch sampler", + ) + group.add_argument( + "--sampler.vbs.crop-size-width", + default=DEFAULT_IMAGE_WIDTH, + type=int, + help="Base crop size (along width) during training", + ) + group.add_argument( + "--sampler.vbs.crop-size-height", + default=DEFAULT_IMAGE_HEIGHT, + type=int, + help="Base crop size (along height) during training", + ) + + group.add_argument( + "--sampler.vbs.min-crop-size-width", + default=160, + type=int, + help="Min. crop size along width during training", + ) + group.add_argument( + "--sampler.vbs.max-crop-size-width", + default=320, + type=int, + help="Max. crop size along width during training", + ) + + group.add_argument( + "--sampler.vbs.min-crop-size-height", + default=160, + type=int, + help="Min. crop size along height during training", + ) + group.add_argument( + "--sampler.vbs.max-crop-size-height", + default=320, + type=int, + help="Max. crop size along height during training", + ) + group.add_argument( + "--sampler.vbs.max-n-scales", + default=5, + type=int, + help="Max. scales in variable batch sampler. For example, [0.25, 0.5, 0.75, 1, 1.25] ", + ) + group.add_argument( + "--sampler.vbs.check-scale", + default=32, + type=int, + help="Image scales should be divisible by this factor", + ) + group.add_argument( + "--sampler.vbs.ep-intervals", + default=[40], + type=int, + help="Epoch intervals at which scales are adjusted", + ) + group.add_argument( + "--sampler.vbs.min-scale-inc-factor", + default=1.0, + type=float, + help="Factor by which we should increase the minimum scale", + ) + group.add_argument( + "--sampler.vbs.max-scale-inc-factor", + default=1.0, + type=float, + help="Factor by which we should increase the maximum scale", + ) + group.add_argument( + "--sampler.vbs.scale-inc", + action="store_true", + help="Increase image scales during training", + ) + + return parser + + +@register_sampler(name="variable_batch_sampler_ddp") +class VariableBatchSamplerDDP(BaseSamplerDDP): + """ + `Variably-size multi-scale batch sampler ` for distributed + data parallel + + Args: + opts: command line argument + n_data_samples (int): Number of samples in the dataset + is_training (Optional[bool]): Training or validation mode. Default: False + """ + + def __init__( + self, + opts, + n_data_samples: int, + is_training: Optional[bool] = False, + *args, + **kwargs + ) -> None: + """ + + :param opts: arguments + :param n_data_samples: number of data samples in the dataset + :param is_training: Training or evaluation mode (eval mode includes validation mode) + """ + super().__init__( + opts=opts, n_data_samples=n_data_samples, is_training=is_training + ) + crop_size_w: int = getattr( + opts, "sampler.vbs.crop_size_width", DEFAULT_IMAGE_WIDTH + ) + crop_size_h: int = getattr( + opts, "sampler.vbs.crop_size_height", DEFAULT_IMAGE_HEIGHT + ) + + min_crop_size_w: int = getattr(opts, "sampler.vbs.min_crop_size_width", 160) + max_crop_size_w: int = getattr(opts, "sampler.vbs.max_crop_size_width", 320) + + min_crop_size_h: int = getattr(opts, "sampler.vbs.min_crop_size_height", 160) + max_crop_size_h: int = getattr(opts, "sampler.vbs.max_crop_size_height", 320) + + scale_inc: bool = getattr(opts, "sampler.vbs.scale_inc", False) + scale_ep_intervals: list or int = getattr( + opts, "sampler.vbs.ep_intervals", [40] + ) + min_scale_inc_factor: float = getattr( + opts, "sampler.vbs.min_scale_inc_factor", 1.0 + ) + max_scale_inc_factor: float = getattr( + opts, "sampler.vbs.max_scale_inc_factor", 1.0 + ) + check_scale_div_factor: int = getattr(opts, "sampler.vbs.check_scale", 32) + + max_img_scales: int = getattr(opts, "sampler.vbs.max_n_scales", 10) + + self.crop_size_h = crop_size_h + self.crop_size_w = crop_size_w + self.min_crop_size_h = min_crop_size_h + self.max_crop_size_h = max_crop_size_h + self.min_crop_size_w = min_crop_size_w + self.max_crop_size_w = max_crop_size_w + + self.min_scale_inc_factor = min_scale_inc_factor + self.max_scale_inc_factor = max_scale_inc_factor + self.scale_ep_intervals = scale_ep_intervals + self.max_img_scales = max_img_scales + self.check_scale_div_factor = check_scale_div_factor + self.scale_inc = scale_inc + + if is_training: + self.img_batch_tuples = _image_batch_pairs( + crop_size_h=self.crop_size_h, + crop_size_w=self.crop_size_w, + batch_size_gpu0=self.batch_size_gpu0, + n_gpus=self.num_replicas, + max_scales=self.max_img_scales, + check_scale_div_factor=self.check_scale_div_factor, + min_crop_size_w=self.min_crop_size_w, + max_crop_size_w=self.max_crop_size_w, + min_crop_size_h=self.min_crop_size_h, + max_crop_size_h=self.max_crop_size_h, + ) + else: + self.img_batch_tuples = [ + (self.crop_size_h, self.crop_size_w, self.batch_size_gpu0) + ] + + def __iter__(self): + indices_rank_i = self.get_indices_rank_i() + start_index = 0 + n_samples_rank_i = len(indices_rank_i) + while start_index < n_samples_rank_i: + crop_h, crop_w, batch_size = random.choice(self.img_batch_tuples) + + end_index = min(start_index + batch_size, n_samples_rank_i) + batch_ids = indices_rank_i[start_index:end_index] + n_batch_samples = len(batch_ids) + if n_batch_samples != batch_size: + batch_ids += indices_rank_i[: (batch_size - n_batch_samples)] + start_index += batch_size + + if len(batch_ids) > 0: + batch = [(crop_h, crop_w, b_id) for b_id in batch_ids] + yield batch + + def update_scales(self, epoch, is_master_node=False, *args, **kwargs): + if (epoch in self.scale_ep_intervals) and self.scale_inc: # Training mode + self.min_crop_size_w += int( + self.min_crop_size_w * self.min_scale_inc_factor + ) + self.max_crop_size_w += int( + self.max_crop_size_w * self.max_scale_inc_factor + ) + + self.min_crop_size_h += int( + self.min_crop_size_h * self.min_scale_inc_factor + ) + self.max_crop_size_h += int( + self.max_crop_size_h * self.max_scale_inc_factor + ) + + self.img_batch_tuples = _image_batch_pairs( + crop_size_h=self.crop_size_h, + crop_size_w=self.crop_size_w, + batch_size_gpu0=self.batch_size_gpu0, + n_gpus=self.num_replicas, + max_scales=self.max_img_scales, + check_scale_div_factor=self.check_scale_div_factor, + min_crop_size_w=self.min_crop_size_w, + max_crop_size_w=self.max_crop_size_w, + min_crop_size_h=self.min_crop_size_h, + max_crop_size_h=self.max_crop_size_h, + ) + if is_master_node: + logger.log("Scales updated in {}".format(self.__class__.__name__)) + logger.log("New scales: {}".format(self.img_batch_tuples)) + + def __repr__(self): + repr_str = "{}(".format(self.__class__.__name__) + repr_str += ( + "\n\t base_im_size=(h={}, w={}), " + "\n\t base_batch_size={} " + "\n\t scales={} " + "\n\t scale_inc={} " + "\n\t min_scale_inc_factor={} " + "\n\t max_scale_inc_factor={} " + "\n\t ep_intervals={} ".format( + self.crop_size_h, + self.crop_size_w, + self.batch_size_gpu0, + self.img_batch_tuples, + self.scale_inc, + self.min_scale_inc_factor, + self.max_scale_inc_factor, + self.scale_ep_intervals, + ) + ) + repr_str += self.extra_repr() + repr_str += "\n )" + return repr_str diff --git a/Adaptive Frequency Filters/data/transforms/__init__.py b/Adaptive Frequency Filters/data/transforms/__init__.py new file mode 100644 index 0000000..ebb9c05 --- /dev/null +++ b/Adaptive Frequency Filters/data/transforms/__init__.py @@ -0,0 +1,56 @@ +# +# For licensing see accompanying LICENSE file. +# Copyright (C) 2023 Apple Inc. All Rights Reserved. +# + +import os +import importlib +import argparse + +from .base_transforms import BaseTransformation + +SUPPORTED_AUG_CATEGORIES = [] +AUGMENTAION_REGISTRY = {} + + +def register_transformations(name, type): + def register_transformation_class(cls): + if name in AUGMENTAION_REGISTRY: + raise ValueError( + "Cannot register duplicate transformation class ({})".format(name) + ) + + if not issubclass(cls, BaseTransformation): + raise ValueError( + "Transformation ({}: {}) must extend BaseTransformation".format( + name, cls.__name__ + ) + ) + + AUGMENTAION_REGISTRY[name + "_" + type] = cls + return cls + + return register_transformation_class + + +def arguments_augmentation(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + + # add augmentation specific arguments + for k, v in AUGMENTAION_REGISTRY.items(): + parser = v.add_arguments(parser=parser) + + return parser + + +# automatically import the augmentations +transform_dir = os.path.dirname(__file__) + +for file in os.listdir(transform_dir): + path = os.path.join(transform_dir, file) + if ( + not file.startswith("_") + and not file.startswith(".") + and (file.endswith(".py") or os.path.isdir(path)) + ): + transform_name = file[: file.find(".py")] if file.endswith(".py") else file + module = importlib.import_module("data.transforms." + transform_name) diff --git a/Adaptive Frequency Filters/data/transforms/base_transforms.py b/Adaptive Frequency Filters/data/transforms/base_transforms.py new file mode 100644 index 0000000..b76fb01 --- /dev/null +++ b/Adaptive Frequency Filters/data/transforms/base_transforms.py @@ -0,0 +1,26 @@ +# +# For licensing see accompanying LICENSE file. +# Copyright (C) 2023 Apple Inc. All Rights Reserved. +# + +import argparse +from typing import Dict + + +class BaseTransformation(object): + """ + Base class for augmentation methods + """ + + def __init__(self, opts, *args, **kwargs) -> None: + self.opts = opts + + def __call__(self, data: Dict) -> Dict: + raise NotImplementedError + + def __repr__(self) -> str: + return "{}()".format(self.__class__.__name__) + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + return parser diff --git a/Adaptive Frequency Filters/data/transforms/image_opencv.py b/Adaptive Frequency Filters/data/transforms/image_opencv.py new file mode 100644 index 0000000..d1f9c67 --- /dev/null +++ b/Adaptive Frequency Filters/data/transforms/image_opencv.py @@ -0,0 +1,1760 @@ +# +# For licensing see accompanying LICENSE file. +# Copyright (C) 2023 Apple Inc. All Rights Reserved. +# + +import cv2 +from typing import Optional +import numpy as np +import random +import torch +import math +import argparse +from typing import Sequence, Dict, Any, Union, Tuple + +from utils import logger + +from .utils import jaccard_numpy +from . import register_transformations, BaseTransformation + +# This file is for compatibility with affnet_v0.1. In future, we won't maintain it. Please use functions from +# image_pil.py + + +_str_to_cv2_interpolation = { + "nearest": cv2.INTER_NEAREST, + "bilinear": cv2.INTER_LINEAR, + "cubic": cv2.INTER_CUBIC, +} + +_cv2_to_str_interpolation = { + cv2.INTER_NEAREST: "nearest", + cv2.INTER_LINEAR: "bilinear", + cv2.INTER_CUBIC: "cubic", +} + +_str_to_cv2_pad = { + "constant": cv2.BORDER_CONSTANT, + "edge": cv2.BORDER_REPLICATE, + "reflect": cv2.BORDER_REFLECT_101, + "symmetric": cv2.BORDER_REFLECT, +} + + +def _cv2_interpolation(interpolation): + if interpolation not in _str_to_cv2_interpolation: + interpolate_modes = list(_str_to_cv2_interpolation.keys()) + inter_str = "Supported interpolation modes are:" + for i, j in enumerate(interpolate_modes): + inter_str += "\n\t{}: {}".format(i, j) + logger.error(inter_str) + return _str_to_cv2_interpolation[interpolation] + + +def _cv2_padding(pad_mode): + if pad_mode not in _str_to_cv2_pad: + pad_modes = list(_str_to_cv2_pad.keys()) + pad_mode_str = "Supported padding modes are:" + for i, j in enumerate(pad_modes): + pad_mode_str += "\n\t{}: {}".format(i, j) + logger.error(pad_mode_str) + return _str_to_cv2_pad[pad_mode] + + +def _crop_fn(data: Dict, i: int, j: int, h: int, w: int): + img = data["image"] + crop_image = img[i : i + h, j : j + w] + data["image"] = crop_image + + if "mask" in data: + mask = data.pop("mask") + crop_mask = mask[i : i + h, j : j + w] + data["mask"] = crop_mask + + if "box_coordinates" in data: + boxes = data.pop("box_coordinates") + + area_before_cropping = (boxes[..., 2] - boxes[..., 0]) * ( + boxes[..., 3] - boxes[..., 1] + ) + + boxes[..., 0::2] = np.clip(boxes[..., 0::2] - j, a_min=0, a_max=j + w) + boxes[..., 1::2] = np.clip(boxes[..., 1::2] - i, a_min=0, a_max=i + h) + + area_after_cropping = (boxes[..., 2] - boxes[..., 0]) * ( + boxes[..., 3] - boxes[..., 1] + ) + area_ratio = area_after_cropping / (area_before_cropping + 1) + + # keep the boxes whose area is atleast 20% of the area before cropping + keep = area_ratio >= 0.2 + + box_labels = data.pop("box_labels") + + data["box_coordinates"] = boxes[keep] + data["box_labels"] = box_labels[keep] + + if "instance_mask" in data: + assert "instance_coords" in data + + instance_masks = data.pop("instance_mask") + data["instance_mask"] = instance_masks[i : i + h, j : j + w] + + instance_coords = data.pop("instance_coords") + instance_coords[..., 0::2] = np.clip( + instance_coords[..., 0::2] - j, a_min=0, a_max=j + w + ) + instance_coords[..., 1::2] = np.clip( + instance_coords[..., 1::2] - i, a_min=0, a_max=i + h + ) + data["instance_coords"] = instance_coords + + return data + + +def _resize_fn( + data: Dict, size: Union[Sequence, int], interpolation: Optional[str] = "bilinear" +): + img = data["image"] + h, w = img.shape[:2] + + if isinstance(size, Sequence) and len(size) == 2: + size_h, size_w = size[0], size[1] + elif isinstance(size, int): + if (w <= h and w == size) or (h <= w and h == size): + return data + + if w < h: + size_h = int(size * h / w) + + size_w = size + else: + size_w = int(size * w / h) + size_h = size + else: + raise TypeError( + "Supported size args are int or tuple of length 2. Got inappropriate size arg: {}".format( + size + ) + ) + if isinstance(interpolation, str): + interpolation = _str_to_cv2_interpolation[interpolation] + img = cv2.resize(img, dsize=(size_w, size_h), interpolation=interpolation) + data["image"] = img + + if "mask" in data: + mask = data.pop("mask") + resized_mask = cv2.resize( + mask, dsize=(size_w, size_h), interpolation=cv2.INTER_NEAREST + ) + # this occurs when input is (H, W, 1) + if len(resized_mask.shape) != len(mask.shape): + resized_mask = resized_mask[..., None] + + data["mask"] = resized_mask + + if "box_coordinates" in data: + boxes = data.pop("box_coordinates") + boxes[:, 0::2] *= 1.0 * size_w / w + boxes[:, 1::2] *= 1.0 * size_h / h + data["box_coordinates"] = boxes + + if "instance_mask" in data: + assert "instance_coords" in data + + instance_masks = data.pop("instance_mask") + + resized_instance_masks = cv2.resize( + instance_masks, dsize=(size_w, size_h), interpolation=cv2.INTER_NEAREST + ) + if len(instance_masks.shape) != len(resized_instance_masks.shape): + resized_instance_masks = resized_instance_masks[..., None] + data["instance_mask"] = resized_instance_masks + + instance_coords = data.pop("instance_coords") + instance_coords = instance_coords.astype(np.float) + instance_coords[..., 0::2] *= 1.0 * size_w / w + instance_coords[..., 1::2] *= 1.0 * size_h / h + data["instance_coords"] = instance_coords + + return data + + +def setup_size(size: Any, error_msg="Need a tuple of length 2"): + if isinstance(size, int): + return size, size + + if isinstance(size, (list, tuple)) and len(size) == 1: + return size[0], size[0] + + if len(size) != 2: + raise ValueError(error_msg) + + return size + + +@register_transformations(name="random_gamma_correction", type="image") +class RandomGammaCorrection(BaseTransformation): + def __init__(self, opts): + gamma_range = getattr( + opts, "image_augmentation.random_gamma_correction.gamma", (0.25, 1.75) + ) + p = getattr(opts, "image_augmentation.random_gamma_correction.p", 0.5) + super(RandomGammaCorrection, self).__init__(opts=opts) + self.gamma = setup_size(gamma_range) + self.p = p + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + + group.add_argument( + "--image-augmentation.random-gamma-correction.enable", + action="store_true", + help="use gamma correction", + ) + group.add_argument( + "--image-augmentation.random-gamma-correction.gamma", + type=float or tuple, + default=(0.5, 1.5), + help="Gamma range", + ) + group.add_argument( + "--image-augmentation.random-gamma-correction.p", + type=float, + default=0.5, + help="Probability that {} will be applied".format(cls.__name__), + ) + return parser + + def __call__(self, data: Dict) -> Dict: + if random.random() <= self.p: + img = data["image"] + gamma = random.uniform(self.gamma[0], self.gamma[1]) + table = np.array( + [((i / 255.0) ** gamma) * 255 for i in np.arange(0, 256)] + ).astype("uint8") + img = cv2.LUT(img, table) + data["image"] = img + return data + + def __repr__(self): + return "{}(gamma={}, p={})".format(self.__class__.__name__, self.gamma, self.p) + + +@register_transformations(name="random_resize", type="image") +class RandomResize(BaseTransformation): + def __init__(self, opts): + min_size = getattr(opts, "image_augmentation.random_resize.min_size", 256) + max_size = getattr(opts, "image_augmentation.random_resize.max_size", 1024) + interpolation = getattr( + opts, "image-augmentation.random_resize.interpolation", "bilinear" + ) + super(RandomResize, self).__init__(opts=opts) + self.min_size = min_size + self.max_size = max_size + self.interpolation = _cv2_interpolation(interpolation=interpolation) + + def __call__(self, data: Dict) -> Dict: + random_size = random.randint(self.min_size, self.max_size) + return _resize_fn(data, size=random_size, interpolation=self.interpolation) + + def __repr__(self): + return "{}(min_size={}, max_size={}, interpolation={})".format( + self.__class__.__name__, + self.min_size, + self.max_size, + _cv2_to_str_interpolation[self.interpolation], + ) + + +@register_transformations(name="random_zoom_out", type="image") +class RandomZoomOut(BaseTransformation): + def __init__(self, opts, size: Optional[Sequence or int] = None): + side_range = getattr( + opts, "image_augmentation.random_zoom_out.side_range", [1, 4] + ) + p = getattr(opts, "image_augmentation.random_zoom_out.p", 0.5) + super(RandomZoomOut, self).__init__(opts=opts) + self.fill = 0.5 + self.side_range = side_range + self.p = p + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + group.add_argument( + "--image-augmentation.random-zoom-out.enable", + action="store_true", + help="Use random scale", + ) + group.add_argument( + "--image-augmentation.random-zoom-out.side-range", + type=list or tuple, + default=[1, 4], + help="Side range", + ) + group.add_argument( + "--image-augmentation.random-zoom-out.p", + type=float, + default=0.5, + help="Probability of applying RandomZoomOut transformation", + ) + return parser + + def zoom_out( + self, image: np.ndarray, boxes: Optional[np.ndarray] = None + ) -> Tuple[np.ndarray, Optional[np.ndarray]]: + height, width, depth = image.shape + ratio = random.uniform(self.side_range[0], self.side_range[1]) + left = int(random.uniform(0, width * ratio - width)) + top = int(random.uniform(0, height * ratio - height)) + + expand_image = ( + np.ones((int(height * ratio), int(width * ratio), depth), dtype=image.dtype) + * self.fill + ) + expand_image[top : top + height, left : left + width] = image + + expand_boxes = None + if boxes is not None: + expand_boxes = boxes.copy() + expand_boxes[:, :2] += (left, top) + expand_boxes[:, 2:] += (left, top) + + return expand_image, expand_boxes + + def __call__(self, data: Dict) -> Dict: + if random.random() > self.p: + return data + img = data["image"] + + boxes = None + if "box_coordinates" in data: + boxes = data.pop("box_coordinates") + + img, boxes = self.zoom_out(image=img, boxes=boxes) + + data["image"] = img + data["box_coordinates"] = boxes + + return data + + def __repr__(self): + return "{}(min_scale={}, max_scale={}, interpolation={})".format( + self.__class__.__name__, + self.min_scale, + self.max_scale, + _cv2_to_str_interpolation[self.interpolation], + ) + + +@register_transformations(name="random_scale", type="image") +class RandomScale(BaseTransformation): + def __init__(self, opts, size: Optional[Sequence or int] = None): + min_scale = getattr(opts, "image_augmentation.random_scale.min_scale", 0.5) + max_scale = getattr(opts, "image_augmentation.random_scale.max_scale", 2.0) + interpolation = getattr( + opts, "image_augmentation.random_scale.interpolation", "bilinear" + ) + super(RandomScale, self).__init__(opts=opts) + self.min_scale = min_scale + self.max_scale = max_scale + self.interpolation = _cv2_interpolation(interpolation) + self.size = None + if size is not None: + self.size = setup_size(size) + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + group.add_argument( + "--image-augmentation.random-scale.enable", + action="store_true", + help="Use random scale", + ) + group.add_argument( + "--image-augmentation.random-scale.min-scale", + type=float, + default=0.5, + help="Min scale", + ) + group.add_argument( + "--image-augmentation.random-scale.max-scale", + type=float, + default=2.0, + help="Max scale", + ) + group.add_argument( + "--image-augmentation.random-scale.interpolation", + type=str, + default="bilinear", + help="Interpolation method", + ) + return parser + + def __call__(self, data: Dict) -> Dict: + scale = random.uniform(self.min_scale, self.max_scale) + + img = data["image"] + if self.size is None: + height, width = img.shape[:2] + else: + height, width = self.size + target_height, target_width = int(height * scale), int(width * scale) + img = cv2.resize( + img, dsize=(target_width, target_height), interpolation=self.interpolation + ) + data["image"] = img + + if "mask" in data: + mask = data.pop("mask") + mask = cv2.resize( + mask, + dsize=(target_width, target_height), + interpolation=cv2.INTER_NEAREST, + ) + data["mask"] = mask + return data + + def __repr__(self): + return "{}(min_scale={}, max_scale={}, interpolation={})".format( + self.__class__.__name__, + self.min_scale, + self.max_scale, + _cv2_to_str_interpolation[self.interpolation], + ) + + +@register_transformations(name="random_resized_crop", type="image") +class RandomResizedCrop(BaseTransformation): + """ + Adapted from Pytorch Torchvision + """ + + def __init__(self, opts, size: tuple or int): + + interpolation = getattr( + opts, "image_augmentation.random_resized_crop.interpolation", "bilinear" + ) + scale = getattr( + opts, "image_augmentation.random_resized_crop.scale", (0.08, 1.0) + ) + ratio = getattr( + opts, + "image_augmentation.random_resized_crop.aspect_ratio", + (3.0 / 4.0, 4.0 / 3.0), + ) + + if not isinstance(scale, Sequence) or ( + isinstance(scale, Sequence) + and len(scale) != 2 + and 0.0 <= scale[0] < scale[1] + ): + logger.error( + "--image-augmentation.random-resized-crop.scale should be a tuple of length 2 " + "such that 0.0 <= scale[0] < scale[1]. Got: {}".format(scale) + ) + + if not isinstance(ratio, Sequence) or ( + isinstance(ratio, Sequence) + and len(ratio) != 2 + and 0.0 < ratio[0] < ratio[1] + ): + logger.error( + "--image-augmentation.random-resized-crop.aspect-ratio should be a tuple of length 2 " + "such that 0.0 < ratio[0] < ratio[1]. Got: {}".format(ratio) + ) + + ratio = (round(ratio[0], 3), round(ratio[1], 3)) + + super(RandomResizedCrop, self).__init__(opts=opts) + + self.scale = scale + self.size = setup_size(size=size) + + self.interpolation = _cv2_interpolation(interpolation) + self.ratio = ratio + + def get_params(self, height: int, width: int) -> (int, int, int, int): + area = height * width + for _ in range(10): + target_area = random.uniform(*self.scale) * area + log_ratio = (math.log(self.ratio[0]), math.log(self.ratio[1])) + aspect_ratio = math.exp(random.uniform(*log_ratio)) + + w = int(round(math.sqrt(target_area * aspect_ratio))) + h = int(round(math.sqrt(target_area / aspect_ratio))) + + if 0 < w <= width and 0 < h <= height: + i = random.randint(0, height - h) + j = random.randint(0, width - w) + return i, j, h, w + + # Fallback to central crop + in_ratio = (1.0 * width) / height + if in_ratio < min(self.ratio): + w = width + h = int(round(w / min(self.ratio))) + elif in_ratio > max(self.ratio): + h = height + w = int(round(h * max(self.ratio))) + else: # whole image + w = width + h = height + i = (height - h) // 2 + j = (width - w) // 2 + return i, j, h, w + + def __call__(self, data: Dict) -> Dict: + img = data["image"] + height, width = img.shape[:2] + + i, j, h, w = self.get_params(height=height, width=width) + data = _crop_fn(data=data, i=i, j=j, h=h, w=w) + return _resize_fn(data=data, size=self.size, interpolation=self.interpolation) + + def __repr__(self): + return "{}(scale={}, ratio={}, interpolation={})".format( + self.__class__.__name__, + self.scale, + self.ratio, + _cv2_to_str_interpolation[self.interpolation], + ) + + +@register_transformations(name="random_crop", type="image") +class RandomCrop(BaseTransformation): + """ + Randomly crop the image to a given size + """ + + def __init__(self, opts, size: Sequence or int): + super(RandomCrop, self).__init__(opts=opts) + self.height, self.width = setup_size(size=size) + self.opts = opts + self.fill_mask = getattr(opts, "image_augmentation.random_crop.mask_fill", 255) + is_padding = not getattr( + opts, "image_augmentation.random_crop.resize_if_needed", False + ) + self.inp_process_fn = ( + self.pad_if_needed if not is_padding else self.resize_if_needed + ) + + @staticmethod + def get_params(img_h, img_w, target_h, target_w): + if img_w == target_w and img_h == target_h: + return 0, 0, img_h, img_w + i = random.randint(0, img_h - target_h) + j = random.randint(0, img_w - target_w) + return i, j, target_h, target_w + + @staticmethod + def get_params_from_box(boxes, img_h, img_w): + # x, y, w, h + offset = random.randint(20, 50) + start_x = max(0, int(round(np.min(boxes[..., 0]))) - offset) + start_y = max(0, int(round(np.min(boxes[..., 1]))) - offset) + end_x = min(int(round(np.max(boxes[..., 2]))) + offset, img_w) + end_y = min(int(round(np.max(boxes[..., 3]))) + offset, img_h) + + return start_y, start_x, end_y - start_y, end_x - start_x + + def pad_if_needed(self, data: Dict) -> Dict: + img = data["image"] + + h, w, channels = img.shape + pad_h = self.height - h if h < self.height else 0 + pad_w = self.width - w if w < self.width else 0 + + # padding format is (top, bottom, left, right) + img = cv2.copyMakeBorder(img, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=0) + data["image"] = img + + if "mask" in data: + mask = data.pop("mask") + mask = cv2.copyMakeBorder( + mask, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=self.fill_mask + ) + data["mask"] = mask + return data + + def resize_if_needed(self, data: Dict) -> Dict: + img = data["image"] + + h, w, channels = img.shape + new_size = min(h + max(0, self.height - h), w + max(0, self.width - w)) + # resize while maintaining the aspect ratio + return _resize_fn(data, size=new_size, interpolation="bilinear") + + def __call__(self, data: Dict) -> Dict: + # box_info + if "box_coordinates" in data: + boxes = data.get("box_coordinates") + # crop the relevant area + image_h, image_w = data["image"].shape[:2] + box_i, box_j, box_h, box_w = self.get_params_from_box( + boxes, image_h, image_w + ) + data = _crop_fn(data, i=box_i, j=box_j, h=box_h, w=box_w) + + data = self.inp_process_fn(data) + img_h, img_w = data["image"].shape[:2] + i, j, h, w = self.get_params( + img_h=img_h, img_w=img_w, target_h=self.height, target_w=self.width + ) + data = _crop_fn(data=data, i=i, j=j, h=h, w=w) + + return data + + def __repr__(self): + return "{}(size=(h={}, w={}))".format( + self.__class__.__name__, self.height, self.width + ) + + +@register_transformations(name="random_flip", type="image") +class RandomFlip(BaseTransformation): + def __init__(self, opts): + super(RandomFlip, self).__init__(opts=opts) + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + + group.add_argument( + "--image-augmentation.random-flip.enable", + action="store_true", + help="use random flipping", + ) + return parser + + def __call__(self, data: Dict) -> Dict: + flip_choice = random.choices([0, 1, 2])[0] + if flip_choice in [0, 1]: # 1 - Horizontal, 0 - vertical + img = data["image"] + img = cv2.flip(img, flip_choice) + data["image"] = img + + if "mask" in data: + mask = data.pop("mask") + mask = cv2.flip(mask, flip_choice) + data["mask"] = mask + + if "box_coordinates" in data: + boxes = data.pop("box_coordinates") + if flip_choice == 0: + height = img.shape[0] + boxes[:, 1::2] = height - boxes[:, 3::-2] + elif flip_choice == 1: + width = img.shape[1] + boxes[:, 0::2] = width - boxes[:, 2::-2] + + data["box_coordinates"] = boxes + + return data + + +@register_transformations(name="random_horizontal_flip", type="image") +class RandomHorizontalFlip(BaseTransformation): + def __init__(self, opts): + p = getattr(opts, "image_augmentation.random_horizontal_flip.p", 0.5) + super(RandomHorizontalFlip, self).__init__(opts=opts) + self.p = p + + def __call__(self, data: Dict) -> Dict: + + if random.random() <= self.p: + img = data["image"] + width = img.shape[1] + data["image"] = img[:, ::-1, ...] + + if "mask" in data: + mask = data.pop("mask") + mask = mask[:, ::-1, ...] + data["mask"] = mask + + if "box_coordinates" in data: + boxes = data.pop("box_coordinates") + boxes[..., 0::2] = width - boxes[..., 2::-2] + data["box_coordinates"] = boxes + + if "instance_mask" in data: + assert "instance_coords" in data + + instance_coords = data.pop("instance_coords") + instance_coords[..., 0::2] = width - instance_coords[..., 2::-2] + data["instance_coords"] = instance_coords + + instance_masks = data.pop("instance_mask") + instance_masks = instance_masks[:, ::-1, ...] + data["instance_mask"] = instance_masks + + return data + + def __repr__(self): + return "{}(p={})".format(self.__class__.__name__, self.p) + + +@register_transformations(name="instance_processor", type="image") +class InstanceProcessor(BaseTransformation): + def __init__( + self, + opts, + instance_size: Optional[Union[int, Tuple[int, ...]]] = 16, + *args, + **kwargs + ): + super(InstanceProcessor, self).__init__(opts=opts) + self.instance_size = setup_size(instance_size) + + def __call__(self, data: Dict) -> Dict: + + if "instance_mask" in data: + assert "instance_coords" in data + instance_masks = data.pop("instance_mask") + instance_coords = data.pop("instance_coords") + instance_coords = instance_coords.astype(np.int) + + valid_boxes = (instance_coords[..., 3] > instance_coords[..., 1]) & ( + instance_coords[..., 2] > instance_coords[..., 0] + ) + instance_masks = instance_masks[..., valid_boxes] + instance_coords = instance_coords[valid_boxes] + + num_instances = instance_masks.shape[-1] + + resized_instances = [] + for i in range(num_instances): + instance_m = instance_masks[..., i] + box_coords = instance_coords[i] + instance_m = instance_m[ + box_coords[1] : box_coords[3], box_coords[0] : box_coords[2] + ] + instance_m = cv2.resize( + instance_m, + dsize=self.instance_size, + interpolation=cv2.INTER_NEAREST, + ) + resized_instances.append(instance_m) + + if len(resized_instances) == 0: + resized_instances = np.zeros( + shape=(self.instance_size[0], self.instance_size[1], 1), + dtype=np.uint8, + ) + instance_coords = np.array( + [[0, 0, self.instance_size[0], self.instance_size[1]]] + ) + else: + resized_instances = np.stack(resized_instances, axis=-1) + + data["instance_mask"] = resized_instances + data["instance_coords"] = instance_coords.astype(np.float) + return data + + +@register_transformations(name="random_vertical_flip", type="image") +class RandomVerticalFlip(BaseTransformation): + def __init__(self, opts): + p = getattr(opts, "image_augmentation.random_vertical_flip.p", 0.5) + super(RandomVerticalFlip, self).__init__(opts=opts) + self.p = p + + def __call__(self, data: Dict) -> Dict: + if random.random() <= self.p: + img = data["image"] + img = cv2.flip(img, 0) + data["image"] = img + + if "mask" in data: + mask = data.pop("mask") + mask = cv2.flip(mask, 0) + data["mask"] = mask + + if "box_coordinates" in data: + boxes = data.pop("box_coordinates") + height = img.shape[0] + boxes[:, 1::2] = height - boxes[:, 3::-2] + + data["box_coordinates"] = boxes + + return data + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + group.add_argument( + "--image-augmentation.random-vertical-flip.enable", + action="store_true", + help="use random vertical flipping", + ) + group.add_argument( + "--image-augmentation.random-vertical-flip.p", + type=float, + default=0.5, + help="Probability for random vertical flip", + ) + return parser + + def __repr__(self): + return "{}(p={})".format(self.__class__.__name__, self.p) + + +@register_transformations(name="random_rotation", type="image") +class RandomRotate(BaseTransformation): + def __init__(self, opts): + angle = getattr(opts, "image_augmentation.random_rotate.angle", 10.0) + fill = getattr(opts, "image_augmentation.random_rotate.mask_fill", 255) + interpolation = getattr( + opts, "image_augmentation.random_rotate.interpolation", "bilinear" + ) + p = getattr(opts, "image_augmentation.random_rotate.p", 0.5) + super(RandomRotate, self).__init__(opts=opts) + self.angle = angle + self.fill = fill + self.p = p + self.interpolation = _cv2_interpolation(interpolation) + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + # These two arguments are CV2-specific + group.add_argument( + "--image-augmentation.random-rotate.interpolation", + type=str, + default="bilinear", + help="Interpolation method", + ) + group.add_argument( + "--image-augmentation.random-rotate.p", + type=float, + default=0.5, + help="Probability that {} will be applied".format(cls.__name__), + ) + return parser + + def __call__(self, data: Dict) -> Dict: + if random.random() <= self.p: + img = data["image"] + height, width = img.shape[:2] + + random_angle = random.uniform(-self.angle, self.angle) + rotation_mat = cv2.getRotationMatrix2D( + center=(width / 2, height / 2), angle=random_angle, scale=1 + ) + + img_rotated = cv2.warpAffine( + src=img, + M=rotation_mat, + dsize=(width, height), + flags=self.interpolation, + borderValue=0, + ) + data["image"] = img_rotated + + if "mask" in data: + mask = data.pop("mask") + mask_rotated = cv2.warpAffine( + src=mask, + M=rotation_mat, + dsize=(width, height), + flags=cv2.INTER_NEAREST, + borderValue=self.fill, + ) + data["mask"] = mask_rotated + + if "box_coordinates" in data: + raise NotImplementedError( + "RandomRotate is not implemented for box coordinates" + ) + + return data + + def __repr__(self): + return "{}(angle={}, interpolation={}, p={})".format( + self.__class__.__name__, + self.angle, + _cv2_to_str_interpolation[self.interpolation], + self.p, + ) + + +BLUR_METHODS = ["gauss", "median", "average", "none", "any"] + + +@register_transformations(name="random_blur", type="image") +class RandomBlur(BaseTransformation): + def __init__(self, opts): + kernel_range = getattr( + opts, "image_augmentation.random_blur.kernel_size", [3, 7] + ) + blur_type = getattr(opts, "image_augmentation.random_blur.kernel_type", "any") + p = getattr(opts, "image_augmentation.random_blur.p", 0.5) + super(RandomBlur, self).__init__(opts=opts) + self.kernel_range = setup_size(kernel_range) + assert 1 <= self.kernel_range[0] <= self.kernel_range[1], "Got: {}".format( + self.kernel_range + ) + self.blur_type = blur_type + self.p = p + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + + group.add_argument( + "--image-augmentation.random-blur.enable", + action="store_true", + help="use random blurring", + ) + + group.add_argument( + "--image-augmentation.random-blur.kernel-size", + type=tuple or int or list, + default=[3, 7], + help="Randomly sample the kernel size from the given range", + ) + group.add_argument( + "--image-augmentation.random-blur.kernel-type", + type=str, + choices=BLUR_METHODS, + default=255, + help="Value used to fill the area after rotation", + ) + group.add_argument( + "--image-augmentation.random-blur.p", + type=float, + default=0.5, + help="Probability that {} will be applied".format(cls.__name__), + ) + return parser + + def blur_median(self, img: np.ndarray, ksize_x: int, ksize_y: int) -> np.ndarray: + ksize = ksize_x if random.random() < 0.5 else ksize_y + img = cv2.medianBlur(src=img, ksize=ksize) + return img + + def blur_avg(self, img: np.ndarray, ksize_x: int, ksize_y: int) -> np.ndarray: + return cv2.blur(src=img, ksize=(ksize_x, ksize_y)) + + def blur_gauss(self, img: np.ndarray, ksize_x: int, ksize_y: int) -> np.ndarray: + return cv2.GaussianBlur(src=img, ksize=(ksize_x, ksize_y), sigmaX=0) + + def blur_any(self, img: np.ndarray, ksize_x: int, ksize_y: int) -> np.ndarray: + blur_method = random.choice(BLUR_METHODS[:-1]) + if blur_method == "gauss": + img = self.blur_gauss(img=img, ksize_x=ksize_x, ksize_y=ksize_y) + elif blur_method == "median": + img = self.blur_median(img=img, ksize_x=ksize_x, ksize_y=ksize_y) + elif blur_method == "average": + img = self.blur_avg(img=img, ksize_x=ksize_x, ksize_y=ksize_y) + return img + + def __call__(self, data: Dict) -> Dict: + if self.blur_type == "none": + return data + + ksize_x = random.randint(self.kernel_range[0], self.kernel_range[1]) + ksize_y = random.randint(self.kernel_range[0], self.kernel_range[1]) + ksize_x = (ksize_x // 2) * 2 + 1 + ksize_y = (ksize_y // 2) * 2 + 1 + + img = data["image"] + + if self.blur_type == "any": + img = self.blur_any(img, ksize_x=ksize_x, ksize_y=ksize_y) + elif self.blur_type == "gaussian" and random.random() <= self.p: + img = self.blur_gauss(img=img, ksize_x=ksize_x, ksize_y=ksize_y) + elif self.blur_type == "median" and random.random() <= self.p: + img = self.blur_median(img=img, ksize_x=ksize_x, ksize_y=ksize_y) + elif self.blur_type == "average" and random.random() <= self.p: + img = self.blur_avg(img=img, ksize_x=ksize_x, ksize_y=ksize_y) + + data["image"] = img + return data + + def __repr__(self): + if self.blur_type == "any": + blur_type = ["gaussian", "median", "average"] + else: + blur_type = self.blur_type + return "{}(blur_type={}, kernel_range={})".format( + self.__class__.__name__, blur_type, self.kernel_range + ) + + +@register_transformations(name="random_translate", type="image") +class RandomTranslate(BaseTransformation): + def __init__(self, opts): + translate_factor = getattr( + opts, "image_augmentation.random_translate.factor", 0.2 + ) + assert 0 < translate_factor < 0.5, "Factor should be between 0 and 0.5" + super(RandomTranslate, self).__init__(opts=opts) + + self.translation_factor = translate_factor + + def __call__(self, data: Dict) -> Dict: + img = data["image"] + + height, width = img.shape[:2] + th = int(math.ceil(random.uniform(0, self.translation_factor) * height)) + tw = int(math.ceil(random.uniform(0, self.translation_factor) * width)) + img_translated = np.zeros_like(img) + translate_from_left = True if random.random() <= 0.5 else False + if translate_from_left: + img_translated[th:, tw:] = img[: height - th, : width - tw] + else: + img_translated[: height - th, : width - tw] = img[th:, tw:] + data["image"] = img_translated + + if "mask" in data: + mask = data.pop("mask") + mask_translated = np.zeros_like(mask) + if translate_from_left: + mask_translated[th:, tw:] = mask[: height - th, : width - tw] + else: + mask_translated[: height - th, : width - tw] = mask[th:, tw:] + data["mask"] = mask_translated + return data + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + group.add_argument( + "--image-augmentation.random-translate.enable", + action="store_true", + help="use random translation", + ) + group.add_argument( + "--image-augmentation.random-translate.factor", + type=float, + default=0.2, + help="Translate uniformly between (-u, u)", + ) + return parser + + def __repr__(self): + return "{}(factor={})".format(self.__class__.__name__, self.translation_factor) + + +@register_transformations(name="resize", type="image") +class Resize(BaseTransformation): + def __init__(self, opts, size, *args, **kwargs): + if not ( + isinstance(size, int) + or (isinstance(size, Sequence) and len(size) in (1, 2)) + ): + raise TypeError( + "Supported size args are int or tuple of length 2. Got inappropriate size arg: {}".format( + self.size + ) + ) + interpolation = getattr( + opts, "image_augmentation.resize.interpolation", "bilinear" + ) + super(Resize, self).__init__(opts=opts) + + self.size = size + self.interpolation = _cv2_interpolation(interpolation) + + def __call__(self, data: Dict) -> Dict: + return _resize_fn(data=data, size=self.size, interpolation=self.interpolation) + + def __repr__(self): + return "{}(size={}, interpolation={})".format( + self.__class__.__name__, + self.size, + _cv2_to_str_interpolation[self.interpolation], + ) + + +@register_transformations(name="box_absolute_coords", type="image") +class BoxAbsoluteCoords(BaseTransformation): + def __init__(self, opts): + super(BoxAbsoluteCoords, self).__init__(opts=opts) + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + + group.add_argument( + "--image-augmentation.box-absolute-coords.enable", + action="store_true", + help="Convert box coordinates to absolute coordinates", + ) + return parser + + def __call__(self, data: Dict) -> Dict: + if "box_coordinates" in data: + boxes = data.pop("box_coordinates") + + image = data["image"] + + height, width, channels = image.shape + boxes[..., 0::2] *= width + boxes[..., 1::2] *= height + + data["box_coordinates"] = boxes + return data + + +@register_transformations(name="box_percent_coords", type="image") +class BoxPercentCoords(BaseTransformation): + def __init__(self, opts): + super(BoxPercentCoords, self).__init__(opts=opts) + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + + group.add_argument( + "--image-augmentation.box-percent-coords.enable", + action="store_true", + help="Convert box coordinates to percent", + ) + return parser + + def __call__(self, data: Dict) -> Dict: + + if "box_coordinates" in data: + boxes = data.pop("box_coordinates") + image = data["image"] + height, width, channels = image.shape + + boxes = boxes.astype(np.float) + + boxes[..., 0::2] /= width + boxes[..., 1::2] /= height + data["box_coordinates"] = boxes + + return data + + +@register_transformations(name="ssd_cropping", type="image") +class SSDCroping(BaseTransformation): + """Crop + Arguments: + img (Image): the image being input during training + boxes (Tensor): the original bounding boxes in pt form + labels (Tensor): the class labels for each bbox + mode (float tuple): the min and max jaccard overlaps + Return: + (img, boxes, classes) + img (Image): the cropped image + boxes (Tensor): the adjusted bounding boxes in pt form + labels (Tensor): the class labels for each bbox + """ + + def __init__(self, opts): + super(SSDCroping, self).__init__(opts=opts) + self.iou_sample_opts = getattr( + opts, + "image_augmentation.ssd_crop.iou_thresholds", + [0.0, 0.1, 0.3, 0.5, 0.7, 0.9, 1.0], + ) + self.trials = getattr(opts, "image_augmentation.ssd_crop.n_trials", 40) + self.min_aspect_ratio = getattr( + opts, "image_augmentation.ssd_crop.min_aspect_ratio", 0.5 + ) + self.max_aspect_ratio = getattr( + opts, "image_augmentation.ssd_crop.max_aspect_ratio", 0.5 + ) + + def __call__(self, data: Dict) -> Dict: + if "box_coordinates" in data: + boxes = data["box_coordinates"] + + # guard against no boxes + if boxes.shape[0] == 0: + return data + + image = data["image"] + labels = data["box_labels"] + height, width = image.shape[:2] + + while True: + # randomly choose a mode + min_jaccard_overalp = random.choice(self.iou_sample_opts) + if min_jaccard_overalp == 0.0: + return data + + for _ in range(self.trials): + w = random.uniform(0.3 * width, width) + h = random.uniform(0.3 * height, height) + + aspect_ratio = h / w + if not ( + self.min_aspect_ratio <= aspect_ratio <= self.max_aspect_ratio + ): + continue + + left = random.uniform(0, width - w) + top = random.uniform(0, height - h) + + # convert to integer rect x1,y1,x2,y2 + rect = np.array([int(left), int(top), int(left + w), int(top + h)]) + + # calculate IoU (jaccard overlap) b/t the cropped and gt boxes + ious = jaccard_numpy(boxes, rect) + + # is min and max overlap constraint satisfied? if not try again + if ious.max() < min_jaccard_overalp: + continue + + # keep overlap with gt box IF center in sampled patch + centers = (boxes[:, :2] + boxes[:, 2:]) * 0.5 + + # mask in all gt boxes that above and to the left of centers + m1 = (rect[0] < centers[:, 0]) * (rect[1] < centers[:, 1]) + + # mask in all gt boxes that under and to the right of centers + m2 = (rect[2] > centers[:, 0]) * (rect[3] > centers[:, 1]) + + # mask in that both m1 and m2 are true + mask = m1 * m2 + + # have any valid boxes? try again if not + if not mask.any(): + continue + + # if image size is too small, try again + if (rect[3] - rect[1]) < 100 or (rect[2] - rect[0]) < 100: + continue + + # cut the crop from the image + image = image[rect[1] : rect[3], rect[0] : rect[2], :] + + # take only matching gt boxes + current_boxes = boxes[mask, :].copy() + + # take only matching gt labels + current_labels = labels[mask] + + # should we use the box left and top corner or the crop's + current_boxes[:, :2] = np.maximum(current_boxes[:, :2], rect[:2]) + # adjust to crop (by substracting crop's left,top) + current_boxes[:, :2] -= rect[:2] + + current_boxes[:, 2:] = np.minimum(current_boxes[:, 2:], rect[2:]) + # adjust to crop (by substracting crop's left,top) + current_boxes[:, 2:] -= rect[:2] + + data["image"] = image + data["box_labels"] = current_labels + data["box_coordinates"] = current_boxes + + if "mask" in data: + seg_mask = data.pop("mask") + seg_mask = seg_mask[rect[1] : rect[3], rect[0] : rect[2]] + data["mask"] = seg_mask + + if "instance_mask" in data: + assert "instance_coords" in data + + instance_masks = data.pop("instance_mask") + instance_masks = instance_masks[ + rect[1] : rect[3], rect[0] : rect[2], ... + ] + data["instance_mask"] = instance_masks + + instance_coords = data.pop("instance_coords") + # should we use the box left and top corner or the crop's + instance_coords[..., :2] = np.maximum( + instance_coords[..., :2], rect[:2] + ) + # adjust to crop (by substracting crop's left,top) + instance_coords[..., :2] -= rect[:2] + + instance_coords[..., 2:] = np.minimum( + instance_coords[..., 2:], rect[2:] + ) + # adjust to crop (by substracting crop's left,top) + instance_coords[..., 2:] -= rect[:2] + data["instance_coords"] = instance_coords + + return data + return data + + +@register_transformations(name="center_crop", type="image") +class CenterCrop(BaseTransformation): + def __init__(self, opts, size: Sequence or int): + super(CenterCrop, self).__init__(opts=opts) + if isinstance(size, Sequence) and len(size) == 2: + self.height, self.width = size[0], size[1] + elif isinstance(size, Sequence) and len(size) == 1: + self.height = self.width = size[0] + elif isinstance(size, int): + self.height = self.width = size + else: + logger.error("Scale should be either an int or tuple of ints") + + def __call__(self, data: Dict) -> Dict: + height, width = data["image"].shape[:2] + i = (height - self.height) // 2 + j = (width - self.width) // 2 + return _crop_fn(data=data, i=i, j=j, h=self.height, w=self.width) + + def __repr__(self): + return "{}(size=(h={}, w={}))".format( + self.__class__.__name__, self.height, self.width + ) + + +@register_transformations(name="random_jpeg_compress", type="image") +class RandomJPEGCompress(BaseTransformation): + def __init__(self, opts): + q_range = getattr( + opts, "image_augmentation.random_jpeg_compress.q_factor", (5, 25) + ) + if isinstance(q_range, (int, float)): + q_range = (max(q_range - 10, 0), q_range) + assert len(q_range) == 2 + assert q_range[0] <= q_range[1] + p = getattr(opts, "image_augmentation.random_jpeg_compress.p", 0.5) + super(RandomJPEGCompress, self).__init__(opts=opts) + self.q_factor = q_range + self.p = p + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + group.add_argument( + "--image-augmentation.random-jpeg-compress.enable", + action="store_true", + help="use random compression", + ) + group.add_argument( + "--image-augmentation.random-jpeg-compress.q-factor", + type=int or tuple, + default=(5, 25), + help="Compression quality factor range", + ) + group.add_argument( + "--image-augmentation.random-jpeg-compress.p", + type=float, + default=0.5, + help="Probability that {} will be applied".format(cls.__name__), + ) + + return parser + + def __call__(self, data: Dict) -> Dict: + if random.random() <= self.p: + q_factor = random.randint(self.q_factor[0], self.q_factor[1]) + encoding_param = [int(cv2.IMWRITE_JPEG_QUALITY), q_factor] + + img = data["image"] + _, enc_img = cv2.imencode(".jpg", img, encoding_param) + comp_img = cv2.imdecode(enc_img, 1) + data["image"] = comp_img + + return data + + def __repr__(self): + return "{}(q_factor=({}, {}), p={})".format( + self.__class__.__name__, self.q_factor[0], self.q_factor[1], self.p + ) + + +@register_transformations(name="random_gauss_noise", type="image") +class RandomGaussianNoise(BaseTransformation): + def __init__(self, opts): + sigma_range = getattr( + opts, "image_augmentation.random_gauss_noise.sigma", (0.03, 0.3) + ) + if isinstance(sigma_range, (float, int)): + sigma_range = (0, sigma_range) + + assert len(sigma_range) == 2, "Got {}".format(sigma_range) + assert sigma_range[0] <= sigma_range[1] + p = getattr(opts, "image_augmentation.random_gauss_noise.p", 0.5) + super(RandomGaussianNoise, self).__init__(opts=opts) + self.sigma_low = sigma_range[0] + self.sigma_high = sigma_range[1] + self.p = p + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + group.add_argument( + "--image-augmentation.random-gauss-noise.enable", + action="store_true", + help="use random gaussian noise", + ) + group.add_argument( + "--image-augmentation.random-gauss-noise.sigma", + type=float or tuple, + default=(0.03, 0.1), + help="Sigma (sqrt of variance) range for Gaussian noise. Default is (0.0001, 0.001).", + ) + group.add_argument( + "--image-augmentation.random-gauss-noise.p", + type=float, + default=0.5, + help="Probability that {} will be applied".format(cls.__name__), + ) + return parser + + def __call__(self, data: Dict) -> Dict: + if random.random() <= self.p: + std = random.uniform(self.sigma_low, self.sigma_high) + + img = data["image"] + noise = np.random.normal(0.0, std, img.shape) * 255 + noisy_img = img + noise + + noisy_img = np.clip(noisy_img, 0, 255).astype(np.uint8) + data["image"] = noisy_img + return data + + def __repr__(self): + return "{}(sigma=({}, {}), p={})".format( + self.__class__.__name__, self.sigma_low, self.sigma_high, self.p + ) + + +@register_transformations(name="to_tensor", type="image") +class NumpyToTensor(BaseTransformation): + def __init__(self, opts, *args, **kwargs): + super(NumpyToTensor, self).__init__(opts=opts) + + def __call__(self, data: Dict) -> Dict: + # HWC --> CHW + img = data["image"] + img = img.transpose(2, 0, 1) + img = np.ascontiguousarray(img) + + # numpy to tensor + img_tensor = torch.from_numpy(img).float() + img_tensor = torch.div(img_tensor, 255.0) + data["image"] = img_tensor + + if "mask" in data: + mask = data.pop("mask") + if len(mask.shape) > 2 and mask.shape[-1] > 1: + mask = mask.transpose(2, 0, 1) + data["mask"] = torch.from_numpy(mask).long() + + if "box_coordinates" in data: + boxes = data.pop("box_coordinates") + data["box_coordinates"] = torch.from_numpy(boxes).float() + + if "box_labels" in data: + box_labels = data.pop("box_labels") + data["box_labels"] = torch.from_numpy(box_labels) + + if "instance_mask" in data: + assert "instance_coords" in data + + instance_masks = data.pop("instance_mask") + # [H, W, N] --> [N, H, W] + instance_masks = instance_masks.transpose(2, 0, 1) + instance_masks = np.ascontiguousarray(instance_masks) + data["instance_mask"] = torch.from_numpy(instance_masks).long() + + instance_coords = data.pop("instance_coords") + data["instance_coords"] = torch.from_numpy(instance_coords).float() + return data + + +@register_transformations(name="random_order", type="image") +class RandomOrder(BaseTransformation): + def __init__(self, opts, img_transforms: list): + super(RandomOrder, self).__init__(opts=opts) + self.transforms = img_transforms + apply_k_factor = getattr(opts, "image_augmentation.random_order.apply_k", 1.0) + assert ( + 0.0 < apply_k_factor <= 1.0 + ), "--image-augmentation.random-order.apply-k should be between 0 and 1" + self.keep_t = int(math.ceil(len(self.transforms) * apply_k_factor)) + + def __call__(self, data: Dict) -> Dict: + random.shuffle(self.transforms) + for t in self.transforms[: self.keep_t]: + data = t(data) + return data + + def __repr__(self): + transform_str = ", ".join(str(t) for t in self.transforms) + repr_str = "{}(n_transforms={}, t_list=[{}]".format( + self.__class__.__name__, self.keep_t, transform_str + ) + return repr_str + + +@register_transformations(name="compose", type="image") +class Compose(BaseTransformation): + def __init__(self, opts, img_transforms: list): + super(Compose, self).__init__(opts=opts) + self.img_transforms = img_transforms + + def __call__(self, data: Dict) -> Dict: + for t in self.img_transforms: + data = t(data) + return data + + def __repr__(self): + transform_str = ", ".join("\n\t\t\t" + str(t) for t in self.img_transforms) + repr_str = "{}({})".format(self.__class__.__name__, transform_str) + return repr_str + + +@register_transformations(name="photo_metric_distort_opencv", type="image") +class PhotometricDistort(BaseTransformation): + def __init__(self, opts): + beta_min = getattr( + opts, "image_augmentation.photo_metric_distort_opencv.beta_min", -0.2 + ) + beta_max = getattr( + opts, "image_augmentation.photo_metric_distort_opencv.beta_max", 0.2 + ) + assert -0.5 <= beta_min < beta_max <= 0.5, "Got {} and {}".format( + beta_min, beta_max + ) + + alpha_min = getattr( + opts, "image_augmentation.photo_metric_distort_opencv.alpha_min", 0.5 + ) + alpha_max = getattr( + opts, "image_augmentation.photo_metric_distort_opencv.alpha_max", 1.5 + ) + assert 0 < alpha_min < alpha_max, "Got {} and {}".format(alpha_min, alpha_max) + + gamma_min = getattr( + opts, "image_augmentation.photo_metric_distort_opencv.gamma_min", 0.5 + ) + gamma_max = getattr( + opts, "image_augmentation.photo_metric_distort_opencv.gamma_max", 1.5 + ) + assert 0 < gamma_min < gamma_max, "Got {} and {}".format(gamma_min, gamma_max) + + delta_min = getattr( + opts, "image_augmentation.photo_metric_distort_opencv.delta_min", -0.05 + ) + delta_max = getattr( + opts, "image_augmentation.photo_metric_distort_opencv.delta_max", 0.05 + ) + assert -1.0 < delta_min < delta_max < 1.0, "Got {} and {}".format( + delta_min, delta_max + ) + + super(PhotometricDistort, self).__init__(opts=opts) + # for briightness + self.beta_min = beta_min + self.beta_max = beta_max + # for contrast + self.alpha_min = alpha_min + self.alpha_max = alpha_max + # for saturation + self.gamma_min = gamma_min + self.gamma_max = gamma_max + # for hue + self.delta_min = delta_min + self.delta_max = delta_max + self.p = getattr(opts, "image_augmentation.photo_metric_distort_opencv.p", 0.5) + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + group.add_argument( + "--image-augmentation.photo-metric-distort-opencv.enable", + action="store_true", + help="Randomly apply photometric transformation", + ) + + group.add_argument( + "--image-augmentation.photo-metric-distort-opencv.alpha-min", + type=float, + default=0.5, + help="Min. alpha value for contrast. Should be > 0", + ) + group.add_argument( + "--image-augmentation.photo-metric-distort-opencv.alpha-max", + type=float, + default=1.5, + help="Max. alpha value for contrast. Should be > 0", + ) + + group.add_argument( + "--image-augmentation.photo-metric-distort-opencv.beta-min", + type=float, + default=-0.2, + help="Min. alpha value for brightness. Should be between -1 and 1.", + ) + group.add_argument( + "--image-augmentation.photo-metric-distort-opencv.beta-max", + type=float, + default=0.2, + help="Max. alpha value for brightness. Should be between -1 and 1.", + ) + + group.add_argument( + "--image-augmentation.photo-metric-distort-opencv.gamma-min", + type=float, + default=0.5, + help="Min. alpha value for saturation. Should be > 0", + ) + group.add_argument( + "--image-augmentation.photo-metric-distort-opencv.gamma-max", + type=float, + default=1.5, + help="Max. alpha value for saturation. Should be > 0", + ) + + group.add_argument( + "--image-augmentation.photo-metric-distort-opencv.delta-min", + type=float, + default=-0.05, + help="Min. alpha value for Hue. Should be between -1 and 1.", + ) + group.add_argument( + "--image-augmentation.photo-metric-distort-opencv.delta-max", + type=float, + default=0.05, + help="Max. alpha value for Hue. Should be between -1 and 1.", + ) + + group.add_argument( + "--image-augmentation.photo-metric-distort-opencv.p", + type=float, + default=0.5, + help="Prob of applying transformation", + ) + + return parser + + def apply_transformations(self, image): + def convert_to_uint8(img): + return np.clip(img, 0, 255).astype(np.uint8) + + rand_nums = np.random.rand(6) + + image = image.astype(np.float32) + + # apply random contrast + alpha = ( + random.uniform(self.alpha_min, self.alpha_max) + if rand_nums[0] < self.p + else 1.0 + ) + image *= alpha + + # Apply random brightness + beta = ( + (random.uniform(self.beta_min, self.beta_max) * 255) + if rand_nums[1] < self.p + else 0.0 + ) + image += beta + + image = convert_to_uint8(image) + image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) + image = image.astype(np.float32) + + # Apply random saturation + gamma = ( + random.uniform(self.gamma_min, self.gamma_max) + if rand_nums[2] < self.p + else 1.0 + ) + image[..., 1] *= gamma + + # Apply random hue + delta = ( + int(random.uniform(self.delta_min, self.delta_max) * 255) + if rand_nums[3] < self.p + else 0.0 + ) + image[..., 0] += delta + + image = convert_to_uint8(image) + image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR) + + if alpha == 1.0 and rand_nums[4] < self.p: + # apply contrast if earlier not applied + image = image.astype(np.float32) + alpha = random.uniform(self.alpha_min, self.alpha_max) + image *= alpha + image = convert_to_uint8(image) + + # Lightning noise + channels = image.shape[-1] + swap = np.random.permutation(range(channels)) if rand_nums[5] < self.p else None + if swap is not None: + image = image[..., swap] + + return image + + def __call__(self, data: Dict) -> Dict: + image = data.pop("image") + data["image"] = self.apply_transformations(image) + return data + + +# add by huangzp +@register_transformations(name="bit_plane", type="image") +class BitPlane(BaseTransformation): + def __init__(self, opts, h, w): + # min_size = getattr(opts, "image_augmentation.random_resize.min_size", 256) + # max_size = getattr(opts, "image_augmentation.random_resize.max_size", 1024) + # interpolation = getattr( + # opts, "image-augmentation.random_resize.interpolation", "bilinear" + # ) + super(BitPlane, self).__init__(opts=opts) + self.h = h + self.w = w + self.weight = np.int16(np.ones([h, w, 8, 3])) + self.bias = np.int16(np.ones([h, w, 8, 3])) + for i in range(8): + self.weight[:,:,i,:] = self.weight[:,:,i,:] * (2**(7-i)) + self.bias[:,:,i,:] = self.bias[:,:,i,:] * (2**i) + + def __call__(self, data: Dict) -> Dict: + img = data['image'] + new_img = (self.weight & img[:,:,None,:]) * self.bias + new_img = new_img.reshape(self.h, self.w, 24) + # h,w = img.shape[0], img.shape[1] + # new_img = np.zeros((h,w,24)) + # for c in range(3): + # for i in range(h): + # for j in range(w): + # n = str(np.binary_repr(img[i,j,c],8)) + # for k in range(8): + # new_img[i,j,3*k+c] = n[k] + # # TODO: to check it out + data['image'] = new_img + return data + + def __repr__(self): + return "{}(bit plane 3 3 3 3 3 3 3 3 sum 24 plane)".format( + self.__class__.__name__, + ) diff --git a/Adaptive Frequency Filters/data/transforms/image_pil.py b/Adaptive Frequency Filters/data/transforms/image_pil.py new file mode 100644 index 0000000..59708bb --- /dev/null +++ b/Adaptive Frequency Filters/data/transforms/image_pil.py @@ -0,0 +1,2158 @@ +# +# For licensing see accompanying LICENSE file. +# Copyright (C) 2023 Apple Inc. All Rights Reserved. +# + +import copy +from PIL import Image, ImageFilter +from utils import logger +import numpy as np +import random +import torch +import math +import argparse +from torchvision import transforms as T +from torchvision.transforms import functional as F +from typing import Sequence, Dict, Any, Union, Tuple, List, Optional + +from . import register_transformations, BaseTransformation +from .utils import jaccard_numpy, setup_size + +INTERPOLATION_MODE_MAP = { + "nearest": T.InterpolationMode.NEAREST, + "bilinear": T.InterpolationMode.BILINEAR, + "bicubic": T.InterpolationMode.BICUBIC, + "cubic": T.InterpolationMode.BICUBIC, + "box": T.InterpolationMode.BOX, + "hamming": T.InterpolationMode.HAMMING, + "lanczos": T.InterpolationMode.LANCZOS, +} + + +def _interpolation_modes_from_str(name: str) -> T.InterpolationMode: + return INTERPOLATION_MODE_MAP[name] + + +def _crop_fn(data: Dict, top: int, left: int, height: int, width: int) -> Dict: + """Helper function for cropping""" + img = data["image"] + data["image"] = F.crop(img, top=top, left=left, height=height, width=width) + + if "mask" in data: + mask = data.pop("mask") + data["mask"] = F.crop(mask, top=top, left=left, height=height, width=width) + + if "box_coordinates" in data: + boxes = data.pop("box_coordinates") + + area_before_cropping = (boxes[..., 2] - boxes[..., 0]) * ( + boxes[..., 3] - boxes[..., 1] + ) + + boxes[..., 0::2] = np.clip(boxes[..., 0::2] - left, a_min=0, a_max=left + width) + boxes[..., 1::2] = np.clip(boxes[..., 1::2] - top, a_min=0, a_max=top + height) + + area_after_cropping = (boxes[..., 2] - boxes[..., 0]) * ( + boxes[..., 3] - boxes[..., 1] + ) + area_ratio = area_after_cropping / (area_before_cropping + 1) + + # keep the boxes whose area is atleast 20% of the area before cropping + keep = area_ratio >= 0.2 + + box_labels = data.pop("box_labels") + + data["box_coordinates"] = boxes[keep] + data["box_labels"] = box_labels[keep] + + if "instance_mask" in data: + assert "instance_coords" in data + + instance_masks = data.pop("instance_mask") + data["instance_mask"] = F.crop( + instance_masks, top=top, left=left, height=height, width=width + ) + + instance_coords = data.pop("instance_coords") + instance_coords[..., 0::2] = np.clip( + instance_coords[..., 0::2] - left, a_min=0, a_max=left + width + ) + instance_coords[..., 1::2] = np.clip( + instance_coords[..., 1::2] - top, a_min=0, a_max=top + height + ) + data["instance_coords"] = instance_coords + + return data + + +def _resize_fn( + data: Dict, + size: Union[Sequence, int], + interpolation: Optional[T.InterpolationMode or str] = T.InterpolationMode.BILINEAR, +) -> Dict: + """Helper function for resizing""" + img = data["image"] + + w, h = F.get_image_size(img) + + if isinstance(size, Sequence) and len(size) == 2: + size_h, size_w = size[0], size[1] + elif isinstance(size, int): + if (w <= h and w == size) or (h <= w and h == size): + return data + + if w < h: + size_h = int(size * h / w) + + size_w = size + else: + size_w = int(size * w / h) + size_h = size + else: + raise TypeError( + "Supported size args are int or tuple of length 2. Got inappropriate size arg: {}".format( + size + ) + ) + + if isinstance(interpolation, str): + interpolation = _interpolation_modes_from_str(name=interpolation) + + data["image"] = F.resize( + img=img, size=[size_h, size_w], interpolation=interpolation + ) + + if "mask" in data: + mask = data.pop("mask") + # mask can be a PIL or Tensor. + # Especially for Mask-RCNN, we may have tensors with first dimension as 0. + # In that case, resize, won't work. + # A workaround is that we check for the instance of a Tensor and then check its dimension. + if isinstance(mask, torch.Tensor) and mask.shape[0] == 0: + # It's empty tensor. + resized_mask = torch.zeros( + [0, size_h, size_w], dtype=mask.dtype, device=mask.device + ) + else: + resized_mask = F.resize( + img=mask, + size=[size_h, size_w], + interpolation=T.InterpolationMode.NEAREST, + ) + data["mask"] = resized_mask + + if "box_coordinates" in data: + boxes = data.pop("box_coordinates") + boxes[:, 0::2] *= 1.0 * size_w / w + boxes[:, 1::2] *= 1.0 * size_h / h + data["box_coordinates"] = boxes + + if "instance_mask" in data: + assert "instance_coords" in data + + instance_masks = data.pop("instance_mask") + + resized_instance_masks = F.resize( + img=instance_masks, + size=[size_h, size_w], + interpolation=T.InterpolationMode.NEAREST, + ) + data["instance_mask"] = resized_instance_masks + + instance_coords = data.pop("instance_coords") + instance_coords = instance_coords.astype(np.float) + instance_coords[..., 0::2] *= 1.0 * size_w / w + instance_coords[..., 1::2] *= 1.0 * size_h / h + data["instance_coords"] = instance_coords + + return data + + +def _pad_fn( + data: Dict, + padding: Union[int, Sequence], + fill: Optional[int] = 0, + padding_mode: Optional[str] = "constant", +) -> Dict: + # Taken from the functional_tensor.py pad + if isinstance(padding, int): + pad_left = pad_right = pad_top = pad_bottom = padding + elif len(padding) == 1: + pad_left = pad_right = pad_top = pad_bottom = padding[0] + elif len(padding) == 2: + pad_left = pad_right = padding[0] + pad_top = pad_bottom = padding[1] + else: + pad_left = padding[0] + pad_top = padding[1] + pad_right = padding[2] + pad_bottom = padding[3] + + padding = [pad_left, pad_top, pad_right, pad_bottom] + data["image"] = F.pad(data.pop("image"), padding, fill, padding_mode) + + if "mask" in data: + data["mask"] = F.pad(data.pop("mask"), padding, 0, "constant") + + if "box_coordinates" in data: + # labels remain unchanged + boxes = data.pop("box_coordinates") + boxes[:, 0::2] += pad_left + boxes[:, 1::2] += pad_top + data["box_coordinates"] = boxes + + return data + + +@register_transformations(name="fixed_size_crop", type="image_pil") +class FixedSizeCrop(BaseTransformation): + def __init__( + self, opts, size: Optional[Union[int, Tuple[int, int]]] = None, *args, **kwargs + ): + super().__init__(opts, *args, **kwargs) + # size can be passed as an argument or using config. + # The argument is useful when implementing variable samplers + if size is None: + size = getattr(opts, "image_augmentation.fixed_size_crop.size", None) + fill = getattr(opts, "image_augmentation.fixed_size_crop.fill", 0) + padding_mode = getattr( + opts, "image_augmentation.fixed_size_crop.padding_mode", "constant" + ) + size = setup_size( + size, + error_msg="Please provide either int or (int, int) for size in {}.".format( + self.__class__.__name__ + ), + ) + self.crop_height = size[0] + self.crop_width = size[1] + self.fill = fill + self.padding_mode = padding_mode + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + group.add_argument( + "--image-augmentation.fixed-size-crop.enable", + action="store_true", + help="use {}. This flag is useful when you want to study the effect of different " + "transforms.".format(cls.__name__), + ) + group.add_argument( + "--image-augmentation.fixed-size-crop.size", + type=int, + nargs="+", + default=None, + help="Image size either as an int or (int, int).", + ) + group.add_argument( + "--image-augmentation.fixed-size-crop.fill", + type=int, + default=0, + help="Fill value to be used during padding operation. Defaults to 0.", + ) + group.add_argument( + "--image-augmentation.fixed-size-crop.padding-mode", + type=str, + default="constant", + help="Padding modes. Defaults to constant", + ) + + return parser + + def __call__(self, data: Dict, *args, **kwargs) -> Dict: + img = data["image"] + width, height = F.get_image_size(img) + new_height = min(height, self.crop_height) + new_width = min(width, self.crop_width) + + if new_height != height or new_width != width: + offset_height = max(height - self.crop_height, 0) + offset_width = max(width - self.crop_width, 0) + + r = random.random() + top = int(offset_height * r) + left = int(offset_width * r) + + data = _crop_fn( + data, top=top, left=left, height=new_height, width=new_width + ) + + pad_bottom = max(self.crop_height - new_height, 0) + pad_right = max(self.crop_width - new_width, 0) + if pad_bottom != 0 or pad_right != 0: + data = _pad_fn( + data, + padding=[0, 0, pad_right, pad_bottom], + fill=self.fill, + padding_mode=self.padding_mode, + ) + return data + + def __repr__(self): + return "{}(crop_size=({}, {}), fill={}, padding_mode={})".format( + self.__class__.__name__, + self.crop_height, + self.crop_width, + self.fill, + self.padding_mode, + ) + + +@register_transformations(name="scale_jitter", type="image_pil") +class ScaleJitter(BaseTransformation): + """Randomly resizes the input within the scale range""" + + def __init__(self, opts, *args, **kwargs) -> None: + target_size = getattr(opts, "image_augmentation.scale_jitter.target_size", None) + if target_size is None: + logger.error( + "Target size can't be None in {}.".format(self.__class__.__name__) + ) + target_size = setup_size( + target_size, + error_msg="Need either an int or (int, int) for target size in {}".format( + self.__class__.__name__ + ), + ) + + scale_range = getattr(opts, "image_augmentation.scale_jitter.scale_range", None) + if scale_range is None: + logger.error( + "Scale range can't be None in {}".format(self.__class__.__name__) + ) + + if isinstance(scale_range, Sequence) and len(scale_range) == 2: + scale_range = scale_range + else: + logger.error( + "Need (float, float) for target size in {}".format( + self.__class__.__name__ + ) + ) + + if scale_range[0] > scale_range[1]: + logger.error( + "scale_range[1] >= scale_range[0] in {}. Got: {}".format( + self.__class__.__name__, scale_range[1], scale_range[0] + ) + ) + + interpolation = getattr( + opts, "image_augmentation.scale_jitter.interpolation", "bilinear" + ) + + if isinstance(interpolation, str): + interpolation = _interpolation_modes_from_str(name=interpolation) + + super().__init__(opts, *args, **kwargs) + self.target_size = target_size + self.scale_range = scale_range + self.interpolation = interpolation + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + group.add_argument( + "--image-augmentation.scale-jitter.enable", + action="store_true", + help="use {}. This flag is useful when you want to study the effect of different " + "transforms.".format(cls.__name__), + ) + group.add_argument( + "--image-augmentation.scale-jitter.interpolation", + type=str, + default="bilinear", + help="Interpolation method. Defaults to bilinear interpolation", + ) + group.add_argument( + "--image-augmentation.scale-jitter.target-size", + type=int, + nargs="+", + default=None, + help="Target image size either as an int or (int, int).", + ) + group.add_argument( + "--image-augmentation.scale-jitter.scale-range", + type=float, + nargs="+", + default=None, + help="Scale range as (float, float).", + ) + + return parser + + def __call__(self, data: Dict, *args, **kwargs) -> Dict: + img = data["image"] + orig_width, orig_height = F.get_image_size(img) + scale = self.scale_range[0] + random.random() * ( + self.scale_range[1] - self.scale_range[0] + ) + r = ( + min(self.target_size[1] / orig_height, self.target_size[0] / orig_width) + * scale + ) + new_width = int(orig_width * r) + new_height = int(orig_height * r) + + data = _resize_fn( + data, size=(new_height, new_width), interpolation=self.interpolation + ) + return data + + def __repr__(self): + return "{}(scale_range={}, target_size={}, interpolation={})".format( + self.__class__.__name__, + self.scale_range, + self.target_size, + self.interpolation, + ) + + +@register_transformations(name="random_resized_crop", type="image_pil") +class RandomResizedCrop(BaseTransformation, T.RandomResizedCrop): + """ + This class crops a random portion of an image and resize it to a given size. + """ + + def __init__(self, opts, size: Union[Sequence, int], *args, **kwargs) -> None: + interpolation = getattr( + opts, "image_augmentation.random_resized_crop.interpolation", "bilinear" + ) + scale = getattr( + opts, "image_augmentation.random_resized_crop.scale", (0.08, 1.0) + ) + ratio = getattr( + opts, + "image_augmentation.random_resized_crop.aspect_ratio", + (3.0 / 4.0, 4.0 / 3.0), + ) + + BaseTransformation.__init__(self, opts=opts) + + T.RandomResizedCrop.__init__( + self, size=size, scale=scale, ratio=ratio, interpolation=interpolation + ) + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + + group.add_argument( + "--image-augmentation.random-resized-crop.enable", + action="store_true", + help="use {}. This flag is useful when you want to study the effect of different " + "transforms.".format(cls.__name__), + ) + group.add_argument( + "--image-augmentation.random-resized-crop.interpolation", + type=str, + default="bilinear", + choices=list(INTERPOLATION_MODE_MAP.keys()), + help="Interpolation method for resizing. Defaults to bilinear.", + ) + group.add_argument( + "--image-augmentation.random-resized-crop.scale", + type=tuple, + default=(0.08, 1.0), + help="Specifies the lower and upper bounds for the random area of the crop, before resizing." + " The scale is defined with respect to the area of the original image. Defaults to " + "(0.08, 1.0)", + ) + group.add_argument( + "--image-augmentation.random-resized-crop.aspect-ratio", + type=float or tuple, + default=(3.0 / 4.0, 4.0 / 3.0), + help="lower and upper bounds for the random aspect ratio of the crop, before resizing. " + "Defaults to (3./4., 4./3.)", + ) + return parser + + def __call__(self, data: Dict) -> Dict: + img = data["image"] + i, j, h, w = super().get_params(img=img, scale=self.scale, ratio=self.ratio) + data = _crop_fn(data=data, top=i, left=j, height=h, width=w) + return _resize_fn(data=data, size=self.size, interpolation=self.interpolation) + + def __repr__(self) -> str: + return "{}(scale={}, ratio={}, size={}, interpolation={})".format( + self.__class__.__name__, + self.scale, + self.ratio, + self.size, + self.interpolation, + ) + + +@register_transformations(name="auto_augment", type="image_pil") +class AutoAugment(BaseTransformation, T.AutoAugment): + """ + This class implements the `AutoAugment data augmentation `_ method. + """ + + def __init__(self, opts, *args, **kwargs) -> None: + policy_name = getattr( + opts, "image_augmentation.auto_augment.policy", "imagenet" + ) + interpolation = getattr( + opts, "image_augmentation.auto_augment.interpolation", "bilinear" + ) + if policy_name == "imagenet": + policy = T.AutoAugmentPolicy.IMAGENET + else: + raise NotImplemented + + if isinstance(interpolation, str): + interpolation = _interpolation_modes_from_str(name=interpolation) + + BaseTransformation.__init__(self, opts=opts) + T.AutoAugment.__init__(self, policy=policy, interpolation=interpolation) + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + + group.add_argument( + "--image-augmentation.auto-augment.enable", + action="store_true", + help="use {}. This flag is useful when you want to study the effect of different " + "transforms.".format(cls.__name__), + ) + group.add_argument( + "--image-augmentation.auto-augment.policy", + type=str, + default="imagenet", + help="Auto-augment policy name. Defaults to imagenet.", + ) + group.add_argument( + "--image-augmentation.auto-augment.interpolation", + type=str, + default="bilinear", + help="Auto-augment interpolation method. Defaults to bilinear interpolation", + ) + return parser + + def __call__(self, data: Dict) -> Dict: + if "box_coordinates" in data or "mask" in data or "instance_masks" in data: + logger.error( + "{} is only supported for classification tasks".format( + self.__class__.__name__ + ) + ) + + img = data["image"] + img = super().forward(img) + data["image"] = img + return data + + def __repr__(self) -> str: + return "{}(policy={}, interpolation={})".format( + self.__class__.__name__, self.policy, self.interpolation + ) + + +@register_transformations(name="rand_augment", type="image_pil") +class RandAugment(BaseTransformation, T.RandAugment): + """ + This class implements the `RandAugment data augmentation `_ method. + """ + + def __init__(self, opts, *args, **kwargs) -> None: + num_ops = getattr(opts, "image_augmentation.rand_augment.num_ops", 2) + magnitude = getattr(opts, "image_augmentation.rand_augment.magnitude", 9) + num_magnitude_bins = getattr( + opts, "image_augmentation.rand_augment.num_magnitude_bins", 31 + ) + interpolation = getattr( + opts, "image_augmentation.rand_augment.interpolation", "bilinear" + ) + + BaseTransformation.__init__(self, opts=opts) + + if isinstance(interpolation, str): + interpolation = _interpolation_modes_from_str(name=interpolation) + + T.RandAugment.__init__( + self, + num_ops=num_ops, + magnitude=magnitude, + num_magnitude_bins=num_magnitude_bins, + interpolation=interpolation, + ) + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + + group.add_argument( + "--image-augmentation.rand-augment.enable", + action="store_true", + help="Use {}. This flag is useful when you want to study the effect of different " + "transforms.".format(cls.__name__), + ) + group.add_argument( + "--image-augmentation.rand-augment.num-ops", + type=int, + default=2, + help="Number of augmentation transformations to apply sequentially. Defaults to 2.", + ) + group.add_argument( + "--image-augmentation.rand-augment.magnitude", + type=int, + default=9, + help="Magnitude for all the transformations. Defaults to 9", + ) + group.add_argument( + "--image-augmentation.rand-augment.num-magnitude-bins", + type=int, + default=31, + help="The number of different magnitude values. Defaults to 31.", + ) + group.add_argument( + "--image-augmentation.rand-augment.interpolation", + type=str, + default="bilinear", + choices=list(INTERPOLATION_MODE_MAP.keys()), + help="Desired interpolation method. Defaults to bilinear", + ) + return parser + + def __call__(self, data: Dict) -> Dict: + if "box_coordinates" in data or "mask" in data or "instance_masks" in data: + logger.error( + "{} is only supported for classification tasks".format( + self.__class__.__name__ + ) + ) + + img = data["image"] + img = super().forward(img) + data["image"] = img + return data + + def __repr__(self) -> str: + return "{}(num_ops={}, magnitude={}, num_magnitude_bins={}, interpolation={})".format( + self.__class__.__name__, + self.num_ops, + self.magnitude, + self.num_magnitude_bins, + self.interpolation, + ) + + +@register_transformations(name="trivial_augment_wide", type="image_pil") +class TrivialAugmentWide(BaseTransformation, T.TrivialAugmentWide): + """ + This class implements the `TrivialAugment (Wide) data augmentation `_ method. + """ + + def __init__(self, opts, *args, **kwargs) -> None: + num_magnitude_bins = getattr( + opts, "image_augmentation.trivial_augment_wide.num_magnitude_bins", 31 + ) + interpolation = getattr( + opts, "image_augmentation.trivial_augment_wide.interpolation", "bilinear" + ) + + BaseTransformation.__init__(self, opts=opts) + + if isinstance(interpolation, str): + interpolation = _interpolation_modes_from_str(name=interpolation) + + T.TrivialAugmentWide.__init__( + self, + num_magnitude_bins=num_magnitude_bins, + interpolation=interpolation, + ) + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + + group.add_argument( + "--image-augmentation.trivial-augment-wide.enable", + action="store_true", + help="Use {}. This flag is useful when you want to study the effect of different " + "transforms.".format(cls.__name__), + ) + group.add_argument( + "--image-augmentation.trivial-augment-wide.num-magnitude-bins", + type=int, + default=31, + help="The number of different magnitude values. Defaults to 31.", + ) + group.add_argument( + "--image-augmentation.trivial-augment-wide.interpolation", + type=str, + default="bilinear", + choices=list(INTERPOLATION_MODE_MAP.keys()), + help="Desired interpolation method. Defaults to bilinear", + ) + return parser + + def __call__(self, data: Dict) -> Dict: + if "box_coordinates" in data or "mask" in data or "instance_masks" in data: + logger.error( + "{} is only supported for classification tasks".format( + self.__class__.__name__ + ) + ) + + img = data["image"] + img = super().forward(img) + data["image"] = img + return data + + def __repr__(self) -> str: + return "{}(num_magnitude_bins={}, interpolation={})".format( + self.__class__.__name__, + self.num_magnitude_bins, + self.interpolation, + ) + + +@register_transformations(name="random_horizontal_flip", type="image_pil") +class RandomHorizontalFlip(BaseTransformation): + """ + This class implements random horizontal flipping method + """ + + def __init__(self, opts, *args, **kwargs) -> None: + p = getattr(opts, "image_augmentation.random_horizontal_flip.p", 0.5) + super().__init__(opts=opts) + self.p = p + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + group.add_argument( + "--image-augmentation.random-horizontal-flip.enable", + action="store_true", + help="use {}. This flag is useful when you want to study the effect of different " + "transforms.".format(cls.__name__), + ) + group.add_argument( + "--image-augmentation.random-horizontal-flip.p", + type=float, + default=0.5, + help="Probability for applying random horizontal flip", + ) + return parser + + def __call__(self, data: Dict) -> Dict: + if random.random() <= self.p: + img = data["image"] + width, height = F.get_image_size(img) + data["image"] = F.hflip(img) + + if "mask" in data: + mask = data.pop("mask") + data["mask"] = F.hflip(mask) + + if "box_coordinates" in data: + boxes = data.pop("box_coordinates") + boxes[..., 0::2] = width - boxes[..., 2::-2] + data["box_coordinates"] = boxes + + if "instance_mask" in data: + assert "instance_coords" in data + + instance_coords = data.pop("instance_coords") + instance_coords[..., 0::2] = width - instance_coords[..., 2::-2] + data["instance_coords"] = instance_coords + + instance_masks = data.pop("instance_mask") + data["instance_mask"] = F.hflip(instance_masks) + return data + + def __repr__(self) -> str: + return "{}(p={})".format(self.__class__.__name__, self.p) + + +@register_transformations(name="random_rotate", type="image_pil") +class RandomRotate(BaseTransformation): + """ + This class implements random rotation method + """ + + def __init__(self, opts, *args, **kwargs) -> None: + super().__init__(opts=opts) + self.angle = getattr(opts, "image_augmentation.random_rotate.angle", 10) + self.mask_fill = getattr(opts, "image_augmentation.random_rotate.mask_fill", 0) + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + group.add_argument( + "--image-augmentation.random-rotate.enable", + action="store_true", + help="use {}. This flag is useful when you want to study the effect of different " + "transforms.".format(cls.__name__), + ) + group.add_argument( + "--image-augmentation.random-rotate.angle", + type=float, + default=10, + help="Angle for rotation. Defaults to 10. The angle is sampled " + "uniformly from [-angle, angle]", + ) + group.add_argument( + "--image-augmentation.random-rotate.mask-fill", + default=0, + help="Fill value for the segmentation mask. Defaults to 0.", + ) + return parser + + def __call__(self, data: Dict) -> Dict: + + data_keys = list(data.keys()) + if "box_coordinates" in data_keys or "instance_mask" in data_keys: + logger.error("{} supports only images and masks") + + rand_angle = random.uniform(-self.angle, self.angle) + img = data.pop("image") + data["image"] = F.rotate( + img, angle=rand_angle, interpolation=F.InterpolationMode.BILINEAR, fill=0 + ) + if "mask" in data: + mask = data.pop("mask") + data["mask"] = F.rotate( + mask, + angle=rand_angle, + interpolation=F.InterpolationMode.NEAREST, + fill=self.mask_fill, + ) + return data + + def __repr__(self) -> str: + return "{}(angle={}, mask_fill={})".format( + self.__class__.__name__, self.angle, self.mask_fill + ) + + +@register_transformations(name="resize", type="image_pil") +class Resize(BaseTransformation): + """ + This class implements resizing operation. + + .. note:: + Two possible modes for resizing. + 1. Resize while maintaining aspect ratio. To enable this option, pass int as a size + 2. Resize to a fixed size. To enable this option, pass a tuple of height and width as a size + + .. note:: + If img_size is passed as a positional argument, then it will override size from args + """ + + def __init__( + self, + opts, + img_size: Optional[Union[Tuple[int, int], int]] = None, + *args, + **kwargs + ) -> None: + interpolation = getattr( + opts, "image_augmentation.resize.interpolation", "bilinear" + ) + super().__init__(opts=opts) + + # img_size argument is useful for implementing multi-scale sampler + size = ( + getattr(opts, "image_augmentation.resize.size", None) + if img_size is None + else img_size + ) + if size is None: + logger.error("Size can not be None in {}".format(self.__class__.__name__)) + + # Possible modes. + # 1. Resize while maintaining aspect ratio. To enable this option, pass int as a size + # 2. Resize to a fixed size. To enable this option, pass a tuple of height and width as a size + + if isinstance(size, Sequence) and len(size) == 1: + # List with single integer + size = size[0] + elif isinstance(size, Sequence) and len(size) > 2: + logger.error( + "The length of size should be either 1 or 2 in {}. Got: {}".format( + self.__class__.__name__, size + ) + ) + + if not (isinstance(size, Sequence) or isinstance(size, int)): + logger.error( + "Size needs to be either Tuple of length 2 or an integer in {}. Got: {}".format( + self.__class__.__name__, size + ) + ) + + self.size = size + self.interpolation = interpolation + self.maintain_aspect_ratio = True if isinstance(size, int) else False + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + group.add_argument( + "--image-augmentation.resize.enable", + action="store_true", + help="use {}. This flag is useful when you want to study the effect of different " + "transforms.".format(cls.__name__), + ) + group.add_argument( + "--image-augmentation.resize.interpolation", + type=str, + default="bilinear", + choices=list(INTERPOLATION_MODE_MAP.keys()), + help="Desired interpolation method for resizing. Defaults to bilinear", + ) + group.add_argument( + "--image-augmentation.resize.size", + type=int, + nargs="+", + default=256, + help="Resize image to the specified size. If int is passed, then shorter side is resized" + "to the specified size and longest side is resized while maintaining aspect ratio." + "Defaults to None.", + ) + return parser + + def __call__(self, data: Dict) -> Dict: + return _resize_fn(data, size=self.size, interpolation=self.interpolation) + + def __repr__(self) -> str: + return "{}(size={}, interpolation={}, maintain_aspect_ratio={})".format( + self.__class__.__name__, + self.size, + self.interpolation, + self.maintain_aspect_ratio, + ) + + +@register_transformations(name="center_crop", type="image_pil") +class CenterCrop(BaseTransformation): + """ + This class implements center cropping method. + + .. note:: + This class assumes that the input size is greater than or equal to the desired size. + """ + + def __init__(self, opts, *args, **kwargs) -> None: + super().__init__(opts=opts) + size = getattr(opts, "image_augmentation.center_crop.size", None) + + if size is None: + logger.error("Size cannot be None in {}".format(self.__class__.__name__)) + + if isinstance(size, Sequence) and len(size) == 2: + self.height, self.width = size[0], size[1] + elif isinstance(size, Sequence) and len(size) == 1: + self.height = self.width = size[0] + elif isinstance(size, int): + self.height = self.width = size + else: + logger.error("Scale should be either an int or tuple of ints") + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + + group.add_argument( + "--image-augmentation.center-crop.enable", + action="store_true", + help="use {}. This flag is useful when you want to study the effect of different " + "transforms.".format(cls.__name__), + ) + group.add_argument( + "--image-augmentation.center-crop.size", + type=int, + nargs="+", + default=224, + help="Center crop size. Defaults to None.", + ) + return parser + + def __call__(self, data: Dict) -> Dict: + width, height = F.get_image_size(data["image"]) + i = (height - self.height) // 2 + j = (width - self.width) // 2 + return _crop_fn(data=data, top=i, left=j, height=self.height, width=self.width) + + def __repr__(self) -> str: + return "{}(size=(h={}, w={}))".format( + self.__class__.__name__, self.height, self.width + ) + + +@register_transformations(name="ssd_cropping", type="image_pil") +class SSDCroping(BaseTransformation): + """ + This class implements cropping method for `Single shot object detector `_. + """ + + def __init__(self, opts, *args, **kwargs) -> None: + super().__init__(opts=opts) + + self.iou_sample_opts = getattr( + opts, + "image_augmentation.ssd_crop.iou_thresholds", + [0.0, 0.1, 0.3, 0.5, 0.7, 0.9, 1.0], + ) + self.trials = getattr(opts, "image_augmentation.ssd_crop.n_trials", 40) + self.min_aspect_ratio = getattr( + opts, "image_augmentation.ssd_crop.min_aspect_ratio", 0.5 + ) + self.max_aspect_ratio = getattr( + opts, "image_augmentation.ssd_crop.max_aspect_ratio", 2.0 + ) + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + + group.add_argument( + "--image-augmentation.ssd-crop.enable", + action="store_true", + help="use {}. This flag is useful when you want to study the effect of different " + "transforms.".format(cls.__name__), + ) + group.add_argument( + "--image-augmentation.ssd-crop.iou-thresholds", + type=float, + nargs="+", + default=[0.0, 0.1, 0.3, 0.5, 0.7, 0.9, 1.0], + help="IoU thresholds for SSD cropping. Defaults to [0.0, 0.1, 0.3, 0.5, 0.7, 0.9, 1.0]", + ) + group.add_argument( + "--image-augmentation.ssd-crop.n-trials", + type=int, + default=40, + help="Number of trials for SSD cropping. Defaults to 40", + ) + group.add_argument( + "--image-augmentation.ssd-crop.min-aspect-ratio", + type=float, + default=0.5, + help="Min. aspect ratio in SSD Cropping. Defaults to 0.5", + ) + group.add_argument( + "--image-augmentation.ssd-crop.max-aspect-ratio", + type=float, + default=2.0, + help="Max. aspect ratio in SSD Cropping. Defaults to 2.0", + ) + return parser + + def __call__(self, data: Dict) -> Dict: + if "box_coordinates" in data: + boxes = data["box_coordinates"] + + # guard against no boxes + if boxes.shape[0] == 0: + return data + + image = data["image"] + labels = data["box_labels"] + width, height = F.get_image_size(image) + + while True: + # randomly choose a mode + min_jaccard_overalp = random.choice(self.iou_sample_opts) + if min_jaccard_overalp == 0.0: + return data + + for _ in range(self.trials): + new_w = int(random.uniform(0.3 * width, width)) + new_h = int(random.uniform(0.3 * height, height)) + + aspect_ratio = new_h / new_w + if not ( + self.min_aspect_ratio <= aspect_ratio <= self.max_aspect_ratio + ): + continue + + left = int(random.uniform(0, width - new_w)) + top = int(random.uniform(0, height - new_h)) + + # convert to integer rect x1,y1,x2,y2 + rect = np.array([left, top, left + new_w, top + new_h]) + + # calculate IoU (jaccard overlap) b/t the cropped and gt boxes + ious = jaccard_numpy(boxes, rect) + + # is min and max overlap constraint satisfied? if not try again + if ious.max() < min_jaccard_overalp: + continue + + # keep overlap with gt box IF center in sampled patch + centers = (boxes[:, :2] + boxes[:, 2:]) * 0.5 + + # mask in all gt boxes that above and to the left of centers + m1 = (rect[0] < centers[:, 0]) * (rect[1] < centers[:, 1]) + + # mask in all gt boxes that under and to the right of centers + m2 = (rect[2] > centers[:, 0]) * (rect[3] > centers[:, 1]) + + # mask in that both m1 and m2 are true + mask = m1 * m2 + + # have any valid boxes? try again if not + if not mask.any(): + continue + + # if image size is too small, try again + if (rect[3] - rect[1]) < 100 or (rect[2] - rect[0]) < 100: + continue + + # cut the crop from the image + image = F.crop(image, top=top, left=left, width=new_w, height=new_h) + + # take only matching gt boxes + current_boxes = boxes[mask, :].copy() + + # take only matching gt labels + current_labels = labels[mask] + + # should we use the box left and top corner or the crop's + current_boxes[:, :2] = np.maximum(current_boxes[:, :2], rect[:2]) + # adjust to crop (by substracting crop's left,top) + current_boxes[:, :2] -= rect[:2] + + current_boxes[:, 2:] = np.minimum(current_boxes[:, 2:], rect[2:]) + # adjust to crop (by substracting crop's left,top) + current_boxes[:, 2:] -= rect[:2] + + data["image"] = image + data["box_labels"] = current_labels + data["box_coordinates"] = current_boxes + + if "mask" in data: + mask = data.pop("mask") + data["mask"] = F.crop( + mask, top=top, left=left, width=new_w, height=new_h + ) + + if "instance_mask" in data: + assert "instance_coords" in data + instance_masks = data.pop("instance_mask") + data["instance_mask"] = F.crop( + instance_masks, + top=top, + left=left, + width=new_w, + height=new_h, + ) + + instance_coords = data.pop("instance_coords") + # should we use the box left and top corner or the crop's + instance_coords[..., :2] = np.maximum( + instance_coords[..., :2], rect[:2] + ) + # adjust to crop (by substracting crop's left,top) + instance_coords[..., :2] -= rect[:2] + + instance_coords[..., 2:] = np.minimum( + instance_coords[..., 2:], rect[2:] + ) + # adjust to crop (by substracting crop's left,top) + instance_coords[..., 2:] -= rect[:2] + data["instance_coords"] = instance_coords + + return data + return data + + +@register_transformations(name="photo_metric_distort", type="image_pil") +class PhotometricDistort(BaseTransformation): + """ + This class implements Photometeric distorion. + + .. note:: + Hyper-parameters of PhotoMetricDistort in PIL and OpenCV are different. Be careful + """ + + def __init__(self, opts, *args, **kwargs) -> None: + # contrast + alpha_min = getattr( + opts, "image_augmentation.photo_metric_distort.alpha_min", 0.5 + ) + alpha_max = getattr( + opts, "image_augmentation.photo_metric_distort.alpha_max", 1.5 + ) + contrast = T.ColorJitter(contrast=[alpha_min, alpha_max]) + + # brightness + beta_min = getattr( + opts, "image_augmentation.photo_metric_distort.beta_min", 0.875 + ) + beta_max = getattr( + opts, "image_augmentation.photo_metric_distort.beta_max", 1.125 + ) + brightness = T.ColorJitter(brightness=[beta_min, beta_max]) + + # saturation + gamma_min = getattr( + opts, "image_augmentation.photo_metric_distort.gamma_min", 0.5 + ) + gamma_max = getattr( + opts, "image_augmentation.photo_metric_distort.gamma_max", 1.5 + ) + saturation = T.ColorJitter(saturation=[gamma_min, gamma_max]) + + # Hue + delta_min = getattr( + opts, "image_augmentation.photo_metric_distort.delta_min", -0.05 + ) + delta_max = getattr( + opts, "image_augmentation.photo_metric_distort.delta_max", 0.05 + ) + hue = T.ColorJitter(hue=[delta_min, delta_max]) + + super().__init__(opts=opts) + self._brightness = brightness + self._contrast = contrast + self._hue = hue + self._saturation = saturation + self.p = getattr(opts, "image_augmentation.photo_metric_distort.p", 0.5) + + def __repr__(self) -> str: + return "{}(contrast={}, brightness={}, saturation={}, hue={})".format( + self.__class__.__name__, + self._contrast.contrast, + self._brightness.brightness, + self._saturation.saturation, + self._hue.hue, + ) + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + group.add_argument( + "--image-augmentation.photo-metric-distort.enable", + action="store_true", + help="use {}. This flag is useful when you want to study the effect of different " + "transforms.".format(cls.__name__), + ) + + group.add_argument( + "--image-augmentation.photo-metric-distort.alpha-min", + type=float, + default=0.5, + help="Min. alpha value for contrast. Should be > 0. Defaults to 0.5", + ) + group.add_argument( + "--image-augmentation.photo-metric-distort.alpha-max", + type=float, + default=1.5, + help="Max. alpha value for contrast. Should be > 0. Defaults to 1.5", + ) + + group.add_argument( + "--image-augmentation.photo-metric-distort.beta-min", + type=float, + default=0.875, + help="Min. beta value for brightness. Should be > 0. Defaults to 0.8", + ) + group.add_argument( + "--image-augmentation.photo-metric-distort.beta-max", + type=float, + default=1.125, + help="Max. beta value for brightness. Should be > 0. Defaults to 1.2", + ) + + group.add_argument( + "--image-augmentation.photo-metric-distort.gamma-min", + type=float, + default=0.5, + help="Min. gamma value for saturation. Should be > 0. Defaults to 0.5", + ) + group.add_argument( + "--image-augmentation.photo-metric-distort.gamma-max", + type=float, + default=1.5, + help="Max. gamma value for saturation. Should be > 0. Defaults to 1.5", + ) + + group.add_argument( + "--image-augmentation.photo-metric-distort.delta-min", + type=float, + default=-0.05, + help="Min. delta value for Hue. Should be between -1 and 1. Defaults to -0.05", + ) + group.add_argument( + "--image-augmentation.photo-metric-distort.delta-max", + type=float, + default=0.05, + help="Max. delta value for Hue. Should be between -1 and 1. Defaults to 0.05", + ) + + group.add_argument( + "--image-augmentation.photo-metric-distort.p", + type=float, + default=0.5, + help="Probability for applying a distortion. Defaults to 0.5", + ) + + return parser + + def _apply_transformations(self, image): + r = np.random.rand(7) + + if r[0] < self.p: + image = self._brightness(image) + + contrast_before = r[1] < self.p + if contrast_before and r[2] < self.p: + image = self._contrast(image) + + if r[3] < self.p: + image = self._saturation(image) + + if r[4] < self.p: + image = self._hue(image) + + if not contrast_before and r[5] < self.p: + image = self._contrast(image) + + if r[6] < self.p and image.mode != "L": + # Only permute channels for RGB images + # [H, W, C] format + image_np = np.asarray(image) + n_channels = image_np.shape[2] + image_np = image_np[..., np.random.permutation(range(n_channels))] + image = Image.fromarray(image_np) + return image + + def __call__(self, data: Dict) -> Dict: + image = data.pop("image") + data["image"] = self._apply_transformations(image) + return data + + +@register_transformations(name="box_percent_coords", type="image_pil") +class BoxPercentCoords(BaseTransformation): + """ + This class converts the box coordinates to percent + """ + + def __init__(self, opts, *args, **kwargs) -> None: + super().__init__(opts=opts) + + def __call__(self, data: Dict) -> Dict: + if "box_coordinates" in data: + boxes = data.pop("box_coordinates") + image = data["image"] + width, height = F.get_image_size(image) + + boxes = boxes.astype(np.float) + + boxes[..., 0::2] /= width + boxes[..., 1::2] /= height + data["box_coordinates"] = boxes + + return data + + +@register_transformations(name="instance_processor", type="image_pil") +class InstanceProcessor(BaseTransformation): + """ + This class processes the instance masks. + """ + + def __init__( + self, + opts, + instance_size: Optional[Union[int, Tuple[int, ...]]] = 16, + *args, + **kwargs + ) -> None: + super().__init__(opts=opts) + self.instance_size = setup_size(instance_size) + + def __call__(self, data: Dict) -> Dict: + + if "instance_mask" in data: + assert "instance_coords" in data + instance_masks = data.pop("instance_mask") + instance_coords = data.pop("instance_coords") + instance_coords = instance_coords.astype(np.int) + + valid_boxes = (instance_coords[..., 3] > instance_coords[..., 1]) & ( + instance_coords[..., 2] > instance_coords[..., 0] + ) + instance_masks = instance_masks[valid_boxes] + instance_coords = instance_coords[valid_boxes] + + num_instances = instance_masks.shape[0] + + resized_instances = [] + for i in range(num_instances): + # format is [N, H, W] + instance_m = instance_masks[i] + box_coords = instance_coords[i] + + instance_m = F.crop( + instance_m, + top=box_coords[1], + left=box_coords[0], + height=box_coords[3] - box_coords[1], + width=box_coords[2] - box_coords[0], + ) + # need to unsqueeze and squeeze to make F.resize work + instance_m = F.resize( + instance_m.unsqueeze(0), + size=self.instance_size, + interpolation=T.InterpolationMode.NEAREST, + ).squeeze(0) + resized_instances.append(instance_m) + + if len(resized_instances) == 0: + resized_instances = torch.zeros( + size=(1, self.instance_size[0], self.instance_size[1]), + dtype=torch.long, + ) + instance_coords = np.array( + [[0, 0, self.instance_size[0], self.instance_size[1]]] + ) + else: + resized_instances = torch.stack(resized_instances, dim=0) + + data["instance_mask"] = resized_instances + data["instance_coords"] = instance_coords.astype(np.float) + return data + + +@register_transformations(name="random_resize", type="image_pil") +class RandomResize(BaseTransformation): + """ + This class implements random resizing method. + """ + + def __init__(self, opts, *args, **kwargs) -> None: + min_ratio = getattr(opts, "image_augmentation.random_resize.min_ratio", 0.5) + max_ratio = getattr(opts, "image_augmentation.random_resize.max_ratio", 2.0) + interpolation = getattr( + opts, "image_augmentation.random_resize.interpolation", "bilinear" + ) + + max_scale_long_edge = getattr( + opts, "image_augmentation.random_resize.max_scale_long_edge", None + ) + max_scale_short_edge = getattr( + opts, "image_augmentation.random_resize.max_scale_short_edge", None + ) + + if max_scale_long_edge is None and max_scale_short_edge is not None: + logger.warning( + "max_scale_long_edge cannot be none when max_scale_short_edge is not None in {}. Setting both to " + "None".format(self.__class__.__name__) + ) + max_scale_long_edge = None + max_scale_short_edge = None + elif max_scale_long_edge is not None and max_scale_short_edge is None: + logger.warning( + "max_scale_short_edge cannot be none when max_scale_long_edge is not None in {}. Setting both to " + "None".format(self.__class__.__name__) + ) + max_scale_long_edge = None + max_scale_short_edge = None + + super().__init__(opts=opts) + self.min_ratio = min_ratio + self.max_ratio = max_ratio + + self.max_scale_long_edge = max_scale_long_edge + self.max_scale_short_edge = max_scale_short_edge + + self.interpolation = interpolation + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + group.add_argument( + "--image-augmentation.random-resize.enable", + action="store_true", + help="use {}. This flag is useful when you want to study the effect of different " + "transforms.".format(cls.__name__), + ) + group.add_argument( + "--image-augmentation.random-resize.max-scale-long-edge", + type=int, + default=None, + help="Max. value along the longest edge. Defaults to None", + ) + group.add_argument( + "--image-augmentation.random-resize.max-scale-short-edge", + type=int, + default=None, + help="Max. value along the shortest edge. Defaults to None.", + ) + + group.add_argument( + "--image-augmentation.random-resize.min-ratio", + type=float, + default=0.5, + help="Min ratio for random resizing. Defaults to 0.5", + ) + group.add_argument( + "--image-augmentation.random-resize.max-ratio", + type=float, + default=2.0, + help="Max ratio for random resizing. Defaults to 2.0", + ) + group.add_argument( + "--image-augmentation.random-resize.interpolation", + type=str, + default="bilinear", + choices=list(INTERPOLATION_MODE_MAP.keys()), + help="Desired interpolation method. Defaults to bilinear.", + ) + return parser + + def __call__(self, data: Dict) -> Dict: + random_ratio = random.uniform(self.min_ratio, self.max_ratio) + + # compute the size + width, height = F.get_image_size(data["image"]) + if self.max_scale_long_edge is not None: + min_hw = min(height, width) + max_hw = max(height, width) + scale_factor = ( + min( + self.max_scale_long_edge / max_hw, + self.max_scale_short_edge / min_hw, + ) + * random_ratio + ) + # resize while maintaining aspect ratio + new_size = int(math.ceil(height * scale_factor)), int( + math.ceil(width * scale_factor) + ) + else: + new_size = int(math.ceil(height * random_ratio)), int( + math.ceil(width * random_ratio) + ) + # new_size should be a tuple of height and width + return _resize_fn(data, size=new_size, interpolation=self.interpolation) + + def __repr__(self) -> str: + return "{}(min_ratio={}, max_ratio={}, interpolation={}, max_long_edge={}, max_short_edge={})".format( + self.__class__.__name__, + self.min_ratio, + self.max_ratio, + self.interpolation, + self.max_scale_long_edge, + self.max_scale_short_edge, + ) + + +@register_transformations(name="random_short_size_resize", type="image_pil") +class RandomShortSizeResize(BaseTransformation): + """ + This class implements random resizing such that shortest side is between specified minimum and maximum values. + """ + + def __init__(self, opts, *args, **kwargs) -> None: + super().__init__(opts=opts) + short_size_min = getattr( + opts, "image_augmentation.random_short_size_resize.short_side_min", None + ) + short_size_max = getattr( + opts, "image_augmentation.random_short_size_resize.short_side_max", None + ) + max_img_dim = getattr( + opts, "image_augmentation.random_short_size_resize.max_img_dim", None + ) + if short_size_min is None: + logger.error( + "Short side minimum value can't be None in {}".format( + self.__class__.__name__ + ) + ) + if short_size_max is None: + logger.error( + "Short side maximum value can't be None in {}".format( + self.__class__.__name__ + ) + ) + if max_img_dim is None: + logger.error( + "Max. image dimension value can't be None in {}".format( + self.__class__.__name__ + ) + ) + + if short_size_max <= short_size_min: + logger.error( + "Short side maximum value should be >= short side minimum value in {}. Got: {} and {}".format( + self.__class__.__name__, short_size_max, short_size_min + ) + ) + + interpolation = getattr( + opts, "image_augmentation.random_short_size_resize.interpolation", "bicubic" + ) + + self.short_side_min = short_size_min + self.short_side_max = short_size_max + self.max_img_dim = max_img_dim + self.interpolation = interpolation + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + group.add_argument( + "--image-augmentation.random-short-size-resize.enable", + action="store_true", + help="use {}. This flag is useful when you want to study the effect of different " + "transforms.".format(cls.__name__), + ) + group.add_argument( + "--image-augmentation.random-short-size-resize.short-side-min", + type=int, + default=None, + help="Minimum value for image's shortest side. Defaults to None.", + ) + group.add_argument( + "--image-augmentation.random-short-size-resize.short-side-max", + type=int, + default=None, + help="Maximum value for image's shortest side. Defaults to None.", + ) + group.add_argument( + "--image-augmentation.random-short-size-resize.interpolation", + type=str, + default="bicubic", + choices=list(INTERPOLATION_MODE_MAP.keys()), + help="Desired interpolation method. Defaults to bicubic", + ) + group.add_argument( + "--image-augmentation.random-short-size-resize.max-img-dim", + type=int, + default=None, + help="Max. image dimension. Defaults to None.", + ) + return parser + + def __call__(self, data: Dict) -> Dict: + short_side = random.randint(self.short_side_min, self.short_side_max) + img_w, img_h = data["image"].size + scale = min( + short_side / min(img_h, img_w), self.max_img_dim / max(img_h, img_w) + ) + img_w = int(img_w * scale) + img_h = int(img_h * scale) + data = _resize_fn(data, size=(img_h, img_w), interpolation=self.interpolation) + return data + + def __repr__(self) -> str: + return "{}(short_side_min={}, short_side_max={}, interpolation={})".format( + self.__class__.__name__, + self.short_side_min, + self.short_side_max, + self.interpolation, + ) + + +@register_transformations(name="random_erasing", type="image_pil") +class RandomErasing(BaseTransformation, T.RandomErasing): + """ + This class randomly selects a region in a tensor and erases its pixels. + See `this paper `_ for details. + """ + + def __init__(self, opts, *args, **kwargs) -> None: + BaseTransformation.__init__(self, opts=opts) + random_erase_p = getattr(opts, "image_augmentation.random_erase.p", 0.5) + T.RandomErasing.__init__(self, p=random_erase_p) + + self.random_erase_p = random_erase_p + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + + group.add_argument( + "--image-augmentation.random-erase.enable", + action="store_true", + help="use {}. This flag is useful when you want to study the effect of different " + "transforms.".format(cls.__name__), + ) + group.add_argument( + "--image-augmentation.random-erase.p", + type=float, + default=0.5, + help="Probability that random erasing operation will be applied. Defaults to 0.5", + ) + return parser + + def __call__(self, data: Dict) -> Dict: + data["image"] = super().forward(data.pop("image")) + return data + + def __repr__(self) -> str: + return "{}(random_erase_p={})".format( + self.__class__.__name__, self.random_erase_p + ) + + +@register_transformations(name="random_gaussian_blur", type="image_pil") +class RandomGaussianBlur(BaseTransformation): + """ + This method randomly blurs the input image. + """ + + def __init__(self, opts, *args, **kwargs): + super().__init__(opts=opts) + self.p = getattr(opts, "image_augmentation.random_gaussian_noise.p", 0.5) + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + + group.add_argument( + "--image-augmentation.random-gaussian-noise.enable", + action="store_true", + help="use {}. This flag is useful when you want to study the effect of different " + "transforms.".format(cls.__name__), + ) + group.add_argument( + "--image-augmentation.random-gaussian-noise.p", + type=float, + default=0.5, + help="Probability for applying {}".format(cls.__name__), + ) + return parser + + def __call__(self, data: Dict) -> Dict: + if random.random() < self.p: + img = data.pop("image") + # radius is the standard devaition of the gaussian kernel + img = img.filter(ImageFilter.GaussianBlur(radius=random.random())) + data["image"] = img + return data + + +@register_transformations(name="random_crop", type="image_pil") +class RandomCrop(BaseTransformation): + """ + This method randomly crops an image area. + + .. note:: + If the size of input image is smaller than the desired crop size, the input image is first resized + while maintaining the aspect ratio and then cropping is performed. + """ + + def __init__( + self, + opts, + size: Union[Sequence, int], + ignore_idx: Optional[int] = 255, + *args, + **kwargs + ) -> None: + super().__init__(opts=opts) + self.height, self.width = setup_size(size=size) + self.opts = opts + self.seg_class_max_ratio = getattr( + opts, "image_augmentation.random_crop.seg_class_max_ratio", None + ) + self.ignore_idx = ignore_idx + self.num_repeats = 10 + self.seg_fill = getattr(opts, "image_augmentation.random_crop.mask_fill", 0) + pad_if_needed = getattr( + opts, "image_augmentation.random_crop.pad_if_needed", False + ) + self.if_needed_fn = ( + self._pad_if_needed if pad_if_needed else self._resize_if_needed + ) + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + + group.add_argument( + "--image-augmentation.random-crop.enable", + action="store_true", + help="use {}. This flag is useful when you want to study the effect of different " + "transforms.".format(cls.__name__), + ) + group.add_argument( + "--image-augmentation.random-crop.seg-class-max-ratio", + default=None, + type=float, + help="Max. ratio that single segmentation class can occupy. Defaults to None", + ) + group.add_argument( + "--image-augmentation.random-crop.pad-if-needed", + action="store_true", + help="Pad images if needed. Defaults to False, i.e., resizing will be performed", + ) + group.add_argument( + "--image-augmentation.random-crop.mask-fill", + type=int, + default=255, + help="Value to fill in segmentation mask in case of padding. Defaults to 255. " + "Generally, this value is the same as background or undefined class id.", + ) + return parser + + @staticmethod + def get_params(img_h, img_w, target_h, target_w): + if img_w == target_w and img_h == target_h: + return 0, 0, img_h, img_w + + i = random.randint(0, max(0, img_h - target_h)) + j = random.randint(0, max(0, img_w - target_w)) + return i, j, target_h, target_w + + @staticmethod + def get_params_from_box(boxes, img_h, img_w): + # x, y, w, h + offset = random.randint(20, 50) + start_x = max(0, int(round(np.min(boxes[..., 0]))) - offset) + start_y = max(0, int(round(np.min(boxes[..., 1]))) - offset) + end_x = min(int(round(np.max(boxes[..., 2]))) + offset, img_w) + end_y = min(int(round(np.max(boxes[..., 3]))) + offset, img_h) + + return start_y, start_x, end_y - start_y, end_x - start_x + + def get_params_from_mask(self, data, i, j, h, w): + img_w, img_h = F.get_image_size(data["image"]) + for _ in range(self.num_repeats): + temp_data = _crop_fn( + data=copy.deepcopy(data), top=i, left=j, height=h, width=w + ) + class_labels, cls_count = np.unique( + np.array(temp_data["mask"]), return_counts=True + ) + valid_cls_count = cls_count[class_labels != self.ignore_idx] + + if valid_cls_count.size == 0: + continue + + # compute the ratio of segmentation class with max. pixels to total pixels. + # If the ratio is less than seg_class_max_ratio, then exit the loop + total_valid_pixels = np.sum(valid_cls_count) + max_valid_pixels = np.max(valid_cls_count) + ratio = max_valid_pixels / total_valid_pixels + + if len(cls_count) > 1 and ratio < self.seg_class_max_ratio: + break + i, j, h, w = self.get_params( + img_h=img_h, img_w=img_w, target_h=self.height, target_w=self.width + ) + return i, j, h, w + + def _resize_if_needed(self, data: Dict) -> Dict: + img = data["image"] + + w, h = F.get_image_size(img) + # resize while maintaining the aspect ratio + new_size = min(h + max(0, self.height - h), w + max(0, self.width - w)) + + return _resize_fn( + data, size=new_size, interpolation=T.InterpolationMode.BILINEAR + ) + + def _pad_if_needed(self, data: Dict) -> Dict: + img = data.pop("image") + + w, h = F.get_image_size(img) + new_h = h + max(self.height - h, 0) + new_w = w + max(self.width - w, 0) + + pad_img = Image.new(img.mode, (new_w, new_h), color=0) + pad_img.paste(img, (0, 0)) + data["image"] = pad_img + + if "mask" in data: + mask = data.pop("mask") + pad_mask = Image.new(mask.mode, (new_w, new_h), color=self.seg_fill) + pad_mask.paste(mask, (0, 0)) + data["mask"] = pad_mask + + return data + + def __call__(self, data: Dict) -> Dict: + # box_info + if "box_coordinates" in data: + boxes = data.get("box_coordinates") + # crop the relevant area + image_w, image_h = F.get_image_size(data["image"]) + box_i, box_j, box_h, box_w = self.get_params_from_box( + boxes, image_h, image_w + ) + data = _crop_fn(data, top=box_i, left=box_j, height=box_h, width=box_w) + + data = self.if_needed_fn(data) + + img_w, img_h = F.get_image_size(data["image"]) + i, j, h, w = self.get_params( + img_h=img_h, img_w=img_w, target_h=self.height, target_w=self.width + ) + + if ( + "mask" in data + and self.seg_class_max_ratio is not None + and self.seg_class_max_ratio < 1.0 + ): + i, j, h, w = self.get_params_from_mask(data=data, i=i, j=j, h=h, w=w) + + data = _crop_fn(data=data, top=i, left=j, height=h, width=w) + return data + + def __repr__(self) -> str: + return "{}(size=(h={}, w={}), seg_class_max_ratio={}, seg_fill={})".format( + self.__class__.__name__, + self.height, + self.width, + self.seg_class_max_ratio, + self.seg_fill, + ) + + +@register_transformations(name="to_tensor", type="image_pil") +class ToTensor(BaseTransformation): + """ + This method converts an image into a tensor. + + .. note:: + We do not perform any mean-std normalization. If mean-std normalization is desired, please modify this class. + """ + + def __init__(self, opts, *args, **kwargs) -> None: + super().__init__(opts=opts) + img_dtype = getattr(opts, "image_augmentation.to_tensor.dtype", "float") + self.img_dtype = torch.float + self.norm_factor = 255 + if img_dtype in ["half", "float16"]: + self.img_dtype = torch.float16 + elif img_dtype in ["uint8"]: + self.img_dtype = torch.uint8 + self.norm_factor = 1 + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + parser.add_argument( + "--image-augmentation.to-tensor.dtype", + type=str, + default="float", + help="Tensor data type. Default is float", + ) + return parser + + def __repr__(self): + return "{}(dtype={}, norm_factor={})".format( + self.__class__.__name__, self.img_dtype, self.norm_factor + ) + + def __call__(self, data: Dict) -> Dict: + # HWC --> CHW + img = data["image"] + + if F._is_pil_image(img): + # convert PIL image to tensor + img = F.pil_to_tensor(img).contiguous() + + data["image"] = img.to(dtype=self.img_dtype).div(self.norm_factor) + + if "mask" in data: + mask = data.pop("mask") + mask = np.array(mask) + + if len(mask.shape) not in (2, 3): + logger.error( + "Mask needs to be 2- or 3-dimensional. Got: {}".format(mask.shape) + ) + data["mask"] = torch.as_tensor(mask, dtype=torch.long) + + if "box_coordinates" in data: + boxes = data.pop("box_coordinates") + data["box_coordinates"] = torch.as_tensor(boxes, dtype=torch.float) + + if "box_labels" in data: + box_labels = data.pop("box_labels") + data["box_labels"] = torch.as_tensor(box_labels) + + if "instance_mask" in data: + assert "instance_coords" in data + instance_masks = data.pop("instance_mask") + data["instance_mask"] = instance_masks.to(dtype=torch.long) + + instance_coords = data.pop("instance_coords") + data["instance_coords"] = torch.as_tensor( + instance_coords, dtype=torch.float + ) + return data + + +@register_transformations(name="compose", type="image_pil") +class Compose(BaseTransformation): + """ + This method applies a list of transforms in a sequential fashion. + """ + + def __init__(self, opts, img_transforms: List, *args, **kwargs) -> None: + super().__init__(opts=opts) + self.img_transforms = img_transforms + + def __call__(self, data: Dict) -> Dict: + for t in self.img_transforms: + data = t(data) + return data + + def __repr__(self) -> str: + transform_str = ", ".join("\n\t\t\t" + str(t) for t in self.img_transforms) + repr_str = "{}({}\n\t\t)".format(self.__class__.__name__, transform_str) + return repr_str + + +@register_transformations(name="random_order", type="image_pil") +class RandomOrder(BaseTransformation): + """ + This method applies a list of all or few transforms in a random order. + """ + + def __init__(self, opts, img_transforms: List, *args, **kwargs) -> None: + super().__init__(opts=opts) + self.transforms = img_transforms + apply_k_factor = getattr(opts, "image_augmentation.random_order.apply_k", 1.0) + assert ( + 0.0 < apply_k_factor <= 1.0 + ), "--image-augmentation.random-order.apply-k should be > 0 and <= 1" + self.keep_t = int(math.ceil(len(self.transforms) * apply_k_factor)) + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + group.add_argument( + "--image-augmentation.random-order.enable", + action="store_true", + help="use {}. This flag is useful when you want to study the effect of different " + "transforms.".format(cls.__name__), + ) + group.add_argument( + "--image-augmentation.random-order.apply-k", + type=int, + default=1.0, + help="Apply K percent of transforms randomly. Value between 0 and 1. " + "Defaults to 1 (i.e., apply all transforms in random order).", + ) + return parser + + def __call__(self, data: Dict) -> Dict: + random.shuffle(self.transforms) + for t in self.transforms[: self.keep_t]: + data = t(data) + return data + + def __repr__(self): + transform_str = ", ".join(str(t) for t in self.transforms) + repr_str = "{}(n_transforms={}, t_list=[{}]".format( + self.__class__.__name__, self.keep_t, transform_str + ) + return repr_str + + +@register_transformations(name="rand_augment_timm", type="image_pil") +class RandAugmentTimm(BaseTransformation): + """ + This class implements the `RandAugment data augmentation `_ method, + as described in `ResNet Strikes Back `_ paper + """ + + def __init__(self, opts, *args, **kwargs) -> None: + config_str = getattr( + opts, + "image_augmentation.rand_augment.timm_config_str", + "rand-m9-mstd0.5-inc1", + ) + + super().__init__(opts=opts, *args, **kwargs) + + rand_augment_transform = None + try: + from timm.data.transforms_factory import rand_augment_transform + except ModuleNotFoundError: + logger.error("Please install timm library") + + self.config_str = config_str + self.aug_fn = rand_augment_transform + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + + group.add_argument( + "--image-augmentation.rand-augment.use-timm-library", + action="store_true", + help="Use timm library for randaugment over PyTorch's implementation", + ) + group.add_argument( + "--image-augmentation.rand-augment.timm-config-str", + type=str, + default="rand-m9-mstd0.5-inc1", + help="Number of augmentation transformations to apply sequentially. Defaults to 2.", + ) + return parser + + def __call__(self, data: Dict) -> Dict: + if "box_coordinates" in data or "mask" in data or "instance_masks" in data: + logger.error( + "{} is only supported for classification tasks".format( + self.__class__.__name__ + ) + ) + + img = data["image"] + img_size_min = min(img.size) + aa_params = dict( + translate_const=int(img_size_min * 0.45), + img_mean=tuple([128, 128, 128]), + ) + img = self.aug_fn(self.config_str, aa_params)(img) + data["image"] = img + return data + + def __repr__(self) -> str: + return "{}(config_str={})".format(self.__class__.__name__, self.config_str) diff --git a/Adaptive Frequency Filters/data/transforms/image_torch.py b/Adaptive Frequency Filters/data/transforms/image_torch.py new file mode 100644 index 0000000..6525409 --- /dev/null +++ b/Adaptive Frequency Filters/data/transforms/image_torch.py @@ -0,0 +1,247 @@ +# +# For licensing see accompanying LICENSE file. +# Copyright (C) 2023 Apple Inc. All Rights Reserved. +# + +import math +from typing import Dict +import argparse +import torch +from torchvision.transforms import functional as F +from torch.nn import functional as F_torch + +from utils import logger + +from . import register_transformations, BaseTransformation + + +# Copied from PyTorch Torchvision +@register_transformations(name="random_mixup", type="image_torch") +class RandomMixup(BaseTransformation): + """ + Given a batch of input images and labels, this class randomly applies the + `Mixup transformation `_ + + Args: + num_classes (int): Number of classes in the dataset + """ + + def __init__(self, opts, num_classes: int, *args, **kwargs) -> None: + super().__init__(opts=opts, *args, **kwargs) + alpha = getattr(opts, "image_augmentation.mixup.alpha", 1.0) + assert ( + num_classes > 0 + ), "Please provide a valid positive value for the num_classes." + assert alpha > 0, "Alpha param can't be zero." + + self.num_classes = num_classes + self.p = getattr(opts, "image_augmentation.mixup.p", 0.5) + self.alpha = alpha + self.inplace = getattr(opts, "image_augmentation.mixup.inplace", False) + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + + group.add_argument( + "--image-augmentation.mixup.enable", + action="store_true", + help="use {}. This flag is useful when you want to study the effect of different " + "transforms.".format(cls.__name__), + ) + group.add_argument( + "--image-augmentation.mixup.alpha", + type=float, + default=1.0, + help="Alpha for MixUp augmentation. Defaults to 1.0", + ) + group.add_argument( + "--image-augmentation.mixup.p", + type=float, + default=0.5, + help="Probability for applying mixup augmentation. Defaults to 0.5", + ) + group.add_argument( + "--image-augmentation.mixup.inplace", + action="store_true", + default=False, + help="Apply Mixup augmentation inplace. Defaults to False.", + ) + return parser + + def __call__(self, data: Dict) -> Dict: + if torch.rand(1).item() >= self.p: + return data + + image_tensor, target_tensor = data.pop("samples"), data.pop("targets") + + if image_tensor.ndim != 4: + logger.error(f"Batch ndim should be 4. Got {image_tensor.ndim}") + if target_tensor.ndim != 1: + logger.error(f"Target ndim should be 1. Got {target_tensor.ndim}") + if not image_tensor.is_floating_point(): + logger.error( + f"Batch dtype should be a float tensor. Got {image_tensor.dtype}." + ) + if target_tensor.dtype != torch.int64: + logger.error( + f"Target dtype should be torch.int64. Got {target_tensor.dtype}" + ) + + if not self.inplace: + image_tensor = image_tensor.clone() + target_tensor = target_tensor.clone() + + if target_tensor.ndim == 1: + target_tensor = F_torch.one_hot( + target_tensor, num_classes=self.num_classes + ).to(dtype=image_tensor.dtype) + + # It's faster to roll the batch by one instead of shuffling it to create image pairs + batch_rolled = image_tensor.roll(1, 0) + target_rolled = target_tensor.roll(1, 0) + + # Implemented as on mixup paper, page 3. + lambda_param = float( + torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0] + ) + batch_rolled.mul_(1.0 - lambda_param) + image_tensor.mul_(lambda_param).add_(batch_rolled) + + target_rolled.mul_(1.0 - lambda_param) + target_tensor.mul_(lambda_param).add_(target_rolled) + + data["samples"] = image_tensor + data["targets"] = target_tensor + + return data + + def __repr__(self) -> str: + return "{}(num_classes={}, p={}, alpha={}, inplace={})".format( + self.__class__.__name__, self.num_classes, self.p, self.alpha, self.inplace + ) + + +@register_transformations(name="random_cutmix", type="image_torch") +class RandomCutmix(BaseTransformation): + """ + Given a batch of input images and labels, this class randomly applies the + `CutMix transformation `_ + + Args: + num_classes (int): Number of classes in the dataset + """ + + def __init__(self, opts, num_classes: int, *args, **kwargs) -> None: + super().__init__(opts=opts, *args, **kwargs) + alpha = getattr(opts, "image_augmentation.cutmix.alpha", 1.0) + assert ( + num_classes > 0 + ), "Please provide a valid positive value for the num_classes." + assert alpha > 0, "Alpha param can't be zero." + + self.num_classes = num_classes + self.p = getattr(opts, "image_augmentation.cutmix.p", 0.5) + self.alpha = alpha + self.inplace = getattr(opts, "image_augmentation.cutmix.inplace", False) + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + + group.add_argument( + "--image-augmentation.cutmix.enable", + action="store_true", + help="use {}. This flag is useful when you want to study the effect of different " + "transforms.".format(cls.__name__), + ) + + group.add_argument( + "--image-augmentation.cutmix.alpha", + type=float, + default=1.0, + help="Alpha for cutmix augmentation. Defaults to 1.0", + ) + group.add_argument( + "--image-augmentation.cutmix.p", + type=float, + default=0.5, + help="Probability for applying cutmix augmentation. Defaults to 0.5", + ) + group.add_argument( + "--image-augmentation.cutmix.inplace", + action="store_true", + default=False, + help="Apply cutmix operation inplace. Defaults to False", + ) + return parser + + def __call__(self, data: Dict) -> Dict: + if torch.rand(1).item() >= self.p: + return data + + image_tensor, target_tensor = data.pop("samples"), data.pop("targets") + + if image_tensor.ndim != 4: + logger.error(f"Batch ndim should be 4. Got {image_tensor.ndim}") + if target_tensor.ndim != 1: + logger.error(f"Target ndim should be 1. Got {target_tensor.ndim}") + if not image_tensor.is_floating_point(): + logger.error( + f"Batch dtype should be a float tensor. Got {image_tensor.dtype}." + ) + if target_tensor.dtype != torch.int64: + logger.error( + f"Target dtype should be torch.int64. Got {target_tensor.dtype}" + ) + + if not self.inplace: + image_tensor = image_tensor.clone() + target_tensor = target_tensor.clone() + + if target_tensor.ndim == 1: + target_tensor = F_torch.one_hot( + target_tensor, num_classes=self.num_classes + ).to(dtype=image_tensor.dtype) + + # It's faster to roll the batch by one instead of shuffling it to create image pairs + batch_rolled = image_tensor.roll(1, 0) + target_rolled = target_tensor.roll(1, 0) + + # Implemented as on cutmix paper, page 12 (with minor corrections on typos). + lambda_param = float( + torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0] + ) + W, H = F.get_image_size(image_tensor) + + r_x = torch.randint(W, (1,)) + r_y = torch.randint(H, (1,)) + + r = 0.5 * math.sqrt(1.0 - lambda_param) + r_w_half = int(r * W) + r_h_half = int(r * H) + + x1 = int(torch.clamp(r_x - r_w_half, min=0)) + y1 = int(torch.clamp(r_y - r_h_half, min=0)) + x2 = int(torch.clamp(r_x + r_w_half, max=W)) + y2 = int(torch.clamp(r_y + r_h_half, max=H)) + + image_tensor[:, :, y1:y2, x1:x2] = batch_rolled[:, :, y1:y2, x1:x2] + lambda_param = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H)) + + target_rolled.mul_(1.0 - lambda_param) + target_tensor.mul_(lambda_param).add_(target_rolled) + + data["samples"] = image_tensor + data["targets"] = target_tensor + + return data + + def __repr__(self) -> str: + return "{}(num_classes={}, p={}, alpha={}, inplace={})".format( + self.__class__.__name__, self.num_classes, self.p, self.alpha, self.inplace + ) diff --git a/Adaptive Frequency Filters/data/transforms/utils.py b/Adaptive Frequency Filters/data/transforms/utils.py new file mode 100644 index 0000000..33cb517 --- /dev/null +++ b/Adaptive Frequency Filters/data/transforms/utils.py @@ -0,0 +1,47 @@ +# +# For licensing see accompanying LICENSE file. +# Copyright (C) 2023 Apple Inc. All Rights Reserved. +# + +from typing import Any +import numpy as np + + +def setup_size(size: Any, error_msg="Need a tuple of length 2"): + if size is None: + raise ValueError("Size can't be None") + + if isinstance(size, int): + return size, size + elif isinstance(size, (list, tuple)) and len(size) == 1: + return size[0], size[0] + + if len(size) != 2: + raise ValueError(error_msg) + + return size + + +def intersect(box_a, box_b): + """Computes the intersection between box_a and box_b""" + max_xy = np.minimum(box_a[:, 2:], box_b[2:]) + min_xy = np.maximum(box_a[:, :2], box_b[:2]) + inter = np.clip((max_xy - min_xy), a_min=0, a_max=np.inf) + return inter[:, 0] * inter[:, 1] + + +def jaccard_numpy(box_a: np.ndarray, box_b: np.ndarray): + """ + Computes the intersection of two boxes. + Args: + box_a (np.ndarray): Boxes of shape [Num_boxes_A, 4] + box_b (np.ndarray): Box osf shape [Num_boxes_B, 4] + + Returns: + intersection over union scores. Shape is [box_a.shape[0], box_a.shape[1]] + """ + inter = intersect(box_a, box_b) + area_a = (box_a[:, 2] - box_a[:, 0]) * (box_a[:, 3] - box_a[:, 1]) # [A,B] + area_b = (box_b[2] - box_b[0]) * (box_b[3] - box_b[1]) # [A,B] + union = area_a + area_b - inter + return inter / union # [A,B] diff --git a/Adaptive Frequency Filters/data/transforms/video.py b/Adaptive Frequency Filters/data/transforms/video.py new file mode 100644 index 0000000..dfe791c --- /dev/null +++ b/Adaptive Frequency Filters/data/transforms/video.py @@ -0,0 +1,608 @@ +# +# For licensing see accompanying LICENSE file. +# Copyright (C) 2023 Apple Inc. All Rights Reserved. +# + +import random +import torch +import math +import argparse +from typing import Sequence, Dict, Any, Union, Tuple, List, Optional +from torch.nn import functional as F + +from utils import logger + +from . import register_transformations, BaseTransformation +from .utils import * + + +SUPPORTED_PYTORCH_INTERPOLATIONS = ["nearest", "bilinear", "bicubic"] + + +def _check_interpolation(interpolation): + if interpolation not in SUPPORTED_PYTORCH_INTERPOLATIONS: + inter_str = "Supported interpolation modes are:" + for i, j in enumerate(SUPPORTED_PYTORCH_INTERPOLATIONS): + inter_str += "\n\t{}: {}".format(i, j) + logger.error(inter_str) + return interpolation + + +def _crop_fn(data: Dict, i: int, j: int, h: int, w: int): + img = data["image"] + if not isinstance(img, torch.Tensor) and img.dim() != 4: + logger.error( + "Cropping requires 4-d tensor of shape NCHW or CNHW. Got {}-dimensional tensor".format( + img.dim() + ) + ) + + crop_image = img[..., i : i + h, j : j + w] + data["image"] = crop_image + + mask = data.get("mask", None) + if mask is not None: + crop_mask = mask[..., i : i + h, j : j + w] + data["mask"] = crop_mask + return data + + +def _resize_fn( + data: Dict, size: Union[Sequence, int], interpolation: Optional[str] = "bilinear" +): + img = data["image"] + + if isinstance(size, Sequence) and len(size) == 2: + size_h, size_w = size[0], size[1] + elif isinstance(size, int): + h, w = img.shape[-2:] + if (w <= h and w == size) or (h <= w and h == size): + return data + + if w < h: + size_h = int(size * h / w) + + size_w = size + else: + size_w = int(size * w / h) + size_h = size + else: + raise TypeError( + "Supported size args are int or tuple of length 2. Got inappropriate size arg: {}".format( + size + ) + ) + if isinstance(interpolation, str): + interpolation = _check_interpolation(interpolation) + img = F.interpolate( + input=img, + size=(size_w, size_h), + mode=interpolation, + align_corners=True if interpolation != "nearest" else None, + ) + data["image"] = img + + mask = data.get("mask", None) + if mask is not None: + mask = F.interpolate(input=mask, size=(size_w, size_h), mode="nearest") + data["mask"] = mask + + return data + + +def _check_rgb_video_tensor(clip): + if not isinstance(clip, torch.FloatTensor) or clip.dim() != 4: + logger.error( + "Video clip is either not an instance of FloatTensor or it is not a 4-d tensor (NCHW or CNHW)" + ) + + +@register_transformations(name="to_tensor", type="video") +class ToTensor(BaseTransformation): + """ + This method converts an image into a tensor. + + .. note:: + We do not perform any mean-std normalization. If mean-std normalization is desired, please modify this class. + """ + + def __init__(self, opts, *args, **kwargs) -> None: + super().__init__(opts=opts) + + def __call__(self, data: Dict) -> Dict: + # [C, N, H, W] + clip = data["image"] + if not isinstance(clip, torch.Tensor): + clip = torch.from_numpy(clip) + clip = clip.float() + + _check_rgb_video_tensor(clip=clip) + + # normalize between 0 and 1 + clip = torch.div(clip, 255.0) + data["image"] = clip + return data + + +@register_transformations(name="random_resized_crop", type="video") +class RandomResizedCrop(BaseTransformation): + """ + This class crops a random portion of an image and resize it to a given size. + """ + + def __init__(self, opts, size: Union[Tuple, int], *args, **kwargs) -> None: + interpolation = getattr( + opts, "video_augmentation.random_resized_crop.interpolation", "bilinear" + ) + scale = getattr( + opts, "video_augmentation.random_resized_crop.scale", (0.08, 1.0) + ) + ratio = getattr( + opts, + "video_augmentation.random_resized_crop.aspect_ratio", + (3.0 / 4.0, 4.0 / 3.0), + ) + + if not isinstance(scale, Sequence) or ( + isinstance(scale, Sequence) + and len(scale) != 2 + and 0.0 <= scale[0] < scale[1] + ): + logger.error( + "--video-augmentation.random-resized-crop.scale should be a tuple of length 2 " + "such that 0.0 <= scale[0] < scale[1]. Got: {}".format(scale) + ) + + if not isinstance(ratio, Sequence) or ( + isinstance(ratio, Sequence) + and len(ratio) != 2 + and 0.0 < ratio[0] < ratio[1] + ): + logger.error( + "--video-augmentation.random-resized-crop.aspect-ratio should be a tuple of length 2 " + "such that 0.0 < ratio[0] < ratio[1]. Got: {}".format(ratio) + ) + + ratio = (round(ratio[0], 3), round(ratio[1], 3)) + + super().__init__(opts=opts) + + self.scale = scale + self.size = setup_size(size=size) + + self.interpolation = _check_interpolation(interpolation) + self.ratio = ratio + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + + group.add_argument( + "--video-augmentation.random-resized-crop.enable", + action="store_true", + help="use {}. This flag is useful when you want to study the effect of different " + "transforms.".format(cls.__name__), + ) + group.add_argument( + "--video-augmentation.random-resized-crop.interpolation", + type=str, + default="bilinear", + choices=SUPPORTED_PYTORCH_INTERPOLATIONS, + help="Desired interpolation method. Defaults to bilinear", + ) + group.add_argument( + "--video-augmentation.random-resized-crop.scale", + type=tuple, + default=(0.08, 1.0), + help="Specifies the lower and upper bounds for the random area of the crop, before resizing." + " The scale is defined with respect to the area of the original image. Defaults to " + "(0.08, 1.0)", + ) + group.add_argument( + "--video-augmentation.random-resized-crop.aspect-ratio", + type=float or tuple, + default=(3.0 / 4.0, 4.0 / 3.0), + help="lower and upper bounds for the random aspect ratio of the crop, before resizing. " + "Defaults to (3./4., 4./3.)", + ) + return parser + + def get_params(self, height: int, width: int) -> (int, int, int, int): + area = height * width + for _ in range(10): + target_area = random.uniform(*self.scale) * area + log_ratio = (math.log(self.ratio[0]), math.log(self.ratio[1])) + aspect_ratio = math.exp(random.uniform(*log_ratio)) + + w = int(round(math.sqrt(target_area * aspect_ratio))) + h = int(round(math.sqrt(target_area / aspect_ratio))) + + if 0 < w <= width and 0 < h <= height: + i = random.randint(0, height - h) + j = random.randint(0, width - w) + return i, j, h, w + + # Fallback to central crop + in_ratio = (1.0 * width) / height + if in_ratio < min(self.ratio): + w = width + h = int(round(w / min(self.ratio))) + elif in_ratio > max(self.ratio): + h = height + w = int(round(h * max(self.ratio))) + else: # whole image + w = width + h = height + i = (height - h) // 2 + j = (width - w) // 2 + return i, j, h, w + + def __call__(self, data: Dict) -> Dict: + clip = data["image"] + _check_rgb_video_tensor(clip=clip) + + height, width = clip.shape[-2:] + + i, j, h, w = self.get_params(height=height, width=width) + data = _crop_fn(data=data, i=i, j=j, h=h, w=w) + return _resize_fn(data=data, size=self.size, interpolation=self.interpolation) + + def __repr__(self) -> str: + return "{}(scale={}, ratio={}, interpolation={})".format( + self.__class__.__name__, self.scale, self.ratio, self.interpolation + ) + + +@register_transformations(name="random_short_side_resize_crop", type="video") +class RandomShortSizeResizeCrop(BaseTransformation): + """ + This class first randomly resizes the input video such that shortest side is between specified minimum and + maximum values, adn then crops a desired size video. + + .. note:: + This class assumes that the video size after resizing is greater than or equal to the desired size. + """ + + def __init__(self, opts, size: Union[Tuple, int], *args, **kwargs) -> None: + interpolation = getattr( + opts, + "video_augmentation.random_short_side_resize_crop.interpolation", + "bilinear", + ) + short_size_min = getattr( + opts, + "video_augmentation.random_short_side_resize_crop.short_side_min", + None, + ) + short_size_max = getattr( + opts, + "video_augmentation.random_short_side_resize_crop.short_side_max", + None, + ) + + if short_size_min is None: + logger.error( + "Short side minimum value can't be None in {}".format( + self.__class__.__name__ + ) + ) + if short_size_max is None: + logger.error( + "Short side maximum value can't be None in {}".format( + self.__class__.__name__ + ) + ) + + if short_size_max <= short_size_min: + logger.error( + "Short side maximum value should be >= short side minimum value in {}. Got: {} and {}".format( + self.__class__.__name__, short_size_max, short_size_min + ) + ) + + super().__init__(opts=opts) + self.short_side_min = short_size_min + self.size = size + self.short_side_max = short_size_max + self.interpolation = _check_interpolation(interpolation) + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + + group.add_argument( + "--video-augmentation.random-short-side-resize-crop.enable", + action="store_true", + help="use {}. This flag is useful when you want to study the effect of different " + "transforms.".format(cls.__name__), + ) + group.add_argument( + "--video-augmentation.random-short-side-resize-crop.interpolation", + type=str, + default="bilinear", + choices=SUPPORTED_PYTORCH_INTERPOLATIONS, + help="Desired interpolation method. Defaults to bilinear", + ) + group.add_argument( + "--video-augmentation.random-short-side-resize-crop.short-side-min", + type=int, + default=None, + help="Minimum value for video's shortest side. Defaults to None.", + ) + group.add_argument( + "--video-augmentation.random-short-side-resize-crop.short-side-max", + type=int, + default=None, + help="Maximum value for video's shortest side. Defaults to None.", + ) + return parser + + def get_params(self, height, width) -> Tuple[int, int, int, int]: + th, tw = self.size + + if width == tw and height == th: + return 0, 0, height, width + + i = random.randint(0, height - th) + j = random.randint(0, width - tw) + return i, j, th, tw + + def __call__(self, data: Dict) -> Dict: + short_dim = random.randint(self.short_side_max, self.short_side_max) + # resize the video so that shorter side is short_dim + data = _resize_fn(data, size=short_dim, interpolation=self.interpolation) + + clip = data["image"] + _check_rgb_video_tensor(clip=clip) + height, width = clip.shape[-2:] + i, j, h, w = self.get_params(height=height, width=width) + # crop the video + return _crop_fn(data=data, i=i, j=j, h=h, w=w) + + def __repr__(self) -> str: + return "{}(size={}, short_size_range=({}, {}), interpolation={})".format( + self.__class__.__name__, + self.size, + self.short_side_min, + self.short_side_max, + self.interpolation, + ) + + +@register_transformations(name="random_crop", type="video") +class RandomCrop(BaseTransformation): + """ + This method randomly crops a video area. + + .. note:: + This class assumes that the input video size is greater than or equal to the desired size. + """ + + def __init__(self, opts, size: Union[Tuple, int], *args, **kwargs) -> None: + size = setup_size(size=size) + super().__init__(opts=opts) + self.size = size + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + + group.add_argument( + "--video-augmentation.random-crop.enable", + action="store_true", + help="use {}. This flag is useful when you want to study the effect of different " + "transforms.".format(cls.__name__), + ) + return parser + + def get_params(self, height, width) -> Tuple[int, int, int, int]: + th, tw = self.size + + if width == tw and height == th: + return 0, 0, height, width + + i = random.randint(0, height - th) + j = random.randint(0, width - tw) + return i, j, th, tw + + def __call__(self, data: Dict) -> Dict: + clip = data["image"] + _check_rgb_video_tensor(clip=clip) + height, width = clip.shape[-2:] + i, j, h, w = self.get_params(height=height, width=width) + return _crop_fn(data=data, i=i, j=j, h=h, w=w) + + def __repr__(self) -> str: + return "{}(size={})".format(self.__class__.__name__, self.size) + + +@register_transformations(name="random_horizontal_flip", type="video") +class RandomHorizontalFlip(BaseTransformation): + """ + This class implements random horizontal flipping method + """ + + def __init__(self, opts, *args, **kwargs) -> None: + p = getattr(opts, "video_augmentation.random_horizontal_flip.p", 0.5) + super().__init__(opts=opts) + self.p = p + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + group.add_argument( + "--video-augmentation.random-horizontal-flip.enable", + action="store_true", + help="use {}. This flag is useful when you want to study the effect of different " + "transforms.".format(cls.__name__), + ) + group.add_argument( + "--video-augmentation.random-horizontal-flip.p", + type=float, + default=0.5, + help="Probability for random horizontal flip", + ) + return parser + + def __call__(self, data: Dict) -> Dict: + + if random.random() <= self.p: + clip = data["image"] + _check_rgb_video_tensor(clip=clip) + clip = torch.flip(clip, dims=[-1]) + data["image"] = clip + + mask = data.get("mask", None) + if mask is not None: + mask = torch.flip(mask, dims=[-1]) + data["mask"] = mask + + return data + + +@register_transformations(name="center_crop", type="video") +class CenterCrop(BaseTransformation): + """ + This class implements center cropping method. + + .. note:: + This class assumes that the input size is greater than or equal to the desired size. + """ + + def __init__(self, opts, size: Sequence or int, *args, **kwargs) -> None: + super().__init__(opts=opts) + if isinstance(size, Sequence) and len(size) == 2: + self.height, self.width = size[0], size[1] + elif isinstance(size, Sequence) and len(size) == 1: + self.height = self.width = size[0] + elif isinstance(size, int): + self.height = self.width = size + else: + logger.error("Scale should be either an int or tuple of ints") + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + + group.add_argument( + "--video-augmentation.center-crop.enable", + action="store_true", + help="use center cropping", + ) + return parser + + def __call__(self, data: Dict) -> Dict: + height, width = data["image"].shape[-2:] + i = (height - self.height) // 2 + j = (width - self.width) // 2 + return _crop_fn(data=data, i=i, j=j, h=self.height, w=self.width) + + def __repr__(self) -> str: + return "{}(size=(h={}, w={}))".format( + self.__class__.__name__, self.height, self.width + ) + + +@register_transformations(name="resize", type="video") +class Resize(BaseTransformation): + """ + This class implements resizing operation. + + .. note:: + Two possible modes for resizing. + 1. Resize while maintaining aspect ratio. To enable this option, pass int as a size + 2. Resize to a fixed size. To enable this option, pass a tuple of height and width as a size + """ + + def __init__(self, opts, *args, **kwargs) -> None: + size = getattr(opts, "video_augmentation.resize.size", None) + if size is None: + logger.error("Size can not be None in {}".format(self.__class__.__name__)) + + # Possible modes. + # 1. Resize while maintaining aspect ratio. To enable this option, pass int as a size + # 2. Resize to a fixed size. To enable this option, pass a tuple of height and width as a size + + if isinstance(size, Sequence) and len(size) > 2: + logger.error( + "The length of size should be either 1 or 2 in {}".format( + self.__class__.__name__ + ) + ) + + interpolation = getattr( + opts, "video_augmentation.resize.interpolation", "bilinear" + ) + super().__init__(opts=opts) + + self.size = size + self.interpolation = _check_interpolation(interpolation) + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + + group.add_argument( + "--video-augmentation.resize.enable", + action="store_true", + help="use fixed resizing", + ) + + group.add_argument( + "--video-augmentation.resize.interpolation", + type=str, + default="bilinear", + choices=SUPPORTED_PYTORCH_INTERPOLATIONS, + help="Interpolation for resizing. Default is bilinear", + ) + group.add_argument( + "--video-augmentation.resize.size", + type=int, + nargs="+", + default=None, + help="Resize video to the specified size. If int is passed, then shorter side is resized" + "to the specified size and longest side is resized while maintaining aspect ratio." + "Defaults to None.", + ) + return parser + + def __call__(self, data: Dict) -> Dict: + return _resize_fn(data=data, size=self.size, interpolation=self.interpolation) + + def __repr__(self): + return "{}(size={}, interpolation={})".format( + self.__class__.__name__, self.size, self.interpolation + ) + + +@register_transformations(name="compose", type="video") +class Compose(BaseTransformation): + """ + This method applies a list of transforms in a sequential fashion. + """ + + def __init__(self, opts, video_transforms: List, *args, **kwargs) -> None: + super().__init__(opts=opts) + self.video_transforms = video_transforms + + def __call__(self, data: Dict) -> Dict: + for t in self.video_transforms: + data = t(data) + return data + + def __repr__(self) -> str: + transform_str = ", ".join("\n\t\t\t" + str(t) for t in self.video_transforms) + repr_str = "{}({})".format(self.__class__.__name__, transform_str) + return repr_str